feat: add resource library and real image workflow

This commit is contained in:
afei A
2026-03-29 00:24:29 +08:00
parent eeaff269eb
commit 04da401ab4
38 changed files with 3033 additions and 117 deletions

View File

@@ -0,0 +1,89 @@
"""Library resource routes."""
from fastapi import APIRouter, Depends, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.schemas.library import (
ArchiveLibraryResourceResponse,
CreateLibraryResourceRequest,
LibraryResourceListResponse,
LibraryResourceRead,
PresignUploadRequest,
PresignUploadResponse,
UpdateLibraryResourceRequest,
)
from app.application.services.library_service import LibraryService
from app.domain.enums import LibraryResourceType
from app.infra.db.session import get_db_session
router = APIRouter(prefix="/library", tags=["library"])
library_service = LibraryService()
@router.post("/uploads/presign", response_model=PresignUploadResponse)
async def create_library_upload_presign(payload: PresignUploadRequest) -> PresignUploadResponse:
"""Create direct-upload metadata for a library file."""
return library_service.create_upload_presign(
resource_type=payload.resource_type,
file_name=payload.file_name,
content_type=payload.content_type,
)
@router.post("/resources", response_model=LibraryResourceRead, status_code=status.HTTP_201_CREATED)
async def create_library_resource(
payload: CreateLibraryResourceRequest,
session: AsyncSession = Depends(get_db_session),
) -> LibraryResourceRead:
"""Create a library resource with already-uploaded file metadata."""
return await library_service.create_resource(session, payload)
@router.get("/resources", response_model=LibraryResourceListResponse)
async def list_library_resources(
resource_type: LibraryResourceType | None = Query(default=None),
query: str | None = Query(default=None, min_length=1),
gender: str | None = Query(default=None),
age_group: str | None = Query(default=None),
environment: str | None = Query(default=None),
category: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=20, ge=1, le=100),
session: AsyncSession = Depends(get_db_session),
) -> LibraryResourceListResponse:
"""List library resources with basic filtering."""
return await library_service.list_resources(
session,
resource_type=resource_type,
query=query,
gender=gender,
age_group=age_group,
environment=environment,
category=category,
page=page,
limit=limit,
)
@router.patch("/resources/{resource_id}", response_model=LibraryResourceRead)
async def update_library_resource(
resource_id: int,
payload: UpdateLibraryResourceRequest,
session: AsyncSession = Depends(get_db_session),
) -> LibraryResourceRead:
"""Update editable metadata on a library resource."""
return await library_service.update_resource(session, resource_id, payload)
@router.delete("/resources/{resource_id}", response_model=ArchiveLibraryResourceResponse)
async def archive_library_resource(
resource_id: int,
session: AsyncSession = Depends(get_db_session),
) -> ArchiveLibraryResourceResponse:
"""Archive a library resource instead of hard deleting it."""
return await library_service.archive_resource(session, resource_id)

118
app/api/schemas/library.py Normal file
View File

