Files
auto-virtual-tryon/app/workers/workflows/low_end_pipeline.py
2026-03-27 00:10:28 +08:00

153 lines
6.0 KiB
Python

"""Low-end image pipeline workflow."""
from datetime import timedelta
from temporalio import workflow
from temporalio.common import RetryPolicy
with workflow.unsafe.imports_passed_through():
from app.domain.enums import OrderStatus, WorkflowStepName
from app.infra.temporal.task_queues import (
IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_QC_TASK_QUEUE,
)
from app.workers.activities.export_activities import run_export_activity
from app.workers.activities.qc_activities import run_qc_activity
from app.workers.activities.review_activities import mark_workflow_failed_activity
from app.workers.activities.scene_activities import run_scene_activity
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
from app.workers.workflows.types import (
PipelineWorkflowInput,
StepActivityInput,
WorkflowFailureActivityInput,
)
ACTIVITY_TIMEOUT = timedelta(seconds=30)
ACTIVITY_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2.0,
maximum_attempts=3,
)
@workflow.defn
class LowEndPipelineWorkflow:
"""Low-end fully automated image pipeline."""
@workflow.run
async def run(self, payload: PipelineWorkflowInput) -> dict[str, int | str | None]:
"""Execute the low-end workflow from start to finish."""
current_step = WorkflowStepName.PREPARE_MODEL
try:
prepared = await workflow.execute_activity(
prepare_model_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.PREPARE_MODEL,
model_id=payload.model_id,
pose_id=payload.pose_id,
garment_asset_id=payload.garment_asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.TRYON
tryon_result = await workflow.execute_activity(
run_tryon_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.TRYON,
source_asset_id=prepared.asset_id,
garment_asset_id=payload.garment_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.SCENE
scene_result = await workflow.execute_activity(
run_scene_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.SCENE,
source_asset_id=tryon_result.asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.QC
qc_result = await workflow.execute_activity(
run_qc_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.QC,
source_asset_id=scene_result.asset_id,
),
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
if not qc_result.passed:
await self._mark_failed(payload, current_step, qc_result.message)
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
current_step = WorkflowStepName.EXPORT
final_result = await workflow.execute_activity(
run_export_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.EXPORT,
source_asset_id=(qc_result.candidate_asset_ids or [scene_result.asset_id])[0],
),
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
return {
"order_id": payload.order_id,
"status": OrderStatus.SUCCEEDED.value,
"final_asset_id": final_result.asset_id,
}
except Exception as exc:
await self._mark_failed(payload, current_step, str(exc))
raise
async def _mark_failed(
self,
payload: PipelineWorkflowInput,
current_step: WorkflowStepName,
message: str,
) -> None:
"""Persist workflow failure state."""
await workflow.execute_activity(
mark_workflow_failed_activity,
WorkflowFailureActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
current_step=current_step,
message=message,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)