Implement FastAPI Temporal MVP pipeline

This commit is contained in:
Codex
2026-03-27 00:10:28 +08:00
commit cc03da8a94
52 changed files with 3619 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
"""Export mock activity."""
from temporalio import activity
from app.domain.enums import AssetType
from app.workers.activities.tryon_activities import execute_asset_step
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@activity.defn
async def run_export_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock final asset export."""
return await execute_asset_step(
payload,
AssetType.FINAL,
filename="final.png",
finalize=True,
)

View File

@@ -0,0 +1,15 @@
"""Face mock activity."""
from temporalio import activity
from app.domain.enums import AssetType
from app.workers.activities.tryon_activities import execute_asset_step
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@activity.defn
async def run_face_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock face enhancement."""
return await execute_asset_step(payload, AssetType.FACE)

View File

@@ -0,0 +1,19 @@
"""Fusion mock activity."""
from temporalio import activity
from app.domain.enums import AssetType
from app.workers.activities.tryon_activities import execute_asset_step
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@activity.defn
async def run_fusion_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock face and body fusion."""
return await execute_asset_step(
payload,
AssetType.FUSION,
extra_metadata={"face_asset_id": payload.selected_asset_id},
)

View File

@@ -0,0 +1,69 @@
"""Quality-control mock activity."""
from temporalio import activity
from app.domain.enums import AssetType, OrderStatus, StepStatus
from app.infra.db.models.asset import AssetORM
from app.infra.db.session import get_session_factory
from app.workers.activities.tryon_activities import create_step_record, jsonable, load_order_and_run, mock_uri, utc_now
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@activity.defn
async def run_qc_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock automated quality control."""
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
passed = not payload.metadata.get("force_fail", False)
candidate_asset_ids: list[int] = []
candidate_uri: str | None = None
if passed:
candidate = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.QC_CANDIDATE,
step_name=payload.step_name,
uri=mock_uri(payload.order_id, payload.step_name.value, "candidate.png"),
metadata_json=jsonable({"source_asset_id": payload.source_asset_id}),
)
session.add(candidate)
await session.flush()
candidate_asset_ids = [candidate.id]
candidate_uri = candidate.uri
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=candidate_asset_ids[0] if candidate_asset_ids else None,
uri=candidate_uri,
score=0.95 if passed else 0.35,
passed=passed,
message="mock success" if passed else "mock qc rejected",
candidate_asset_ids=candidate_asset_ids,
metadata={"source_asset_id": payload.source_asset_id},
)
step.step_status = StepStatus.SUCCEEDED if passed else StepStatus.FAILED
step.output_json = jsonable(result)
step.error_message = None if passed else "QC rejected the asset"
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise

View File

