171 lines
5.6 KiB
Python
171 lines
5.6 KiB
Python
"""Prepare-model and try-on mock activities plus shared helpers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import asdict, is_dataclass
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Any
|
|
from uuid import uuid4
|
|
|
|
from temporalio import activity
|
|
|
|
from app.domain.enums import AssetType, OrderStatus, StepStatus
|
|
from app.infra.db.models.asset import AssetORM
|
|
from app.infra.db.models.order import OrderORM
|
|
from app.infra.db.models.workflow_run import WorkflowRunORM
|
|
from app.infra.db.models.workflow_step import WorkflowStepORM
|
|
from app.infra.db.session import get_session_factory
|
|
from app.workers.workflows.types import MockActivityResult, StepActivityInput
|
|
|
|
|
|
def utc_now() -> datetime:
|
|
"""Return the current UTC timestamp."""
|
|
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
def jsonable(value: Any) -> Any:
|
|
"""Convert enums, dataclasses, and nested values to JSON-safe structures."""
|
|
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, Enum):
|
|
return value.value
|
|
if isinstance(value, datetime):
|
|
return value.isoformat()
|
|
if is_dataclass(value):
|
|
return jsonable(asdict(value))
|
|
if isinstance(value, dict):
|
|
return {key: jsonable(item) for key, item in value.items() if item is not None}
|
|
if isinstance(value, (list, tuple, set)):
|
|
return [jsonable(item) for item in value]
|
|
return value
|
|
|
|
|
|
def mock_uri(order_id: int, step_name: str, filename: str = "result.png") -> str:
|
|
"""Build a deterministic-looking mock URI for an order step."""
|
|
|
|
return f"mock://orders/{order_id}/{step_name}/{uuid4().hex[:8]}-{filename}"
|
|
|
|
|
|
async def load_order_and_run(session, order_id: int, workflow_run_id: int) -> tuple[OrderORM, WorkflowRunORM]:
|
|
"""Load the order and workflow run required by an activity."""
|
|
|
|
order = await session.get(OrderORM, order_id)
|
|
workflow_run = await session.get(WorkflowRunORM, workflow_run_id)
|
|
if order is None or workflow_run is None:
|
|
raise ValueError("Order or workflow run not found for activity execution")
|
|
return order, workflow_run
|
|
|
|
|
|
def create_step_record(payload: StepActivityInput) -> WorkflowStepORM:
|
|
"""Create a running workflow step row for an activity execution."""
|
|
|
|
return WorkflowStepORM(
|
|
workflow_run_id=payload.workflow_run_id,
|
|
step_name=payload.step_name,
|
|
step_status=StepStatus.RUNNING,
|
|
input_json=jsonable(payload),
|
|
started_at=utc_now(),
|
|
)
|
|
|
|
|
|
async def execute_asset_step(
|
|
payload: StepActivityInput,
|
|
asset_type: AssetType,
|
|
*,
|
|
score: float = 0.95,
|
|
filename: str = "result.png",
|
|
message: str = "mock success",
|
|
extra_metadata: dict[str, Any] | None = None,
|
|
finalize: bool = False,
|
|
) -> MockActivityResult:
|
|
"""Persist a mock asset-producing step and return its result."""
|
|
|
|
async with get_session_factory()() as session:
|
|
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
|
step = create_step_record(payload)
|
|
session.add(step)
|
|
|
|
order.status = OrderStatus.RUNNING
|
|
workflow_run.status = OrderStatus.RUNNING
|
|
workflow_run.current_step = payload.step_name
|
|
await session.flush()
|
|
|
|
try:
|
|
metadata = {
|
|
**payload.metadata,
|
|
"source_asset_id": payload.source_asset_id,
|
|
"selected_asset_id": payload.selected_asset_id,
|
|
**(extra_metadata or {}),
|
|
}
|
|
metadata = {key: value for key, value in metadata.items() if value is not None}
|
|
|
|
asset = AssetORM(
|
|
order_id=payload.order_id,
|
|
asset_type=asset_type,
|
|
step_name=payload.step_name,
|
|
uri=mock_uri(payload.order_id, payload.step_name.value, filename),
|
|
metadata_json=jsonable(metadata),
|
|
)
|
|
session.add(asset)
|
|
await session.flush()
|
|
|
|
result = MockActivityResult(
|
|
step_name=payload.step_name,
|
|
success=True,
|
|
asset_id=asset.id,
|
|
uri=asset.uri,
|
|
score=score,
|
|
passed=True,
|
|
message=message,
|
|
metadata=jsonable(metadata) or {},
|
|
)
|
|
|
|
if finalize:
|
|
order.final_asset_id = asset.id
|
|
order.status = OrderStatus.SUCCEEDED
|
|
workflow_run.status = OrderStatus.SUCCEEDED
|
|
|
|
step.step_status = StepStatus.SUCCEEDED
|
|
step.output_json = jsonable(result)
|
|
step.ended_at = utc_now()
|
|
await session.commit()
|
|
return result
|
|
except Exception as exc:
|
|
step.step_status = StepStatus.FAILED
|
|
step.error_message = str(exc)
|
|
step.ended_at = utc_now()
|
|
order.status = OrderStatus.FAILED
|
|
workflow_run.status = OrderStatus.FAILED
|
|
await session.commit()
|
|
raise
|
|
|
|
|
|
@activity.defn
|
|
async def prepare_model_activity(payload: StepActivityInput) -> MockActivityResult:
|
|
"""Mock model preparation for the pipeline."""
|
|
|
|
return await execute_asset_step(
|
|
payload,
|
|
AssetType.PREPARED_MODEL,
|
|
extra_metadata={
|
|
"model_id": payload.model_id,
|
|
"pose_id": payload.pose_id,
|
|
"garment_asset_id": payload.garment_asset_id,
|
|
"scene_ref_asset_id": payload.scene_ref_asset_id,
|
|
},
|
|
)
|
|
|
|
|
|
@activity.defn
|
|
async def run_tryon_activity(payload: StepActivityInput) -> MockActivityResult:
|
|
"""Mock try-on rendering."""
|
|
|
|
return await execute_asset_step(
|
|
payload,
|
|
AssetType.TRYON,
|
|
extra_metadata={"prepared_asset_id": payload.source_asset_id},
|
|
)
|