@@ -0,0 +1,118 @@
"""Library resource API schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from app.domain.enums import LibraryFileRole, LibraryResourceStatus, LibraryResourceType
class PresignUploadRequest(BaseModel):
"""Request payload for generating direct-upload metadata."""
resource_type: LibraryResourceType
file_role: LibraryFileRole
file_name: str
content_type: str
class PresignUploadResponse(BaseModel):
"""Response returned when generating direct-upload metadata."""
method: str
upload_url: str
headers: dict[str, str] = Field(default_factory=dict)
storage_key: str
public_url: str
class CreateLibraryResourceFileRequest(BaseModel):
"""Metadata for one already-uploaded file."""
file_role: LibraryFileRole
storage_key: str
public_url: str
mime_type: str
size_bytes: int
sort_order: int = 0
width: int | None = None
height: int | None = None
class CreateLibraryResourceRequest(BaseModel):
"""One-shot resource creation request."""
resource_type: LibraryResourceType
name: str
description: str | None = None
tags: list[str] = Field(default_factory=list)
gender: str | None = None
age_group: str | None = None
pose_id: int | None = None
environment: str | None = None
category: str | None = None
files: list[CreateLibraryResourceFileRequest]
class UpdateLibraryResourceRequest(BaseModel):
"""Partial update request for a library resource."""
name: str | None = None
description: str | None = None
tags: list[str] | None = None
gender: str | None = None
age_group: str | None = None
pose_id: int | None = None
environment: str | None = None
category: str | None = None
cover_file_id: int | None = None
class ArchiveLibraryResourceResponse(BaseModel):
"""Response returned after archiving a resource."""
id: int
class LibraryResourceFileRead(BaseModel):
"""Serialized library file metadata."""
id: int
file_role: LibraryFileRole
storage_key: str
public_url: str
bucket: str
mime_type: str
size_bytes: int
sort_order: int
width: int | None = None
height: int | None = None
created_at: datetime
class LibraryResourceRead(BaseModel):
"""Serialized library resource."""
id: int
resource_type: LibraryResourceType
name: str
description: str | None = None
tags: list[str]
status: LibraryResourceStatus
gender: str | None = None
age_group: str | None = None
pose_id: int | None = None
environment: str | None = None
category: str | None = None
cover_url: str | None = None
original_url: str | None = None
files: list[LibraryResourceFileRead]
created_at: datetime
updated_at: datetime
class LibraryResourceListResponse(BaseModel):
"""Paginated library resource response."""
total: int
items: list[LibraryResourceRead]

View File

@@ -14,9 +14,9 @@ class CreateOrderRequest(BaseModel):
customer_level: CustomerLevel
service_mode: ServiceMode
model_id: int
pose_id: int
pose_id: int | None = None
garment_asset_id: int
scene_ref_asset_id: int
scene_ref_asset_id: int | None = None
class CreateOrderResponse(BaseModel):
@@ -35,9 +35,9 @@ class OrderDetailResponse(BaseModel):
service_mode: ServiceMode
status: OrderStatus
model_id: int
pose_id: int
pose_id: int | None
garment_asset_id: int
scene_ref_asset_id: int
scene_ref_asset_id: int | None
final_asset_id: int | None
workflow_id: str | None
current_step: WorkflowStepName | None

View File

@@ -0,0 +1,163 @@
"""Application-facing orchestration for image-generation providers."""
from __future__ import annotations
import base64
import httpx
from fastapi import HTTPException, status
from app.config.settings import get_settings
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
from app.infra.image_generation.gemini_provider import GeminiImageProvider
TRYON_PROMPT = (
"Use the first image as the base person image and the second image as the garment reference. "
"Generate a realistic virtual try-on result where the person in the first image is wearing the garment from "
"the second image. Preserve the person's identity, body proportions, pose, camera angle, lighting direction, "
"and background composition from the first image. Preserve the garment's key silhouette, material feel, color, "
"and visible design details from the second image. Produce a clean e-commerce quality result without adding "
"extra accessories, text, watermarks, or duplicate limbs."
)
SCENE_PROMPT = (
"Use the first image as the finished subject image and the second image as the target scene reference. "
"Generate a realistic composite where the person, clothing, body proportions, facial identity, pose, and camera "
"framing from the first image are preserved, while the environment is adapted to match the second image. "
"Keep the garment details, colors, and texture from the first image intact. Blend the subject naturally into the "
"target scene with coherent perspective, lighting, shadows, and depth. Do not add extra people, accessories, "
"text, logos, watermarks, or duplicate body parts."
)
MOCK_PNG_BASE64 = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO5Ww4QAAAAASUVORK5CYII="
)
class MockImageGenerationService:
"""Deterministic fallback used in tests and when no real provider is configured."""
async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> GeneratedImageResult:
return GeneratedImageResult(
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
mime_type="image/png",
provider="mock",
model="mock-image-provider",
prompt=TRYON_PROMPT,
)
async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> GeneratedImageResult:
return GeneratedImageResult(
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
mime_type="image/png",
provider="mock",
model="mock-image-provider",
prompt=SCENE_PROMPT,
)
class ImageGenerationService:
"""Download source assets and dispatch them to the configured provider."""
def __init__(
self,
*,
provider: GeminiImageProvider | None = None,
downloader: httpx.AsyncClient | None = None,
) -> None:
self.settings = get_settings()
self.provider = provider or GeminiImageProvider(
api_key=self.settings.gemini_api_key,
base_url=self.settings.gemini_base_url,
model=self.settings.gemini_model,
timeout_seconds=self.settings.gemini_timeout_seconds,
max_attempts=self.settings.gemini_max_attempts,
)
self._downloader = downloader
async def generate_tryon_image(
self,
*,
person_image_url: str,
garment_image_url: str,
) -> GeneratedImageResult:
"""Generate a try-on image using the configured provider."""
if not self.settings.gemini_api_key:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
)
person_image, garment_image = await self._download_inputs(
first_image_url=person_image_url,
second_image_url=garment_image_url,
)
return await self.provider.generate_tryon_image(
prompt=TRYON_PROMPT,
person_image=person_image,
garment_image=garment_image,
)
async def generate_scene_image(
self,
*,
source_image_url: str,
scene_image_url: str,
) -> GeneratedImageResult:
"""Generate a scene-composited image using the configured provider."""
if not self.settings.gemini_api_key:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
)
source_image, scene_image = await self._download_inputs(
first_image_url=source_image_url,
second_image_url=scene_image_url,
)
return await self.provider.generate_scene_image(
prompt=SCENE_PROMPT,
source_image=source_image,
scene_image=scene_image,
)
async def _download_inputs(
self,
*,
first_image_url: str,
second_image_url: str,
) -> tuple[SourceImage, SourceImage]:
owns_client = self._downloader is None
client = self._downloader or httpx.AsyncClient(timeout=self.settings.gemini_timeout_seconds)
try:
first_response = await client.get(first_image_url)
first_response.raise_for_status()
second_response = await client.get(second_image_url)
second_response.raise_for_status()
finally:
if owns_client:
await client.aclose()
first_mime = first_response.headers.get("content-type", "image/png").split(";")[0].strip()
second_mime = second_response.headers.get("content-type", "image/png").split(";")[0].strip()
return (
SourceImage(url=first_image_url, mime_type=first_mime or "image/png", data=first_response.content),
SourceImage(url=second_image_url, mime_type=second_mime or "image/png", data=second_response.content),
)
def build_image_generation_service():
"""Return the configured image-generation service."""
settings = get_settings()
if settings.image_generation_provider.lower() == "mock":
return MockImageGenerationService()
if settings.image_generation_provider.lower() == "gemini":
return ImageGenerationService()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unsupported image_generation_provider: {settings.image_generation_provider}",
)

View File

@@ -0,0 +1,366 @@
"""Library resource application service."""
from __future__ import annotations
from fastapi import HTTPException, status
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.api.schemas.library import (
ArchiveLibraryResourceResponse,
CreateLibraryResourceRequest,
LibraryResourceFileRead,
LibraryResourceListResponse,
LibraryResourceRead,
PresignUploadResponse,
UpdateLibraryResourceRequest,
)
from app.config.settings import get_settings
from app.domain.enums import LibraryFileRole, LibraryResourceStatus, LibraryResourceType
from app.infra.db.models.library_resource import LibraryResourceORM
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
from app.infra.storage.s3 import RESOURCE_PREFIXES, S3PresignService
class LibraryService:
"""Application service for resource-library uploads and queries."""
def __init__(self, presign_service: S3PresignService | None = None) -> None:
self.presign_service = presign_service or S3PresignService()
def create_upload_presign(
self,
resource_type: LibraryResourceType,
file_name: str,
content_type: str,
) -> PresignUploadResponse:
"""Create upload metadata for a direct S3 PUT."""
settings = get_settings()
if not all(
[
settings.s3_bucket,
settings.s3_region,
settings.s3_access_key,
settings.s3_secret_key,
settings.s3_endpoint,
]
):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="S3 upload settings are incomplete",
)
storage_key, upload_url = self.presign_service.create_upload(
resource_type=resource_type,
file_name=file_name,
content_type=content_type,
)
return PresignUploadResponse(
method="PUT",
upload_url=upload_url,
headers={"content-type": content_type},
storage_key=storage_key,
public_url=self.presign_service.get_public_url(storage_key),
)
async def create_resource(
self,
session: AsyncSession,
payload: CreateLibraryResourceRequest,
) -> LibraryResourceRead:
"""Persist a resource and its uploaded file metadata."""
self._validate_payload(payload)
resource = LibraryResourceORM(
resource_type=payload.resource_type,
name=payload.name.strip(),
description=payload.description.strip() if payload.description else None,
tags=[tag.strip() for tag in payload.tags if tag.strip()],
status=LibraryResourceStatus.ACTIVE,
gender=payload.gender.strip() if payload.gender else None,
age_group=payload.age_group.strip() if payload.age_group else None,
pose_id=payload.pose_id,
environment=payload.environment.strip() if payload.environment else None,
category=payload.category.strip() if payload.category else None,
)
session.add(resource)
await session.flush()
file_models: list[LibraryResourceFileORM] = []
for item in payload.files:
file_model = LibraryResourceFileORM(
resource_id=resource.id,
file_role=item.file_role,
storage_key=item.storage_key,
public_url=item.public_url,
bucket=get_settings().s3_bucket,
mime_type=item.mime_type,
size_bytes=item.size_bytes,
sort_order=item.sort_order,
width=item.width,
height=item.height,
)
session.add(file_model)
file_models.append(file_model)
await session.flush()
resource.original_file_id = self._find_role(file_models, LibraryFileRole.ORIGINAL).id
resource.cover_file_id = self._find_role(file_models, LibraryFileRole.THUMBNAIL).id
await session.commit()
return self._to_read(resource, files=file_models)
async def list_resources(
self,
session: AsyncSession,
*,
resource_type: LibraryResourceType | None = None,
query: str | None = None,
gender: str | None = None,
age_group: str | None = None,
environment: str | None = None,
category: str | None = None,
page: int = 1,
limit: int = 20,
) -> LibraryResourceListResponse:
"""List persisted resources with simple filter support."""
filters = [LibraryResourceORM.status == LibraryResourceStatus.ACTIVE]
if resource_type is not None:
filters.append(LibraryResourceORM.resource_type == resource_type)
if query:
like = f"%{query.strip()}%"
filters.append(
or_(
LibraryResourceORM.name.ilike(like),
LibraryResourceORM.description.ilike(like),
)
)
if gender:
filters.append(LibraryResourceORM.gender == gender)
if age_group:
filters.append(LibraryResourceORM.age_group == age_group)
if environment:
filters.append(LibraryResourceORM.environment == environment)
if category:
filters.append(LibraryResourceORM.category == category)
total = (
await session.execute(select(func.count(LibraryResourceORM.id)).where(*filters))
).scalar_one()
result = await session.execute(
select(LibraryResourceORM)
.options(selectinload(LibraryResourceORM.files))
.where(*filters)
.order_by(LibraryResourceORM.created_at.desc(), LibraryResourceORM.id.desc())
.offset((page - 1) * limit)
.limit(limit)
)
items = result.scalars().all()
return LibraryResourceListResponse(total=total, items=[self._to_read(item) for item in items])
async def update_resource(
self,
session: AsyncSession,
resource_id: int,
payload: UpdateLibraryResourceRequest,
) -> LibraryResourceRead:
"""Update editable metadata on a library resource."""
resource = await self._get_resource_or_404(session, resource_id)
if payload.name is not None:
name = payload.name.strip()
if not name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Resource name is required",
)
resource.name = name
if payload.description is not None:
resource.description = payload.description.strip() or None
if payload.tags is not None:
resource.tags = [tag.strip() for tag in payload.tags if tag.strip()]
if resource.resource_type == LibraryResourceType.MODEL:
if payload.gender is not None:
resource.gender = payload.gender.strip() or None
if payload.age_group is not None:
resource.age_group = payload.age_group.strip() or None
if payload.pose_id is not None:
resource.pose_id = payload.pose_id
elif resource.resource_type == LibraryResourceType.SCENE:
if payload.environment is not None:
resource.environment = payload.environment.strip() or None
elif resource.resource_type == LibraryResourceType.GARMENT:
if payload.category is not None:
resource.category = payload.category.strip() or None
if payload.cover_file_id is not None:
cover_file = next(
(file for file in resource.files if file.id == payload.cover_file_id),
None,
)
if cover_file is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cover file does not belong to the resource",
)
resource.cover_file_id = cover_file.id
self._validate_existing_resource(resource)
await session.commit()
await session.refresh(resource, attribute_names=["files"])
return self._to_read(resource)
async def archive_resource(
self,
session: AsyncSession,
resource_id: int,
) -> ArchiveLibraryResourceResponse:
"""Soft delete a library resource by archiving it."""
resource = await self._get_resource_or_404(session, resource_id)
resource.status = LibraryResourceStatus.ARCHIVED
await session.commit()
return ArchiveLibraryResourceResponse(id=resource.id)
def _validate_payload(self, payload: CreateLibraryResourceRequest) -> None:
if not payload.name.strip():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Resource name is required")
if not payload.files:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one file is required")
originals = [file for file in payload.files if file.file_role == LibraryFileRole.ORIGINAL]
thumbnails = [file for file in payload.files if file.file_role == LibraryFileRole.THUMBNAIL]
if len(originals) != 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Exactly one original file is required",
)
if len(thumbnails) != 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Exactly one thumbnail file is required",
)
expected_prefix = f"library/{RESOURCE_PREFIXES[payload.resource_type]}/"
for file in payload.files:
if not file.storage_key.startswith(expected_prefix):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Uploaded file key does not match resource type",
)
if payload.resource_type == LibraryResourceType.MODEL:
if not payload.gender or not payload.age_group:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Model resources require gender and age_group",
)
elif payload.resource_type == LibraryResourceType.SCENE:
if not payload.environment:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Scene resources require environment",
)
elif payload.resource_type == LibraryResourceType.GARMENT and not payload.category:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Garment resources require category",
)
def _find_role(
self,
files: list[LibraryResourceFileORM],
role: LibraryFileRole,
) -> LibraryResourceFileORM:
return next(file for file in files if file.file_role == role)
def _to_read(
self,
resource: LibraryResourceORM,
*,
files: list[LibraryResourceFileORM] | None = None,
) -> LibraryResourceRead:
resource_files = files if files is not None else list(resource.files)
files_sorted = sorted(resource_files, key=lambda item: (item.sort_order, item.id))
cover = next((item for item in files_sorted if item.id == resource.cover_file_id), None)
original = next((item for item in files_sorted if item.id == resource.original_file_id), None)
return LibraryResourceRead(
id=resource.id,
resource_type=resource.resource_type,
name=resource.name,
description=resource.description,
tags=resource.tags,
status=resource.status,
gender=resource.gender,
age_group=resource.age_group,
pose_id=resource.pose_id,
environment=resource.environment,
category=resource.category,
cover_url=cover.public_url if cover else None,
original_url=original.public_url if original else None,
files=[
LibraryResourceFileRead(
id=file.id,
file_role=file.file_role,
storage_key=file.storage_key,
public_url=file.public_url,
bucket=file.bucket,
mime_type=file.mime_type,
size_bytes=file.size_bytes,
sort_order=file.sort_order,
width=file.width,
height=file.height,
created_at=file.created_at,
)
for file in files_sorted
],
created_at=resource.created_at,
updated_at=resource.updated_at,
)
async def _get_resource_or_404(
self,
session: AsyncSession,
resource_id: int,
) -> LibraryResourceORM:
result = await session.execute(
select(LibraryResourceORM)
.options(selectinload(LibraryResourceORM.files))
.where(LibraryResourceORM.id == resource_id)
)
resource = result.scalar_one_or_none()
if resource is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found")
return resource
def _validate_existing_resource(self, resource: LibraryResourceORM) -> None:
if not resource.name.strip():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Resource name is required")
if resource.resource_type == LibraryResourceType.MODEL:
if not resource.gender or not resource.age_group:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Model resources require gender and age_group",
)
elif resource.resource_type == LibraryResourceType.SCENE:
if not resource.environment:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Scene resources require environment",
)
elif resource.resource_type == LibraryResourceType.GARMENT and not resource.category:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Garment resources require category",
)

View File

@@ -1,9 +1,12 @@
"""Application settings."""
from functools import lru_cache
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
PROJECT_ROOT = Path(__file__).resolve().parents[2]
class Settings(BaseSettings):
"""Runtime settings loaded from environment variables."""
@@ -15,11 +18,25 @@ class Settings(BaseSettings):
database_url: str = "sqlite+aiosqlite:///./temporal_demo.db"
temporal_address: str = "localhost:7233"
temporal_namespace: str = "default"
s3_access_key: str = ""
s3_secret_key: str = ""
s3_bucket: str = ""
s3_region: str = ""
s3_endpoint: str = ""
s3_cname: str = ""
s3_presign_expiry_seconds: int = 900
image_generation_provider: str = "mock"
gemini_api_key: str = ""
gemini_base_url: str = "https://api.museidea.com/v1beta"
gemini_model: str = "gemini-3.1-flash-image-preview"
gemini_timeout_seconds: int = 300
gemini_max_attempts: int = 2
model_config = SettingsConfigDict(
env_file=".env",
env_file=PROJECT_ROOT / ".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
@property
@@ -34,4 +51,3 @@ def get_settings() -> Settings:
"""Return the cached application settings."""
return Settings()

View File

@@ -82,3 +82,26 @@ class AssetType(str, Enum):
QC_CANDIDATE = "qc_candidate"
MANUAL_REVISION = "manual_revision"
FINAL = "final"
class LibraryResourceType(str, Enum):
"""Supported resource-library item types."""
MODEL = "model"
SCENE = "scene"
GARMENT = "garment"
class LibraryFileRole(str, Enum):
"""Supported file roles within a library resource."""
ORIGINAL = "original"
THUMBNAIL = "thumbnail"
GALLERY = "gallery"
class LibraryResourceStatus(str, Enum):
"""Lifecycle state for a library resource."""
ACTIVE = "active"
ARCHIVED = "archived"

View File

@@ -15,12 +15,11 @@ class Order:
service_mode: ServiceMode
status: OrderStatus
model_id: int
pose_id: int
pose_id: int | None
garment_asset_id: int
scene_ref_asset_id: int
scene_ref_asset_id: int | None
final_asset_id: int | None
workflow_id: str | None
current_step: WorkflowStepName | None
created_at: datetime
updated_at: datetime

View File

@@ -0,0 +1,45 @@
"""Library resource ORM model."""
from __future__ import annotations
from sqlalchemy import Enum, Integer, JSON, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.domain.enums import LibraryResourceStatus, LibraryResourceType
from app.infra.db.base import Base, TimestampMixin
class LibraryResourceORM(TimestampMixin, Base):
"""Persisted library resource independent from order-generated assets."""
__tablename__ = "library_resources"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
resource_type: Mapped[LibraryResourceType] = mapped_column(
Enum(LibraryResourceType, native_enum=False),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
tags: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
status: Mapped[LibraryResourceStatus] = mapped_column(
Enum(LibraryResourceStatus, native_enum=False),
nullable=False,
default=LibraryResourceStatus.ACTIVE,
index=True,
)
gender: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True)
age_group: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True)
pose_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
environment: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True)
category: Mapped[str | None] = mapped_column(String(128), nullable=True, index=True)
cover_file_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
original_file_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
files = relationship(
"LibraryResourceFileORM",
back_populates="resource",
lazy="selectin",
cascade="all, delete-orphan",
)

View File

@@ -0,0 +1,37 @@
"""Library resource file ORM model."""
from __future__ import annotations
from sqlalchemy import Enum, ForeignKey, Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.domain.enums import LibraryFileRole
from app.infra.db.base import Base, TimestampMixin
class LibraryResourceFileORM(TimestampMixin, Base):
"""Persisted uploaded file metadata for a library resource."""
__tablename__ = "library_resource_files"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
resource_id: Mapped[int] = mapped_column(
ForeignKey("library_resources.id"),
nullable=False,
index=True,
)
file_role: Mapped[LibraryFileRole] = mapped_column(
Enum(LibraryFileRole, native_enum=False),
nullable=False,
index=True,
)
storage_key: Mapped[str] = mapped_column(String(500), nullable=False)
public_url: Mapped[str] = mapped_column(String(500), nullable=False)
bucket: Mapped[str] = mapped_column(String(255), nullable=False)
mime_type: Mapped[str] = mapped_column(String(255), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
width: Mapped[int | None] = mapped_column(Integer, nullable=True)
height: Mapped[int | None] = mapped_column(Integer, nullable=True)
resource = relationship("LibraryResourceORM", back_populates="files")

View File

@@ -27,12 +27,11 @@ class OrderORM(TimestampMixin, Base):
default=OrderStatus.CREATED,
)
model_id: Mapped[int] = mapped_column(Integer, nullable=False)
pose_id: Mapped[int] = mapped_column(Integer, nullable=False)
pose_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
garment_asset_id: Mapped[int] = mapped_column(Integer, nullable=False)
scene_ref_asset_id: Mapped[int] = mapped_column(Integer, nullable=False)
scene_ref_asset_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
final_asset_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
assets = relationship("AssetORM", back_populates="order", lazy="selectin")
review_tasks = relationship("ReviewTaskORM", back_populates="order", lazy="selectin")
workflow_runs = relationship("WorkflowRunORM", back_populates="order", lazy="selectin")

View File

@@ -44,12 +44,14 @@ async def init_database() -> None:
"""Create database tables when running the MVP without migrations."""
from app.infra.db.models.asset import AssetORM
from app.infra.db.models.library_resource import LibraryResourceORM
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
from app.infra.db.models.order import OrderORM
from app.infra.db.models.review_task import ReviewTaskORM
from app.infra.db.models.workflow_run import WorkflowRunORM
from app.infra.db.models.workflow_step import WorkflowStepORM
del AssetORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM
del AssetORM, LibraryResourceORM, LibraryResourceFileORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM
async with get_async_engine().begin() as connection:
await connection.run_sync(Base.metadata.create_all)

View File

@@ -0,0 +1,38 @@
"""Shared types for image-generation providers."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
@dataclass(slots=True)
class SourceImage:
"""A binary source image passed into an image-generation provider."""
url: str
mime_type: str
data: bytes
@dataclass(slots=True)
class GeneratedImageResult:
"""Normalized image-generation output returned by providers."""
image_bytes: bytes
mime_type: str
provider: str
model: str
prompt: str
class ImageGenerationProvider(Protocol):
"""Contract implemented by concrete image-generation providers."""
async def generate_tryon_image(
self,
*,
prompt: str,
person_image: SourceImage,
garment_image: SourceImage,
) -> GeneratedImageResult: ...

