422 lines
16 KiB
Python
422 lines
16 KiB
Python
"""Prepare-model and try-on activities plus shared helpers."""
|
||
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import asdict, is_dataclass
|
||
from datetime import datetime, timezone
|
||
from enum import Enum
|
||
from typing import Any
|
||
from uuid import uuid4
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.orm import selectinload
|
||
from temporalio import activity
|
||
|
||
from app.application.services.image_generation_service import build_image_generation_service
|
||
from app.domain.enums import AssetType, LibraryResourceStatus, LibraryResourceType, OrderStatus, StepStatus
|
||
from app.infra.db.models.asset import AssetORM
|
||
from app.infra.db.models.library_resource import LibraryResourceORM
|
||
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
|
||
from app.infra.db.models.order import OrderORM
|
||
from app.infra.db.models.workflow_run import WorkflowRunORM
|
||
from app.infra.db.models.workflow_step import WorkflowStepORM
|
||
from app.infra.db.session import get_session_factory
|
||
from app.infra.storage.s3 import S3ObjectStorageService
|
||
from app.workers.workflows.types import MockActivityResult, StepActivityInput
|
||
|
||
|
||
def utc_now() -> datetime:
|
||
"""Return the current UTC timestamp."""
|
||
|
||
return datetime.now(timezone.utc)
|
||
|
||
|
||
def jsonable(value: Any) -> Any:
|
||
"""Convert enums, dataclasses, and nested values to JSON-safe structures."""
|
||
|
||
if value is None:
|
||
return None
|
||
if isinstance(value, Enum):
|
||
return value.value
|
||
if isinstance(value, datetime):
|
||
return value.isoformat()
|
||
if is_dataclass(value):
|
||
return jsonable(asdict(value))
|
||
if isinstance(value, dict):
|
||
return {key: jsonable(item) for key, item in value.items() if item is not None}
|
||
if isinstance(value, (list, tuple, set)):
|
||
return [jsonable(item) for item in value]
|
||
return value
|
||
|
||
|
||
def mock_uri(order_id: int, step_name: str, filename: str = "result.png") -> str:
|
||
"""Build a deterministic-looking mock URI for an order step."""
|
||
|
||
return f"mock://orders/{order_id}/{step_name}/{uuid4().hex[:8]}-{filename}"
|
||
|
||
|
||
async def load_order_and_run(session, order_id: int, workflow_run_id: int) -> tuple[OrderORM, WorkflowRunORM]:
|
||
"""Load the order and workflow run required by an activity."""
|
||
|
||
order = await session.get(OrderORM, order_id)
|
||
workflow_run = await session.get(WorkflowRunORM, workflow_run_id)
|
||
if order is None or workflow_run is None:
|
||
raise ValueError("Order or workflow run not found for activity execution")
|
||
return order, workflow_run
|
||
|
||
|
||
def create_step_record(payload: StepActivityInput) -> WorkflowStepORM:
|
||
"""Create a running workflow step row for an activity execution."""
|
||
|
||
return WorkflowStepORM(
|
||
workflow_run_id=payload.workflow_run_id,
|
||
step_name=payload.step_name,
|
||
step_status=StepStatus.RUNNING,
|
||
input_json=jsonable(payload),
|
||
started_at=utc_now(),
|
||
)
|
||
|
||
|
||
async def load_active_library_resource(
|
||
session,
|
||
resource_id: int,
|
||
*,
|
||
resource_type: LibraryResourceType,
|
||
) -> tuple[LibraryResourceORM, LibraryResourceFileORM]:
|
||
"""Load an active library resource and its original file from the library."""
|
||
|
||
result = await session.execute(
|
||
select(LibraryResourceORM)
|
||
.options(selectinload(LibraryResourceORM.files))
|
||
.where(LibraryResourceORM.id == resource_id)
|
||
)
|
||
resource = result.scalar_one_or_none()
|
||
if resource is None:
|
||
raise ValueError(f"Library resource {resource_id} not found")
|
||
if resource.resource_type != resource_type:
|
||
raise ValueError(f"Resource {resource_id} is not a {resource_type.value} resource")
|
||
if resource.status != LibraryResourceStatus.ACTIVE:
|
||
raise ValueError(f"Library resource {resource_id} is not active")
|
||
if resource.original_file_id is None:
|
||
raise ValueError(f"Library resource {resource_id} is missing an original file")
|
||
|
||
original_file = next((item for item in resource.files if item.id == resource.original_file_id), None)
|
||
if original_file is None:
|
||
raise ValueError(f"Library resource {resource_id} original file record not found")
|
||
return resource, original_file
|
||
|
||
|
||
def get_image_generation_service():
|
||
"""Return the configured image-generation service."""
|
||
|
||
return build_image_generation_service()
|
||
|
||
|
||
def get_order_artifact_storage_service() -> S3ObjectStorageService:
|
||
"""Return the object-storage service for workflow-generated images."""
|
||
|
||
return S3ObjectStorageService()
|
||
|
||
|
||
def build_resource_input_snapshot(
|
||
resource: LibraryResourceORM,
|
||
original_file: LibraryResourceFileORM,
|
||
) -> dict[str, Any]:
|
||
"""Build a frontend-friendly snapshot of one library input resource."""
|
||
|
||
return {
|
||
"resource_id": resource.id,
|
||
"resource_name": resource.name,
|
||
"original_file_id": original_file.id,
|
||
"original_url": original_file.public_url,
|
||
"mime_type": original_file.mime_type,
|
||
"width": original_file.width,
|
||
"height": original_file.height,
|
||
}
|
||
|
||
|
||
async def execute_asset_step(
|
||
payload: StepActivityInput,
|
||
asset_type: AssetType,
|
||
*,
|
||
score: float = 0.95,
|
||
filename: str = "result.png",
|
||
message: str = "mock success",
|
||
extra_metadata: dict[str, Any] | None = None,
|
||
finalize: bool = False,
|
||
) -> MockActivityResult:
|
||
"""Persist a mock asset-producing step and return its result."""
|
||
|
||
async with get_session_factory()() as session:
|
||
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
||
step = create_step_record(payload)
|
||
session.add(step)
|
||
|
||
order.status = OrderStatus.RUNNING
|
||
workflow_run.status = OrderStatus.RUNNING
|
||
workflow_run.current_step = payload.step_name
|
||
await session.flush()
|
||
|
||
try:
|
||
metadata = {
|
||
**payload.metadata,
|
||
"source_asset_id": payload.source_asset_id,
|
||
"selected_asset_id": payload.selected_asset_id,
|
||
**(extra_metadata or {}),
|
||
}
|
||
metadata = {key: value for key, value in metadata.items() if value is not None}
|
||
|
||
asset = AssetORM(
|
||
order_id=payload.order_id,
|
||
asset_type=asset_type,
|
||
step_name=payload.step_name,
|
||
uri=mock_uri(payload.order_id, payload.step_name.value, filename),
|
||
metadata_json=jsonable(metadata),
|
||
)
|
||
session.add(asset)
|
||
await session.flush()
|
||
|
||
result = MockActivityResult(
|
||
step_name=payload.step_name,
|
||
success=True,
|
||
asset_id=asset.id,
|
||
uri=asset.uri,
|
||
score=score,
|
||
passed=True,
|
||
message=message,
|
||
metadata=jsonable(metadata) or {},
|
||
)
|
||
|
||
if finalize:
|
||
order.final_asset_id = asset.id
|
||
order.status = OrderStatus.SUCCEEDED
|
||
workflow_run.status = OrderStatus.SUCCEEDED
|
||
|
||
step.step_status = StepStatus.SUCCEEDED
|
||
step.output_json = jsonable(result)
|
||
step.ended_at = utc_now()
|
||
await session.commit()
|
||
return result
|
||
except Exception as exc:
|
||
step.step_status = StepStatus.FAILED
|
||
step.error_message = str(exc)
|
||
step.ended_at = utc_now()
|
||
order.status = OrderStatus.FAILED
|
||
workflow_run.status = OrderStatus.FAILED
|
||
await session.commit()
|
||
raise
|
||
|
||
|
||
@activity.defn
|
||
async def prepare_model_activity(payload: StepActivityInput) -> MockActivityResult:
|
||
"""Resolve a model resource into an order-scoped prepared-model asset."""
|
||
|
||
if payload.model_id is None:
|
||
raise ValueError("prepare_model_activity requires model_id")
|
||
|
||
async with get_session_factory()() as session:
|
||
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
||
step = create_step_record(payload)
|
||
session.add(step)
|
||
|
||
order.status = OrderStatus.RUNNING
|
||
workflow_run.status = OrderStatus.RUNNING
|
||
workflow_run.current_step = payload.step_name
|
||
await session.flush()
|
||
|
||
try:
|
||
resource, original_file = await load_active_library_resource(
|
||
session,
|
||
payload.model_id,
|
||
resource_type=LibraryResourceType.MODEL,
|
||
)
|
||
garment_snapshot = None
|
||
if payload.garment_asset_id is not None:
|
||
garment_resource, garment_original = await load_active_library_resource(
|
||
session,
|
||
payload.garment_asset_id,
|
||
resource_type=LibraryResourceType.GARMENT,
|
||
)
|
||
garment_snapshot = build_resource_input_snapshot(garment_resource, garment_original)
|
||
|
||
scene_snapshot = None
|
||
if payload.scene_ref_asset_id is not None:
|
||
scene_resource, scene_original = await load_active_library_resource(
|
||
session,
|
||
payload.scene_ref_asset_id,
|
||
resource_type=LibraryResourceType.SCENE,
|
||
)
|
||
scene_snapshot = build_resource_input_snapshot(scene_resource, scene_original)
|
||
|
||
metadata = {
|
||
**payload.metadata,
|
||
"model_id": payload.model_id,
|
||
"pose_id": payload.pose_id,
|
||
"garment_asset_id": payload.garment_asset_id,
|
||
"scene_ref_asset_id": payload.scene_ref_asset_id,
|
||
"library_resource_id": resource.id,
|
||
"library_original_file_id": original_file.id,
|
||
"library_original_url": original_file.public_url,
|
||
"library_original_mime_type": original_file.mime_type,
|
||
"library_original_width": original_file.width,
|
||
"library_original_height": original_file.height,
|
||
"model_input": build_resource_input_snapshot(resource, original_file),
|
||
"garment_input": garment_snapshot,
|
||
"scene_input": scene_snapshot,
|
||
"pose_input": {"pose_id": payload.pose_id} if payload.pose_id is not None else None,
|
||
"normalized": False,
|
||
}
|
||
metadata = {key: value for key, value in metadata.items() if value is not None}
|
||
|
||
asset = AssetORM(
|
||
order_id=payload.order_id,
|
||
asset_type=AssetType.PREPARED_MODEL,
|
||
step_name=payload.step_name,
|
||
uri=original_file.public_url,
|
||
metadata_json=jsonable(metadata),
|
||
)
|
||
session.add(asset)
|
||
await session.flush()
|
||
|
||
result = MockActivityResult(
|
||
step_name=payload.step_name,
|
||
success=True,
|
||
asset_id=asset.id,
|
||
uri=asset.uri,
|
||
score=1.0,
|
||
passed=True,
|
||
message="prepared model ready",
|
||
metadata=jsonable(metadata) or {},
|
||
)
|
||
|
||
step.step_status = StepStatus.SUCCEEDED
|
||
step.output_json = jsonable(result)
|
||
step.ended_at = utc_now()
|
||
await session.commit()
|
||
return result
|
||
except Exception as exc:
|
||
step.step_status = StepStatus.FAILED
|
||
step.error_message = str(exc)
|
||
step.ended_at = utc_now()
|
||
order.status = OrderStatus.FAILED
|
||
workflow_run.status = OrderStatus.FAILED
|
||
await session.commit()
|
||
raise
|
||
|
||
|
||
@activity.defn
|
||
async def run_tryon_activity(payload: StepActivityInput) -> MockActivityResult:
|
||
"""执行试衣渲染步骤,或在未接真实能力时走 mock 分支。
|
||
|
||
流程:
|
||
1. 读取当前配置的图片生成 service。
|
||
2. 如果当前仍是 mock service,就退回通用的 mock 资产产出逻辑。
|
||
3. 如果是真实模式,先读取 prepare_model_activity 产出的 prepared_model 资产。
|
||
4. 再读取本次选中的服装资源,并解析出它的原图 URL。
|
||
5. 把“模特准备图 + 服装原图”一起发给 provider 做试衣生成。
|
||
6. 把生成结果上传到 S3,并落成订单内的一条 TRYON 资产。
|
||
7. 在 metadata 里记录 provider、model、prompt 和输入来源,方便追踪排查。
|
||
"""
|
||
|
||
service = get_image_generation_service()
|
||
# 保留旧的 mock 分支,这样在没有接入真实 provider 时 workflow 也还能跑通。
|
||
if service.__class__.__name__ == "MockImageGenerationService":
|
||
return await execute_asset_step(
|
||
payload,
|
||
AssetType.TRYON,
|
||
extra_metadata={"prepared_asset_id": payload.source_asset_id},
|
||
)
|
||
|
||
if payload.source_asset_id is None:
|
||
raise ValueError("run_tryon_activity requires source_asset_id")
|
||
if payload.garment_asset_id is None:
|
||
raise ValueError("run_tryon_activity requires garment_asset_id")
|
||
|
||
async with get_session_factory()() as session:
|
||
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
||
step = create_step_record(payload)
|
||
session.add(step)
|
||
|
||
order.status = OrderStatus.RUNNING
|
||
workflow_run.status = OrderStatus.RUNNING
|
||
workflow_run.current_step = payload.step_name
|
||
await session.flush()
|
||
|
||
try:
|
||
# 试衣步骤的输入起点固定是上一阶段产出的 prepared_model 资产。
|
||
prepared_asset = await session.get(AssetORM, payload.source_asset_id)
|
||
if prepared_asset is None or prepared_asset.order_id != payload.order_id:
|
||
raise ValueError(f"Prepared asset {payload.source_asset_id} not found for order {payload.order_id}")
|
||
if prepared_asset.asset_type != AssetType.PREPARED_MODEL:
|
||
raise ValueError(f"Asset {payload.source_asset_id} is not a prepared_model asset")
|
||
|
||
# 服装素材来自共享资源库,workflow 实际消费的是资源库里的原图。
|
||
garment_resource, garment_original = await load_active_library_resource(
|
||
session,
|
||
payload.garment_asset_id,
|
||
resource_type=LibraryResourceType.GARMENT,
|
||
)
|
||
|
||
# provider 接收两张真实输入图:模特准备图 + 服装参考图。
|
||
generated = await service.generate_tryon_image(
|
||
person_image_url=prepared_asset.uri,
|
||
garment_image_url=garment_original.public_url,
|
||
)
|
||
# workflow 生成出的结果图先上传到 S3,再登记成订单资产。
|
||
storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image(
|
||
order_id=payload.order_id,
|
||
step_name=payload.step_name,
|
||
image_bytes=generated.image_bytes,
|
||
mime_type=generated.mime_type,
|
||
)
|
||
|
||
# 记录足够的追踪信息,方便回溯这张试衣图由哪些输入和哪个 provider 产出。
|
||
metadata = {
|
||
**payload.metadata,
|
||
"prepared_asset_id": prepared_asset.id,
|
||
"garment_resource_id": garment_resource.id,
|
||
"garment_original_file_id": garment_original.id,
|
||
"garment_original_url": garment_original.public_url,
|
||
"provider": generated.provider,
|
||
"model": generated.model,
|
||
"storage_key": storage_key,
|
||
"mime_type": generated.mime_type,
|
||
"prompt": generated.prompt,
|
||
}
|
||
metadata = {key: value for key, value in metadata.items() if value is not None}
|
||
|
||
asset = AssetORM(
|
||
order_id=payload.order_id,
|
||
asset_type=AssetType.TRYON,
|
||
step_name=payload.step_name,
|
||
uri=public_url,
|
||
metadata_json=jsonable(metadata),
|
||
)
|
||
session.add(asset)
|
||
await session.flush()
|
||
|
||
result = MockActivityResult(
|
||
step_name=payload.step_name,
|
||
success=True,
|
||
asset_id=asset.id,
|
||
uri=asset.uri,
|
||
score=1.0,
|
||
passed=True,
|
||
message="try-on generated",
|
||
metadata=jsonable(metadata) or {},
|
||
)
|
||
|
||
step.step_status = StepStatus.SUCCEEDED
|
||
step.output_json = jsonable(result)
|
||
step.ended_at = utc_now()
|
||
await session.commit()
|
||
return result
|
||
except Exception as exc:
|
||
step.step_status = StepStatus.FAILED
|
||
step.error_message = str(exc)
|
||
step.ended_at = utc_now()
|
||
order.status = OrderStatus.FAILED
|
||
workflow_run.status = OrderStatus.FAILED
|
||
await session.commit()
|
||
raise
|