@@ -0,0 +1,117 @@
"""Review state management mock activities."""
from sqlalchemy import select
from temporalio import activity
from app.domain.enums import OrderStatus, ReviewDecision, ReviewTaskStatus, StepStatus, WorkflowStepName
from app.infra.db.models.review_task import ReviewTaskORM
from app.infra.db.models.workflow_step import WorkflowStepORM
from app.infra.db.session import get_session_factory
from app.workers.activities.tryon_activities import jsonable, load_order_and_run, utc_now
from app.workers.workflows.types import (
ReviewResolutionActivityInput,
ReviewWaitActivityInput,
WorkflowFailureActivityInput,
)
@activity.defn
async def mark_waiting_for_review_activity(payload: ReviewWaitActivityInput) -> None:
"""Mark a workflow as waiting for a human review."""
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
review_step = WorkflowStepORM(
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.REVIEW,
step_status=StepStatus.WAITING,
input_json=jsonable(payload),
started_at=utc_now(),
)
session.add(review_step)
session.add(
ReviewTaskORM(
order_id=payload.order_id,
status=ReviewTaskStatus.PENDING,
selected_asset_id=payload.candidate_asset_ids[0] if payload.candidate_asset_ids else None,
comment=payload.comment,
)
)
order.status = OrderStatus.WAITING_REVIEW
workflow_run.status = OrderStatus.WAITING_REVIEW
workflow_run.current_step = WorkflowStepName.REVIEW
await session.commit()
@activity.defn
async def complete_review_wait_activity(payload: ReviewResolutionActivityInput) -> None:
"""Resolve the current waiting-review step before the next branch runs."""
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step_result = await session.execute(
select(WorkflowStepORM)
.where(
WorkflowStepORM.workflow_run_id == payload.workflow_run_id,
WorkflowStepORM.step_name == WorkflowStepName.REVIEW,
WorkflowStepORM.step_status == StepStatus.WAITING,
)
.order_by(WorkflowStepORM.started_at.desc(), WorkflowStepORM.id.desc())
)
review_step = step_result.scalars().first()
if review_step is not None:
review_step.step_status = (
StepStatus.FAILED if payload.decision == ReviewDecision.REJECT else StepStatus.SUCCEEDED
)
review_step.output_json = jsonable(payload)
review_step.error_message = payload.comment if payload.decision == ReviewDecision.REJECT else None
review_step.ended_at = utc_now()
if payload.decision == ReviewDecision.REJECT:
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
else:
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = WorkflowStepName.REVIEW
await session.commit()
@activity.defn
async def mark_workflow_failed_activity(payload: WorkflowFailureActivityInput) -> None:
"""Mark the persisted workflow state as failed."""
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step_result = await session.execute(
select(WorkflowStepORM)
.where(
WorkflowStepORM.workflow_run_id == payload.workflow_run_id,
WorkflowStepORM.step_name == payload.current_step,
)
.order_by(WorkflowStepORM.started_at.desc(), WorkflowStepORM.id.desc())
)
workflow_step = step_result.scalars().first()
if workflow_step is None:
workflow_step = WorkflowStepORM(
workflow_run_id=payload.workflow_run_id,
step_name=payload.current_step,
step_status=StepStatus.FAILED,
input_json=jsonable(payload),
started_at=utc_now(),
)
session.add(workflow_step)
workflow_step.step_status = StepStatus.FAILED
workflow_step.error_message = payload.message
workflow_step.output_json = jsonable({"message": payload.message, "status": payload.status.value})
workflow_step.ended_at = workflow_step.ended_at or utc_now()
order.status = payload.status
workflow_run.status = payload.status
workflow_run.current_step = payload.current_step
await session.commit()

View File

@@ -0,0 +1,19 @@
"""Scene mock activity."""
from temporalio import activity
from app.domain.enums import AssetType
from app.workers.activities.tryon_activities import execute_asset_step
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@activity.defn
async def run_scene_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock scene replacement."""
return await execute_asset_step(
payload,
AssetType.SCENE,
extra_metadata={"scene_ref_asset_id": payload.scene_ref_asset_id},
)

View File

@@ -0,0 +1,15 @@
"""Texture mock activity."""
from temporalio import activity
from app.domain.enums import AssetType
from app.workers.activities.tryon_activities import execute_asset_step
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@activity.defn
async def run_texture_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock garment texture enhancement."""
return await execute_asset_step(payload, AssetType.TEXTURE)

View File

