feat: add resource library and real image workflow

This commit is contained in:
afei A
2026-03-29 00:24:29 +08:00
parent eeaff269eb
commit 04da401ab4
38 changed files with 3033 additions and 117 deletions

View File

@@ -5,3 +5,9 @@ AUTO_CREATE_TABLES=true
DATABASE_URL=sqlite+aiosqlite:///./temporal_demo.db
TEMPORAL_ADDRESS=localhost:7233
TEMPORAL_NAMESPACE=default
IMAGE_GENERATION_PROVIDER=mock
GEMINI_API_KEY=
GEMINI_BASE_URL=https://api.museidea.com/v1beta
GEMINI_MODEL=gemini-3.1-flash-image-preview
GEMINI_TIMEOUT_SECONDS=300
GEMINI_MAX_ATTEMPTS=2

View File

@@ -8,12 +8,14 @@ from sqlalchemy import engine_from_config, pool
from app.config.settings import get_settings
from app.infra.db.base import Base
from app.infra.db.models.asset import AssetORM
from app.infra.db.models.library_resource import LibraryResourceORM
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
from app.infra.db.models.order import OrderORM
from app.infra.db.models.review_task import ReviewTaskORM
from app.infra.db.models.workflow_run import WorkflowRunORM
from app.infra.db.models.workflow_step import WorkflowStepORM
del AssetORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM
del AssetORM, LibraryResourceORM, LibraryResourceFileORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM
config = context.config
@@ -58,4 +60,3 @@ if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,92 @@
"""resource library schema
Revision ID: 20260328_0003
Revises: 20260327_0002
Create Date: 2026-03-28 10:20:00.000000
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
revision: str = "20260328_0003"
down_revision: str | None = "20260327_0002"
branch_labels: Sequence[str] | None = None
depends_on: Sequence[str] | None = None
def upgrade() -> None:
"""Create the resource-library tables."""
op.create_table(
"library_resources",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column(
"resource_type",
sa.Enum("model", "scene", "garment", name="libraryresourcetype", native_enum=False),
nullable=False,
),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("tags", sa.JSON(), nullable=False),
sa.Column(
"status",
sa.Enum("active", "archived", name="libraryresourcestatus", native_enum=False),
nullable=False,
),
sa.Column("gender", sa.String(length=32), nullable=True),
sa.Column("age_group", sa.String(length=32), nullable=True),
sa.Column("pose_id", sa.Integer(), nullable=True),
sa.Column("environment", sa.String(length=32), nullable=True),
sa.Column("category", sa.String(length=128), nullable=True),
sa.Column("cover_file_id", sa.Integer(), nullable=True),
sa.Column("original_file_id", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_index("ix_library_resources_resource_type", "library_resources", ["resource_type"])
op.create_index("ix_library_resources_status", "library_resources", ["status"])
op.create_index("ix_library_resources_gender", "library_resources", ["gender"])
op.create_index("ix_library_resources_age_group", "library_resources", ["age_group"])
op.create_index("ix_library_resources_environment", "library_resources", ["environment"])
op.create_index("ix_library_resources_category", "library_resources", ["category"])
op.create_table(
"library_resource_files",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("resource_id", sa.Integer(), sa.ForeignKey("library_resources.id"), nullable=False),
sa.Column(
"file_role",
sa.Enum("original", "thumbnail", "gallery", name="libraryfilerole", native_enum=False),
nullable=False,
),
sa.Column("storage_key", sa.String(length=500), nullable=False),
sa.Column("public_url", sa.String(length=500), nullable=False),
sa.Column("bucket", sa.String(length=255), nullable=False),
sa.Column("mime_type", sa.String(length=255), nullable=False),
sa.Column("size_bytes", sa.Integer(), nullable=False),
sa.Column("sort_order", sa.Integer(), nullable=False),
sa.Column("width", sa.Integer(), nullable=True),
sa.Column("height", sa.Integer(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
)
op.create_index("ix_library_resource_files_resource_id", "library_resource_files", ["resource_id"])
op.create_index("ix_library_resource_files_file_role", "library_resource_files", ["file_role"])
def downgrade() -> None:
"""Drop the resource-library tables."""
op.drop_index("ix_library_resource_files_file_role", table_name="library_resource_files")
op.drop_index("ix_library_resource_files_resource_id", table_name="library_resource_files")
op.drop_table("library_resource_files")
op.drop_index("ix_library_resources_category", table_name="library_resources")
op.drop_index("ix_library_resources_environment", table_name="library_resources")
op.drop_index("ix_library_resources_age_group", table_name="library_resources")
op.drop_index("ix_library_resources_gender", table_name="library_resources")
op.drop_index("ix_library_resources_status", table_name="library_resources")
op.drop_index("ix_library_resources_resource_type", table_name="library_resources")
op.drop_table("library_resources")

View File

@@ -0,0 +1,25 @@
"""Make order pose_id optional for MVP.
Revision ID: 20260328_0004
Revises: 20260328_0003
Create Date: 2026-03-28 11:08:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "20260328_0004"
down_revision = "20260328_0003"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("orders") as batch_op:
batch_op.alter_column("pose_id", existing_type=sa.Integer(), nullable=True)
def downgrade() -> None:
with op.batch_alter_table("orders") as batch_op:
batch_op.alter_column("pose_id", existing_type=sa.Integer(), nullable=False)

View File

@@ -0,0 +1,25 @@
"""Make order scene_ref_asset_id optional for MVP.
Revision ID: 20260328_0005
Revises: 20260328_0004
Create Date: 2026-03-28 14:20:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "20260328_0005"
down_revision = "20260328_0004"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("orders") as batch_op:
batch_op.alter_column("scene_ref_asset_id", existing_type=sa.Integer(), nullable=True)
def downgrade() -> None:
with op.batch_alter_table("orders") as batch_op:
batch_op.alter_column("scene_ref_asset_id", existing_type=sa.Integer(), nullable=False)

View File

@@ -0,0 +1,89 @@
"""Library resource routes."""
from fastapi import APIRouter, Depends, Query, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.schemas.library import (
ArchiveLibraryResourceResponse,
CreateLibraryResourceRequest,
LibraryResourceListResponse,
LibraryResourceRead,
PresignUploadRequest,
PresignUploadResponse,
UpdateLibraryResourceRequest,
)
from app.application.services.library_service import LibraryService
from app.domain.enums import LibraryResourceType
from app.infra.db.session import get_db_session
router = APIRouter(prefix="/library", tags=["library"])
library_service = LibraryService()
@router.post("/uploads/presign", response_model=PresignUploadResponse)
async def create_library_upload_presign(payload: PresignUploadRequest) -> PresignUploadResponse:
"""Create direct-upload metadata for a library file."""
return library_service.create_upload_presign(
resource_type=payload.resource_type,
file_name=payload.file_name,
content_type=payload.content_type,
)
@router.post("/resources", response_model=LibraryResourceRead, status_code=status.HTTP_201_CREATED)
async def create_library_resource(
payload: CreateLibraryResourceRequest,
session: AsyncSession = Depends(get_db_session),
) -> LibraryResourceRead:
"""Create a library resource with already-uploaded file metadata."""
return await library_service.create_resource(session, payload)
@router.get("/resources", response_model=LibraryResourceListResponse)
async def list_library_resources(
resource_type: LibraryResourceType | None = Query(default=None),
query: str | None = Query(default=None, min_length=1),
gender: str | None = Query(default=None),
age_group: str | None = Query(default=None),
environment: str | None = Query(default=None),
category: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=20, ge=1, le=100),
session: AsyncSession = Depends(get_db_session),
) -> LibraryResourceListResponse:
"""List library resources with basic filtering."""
return await library_service.list_resources(
session,
resource_type=resource_type,
query=query,
gender=gender,
age_group=age_group,
environment=environment,
category=category,
page=page,
limit=limit,
)
@router.patch("/resources/{resource_id}", response_model=LibraryResourceRead)
async def update_library_resource(
resource_id: int,
payload: UpdateLibraryResourceRequest,
session: AsyncSession = Depends(get_db_session),
) -> LibraryResourceRead:
"""Update editable metadata on a library resource."""
return await library_service.update_resource(session, resource_id, payload)
@router.delete("/resources/{resource_id}", response_model=ArchiveLibraryResourceResponse)
async def archive_library_resource(
resource_id: int,
session: AsyncSession = Depends(get_db_session),
) -> ArchiveLibraryResourceResponse:
"""Archive a library resource instead of hard deleting it."""
return await library_service.archive_resource(session, resource_id)

118
app/api/schemas/library.py Normal file
View File

@@ -0,0 +1,118 @@
"""Library resource API schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from app.domain.enums import LibraryFileRole, LibraryResourceStatus, LibraryResourceType
class PresignUploadRequest(BaseModel):
"""Request payload for generating direct-upload metadata."""
resource_type: LibraryResourceType
file_role: LibraryFileRole
file_name: str
content_type: str
class PresignUploadResponse(BaseModel):
"""Response returned when generating direct-upload metadata."""
method: str
upload_url: str
headers: dict[str, str] = Field(default_factory=dict)
storage_key: str
public_url: str
class CreateLibraryResourceFileRequest(BaseModel):
"""Metadata for one already-uploaded file."""
file_role: LibraryFileRole
storage_key: str
public_url: str
mime_type: str
size_bytes: int
sort_order: int = 0
width: int | None = None
height: int | None = None
class CreateLibraryResourceRequest(BaseModel):
"""One-shot resource creation request."""
resource_type: LibraryResourceType
name: str
description: str | None = None
tags: list[str] = Field(default_factory=list)
gender: str | None = None
age_group: str | None = None
pose_id: int | None = None
environment: str | None = None
category: str | None = None
files: list[CreateLibraryResourceFileRequest]
class UpdateLibraryResourceRequest(BaseModel):
"""Partial update request for a library resource."""
name: str | None = None
description: str | None = None
tags: list[str] | None = None
gender: str | None = None
age_group: str | None = None
pose_id: int | None = None
environment: str | None = None
category: str | None = None
cover_file_id: int | None = None
class ArchiveLibraryResourceResponse(BaseModel):
"""Response returned after archiving a resource."""
id: int
class LibraryResourceFileRead(BaseModel):
"""Serialized library file metadata."""
id: int
file_role: LibraryFileRole
storage_key: str
public_url: str
bucket: str
mime_type: str
size_bytes: int
sort_order: int
width: int | None = None
height: int | None = None
created_at: datetime
class LibraryResourceRead(BaseModel):
"""Serialized library resource."""
id: int
resource_type: LibraryResourceType
name: str
description: str | None = None
tags: list[str]
status: LibraryResourceStatus
gender: str | None = None
age_group: str | None = None
pose_id: int | None = None
environment: str | None = None
category: str | None = None
cover_url: str | None = None
original_url: str | None = None
files: list[LibraryResourceFileRead]
created_at: datetime
updated_at: datetime
class LibraryResourceListResponse(BaseModel):
"""Paginated library resource response."""
total: int
items: list[LibraryResourceRead]

View File

@@ -14,9 +14,9 @@ class CreateOrderRequest(BaseModel):
customer_level: CustomerLevel
service_mode: ServiceMode
model_id: int
pose_id: int
pose_id: int | None = None
garment_asset_id: int
scene_ref_asset_id: int
scene_ref_asset_id: int | None = None
class CreateOrderResponse(BaseModel):
@@ -35,9 +35,9 @@ class OrderDetailResponse(BaseModel):
service_mode: ServiceMode
status: OrderStatus
model_id: int
pose_id: int
pose_id: int | None
garment_asset_id: int
scene_ref_asset_id: int
scene_ref_asset_id: int | None
final_asset_id: int | None
workflow_id: str | None
current_step: WorkflowStepName | None

View File