View File

@@ -0,0 +1,149 @@
"""Gemini-backed image-generation provider."""
from __future__ import annotations
import base64
import httpx
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
class GeminiImageProvider:
"""Call the Gemini image-generation API via a configurable REST endpoint."""
provider_name = "gemini"
def __init__(
self,
*,
api_key: str,
base_url: str,
model: str,
timeout_seconds: int,
max_attempts: int = 2,
http_client: httpx.AsyncClient | None = None,
) -> None:
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.model = model
self.timeout_seconds = timeout_seconds
self.max_attempts = max(1, max_attempts)
self._http_client = http_client
async def generate_tryon_image(
self,
*,
prompt: str,
person_image: SourceImage,
garment_image: SourceImage,
) -> GeneratedImageResult:
"""Generate a try-on image from a prepared person image and a garment image."""
return await self._generate_two_image_edit(
prompt=prompt,
first_image=person_image,
second_image=garment_image,
)
async def generate_scene_image(
self,
*,
prompt: str,
source_image: SourceImage,
scene_image: SourceImage,
) -> GeneratedImageResult:
"""Generate a scene-composited image from a rendered subject image and a scene reference."""
return await self._generate_two_image_edit(
prompt=prompt,
first_image=source_image,
second_image=scene_image,
)
async def _generate_two_image_edit(
self,
*,
prompt: str,
first_image: SourceImage,
second_image: SourceImage,
) -> GeneratedImageResult:
"""Generate an edited image from a prompt plus two inline image inputs."""
payload = {
"contents": [
{
"parts": [
{"text": prompt},
{
"inline_data": {
"mime_type": first_image.mime_type,
"data": base64.b64encode(first_image.data).decode("utf-8"),
}
},
{
"inline_data": {
"mime_type": second_image.mime_type,
"data": base64.b64encode(second_image.data).decode("utf-8"),
}
},
]
}
],
"generationConfig": {
"responseModalities": ["TEXT", "IMAGE"],
},
}
owns_client = self._http_client is None
client = self._http_client or httpx.AsyncClient(timeout=self.timeout_seconds)
try:
response = None
for attempt in range(1, self.max_attempts + 1):
try:
response = await client.post(
f"{self.base_url}/models/{self.model}:generateContent",
headers={
"x-goog-api-key": self.api_key,
"Content-Type": "application/json",
},
json=payload,
timeout=self.timeout_seconds,
)
response.raise_for_status()
break
except httpx.TransportError:
if attempt >= self.max_attempts:
raise
finally:
if owns_client:
await client.aclose()
if response is None:
raise RuntimeError("Gemini provider did not receive a response")
body = response.json()
image_part = self._find_image_part(body)
image_data = image_part.get("inlineData") or image_part.get("inline_data")
mime_type = image_data.get("mimeType") or image_data.get("mime_type") or "image/png"
data = image_data.get("data")
if not data:
raise ValueError("Gemini response did not include image bytes")
return GeneratedImageResult(
image_bytes=base64.b64decode(data),
mime_type=mime_type,
provider=self.provider_name,
model=self.model,
prompt=prompt,
)
@staticmethod
def _find_image_part(body: dict) -> dict:
candidates = body.get("candidates") or []
for candidate in candidates:
content = candidate.get("content") or {}
for part in content.get("parts") or []:
if part.get("inlineData") or part.get("inline_data"):
return part
raise ValueError("Gemini response did not contain an image part")

107
app/infra/storage/s3.py Normal file
View File

@@ -0,0 +1,107 @@
"""S3 direct-upload helpers."""
from __future__ import annotations
from pathlib import Path
from uuid import uuid4
import boto3
from app.config.settings import get_settings
from app.domain.enums import LibraryResourceType, WorkflowStepName
RESOURCE_PREFIXES: dict[LibraryResourceType, str] = {
LibraryResourceType.MODEL: "models",
LibraryResourceType.SCENE: "scenes",
LibraryResourceType.GARMENT: "garments",
}
class S3PresignService:
"""Generate presigned upload URLs and derived public URLs."""
def __init__(self) -> None:
self.settings = get_settings()
self._client = boto3.client(
"s3",
region_name=self.settings.s3_region or None,
endpoint_url=self.settings.s3_endpoint or None,
aws_access_key_id=self.settings.s3_access_key or None,
aws_secret_access_key=self.settings.s3_secret_key or None,
)
def create_upload(self, resource_type: LibraryResourceType, file_name: str, content_type: str) -> tuple[str, str]:
"""Return a storage key and presigned PUT URL for a resource file."""
storage_key = self._build_storage_key(resource_type, file_name)
upload_url = self._client.generate_presigned_url(
"put_object",
Params={
"Bucket": self.settings.s3_bucket,
"Key": storage_key,
"ContentType": content_type,
},
ExpiresIn=self.settings.s3_presign_expiry_seconds,
HttpMethod="PUT",
)
return storage_key, upload_url
def get_public_url(self, storage_key: str) -> str:
"""Return the public CDN URL for an uploaded object."""
if self.settings.s3_cname:
base = self.settings.s3_cname
if not base.startswith("http://") and not base.startswith("https://"):
base = f"https://{base}"
return f"{base.rstrip('/')}/{storage_key}"
endpoint = self.settings.s3_endpoint.rstrip("/")
return f"{endpoint}/{self.settings.s3_bucket}/{storage_key}"
def _build_storage_key(self, resource_type: LibraryResourceType, file_name: str) -> str:
suffix = Path(file_name).suffix or ".bin"
stem = Path(file_name).stem.replace(" ", "-").lower() or "file"
return f"library/{RESOURCE_PREFIXES[resource_type]}/{uuid4().hex}-{stem}{suffix.lower()}"
class S3ObjectStorageService:
"""Upload generated workflow artifacts to the configured object store."""
def __init__(self) -> None:
self.settings = get_settings()
self._client = boto3.client(
"s3",
region_name=self.settings.s3_region or None,
endpoint_url=self.settings.s3_endpoint or None,
aws_access_key_id=self.settings.s3_access_key or None,
aws_secret_access_key=self.settings.s3_secret_key or None,
)
self._presign = S3PresignService()
async def upload_generated_image(
self,
*,
order_id: int,
step_name: WorkflowStepName,
image_bytes: bytes,
mime_type: str,
) -> tuple[str, str]:
"""Upload bytes and return the storage key plus public URL."""
storage_key = self._build_storage_key(order_id=order_id, step_name=step_name, mime_type=mime_type)
self._client.put_object(
Bucket=self.settings.s3_bucket,
Key=storage_key,
Body=image_bytes,
ContentType=mime_type,
)
return storage_key, self._presign.get_public_url(storage_key)
@staticmethod
def _build_storage_key(*, order_id: int, step_name: WorkflowStepName, mime_type: str) -> str:
suffix = {
"image/png": ".png",
"image/jpeg": ".jpg",
"image/webp": ".webp",
}.get(mime_type, ".bin")
return f"orders/{order_id}/{step_name.value}/{uuid4().hex}{suffix}"