@@ -0,0 +1,170 @@
"""Prepare-model and try-on mock activities plus shared helpers."""
from __future__ import annotations
from dataclasses import asdict, is_dataclass
from datetime import datetime, timezone
from enum import Enum
from typing import Any
from uuid import uuid4
from temporalio import activity
from app.domain.enums import AssetType, OrderStatus, StepStatus
from app.infra.db.models.asset import AssetORM
from app.infra.db.models.order import OrderORM
from app.infra.db.models.workflow_run import WorkflowRunORM
from app.infra.db.models.workflow_step import WorkflowStepORM
from app.infra.db.session import get_session_factory
from app.workers.workflows.types import MockActivityResult, StepActivityInput
def utc_now() -> datetime:
"""Return the current UTC timestamp."""
return datetime.now(timezone.utc)
def jsonable(value: Any) -> Any:
"""Convert enums, dataclasses, and nested values to JSON-safe structures."""
if value is None:
return None
if isinstance(value, Enum):
return value.value
if isinstance(value, datetime):
return value.isoformat()
if is_dataclass(value):
return jsonable(asdict(value))
if isinstance(value, dict):
return {key: jsonable(item) for key, item in value.items() if item is not None}
if isinstance(value, (list, tuple, set)):
return [jsonable(item) for item in value]
return value
def mock_uri(order_id: int, step_name: str, filename: str = "result.png") -> str:
"""Build a deterministic-looking mock URI for an order step."""
return f"mock://orders/{order_id}/{step_name}/{uuid4().hex[:8]}-{filename}"
async def load_order_and_run(session, order_id: int, workflow_run_id: int) -> tuple[OrderORM, WorkflowRunORM]:
"""Load the order and workflow run required by an activity."""
order = await session.get(OrderORM, order_id)
workflow_run = await session.get(WorkflowRunORM, workflow_run_id)
if order is None or workflow_run is None:
raise ValueError("Order or workflow run not found for activity execution")
return order, workflow_run
def create_step_record(payload: StepActivityInput) -> WorkflowStepORM:
"""Create a running workflow step row for an activity execution."""
return WorkflowStepORM(
workflow_run_id=payload.workflow_run_id,
step_name=payload.step_name,
step_status=StepStatus.RUNNING,
input_json=jsonable(payload),
started_at=utc_now(),
)
async def execute_asset_step(
payload: StepActivityInput,
asset_type: AssetType,
*,
score: float = 0.95,
filename: str = "result.png",
message: str = "mock success",
extra_metadata: dict[str, Any] | None = None,
finalize: bool = False,
) -> MockActivityResult:
"""Persist a mock asset-producing step and return its result."""
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
metadata = {
**payload.metadata,
"source_asset_id": payload.source_asset_id,
"selected_asset_id": payload.selected_asset_id,
**(extra_metadata or {}),
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=asset_type,
step_name=payload.step_name,
uri=mock_uri(payload.order_id, payload.step_name.value, filename),
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=score,
passed=True,
message=message,
metadata=jsonable(metadata) or {},
)
if finalize:
order.final_asset_id = asset.id
order.status = OrderStatus.SUCCEEDED
workflow_run.status = OrderStatus.SUCCEEDED
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise
@activity.defn
async def prepare_model_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock model preparation for the pipeline."""
return await execute_asset_step(
payload,
AssetType.PREPARED_MODEL,
extra_metadata={
"model_id": payload.model_id,
"pose_id": payload.pose_id,
"garment_asset_id": payload.garment_asset_id,
"scene_ref_asset_id": payload.scene_ref_asset_id,
},
)
@activity.defn
async def run_tryon_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock try-on rendering."""
return await execute_asset_step(
payload,
AssetType.TRYON,
extra_metadata={"prepared_asset_id": payload.source_asset_id},
)

84
app/workers/runner.py Normal file
View File

@@ -0,0 +1,84 @@
"""Temporal worker runner."""
import asyncio
from contextlib import AsyncExitStack
from temporalio.client import Client
from temporalio.worker import Worker
from app.infra.temporal.client import get_temporal_client
from app.infra.temporal.task_queues import (
IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
IMAGE_PIPELINE_QC_TASK_QUEUE,
)
from app.workers.activities.export_activities import run_export_activity
from app.workers.activities.face_activities import run_face_activity
from app.workers.activities.fusion_activities import run_fusion_activity
from app.workers.activities.qc_activities import run_qc_activity
from app.workers.activities.review_activities import (
complete_review_wait_activity,
mark_waiting_for_review_activity,
mark_workflow_failed_activity,
)
from app.workers.activities.scene_activities import run_scene_activity
from app.workers.activities.texture_activities import run_texture_activity
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
from app.workers.workflows.low_end_pipeline import LowEndPipelineWorkflow
from app.workers.workflows.mid_end_pipeline import MidEndPipelineWorkflow
def build_workers(client: Client) -> list[Worker]:
"""Create the worker set needed for the task queues in this MVP."""
return [
Worker(
client,
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
workflows=[LowEndPipelineWorkflow, MidEndPipelineWorkflow],
activities=[
prepare_model_activity,
mark_waiting_for_review_activity,
complete_review_wait_activity,
mark_workflow_failed_activity,
],
),
Worker(
client,
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
activities=[run_tryon_activity, run_scene_activity],
),
Worker(
client,
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
activities=[run_texture_activity, run_face_activity, run_fusion_activity],
),
Worker(
client,
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
activities=[run_qc_activity],
),
Worker(
client,
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
activities=[run_export_activity],
),
]
async def run_workers() -> None:
"""Start all Temporal workers and keep the process alive."""
client = await get_temporal_client()
workers = build_workers(client)
async with AsyncExitStack() as stack:
for worker in workers:
await stack.enter_async_context(worker)
await asyncio.Event().wait()
if __name__ == "__main__":
asyncio.run(run_workers())