@@ -0,0 +1,163 @@
"""Application-facing orchestration for image-generation providers."""
from __future__ import annotations
import base64
import httpx
from fastapi import HTTPException, status
from app.config.settings import get_settings
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
from app.infra.image_generation.gemini_provider import GeminiImageProvider
TRYON_PROMPT = (
"Use the first image as the base person image and the second image as the garment reference. "
"Generate a realistic virtual try-on result where the person in the first image is wearing the garment from "
"the second image. Preserve the person's identity, body proportions, pose, camera angle, lighting direction, "
"and background composition from the first image. Preserve the garment's key silhouette, material feel, color, "
"and visible design details from the second image. Produce a clean e-commerce quality result without adding "
"extra accessories, text, watermarks, or duplicate limbs."
)
SCENE_PROMPT = (
"Use the first image as the finished subject image and the second image as the target scene reference. "
"Generate a realistic composite where the person, clothing, body proportions, facial identity, pose, and camera "
"framing from the first image are preserved, while the environment is adapted to match the second image. "
"Keep the garment details, colors, and texture from the first image intact. Blend the subject naturally into the "
"target scene with coherent perspective, lighting, shadows, and depth. Do not add extra people, accessories, "
"text, logos, watermarks, or duplicate body parts."
)
MOCK_PNG_BASE64 = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO5Ww4QAAAAASUVORK5CYII="
)
class MockImageGenerationService:
"""Deterministic fallback used in tests and when no real provider is configured."""
async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> GeneratedImageResult:
return GeneratedImageResult(
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
mime_type="image/png",
provider="mock",
model="mock-image-provider",
prompt=TRYON_PROMPT,
)
async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> GeneratedImageResult:
return GeneratedImageResult(
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
mime_type="image/png",
provider="mock",
model="mock-image-provider",
prompt=SCENE_PROMPT,
)
class ImageGenerationService:
"""Download source assets and dispatch them to the configured provider."""
def __init__(
self,
*,
provider: GeminiImageProvider | None = None,
downloader: httpx.AsyncClient | None = None,
) -> None:
self.settings = get_settings()
self.provider = provider or GeminiImageProvider(
api_key=self.settings.gemini_api_key,
base_url=self.settings.gemini_base_url,
model=self.settings.gemini_model,
timeout_seconds=self.settings.gemini_timeout_seconds,
max_attempts=self.settings.gemini_max_attempts,
)
self._downloader = downloader
async def generate_tryon_image(
self,
*,
person_image_url: str,
garment_image_url: str,
) -> GeneratedImageResult:
"""Generate a try-on image using the configured provider."""
if not self.settings.gemini_api_key:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
)
person_image, garment_image = await self._download_inputs(
first_image_url=person_image_url,
second_image_url=garment_image_url,
)
return await self.provider.generate_tryon_image(
prompt=TRYON_PROMPT,
person_image=person_image,
garment_image=garment_image,
)
async def generate_scene_image(
self,
*,
source_image_url: str,
scene_image_url: str,
) -> GeneratedImageResult:
"""Generate a scene-composited image using the configured provider."""
if not self.settings.gemini_api_key:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
)
source_image, scene_image = await self._download_inputs(
first_image_url=source_image_url,
second_image_url=scene_image_url,
)
return await self.provider.generate_scene_image(
prompt=SCENE_PROMPT,
source_image=source_image,
scene_image=scene_image,
)
async def _download_inputs(
self,
*,
first_image_url: str,
second_image_url: str,
) -> tuple[SourceImage, SourceImage]:
owns_client = self._downloader is None
client = self._downloader or httpx.AsyncClient(timeout=self.settings.gemini_timeout_seconds)
try:
first_response = await client.get(first_image_url)
first_response.raise_for_status()
second_response = await client.get(second_image_url)
second_response.raise_for_status()
finally:
if owns_client:
await client.aclose()
first_mime = first_response.headers.get("content-type", "image/png").split(";")[0].strip()
second_mime = second_response.headers.get("content-type", "image/png").split(";")[0].strip()
return (
SourceImage(url=first_image_url, mime_type=first_mime or "image/png", data=first_response.content),
SourceImage(url=second_image_url, mime_type=second_mime or "image/png", data=second_response.content),
)
def build_image_generation_service():
"""Return the configured image-generation service."""
settings = get_settings()
if settings.image_generation_provider.lower() == "mock":
return MockImageGenerationService()
if settings.image_generation_provider.lower() == "gemini":
return ImageGenerationService()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unsupported image_generation_provider: {settings.image_generation_provider}",
)

View File

@@ -0,0 +1,366 @@
"""Library resource application service."""
from __future__ import annotations
from fastapi import HTTPException, status
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.api.schemas.library import (
ArchiveLibraryResourceResponse,
CreateLibraryResourceRequest,
LibraryResourceFileRead,
LibraryResourceListResponse,
LibraryResourceRead,
PresignUploadResponse,
UpdateLibraryResourceRequest,
)
from app.config.settings import get_settings
from app.domain.enums import LibraryFileRole, LibraryResourceStatus, LibraryResourceType
from app.infra.db.models.library_resource import LibraryResourceORM
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
from app.infra.storage.s3 import RESOURCE_PREFIXES, S3PresignService
class LibraryService:
"""Application service for resource-library uploads and queries."""
def __init__(self, presign_service: S3PresignService | None = None) -> None:
self.presign_service = presign_service or S3PresignService()
def create_upload_presign(
self,
resource_type: LibraryResourceType,
file_name: str,
content_type: str,
) -> PresignUploadResponse:
"""Create upload metadata for a direct S3 PUT."""
settings = get_settings()
if not all(
[
settings.s3_bucket,
settings.s3_region,
settings.s3_access_key,
settings.s3_secret_key,
settings.s3_endpoint,
]
):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="S3 upload settings are incomplete",
)
storage_key, upload_url = self.presign_service.create_upload(
resource_type=resource_type,
file_name=file_name,
content_type=content_type,
)
return PresignUploadResponse(
method="PUT",
upload_url=upload_url,
headers={"content-type": content_type},
storage_key=storage_key,
public_url=self.presign_service.get_public_url(storage_key),
)
async def create_resource(
self,
session: AsyncSession,
payload: CreateLibraryResourceRequest,
) -> LibraryResourceRead:
"""Persist a resource and its uploaded file metadata."""
self._validate_payload(payload)
resource = LibraryResourceORM(
resource_type=payload.resource_type,
name=payload.name.strip(),
description=payload.description.strip() if payload.description else None,
tags=[tag.strip() for tag in payload.tags if tag.strip()],
status=LibraryResourceStatus.ACTIVE,
gender=payload.gender.strip() if payload.gender else None,
age_group=payload.age_group.strip() if payload.age_group else None,
pose_id=payload.pose_id,
environment=payload.environment.strip() if payload.environment else None,
category=payload.category.strip() if payload.category else None,
)
session.add(resource)
await session.flush()
file_models: list[LibraryResourceFileORM] = []
for item in payload.files:
file_model = LibraryResourceFileORM(
resource_id=resource.id,
file_role=item.file_role,
storage_key=item.storage_key,
public_url=item.public_url,
bucket=get_settings().s3_bucket,
mime_type=item.mime_type,
size_bytes=item.size_bytes,
sort_order=item.sort_order,
width=item.width,
height=item.height,
)
session.add(file_model)
file_models.append(file_model)
await session.flush()
resource.original_file_id = self._find_role(file_models, LibraryFileRole.ORIGINAL).id
resource.cover_file_id = self._find_role(file_models, LibraryFileRole.THUMBNAIL).id
await session.commit()
return self._to_read(resource, files=file_models)
async def list_resources(
self,
session: AsyncSession,
*,
resource_type: LibraryResourceType | None = None,
query: str | None = None,
gender: str | None = None,
age_group: str | None = None,
environment: str | None = None,
category: str | None = None,
page: int = 1,
limit: int = 20,
) -> LibraryResourceListResponse:
"""List persisted resources with simple filter support."""
filters = [LibraryResourceORM.status == LibraryResourceStatus.ACTIVE]
if resource_type is not None:
filters.append(LibraryResourceORM.resource_type == resource_type)
if query:
like = f"%{query.strip()}%"
filters.append(
or_(
LibraryResourceORM.name.ilike(like),
LibraryResourceORM.description.ilike(like),
)
)
if gender:
filters.append(LibraryResourceORM.gender == gender)
if age_group:
filters.append(LibraryResourceORM.age_group == age_group)
if environment:
filters.append(LibraryResourceORM.environment == environment)
if category:
filters.append(LibraryResourceORM.category == category)
total = (
await session.execute(select(func.count(LibraryResourceORM.id)).where(*filters))
).scalar_one()
result = await session.execute(
select(LibraryResourceORM)
.options(selectinload(LibraryResourceORM.files))
.where(*filters)
.order_by(LibraryResourceORM.created_at.desc(), LibraryResourceORM.id.desc())
.offset((page - 1) * limit)
.limit(limit)
)
items = result.scalars().all()
return LibraryResourceListResponse(total=total, items=[self._to_read(item) for item in items])
async def update_resource(
self,
session: AsyncSession,
resource_id: int,
payload: UpdateLibraryResourceRequest,
) -> LibraryResourceRead:
"""Update editable metadata on a library resource."""
resource = await self._get_resource_or_404(session, resource_id)
if payload.name is not None:
name = payload.name.strip()
if not name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Resource name is required",
)
resource.name = name
if payload.description is not None:
resource.description = payload.description.strip() or None
if payload.tags is not None:
resource.tags = [tag.strip() for tag in payload.tags if tag.strip()]
if resource.resource_type == LibraryResourceType.MODEL:
if payload.gender is not None:
resource.gender = payload.gender.strip() or None
if payload.age_group is not None:
resource.age_group = payload.age_group.strip() or None
if payload.pose_id is not None:
resource.pose_id = payload.pose_id
elif resource.resource_type == LibraryResourceType.SCENE:
if payload.environment is not None:
resource.environment = payload.environment.strip() or None
elif resource.resource_type == LibraryResourceType.GARMENT:
if payload.category is not None:
resource.category = payload.category.strip() or None
if payload.cover_file_id is not None:
cover_file = next(
(file for file in resource.files if file.id == payload.cover_file_id),
None,
)
if cover_file is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cover file does not belong to the resource",
)
resource.cover_file_id = cover_file.id
self._validate_existing_resource(resource)
await session.commit()
await session.refresh(resource, attribute_names=["files"])
return self._to_read(resource)
async def archive_resource(
self,
session: AsyncSession,
resource_id: int,
) -> ArchiveLibraryResourceResponse:
"""Soft delete a library resource by archiving it."""
resource = await self._get_resource_or_404(session, resource_id)
resource.status = LibraryResourceStatus.ARCHIVED
await session.commit()
return ArchiveLibraryResourceResponse(id=resource.id)
def _validate_payload(self, payload: CreateLibraryResourceRequest) -> None:
if not payload.name.strip():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Resource name is required")
if not payload.files:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one file is required")
originals = [file for file in payload.files if file.file_role == LibraryFileRole.ORIGINAL]
thumbnails = [file for file in payload.files if file.file_role == LibraryFileRole.THUMBNAIL]
if len(originals) != 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Exactly one original file is required",
)
if len(thumbnails) != 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Exactly one thumbnail file is required",
)
expected_prefix = f"library/{RESOURCE_PREFIXES[payload.resource_type]}/"
for file in payload.files:
if not file.storage_key.startswith(expected_prefix):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Uploaded file key does not match resource type",
)
if payload.resource_type == LibraryResourceType.MODEL:
if not payload.gender or not payload.age_group:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Model resources require gender and age_group",
)
elif payload.resource_type == LibraryResourceType.SCENE:
if not payload.environment:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Scene resources require environment",
)
elif payload.resource_type == LibraryResourceType.GARMENT and not payload.category:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Garment resources require category",
)
def _find_role(
self,
files: list[LibraryResourceFileORM],
role: LibraryFileRole,
) -> LibraryResourceFileORM:
return next(file for file in files if file.file_role == role)
def _to_read(
self,
resource: LibraryResourceORM,
*,
files: list[LibraryResourceFileORM] | None = None,
) -> LibraryResourceRead:
resource_files = files if files is not None else list(resource.files)
files_sorted = sorted(resource_files, key=lambda item: (item.sort_order, item.id))
cover = next((item for item in files_sorted if item.id == resource.cover_file_id), None)
original = next((item for item in files_sorted if item.id == resource.original_file_id), None)
return LibraryResourceRead(
id=resource.id,
resource_type=resource.resource_type,
name=resource.name,
description=resource.description,
tags=resource.tags,
status=resource.status,
gender=resource.gender,
age_group=resource.age_group,
pose_id=resource.pose_id,
environment=resource.environment,
category=resource.category,
cover_url=cover.public_url if cover else None,
original_url=original.public_url if original else None,
files=[
LibraryResourceFileRead(
id=file.id,
file_role=file.file_role,
storage_key=file.storage_key,
public_url=file.public_url,
bucket=file.bucket,
mime_type=file.mime_type,
size_bytes=file.size_bytes,
sort_order=file.sort_order,
width=file.width,
height=file.height,
created_at=file.created_at,
)
for file in files_sorted
],
created_at=resource.created_at,
updated_at=resource.updated_at,
)
async def _get_resource_or_404(
self,
session: AsyncSession,
resource_id: int,
) -> LibraryResourceORM:
result = await session.execute(
select(LibraryResourceORM)
.options(selectinload(LibraryResourceORM.files))
.where(LibraryResourceORM.id == resource_id)
)
resource = result.scalar_one_or_none()
if resource is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found")
return resource
def _validate_existing_resource(self, resource: LibraryResourceORM) -> None:
if not resource.name.strip():
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Resource name is required")
if resource.resource_type == LibraryResourceType.MODEL:
if not resource.gender or not resource.age_group:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Model resources require gender and age_group",
)
elif resource.resource_type == LibraryResourceType.SCENE:
if not resource.environment:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Scene resources require environment",
)
elif resource.resource_type == LibraryResourceType.GARMENT and not resource.category:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Garment resources require category",
)

View File

