118 lines
4.8 KiB
Python
118 lines
4.8 KiB
Python
"""Review state management mock activities."""
|
|
|
|
from sqlalchemy import select
|
|
|
|
from temporalio import activity
|
|
|
|
from app.domain.enums import OrderStatus, ReviewDecision, ReviewTaskStatus, StepStatus, WorkflowStepName
|
|
from app.infra.db.models.review_task import ReviewTaskORM
|
|
from app.infra.db.models.workflow_step import WorkflowStepORM
|
|
from app.infra.db.session import get_session_factory
|
|
from app.workers.activities.tryon_activities import jsonable, load_order_and_run, utc_now
|
|
from app.workers.workflows.types import (
|
|
ReviewResolutionActivityInput,
|
|
ReviewWaitActivityInput,
|
|
WorkflowFailureActivityInput,
|
|
)
|
|
|
|
|
|
@activity.defn
|
|
async def mark_waiting_for_review_activity(payload: ReviewWaitActivityInput) -> None:
|
|
"""Mark a workflow as waiting for a human review."""
|
|
|
|
async with get_session_factory()() as session:
|
|
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
|
review_step = WorkflowStepORM(
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.REVIEW,
|
|
step_status=StepStatus.WAITING,
|
|
input_json=jsonable(payload),
|
|
started_at=utc_now(),
|
|
)
|
|
session.add(review_step)
|
|
session.add(
|
|
ReviewTaskORM(
|
|
order_id=payload.order_id,
|
|
status=ReviewTaskStatus.PENDING,
|
|
selected_asset_id=payload.candidate_asset_ids[0] if payload.candidate_asset_ids else None,
|
|
comment=payload.comment,
|
|
)
|
|
)
|
|
|
|
order.status = OrderStatus.WAITING_REVIEW
|
|
workflow_run.status = OrderStatus.WAITING_REVIEW
|
|
workflow_run.current_step = WorkflowStepName.REVIEW
|
|
await session.commit()
|
|
|
|
|
|
@activity.defn
|
|
async def complete_review_wait_activity(payload: ReviewResolutionActivityInput) -> None:
|
|
"""Resolve the current waiting-review step before the next branch runs."""
|
|
|
|
async with get_session_factory()() as session:
|
|
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
|
step_result = await session.execute(
|
|
select(WorkflowStepORM)
|
|
.where(
|
|
WorkflowStepORM.workflow_run_id == payload.workflow_run_id,
|
|
WorkflowStepORM.step_name == WorkflowStepName.REVIEW,
|
|
WorkflowStepORM.step_status == StepStatus.WAITING,
|
|
)
|
|
.order_by(WorkflowStepORM.started_at.desc(), WorkflowStepORM.id.desc())
|
|
)
|
|
review_step = step_result.scalars().first()
|
|
if review_step is not None:
|
|
review_step.step_status = (
|
|
StepStatus.FAILED if payload.decision == ReviewDecision.REJECT else StepStatus.SUCCEEDED
|
|
)
|
|
review_step.output_json = jsonable(payload)
|
|
review_step.error_message = payload.comment if payload.decision == ReviewDecision.REJECT else None
|
|
review_step.ended_at = utc_now()
|
|
|
|
if payload.decision == ReviewDecision.REJECT:
|
|
order.status = OrderStatus.FAILED
|
|
workflow_run.status = OrderStatus.FAILED
|
|
else:
|
|
order.status = OrderStatus.RUNNING
|
|
workflow_run.status = OrderStatus.RUNNING
|
|
|
|
workflow_run.current_step = WorkflowStepName.REVIEW
|
|
await session.commit()
|
|
|
|
|
|
@activity.defn
|
|
async def mark_workflow_failed_activity(payload: WorkflowFailureActivityInput) -> None:
|
|
"""Mark the persisted workflow state as failed."""
|
|
|
|
async with get_session_factory()() as session:
|
|
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
|
|
|
step_result = await session.execute(
|
|
select(WorkflowStepORM)
|
|
.where(
|
|
WorkflowStepORM.workflow_run_id == payload.workflow_run_id,
|
|
WorkflowStepORM.step_name == payload.current_step,
|
|
)
|
|
.order_by(WorkflowStepORM.started_at.desc(), WorkflowStepORM.id.desc())
|
|
)
|
|
workflow_step = step_result.scalars().first()
|
|
if workflow_step is None:
|
|
workflow_step = WorkflowStepORM(
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=payload.current_step,
|
|
step_status=StepStatus.FAILED,
|
|
input_json=jsonable(payload),
|
|
started_at=utc_now(),
|
|
)
|
|
session.add(workflow_step)
|
|
|
|
workflow_step.step_status = StepStatus.FAILED
|
|
workflow_step.error_message = payload.message
|
|
workflow_step.output_json = jsonable({"message": payload.message, "status": payload.status.value})
|
|
workflow_step.ended_at = workflow_step.ended_at or utc_now()
|
|
|
|
order.status = payload.status
|
|
workflow_run.status = payload.status
|
|
workflow_run.current_step = payload.current_step
|
|
await session.commit()
|