View File

@@ -0,0 +1,152 @@
"""Low-end image pipeline workflow."""
from datetime import timedelta
from temporalio import workflow
from temporalio.common import RetryPolicy
with workflow.unsafe.imports_passed_through():
from app.domain.enums import OrderStatus, WorkflowStepName
from app.infra.temporal.task_queues import (
IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_QC_TASK_QUEUE,
)
from app.workers.activities.export_activities import run_export_activity
from app.workers.activities.qc_activities import run_qc_activity
from app.workers.activities.review_activities import mark_workflow_failed_activity
from app.workers.activities.scene_activities import run_scene_activity
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
from app.workers.workflows.types import (
PipelineWorkflowInput,
StepActivityInput,
WorkflowFailureActivityInput,
)
ACTIVITY_TIMEOUT = timedelta(seconds=30)
ACTIVITY_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2.0,
maximum_attempts=3,
)
@workflow.defn
class LowEndPipelineWorkflow:
"""Low-end fully automated image pipeline."""
@workflow.run
async def run(self, payload: PipelineWorkflowInput) -> dict[str, int | str | None]:
"""Execute the low-end workflow from start to finish."""
current_step = WorkflowStepName.PREPARE_MODEL
try:
prepared = await workflow.execute_activity(
prepare_model_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.PREPARE_MODEL,
model_id=payload.model_id,
pose_id=payload.pose_id,
garment_asset_id=payload.garment_asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.TRYON
tryon_result = await workflow.execute_activity(
run_tryon_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.TRYON,
source_asset_id=prepared.asset_id,
garment_asset_id=payload.garment_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.SCENE
scene_result = await workflow.execute_activity(
run_scene_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.SCENE,
source_asset_id=tryon_result.asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.QC
qc_result = await workflow.execute_activity(
run_qc_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.QC,
source_asset_id=scene_result.asset_id,
),
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
if not qc_result.passed:
await self._mark_failed(payload, current_step, qc_result.message)
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
current_step = WorkflowStepName.EXPORT
final_result = await workflow.execute_activity(
run_export_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.EXPORT,
source_asset_id=(qc_result.candidate_asset_ids or [scene_result.asset_id])[0],
),
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
return {
"order_id": payload.order_id,
"status": OrderStatus.SUCCEEDED.value,
"final_asset_id": final_result.asset_id,
}
except Exception as exc:
await self._mark_failed(payload, current_step, str(exc))
raise
async def _mark_failed(
self,
payload: PipelineWorkflowInput,
current_step: WorkflowStepName,
message: str,
) -> None:
"""Persist workflow failure state."""
await workflow.execute_activity(
mark_workflow_failed_activity,
WorkflowFailureActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
current_step=current_step,
message=message,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)

View File

@@ -0,0 +1,315 @@
"""Mid-end image pipeline workflow with review signal support."""
from datetime import timedelta
from temporalio import workflow
from temporalio.common import RetryPolicy
with workflow.unsafe.imports_passed_through():
from app.domain.enums import OrderStatus, ReviewDecision, WorkflowStepName
from app.infra.temporal.task_queues import (
IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
IMAGE_PIPELINE_QC_TASK_QUEUE,
)
from app.workers.activities.export_activities import run_export_activity
from app.workers.activities.face_activities import run_face_activity
from app.workers.activities.fusion_activities import run_fusion_activity
from app.workers.activities.qc_activities import run_qc_activity
from app.workers.activities.review_activities import (
complete_review_wait_activity,
mark_waiting_for_review_activity,
mark_workflow_failed_activity,
)
from app.workers.activities.scene_activities import run_scene_activity
from app.workers.activities.texture_activities import run_texture_activity
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
from app.workers.workflows.types import (
MockActivityResult,
PipelineWorkflowInput,
ReviewResolutionActivityInput,
ReviewSignalPayload,
ReviewWaitActivityInput,
StepActivityInput,
WorkflowFailureActivityInput,
)
ACTIVITY_TIMEOUT = timedelta(seconds=30)
ACTIVITY_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2.0,
maximum_attempts=3,
)
@workflow.defn
class MidEndPipelineWorkflow:
"""Mid-end workflow that pauses for human review and supports reruns."""
def __init__(self) -> None:
self._review_payload: ReviewSignalPayload | None = None
@workflow.signal
def submit_review(self, payload: ReviewSignalPayload) -> None:
"""Receive a review decision from the API layer."""
self._review_payload = payload
@workflow.run
async def run(self, payload: PipelineWorkflowInput) -> dict[str, int | str | None]:
"""Execute the mid-end workflow with a human review loop."""
current_step = WorkflowStepName.PREPARE_MODEL
try:
prepared = await workflow.execute_activity(
prepare_model_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.PREPARE_MODEL,
model_id=payload.model_id,
pose_id=payload.pose_id,
garment_asset_id=payload.garment_asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.TRYON
tryon_result = await workflow.execute_activity(
run_tryon_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.TRYON,
source_asset_id=prepared.asset_id,
garment_asset_id=payload.garment_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
current_step = WorkflowStepName.TEXTURE
texture_result = await self._run_texture(payload, scene_result.asset_id)
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
current_step = WorkflowStepName.QC
qc_result = await self._run_qc(payload, fusion_result.asset_id)
if not qc_result.passed:
await self._mark_failed(payload, current_step, qc_result.message)
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
while True:
current_step = WorkflowStepName.REVIEW
await workflow.execute_activity(
mark_waiting_for_review_activity,
ReviewWaitActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
candidate_asset_ids=qc_result.candidate_asset_ids,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
review_payload = await self._wait_for_review()
await workflow.execute_activity(
complete_review_wait_activity,
ReviewResolutionActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
decision=review_payload.decision,
reviewer_id=review_payload.reviewer_id,
selected_asset_id=review_payload.selected_asset_id,
comment=review_payload.comment,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
if review_payload.decision == ReviewDecision.APPROVE:
current_step = WorkflowStepName.EXPORT
export_source_id = review_payload.selected_asset_id
if export_source_id is None:
export_source_id = (qc_result.candidate_asset_ids or [fusion_result.asset_id])[0]
final_result = await workflow.execute_activity(
run_export_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.EXPORT,
source_asset_id=export_source_id,
),
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
return {
"order_id": payload.order_id,
"status": OrderStatus.SUCCEEDED.value,
"final_asset_id": final_result.asset_id,
}
if review_payload.decision == ReviewDecision.REJECT:
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
if review_payload.decision == ReviewDecision.RERUN_SCENE:
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
current_step = WorkflowStepName.TEXTURE
texture_result = await self._run_texture(payload, scene_result.asset_id)
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
elif review_payload.decision == ReviewDecision.RERUN_FACE:
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
elif review_payload.decision == ReviewDecision.RERUN_FUSION:
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
current_step = WorkflowStepName.QC
qc_result = await self._run_qc(payload, fusion_result.asset_id)
if not qc_result.passed:
await self._mark_failed(payload, current_step, qc_result.message)
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
except Exception as exc:
await self._mark_failed(payload, current_step, str(exc))
raise
async def _wait_for_review(self) -> ReviewSignalPayload:
"""Suspend the workflow until a review signal arrives."""
if self._review_payload is None:
await workflow.wait_condition(lambda: self._review_payload is not None)
assert self._review_payload is not None
review_payload = self._review_payload
self._review_payload = None
return review_payload
async def _run_scene(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""Execute the scene activity."""
return await workflow.execute_activity(
run_scene_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.SCENE,
source_asset_id=source_asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_texture(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""Execute the texture activity."""
return await workflow.execute_activity(
run_texture_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.TEXTURE,
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_face(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""Execute the face activity."""
return await workflow.execute_activity(
run_face_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.FACE,
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_fusion(
self,
payload: PipelineWorkflowInput,
source_asset_id: int | None,
face_asset_id: int | None,
) -> MockActivityResult:
"""Execute the fusion activity."""
return await workflow.execute_activity(
run_fusion_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.FUSION,
source_asset_id=source_asset_id,
selected_asset_id=face_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_qc(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""Execute the QC activity."""
return await workflow.execute_activity(
run_qc_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.QC,
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _mark_failed(
self,
payload: PipelineWorkflowInput,
current_step: WorkflowStepName,
message: str,
) -> None:
"""Persist workflow failure state."""
await workflow.execute_activity(
mark_workflow_failed_activity,
WorkflowFailureActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
current_step=current_step,
message=message,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)

View File

@@ -0,0 +1,137 @@
"""Shared workflow and activity payload types."""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from app.domain.enums import CustomerLevel, OrderStatus, ReviewDecision, ServiceMode, WorkflowStepName
def _coerce_enum(value: Any, enum_cls: type[Enum]) -> Any:
"""Coerce raw Temporal payload values back into enum instances."""
if value is None or isinstance(value, enum_cls):
return value
if isinstance(value, list):
value = "".join(str(item) for item in value)
return enum_cls(value)
@dataclass(slots=True)
class PipelineWorkflowInput:
"""Temporal workflow input for an image pipeline order."""
order_id: int
workflow_run_id: int
customer_level: CustomerLevel
service_mode: ServiceMode
model_id: int
pose_id: int
garment_asset_id: int
scene_ref_asset_id: int
def __post_init__(self) -> None:
"""Normalize enum-like values after Temporal deserialization."""
self.customer_level = _coerce_enum(self.customer_level, CustomerLevel)
self.service_mode = _coerce_enum(self.service_mode, ServiceMode)
@dataclass(slots=True)
class StepActivityInput:
"""Input payload shared by the mock pipeline activities."""
order_id: int
workflow_run_id: int
step_name: WorkflowStepName
model_id: int | None = None
pose_id: int | None = None
garment_asset_id: int | None = None
scene_ref_asset_id: int | None = None
source_asset_id: int | None = None
selected_asset_id: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Normalize enum-like values after Temporal deserialization."""
self.step_name = _coerce_enum(self.step_name, WorkflowStepName)
@dataclass(slots=True)
class MockActivityResult:
"""Common mock activity result structure."""
step_name: WorkflowStepName
success: bool
asset_id: int | None
uri: str | None
score: float | None = None
passed: bool | None = None
message: str = "mock success"
candidate_asset_ids: list[int] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Normalize enum-like values after Temporal deserialization."""
self.step_name = _coerce_enum(self.step_name, WorkflowStepName)
@dataclass(slots=True)
class ReviewSignalPayload:
"""Signal payload sent from the API to the mid-end workflow."""
decision: ReviewDecision
reviewer_id: int
selected_asset_id: int | None = None
comment: str | None = None
def __post_init__(self) -> None:
"""Normalize enum-like values after Temporal deserialization."""
self.decision = _coerce_enum(self.decision, ReviewDecision)
@dataclass(slots=True)
class ReviewWaitActivityInput:
"""Input for marking a workflow as waiting for review."""
order_id: int
workflow_run_id: int
candidate_asset_ids: list[int] = field(default_factory=list)
comment: str | None = None
@dataclass(slots=True)
class ReviewResolutionActivityInput:
"""Input for completing a waiting review state."""
order_id: int
workflow_run_id: int
decision: ReviewDecision
reviewer_id: int
selected_asset_id: int | None = None
comment: str | None = None
def __post_init__(self) -> None:
"""Normalize enum-like values after Temporal deserialization."""
self.decision = _coerce_enum(self.decision, ReviewDecision)
@dataclass(slots=True)
class WorkflowFailureActivityInput:
"""Input for marking a workflow as failed."""
order_id: int
workflow_run_id: int
current_step: WorkflowStepName
message: str
status: OrderStatus = OrderStatus.FAILED
def __post_init__(self) -> None:
"""Normalize enum-like values after Temporal deserialization."""
self.current_step = _coerce_enum(self.current_step, WorkflowStepName)
self.status = _coerce_enum(self.status, OrderStatus)