@@ -1,9 +1,12 @@
"""Application settings."""
from functools import lru_cache
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
PROJECT_ROOT = Path(__file__).resolve().parents[2]
class Settings(BaseSettings):
"""Runtime settings loaded from environment variables."""
@@ -15,11 +18,25 @@ class Settings(BaseSettings):
database_url: str = "sqlite+aiosqlite:///./temporal_demo.db"
temporal_address: str = "localhost:7233"
temporal_namespace: str = "default"
s3_access_key: str = ""
s3_secret_key: str = ""
s3_bucket: str = ""
s3_region: str = ""
s3_endpoint: str = ""
s3_cname: str = ""
s3_presign_expiry_seconds: int = 900
image_generation_provider: str = "mock"
gemini_api_key: str = ""
gemini_base_url: str = "https://api.museidea.com/v1beta"
gemini_model: str = "gemini-3.1-flash-image-preview"
gemini_timeout_seconds: int = 300
gemini_max_attempts: int = 2
model_config = SettingsConfigDict(
env_file=".env",
env_file=PROJECT_ROOT / ".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
@property
@@ -34,4 +51,3 @@ def get_settings() -> Settings:
"""Return the cached application settings."""
return Settings()

View File

@@ -82,3 +82,26 @@ class AssetType(str, Enum):
QC_CANDIDATE = "qc_candidate"
MANUAL_REVISION = "manual_revision"
FINAL = "final"
class LibraryResourceType(str, Enum):
"""Supported resource-library item types."""
MODEL = "model"
SCENE = "scene"
GARMENT = "garment"
class LibraryFileRole(str, Enum):
"""Supported file roles within a library resource."""
ORIGINAL = "original"
THUMBNAIL = "thumbnail"
GALLERY = "gallery"
class LibraryResourceStatus(str, Enum):
"""Lifecycle state for a library resource."""
ACTIVE = "active"
ARCHIVED = "archived"

View File

@@ -15,12 +15,11 @@ class Order:
service_mode: ServiceMode
status: OrderStatus
model_id: int
pose_id: int
pose_id: int | None
garment_asset_id: int
scene_ref_asset_id: int
scene_ref_asset_id: int | None
final_asset_id: int | None
workflow_id: str | None
current_step: WorkflowStepName | None
created_at: datetime
updated_at: datetime

View File

@@ -0,0 +1,45 @@
"""Library resource ORM model."""
from __future__ import annotations
from sqlalchemy import Enum, Integer, JSON, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.domain.enums import LibraryResourceStatus, LibraryResourceType
from app.infra.db.base import Base, TimestampMixin
class LibraryResourceORM(TimestampMixin, Base):
"""Persisted library resource independent from order-generated assets."""
__tablename__ = "library_resources"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
resource_type: Mapped[LibraryResourceType] = mapped_column(
Enum(LibraryResourceType, native_enum=False),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
tags: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
status: Mapped[LibraryResourceStatus] = mapped_column(
Enum(LibraryResourceStatus, native_enum=False),
nullable=False,
default=LibraryResourceStatus.ACTIVE,
index=True,
)
gender: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True)
age_group: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True)
pose_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
environment: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True)
category: Mapped[str | None] = mapped_column(String(128), nullable=True, index=True)
cover_file_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
original_file_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
files = relationship(
"LibraryResourceFileORM",
back_populates="resource",
lazy="selectin",
cascade="all, delete-orphan",
)

View File

@@ -0,0 +1,37 @@
"""Library resource file ORM model."""
from __future__ import annotations
from sqlalchemy import Enum, ForeignKey, Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.domain.enums import LibraryFileRole
from app.infra.db.base import Base, TimestampMixin
class LibraryResourceFileORM(TimestampMixin, Base):
"""Persisted uploaded file metadata for a library resource."""
__tablename__ = "library_resource_files"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
resource_id: Mapped[int] = mapped_column(
ForeignKey("library_resources.id"),
nullable=False,
index=True,
)
file_role: Mapped[LibraryFileRole] = mapped_column(
Enum(LibraryFileRole, native_enum=False),
nullable=False,
index=True,
)
storage_key: Mapped[str] = mapped_column(String(500), nullable=False)
public_url: Mapped[str] = mapped_column(String(500), nullable=False)
bucket: Mapped[str] = mapped_column(String(255), nullable=False)
mime_type: Mapped[str] = mapped_column(String(255), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
width: Mapped[int | None] = mapped_column(Integer, nullable=True)
height: Mapped[int | None] = mapped_column(Integer, nullable=True)
resource = relationship("LibraryResourceORM", back_populates="files")

View File

@@ -27,12 +27,11 @@ class OrderORM(TimestampMixin, Base):
default=OrderStatus.CREATED,
)
model_id: Mapped[int] = mapped_column(Integer, nullable=False)
pose_id: Mapped[int] = mapped_column(Integer, nullable=False)
pose_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
garment_asset_id: Mapped[int] = mapped_column(Integer, nullable=False)
scene_ref_asset_id: Mapped[int] = mapped_column(Integer, nullable=False)
scene_ref_asset_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
final_asset_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
assets = relationship("AssetORM", back_populates="order", lazy="selectin")
review_tasks = relationship("ReviewTaskORM", back_populates="order", lazy="selectin")
workflow_runs = relationship("WorkflowRunORM", back_populates="order", lazy="selectin")

View File

@@ -44,12 +44,14 @@ async def init_database() -> None:
"""Create database tables when running the MVP without migrations."""
from app.infra.db.models.asset import AssetORM
from app.infra.db.models.library_resource import LibraryResourceORM
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
from app.infra.db.models.order import OrderORM
from app.infra.db.models.review_task import ReviewTaskORM
from app.infra.db.models.workflow_run import WorkflowRunORM
from app.infra.db.models.workflow_step import WorkflowStepORM
del AssetORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM
del AssetORM, LibraryResourceORM, LibraryResourceFileORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM
async with get_async_engine().begin() as connection:
await connection.run_sync(Base.metadata.create_all)

View File

@@ -0,0 +1,38 @@
"""Shared types for image-generation providers."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
@dataclass(slots=True)
class SourceImage:
"""A binary source image passed into an image-generation provider."""
url: str
mime_type: str
data: bytes
@dataclass(slots=True)
class GeneratedImageResult:
"""Normalized image-generation output returned by providers."""
image_bytes: bytes
mime_type: str
provider: str
model: str
prompt: str
class ImageGenerationProvider(Protocol):
"""Contract implemented by concrete image-generation providers."""
async def generate_tryon_image(
self,
*,
prompt: str,
person_image: SourceImage,
garment_image: SourceImage,
) -> GeneratedImageResult: ...

View File

@@ -0,0 +1,149 @@
"""Gemini-backed image-generation provider."""
from __future__ import annotations
import base64
import httpx
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
class GeminiImageProvider:
"""Call the Gemini image-generation API via a configurable REST endpoint."""
provider_name = "gemini"
def __init__(
self,
*,
api_key: str,
base_url: str,
model: str,
timeout_seconds: int,
max_attempts: int = 2,
http_client: httpx.AsyncClient | None = None,
) -> None:
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.model = model
self.timeout_seconds = timeout_seconds
self.max_attempts = max(1, max_attempts)
self._http_client = http_client
async def generate_tryon_image(
self,
*,
prompt: str,
person_image: SourceImage,
garment_image: SourceImage,
) -> GeneratedImageResult:
"""Generate a try-on image from a prepared person image and a garment image."""
return await self._generate_two_image_edit(
prompt=prompt,
first_image=person_image,
second_image=garment_image,
)
async def generate_scene_image(
self,
*,
prompt: str,
source_image: SourceImage,
scene_image: SourceImage,
) -> GeneratedImageResult:
"""Generate a scene-composited image from a rendered subject image and a scene reference."""
return await self._generate_two_image_edit(
prompt=prompt,
first_image=source_image,
second_image=scene_image,
)
async def _generate_two_image_edit(
self,
*,
prompt: str,
first_image: SourceImage,
second_image: SourceImage,
) -> GeneratedImageResult:
"""Generate an edited image from a prompt plus two inline image inputs."""
payload = {
"contents": [
{
"parts": [
{"text": prompt},
{
"inline_data": {
"mime_type": first_image.mime_type,
"data": base64.b64encode(first_image.data).decode("utf-8"),
}
},
{
"inline_data": {
"mime_type": second_image.mime_type,
"data": base64.b64encode(second_image.data).decode("utf-8"),
}
},
]
}
],
"generationConfig": {
"responseModalities": ["TEXT", "IMAGE"],
},
}
owns_client = self._http_client is None
client = self._http_client or httpx.AsyncClient(timeout=self.timeout_seconds)
try:
response = None
for attempt in range(1, self.max_attempts + 1):
try:
response = await client.post(
f"{self.base_url}/models/{self.model}:generateContent",
headers={
"x-goog-api-key": self.api_key,
"Content-Type": "application/json",
},
json=payload,
timeout=self.timeout_seconds,
)
response.raise_for_status()
break
except httpx.TransportError:
if attempt >= self.max_attempts:
raise
finally:
if owns_client:
await client.aclose()
if response is None:
raise RuntimeError("Gemini provider did not receive a response")
body = response.json()
image_part = self._find_image_part(body)
image_data = image_part.get("inlineData") or image_part.get("inline_data")
mime_type = image_data.get("mimeType") or image_data.get("mime_type") or "image/png"
data = image_data.get("data")
if not data:
raise ValueError("Gemini response did not include image bytes")
return GeneratedImageResult(
image_bytes=base64.b64decode(data),
mime_type=mime_type,
provider=self.provider_name,
model=self.model,
prompt=prompt,
)
@staticmethod
def _find_image_part(body: dict) -> dict:
candidates = body.get("candidates") or []
for candidate in candidates:
content = candidate.get("content") or {}
for part in content.get("parts") or []:
if part.get("inlineData") or part.get("inline_data"):
return part
raise ValueError("Gemini response did not contain an image part")

107
app/infra/storage/s3.py Normal file
View File

@@ -0,0 +1,107 @@
"""S3 direct-upload helpers."""
from __future__ import annotations
from pathlib import Path
from uuid import uuid4
import boto3
from app.config.settings import get_settings
from app.domain.enums import LibraryResourceType, WorkflowStepName
RESOURCE_PREFIXES: dict[LibraryResourceType, str] = {
LibraryResourceType.MODEL: "models",
LibraryResourceType.SCENE: "scenes",
LibraryResourceType.GARMENT: "garments",
}
class S3PresignService:
"""Generate presigned upload URLs and derived public URLs."""
def __init__(self) -> None:
self.settings = get_settings()
self._client = boto3.client(
"s3",
region_name=self.settings.s3_region or None,
endpoint_url=self.settings.s3_endpoint or None,
aws_access_key_id=self.settings.s3_access_key or None,
aws_secret_access_key=self.settings.s3_secret_key or None,
)
def create_upload(self, resource_type: LibraryResourceType, file_name: str, content_type: str) -> tuple[str, str]:
"""Return a storage key and presigned PUT URL for a resource file."""
storage_key = self._build_storage_key(resource_type, file_name)
upload_url = self._client.generate_presigned_url(
"put_object",
Params={
"Bucket": self.settings.s3_bucket,
"Key": storage_key,
"ContentType": content_type,
},
ExpiresIn=self.settings.s3_presign_expiry_seconds,
HttpMethod="PUT",
)
return storage_key, upload_url
def get_public_url(self, storage_key: str) -> str:
"""Return the public CDN URL for an uploaded object."""
if self.settings.s3_cname:
base = self.settings.s3_cname
if not base.startswith("http://") and not base.startswith("https://"):
base = f"https://{base}"
return f"{base.rstrip('/')}/{storage_key}"
endpoint = self.settings.s3_endpoint.rstrip("/")
return f"{endpoint}/{self.settings.s3_bucket}/{storage_key}"
def _build_storage_key(self, resource_type: LibraryResourceType, file_name: str) -> str:
suffix = Path(file_name).suffix or ".bin"
stem = Path(file_name).stem.replace(" ", "-").lower() or "file"
return f"library/{RESOURCE_PREFIXES[resource_type]}/{uuid4().hex}-{stem}{suffix.lower()}"
class S3ObjectStorageService:
"""Upload generated workflow artifacts to the configured object store."""
def __init__(self) -> None:
self.settings = get_settings()
self._client = boto3.client(
"s3",
region_name=self.settings.s3_region or None,
endpoint_url=self.settings.s3_endpoint or None,
aws_access_key_id=self.settings.s3_access_key or None,
aws_secret_access_key=self.settings.s3_secret_key or None,
)
self._presign = S3PresignService()
async def upload_generated_image(
self,
*,
order_id: int,
step_name: WorkflowStepName,
image_bytes: bytes,
mime_type: str,
) -> tuple[str, str]:
"""Upload bytes and return the storage key plus public URL."""
storage_key = self._build_storage_key(order_id=order_id, step_name=step_name, mime_type=mime_type)
self._client.put_object(
Bucket=self.settings.s3_bucket,
Key=storage_key,
Body=image_bytes,
ContentType=mime_type,
)
return storage_key, self._presign.get_public_url(storage_key)
@staticmethod
def _build_storage_key(*, order_id: int, step_name: WorkflowStepName, mime_type: str) -> str:
suffix = {
"image/png": ".png",
"image/jpeg": ".jpg",
"image/webp": ".webp",
}.get(mime_type, ".bin")
return f"orders/{order_id}/{step_name.value}/{uuid4().hex}{suffix}"

View File

@@ -6,6 +6,7 @@ from fastapi import FastAPI
from app.api.routers.assets import router as assets_router
from app.api.routers.health import router as health_router
from app.api.routers.library import router as library_router
from app.api.routers.orders import router as orders_router
from app.api.routers.revisions import router as revisions_router
from app.api.routers.reviews import router as reviews_router
@@ -30,6 +31,7 @@ def create_app() -> FastAPI:
settings = get_settings()
app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan)
app.include_router(health_router)
app.include_router(library_router, prefix=settings.api_prefix)
app.include_router(orders_router, prefix=settings.api_prefix)
app.include_router(assets_router, prefix=settings.api_prefix)
app.include_router(revisions_router, prefix=settings.api_prefix)

View File

@@ -1,20 +1,82 @@
"""Export mock activity."""
"""Export activity."""
from app.domain.enums import AssetType, OrderStatus, StepStatus
from app.infra.db.models.asset import AssetORM
from app.infra.db.session import get_session_factory
from temporalio import activity
from app.domain.enums import AssetType
from app.workers.activities.tryon_activities import execute_asset_step
from app.workers.activities.tryon_activities import create_step_record, jsonable, load_order_and_run, utc_now
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@activity.defn
async def run_export_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock final asset export."""
"""Finalize the chosen source asset as the order's exported deliverable."""
return await execute_asset_step(
payload,
AssetType.FINAL,
filename="final.png",
finalize=True,
)
if payload.source_asset_id is None:
raise ValueError("run_export_activity requires source_asset_id")
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
source_asset = await session.get(AssetORM, payload.source_asset_id)
if source_asset is None:
raise ValueError(f"Source asset {payload.source_asset_id} not found")
if source_asset.order_id != payload.order_id:
raise ValueError(
f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}"
)
metadata = {
**payload.metadata,
"source_asset_id": payload.source_asset_id,
"selected_asset_id": payload.selected_asset_id,
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.FINAL,
step_name=payload.step_name,
uri=source_asset.uri,
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=0.95,
passed=True,
message="mock success",
metadata=jsonable(metadata) or {},
)
order.final_asset_id = asset.id
order.status = OrderStatus.SUCCEEDED
workflow_run.status = OrderStatus.SUCCEEDED
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise

