153 lines
6.0 KiB
Python
153 lines
6.0 KiB
Python
"""Low-end image pipeline workflow."""
|
|
|
|
from datetime import timedelta
|
|
|
|
from temporalio import workflow
|
|
from temporalio.common import RetryPolicy
|
|
|
|
with workflow.unsafe.imports_passed_through():
|
|
from app.domain.enums import OrderStatus, WorkflowStepName
|
|
from app.infra.temporal.task_queues import (
|
|
IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
|
IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
|
|
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
|
IMAGE_PIPELINE_QC_TASK_QUEUE,
|
|
)
|
|
from app.workers.activities.export_activities import run_export_activity
|
|
from app.workers.activities.qc_activities import run_qc_activity
|
|
from app.workers.activities.review_activities import mark_workflow_failed_activity
|
|
from app.workers.activities.scene_activities import run_scene_activity
|
|
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
|
|
from app.workers.workflows.types import (
|
|
PipelineWorkflowInput,
|
|
StepActivityInput,
|
|
WorkflowFailureActivityInput,
|
|
)
|
|
|
|
|
|
ACTIVITY_TIMEOUT = timedelta(seconds=30)
|
|
ACTIVITY_RETRY_POLICY = RetryPolicy(
|
|
initial_interval=timedelta(seconds=1),
|
|
backoff_coefficient=2.0,
|
|
maximum_attempts=3,
|
|
)
|
|
|
|
|
|
@workflow.defn
|
|
class LowEndPipelineWorkflow:
|
|
"""Low-end fully automated image pipeline."""
|
|
|
|
@workflow.run
|
|
async def run(self, payload: PipelineWorkflowInput) -> dict[str, int | str | None]:
|
|
"""Execute the low-end workflow from start to finish."""
|
|
|
|
current_step = WorkflowStepName.PREPARE_MODEL
|
|
try:
|
|
prepared = await workflow.execute_activity(
|
|
prepare_model_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.PREPARE_MODEL,
|
|
model_id=payload.model_id,
|
|
pose_id=payload.pose_id,
|
|
garment_asset_id=payload.garment_asset_id,
|
|
scene_ref_asset_id=payload.scene_ref_asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
current_step = WorkflowStepName.TRYON
|
|
tryon_result = await workflow.execute_activity(
|
|
run_tryon_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.TRYON,
|
|
source_asset_id=prepared.asset_id,
|
|
garment_asset_id=payload.garment_asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
current_step = WorkflowStepName.SCENE
|
|
scene_result = await workflow.execute_activity(
|
|
run_scene_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.SCENE,
|
|
source_asset_id=tryon_result.asset_id,
|
|
scene_ref_asset_id=payload.scene_ref_asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
current_step = WorkflowStepName.QC
|
|
qc_result = await workflow.execute_activity(
|
|
run_qc_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.QC,
|
|
source_asset_id=scene_result.asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
if not qc_result.passed:
|
|
await self._mark_failed(payload, current_step, qc_result.message)
|
|
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
|
|
|
|
current_step = WorkflowStepName.EXPORT
|
|
final_result = await workflow.execute_activity(
|
|
run_export_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.EXPORT,
|
|
source_asset_id=(qc_result.candidate_asset_ids or [scene_result.asset_id])[0],
|
|
),
|
|
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
return {
|
|
"order_id": payload.order_id,
|
|
"status": OrderStatus.SUCCEEDED.value,
|
|
"final_asset_id": final_result.asset_id,
|
|
}
|
|
except Exception as exc:
|
|
await self._mark_failed(payload, current_step, str(exc))
|
|
raise
|
|
|
|
async def _mark_failed(
|
|
self,
|
|
payload: PipelineWorkflowInput,
|
|
current_step: WorkflowStepName,
|
|
message: str,
|
|
) -> None:
|
|
"""Persist workflow failure state."""
|
|
|
|
await workflow.execute_activity(
|
|
mark_workflow_failed_activity,
|
|
WorkflowFailureActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
current_step=current_step,
|
|
message=message,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|