View File

@@ -6,6 +6,7 @@ from fastapi import FastAPI
from app.api.routers.assets import router as assets_router
from app.api.routers.health import router as health_router
from app.api.routers.library import router as library_router
from app.api.routers.orders import router as orders_router
from app.api.routers.revisions import router as revisions_router
from app.api.routers.reviews import router as reviews_router
@@ -30,6 +31,7 @@ def create_app() -> FastAPI:
settings = get_settings()
app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan)
app.include_router(health_router)
app.include_router(library_router, prefix=settings.api_prefix)
app.include_router(orders_router, prefix=settings.api_prefix)
app.include_router(assets_router, prefix=settings.api_prefix)
app.include_router(revisions_router, prefix=settings.api_prefix)

View File

@@ -1,20 +1,82 @@
"""Export mock activity."""
"""Export activity."""
from app.domain.enums import AssetType, OrderStatus, StepStatus
from app.infra.db.models.asset import AssetORM
from app.infra.db.session import get_session_factory
from temporalio import activity
from app.domain.enums import AssetType
from app.workers.activities.tryon_activities import execute_asset_step
from app.workers.activities.tryon_activities import create_step_record, jsonable, load_order_and_run, utc_now
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@activity.defn
async def run_export_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock final asset export."""
"""Finalize the chosen source asset as the order's exported deliverable."""
return await execute_asset_step(
payload,
AssetType.FINAL,
filename="final.png",
finalize=True,
)
if payload.source_asset_id is None:
raise ValueError("run_export_activity requires source_asset_id")
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
source_asset = await session.get(AssetORM, payload.source_asset_id)
if source_asset is None:
raise ValueError(f"Source asset {payload.source_asset_id} not found")
if source_asset.order_id != payload.order_id:
raise ValueError(
f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}"
)
metadata = {
**payload.metadata,
"source_asset_id": payload.source_asset_id,
"selected_asset_id": payload.selected_asset_id,
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.FINAL,
step_name=payload.step_name,
uri=source_asset.uri,
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=0.95,
passed=True,
message="mock success",
metadata=jsonable(metadata) or {},
)
order.final_asset_id = asset.id
order.status = OrderStatus.SUCCEEDED
workflow_run.status = OrderStatus.SUCCEEDED
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise

View File

@@ -29,11 +29,22 @@ async def run_qc_activity(payload: StepActivityInput) -> MockActivityResult:
candidate_uri: str | None = None
if passed:
if payload.source_asset_id is None:
raise ValueError("run_qc_activity requires source_asset_id")
source_asset = await session.get(AssetORM, payload.source_asset_id)
if source_asset is None:
raise ValueError(f"Source asset {payload.source_asset_id} not found")
if source_asset.order_id != payload.order_id:
raise ValueError(
f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}"
)
candidate = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.QC_CANDIDATE,
step_name=payload.step_name,
uri=mock_uri(payload.order_id, payload.step_name.value, "candidate.png"),
uri=source_asset.uri,
metadata_json=jsonable({"source_asset_id": payload.source_asset_id}),
)
session.add(candidate)

View File