View File

@@ -29,11 +29,22 @@ async def run_qc_activity(payload: StepActivityInput) -> MockActivityResult:
candidate_uri: str | None = None
if passed:
if payload.source_asset_id is None:
raise ValueError("run_qc_activity requires source_asset_id")
source_asset = await session.get(AssetORM, payload.source_asset_id)
if source_asset is None:
raise ValueError(f"Source asset {payload.source_asset_id} not found")
if source_asset.order_id != payload.order_id:
raise ValueError(
f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}"
)
candidate = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.QC_CANDIDATE,
step_name=payload.step_name,
uri=mock_uri(payload.order_id, payload.step_name.value, "candidate.png"),
uri=source_asset.uri,
metadata_json=jsonable({"source_asset_id": payload.source_asset_id}),
)
session.add(candidate)

View File

@@ -1,19 +1,122 @@
"""Scene mock activity."""
"""Scene activities."""
from temporalio import activity
from app.domain.enums import AssetType
from app.workers.activities.tryon_activities import execute_asset_step
from app.domain.enums import AssetType, LibraryResourceType, OrderStatus, StepStatus
from app.infra.db.models.asset import AssetORM
from app.infra.db.session import get_session_factory
from app.workers.activities.tryon_activities import (
create_step_record,
execute_asset_step,
get_image_generation_service,
get_order_artifact_storage_service,
jsonable,
load_active_library_resource,
load_order_and_run,
utc_now,
)
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@activity.defn
async def run_scene_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock scene replacement."""
"""Generate a scene-composited asset, or fall back to mock mode when configured."""
return await execute_asset_step(
payload,
AssetType.SCENE,
extra_metadata={"scene_ref_asset_id": payload.scene_ref_asset_id},
)
service = get_image_generation_service()
if service.__class__.__name__ == "MockImageGenerationService":
return await execute_asset_step(
payload,
AssetType.SCENE,
extra_metadata={"scene_ref_asset_id": payload.scene_ref_asset_id},
)
if payload.source_asset_id is None:
raise ValueError("run_scene_activity requires source_asset_id")
if payload.scene_ref_asset_id is None:
raise ValueError("run_scene_activity requires scene_ref_asset_id")
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
source_asset = await session.get(AssetORM, payload.source_asset_id)
if source_asset is None:
raise ValueError(f"Source asset {payload.source_asset_id} not found")
if source_asset.order_id != payload.order_id:
raise ValueError(
f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}"
)
scene_resource, scene_original = await load_active_library_resource(
session,
payload.scene_ref_asset_id,
resource_type=LibraryResourceType.SCENE,
)
generated = await service.generate_scene_image(
source_image_url=source_asset.uri,
scene_image_url=scene_original.public_url,
)
storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image(
order_id=payload.order_id,
step_name=payload.step_name,
image_bytes=generated.image_bytes,
mime_type=generated.mime_type,
)
metadata = {
**payload.metadata,
"source_asset_id": source_asset.id,
"scene_ref_asset_id": payload.scene_ref_asset_id,
"scene_resource_id": scene_resource.id,
"scene_original_file_id": scene_original.id,
"scene_original_url": scene_original.public_url,
"provider": generated.provider,
"model": generated.model,
"storage_key": storage_key,
"mime_type": generated.mime_type,
"prompt": generated.prompt,
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.SCENE,
step_name=payload.step_name,
uri=public_url,
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=1.0,
passed=True,
message="scene generated",
metadata=jsonable(metadata) or {},
)
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise

View File

@@ -1,4 +1,4 @@
"""Prepare-model and try-on mock activities plus shared helpers."""
"""Prepare-model and try-on activities plus shared helpers."""
from __future__ import annotations
@@ -8,14 +8,20 @@ from enum import Enum
from typing import Any
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from temporalio import activity
from app.domain.enums import AssetType, OrderStatus, StepStatus
from app.application.services.image_generation_service import build_image_generation_service
from app.domain.enums import AssetType, LibraryResourceStatus, LibraryResourceType, OrderStatus, StepStatus
from app.infra.db.models.asset import AssetORM
from app.infra.db.models.library_resource import LibraryResourceORM
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
from app.infra.db.models.order import OrderORM
from app.infra.db.models.workflow_run import WorkflowRunORM
from app.infra.db.models.workflow_step import WorkflowStepORM
from app.infra.db.session import get_session_factory
from app.infra.storage.s3 import S3ObjectStorageService
from app.workers.workflows.types import MockActivityResult, StepActivityInput
@@ -71,6 +77,64 @@ def create_step_record(payload: StepActivityInput) -> WorkflowStepORM:
)
async def load_active_library_resource(
session,
resource_id: int,
*,
resource_type: LibraryResourceType,
) -> tuple[LibraryResourceORM, LibraryResourceFileORM]:
"""Load an active library resource and its original file from the library."""
result = await session.execute(
select(LibraryResourceORM)
.options(selectinload(LibraryResourceORM.files))
.where(LibraryResourceORM.id == resource_id)
)
resource = result.scalar_one_or_none()
if resource is None:
raise ValueError(f"Library resource {resource_id} not found")
if resource.resource_type != resource_type:
raise ValueError(f"Resource {resource_id} is not a {resource_type.value} resource")
if resource.status != LibraryResourceStatus.ACTIVE:
raise ValueError(f"Library resource {resource_id} is not active")
if resource.original_file_id is None:
raise ValueError(f"Library resource {resource_id} is missing an original file")
original_file = next((item for item in resource.files if item.id == resource.original_file_id), None)
if original_file is None:
raise ValueError(f"Library resource {resource_id} original file record not found")
return resource, original_file
def get_image_generation_service():
"""Return the configured image-generation service."""
return build_image_generation_service()
def get_order_artifact_storage_service() -> S3ObjectStorageService:
"""Return the object-storage service for workflow-generated images."""
return S3ObjectStorageService()
def build_resource_input_snapshot(
resource: LibraryResourceORM,
original_file: LibraryResourceFileORM,
) -> dict[str, Any]:
"""Build a frontend-friendly snapshot of one library input resource."""
return {
"resource_id": resource.id,
"resource_name": resource.name,
"original_file_id": original_file.id,
"original_url": original_file.public_url,
"mime_type": original_file.mime_type,
"width": original_file.width,
"height": original_file.height,
}
async def execute_asset_step(
payload: StepActivityInput,
asset_type: AssetType,
@@ -145,26 +209,213 @@ async def execute_asset_step(
@activity.defn
async def prepare_model_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock model preparation for the pipeline."""
"""Resolve a model resource into an order-scoped prepared-model asset."""
return await execute_asset_step(
payload,
AssetType.PREPARED_MODEL,
extra_metadata={
"model_id": payload.model_id,
"pose_id": payload.pose_id,
"garment_asset_id": payload.garment_asset_id,
"scene_ref_asset_id": payload.scene_ref_asset_id,
},
)
if payload.model_id is None:
raise ValueError("prepare_model_activity requires model_id")
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
resource, original_file = await load_active_library_resource(
session,
payload.model_id,
resource_type=LibraryResourceType.MODEL,
)
garment_snapshot = None
if payload.garment_asset_id is not None:
garment_resource, garment_original = await load_active_library_resource(
session,
payload.garment_asset_id,
resource_type=LibraryResourceType.GARMENT,
)
garment_snapshot = build_resource_input_snapshot(garment_resource, garment_original)
scene_snapshot = None
if payload.scene_ref_asset_id is not None:
scene_resource, scene_original = await load_active_library_resource(
session,
payload.scene_ref_asset_id,
resource_type=LibraryResourceType.SCENE,
)
scene_snapshot = build_resource_input_snapshot(scene_resource, scene_original)
metadata = {
**payload.metadata,
"model_id": payload.model_id,
"pose_id": payload.pose_id,
"garment_asset_id": payload.garment_asset_id,
"scene_ref_asset_id": payload.scene_ref_asset_id,
"library_resource_id": resource.id,
"library_original_file_id": original_file.id,
"library_original_url": original_file.public_url,
"library_original_mime_type": original_file.mime_type,
"library_original_width": original_file.width,
"library_original_height": original_file.height,
"model_input": build_resource_input_snapshot(resource, original_file),
"garment_input": garment_snapshot,
"scene_input": scene_snapshot,
"pose_input": {"pose_id": payload.pose_id} if payload.pose_id is not None else None,
"normalized": False,
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.PREPARED_MODEL,
step_name=payload.step_name,
uri=original_file.public_url,
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=1.0,
passed=True,
message="prepared model ready",
metadata=jsonable(metadata) or {},
)
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise
@activity.defn
async def run_tryon_activity(payload: StepActivityInput) -> MockActivityResult:
"""Mock try-on rendering."""
"""执行试衣渲染步骤,或在未接真实能力时走 mock 分支。
return await execute_asset_step(
payload,
AssetType.TRYON,
extra_metadata={"prepared_asset_id": payload.source_asset_id},
)
流程:
1. 读取当前配置的图片生成 service。
2. 如果当前仍是 mock service就退回通用的 mock 资产产出逻辑。
3. 如果是真实模式,先读取 prepare_model_activity 产出的 prepared_model 资产。
4. 再读取本次选中的服装资源,并解析出它的原图 URL。
5. 把“模特准备图 + 服装原图”一起发给 provider 做试衣生成。
6. 把生成结果上传到 S3并落成订单内的一条 TRYON 资产。
7. 在 metadata 里记录 provider、model、prompt 和输入来源,方便追踪排查。
"""
service = get_image_generation_service()
# 保留旧的 mock 分支,这样在没有接入真实 provider 时 workflow 也还能跑通。
if service.__class__.__name__ == "MockImageGenerationService":
return await execute_asset_step(
payload,
AssetType.TRYON,
extra_metadata={"prepared_asset_id": payload.source_asset_id},
)
if payload.source_asset_id is None:
raise ValueError("run_tryon_activity requires source_asset_id")
if payload.garment_asset_id is None:
raise ValueError("run_tryon_activity requires garment_asset_id")
async with get_session_factory()() as session:
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
step = create_step_record(payload)
session.add(step)
order.status = OrderStatus.RUNNING
workflow_run.status = OrderStatus.RUNNING
workflow_run.current_step = payload.step_name
await session.flush()
try:
# 试衣步骤的输入起点固定是上一阶段产出的 prepared_model 资产。
prepared_asset = await session.get(AssetORM, payload.source_asset_id)
if prepared_asset is None or prepared_asset.order_id != payload.order_id:
raise ValueError(f"Prepared asset {payload.source_asset_id} not found for order {payload.order_id}")
if prepared_asset.asset_type != AssetType.PREPARED_MODEL:
raise ValueError(f"Asset {payload.source_asset_id} is not a prepared_model asset")
# 服装素材来自共享资源库workflow 实际消费的是资源库里的原图。
garment_resource, garment_original = await load_active_library_resource(
session,
payload.garment_asset_id,
resource_type=LibraryResourceType.GARMENT,
)
# provider 接收两张真实输入图:模特准备图 + 服装参考图。
generated = await service.generate_tryon_image(
person_image_url=prepared_asset.uri,
garment_image_url=garment_original.public_url,
)
# workflow 生成出的结果图先上传到 S3再登记成订单资产。
storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image(
order_id=payload.order_id,
step_name=payload.step_name,
image_bytes=generated.image_bytes,
mime_type=generated.mime_type,
)
# 记录足够的追踪信息,方便回溯这张试衣图由哪些输入和哪个 provider 产出。
metadata = {
**payload.metadata,
"prepared_asset_id": prepared_asset.id,
"garment_resource_id": garment_resource.id,
"garment_original_file_id": garment_original.id,
"garment_original_url": garment_original.public_url,
"provider": generated.provider,
"model": generated.model,
"storage_key": storage_key,
"mime_type": generated.mime_type,
"prompt": generated.prompt,
}
metadata = {key: value for key, value in metadata.items() if value is not None}
asset = AssetORM(
order_id=payload.order_id,
asset_type=AssetType.TRYON,
step_name=payload.step_name,
uri=public_url,
metadata_json=jsonable(metadata),
)
session.add(asset)
await session.flush()
result = MockActivityResult(
step_name=payload.step_name,
success=True,
asset_id=asset.id,
uri=asset.uri,
score=1.0,
passed=True,
message="try-on generated",
metadata=jsonable(metadata) or {},
)
step.step_status = StepStatus.SUCCEEDED
step.output_json = jsonable(result)
step.ended_at = utc_now()
await session.commit()
return result
except Exception as exc:
step.step_status = StepStatus.FAILED
step.error_message = str(exc)
step.ended_at = utc_now()
order.status = OrderStatus.FAILED
workflow_run.status = OrderStatus.FAILED
await session.commit()
raise

View File

@@ -26,14 +26,13 @@ with workflow.unsafe.imports_passed_through():
from app.workers.activities.review_activities import mark_workflow_failed_activity
from app.workers.activities.scene_activities import run_scene_activity
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
from app.workers.workflows.timeout_policy import DEFAULT_ACTIVITY_TIMEOUT, activity_timeout_for_task_queue
from app.workers.workflows.types import (
PipelineWorkflowInput,
StepActivityInput,
WorkflowFailureActivityInput,
)
ACTIVITY_TIMEOUT = timedelta(seconds=30)
ACTIVITY_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2.0,
@@ -64,6 +63,8 @@ class LowEndPipelineWorkflow:
try:
# 每个步骤都通过 execute_activity 触发真正执行。
# workflow 自己不做计算,只负责调度。
# prepare_model 的职责是把“模特资源 + 订单上下文”整理成后续可消费的人物底图。
# 这一步产出的 prepared.asset_id 会作为 tryon 的 source_asset_id。
prepared = await workflow.execute_activity(
prepare_model_activity,
StepActivityInput(
@@ -76,12 +77,14 @@ class LowEndPipelineWorkflow:
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.TRYON
# 下游步骤通过 source_asset_id 引用上一步生成的资产。
# tryon 是换装主步骤:把 prepared 模特图和 garment_asset_id 组合成试衣结果。
# 它产出的 tryon_result.asset_id 是后续 scene / qc 的基础输入。
tryon_result = await workflow.execute_activity(
run_tryon_activity,
StepActivityInput(
@@ -92,38 +95,46 @@ class LowEndPipelineWorkflow:
garment_asset_id=payload.garment_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.SCENE
scene_result = await workflow.execute_activity(
run_scene_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.SCENE,
source_asset_id=tryon_result.asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
scene_source_asset_id = tryon_result.asset_id
if payload.scene_ref_asset_id is not None:
current_step = WorkflowStepName.SCENE
# scene 是可选步骤。
# 有场景图时,用 scene_ref_asset_id 把 tryon 结果合成到目标背景;
# 没有场景图时,直接沿用 tryon 结果继续往下走。
scene_result = await workflow.execute_activity(
run_scene_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.SCENE,
source_asset_id=tryon_result.asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
scene_source_asset_id = scene_result.asset_id
current_step = WorkflowStepName.QC
# QC 是流程里的“闸门”。
# 如果这里不通过,低端流程直接失败,不会再 export。
# source_asset_id 指向当前链路里“最接近最终成品”的那张图:
# 有场景时是 scene 结果,没有场景时就是 tryon 结果。
qc_result = await workflow.execute_activity(
run_qc_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.QC,
source_asset_id=scene_result.asset_id,
source_asset_id=scene_source_asset_id,
),
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -134,16 +145,17 @@ class LowEndPipelineWorkflow:
current_step = WorkflowStepName.EXPORT
# candidate_asset_ids 是 QC 推荐可导出的候选资产。
# 当前 MVP 只会返回一个候选;如果没有,就退回 scene 结果导出。
# export 是最后的收口步骤:生成最终交付资产,并把订单状态写成 succeeded。
final_result = await workflow.execute_activity(
run_export_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.EXPORT,
source_asset_id=(qc_result.candidate_asset_ids or [scene_result.asset_id])[0],
source_asset_id=(qc_result.candidate_asset_ids or [scene_source_asset_id])[0],
),
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
return {
@@ -177,6 +189,6 @@ class LowEndPipelineWorkflow:
message=message,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=DEFAULT_ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)

View File

@@ -34,6 +34,7 @@ with workflow.unsafe.imports_passed_through():
from app.workers.activities.scene_activities import run_scene_activity
from app.workers.activities.texture_activities import run_texture_activity
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
from app.workers.workflows.timeout_policy import DEFAULT_ACTIVITY_TIMEOUT, activity_timeout_for_task_queue
from app.workers.workflows.types import (
MockActivityResult,
PipelineWorkflowInput,
@@ -44,8 +45,6 @@ with workflow.unsafe.imports_passed_through():
WorkflowFailureActivityInput,
)
ACTIVITY_TIMEOUT = timedelta(seconds=30)
ACTIVITY_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2.0,
@@ -87,6 +86,8 @@ class MidEndPipelineWorkflow:
# current_step 用于失败时记录“最后跑到哪一步”。
current_step = WorkflowStepName.PREPARE_MODEL
try:
# prepare_model / tryon / scene 这三段和低端流程共享同一套 activity
# 区别只在于中端流程后面还会继续进入 texture / face / fusion / review。
prepared = await workflow.execute_activity(
prepare_model_activity,
StepActivityInput(
@@ -99,11 +100,12 @@ class MidEndPipelineWorkflow:
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.TRYON
# tryon 产出基础换装图,后面的所有增强步骤都以它为起点。
tryon_result = await workflow.execute_activity(
run_tryon_activity,
StepActivityInput(
@@ -114,23 +116,37 @@ class MidEndPipelineWorkflow:
garment_asset_id=payload.garment_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
scene_source_asset_id = tryon_result.asset_id
if payload.scene_ref_asset_id is not None:
current_step = WorkflowStepName.SCENE
# scene 仍然是可选步骤。
# 如果存在场景图,就先完成换背景,再把结果交给 texture / fusion
# 如果没有场景图,就直接把 tryon 结果当成“场景基底”继续流程。
scene_result = await self._run_scene(payload, tryon_result.asset_id)
scene_source_asset_id = scene_result.asset_id
else:
scene_result = tryon_result
current_step = WorkflowStepName.TEXTURE
texture_result = await self._run_texture(payload, scene_result.asset_id)
# texture 是复杂流独有步骤。
# 它通常负责衣物纹理、材质细节或局部增强,输入是当前场景基底图。
texture_result = await self._run_texture(payload, scene_source_asset_id)
current_step = WorkflowStepName.FACE
# face 也是复杂流独有步骤。
# 它在 texture 结果基础上做脸部修复/增强,产出供 fusion 合成的人像版本。
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
# fusion 负责把“场景基底”和“脸部增强结果”合成成候选成品。
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
current_step = WorkflowStepName.QC
# QC 和低端流程复用同一套 activity但这里检查的是 fusion 产物。
qc_result = await self._run_qc(payload, fusion_result.asset_id)
if not qc_result.passed:
await self._mark_failed(payload, current_step, qc_result.message)
@@ -144,6 +160,8 @@ class MidEndPipelineWorkflow:
current_step = WorkflowStepName.REVIEW
# 这里通过 activity 把数据库里的订单状态更新成 waiting_review
# 同时创建 review_task供 API 查询待审核列表。
# review 是复杂流真正区别于低端流程的核心:
# 在 export 前插入人工决策点,并支持按意见回流重跑。
await workflow.execute_activity(
mark_waiting_for_review_activity,
ReviewWaitActivityInput(
@@ -152,7 +170,7 @@ class MidEndPipelineWorkflow:
candidate_asset_ids=qc_result.candidate_asset_ids,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -172,7 +190,7 @@ class MidEndPipelineWorkflow:
comment=review_payload.comment,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -180,6 +198,7 @@ class MidEndPipelineWorkflow:
current_step = WorkflowStepName.EXPORT
# 如果审核人显式选了资产,就导出该资产;
# 否则默认导出 QC 候选资产。
# export 本身仍然和低端流程是同一个 activity。
export_source_id = review_payload.selected_asset_id
if export_source_id is None:
export_source_id = (qc_result.candidate_asset_ids or [fusion_result.asset_id])[0]
@@ -192,7 +211,7 @@ class MidEndPipelineWorkflow:
source_asset_id=export_source_id,
),
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
return {
@@ -208,22 +227,30 @@ class MidEndPipelineWorkflow:
# rerun 的核心思想是:
# 把指定节点后的链路重新跑一遍,然后再次进入 QC 和 waiting_review。
if review_payload.decision == ReviewDecision.RERUN_SCENE:
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
# 从 scene 开始重跑意味着 scene / texture / face / fusion 全部重算。
if payload.scene_ref_asset_id is not None:
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
scene_source_asset_id = scene_result.asset_id
else:
scene_result = tryon_result
scene_source_asset_id = tryon_result.asset_id
current_step = WorkflowStepName.TEXTURE
texture_result = await self._run_texture(payload, scene_result.asset_id)
texture_result = await self._run_texture(payload, scene_source_asset_id)
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
elif review_payload.decision == ReviewDecision.RERUN_FACE:
# 从 face 开始重跑时scene / texture 结果保持不变。
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
elif review_payload.decision == ReviewDecision.RERUN_FUSION:
# 从 fusion 重跑是最小范围回流,只重做最终合成。
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
current_step = WorkflowStepName.QC
qc_result = await self._run_qc(payload, fusion_result.asset_id)
@@ -264,7 +291,7 @@ class MidEndPipelineWorkflow:
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -280,7 +307,7 @@ class MidEndPipelineWorkflow:
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -296,7 +323,7 @@ class MidEndPipelineWorkflow:
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -318,7 +345,7 @@ class MidEndPipelineWorkflow:
selected_asset_id=face_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -334,7 +361,7 @@ class MidEndPipelineWorkflow:
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE),
retry_policy=ACTIVITY_RETRY_POLICY,
)
@@ -355,6 +382,6 @@ class MidEndPipelineWorkflow:
message=message,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
start_to_close_timeout=DEFAULT_ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)

View File

@@ -0,0 +1,24 @@
"""Shared activity timeout policy for Temporal workflows."""
from datetime import timedelta
from app.infra.temporal.task_queues import (
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
)
DEFAULT_ACTIVITY_TIMEOUT = timedelta(seconds=30)
LONG_RUNNING_ACTIVITY_TIMEOUT = timedelta(minutes=5)
LONG_RUNNING_TASK_QUEUES = {
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
}
def activity_timeout_for_task_queue(task_queue: str) -> timedelta:
"""Return the timeout that matches the workload behind the given task queue."""
if task_queue in LONG_RUNNING_TASK_QUEUES:
return LONG_RUNNING_ACTIVITY_TIMEOUT
return DEFAULT_ACTIVITY_TIMEOUT

View File

@@ -32,14 +32,22 @@ class PipelineWorkflowInput:
这是一张订单进入 workflow 时携带的最小上下文。
"""
# 订单主键。整个 workflow 的所有步骤都会围绕这张订单落库、查状态。
order_id: int
# workflow_run 主键。用于把每一步 step 记录关联到同一次运行。
workflow_run_id: int
# 客户层级。决定业务侧订单归属,也会影响后续统计和展示。
customer_level: CustomerLevel
# 服务模式。这里直接决定启动简单流程还是复杂流程。
service_mode: ServiceMode
# 模特资源 ID。prepare_model 会基于它准备人物素材。
model_id: int
pose_id: int
# 服装资源/资产 ID。tryon 会把它作为换装输入。
garment_asset_id: int
scene_ref_asset_id: int
# 场景参考图 ID。可为空为空时跳过 scene 步骤。
scene_ref_asset_id: int | None = None
# 模特姿势 ID。当前 MVP 已经放宽为可空,后续需要姿势控制时再启用。
pose_id: int | None = None
def __post_init__(self) -> None:
"""在反序列化后把枚举字段修正回来。"""
@@ -59,15 +67,25 @@ class StepActivityInput:
- 上一步产出的 asset_id
"""
# 当前步骤属于哪张订单。
order_id: int
# 当前步骤属于哪次 workflow 运行。
workflow_run_id: int
# 当前执行的步骤名,例如 tryon / qc / export。
step_name: WorkflowStepName
# 模特资源 ID。只有 prepare_model 等少数步骤会直接消费它。
model_id: int | None = None
# 姿势 ID。当前大多为透传预留字段。
pose_id: int | None = None
# 服装资源/资产 ID。tryon 是主要消费者。
garment_asset_id: int | None = None
# 场景参考图 ID。scene 步骤会用它完成换背景。
scene_ref_asset_id: int | None = None
# 上一步产出的资产 ID。绝大多数步骤都靠它串起资产链路。
source_asset_id: int | None = None
# 人工选中的资产 ID。review / fusion / export 等需要显式选图时使用。
selected_asset_id: int | None = None
# 额外扩展参数。给特殊步骤塞一些不值得单独建字段的临时上下文。
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
@@ -80,14 +98,23 @@ class StepActivityInput:
class MockActivityResult:
"""通用 mock activity 返回结构。"""
# 这是哪一步返回的结果,方便 workflow 和落库代码识别来源。
step_name: WorkflowStepName
# activity 是否执行成功。一般表示“代码执行成功”,不等于业务一定通过。
success: bool
# 这一步新生成的资产 ID如果步骤不产出资产可以为空。
asset_id: int | None
# 产出资产的访问地址;当前 mock 阶段通常是 mock:// URI。
uri: str | None
# 模型分数/质量分数之类的数值结果,非所有步骤都有。
score: float | None = None
# 业务是否通过。典型场景是 QCactivity 成功执行,但 passed 可能为 False。
passed: bool | None = None
# 给 workflow 或 API 展示的简短结果消息。
message: str = "mock success"
# 候选资产列表。主要给 QC / review 这类“多候选图”步骤使用。
candidate_asset_ids: list[int] = field(default_factory=list)
# 补充元数据。用于把步骤内部的一些上下文回传给调用方。
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
@@ -100,9 +127,13 @@ class MockActivityResult:
class ReviewSignalPayload:
"""API 发给中端 workflow 的审核 signal 载荷。"""
# 审核动作:通过 / 拒绝 / 从某一步重跑。
decision: ReviewDecision
# 操作审核的用户 ID便于审计和任务归属。
reviewer_id: int
# 审核人最终选中的候选资产 IDapprove 时最常见。
selected_asset_id: int | None = None
# 审核备注,例如“脸部不自然,重跑融合”。
comment: str | None = None
def __post_init__(self) -> None:
@@ -115,9 +146,13 @@ class ReviewSignalPayload:
class ReviewWaitActivityInput:
"""把流程切到 waiting_review 时传给 activity 的输入。"""
# 当前审核任务属于哪张订单。
order_id: int
# 当前审核任务属于哪次 workflow 运行。
workflow_run_id: int
# 提交给审核端看的候选资产列表。
candidate_asset_ids: list[int] = field(default_factory=list)
# 进入审核态时附带的说明文案,当前大多留空。
comment: str | None = None
@@ -125,11 +160,17 @@ class ReviewWaitActivityInput:
class ReviewResolutionActivityInput:
"""审核结果到达后,用于结束 waiting_review 的输入。"""
# 当前审核结果属于哪张订单。
order_id: int
# 当前审核结果属于哪次 workflow 运行。
workflow_run_id: int
# 审核最终决策。
decision: ReviewDecision
# 处理这次审核的审核人 ID。
reviewer_id: int
# 如果审核人明确挑了一张图,这里记录最终选择的资产 ID。
selected_asset_id: int | None = None
# 审核备注或重跑原因。
comment: str | None = None
def __post_init__(self) -> None:
@@ -142,10 +183,15 @@ class ReviewResolutionActivityInput:
class WorkflowFailureActivityInput:
"""流程失败收尾 activity 的输入。"""
# 失败发生在哪张订单。
order_id: int
# 失败发生在哪次 workflow 运行。
workflow_run_id: int
# 失败停留在哪个步骤,用于更新 workflow_run.current_step。
current_step: WorkflowStepName
# 失败原因文本,通常来自异常或 QC 驳回消息。
message: str
# 要写回数据库的最终状态,默认就是 failed。
status: OrderStatus = OrderStatus.FAILED
def __post_init__(self) -> None:

View File

@@ -11,6 +11,7 @@ requires-python = ">=3.11"
dependencies = [
"aiosqlite>=0.20,<1.0",
"alembic>=1.13,<2.0",
"boto3>=1.35,<2.0",
"fastapi>=0.115,<1.0",
"greenlet>=3.1,<4.0",
"httpx>=0.27,<1.0",

View File

@@ -20,6 +20,13 @@ async def api_runtime(tmp_path, monkeypatch):
db_path = tmp_path / "test.db"
monkeypatch.setenv("DATABASE_URL", f"sqlite+aiosqlite:///{db_path.as_posix()}")
monkeypatch.setenv("AUTO_CREATE_TABLES", "true")
monkeypatch.setenv("S3_ACCESS_KEY", "test-access")
monkeypatch.setenv("S3_SECRET_KEY", "test-secret")
monkeypatch.setenv("S3_BUCKET", "test-bucket")
monkeypatch.setenv("S3_REGION", "ap-southeast-1")
monkeypatch.setenv("S3_ENDPOINT", "https://s3.example.com")
monkeypatch.setenv("S3_CNAME", "images.example.com")
monkeypatch.setenv("IMAGE_GENERATION_PROVIDER", "mock")
get_settings.cache_clear()
await dispose_database()
@@ -41,4 +48,3 @@ async def api_runtime(tmp_path, monkeypatch):
set_temporal_client(None)
await dispose_database()
get_settings.cache_clear()

View File

@@ -39,15 +39,15 @@ async def wait_for_step_count(client, order_id: int, step_name: str, minimum_cou
async def create_mid_end_order(client):
"""Create a standard semi-pro order for review-path tests."""
resources = await create_workflow_resources(client, include_scene=True)
response = await client.post(
"/api/v1/orders",
json={
"customer_level": "mid",
"service_mode": "semi_pro",
"model_id": 101,
"pose_id": 3,
"garment_asset_id": 9001,
"scene_ref_asset_id": 8001,
"model_id": resources["model"]["id"],
"garment_asset_id": resources["garment"]["id"],
"scene_ref_asset_id": resources["scene"]["id"],
},
)
@@ -55,6 +55,115 @@ async def create_mid_end_order(client):
return response.json()
async def create_library_resource(client, payload: dict) -> dict:
"""Create a resource-library item for workflow integration tests."""
response = await client.post("/api/v1/library/resources", json=payload)
assert response.status_code == 201
return response.json()
async def create_workflow_resources(client, *, include_scene: bool) -> dict[str, dict]:
"""Create a minimal set of real library resources for workflow tests."""
model_resource = await create_library_resource(
client,
{
"resource_type": "model",
"name": "Ava Studio",
"description": "棚拍女模特",
"tags": ["女装", "棚拍"],
"gender": "female",
"age_group": "adult",
"files": [
{
"file_role": "original",
"storage_key": "library/models/ava/original.png",
"public_url": "https://images.example.com/library/models/ava/original.png",
"mime_type": "image/png",
"size_bytes": 1024,
"sort_order": 0,
"width": 1200,
"height": 1600,
},
{
"file_role": "thumbnail",
"storage_key": "library/models/ava/thumb.png",
"public_url": "https://images.example.com/library/models/ava/thumb.png",
"mime_type": "image/png",
"size_bytes": 256,
"sort_order": 0,
},
],
},
)
garment_resource = await create_library_resource(
client,
{
"resource_type": "garment",
"name": "Cream Dress",
"description": "米白色连衣裙",
"tags": ["女装"],
"category": "dress",
"files": [
{
"file_role": "original",
"storage_key": "library/garments/cream-dress/original.png",
"public_url": "https://images.example.com/library/garments/cream-dress/original.png",
"mime_type": "image/png",
"size_bytes": 1024,
"sort_order": 0,
},
{
"file_role": "thumbnail",
"storage_key": "library/garments/cream-dress/thumb.png",
"public_url": "https://images.example.com/library/garments/cream-dress/thumb.png",
"mime_type": "image/png",
"size_bytes": 256,
"sort_order": 0,
},
],
},
)
resources = {
"model": model_resource,
"garment": garment_resource,
}
if include_scene:
resources["scene"] = await create_library_resource(
client,
{
"resource_type": "scene",
"name": "Loft Window",
"description": "暖调室内场景",
"tags": ["室内"],
"environment": "indoor",
"files": [
{
"file_role": "original",
"storage_key": "library/scenes/loft/original.png",
"public_url": "https://images.example.com/library/scenes/loft/original.png",
"mime_type": "image/png",
"size_bytes": 1024,
"sort_order": 0,
},
{
"file_role": "thumbnail",
"storage_key": "library/scenes/loft/thumb.png",
"public_url": "https://images.example.com/library/scenes/loft/thumb.png",
"mime_type": "image/png",
"size_bytes": 256,
"sort_order": 0,
},
],
},
)
return resources
@pytest.mark.asyncio
async def test_healthcheck(api_runtime):
"""The health endpoint should always respond successfully."""
@@ -71,15 +180,15 @@ async def test_low_end_order_completes(api_runtime):
"""Low-end orders should run through the full automated pipeline."""
client, env = api_runtime
resources = await create_workflow_resources(client, include_scene=True)
response = await client.post(
"/api/v1/orders",
json={
"customer_level": "low",
"service_mode": "auto_basic",
"model_id": 101,
"pose_id": 3,
"garment_asset_id": 9001,
"scene_ref_asset_id": 8001,
"model_id": resources["model"]["id"],
"garment_asset_id": resources["garment"]["id"],
"scene_ref_asset_id": resources["scene"]["id"],
},
)
@@ -105,6 +214,160 @@ async def test_low_end_order_completes(api_runtime):
assert workflow_response.json()["workflow_status"] == "succeeded"
@pytest.mark.asyncio
async def test_prepare_model_step_uses_library_original_file(api_runtime):
"""prepare_model should use the model resource's original file rather than a mock URI."""
def expected_input_snapshot(resource: dict) -> dict:
original_file = next(file for file in resource["files"] if file["file_role"] == "original")
snapshot = {
"resource_id": resource["id"],
"resource_name": resource["name"],
"original_file_id": original_file["id"],
"original_url": resource["original_url"],
"mime_type": original_file["mime_type"],
"width": original_file["width"],
"height": original_file["height"],
}
return {key: value for key, value in snapshot.items() if value is not None}
client, env = api_runtime
resources = await create_workflow_resources(client, include_scene=True)
model_resource = resources["model"]
garment_resource = resources["garment"]
scene_resource = resources["scene"]
response = await client.post(
"/api/v1/orders",
json={
"customer_level": "low",
"service_mode": "auto_basic",
"model_id": model_resource["id"],
"garment_asset_id": garment_resource["id"],
"scene_ref_asset_id": scene_resource["id"],
},
)
assert response.status_code == 201
payload = response.json()
handle = env.client.get_workflow_handle(payload["workflow_id"])
result = await handle.result()
assert result["status"] == "succeeded"
assets_response = await client.get(f"/api/v1/orders/{payload['order_id']}/assets")
assert assets_response.status_code == 200
prepared_asset = next(
asset for asset in assets_response.json() if asset["asset_type"] == "prepared_model"
)
assert prepared_asset["uri"] == model_resource["original_url"]
assert prepared_asset["metadata_json"]["library_resource_id"] == model_resource["id"]
assert prepared_asset["metadata_json"]["library_original_file_id"] == next(
file["id"] for file in model_resource["files"] if file["file_role"] == "original"
)
assert prepared_asset["metadata_json"]["library_original_url"] == model_resource["original_url"]
assert prepared_asset["metadata_json"]["model_input"] == expected_input_snapshot(model_resource)
assert prepared_asset["metadata_json"]["garment_input"] == expected_input_snapshot(garment_resource)
assert prepared_asset["metadata_json"]["scene_input"] == expected_input_snapshot(scene_resource)
@pytest.mark.asyncio
async def test_low_end_order_completes_without_scene(api_runtime):
"""Low-end orders should still complete when scene input is omitted."""
client, env = api_runtime
resources = await create_workflow_resources(client, include_scene=False)
response = await client.post(
"/api/v1/orders",
json={
"customer_level": "low",
"service_mode": "auto_basic",
"model_id": resources["model"]["id"],
"garment_asset_id": resources["garment"]["id"],
},
)
assert response.status_code == 201
payload = response.json()
handle = env.client.get_workflow_handle(payload["workflow_id"])
result = await handle.result()
assert result["status"] == "succeeded"
order_response = await client.get(f"/api/v1/orders/{payload['order_id']}")
assert order_response.status_code == 200
assert order_response.json()["scene_ref_asset_id"] is None
assert order_response.json()["status"] == "succeeded"
@pytest.mark.asyncio
async def test_low_end_order_reuses_tryon_uri_for_qc_and_final_without_scene(api_runtime):
"""Without scene input, qc/export should keep the real try-on image URI."""
client, env = api_runtime
resources = await create_workflow_resources(client, include_scene=False)
response = await client.post(
"/api/v1/orders",
json={
"customer_level": "low",
"service_mode": "auto_basic",
"model_id": resources["model"]["id"],
"garment_asset_id": resources["garment"]["id"],
},
)
assert response.status_code == 201
payload = response.json()
handle = env.client.get_workflow_handle(payload["workflow_id"])
result = await handle.result()
assert result["status"] == "succeeded"
assets_response = await client.get(f"/api/v1/orders/{payload['order_id']}/assets")
assert assets_response.status_code == 200
assets = assets_response.json()
tryon_asset = next(asset for asset in assets if asset["asset_type"] == "tryon")
qc_asset = next(asset for asset in assets if asset["asset_type"] == "qc_candidate")
final_asset = next(asset for asset in assets if asset["asset_type"] == "final")
assert qc_asset["uri"] == tryon_asset["uri"]
assert final_asset["uri"] == tryon_asset["uri"]
assert final_asset["metadata_json"]["source_asset_id"] == qc_asset["id"]
@pytest.mark.asyncio
async def test_mid_end_order_waits_review_then_approves_without_scene(api_runtime):
"""Mid-end orders should still reach review and approve when scene input is omitted."""
client, env = api_runtime
resources = await create_workflow_resources(client, include_scene=False)
response = await client.post(
"/api/v1/orders",
json={
"customer_level": "mid",
"service_mode": "semi_pro",
"model_id": resources["model"]["id"],
"garment_asset_id": resources["garment"]["id"],
},
)
assert response.status_code == 201
payload = response.json()
await wait_for_workflow_status(client, payload["order_id"], "waiting_review")
review_response = await client.post(
f"/api/v1/reviews/{payload['order_id']}/submit",
json={"decision": "approve", "reviewer_id": 77, "comment": "通过"},
)
assert review_response.status_code == 200
handle = env.client.get_workflow_handle(payload["workflow_id"])
result = await handle.result()
assert result["status"] == "succeeded"
@pytest.mark.asyncio
async def test_mid_end_order_waits_review_then_approves(api_runtime):
"""Mid-end orders should pause for review and continue after approval."""
@@ -133,6 +396,162 @@ async def test_mid_end_order_waits_review_then_approves(api_runtime):
assert order_response.json()["status"] == "succeeded"
@pytest.mark.asyncio
async def test_library_upload_presign_returns_direct_upload_metadata(api_runtime):
"""Library upload presign should return S3 direct-upload metadata and a public URL."""
client, _ = api_runtime
response = await client.post(
"/api/v1/library/uploads/presign",
json={
"resource_type": "model",
"file_role": "original",
"file_name": "ava.png",
"content_type": "image/png",
},
)
assert response.status_code == 200
payload = response.json()
assert payload["method"] == "PUT"
assert payload["storage_key"].startswith("library/models/")
assert payload["public_url"].startswith("https://")
assert "library/models/" in payload["upload_url"]
@pytest.mark.asyncio
async def test_library_resource_can_be_created_and_listed(api_runtime):
"""Creating a library resource should persist original, thumbnail, and gallery files."""
client, _ = api_runtime
create_response = await client.post(
"/api/v1/library/resources",
json={
"resource_type": "model",
"name": "Ava Studio",
"description": "棚拍女模特",
"tags": ["女装", "棚拍"],
"gender": "female",
"age_group": "adult",
"files": [
{
"file_role": "original",
"storage_key": "library/models/ava/original.png",
"public_url": "https://images.marcusd.me/library/models/ava/original.png",
"mime_type": "image/png",
"size_bytes": 1024,
"sort_order": 0,
},
{
"file_role": "thumbnail",
"storage_key": "library/models/ava/thumb.png",
"public_url": "https://images.marcusd.me/library/models/ava/thumb.png",
"mime_type": "image/png",
"size_bytes": 256,
"sort_order": 0,
},
{
"file_role": "gallery",
"storage_key": "library/models/ava/gallery-1.png",
"public_url": "https://images.marcusd.me/library/models/ava/gallery-1.png",
"mime_type": "image/png",
"size_bytes": 2048,
"sort_order": 1,
},
],
},
)
assert create_response.status_code == 201
created = create_response.json()
assert created["resource_type"] == "model"
assert created["pose_id"] is None
assert created["cover_url"] == "https://images.marcusd.me/library/models/ava/thumb.png"
assert created["original_url"] == "https://images.marcusd.me/library/models/ava/original.png"
assert len(created["files"]) == 3
list_response = await client.get("/api/v1/library/resources", params={"resource_type": "model"})
assert list_response.status_code == 200
listing = list_response.json()
assert listing["total"] == 1
assert listing["items"][0]["name"] == "Ava Studio"
assert listing["items"][0]["gender"] == "female"
assert listing["items"][0]["pose_id"] is None
@pytest.mark.asyncio
async def test_library_resource_list_supports_type_specific_filters(api_runtime):
"""Library listing should support resource-type-specific filter fields."""
client, _ = api_runtime
payloads = [
{
"resource_type": "scene",
"name": "Loft Window",
"description": "暖调室内场景",
"tags": ["室内"],
"environment": "indoor",
"files": [
{
"file_role": "original",
"storage_key": "library/scenes/loft/original.png",
"public_url": "https://images.marcusd.me/library/scenes/loft/original.png",
"mime_type": "image/png",
"size_bytes": 1024,
"sort_order": 0,
},
{
"file_role": "thumbnail",
"storage_key": "library/scenes/loft/thumb.png",
"public_url": "https://images.marcusd.me/library/scenes/loft/thumb.png",
"mime_type": "image/png",
"size_bytes": 256,
"sort_order": 0,
},
],
},
{
"resource_type": "scene",
"name": "Garden Walk",
"description": "自然光室外场景",
"tags": ["室外"],
"environment": "outdoor",
"files": [
{
"file_role": "original",
"storage_key": "library/scenes/garden/original.png",
"public_url": "https://images.marcusd.me/library/scenes/garden/original.png",
"mime_type": "image/png",
"size_bytes": 1024,
"sort_order": 0,
},
{
"file_role": "thumbnail",
"storage_key": "library/scenes/garden/thumb.png",
"public_url": "https://images.marcusd.me/library/scenes/garden/thumb.png",
"mime_type": "image/png",
"size_bytes": 256,
"sort_order": 0,
},
],
},
]
for payload in payloads:
response = await client.post("/api/v1/library/resources", json=payload)
assert response.status_code == 201
indoor_response = await client.get(
"/api/v1/library/resources",
params={"resource_type": "scene", "environment": "indoor"},
)
assert indoor_response.status_code == 200
indoor_payload = indoor_response.json()
assert indoor_payload["total"] == 1
assert indoor_payload["items"][0]["name"] == "Loft Window"
@pytest.mark.asyncio
@pytest.mark.parametrize(
("decision", "expected_step"),
@@ -312,15 +731,16 @@ async def test_orders_list_returns_recent_orders_with_revision_summary(api_runti
client, env = api_runtime
low_resources = await create_workflow_resources(client, include_scene=True)
low_order = await client.post(
"/api/v1/orders",
json={
"customer_level": "low",
"service_mode": "auto_basic",
"model_id": 201,
"model_id": low_resources["model"]["id"],
"pose_id": 11,
"garment_asset_id": 9101,
"scene_ref_asset_id": 8101,
"garment_asset_id": low_resources["garment"]["id"],
"scene_ref_asset_id": low_resources["scene"]["id"],
},
)
assert low_order.status_code == 201
@@ -394,15 +814,16 @@ async def test_workflows_list_returns_recent_runs_with_failure_count(api_runtime
client, env = api_runtime
low_resources = await create_workflow_resources(client, include_scene=True)
low_order = await client.post(
"/api/v1/orders",
json={
"customer_level": "low",
"service_mode": "auto_basic",
"model_id": 301,
"model_id": low_resources["model"]["id"],
"pose_id": 21,
"garment_asset_id": 9201,
"scene_ref_asset_id": 8201,
"garment_asset_id": low_resources["garment"]["id"],
"scene_ref_asset_id": low_resources["scene"]["id"],
},
)
assert low_order.status_code == 201

View File

@@ -0,0 +1,150 @@
"""Tests for the Gemini image provider."""
from __future__ import annotations
import base64
import httpx
import pytest
from app.infra.image_generation.base import SourceImage
from app.infra.image_generation.gemini_provider import GeminiImageProvider
@pytest.mark.asyncio
async def test_gemini_provider_uses_configured_base_url_and_model():
"""The provider should call the configured endpoint and decode the returned image bytes."""
expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent"
png_bytes = b"generated-png"
async def handler(request: httpx.Request) -> httpx.Response:
assert str(request.url) == expected_url
assert request.headers["x-goog-api-key"] == "test-key"
payload = request.read().decode("utf-8")
assert "gemini-test-model" not in payload
assert "Dress the person in the garment" in payload
assert "cGVyc29uLWJ5dGVz" in payload
assert "Z2FybWVudC1ieXRlcw==" in payload
return httpx.Response(
200,
json={
"candidates": [
{
"content": {
"parts": [
{"text": "done"},
{
"inlineData": {
"mimeType": "image/png",
"data": base64.b64encode(png_bytes).decode("utf-8"),
}
},
]
}
}
]
},
)
provider = GeminiImageProvider(
api_key="test-key",
base_url="https://gemini.example.com/custom",
model="gemini-test-model",
timeout_seconds=5,
http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)),
)
try:
result = await provider.generate_tryon_image(
prompt="Dress the person in the garment",
person_image=SourceImage(
url="https://images.example.com/person.png",
mime_type="image/png",
data=b"person-bytes",
),
garment_image=SourceImage(
url="https://images.example.com/garment.png",
mime_type="image/png",
data=b"garment-bytes",
),
)
finally:
await provider._http_client.aclose()
assert result.image_bytes == png_bytes
assert result.mime_type == "image/png"
assert result.provider == "gemini"
assert result.model == "gemini-test-model"
@pytest.mark.asyncio
async def test_gemini_provider_retries_remote_disconnect_and_uses_request_timeout():
"""The provider should retry transient transport failures and pass timeout per request."""
expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent"
jpeg_bytes = b"generated-jpeg"
class FakeClient:
def __init__(self) -> None:
self.attempts = 0
self.timeouts: list[int | None] = []
async def post(self, url: str, **kwargs) -> httpx.Response:
self.attempts += 1
self.timeouts.append(kwargs.get("timeout"))
assert url == expected_url
if self.attempts == 1:
raise httpx.RemoteProtocolError("Server disconnected without sending a response.")
return httpx.Response(
200,
json={
"candidates": [
{
"content": {
"parts": [
{
"inlineData": {
"mimeType": "image/jpeg",
"data": base64.b64encode(jpeg_bytes).decode("utf-8"),
}
}
]
}
}
]
},
request=httpx.Request("POST", url),
)
async def aclose(self) -> None:
return None
fake_client = FakeClient()
provider = GeminiImageProvider(
api_key="test-key",
base_url="https://gemini.example.com/custom",
model="gemini-test-model",
timeout_seconds=300,
http_client=fake_client,
)
result = await provider.generate_tryon_image(
prompt="Dress the person in the garment",
person_image=SourceImage(
url="https://images.example.com/person.png",
mime_type="image/png",
data=b"person-bytes",
),
garment_image=SourceImage(
url="https://images.example.com/garment.png",
mime_type="image/png",
data=b"garment-bytes",
),
)
assert fake_client.attempts == 2
assert fake_client.timeouts == [300, 300]
assert result.image_bytes == jpeg_bytes
assert result.mime_type == "image/jpeg"

View File

@@ -0,0 +1,91 @@
"""Tests for application-level image-generation orchestration."""
from __future__ import annotations
import httpx
import pytest
from app.application.services.image_generation_service import ImageGenerationService
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
class FakeProvider:
def __init__(self) -> None:
self.calls: list[tuple[str, str, str]] = []
async def generate_tryon_image(
self,
*,
prompt: str,
person_image: SourceImage,
garment_image: SourceImage,
) -> GeneratedImageResult:
self.calls.append(("tryon", person_image.url, garment_image.url))
return GeneratedImageResult(
image_bytes=b"tryon",
mime_type="image/png",
provider="gemini",
model="gemini-test",
prompt=prompt,
)
async def generate_scene_image(
self,
*,
prompt: str,
source_image: SourceImage,
scene_image: SourceImage,
) -> GeneratedImageResult:
self.calls.append(("scene", source_image.url, scene_image.url))
return GeneratedImageResult(
image_bytes=b"scene",
mime_type="image/jpeg",
provider="gemini",
model="gemini-test",
prompt=prompt,
)
@pytest.mark.asyncio
async def test_image_generation_service_downloads_inputs_for_tryon_and_scene(monkeypatch):
"""The service should download both inputs and dispatch them to the matching provider method."""
responses = {
"https://images.example.com/person.png": (b"person-bytes", "image/png"),
"https://images.example.com/garment.png": (b"garment-bytes", "image/png"),
"https://images.example.com/source.jpg": (b"source-bytes", "image/jpeg"),
"https://images.example.com/scene.jpg": (b"scene-bytes", "image/jpeg"),
}
async def handler(request: httpx.Request) -> httpx.Response:
body, mime = responses[str(request.url)]
return httpx.Response(200, content=body, headers={"content-type": mime}, request=request)
monkeypatch.setenv("GEMINI_API_KEY", "test-key")
from app.config.settings import get_settings
get_settings.cache_clear()
provider = FakeProvider()
downloader = httpx.AsyncClient(transport=httpx.MockTransport(handler))
service = ImageGenerationService(provider=provider, downloader=downloader)
try:
tryon = await service.generate_tryon_image(
person_image_url="https://images.example.com/person.png",
garment_image_url="https://images.example.com/garment.png",
)
scene = await service.generate_scene_image(
source_image_url="https://images.example.com/source.jpg",
scene_image_url="https://images.example.com/scene.jpg",
)
finally:
await downloader.aclose()
get_settings.cache_clear()
assert provider.calls == [
("tryon", "https://images.example.com/person.png", "https://images.example.com/garment.png"),
("scene", "https://images.example.com/source.jpg", "https://images.example.com/scene.jpg"),
]
assert tryon.image_bytes == b"tryon"
assert scene.image_bytes == b"scene"

View File

@@ -0,0 +1,63 @@
import pytest
@pytest.mark.asyncio
async def test_library_resource_can_be_updated_and_archived(api_runtime):
client, _ = api_runtime
create_response = await client.post(
"/api/v1/library/resources",
json={
"resource_type": "model",
"name": "Ava Studio",
"description": "棚拍女模特",
"tags": ["女装", "棚拍"],
"gender": "female",
"age_group": "adult",
"files": [
{
"file_role": "original",
"storage_key": "library/models/ava/original.png",
"public_url": "https://images.marcusd.me/library/models/ava/original.png",
"mime_type": "image/png",
"size_bytes": 1024,
"sort_order": 0,
},
{
"file_role": "thumbnail",
"storage_key": "library/models/ava/thumb.png",
"public_url": "https://images.marcusd.me/library/models/ava/thumb.png",
"mime_type": "image/png",
"size_bytes": 256,
"sort_order": 0,
},
],
},
)
assert create_response.status_code == 201
resource = create_response.json()
update_response = await client.patch(
f"/api/v1/library/resources/{resource['id']}",
json={
"name": "Ava Studio Updated",
"description": "新的描述",
"tags": ["女装", "更新"],
},
)
assert update_response.status_code == 200
updated = update_response.json()
assert updated["name"] == "Ava Studio Updated"
assert updated["description"] == "新的描述"
assert updated["tags"] == ["女装", "更新"]
archive_response = await client.delete(f"/api/v1/library/resources/{resource['id']}")
assert archive_response.status_code == 200
assert archive_response.json()["id"] == resource["id"]
list_response = await client.get("/api/v1/library/resources", params={"resource_type": "model"})
assert list_response.status_code == 200
assert list_response.json()["items"] == []

15
tests/test_settings.py Normal file
View File

@@ -0,0 +1,15 @@
"""Tests for application settings resolution."""
from pathlib import Path
from app.config.settings import Settings
def test_settings_env_file_uses_backend_repo_path() -> None:
"""Settings should resolve .env relative to the backend repo, not the launch cwd."""
env_file = Settings.model_config.get("env_file")
assert isinstance(env_file, Path)
assert env_file.is_absolute()
assert env_file.name == ".env"

View File

@@ -0,0 +1,298 @@
"""Focused tests for the try-on activity implementation."""
from __future__ import annotations
from dataclasses import dataclass
import pytest
from app.domain.enums import AssetType, LibraryFileRole, LibraryResourceStatus, LibraryResourceType, OrderStatus, WorkflowStepName
from app.infra.db.models.asset import AssetORM
from app.infra.db.models.library_resource import LibraryResourceORM
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
from app.infra.db.models.order import OrderORM
from app.infra.db.models.workflow_run import WorkflowRunORM
from app.infra.db.session import get_session_factory
from app.workers.activities import tryon_activities
from app.workers.workflows.types import StepActivityInput
@dataclass(slots=True)
class FakeGeneratedImage:
image_bytes: bytes
mime_type: str
provider: str
model: str
prompt: str
class FakeImageGenerationService:
async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> FakeGeneratedImage:
assert person_image_url == "https://images.example.com/orders/1/prepared-model.png"
assert garment_image_url == "https://images.example.com/library/garments/cream-dress/original.png"
return FakeGeneratedImage(
image_bytes=b"fake-png-binary",
mime_type="image/png",
provider="gemini",
model="gemini-test-image",
prompt="test prompt",
)
async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> FakeGeneratedImage:
assert source_image_url == "https://images.example.com/orders/1/tryon/generated.png"
assert scene_image_url == "https://images.example.com/library/scenes/studio/original.png"
return FakeGeneratedImage(
image_bytes=b"fake-scene-binary",
mime_type="image/jpeg",
provider="gemini",
model="gemini-test-image",
prompt="scene prompt",
)
class FakeOrderArtifactStorageService:
async def upload_generated_image(
self,
*,
order_id: int,
step_name: WorkflowStepName,
image_bytes: bytes,
mime_type: str,
) -> tuple[str, str]:
assert order_id == 1
assert step_name == WorkflowStepName.TRYON
assert image_bytes == b"fake-png-binary"
assert mime_type == "image/png"
return (
"orders/1/tryon/generated.png",
"https://images.example.com/orders/1/tryon/generated.png",
)
class FakeSceneArtifactStorageService:
async def upload_generated_image(
self,
*,
order_id: int,
step_name: WorkflowStepName,
image_bytes: bytes,
mime_type: str,
) -> tuple[str, str]:
assert order_id == 1
assert step_name == WorkflowStepName.SCENE
assert image_bytes == b"fake-scene-binary"
assert mime_type == "image/jpeg"
return (
"orders/1/scene/generated.jpg",
"https://images.example.com/orders/1/scene/generated.jpg",
)
@pytest.mark.asyncio
async def test_run_tryon_activity_persists_uploaded_asset_with_provider_metadata(api_runtime, monkeypatch):
"""Gemini-mode try-on should persist the uploaded output URL instead of a mock URI."""
monkeypatch.setattr(
tryon_activities,
"get_image_generation_service",
lambda: FakeImageGenerationService(),
)
monkeypatch.setattr(
tryon_activities,
"get_order_artifact_storage_service",
lambda: FakeOrderArtifactStorageService(),
)
async with get_session_factory()() as session:
order = OrderORM(
customer_level="low",
service_mode="auto_basic",
status=OrderStatus.CREATED,
model_id=1,
garment_asset_id=2,
)
session.add(order)
await session.flush()
workflow_run = WorkflowRunORM(
order_id=order.id,
workflow_id=f"order-{order.id}",
workflow_type="LowEndPipelineWorkflow",
status=OrderStatus.CREATED,
)
session.add(workflow_run)
await session.flush()
prepared_asset = AssetORM(
order_id=order.id,
asset_type=AssetType.PREPARED_MODEL,
step_name=WorkflowStepName.PREPARE_MODEL,
uri="https://images.example.com/orders/1/prepared-model.png",
metadata_json={"library_resource_id": 1},
)
session.add(prepared_asset)
garment_resource = LibraryResourceORM(
resource_type=LibraryResourceType.GARMENT,
name="Cream Dress",
description="米白色连衣裙",
tags=["女装"],
status=LibraryResourceStatus.ACTIVE,
category="dress",
)
session.add(garment_resource)
await session.flush()
garment_original = LibraryResourceFileORM(
resource_id=garment_resource.id,
file_role=LibraryFileRole.ORIGINAL,
storage_key="library/garments/cream-dress/original.png",
public_url="https://images.example.com/library/garments/cream-dress/original.png",
bucket="test-bucket",
mime_type="image/png",
size_bytes=1024,
sort_order=0,
)
garment_thumb = LibraryResourceFileORM(
resource_id=garment_resource.id,
file_role=LibraryFileRole.THUMBNAIL,
storage_key="library/garments/cream-dress/thumb.png",
public_url="https://images.example.com/library/garments/cream-dress/thumb.png",
bucket="test-bucket",
mime_type="image/png",
size_bytes=256,
sort_order=0,
)
session.add_all([garment_original, garment_thumb])
await session.flush()
garment_resource.original_file_id = garment_original.id
garment_resource.cover_file_id = garment_thumb.id
await session.commit()
payload = StepActivityInput(
order_id=1,
workflow_run_id=1,
step_name=WorkflowStepName.TRYON,
source_asset_id=1,
garment_asset_id=1,
)
result = await tryon_activities.run_tryon_activity(payload)
assert result.uri == "https://images.example.com/orders/1/tryon/generated.png"
assert result.metadata["provider"] == "gemini"
assert result.metadata["model"] == "gemini-test-image"
assert result.metadata["prepared_asset_id"] == 1
assert result.metadata["garment_resource_id"] == 1
async with get_session_factory()() as session:
assets = (await session.execute(
AssetORM.__table__.select().where(AssetORM.order_id == 1, AssetORM.asset_type == AssetType.TRYON)
)).mappings().all()
assert len(assets) == 1
assert assets[0]["uri"] == "https://images.example.com/orders/1/tryon/generated.png"
@pytest.mark.asyncio
async def test_run_scene_activity_persists_uploaded_asset_with_provider_metadata(api_runtime, monkeypatch):
"""Gemini-mode scene should persist the uploaded output URL instead of a mock URI."""
from app.workers.activities import scene_activities
monkeypatch.setattr(
scene_activities,
"get_image_generation_service",
lambda: FakeImageGenerationService(),
)
monkeypatch.setattr(
scene_activities,
"get_order_artifact_storage_service",
lambda: FakeSceneArtifactStorageService(),
)
async with get_session_factory()() as session:
order = OrderORM(
customer_level="low",
service_mode="auto_basic",
status=OrderStatus.CREATED,
model_id=1,
garment_asset_id=2,
scene_ref_asset_id=3,
)
session.add(order)
await session.flush()
workflow_run = WorkflowRunORM(
order_id=order.id,
workflow_id=f"order-{order.id}",
workflow_type="LowEndPipelineWorkflow",
status=OrderStatus.CREATED,
)
session.add(workflow_run)
await session.flush()
tryon_asset = AssetORM(
order_id=order.id,
asset_type=AssetType.TRYON,
step_name=WorkflowStepName.TRYON,
uri="https://images.example.com/orders/1/tryon/generated.png",
metadata_json={"prepared_asset_id": 1},
)
session.add(tryon_asset)
scene_resource = LibraryResourceORM(
resource_type=LibraryResourceType.SCENE,
name="Studio Background",
description="摄影棚背景",
tags=["室内"],
status=LibraryResourceStatus.ACTIVE,
environment="indoor",
)
session.add(scene_resource)
await session.flush()
scene_original = LibraryResourceFileORM(
resource_id=scene_resource.id,
file_role=LibraryFileRole.ORIGINAL,
storage_key="library/scenes/studio/original.png",
public_url="https://images.example.com/library/scenes/studio/original.png",
bucket="test-bucket",
mime_type="image/png",
size_bytes=2048,
sort_order=0,
)
session.add(scene_original)
await session.flush()
scene_resource.original_file_id = scene_original.id
scene_resource.cover_file_id = scene_original.id
await session.commit()
payload = StepActivityInput(
order_id=1,
workflow_run_id=1,
step_name=WorkflowStepName.SCENE,
source_asset_id=1,
scene_ref_asset_id=1,
)
result = await scene_activities.run_scene_activity(payload)
assert result.uri == "https://images.example.com/orders/1/scene/generated.jpg"
assert result.metadata["provider"] == "gemini"
assert result.metadata["model"] == "gemini-test-image"
assert result.metadata["source_asset_id"] == 1
assert result.metadata["scene_resource_id"] == 1
async with get_session_factory()() as session:
assets = (
await session.execute(
AssetORM.__table__.select().where(
AssetORM.order_id == 1,
AssetORM.asset_type == AssetType.SCENE,
)
)
).mappings().all()
assert len(assets) == 1
assert assets[0]["uri"] == "https://images.example.com/orders/1/scene/generated.jpg"

View File

@@ -0,0 +1,33 @@
"""Tests for workflow activity timeout policies."""
from datetime import timedelta
from app.infra.temporal.task_queues import (
IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
IMAGE_PIPELINE_QC_TASK_QUEUE,
)
from app.workers.workflows.timeout_policy import (
DEFAULT_ACTIVITY_TIMEOUT,
LONG_RUNNING_ACTIVITY_TIMEOUT,
activity_timeout_for_task_queue,
)
def test_activity_timeout_for_task_queue_uses_long_timeout_for_image_work():
"""Image generation and post-processing queues should get a longer timeout."""
assert DEFAULT_ACTIVITY_TIMEOUT == timedelta(seconds=30)
assert LONG_RUNNING_ACTIVITY_TIMEOUT == timedelta(minutes=5)
assert activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE) == LONG_RUNNING_ACTIVITY_TIMEOUT
assert activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE) == LONG_RUNNING_ACTIVITY_TIMEOUT
def test_activity_timeout_for_task_queue_keeps_short_timeout_for_light_steps():
"""Control, QC, and export queues should stay on the short timeout."""
assert activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE) == DEFAULT_ACTIVITY_TIMEOUT
assert activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE) == DEFAULT_ACTIVITY_TIMEOUT
assert activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE) == DEFAULT_ACTIVITY_TIMEOUT