"""Prepare-model and try-on activities plus shared helpers.""" from __future__ import annotations from dataclasses import asdict, is_dataclass from datetime import datetime, timezone from enum import Enum from typing import Any from uuid import uuid4 from sqlalchemy import select from sqlalchemy.orm import selectinload from temporalio import activity from app.application.services.image_generation_service import build_image_generation_service from app.domain.enums import AssetType, LibraryResourceStatus, LibraryResourceType, OrderStatus, StepStatus from app.infra.db.models.asset import AssetORM from app.infra.db.models.library_resource import LibraryResourceORM from app.infra.db.models.library_resource_file import LibraryResourceFileORM from app.infra.db.models.order import OrderORM from app.infra.db.models.workflow_run import WorkflowRunORM from app.infra.db.models.workflow_step import WorkflowStepORM from app.infra.db.session import get_session_factory from app.infra.storage.s3 import S3ObjectStorageService from app.workers.workflows.types import MockActivityResult, StepActivityInput def utc_now() -> datetime: """Return the current UTC timestamp.""" return datetime.now(timezone.utc) def jsonable(value: Any) -> Any: """Convert enums, dataclasses, and nested values to JSON-safe structures.""" if value is None: return None if isinstance(value, Enum): return value.value if isinstance(value, datetime): return value.isoformat() if is_dataclass(value): return jsonable(asdict(value)) if isinstance(value, dict): return {key: jsonable(item) for key, item in value.items() if item is not None} if isinstance(value, (list, tuple, set)): return [jsonable(item) for item in value] return value def mock_uri(order_id: int, step_name: str, filename: str = "result.png") -> str: """Build a deterministic-looking mock URI for an order step.""" return f"mock://orders/{order_id}/{step_name}/{uuid4().hex[:8]}-{filename}" async def load_order_and_run(session, order_id: int, workflow_run_id: int) -> tuple[OrderORM, WorkflowRunORM]: """Load the order and workflow run required by an activity.""" order = await session.get(OrderORM, order_id) workflow_run = await session.get(WorkflowRunORM, workflow_run_id) if order is None or workflow_run is None: raise ValueError("Order or workflow run not found for activity execution") return order, workflow_run def create_step_record(payload: StepActivityInput) -> WorkflowStepORM: """Create a running workflow step row for an activity execution.""" return WorkflowStepORM( workflow_run_id=payload.workflow_run_id, step_name=payload.step_name, step_status=StepStatus.RUNNING, input_json=jsonable(payload), started_at=utc_now(), ) async def load_active_library_resource( session, resource_id: int, *, resource_type: LibraryResourceType, ) -> tuple[LibraryResourceORM, LibraryResourceFileORM]: """Load an active library resource and its original file from the library.""" result = await session.execute( select(LibraryResourceORM) .options(selectinload(LibraryResourceORM.files)) .where(LibraryResourceORM.id == resource_id) ) resource = result.scalar_one_or_none() if resource is None: raise ValueError(f"Library resource {resource_id} not found") if resource.resource_type != resource_type: raise ValueError(f"Resource {resource_id} is not a {resource_type.value} resource") if resource.status != LibraryResourceStatus.ACTIVE: raise ValueError(f"Library resource {resource_id} is not active") if resource.original_file_id is None: raise ValueError(f"Library resource {resource_id} is missing an original file") original_file = next((item for item in resource.files if item.id == resource.original_file_id), None) if original_file is None: raise ValueError(f"Library resource {resource_id} original file record not found") return resource, original_file def get_image_generation_service(): """Return the configured image-generation service.""" return build_image_generation_service() def get_order_artifact_storage_service() -> S3ObjectStorageService: """Return the object-storage service for workflow-generated images.""" return S3ObjectStorageService() def build_resource_input_snapshot( resource: LibraryResourceORM, original_file: LibraryResourceFileORM, ) -> dict[str, Any]: """Build a frontend-friendly snapshot of one library input resource.""" return { "resource_id": resource.id, "resource_name": resource.name, "original_file_id": original_file.id, "original_url": original_file.public_url, "mime_type": original_file.mime_type, "width": original_file.width, "height": original_file.height, } async def execute_asset_step( payload: StepActivityInput, asset_type: AssetType, *, score: float = 0.95, filename: str = "result.png", message: str = "mock success", extra_metadata: dict[str, Any] | None = None, finalize: bool = False, ) -> MockActivityResult: """Persist a mock asset-producing step and return its result.""" async with get_session_factory()() as session: order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id) step = create_step_record(payload) session.add(step) order.status = OrderStatus.RUNNING workflow_run.status = OrderStatus.RUNNING workflow_run.current_step = payload.step_name await session.flush() try: metadata = { **payload.metadata, "source_asset_id": payload.source_asset_id, "selected_asset_id": payload.selected_asset_id, **(extra_metadata or {}), } metadata = {key: value for key, value in metadata.items() if value is not None} asset = AssetORM( order_id=payload.order_id, asset_type=asset_type, step_name=payload.step_name, uri=mock_uri(payload.order_id, payload.step_name.value, filename), metadata_json=jsonable(metadata), ) session.add(asset) await session.flush() result = MockActivityResult( step_name=payload.step_name, success=True, asset_id=asset.id, uri=asset.uri, score=score, passed=True, message=message, metadata=jsonable(metadata) or {}, ) if finalize: order.final_asset_id = asset.id order.status = OrderStatus.SUCCEEDED workflow_run.status = OrderStatus.SUCCEEDED step.step_status = StepStatus.SUCCEEDED step.output_json = jsonable(result) step.ended_at = utc_now() await session.commit() return result except Exception as exc: step.step_status = StepStatus.FAILED step.error_message = str(exc) step.ended_at = utc_now() order.status = OrderStatus.FAILED workflow_run.status = OrderStatus.FAILED await session.commit() raise @activity.defn async def prepare_model_activity(payload: StepActivityInput) -> MockActivityResult: """Resolve a model resource into an order-scoped prepared-model asset.""" if payload.model_id is None: raise ValueError("prepare_model_activity requires model_id") async with get_session_factory()() as session: order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id) step = create_step_record(payload) session.add(step) order.status = OrderStatus.RUNNING workflow_run.status = OrderStatus.RUNNING workflow_run.current_step = payload.step_name await session.flush() try: resource, original_file = await load_active_library_resource( session, payload.model_id, resource_type=LibraryResourceType.MODEL, ) garment_snapshot = None if payload.garment_asset_id is not None: garment_resource, garment_original = await load_active_library_resource( session, payload.garment_asset_id, resource_type=LibraryResourceType.GARMENT, ) garment_snapshot = build_resource_input_snapshot(garment_resource, garment_original) scene_snapshot = None if payload.scene_ref_asset_id is not None: scene_resource, scene_original = await load_active_library_resource( session, payload.scene_ref_asset_id, resource_type=LibraryResourceType.SCENE, ) scene_snapshot = build_resource_input_snapshot(scene_resource, scene_original) metadata = { **payload.metadata, "model_id": payload.model_id, "pose_id": payload.pose_id, "garment_asset_id": payload.garment_asset_id, "scene_ref_asset_id": payload.scene_ref_asset_id, "library_resource_id": resource.id, "library_original_file_id": original_file.id, "library_original_url": original_file.public_url, "library_original_mime_type": original_file.mime_type, "library_original_width": original_file.width, "library_original_height": original_file.height, "model_input": build_resource_input_snapshot(resource, original_file), "garment_input": garment_snapshot, "scene_input": scene_snapshot, "pose_input": {"pose_id": payload.pose_id} if payload.pose_id is not None else None, "normalized": False, } metadata = {key: value for key, value in metadata.items() if value is not None} asset = AssetORM( order_id=payload.order_id, asset_type=AssetType.PREPARED_MODEL, step_name=payload.step_name, uri=original_file.public_url, metadata_json=jsonable(metadata), ) session.add(asset) await session.flush() result = MockActivityResult( step_name=payload.step_name, success=True, asset_id=asset.id, uri=asset.uri, score=1.0, passed=True, message="prepared model ready", metadata=jsonable(metadata) or {}, ) step.step_status = StepStatus.SUCCEEDED step.output_json = jsonable(result) step.ended_at = utc_now() await session.commit() return result except Exception as exc: step.step_status = StepStatus.FAILED step.error_message = str(exc) step.ended_at = utc_now() order.status = OrderStatus.FAILED workflow_run.status = OrderStatus.FAILED await session.commit() raise @activity.defn async def run_tryon_activity(payload: StepActivityInput) -> MockActivityResult: """执行试衣渲染步骤,或在未接真实能力时走 mock 分支。 流程: 1. 读取当前配置的图片生成 service。 2. 如果当前仍是 mock service,就退回通用的 mock 资产产出逻辑。 3. 如果是真实模式,先读取 prepare_model_activity 产出的 prepared_model 资产。 4. 再读取本次选中的服装资源,并解析出它的原图 URL。 5. 把“模特准备图 + 服装原图”一起发给 provider 做试衣生成。 6. 把生成结果上传到 S3,并落成订单内的一条 TRYON 资产。 7. 在 metadata 里记录 provider、model、prompt 和输入来源,方便追踪排查。 """ service = get_image_generation_service() # 保留旧的 mock 分支,这样在没有接入真实 provider 时 workflow 也还能跑通。 if service.__class__.__name__ == "MockImageGenerationService": return await execute_asset_step( payload, AssetType.TRYON, extra_metadata={"prepared_asset_id": payload.source_asset_id}, ) if payload.source_asset_id is None: raise ValueError("run_tryon_activity requires source_asset_id") if payload.garment_asset_id is None: raise ValueError("run_tryon_activity requires garment_asset_id") async with get_session_factory()() as session: order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id) step = create_step_record(payload) session.add(step) order.status = OrderStatus.RUNNING workflow_run.status = OrderStatus.RUNNING workflow_run.current_step = payload.step_name await session.flush() try: # 试衣步骤的输入起点固定是上一阶段产出的 prepared_model 资产。 prepared_asset = await session.get(AssetORM, payload.source_asset_id) if prepared_asset is None or prepared_asset.order_id != payload.order_id: raise ValueError(f"Prepared asset {payload.source_asset_id} not found for order {payload.order_id}") if prepared_asset.asset_type != AssetType.PREPARED_MODEL: raise ValueError(f"Asset {payload.source_asset_id} is not a prepared_model asset") # 服装素材来自共享资源库,workflow 实际消费的是资源库里的原图。 garment_resource, garment_original = await load_active_library_resource( session, payload.garment_asset_id, resource_type=LibraryResourceType.GARMENT, ) # provider 接收两张真实输入图:模特准备图 + 服装参考图。 generated = await service.generate_tryon_image( person_image_url=prepared_asset.uri, garment_image_url=garment_original.public_url, ) # workflow 生成出的结果图先上传到 S3,再登记成订单资产。 storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image( order_id=payload.order_id, step_name=payload.step_name, image_bytes=generated.image_bytes, mime_type=generated.mime_type, ) # 记录足够的追踪信息,方便回溯这张试衣图由哪些输入和哪个 provider 产出。 metadata = { **payload.metadata, "prepared_asset_id": prepared_asset.id, "garment_resource_id": garment_resource.id, "garment_original_file_id": garment_original.id, "garment_original_url": garment_original.public_url, "provider": generated.provider, "model": generated.model, "storage_key": storage_key, "mime_type": generated.mime_type, "prompt": generated.prompt, } metadata = {key: value for key, value in metadata.items() if value is not None} asset = AssetORM( order_id=payload.order_id, asset_type=AssetType.TRYON, step_name=payload.step_name, uri=public_url, metadata_json=jsonable(metadata), ) session.add(asset) await session.flush() result = MockActivityResult( step_name=payload.step_name, success=True, asset_id=asset.id, uri=asset.uri, score=1.0, passed=True, message="try-on generated", metadata=jsonable(metadata) or {}, ) step.step_status = StepStatus.SUCCEEDED step.output_json = jsonable(result) step.ended_at = utc_now() await session.commit() return result except Exception as exc: step.step_status = StepStatus.FAILED step.error_message = str(exc) step.ended_at = utc_now() order.status = OrderStatus.FAILED workflow_run.status = OrderStatus.FAILED await session.commit() raise