Files
auto-virtual-tryon/app/workers/activities/tryon_activities.py

422 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Prepare-model and try-on activities plus shared helpers."""
from __future__ import annotations
from dataclasses import asdict, is_dataclass
from datetime import datetime, timezone
from enum import Enum
from typing import Any
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from temporalio import activity
from app.application.services.image_generation_service import build_image_generation_service
from app.domain.enums import AssetType, LibraryResourceStatus, LibraryResourceType, OrderStatus, StepStatus
from app.infra.db.models.asset import AssetORM
from app.infra.db.models.library_resource import LibraryResourceORM
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
from app.infra.db.models.order import OrderORM
from app.infra.db.models.workflow_run import WorkflowRunORM
from app.infra.db.models.workflow_step import WorkflowStepORM
from app.infra.db.session import get_session_factory
from app.infra.storage.s3 import S3ObjectStorageService
from app.workers.workflows.types import MockActivityResult, StepActivityInput
def utc_now() -> datetime:
"""Return the current UTC timestamp."""
return datetime.now(timezone.utc)
def jsonable(value: Any) -> Any:
"""Convert enums, dataclasses, and nested values to JSON-safe structures."""
if value is None:
return None
if isinstance(value, Enum):
return value.value
if isinstance(value, datetime):
return value.isoformat()
if is_dataclass(value):
return jsonable(asdict(value))
if isinstance(value, dict):
return {key: jsonable(item) for key, item in value.items() if item is not None}
if isinstance(value, (list, tuple, set)):
return [jsonable(item) for item in value]
return value
def mock_uri(order_id: int, step_name: str, filename: str = "result.png") -> str:
"""Build a deterministic-looking mock URI for an order step."""
return f"mock://orders/{order_id}/{step_name}/{uuid4().hex[:8]}-{filename}"
async def load_order_and_run(session, order_id: int, workflow_run_id: int) -> tuple[OrderORM, WorkflowRunORM]:
"""Load the order and workflow run required by an activity."""
order = await session.get(OrderORM, order_id)
workflow_run = await session.get(WorkflowRunORM, workflow_run_id)
if order is None or workflow_run is None:
raise ValueError("Order or workflow run not found for activity execution")
return order, workflow_run
def create_step_record(payload: StepActivityInput) -> WorkflowStepORM:
"""Create a running workflow step row for an activity execution."""
return WorkflowStepORM(
workflow_run_id=payload.workflow_run_id,
step_name=payload.step_name,
step_status=StepStatus.RUNNING,
input_json=jsonable(payload),
started_at=utc_now(),
)
async def load_active_library_resource(
session,
resource_id: int,
*,
resource_type: LibraryResourceType,
) -> tuple[LibraryResourceORM, LibraryResourceFileORM]:
"""Load an active library resource and its original file from the library."""
result = await session.execute(
select(LibraryResourceORM)
.options(selectinload(LibraryResourceORM.files))
.where(LibraryResourceORM.id == resource_id)
)
resource = result.scalar_one_or_none()
if resource is None:
raise ValueError(f"Library resource {resource_id} not found")
if resource.resource_type != resource_type:
raise ValueError(f"Resource {resource_id} is not a {resource_type.value} resource")
if resource.status != LibraryResourceStatus.ACTIVE:
raise ValueError(f"Library resource {resource_id} is not active")
if resource.original_file_id is None:
raise ValueError(f"Library resource {resource_id} is missing an original file")
original_file = next((item for item in resource.files if item.id == resource.original_file_id), None)
if original_file is None:
raise ValueError(f"Library resource {resource_id} original file record not found")
return resource, original_file
def get_image_generation_service():
"""Return the configured image-generation service."""
return build_image_generation_service()
def get_order_artifact_storage_service() -> S3ObjectStorageService:
"""Return the object-storage service for workflow-generated images."""
return S3ObjectStorageService()
def build_resource_input_snapshot(
resource: LibraryResourceORM,
original_file: LibraryResourceFileORM,
) -> dict[str, Any]:
"""Build a frontend-friendly snapshot of one library input resource."""
return {
"resource_id": resource.id,
"resource_name": resource.name,
"original_file_id": original_file.id,
"original_url": original_file.public_url,
"mime_type": original_file.mime_type,
"width": original_file.width,
"height": original_file.height,
}
async def execute_asset_step(
payload: StepActivityInput,
asset_type: AssetType,
*,
score: float = 0.95,
filename: str = "result.png",
message: str = "mock success",
extra_metadata: dict[str, Any] | None = None,
finalize: bool = False,
) -> MockActivityResult:
"""Persist a mock asset-producing step and return its result."""
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
metadata = {
**payload.metadata,
"source_asset_id": payload.source_asset_id,
"selected_asset_id": payload.selected_asset_id,
**(extra_metadata or {}),
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=asset_type,
step_name=payload.step_name,
uri=mock_uri(payload.order_id, payload.step_name.value, filename),
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=score,
passed=True,
message=message,
metadata=jsonable(metadata) or {},
)
if finalize:
order.final_asset_id = asset.id
order.status = OrderStatus.SUCCEEDED
workflow_run.status = OrderStatus.SUCCEEDED
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise
@activity.defn
async def prepare_model_activity(payload: StepActivityInput) -> MockActivityResult:
"""Resolve a model resource into an order-scoped prepared-model asset."""
if payload.model_id is None:
raise ValueError("prepare_model_activity requires model_id")
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
resource, original_file = await load_active_library_resource(
session,
payload.model_id,
resource_type=LibraryResourceType.MODEL,
)
garment_snapshot = None
if payload.garment_asset_id is not None:
garment_resource, garment_original = await load_active_library_resource(
session,
payload.garment_asset_id,
resource_type=LibraryResourceType.GARMENT,
)
garment_snapshot = build_resource_input_snapshot(garment_resource, garment_original)
scene_snapshot = None
if payload.scene_ref_asset_id is not None:
scene_resource, scene_original = await load_active_library_resource(
session,
payload.scene_ref_asset_id,
resource_type=LibraryResourceType.SCENE,
)
scene_snapshot = build_resource_input_snapshot(scene_resource, scene_original)
metadata = {
**payload.metadata,
"model_id": payload.model_id,
"pose_id": payload.pose_id,
"garment_asset_id": payload.garment_asset_id,
"scene_ref_asset_id": payload.scene_ref_asset_id,
"library_resource_id": resource.id,
"library_original_file_id": original_file.id,
"library_original_url": original_file.public_url,
"library_original_mime_type": original_file.mime_type,
"library_original_width": original_file.width,
"library_original_height": original_file.height,
"model_input": build_resource_input_snapshot(resource, original_file),
"garment_input": garment_snapshot,
"scene_input": scene_snapshot,
"pose_input": {"pose_id": payload.pose_id} if payload.pose_id is not None else None,
"normalized": False,
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.PREPARED_MODEL,
step_name=payload.step_name,
uri=original_file.public_url,
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=1.0,
passed=True,
message="prepared model ready",
metadata=jsonable(metadata) or {},
)
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise
@activity.defn
async def run_tryon_activity(payload: StepActivityInput) -> MockActivityResult:
"""执行试衣渲染步骤,或在未接真实能力时走 mock 分支。
流程:
1. 读取当前配置的图片生成 service。
2. 如果当前仍是 mock service就退回通用的 mock 资产产出逻辑。
3. 如果是真实模式,先读取 prepare_model_activity 产出的 prepared_model 资产。
4. 再读取本次选中的服装资源,并解析出它的原图 URL。
5. 把“模特准备图 + 服装原图”一起发给 provider 做试衣生成。
6. 把生成结果上传到 S3并落成订单内的一条 TRYON 资产。
7. 在 metadata 里记录 provider、model、prompt 和输入来源,方便追踪排查。
"""
service = get_image_generation_service()
# 保留旧的 mock 分支,这样在没有接入真实 provider 时 workflow 也还能跑通。
if service.__class__.__name__ == "MockImageGenerationService":
return await execute_asset_step(
payload,
AssetType.TRYON,
extra_metadata={"prepared_asset_id": payload.source_asset_id},
)
if payload.source_asset_id is None:
raise ValueError("run_tryon_activity requires source_asset_id")
if payload.garment_asset_id is None:
raise ValueError("run_tryon_activity requires garment_asset_id")
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
# 试衣步骤的输入起点固定是上一阶段产出的 prepared_model 资产。
prepared_asset = await session.get(AssetORM, payload.source_asset_id)
if prepared_asset is None or prepared_asset.order_id != payload.order_id:
raise ValueError(f"Prepared asset {payload.source_asset_id} not found for order {payload.order_id}")
if prepared_asset.asset_type != AssetType.PREPARED_MODEL:
raise ValueError(f"Asset {payload.source_asset_id} is not a prepared_model asset")
# 服装素材来自共享资源库workflow 实际消费的是资源库里的原图。
garment_resource, garment_original = await load_active_library_resource(
session,
payload.garment_asset_id,
resource_type=LibraryResourceType.GARMENT,
)
# provider 接收两张真实输入图:模特准备图 + 服装参考图。
generated = await service.generate_tryon_image(
person_image_url=prepared_asset.uri,
garment_image_url=garment_original.public_url,
)
# workflow 生成出的结果图先上传到 S3再登记成订单资产。
storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image(
order_id=payload.order_id,
step_name=payload.step_name,
image_bytes=generated.image_bytes,
mime_type=generated.mime_type,
)
# 记录足够的追踪信息,方便回溯这张试衣图由哪些输入和哪个 provider 产出。
metadata = {
**payload.metadata,
"prepared_asset_id": prepared_asset.id,
"garment_resource_id": garment_resource.id,
"garment_original_file_id": garment_original.id,
"garment_original_url": garment_original.public_url,
"provider": generated.provider,
"model": generated.model,
"storage_key": storage_key,
"mime_type": generated.mime_type,
"prompt": generated.prompt,
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.TRYON,
step_name=payload.step_name,
uri=public_url,
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=1.0,
passed=True,
message="try-on generated",
metadata=jsonable(metadata) or {},
)
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise