Files
auto-virtual-tryon/app/workers/workflows/mid_end_pipeline.py
2026-03-27 00:10:28 +08:00

316 lines
14 KiB
Python

"""Mid-end image pipeline workflow with review signal support."""
from datetime import timedelta
from temporalio import workflow
from temporalio.common import RetryPolicy
with workflow.unsafe.imports_passed_through():
from app.domain.enums import OrderStatus, ReviewDecision, WorkflowStepName
from app.infra.temporal.task_queues import (
IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
IMAGE_PIPELINE_QC_TASK_QUEUE,
)
from app.workers.activities.export_activities import run_export_activity
from app.workers.activities.face_activities import run_face_activity
from app.workers.activities.fusion_activities import run_fusion_activity
from app.workers.activities.qc_activities import run_qc_activity
from app.workers.activities.review_activities import (
complete_review_wait_activity,
mark_waiting_for_review_activity,
mark_workflow_failed_activity,
)
from app.workers.activities.scene_activities import run_scene_activity
from app.workers.activities.texture_activities import run_texture_activity
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
from app.workers.workflows.types import (
MockActivityResult,
PipelineWorkflowInput,
ReviewResolutionActivityInput,
ReviewSignalPayload,
ReviewWaitActivityInput,
StepActivityInput,
WorkflowFailureActivityInput,
)
ACTIVITY_TIMEOUT = timedelta(seconds=30)
ACTIVITY_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2.0,
maximum_attempts=3,
)
@workflow.defn
class MidEndPipelineWorkflow:
"""Mid-end workflow that pauses for human review and supports reruns."""
def __init__(self) -> None:
self._review_payload: ReviewSignalPayload | None = None
@workflow.signal
def submit_review(self, payload: ReviewSignalPayload) -> None:
"""Receive a review decision from the API layer."""
self._review_payload = payload
@workflow.run
async def run(self, payload: PipelineWorkflowInput) -> dict[str, int | str | None]:
"""Execute the mid-end workflow with a human review loop."""
current_step = WorkflowStepName.PREPARE_MODEL
try:
prepared = await workflow.execute_activity(
prepare_model_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.PREPARE_MODEL,
model_id=payload.model_id,
pose_id=payload.pose_id,
garment_asset_id=payload.garment_asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.TRYON
tryon_result = await workflow.execute_activity(
run_tryon_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.TRYON,
source_asset_id=prepared.asset_id,
garment_asset_id=payload.garment_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
current_step = WorkflowStepName.TEXTURE
texture_result = await self._run_texture(payload, scene_result.asset_id)
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
current_step = WorkflowStepName.QC
qc_result = await self._run_qc(payload, fusion_result.asset_id)
if not qc_result.passed:
await self._mark_failed(payload, current_step, qc_result.message)
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
while True:
current_step = WorkflowStepName.REVIEW
await workflow.execute_activity(
mark_waiting_for_review_activity,
ReviewWaitActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
candidate_asset_ids=qc_result.candidate_asset_ids,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
review_payload = await self._wait_for_review()
await workflow.execute_activity(
complete_review_wait_activity,
ReviewResolutionActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
decision=review_payload.decision,
reviewer_id=review_payload.reviewer_id,
selected_asset_id=review_payload.selected_asset_id,
comment=review_payload.comment,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
if review_payload.decision == ReviewDecision.APPROVE:
current_step = WorkflowStepName.EXPORT
export_source_id = review_payload.selected_asset_id
if export_source_id is None:
export_source_id = (qc_result.candidate_asset_ids or [fusion_result.asset_id])[0]
final_result = await workflow.execute_activity(
run_export_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.EXPORT,
source_asset_id=export_source_id,
),
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
return {
"order_id": payload.order_id,
"status": OrderStatus.SUCCEEDED.value,
"final_asset_id": final_result.asset_id,
}
if review_payload.decision == ReviewDecision.REJECT:
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
if review_payload.decision == ReviewDecision.RERUN_SCENE:
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
current_step = WorkflowStepName.TEXTURE
texture_result = await self._run_texture(payload, scene_result.asset_id)
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
elif review_payload.decision == ReviewDecision.RERUN_FACE:
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
elif review_payload.decision == ReviewDecision.RERUN_FUSION:
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
current_step = WorkflowStepName.QC
qc_result = await self._run_qc(payload, fusion_result.asset_id)
if not qc_result.passed:
await self._mark_failed(payload, current_step, qc_result.message)
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
except Exception as exc:
await self._mark_failed(payload, current_step, str(exc))
raise
async def _wait_for_review(self) -> ReviewSignalPayload:
"""Suspend the workflow until a review signal arrives."""
if self._review_payload is None:
await workflow.wait_condition(lambda: self._review_payload is not None)
assert self._review_payload is not None
review_payload = self._review_payload
self._review_payload = None
return review_payload
async def _run_scene(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""Execute the scene activity."""
return await workflow.execute_activity(
run_scene_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.SCENE,
source_asset_id=source_asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_texture(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""Execute the texture activity."""
return await workflow.execute_activity(
run_texture_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.TEXTURE,
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_face(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""Execute the face activity."""
return await workflow.execute_activity(
run_face_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.FACE,
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_fusion(
self,
payload: PipelineWorkflowInput,
source_asset_id: int | None,
face_asset_id: int | None,
) -> MockActivityResult:
"""Execute the fusion activity."""
return await workflow.execute_activity(
run_fusion_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.FUSION,
source_asset_id=source_asset_id,
selected_asset_id=face_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_qc(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""Execute the QC activity."""
return await workflow.execute_activity(
run_qc_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.QC,
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _mark_failed(
self,
payload: PipelineWorkflowInput,
current_step: WorkflowStepName,
message: str,
) -> None:
"""Persist workflow failure state."""
await workflow.execute_activity(
mark_workflow_failed_activity,
WorkflowFailureActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
current_step=current_step,
message=message,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)