316 lines
14 KiB
Python
316 lines
14 KiB
Python
"""Mid-end image pipeline workflow with review signal support."""
|
|
|
|
from datetime import timedelta
|
|
|
|
from temporalio import workflow
|
|
from temporalio.common import RetryPolicy
|
|
|
|
with workflow.unsafe.imports_passed_through():
|
|
from app.domain.enums import OrderStatus, ReviewDecision, WorkflowStepName
|
|
from app.infra.temporal.task_queues import (
|
|
IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
|
IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
|
|
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
|
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
|
|
IMAGE_PIPELINE_QC_TASK_QUEUE,
|
|
)
|
|
from app.workers.activities.export_activities import run_export_activity
|
|
from app.workers.activities.face_activities import run_face_activity
|
|
from app.workers.activities.fusion_activities import run_fusion_activity
|
|
from app.workers.activities.qc_activities import run_qc_activity
|
|
from app.workers.activities.review_activities import (
|
|
complete_review_wait_activity,
|
|
mark_waiting_for_review_activity,
|
|
mark_workflow_failed_activity,
|
|
)
|
|
from app.workers.activities.scene_activities import run_scene_activity
|
|
from app.workers.activities.texture_activities import run_texture_activity
|
|
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
|
|
from app.workers.workflows.types import (
|
|
MockActivityResult,
|
|
PipelineWorkflowInput,
|
|
ReviewResolutionActivityInput,
|
|
ReviewSignalPayload,
|
|
ReviewWaitActivityInput,
|
|
StepActivityInput,
|
|
WorkflowFailureActivityInput,
|
|
)
|
|
|
|
|
|
ACTIVITY_TIMEOUT = timedelta(seconds=30)
|
|
ACTIVITY_RETRY_POLICY = RetryPolicy(
|
|
initial_interval=timedelta(seconds=1),
|
|
backoff_coefficient=2.0,
|
|
maximum_attempts=3,
|
|
)
|
|
|
|
|
|
@workflow.defn
|
|
class MidEndPipelineWorkflow:
|
|
"""Mid-end workflow that pauses for human review and supports reruns."""
|
|
|
|
def __init__(self) -> None:
|
|
self._review_payload: ReviewSignalPayload | None = None
|
|
|
|
@workflow.signal
|
|
def submit_review(self, payload: ReviewSignalPayload) -> None:
|
|
"""Receive a review decision from the API layer."""
|
|
|
|
self._review_payload = payload
|
|
|
|
@workflow.run
|
|
async def run(self, payload: PipelineWorkflowInput) -> dict[str, int | str | None]:
|
|
"""Execute the mid-end workflow with a human review loop."""
|
|
|
|
current_step = WorkflowStepName.PREPARE_MODEL
|
|
try:
|
|
prepared = await workflow.execute_activity(
|
|
prepare_model_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.PREPARE_MODEL,
|
|
model_id=payload.model_id,
|
|
pose_id=payload.pose_id,
|
|
garment_asset_id=payload.garment_asset_id,
|
|
scene_ref_asset_id=payload.scene_ref_asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
current_step = WorkflowStepName.TRYON
|
|
tryon_result = await workflow.execute_activity(
|
|
run_tryon_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.TRYON,
|
|
source_asset_id=prepared.asset_id,
|
|
garment_asset_id=payload.garment_asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
current_step = WorkflowStepName.SCENE
|
|
scene_result = await self._run_scene(payload, tryon_result.asset_id)
|
|
|
|
current_step = WorkflowStepName.TEXTURE
|
|
texture_result = await self._run_texture(payload, scene_result.asset_id)
|
|
|
|
current_step = WorkflowStepName.FACE
|
|
face_result = await self._run_face(payload, texture_result.asset_id)
|
|
|
|
current_step = WorkflowStepName.FUSION
|
|
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
|
|
|
|
current_step = WorkflowStepName.QC
|
|
qc_result = await self._run_qc(payload, fusion_result.asset_id)
|
|
if not qc_result.passed:
|
|
await self._mark_failed(payload, current_step, qc_result.message)
|
|
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
|
|
|
|
while True:
|
|
current_step = WorkflowStepName.REVIEW
|
|
await workflow.execute_activity(
|
|
mark_waiting_for_review_activity,
|
|
ReviewWaitActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
candidate_asset_ids=qc_result.candidate_asset_ids,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
review_payload = await self._wait_for_review()
|
|
await workflow.execute_activity(
|
|
complete_review_wait_activity,
|
|
ReviewResolutionActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
decision=review_payload.decision,
|
|
reviewer_id=review_payload.reviewer_id,
|
|
selected_asset_id=review_payload.selected_asset_id,
|
|
comment=review_payload.comment,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
if review_payload.decision == ReviewDecision.APPROVE:
|
|
current_step = WorkflowStepName.EXPORT
|
|
export_source_id = review_payload.selected_asset_id
|
|
if export_source_id is None:
|
|
export_source_id = (qc_result.candidate_asset_ids or [fusion_result.asset_id])[0]
|
|
final_result = await workflow.execute_activity(
|
|
run_export_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.EXPORT,
|
|
source_asset_id=export_source_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
return {
|
|
"order_id": payload.order_id,
|
|
"status": OrderStatus.SUCCEEDED.value,
|
|
"final_asset_id": final_result.asset_id,
|
|
}
|
|
|
|
if review_payload.decision == ReviewDecision.REJECT:
|
|
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
|
|
|
|
if review_payload.decision == ReviewDecision.RERUN_SCENE:
|
|
current_step = WorkflowStepName.SCENE
|
|
scene_result = await self._run_scene(payload, tryon_result.asset_id)
|
|
current_step = WorkflowStepName.TEXTURE
|
|
texture_result = await self._run_texture(payload, scene_result.asset_id)
|
|
current_step = WorkflowStepName.FACE
|
|
face_result = await self._run_face(payload, texture_result.asset_id)
|
|
current_step = WorkflowStepName.FUSION
|
|
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
|
|
elif review_payload.decision == ReviewDecision.RERUN_FACE:
|
|
current_step = WorkflowStepName.FACE
|
|
face_result = await self._run_face(payload, texture_result.asset_id)
|
|
current_step = WorkflowStepName.FUSION
|
|
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
|
|
elif review_payload.decision == ReviewDecision.RERUN_FUSION:
|
|
current_step = WorkflowStepName.FUSION
|
|
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
|
|
|
|
current_step = WorkflowStepName.QC
|
|
qc_result = await self._run_qc(payload, fusion_result.asset_id)
|
|
if not qc_result.passed:
|
|
await self._mark_failed(payload, current_step, qc_result.message)
|
|
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
|
|
except Exception as exc:
|
|
await self._mark_failed(payload, current_step, str(exc))
|
|
raise
|
|
|
|
async def _wait_for_review(self) -> ReviewSignalPayload:
|
|
"""Suspend the workflow until a review signal arrives."""
|
|
|
|
if self._review_payload is None:
|
|
await workflow.wait_condition(lambda: self._review_payload is not None)
|
|
assert self._review_payload is not None
|
|
review_payload = self._review_payload
|
|
self._review_payload = None
|
|
return review_payload
|
|
|
|
async def _run_scene(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
|
|
"""Execute the scene activity."""
|
|
|
|
return await workflow.execute_activity(
|
|
run_scene_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.SCENE,
|
|
source_asset_id=source_asset_id,
|
|
scene_ref_asset_id=payload.scene_ref_asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
async def _run_texture(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
|
|
"""Execute the texture activity."""
|
|
|
|
return await workflow.execute_activity(
|
|
run_texture_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.TEXTURE,
|
|
source_asset_id=source_asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
async def _run_face(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
|
|
"""Execute the face activity."""
|
|
|
|
return await workflow.execute_activity(
|
|
run_face_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.FACE,
|
|
source_asset_id=source_asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
async def _run_fusion(
|
|
self,
|
|
payload: PipelineWorkflowInput,
|
|
source_asset_id: int | None,
|
|
face_asset_id: int | None,
|
|
) -> MockActivityResult:
|
|
"""Execute the fusion activity."""
|
|
|
|
return await workflow.execute_activity(
|
|
run_fusion_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.FUSION,
|
|
source_asset_id=source_asset_id,
|
|
selected_asset_id=face_asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
async def _run_qc(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
|
|
"""Execute the QC activity."""
|
|
|
|
return await workflow.execute_activity(
|
|
run_qc_activity,
|
|
StepActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=WorkflowStepName.QC,
|
|
source_asset_id=source_asset_id,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|
|
|
|
async def _mark_failed(
|
|
self,
|
|
payload: PipelineWorkflowInput,
|
|
current_step: WorkflowStepName,
|
|
message: str,
|
|
) -> None:
|
|
"""Persist workflow failure state."""
|
|
|
|
await workflow.execute_activity(
|
|
mark_workflow_failed_activity,
|
|
WorkflowFailureActivityInput(
|
|
order_id=payload.order_id,
|
|
workflow_run_id=payload.workflow_run_id,
|
|
current_step=current_step,
|
|
message=message,
|
|
),
|
|
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
|
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
|
retry_policy=ACTIVITY_RETRY_POLICY,
|
|
)
|