@@ -1,19 +1,122 @@
"""Scene mock activity."""
"""Scene activities."""
from temporalio import activity
from app.domain.enums import AssetType
from app.workers.activities.tryon_activities import execute_asset_step
from app.domain.enums import AssetType, LibraryResourceType, OrderStatus, StepStatus
from app.infra.db.models.asset import AssetORM
from app.infra.db.session import get_session_factory
from app.workers.activities.tryon_activities import (
create_step_record,
execute_asset_step,
get_image_generation_service,
get_order_artifact_storage_service,
jsonable,
load_active_library_resource,
load_order_and_run,
utc_now,
)
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@activity.defn
async def run_scene_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock scene replacement."""
"""Generate a scene-composited asset, or fall back to mock mode when configured."""
return await execute_asset_step(
payload,
AssetType.SCENE,
extra_metadata={"scene_ref_asset_id": payload.scene_ref_asset_id},
)
service = get_image_generation_service()
if service.__class__.__name__ == "MockImageGenerationService":
return await execute_asset_step(
payload,
AssetType.SCENE,
extra_metadata={"scene_ref_asset_id": payload.scene_ref_asset_id},
)
if payload.source_asset_id is None:
raise ValueError("run_scene_activity requires source_asset_id")
if payload.scene_ref_asset_id is None:
raise ValueError("run_scene_activity requires scene_ref_asset_id")
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
source_asset = await session.get(AssetORM, payload.source_asset_id)
if source_asset is None:
raise ValueError(f"Source asset {payload.source_asset_id} not found")
if source_asset.order_id != payload.order_id:
raise ValueError(
f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}"
)
scene_resource, scene_original = await load_active_library_resource(
session,
payload.scene_ref_asset_id,
resource_type=LibraryResourceType.SCENE,
)
generated = await service.generate_scene_image(
source_image_url=source_asset.uri,
scene_image_url=scene_original.public_url,
)
storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image(
order_id=payload.order_id,
step_name=payload.step_name,
image_bytes=generated.image_bytes,
mime_type=generated.mime_type,
)
metadata = {
**payload.metadata,
"source_asset_id": source_asset.id,
"scene_ref_asset_id": payload.scene_ref_asset_id,
"scene_resource_id": scene_resource.id,
"scene_original_file_id": scene_original.id,
"scene_original_url": scene_original.public_url,
"provider": generated.provider,
"model": generated.model,
"storage_key": storage_key,
"mime_type": generated.mime_type,
"prompt": generated.prompt,
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.SCENE,
step_name=payload.step_name,
uri=public_url,
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=1.0,
passed=True,
message="scene generated",
metadata=jsonable(metadata) or {},
)
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise

View File

@@ -1,4 +1,4 @@
"""Prepare-model and try-on mock activities plus shared helpers."""
"""Prepare-model and try-on activities plus shared helpers."""
from __future__ import annotations
@@ -8,14 +8,20 @@ from enum import Enum
from typing import Any
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from temporalio import activity
from app.domain.enums import AssetType, OrderStatus, StepStatus
from app.application.services.image_generation_service import build_image_generation_service
from app.domain.enums import AssetType, LibraryResourceStatus, LibraryResourceType, OrderStatus, StepStatus
from app.infra.db.models.asset import AssetORM
from app.infra.db.models.library_resource import LibraryResourceORM
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
from app.infra.db.models.order import OrderORM
from app.infra.db.models.workflow_run import WorkflowRunORM
from app.infra.db.models.workflow_step import WorkflowStepORM
from app.infra.db.session import get_session_factory
from app.infra.storage.s3 import S3ObjectStorageService
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@@ -71,6 +77,64 @@ def create_step_record(payload: StepActivityInput) -> WorkflowStepORM:
)
async def load_active_library_resource(
session,
resource_id: int,
*,
resource_type: LibraryResourceType,
) -> tuple[LibraryResourceORM, LibraryResourceFileORM]:
"""Load an active library resource and its original file from the library."""
result = await session.execute(
select(LibraryResourceORM)
.options(selectinload(LibraryResourceORM.files))
.where(LibraryResourceORM.id == resource_id)
)
resource = result.scalar_one_or_none()
if resource is None:
raise ValueError(f"Library resource {resource_id} not found")
if resource.resource_type != resource_type:
raise ValueError(f"Resource {resource_id} is not a {resource_type.value} resource")
if resource.status != LibraryResourceStatus.ACTIVE:
raise ValueError(f"Library resource {resource_id} is not active")
if resource.original_file_id is None:
raise ValueError(f"Library resource {resource_id} is missing an original file")
original_file = next((item for item in resource.files if item.id == resource.original_file_id), None)
if original_file is None:
raise ValueError(f"Library resource {resource_id} original file record not found")
return resource, original_file
def get_image_generation_service():
"""Return the configured image-generation service."""
return build_image_generation_service()
def get_order_artifact_storage_service() -> S3ObjectStorageService:
"""Return the object-storage service for workflow-generated images."""
return S3ObjectStorageService()
def build_resource_input_snapshot(
resource: LibraryResourceORM,
original_file: LibraryResourceFileORM,
) -> dict[str, Any]:
"""Build a frontend-friendly snapshot of one library input resource."""
return {
"resource_id": resource.id,
"resource_name": resource.name,
"original_file_id": original_file.id,
"original_url": original_file.public_url,
"mime_type": original_file.mime_type,
"width": original_file.width,
"height": original_file.height,
}
async def execute_asset_step(
payload: StepActivityInput,
asset_type: AssetType,
@@ -145,26 +209,213 @@ async def execute_asset_step(
@activity.defn
async def prepare_model_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock model preparation for the pipeline."""
"""Resolve a model resource into an order-scoped prepared-model asset."""
return await execute_asset_step(
payload,
AssetType.PREPARED_MODEL,
extra_metadata={
"model_id": payload.model_id,
"pose_id": payload.pose_id,
"garment_asset_id": payload.garment_asset_id,
"scene_ref_asset_id": payload.scene_ref_asset_id,
},
)
if payload.model_id is None:
raise ValueError("prepare_model_activity requires model_id")
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
resource, original_file = await load_active_library_resource(
session,
payload.model_id,
resource_type=LibraryResourceType.MODEL,
)
garment_snapshot = None
if payload.garment_asset_id is not None:
garment_resource, garment_original = await load_active_library_resource(
session,
payload.garment_asset_id,
resource_type=LibraryResourceType.GARMENT,
)
garment_snapshot = build_resource_input_snapshot(garment_resource, garment_original)
scene_snapshot = None
if payload.scene_ref_asset_id is not None:
scene_resource, scene_original = await load_active_library_resource(
session,
payload.scene_ref_asset_id,
resource_type=LibraryResourceType.SCENE,
)
scene_snapshot = build_resource_input_snapshot(scene_resource, scene_original)
metadata = {
**payload.metadata,
"model_id": payload.model_id,
"pose_id": payload.pose_id,
"garment_asset_id": payload.garment_asset_id,
"scene_ref_asset_id": payload.scene_ref_asset_id,
"library_resource_id": resource.id,
"library_original_file_id": original_file.id,
"library_original_url": original_file.public_url,
"library_original_mime_type": original_file.mime_type,
"library_original_width": original_file.width,
"library_original_height": original_file.height,
"model_input": build_resource_input_snapshot(resource, original_file),
"garment_input": garment_snapshot,
"scene_input": scene_snapshot,
"pose_input": {"pose_id": payload.pose_id} if payload.pose_id is not None else None,
"normalized": False,
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.PREPARED_MODEL,
step_name=payload.step_name,
uri=original_file.public_url,
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=1.0,
passed=True,
message="prepared model ready",
metadata=jsonable(metadata) or {},
)
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise
@activity.defn
async def run_tryon_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock try-on rendering."""
"""执行试衣渲染步骤,或在未接真实能力时走 mock 分支。
return await execute_asset_step(
payload,
AssetType.TRYON,
extra_metadata={"prepared_asset_id": payload.source_asset_id},
)
流程:
1. 读取当前配置的图片生成 service。
2. 如果当前仍是 mock service就退回通用的 mock 资产产出逻辑。
3. 如果是真实模式,先读取 prepare_model_activity 产出的 prepared_model 资产。
4. 再读取本次选中的服装资源,并解析出它的原图 URL。
5. 把“模特准备图 + 服装原图”一起发给 provider 做试衣生成。
6. 把生成结果上传到 S3并落成订单内的一条 TRYON 资产。
7. 在 metadata 里记录 provider、model、prompt 和输入来源,方便追踪排查。
"""
service = get_image_generation_service()
# 保留旧的 mock 分支,这样在没有接入真实 provider 时 workflow 也还能跑通。
if service.__class__.__name__ == "MockImageGenerationService":
return await execute_asset_step(
payload,
AssetType.TRYON,
extra_metadata={"prepared_asset_id": payload.source_asset_id},
)
if payload.source_asset_id is None:
raise ValueError("run_tryon_activity requires source_asset_id")
if payload.garment_asset_id is None:
raise ValueError("run_tryon_activity requires garment_asset_id")
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
# 试衣步骤的输入起点固定是上一阶段产出的 prepared_model 资产。
prepared_asset = await session.get(AssetORM, payload.source_asset_id)
if prepared_asset is None or prepared_asset.order_id != payload.order_id:
raise ValueError(f"Prepared asset {payload.source_asset_id} not found for order {payload.order_id}")
if prepared_asset.asset_type != AssetType.PREPARED_MODEL:
raise ValueError(f"Asset {payload.source_asset_id} is not a prepared_model asset")
# 服装素材来自共享资源库workflow 实际消费的是资源库里的原图。
garment_resource, garment_original = await load_active_library_resource(
session,
payload.garment_asset_id,
resource_type=LibraryResourceType.GARMENT,
)
# provider 接收两张真实输入图:模特准备图 + 服装参考图。
generated = await service.generate_tryon_image(
person_image_url=prepared_asset.uri,
garment_image_url=garment_original.public_url,
)
# workflow 生成出的结果图先上传到 S3再登记成订单资产。
storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image(
order_id=payload.order_id,
step_name=payload.step_name,
image_bytes=generated.image_bytes,
mime_type=generated.mime_type,
)
# 记录足够的追踪信息,方便回溯这张试衣图由哪些输入和哪个 provider 产出。
metadata = {
**payload.metadata,
"prepared_asset_id": prepared_asset.id,
"garment_resource_id": garment_resource.id,
"garment_original_file_id": garment_original.id,
"garment_original_url": garment_original.public_url,
"provider": generated.provider,
"model": generated.model,
"storage_key": storage_key,
"mime_type": generated.mime_type,
"prompt": generated.prompt,
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.TRYON,
step_name=payload.step_name,
uri=public_url,
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=1.0,
passed=True,
message="try-on generated",
metadata=jsonable(metadata) or {},
)
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise

View File

@@ -26,14 +26,13 @@ with workflow.unsafe.imports_passed_through():
from app.workers.activities.review_activities import mark_workflow_failed_activity
from app.workers.activities.scene_activities import run_scene_activity
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
from app.workers.workflows.timeout_policy import DEFAULT_ACTIVITY_TIMEOUT, activity_timeout_for_task_queue
from app.workers.workflows.types import (
PipelineWorkflowInput,
StepActivityInput,
WorkflowFailureActivityInput,
)
ACTIVITY_TIMEOUT = timedelta(seconds=30)
ACTIVITY_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2.0,
@@ -64,6 +63,8 @@ class LowEndPipelineWorkflow:
try:
# 每个步骤都通过 execute_activity 触发真正执行。
# workflow 自己不做计算,只负责调度。
# prepare_model 的职责是把“模特资源 + 订单上下文”整理成后续可消费的人物底图。
# 这一步产出的 prepared.asset_id 会作为 tryon 的 source_asset_id。
prepared = await workflow.execute_activity(
prepare_model_activity,
StepActivityInput(
@@ -76,12 +77,14 @@ class LowEndPipelineWorkflow:
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.TRYON
# 下游步骤通过 source_asset_id 引用上一步生成的资产。
# tryon 是换装主步骤:把 prepared 模特图和 garment_asset_id 组合成试衣结果。
# 它产出的 tryon_result.asset_id 是后续 scene / qc 的基础输入。
tryon_result = await workflow.execute_activity(
run_tryon_activity,
StepActivityInput(
@@ -92,38 +95,46 @@ class LowEndPipelineWorkflow:
garment_asset_id=payload.garment_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.SCENE
scene_result = await workflow.execute_activity(
run_scene_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.SCENE,
source_asset_id=tryon_result.asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
scene_source_asset_id = tryon_result.asset_id
if payload.scene_ref_asset_id is not None:
current_step = WorkflowStepName.SCENE
# scene 是可选步骤。
# 有场景图时,用 scene_ref_asset_id 把 tryon 结果合成到目标背景;
# 没有场景图时,直接沿用 tryon 结果继续往下走。
scene_result = await workflow.execute_activity(
run_scene_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.SCENE,
source_asset_id=tryon_result.asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
scene_source_asset_id = scene_result.asset_id
current_step = WorkflowStepName.QC
# QC 是流程里的“闸门”。
# 如果这里不通过,低端流程直接失败,不会再 export。
# source_asset_id 指向当前链路里“最接近最终成品”的那张图:
# 有场景时是 scene 结果,没有场景时就是 tryon 结果。
qc_result = await workflow.execute_activity(
run_qc_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.QC,
source_asset_id=scene_result.asset_id,
source_asset_id=scene_source_asset_id,
),
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -134,16 +145,17 @@ class LowEndPipelineWorkflow:
current_step = WorkflowStepName.EXPORT
# candidate_asset_ids 是 QC 推荐可导出的候选资产。
# 当前 MVP 只会返回一个候选;如果没有,就退回 scene 结果导出。
# export 是最后的收口步骤:生成最终交付资产,并把订单状态写成 succeeded。
final_result = await workflow.execute_activity(
run_export_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.EXPORT,
source_asset_id=(qc_result.candidate_asset_ids or [scene_result.asset_id])[0],
source_asset_id=(qc_result.candidate_asset_ids or [scene_source_asset_id])[0],
),
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
return {
@@ -177,6 +189,6 @@ class LowEndPipelineWorkflow:
message=message,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=DEFAULT_ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)

View File

@@ -34,6 +34,7 @@ with workflow.unsafe.imports_passed_through():
from app.workers.activities.scene_activities import run_scene_activity
from app.workers.activities.texture_activities import run_texture_activity
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
from app.workers.workflows.timeout_policy import DEFAULT_ACTIVITY_TIMEOUT, activity_timeout_for_task_queue
from app.workers.workflows.types import (
MockActivityResult,
PipelineWorkflowInput,
@@ -44,8 +45,6 @@ with workflow.unsafe.imports_passed_through():
WorkflowFailureActivityInput,
)
ACTIVITY_TIMEOUT = timedelta(seconds=30)
ACTIVITY_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2.0,
@@ -87,6 +86,8 @@ class MidEndPipelineWorkflow:
# current_step 用于失败时记录“最后跑到哪一步”。
current_step = WorkflowStepName.PREPARE_MODEL
try:
# prepare_model / tryon / scene 这三段和低端流程共享同一套 activity
# 区别只在于中端流程后面还会继续进入 texture / face / fusion / review。
prepared = await workflow.execute_activity(
prepare_model_activity,
StepActivityInput(
@@ -99,11 +100,12 @@ class MidEndPipelineWorkflow:
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.TRYON
# tryon 产出基础换装图,后面的所有增强步骤都以它为起点。
tryon_result = await workflow.execute_activity(
run_tryon_activity,
StepActivityInput(
@@ -114,23 +116,37 @@ class MidEndPipelineWorkflow:
garment_asset_id=payload.garment_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
scene_source_asset_id = tryon_result.asset_id
if payload.scene_ref_asset_id is not None:
current_step = WorkflowStepName.SCENE
# scene 仍然是可选步骤。
# 如果存在场景图,就先完成换背景,再把结果交给 texture / fusion
# 如果没有场景图,就直接把 tryon 结果当成“场景基底”继续流程。
scene_result = await self._run_scene(payload, tryon_result.asset_id)
scene_source_asset_id = scene_result.asset_id
else:
scene_result = tryon_result
current_step = WorkflowStepName.TEXTURE
texture_result = await self._run_texture(payload, scene_result.asset_id)
# texture 是复杂流独有步骤。
# 它通常负责衣物纹理、材质细节或局部增强,输入是当前场景基底图。
texture_result = await self._run_texture(payload, scene_source_asset_id)
current_step = WorkflowStepName.FACE
# face 也是复杂流独有步骤。
# 它在 texture 结果基础上做脸部修复/增强,产出供 fusion 合成的人像版本。
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
# fusion 负责把“场景基底”和“脸部增强结果”合成成候选成品。
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
current_step = WorkflowStepName.QC
# QC 和低端流程复用同一套 activity但这里检查的是 fusion 产物。
qc_result = await self._run_qc(payload, fusion_result.asset_id)
if not qc_result.passed:
await self._mark_failed(payload, current_step, qc_result.message)
@@ -144,6 +160,8 @@ class MidEndPipelineWorkflow:
current_step = WorkflowStepName.REVIEW
# 这里通过 activity 把数据库里的订单状态更新成 waiting_review
# 同时创建 review_task供 API 查询待审核列表。
# review 是复杂流真正区别于低端流程的核心:
# 在 export 前插入人工决策点,并支持按意见回流重跑。
await workflow.execute_activity(
mark_waiting_for_review_activity,
ReviewWaitActivityInput(
@@ -152,7 +170,7 @@ class MidEndPipelineWorkflow:
candidate_asset_ids=qc_result.candidate_asset_ids,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -172,7 +190,7 @@ class MidEndPipelineWorkflow:
comment=review_payload.comment,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -180,6 +198,7 @@ class MidEndPipelineWorkflow:
current_step = WorkflowStepName.EXPORT
# 如果审核人显式选了资产,就导出该资产;
# 否则默认导出 QC 候选资产。
# export 本身仍然和低端流程是同一个 activity。
export_source_id = review_payload.selected_asset_id
if export_source_id is None:
export_source_id = (qc_result.candidate_asset_ids or [fusion_result.asset_id])[0]
@@ -192,7 +211,7 @@ class MidEndPipelineWorkflow:
source_asset_id=export_source_id,
),
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
return {
@@ -208,22 +227,30 @@ class MidEndPipelineWorkflow:
# rerun 的核心思想是:
# 把指定节点后的链路重新跑一遍,然后再次进入 QC 和 waiting_review。
if review_payload.decision == ReviewDecision.RERUN_SCENE:
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
# 从 scene 开始重跑意味着 scene / texture / face / fusion 全部重算。
if payload.scene_ref_asset_id is not None:
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
scene_source_asset_id = scene_result.asset_id
else:
scene_result = tryon_result
scene_source_asset_id = tryon_result.asset_id
current_step = WorkflowStepName.TEXTURE
texture_result = await self._run_texture(payload, scene_result.asset_id)
texture_result = await self._run_texture(payload, scene_source_asset_id)
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
elif review_payload.decision == ReviewDecision.RERUN_FACE:
# 从 face 开始重跑时scene / texture 结果保持不变。
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
elif review_payload.decision == ReviewDecision.RERUN_FUSION:
# 从 fusion 重跑是最小范围回流,只重做最终合成。
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
current_step = WorkflowStepName.QC
qc_result = await self._run_qc(payload, fusion_result.asset_id)
@@ -264,7 +291,7 @@ class MidEndPipelineWorkflow:
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -280,7 +307,7 @@ class MidEndPipelineWorkflow:
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -296,7 +323,7 @@ class MidEndPipelineWorkflow:
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -318,7 +345,7 @@ class MidEndPipelineWorkflow:
selected_asset_id=face_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -334,7 +361,7 @@ class MidEndPipelineWorkflow:
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -355,6 +382,6 @@ class MidEndPipelineWorkflow:
message=message,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=DEFAULT_ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)

View File

@@ -0,0 +1,24 @@
"""Shared activity timeout policy for Temporal workflows."""
from datetime import timedelta
from app.infra.temporal.task_queues import (
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
)
DEFAULT_ACTIVITY_TIMEOUT = timedelta(seconds=30)
LONG_RUNNING_ACTIVITY_TIMEOUT = timedelta(minutes=5)
LONG_RUNNING_TASK_QUEUES = {
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
}
def activity_timeout_for_task_queue(task_queue: str) -> timedelta:
"""Return the timeout that matches the workload behind the given task queue."""
if task_queue in LONG_RUNNING_TASK_QUEUES:
return LONG_RUNNING_ACTIVITY_TIMEOUT
return DEFAULT_ACTIVITY_TIMEOUT

View File

@@ -32,14 +32,22 @@ class PipelineWorkflowInput:
这是一张订单进入 workflow 时携带的最小上下文。
"""
# 订单主键。整个 workflow 的所有步骤都会围绕这张订单落库、查状态。
order_id: int
# workflow_run 主键。用于把每一步 step 记录关联到同一次运行。
workflow_run_id: int
# 客户层级。决定业务侧订单归属,也会影响后续统计和展示。
customer_level: CustomerLevel
# 服务模式。这里直接决定启动简单流程还是复杂流程。
service_mode: ServiceMode
# 模特资源 ID。prepare_model 会基于它准备人物素材。
model_id: int
pose_id: int
# 服装资源/资产 ID。tryon 会把它作为换装输入。
garment_asset_id: int
scene_ref_asset_id: int
# 场景参考图 ID。可为空为空时跳过 scene 步骤。
scene_ref_asset_id: int | None = None
# 模特姿势 ID。当前 MVP 已经放宽为可空,后续需要姿势控制时再启用。
pose_id: int | None = None
def __post_init__(self) -> None:
"""在反序列化后把枚举字段修正回来。"""
@@ -59,15 +67,25 @@ class StepActivityInput:
- 上一步产出的 asset_id
"""
# 当前步骤属于哪张订单。
order_id: int
# 当前步骤属于哪次 workflow 运行。
workflow_run_id: int
# 当前执行的步骤名,例如 tryon / qc / export。
step_name: WorkflowStepName
# 模特资源 ID。只有 prepare_model 等少数步骤会直接消费它。
model_id: int | None = None
# 姿势 ID。当前大多为透传预留字段。
pose_id: int | None = None
# 服装资源/资产 ID。tryon 是主要消费者。
garment_asset_id: int | None = None
# 场景参考图 ID。scene 步骤会用它完成换背景。
scene_ref_asset_id: int | None = None
# 上一步产出的资产 ID。绝大多数步骤都靠它串起资产链路。
source_asset_id: int | None = None
# 人工选中的资产 ID。review / fusion / export 等需要显式选图时使用。
selected_asset_id: int | None = None
# 额外扩展参数。给特殊步骤塞一些不值得单独建字段的临时上下文。
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
@@ -80,14 +98,23 @@ class StepActivityInput:
class MockActivityResult:
"""通用 mock activity 返回结构。"""
# 这是哪一步返回的结果,方便 workflow 和落库代码识别来源。
step_name: WorkflowStepName
# activity 是否执行成功。一般表示“代码执行成功”,不等于业务一定通过。
success: bool
# 这一步新生成的资产 ID如果步骤不产出资产可以为空。
asset_id: int | None
# 产出资产的访问地址;当前 mock 阶段通常是 mock:// URI。
uri: str | None
# 模型分数/质量分数之类的数值结果,非所有步骤都有。
score: float | None = None
# 业务是否通过。典型场景是 QCactivity 成功执行,但 passed 可能为 False。
passed: bool | None = None
# 给 workflow 或 API 展示的简短结果消息。
message: str = "mock success"
# 候选资产列表。主要给 QC / review 这类“多候选图”步骤使用。
candidate_asset_ids: list[int] = field(default_factory=list)
# 补充元数据。用于把步骤内部的一些上下文回传给调用方。
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
@@ -100,9 +127,13 @@ class MockActivityResult:
class ReviewSignalPayload:
"""API 发给中端 workflow 的审核 signal 载荷。"""
# 审核动作:通过 / 拒绝 / 从某一步重跑。
decision: ReviewDecision
# 操作审核的用户 ID便于审计和任务归属。
reviewer_id: int
# 审核人最终选中的候选资产 IDapprove 时最常见。
selected_asset_id: int | None = None
# 审核备注,例如“脸部不自然,重跑融合”。
comment: str | None = None
def __post_init__(self) -> None:
@@ -115,9 +146,13 @@ class ReviewSignalPayload:
class ReviewWaitActivityInput:
"""把流程切到 waiting_review 时传给 activity 的输入。"""
# 当前审核任务属于哪张订单。
order_id: int
# 当前审核任务属于哪次 workflow 运行。
workflow_run_id: int
# 提交给审核端看的候选资产列表。
candidate_asset_ids: list[int] = field(default_factory=list)
# 进入审核态时附带的说明文案,当前大多留空。
comment: str | None = None
@@ -125,11 +160,17 @@ class ReviewWaitActivityInput:
class ReviewResolutionActivityInput:
"""审核结果到达后,用于结束 waiting_review 的输入。"""
# 当前审核结果属于哪张订单。
order_id: int
# 当前审核结果属于哪次 workflow 运行。
workflow_run_id: int
# 审核最终决策。
decision: ReviewDecision
# 处理这次审核的审核人 ID。
reviewer_id: int
# 如果审核人明确挑了一张图,这里记录最终选择的资产 ID。
selected_asset_id: int | None = None
# 审核备注或重跑原因。
comment: str | None = None
def __post_init__(self) -> None:
@@ -142,10 +183,15 @@ class ReviewResolutionActivityInput:
class WorkflowFailureActivityInput:
"""流程失败收尾 activity 的输入。"""
# 失败发生在哪张订单。
order_id: int
# 失败发生在哪次 workflow 运行。
workflow_run_id: int
# 失败停留在哪个步骤,用于更新 workflow_run.current_step。
current_step: WorkflowStepName
# 失败原因文本,通常来自异常或 QC 驳回消息。
message: str
# 要写回数据库的最终状态,默认就是 failed。
status: OrderStatus = OrderStatus.FAILED
def __post_init__(self) -> None: