feat: add resource library and real image workflow
This commit is contained in:
@@ -5,3 +5,9 @@ AUTO_CREATE_TABLES=true
|
||||
DATABASE_URL=sqlite+aiosqlite:///./temporal_demo.db
|
||||
TEMPORAL_ADDRESS=localhost:7233
|
||||
TEMPORAL_NAMESPACE=default
|
||||
IMAGE_GENERATION_PROVIDER=mock
|
||||
GEMINI_API_KEY=
|
||||
GEMINI_BASE_URL=https://api.museidea.com/v1beta
|
||||
GEMINI_MODEL=gemini-3.1-flash-image-preview
|
||||
GEMINI_TIMEOUT_SECONDS=300
|
||||
GEMINI_MAX_ATTEMPTS=2
|
||||
|
||||
@@ -8,12 +8,14 @@ from sqlalchemy import engine_from_config, pool
|
||||
from app.config.settings import get_settings
|
||||
from app.infra.db.base import Base
|
||||
from app.infra.db.models.asset import AssetORM
|
||||
from app.infra.db.models.library_resource import LibraryResourceORM
|
||||
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
|
||||
from app.infra.db.models.order import OrderORM
|
||||
from app.infra.db.models.review_task import ReviewTaskORM
|
||||
from app.infra.db.models.workflow_run import WorkflowRunORM
|
||||
from app.infra.db.models.workflow_step import WorkflowStepORM
|
||||
|
||||
del AssetORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM
|
||||
del AssetORM, LibraryResourceORM, LibraryResourceFileORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM
|
||||
|
||||
config = context.config
|
||||
|
||||
@@ -58,4 +60,3 @@ if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
|
||||
|
||||
92
alembic/versions/20260328_0003_resource_library.py
Normal file
92
alembic/versions/20260328_0003_resource_library.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""resource library schema
|
||||
|
||||
Revision ID: 20260328_0003
|
||||
Revises: 20260327_0002
|
||||
Create Date: 2026-03-28 10:20:00.000000
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = "20260328_0003"
|
||||
down_revision: str | None = "20260327_0002"
|
||||
branch_labels: Sequence[str] | None = None
|
||||
depends_on: Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create the resource-library tables."""
|
||||
|
||||
op.create_table(
|
||||
"library_resources",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column(
|
||||
"resource_type",
|
||||
sa.Enum("model", "scene", "garment", name="libraryresourcetype", native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("tags", sa.JSON(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum("active", "archived", name="libraryresourcestatus", native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("gender", sa.String(length=32), nullable=True),
|
||||
sa.Column("age_group", sa.String(length=32), nullable=True),
|
||||
sa.Column("pose_id", sa.Integer(), nullable=True),
|
||||
sa.Column("environment", sa.String(length=32), nullable=True),
|
||||
sa.Column("category", sa.String(length=128), nullable=True),
|
||||
sa.Column("cover_file_id", sa.Integer(), nullable=True),
|
||||
sa.Column("original_file_id", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
op.create_index("ix_library_resources_resource_type", "library_resources", ["resource_type"])
|
||||
op.create_index("ix_library_resources_status", "library_resources", ["status"])
|
||||
op.create_index("ix_library_resources_gender", "library_resources", ["gender"])
|
||||
op.create_index("ix_library_resources_age_group", "library_resources", ["age_group"])
|
||||
op.create_index("ix_library_resources_environment", "library_resources", ["environment"])
|
||||
op.create_index("ix_library_resources_category", "library_resources", ["category"])
|
||||
|
||||
op.create_table(
|
||||
"library_resource_files",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("resource_id", sa.Integer(), sa.ForeignKey("library_resources.id"), nullable=False),
|
||||
sa.Column(
|
||||
"file_role",
|
||||
sa.Enum("original", "thumbnail", "gallery", name="libraryfilerole", native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("storage_key", sa.String(length=500), nullable=False),
|
||||
sa.Column("public_url", sa.String(length=500), nullable=False),
|
||||
sa.Column("bucket", sa.String(length=255), nullable=False),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=False),
|
||||
sa.Column("size_bytes", sa.Integer(), nullable=False),
|
||||
sa.Column("sort_order", sa.Integer(), nullable=False),
|
||||
sa.Column("width", sa.Integer(), nullable=True),
|
||||
sa.Column("height", sa.Integer(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
op.create_index("ix_library_resource_files_resource_id", "library_resource_files", ["resource_id"])
|
||||
op.create_index("ix_library_resource_files_file_role", "library_resource_files", ["file_role"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop the resource-library tables."""
|
||||
|
||||
op.drop_index("ix_library_resource_files_file_role", table_name="library_resource_files")
|
||||
op.drop_index("ix_library_resource_files_resource_id", table_name="library_resource_files")
|
||||
op.drop_table("library_resource_files")
|
||||
op.drop_index("ix_library_resources_category", table_name="library_resources")
|
||||
op.drop_index("ix_library_resources_environment", table_name="library_resources")
|
||||
op.drop_index("ix_library_resources_age_group", table_name="library_resources")
|
||||
op.drop_index("ix_library_resources_gender", table_name="library_resources")
|
||||
op.drop_index("ix_library_resources_status", table_name="library_resources")
|
||||
op.drop_index("ix_library_resources_resource_type", table_name="library_resources")
|
||||
op.drop_table("library_resources")
|
||||
25
alembic/versions/20260328_0004_optional_pose_id.py
Normal file
25
alembic/versions/20260328_0004_optional_pose_id.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Make order pose_id optional for MVP.
|
||||
|
||||
Revision ID: 20260328_0004
|
||||
Revises: 20260328_0003
|
||||
Create Date: 2026-03-28 11:08:00.000000
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "20260328_0004"
|
||||
down_revision = "20260328_0003"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("orders") as batch_op:
|
||||
batch_op.alter_column("pose_id", existing_type=sa.Integer(), nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("orders") as batch_op:
|
||||
batch_op.alter_column("pose_id", existing_type=sa.Integer(), nullable=False)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Make order scene_ref_asset_id optional for MVP.
|
||||
|
||||
Revision ID: 20260328_0005
|
||||
Revises: 20260328_0004
|
||||
Create Date: 2026-03-28 14:20:00.000000
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "20260328_0005"
|
||||
down_revision = "20260328_0004"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("orders") as batch_op:
|
||||
batch_op.alter_column("scene_ref_asset_id", existing_type=sa.Integer(), nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("orders") as batch_op:
|
||||
batch_op.alter_column("scene_ref_asset_id", existing_type=sa.Integer(), nullable=False)
|
||||
89
app/api/routers/library.py
Normal file
89
app/api/routers/library.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Library resource routes."""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.schemas.library import (
|
||||
ArchiveLibraryResourceResponse,
|
||||
CreateLibraryResourceRequest,
|
||||
LibraryResourceListResponse,
|
||||
LibraryResourceRead,
|
||||
PresignUploadRequest,
|
||||
PresignUploadResponse,
|
||||
UpdateLibraryResourceRequest,
|
||||
)
|
||||
from app.application.services.library_service import LibraryService
|
||||
from app.domain.enums import LibraryResourceType
|
||||
from app.infra.db.session import get_db_session
|
||||
|
||||
router = APIRouter(prefix="/library", tags=["library"])
|
||||
library_service = LibraryService()
|
||||
|
||||
|
||||
@router.post("/uploads/presign", response_model=PresignUploadResponse)
|
||||
async def create_library_upload_presign(payload: PresignUploadRequest) -> PresignUploadResponse:
|
||||
"""Create direct-upload metadata for a library file."""
|
||||
|
||||
return library_service.create_upload_presign(
|
||||
resource_type=payload.resource_type,
|
||||
file_name=payload.file_name,
|
||||
content_type=payload.content_type,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/resources", response_model=LibraryResourceRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_library_resource(
|
||||
payload: CreateLibraryResourceRequest,
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> LibraryResourceRead:
|
||||
"""Create a library resource with already-uploaded file metadata."""
|
||||
|
||||
return await library_service.create_resource(session, payload)
|
||||
|
||||
|
||||
@router.get("/resources", response_model=LibraryResourceListResponse)
|
||||
async def list_library_resources(
|
||||
resource_type: LibraryResourceType | None = Query(default=None),
|
||||
query: str | None = Query(default=None, min_length=1),
|
||||
gender: str | None = Query(default=None),
|
||||
age_group: str | None = Query(default=None),
|
||||
environment: str | None = Query(default=None),
|
||||
category: str | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> LibraryResourceListResponse:
|
||||
"""List library resources with basic filtering."""
|
||||
|
||||
return await library_service.list_resources(
|
||||
session,
|
||||
resource_type=resource_type,
|
||||
query=query,
|
||||
gender=gender,
|
||||
age_group=age_group,
|
||||
environment=environment,
|
||||
category=category,
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/resources/{resource_id}", response_model=LibraryResourceRead)
|
||||
async def update_library_resource(
|
||||
resource_id: int,
|
||||
payload: UpdateLibraryResourceRequest,
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> LibraryResourceRead:
|
||||
"""Update editable metadata on a library resource."""
|
||||
|
||||
return await library_service.update_resource(session, resource_id, payload)
|
||||
|
||||
|
||||
@router.delete("/resources/{resource_id}", response_model=ArchiveLibraryResourceResponse)
|
||||
async def archive_library_resource(
|
||||
resource_id: int,
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> ArchiveLibraryResourceResponse:
|
||||
"""Archive a library resource instead of hard deleting it."""
|
||||
|
||||
return await library_service.archive_resource(session, resource_id)
|
||||
118
app/api/schemas/library.py
Normal file
118
app/api/schemas/library.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Library resource API schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.domain.enums import LibraryFileRole, LibraryResourceStatus, LibraryResourceType
|
||||
|
||||
|
||||
class PresignUploadRequest(BaseModel):
|
||||
"""Request payload for generating direct-upload metadata."""
|
||||
|
||||
resource_type: LibraryResourceType
|
||||
file_role: LibraryFileRole
|
||||
file_name: str
|
||||
content_type: str
|
||||
|
||||
|
||||
class PresignUploadResponse(BaseModel):
|
||||
"""Response returned when generating direct-upload metadata."""
|
||||
|
||||
method: str
|
||||
upload_url: str
|
||||
headers: dict[str, str] = Field(default_factory=dict)
|
||||
storage_key: str
|
||||
public_url: str
|
||||
|
||||
|
||||
class CreateLibraryResourceFileRequest(BaseModel):
|
||||
"""Metadata for one already-uploaded file."""
|
||||
|
||||
file_role: LibraryFileRole
|
||||
storage_key: str
|
||||
public_url: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
sort_order: int = 0
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
|
||||
|
||||
class CreateLibraryResourceRequest(BaseModel):
|
||||
"""One-shot resource creation request."""
|
||||
|
||||
resource_type: LibraryResourceType
|
||||
name: str
|
||||
description: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
gender: str | None = None
|
||||
age_group: str | None = None
|
||||
pose_id: int | None = None
|
||||
environment: str | None = None
|
||||
category: str | None = None
|
||||
files: list[CreateLibraryResourceFileRequest]
|
||||
|
||||
|
||||
class UpdateLibraryResourceRequest(BaseModel):
|
||||
"""Partial update request for a library resource."""
|
||||
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
tags: list[str] | None = None
|
||||
gender: str | None = None
|
||||
age_group: str | None = None
|
||||
pose_id: int | None = None
|
||||
environment: str | None = None
|
||||
category: str | None = None
|
||||
cover_file_id: int | None = None
|
||||
|
||||
|
||||
class ArchiveLibraryResourceResponse(BaseModel):
|
||||
"""Response returned after archiving a resource."""
|
||||
|
||||
id: int
|
||||
|
||||
|
||||
class LibraryResourceFileRead(BaseModel):
|
||||
"""Serialized library file metadata."""
|
||||
|
||||
id: int
|
||||
file_role: LibraryFileRole
|
||||
storage_key: str
|
||||
public_url: str
|
||||
bucket: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
sort_order: int
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class LibraryResourceRead(BaseModel):
|
||||
"""Serialized library resource."""
|
||||
|
||||
id: int
|
||||
resource_type: LibraryResourceType
|
||||
name: str
|
||||
description: str | None = None
|
||||
tags: list[str]
|
||||
status: LibraryResourceStatus
|
||||
gender: str | None = None
|
||||
age_group: str | None = None
|
||||
pose_id: int | None = None
|
||||
environment: str | None = None
|
||||
category: str | None = None
|
||||
cover_url: str | None = None
|
||||
original_url: str | None = None
|
||||
files: list[LibraryResourceFileRead]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class LibraryResourceListResponse(BaseModel):
|
||||
"""Paginated library resource response."""
|
||||
|
||||
total: int
|
||||
items: list[LibraryResourceRead]
|
||||
@@ -14,9 +14,9 @@ class CreateOrderRequest(BaseModel):
|
||||
customer_level: CustomerLevel
|
||||
service_mode: ServiceMode
|
||||
model_id: int
|
||||
pose_id: int
|
||||
pose_id: int | None = None
|
||||
garment_asset_id: int
|
||||
scene_ref_asset_id: int
|
||||
scene_ref_asset_id: int | None = None
|
||||
|
||||
|
||||
class CreateOrderResponse(BaseModel):
|
||||
@@ -35,9 +35,9 @@ class OrderDetailResponse(BaseModel):
|
||||
service_mode: ServiceMode
|
||||
status: OrderStatus
|
||||
model_id: int
|
||||
pose_id: int
|
||||
pose_id: int | None
|
||||
garment_asset_id: int
|
||||
scene_ref_asset_id: int
|
||||
scene_ref_asset_id: int | None
|
||||
final_asset_id: int | None
|
||||
workflow_id: str | None
|
||||
current_step: WorkflowStepName | None
|
||||
|
||||
163
app/application/services/image_generation_service.py
Normal file
163
app/application/services/image_generation_service.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Application-facing orchestration for image-generation providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.config.settings import get_settings
|
||||
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
|
||||
from app.infra.image_generation.gemini_provider import GeminiImageProvider
|
||||
|
||||
TRYON_PROMPT = (
|
||||
"Use the first image as the base person image and the second image as the garment reference. "
|
||||
"Generate a realistic virtual try-on result where the person in the first image is wearing the garment from "
|
||||
"the second image. Preserve the person's identity, body proportions, pose, camera angle, lighting direction, "
|
||||
"and background composition from the first image. Preserve the garment's key silhouette, material feel, color, "
|
||||
"and visible design details from the second image. Produce a clean e-commerce quality result without adding "
|
||||
"extra accessories, text, watermarks, or duplicate limbs."
|
||||
)
|
||||
|
||||
SCENE_PROMPT = (
|
||||
"Use the first image as the finished subject image and the second image as the target scene reference. "
|
||||
"Generate a realistic composite where the person, clothing, body proportions, facial identity, pose, and camera "
|
||||
"framing from the first image are preserved, while the environment is adapted to match the second image. "
|
||||
"Keep the garment details, colors, and texture from the first image intact. Blend the subject naturally into the "
|
||||
"target scene with coherent perspective, lighting, shadows, and depth. Do not add extra people, accessories, "
|
||||
"text, logos, watermarks, or duplicate body parts."
|
||||
)
|
||||
|
||||
MOCK_PNG_BASE64 = (
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO5Ww4QAAAAASUVORK5CYII="
|
||||
)
|
||||
|
||||
|
||||
class MockImageGenerationService:
|
||||
"""Deterministic fallback used in tests and when no real provider is configured."""
|
||||
|
||||
async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> GeneratedImageResult:
|
||||
return GeneratedImageResult(
|
||||
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
|
||||
mime_type="image/png",
|
||||
provider="mock",
|
||||
model="mock-image-provider",
|
||||
prompt=TRYON_PROMPT,
|
||||
)
|
||||
|
||||
async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> GeneratedImageResult:
|
||||
return GeneratedImageResult(
|
||||
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
|
||||
mime_type="image/png",
|
||||
provider="mock",
|
||||
model="mock-image-provider",
|
||||
prompt=SCENE_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationService:
|
||||
"""Download source assets and dispatch them to the configured provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: GeminiImageProvider | None = None,
|
||||
downloader: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self.settings = get_settings()
|
||||
self.provider = provider or GeminiImageProvider(
|
||||
api_key=self.settings.gemini_api_key,
|
||||
base_url=self.settings.gemini_base_url,
|
||||
model=self.settings.gemini_model,
|
||||
timeout_seconds=self.settings.gemini_timeout_seconds,
|
||||
max_attempts=self.settings.gemini_max_attempts,
|
||||
)
|
||||
self._downloader = downloader
|
||||
|
||||
async def generate_tryon_image(
|
||||
self,
|
||||
*,
|
||||
person_image_url: str,
|
||||
garment_image_url: str,
|
||||
) -> GeneratedImageResult:
|
||||
"""Generate a try-on image using the configured provider."""
|
||||
|
||||
if not self.settings.gemini_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
|
||||
)
|
||||
|
||||
person_image, garment_image = await self._download_inputs(
|
||||
first_image_url=person_image_url,
|
||||
second_image_url=garment_image_url,
|
||||
)
|
||||
return await self.provider.generate_tryon_image(
|
||||
prompt=TRYON_PROMPT,
|
||||
person_image=person_image,
|
||||
garment_image=garment_image,
|
||||
)
|
||||
|
||||
async def generate_scene_image(
|
||||
self,
|
||||
*,
|
||||
source_image_url: str,
|
||||
scene_image_url: str,
|
||||
) -> GeneratedImageResult:
|
||||
"""Generate a scene-composited image using the configured provider."""
|
||||
|
||||
if not self.settings.gemini_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
|
||||
)
|
||||
|
||||
source_image, scene_image = await self._download_inputs(
|
||||
first_image_url=source_image_url,
|
||||
second_image_url=scene_image_url,
|
||||
)
|
||||
return await self.provider.generate_scene_image(
|
||||
prompt=SCENE_PROMPT,
|
||||
source_image=source_image,
|
||||
scene_image=scene_image,
|
||||
)
|
||||
|
||||
async def _download_inputs(
|
||||
self,
|
||||
*,
|
||||
first_image_url: str,
|
||||
second_image_url: str,
|
||||
) -> tuple[SourceImage, SourceImage]:
|
||||
owns_client = self._downloader is None
|
||||
client = self._downloader or httpx.AsyncClient(timeout=self.settings.gemini_timeout_seconds)
|
||||
try:
|
||||
first_response = await client.get(first_image_url)
|
||||
first_response.raise_for_status()
|
||||
second_response = await client.get(second_image_url)
|
||||
second_response.raise_for_status()
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.aclose()
|
||||
|
||||
first_mime = first_response.headers.get("content-type", "image/png").split(";")[0].strip()
|
||||
second_mime = second_response.headers.get("content-type", "image/png").split(";")[0].strip()
|
||||
|
||||
return (
|
||||
SourceImage(url=first_image_url, mime_type=first_mime or "image/png", data=first_response.content),
|
||||
SourceImage(url=second_image_url, mime_type=second_mime or "image/png", data=second_response.content),
|
||||
)
|
||||
|
||||
|
||||
def build_image_generation_service():
|
||||
"""Return the configured image-generation service."""
|
||||
|
||||
settings = get_settings()
|
||||
if settings.image_generation_provider.lower() == "mock":
|
||||
return MockImageGenerationService()
|
||||
if settings.image_generation_provider.lower() == "gemini":
|
||||
return ImageGenerationService()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Unsupported image_generation_provider: {settings.image_generation_provider}",
|
||||
)
|
||||
366
app/application/services/library_service.py
Normal file
366
app/application/services/library_service.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""Library resource application service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.api.schemas.library import (
|
||||
ArchiveLibraryResourceResponse,
|
||||
CreateLibraryResourceRequest,
|
||||
LibraryResourceFileRead,
|
||||
LibraryResourceListResponse,
|
||||
LibraryResourceRead,
|
||||
PresignUploadResponse,
|
||||
UpdateLibraryResourceRequest,
|
||||
)
|
||||
from app.config.settings import get_settings
|
||||
from app.domain.enums import LibraryFileRole, LibraryResourceStatus, LibraryResourceType
|
||||
from app.infra.db.models.library_resource import LibraryResourceORM
|
||||
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
|
||||
from app.infra.storage.s3 import RESOURCE_PREFIXES, S3PresignService
|
||||
|
||||
|
||||
class LibraryService:
|
||||
"""Application service for resource-library uploads and queries."""
|
||||
|
||||
def __init__(self, presign_service: S3PresignService | None = None) -> None:
|
||||
self.presign_service = presign_service or S3PresignService()
|
||||
|
||||
def create_upload_presign(
|
||||
self,
|
||||
resource_type: LibraryResourceType,
|
||||
file_name: str,
|
||||
content_type: str,
|
||||
) -> PresignUploadResponse:
|
||||
"""Create upload metadata for a direct S3 PUT."""
|
||||
|
||||
settings = get_settings()
|
||||
if not all(
|
||||
[
|
||||
settings.s3_bucket,
|
||||
settings.s3_region,
|
||||
settings.s3_access_key,
|
||||
settings.s3_secret_key,
|
||||
settings.s3_endpoint,
|
||||
]
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="S3 upload settings are incomplete",
|
||||
)
|
||||
|
||||
storage_key, upload_url = self.presign_service.create_upload(
|
||||
resource_type=resource_type,
|
||||
file_name=file_name,
|
||||
content_type=content_type,
|
||||
)
|
||||
return PresignUploadResponse(
|
||||
method="PUT",
|
||||
upload_url=upload_url,
|
||||
headers={"content-type": content_type},
|
||||
storage_key=storage_key,
|
||||
public_url=self.presign_service.get_public_url(storage_key),
|
||||
)
|
||||
|
||||
async def create_resource(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
payload: CreateLibraryResourceRequest,
|
||||
) -> LibraryResourceRead:
|
||||
"""Persist a resource and its uploaded file metadata."""
|
||||
|
||||
self._validate_payload(payload)
|
||||
|
||||
resource = LibraryResourceORM(
|
||||
resource_type=payload.resource_type,
|
||||
name=payload.name.strip(),
|
||||
description=payload.description.strip() if payload.description else None,
|
||||
tags=[tag.strip() for tag in payload.tags if tag.strip()],
|
||||
status=LibraryResourceStatus.ACTIVE,
|
||||
gender=payload.gender.strip() if payload.gender else None,
|
||||
age_group=payload.age_group.strip() if payload.age_group else None,
|
||||
pose_id=payload.pose_id,
|
||||
environment=payload.environment.strip() if payload.environment else None,
|
||||
category=payload.category.strip() if payload.category else None,
|
||||
)
|
||||
session.add(resource)
|
||||
await session.flush()
|
||||
|
||||
file_models: list[LibraryResourceFileORM] = []
|
||||
for item in payload.files:
|
||||
file_model = LibraryResourceFileORM(
|
||||
resource_id=resource.id,
|
||||
file_role=item.file_role,
|
||||
storage_key=item.storage_key,
|
||||
public_url=item.public_url,
|
||||
bucket=get_settings().s3_bucket,
|
||||
mime_type=item.mime_type,
|
||||
size_bytes=item.size_bytes,
|
||||
sort_order=item.sort_order,
|
||||
width=item.width,
|
||||
height=item.height,
|
||||
)
|
||||
session.add(file_model)
|
||||
file_models.append(file_model)
|
||||
|
||||
await session.flush()
|
||||
resource.original_file_id = self._find_role(file_models, LibraryFileRole.ORIGINAL).id
|
||||
resource.cover_file_id = self._find_role(file_models, LibraryFileRole.THUMBNAIL).id
|
||||
await session.commit()
|
||||
|
||||
return self._to_read(resource, files=file_models)
|
||||
|
||||
async def list_resources(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
resource_type: LibraryResourceType | None = None,
|
||||
query: str | None = None,
|
||||
gender: str | None = None,
|
||||
age_group: str | None = None,
|
||||
environment: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
) -> LibraryResourceListResponse:
|
||||
"""List persisted resources with simple filter support."""
|
||||
|
||||
filters = [LibraryResourceORM.status == LibraryResourceStatus.ACTIVE]
|
||||
if resource_type is not None:
|
||||
filters.append(LibraryResourceORM.resource_type == resource_type)
|
||||
if query:
|
||||
like = f"%{query.strip()}%"
|
||||
filters.append(
|
||||
or_(
|
||||
LibraryResourceORM.name.ilike(like),
|
||||
LibraryResourceORM.description.ilike(like),
|
||||
)
|
||||
)
|
||||
if gender:
|
||||
filters.append(LibraryResourceORM.gender == gender)
|
||||
if age_group:
|
||||
filters.append(LibraryResourceORM.age_group == age_group)
|
||||
if environment:
|
||||
filters.append(LibraryResourceORM.environment == environment)
|
||||
if category:
|
||||
filters.append(LibraryResourceORM.category == category)
|
||||
|
||||
total = (
|
||||
await session.execute(select(func.count(LibraryResourceORM.id)).where(*filters))
|
||||
).scalar_one()
|
||||
|
||||
result = await session.execute(
|
||||
select(LibraryResourceORM)
|
||||
.options(selectinload(LibraryResourceORM.files))
|
||||
.where(*filters)
|
||||
.order_by(LibraryResourceORM.created_at.desc(), LibraryResourceORM.id.desc())
|
||||
.offset((page - 1) * limit)
|
||||
.limit(limit)
|
||||
)
|
||||
items = result.scalars().all()
|
||||
return LibraryResourceListResponse(total=total, items=[self._to_read(item) for item in items])
|
||||
|
||||
async def update_resource(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
resource_id: int,
|
||||
payload: UpdateLibraryResourceRequest,
|
||||
) -> LibraryResourceRead:
|
||||
"""Update editable metadata on a library resource."""
|
||||
|
||||
resource = await self._get_resource_or_404(session, resource_id)
|
||||
|
||||
if payload.name is not None:
|
||||
name = payload.name.strip()
|
||||
if not name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Resource name is required",
|
||||
)
|
||||
resource.name = name
|
||||
|
||||
if payload.description is not None:
|
||||
resource.description = payload.description.strip() or None
|
||||
|
||||
if payload.tags is not None:
|
||||
resource.tags = [tag.strip() for tag in payload.tags if tag.strip()]
|
||||
|
||||
if resource.resource_type == LibraryResourceType.MODEL:
|
||||
if payload.gender is not None:
|
||||
resource.gender = payload.gender.strip() or None
|
||||
if payload.age_group is not None:
|
||||
resource.age_group = payload.age_group.strip() or None
|
||||
if payload.pose_id is not None:
|
||||
resource.pose_id = payload.pose_id
|
||||
elif resource.resource_type == LibraryResourceType.SCENE:
|
||||
if payload.environment is not None:
|
||||
resource.environment = payload.environment.strip() or None
|
||||
elif resource.resource_type == LibraryResourceType.GARMENT:
|
||||
if payload.category is not None:
|
||||
resource.category = payload.category.strip() or None
|
||||
|
||||
if payload.cover_file_id is not None:
|
||||
cover_file = next(
|
||||
(file for file in resource.files if file.id == payload.cover_file_id),
|
||||
None,
|
||||
)
|
||||
if cover_file is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cover file does not belong to the resource",
|
||||
)
|
||||
resource.cover_file_id = cover_file.id
|
||||
|
||||
self._validate_existing_resource(resource)
|
||||
await session.commit()
|
||||
await session.refresh(resource, attribute_names=["files"])
|
||||
return self._to_read(resource)
|
||||
|
||||
async def archive_resource(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
resource_id: int,
|
||||
) -> ArchiveLibraryResourceResponse:
|
||||
"""Soft delete a library resource by archiving it."""
|
||||
|
||||
resource = await self._get_resource_or_404(session, resource_id)
|
||||
resource.status = LibraryResourceStatus.ARCHIVED
|
||||
await session.commit()
|
||||
return ArchiveLibraryResourceResponse(id=resource.id)
|
||||
|
||||
def _validate_payload(self, payload: CreateLibraryResourceRequest) -> None:
|
||||
if not payload.name.strip():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Resource name is required")
|
||||
|
||||
if not payload.files:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one file is required")
|
||||
|
||||
originals = [file for file in payload.files if file.file_role == LibraryFileRole.ORIGINAL]
|
||||
thumbnails = [file for file in payload.files if file.file_role == LibraryFileRole.THUMBNAIL]
|
||||
if len(originals) != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Exactly one original file is required",
|
||||
)
|
||||
if len(thumbnails) != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Exactly one thumbnail file is required",
|
||||
)
|
||||
|
||||
expected_prefix = f"library/{RESOURCE_PREFIXES[payload.resource_type]}/"
|
||||
for file in payload.files:
|
||||
if not file.storage_key.startswith(expected_prefix):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Uploaded file key does not match resource type",
|
||||
)
|
||||
|
||||
if payload.resource_type == LibraryResourceType.MODEL:
|
||||
if not payload.gender or not payload.age_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Model resources require gender and age_group",
|
||||
)
|
||||
elif payload.resource_type == LibraryResourceType.SCENE:
|
||||
if not payload.environment:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Scene resources require environment",
|
||||
)
|
||||
elif payload.resource_type == LibraryResourceType.GARMENT and not payload.category:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Garment resources require category",
|
||||
)
|
||||
|
||||
def _find_role(
|
||||
self,
|
||||
files: list[LibraryResourceFileORM],
|
||||
role: LibraryFileRole,
|
||||
) -> LibraryResourceFileORM:
|
||||
return next(file for file in files if file.file_role == role)
|
||||
|
||||
def _to_read(
|
||||
self,
|
||||
resource: LibraryResourceORM,
|
||||
*,
|
||||
files: list[LibraryResourceFileORM] | None = None,
|
||||
) -> LibraryResourceRead:
|
||||
resource_files = files if files is not None else list(resource.files)
|
||||
files_sorted = sorted(resource_files, key=lambda item: (item.sort_order, item.id))
|
||||
cover = next((item for item in files_sorted if item.id == resource.cover_file_id), None)
|
||||
original = next((item for item in files_sorted if item.id == resource.original_file_id), None)
|
||||
return LibraryResourceRead(
|
||||
id=resource.id,
|
||||
resource_type=resource.resource_type,
|
||||
name=resource.name,
|
||||
description=resource.description,
|
||||
tags=resource.tags,
|
||||
status=resource.status,
|
||||
gender=resource.gender,
|
||||
age_group=resource.age_group,
|
||||
pose_id=resource.pose_id,
|
||||
environment=resource.environment,
|
||||
category=resource.category,
|
||||
cover_url=cover.public_url if cover else None,
|
||||
original_url=original.public_url if original else None,
|
||||
files=[
|
||||
LibraryResourceFileRead(
|
||||
id=file.id,
|
||||
file_role=file.file_role,
|
||||
storage_key=file.storage_key,
|
||||
public_url=file.public_url,
|
||||
bucket=file.bucket,
|
||||
mime_type=file.mime_type,
|
||||
size_bytes=file.size_bytes,
|
||||
sort_order=file.sort_order,
|
||||
width=file.width,
|
||||
height=file.height,
|
||||
created_at=file.created_at,
|
||||
)
|
||||
for file in files_sorted
|
||||
],
|
||||
created_at=resource.created_at,
|
||||
updated_at=resource.updated_at,
|
||||
)
|
||||
|
||||
async def _get_resource_or_404(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
resource_id: int,
|
||||
) -> LibraryResourceORM:
|
||||
result = await session.execute(
|
||||
select(LibraryResourceORM)
|
||||
.options(selectinload(LibraryResourceORM.files))
|
||||
.where(LibraryResourceORM.id == resource_id)
|
||||
)
|
||||
resource = result.scalar_one_or_none()
|
||||
if resource is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found")
|
||||
return resource
|
||||
|
||||
def _validate_existing_resource(self, resource: LibraryResourceORM) -> None:
|
||||
if not resource.name.strip():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Resource name is required")
|
||||
|
||||
if resource.resource_type == LibraryResourceType.MODEL:
|
||||
if not resource.gender or not resource.age_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Model resources require gender and age_group",
|
||||
)
|
||||
elif resource.resource_type == LibraryResourceType.SCENE:
|
||||
if not resource.environment:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Scene resources require environment",
|
||||
)
|
||||
elif resource.resource_type == LibraryResourceType.GARMENT and not resource.category:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Garment resources require category",
|
||||
)
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Application settings."""
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Runtime settings loaded from environment variables."""
|
||||
@@ -15,11 +18,25 @@ class Settings(BaseSettings):
|
||||
database_url: str = "sqlite+aiosqlite:///./temporal_demo.db"
|
||||
temporal_address: str = "localhost:7233"
|
||||
temporal_namespace: str = "default"
|
||||
s3_access_key: str = ""
|
||||
s3_secret_key: str = ""
|
||||
s3_bucket: str = ""
|
||||
s3_region: str = ""
|
||||
s3_endpoint: str = ""
|
||||
s3_cname: str = ""
|
||||
s3_presign_expiry_seconds: int = 900
|
||||
image_generation_provider: str = "mock"
|
||||
gemini_api_key: str = ""
|
||||
gemini_base_url: str = "https://api.museidea.com/v1beta"
|
||||
gemini_model: str = "gemini-3.1-flash-image-preview"
|
||||
gemini_timeout_seconds: int = 300
|
||||
gemini_max_attempts: int = 2
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file=PROJECT_ROOT / ".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -34,4 +51,3 @@ def get_settings() -> Settings:
|
||||
"""Return the cached application settings."""
|
||||
|
||||
return Settings()
|
||||
|
||||
|
||||
@@ -82,3 +82,26 @@ class AssetType(str, Enum):
|
||||
QC_CANDIDATE = "qc_candidate"
|
||||
MANUAL_REVISION = "manual_revision"
|
||||
FINAL = "final"
|
||||
|
||||
|
||||
class LibraryResourceType(str, Enum):
|
||||
"""Supported resource-library item types."""
|
||||
|
||||
MODEL = "model"
|
||||
SCENE = "scene"
|
||||
GARMENT = "garment"
|
||||
|
||||
|
||||
class LibraryFileRole(str, Enum):
|
||||
"""Supported file roles within a library resource."""
|
||||
|
||||
ORIGINAL = "original"
|
||||
THUMBNAIL = "thumbnail"
|
||||
GALLERY = "gallery"
|
||||
|
||||
|
||||
class LibraryResourceStatus(str, Enum):
|
||||
"""Lifecycle state for a library resource."""
|
||||
|
||||
ACTIVE = "active"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
@@ -15,12 +15,11 @@ class Order:
|
||||
service_mode: ServiceMode
|
||||
status: OrderStatus
|
||||
model_id: int
|
||||
pose_id: int
|
||||
pose_id: int | None
|
||||
garment_asset_id: int
|
||||
scene_ref_asset_id: int
|
||||
scene_ref_asset_id: int | None
|
||||
final_asset_id: int | None
|
||||
workflow_id: str | None
|
||||
current_step: WorkflowStepName | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
45
app/infra/db/models/library_resource.py
Normal file
45
app/infra/db/models/library_resource.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Library resource ORM model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Enum, Integer, JSON, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.domain.enums import LibraryResourceStatus, LibraryResourceType
|
||||
from app.infra.db.base import Base, TimestampMixin
|
||||
|
||||
|
||||
class LibraryResourceORM(TimestampMixin, Base):
|
||||
"""Persisted library resource independent from order-generated assets."""
|
||||
|
||||
__tablename__ = "library_resources"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
resource_type: Mapped[LibraryResourceType] = mapped_column(
|
||||
Enum(LibraryResourceType, native_enum=False),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tags: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
||||
status: Mapped[LibraryResourceStatus] = mapped_column(
|
||||
Enum(LibraryResourceStatus, native_enum=False),
|
||||
nullable=False,
|
||||
default=LibraryResourceStatus.ACTIVE,
|
||||
index=True,
|
||||
)
|
||||
gender: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True)
|
||||
age_group: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True)
|
||||
pose_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
environment: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True)
|
||||
category: Mapped[str | None] = mapped_column(String(128), nullable=True, index=True)
|
||||
cover_file_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
original_file_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
files = relationship(
|
||||
"LibraryResourceFileORM",
|
||||
back_populates="resource",
|
||||
lazy="selectin",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
37
app/infra/db/models/library_resource_file.py
Normal file
37
app/infra/db/models/library_resource_file.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Library resource file ORM model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Enum, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.domain.enums import LibraryFileRole
|
||||
from app.infra.db.base import Base, TimestampMixin
|
||||
|
||||
|
||||
class LibraryResourceFileORM(TimestampMixin, Base):
|
||||
"""Persisted uploaded file metadata for a library resource."""
|
||||
|
||||
__tablename__ = "library_resource_files"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
resource_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("library_resources.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
file_role: Mapped[LibraryFileRole] = mapped_column(
|
||||
Enum(LibraryFileRole, native_enum=False),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
storage_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
public_url: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
bucket: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
mime_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
width: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
height: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
resource = relationship("LibraryResourceORM", back_populates="files")
|
||||
@@ -27,12 +27,11 @@ class OrderORM(TimestampMixin, Base):
|
||||
default=OrderStatus.CREATED,
|
||||
)
|
||||
model_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
pose_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
pose_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
garment_asset_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
scene_ref_asset_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
scene_ref_asset_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
final_asset_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
assets = relationship("AssetORM", back_populates="order", lazy="selectin")
|
||||
review_tasks = relationship("ReviewTaskORM", back_populates="order", lazy="selectin")
|
||||
workflow_runs = relationship("WorkflowRunORM", back_populates="order", lazy="selectin")
|
||||
|
||||
|
||||
@@ -44,12 +44,14 @@ async def init_database() -> None:
|
||||
"""Create database tables when running the MVP without migrations."""
|
||||
|
||||
from app.infra.db.models.asset import AssetORM
|
||||
from app.infra.db.models.library_resource import LibraryResourceORM
|
||||
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
|
||||
from app.infra.db.models.order import OrderORM
|
||||
from app.infra.db.models.review_task import ReviewTaskORM
|
||||
from app.infra.db.models.workflow_run import WorkflowRunORM
|
||||
from app.infra.db.models.workflow_step import WorkflowStepORM
|
||||
|
||||
del AssetORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM
|
||||
del AssetORM, LibraryResourceORM, LibraryResourceFileORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM
|
||||
|
||||
async with get_async_engine().begin() as connection:
|
||||
await connection.run_sync(Base.metadata.create_all)
|
||||
|
||||
38
app/infra/image_generation/base.py
Normal file
38
app/infra/image_generation/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Shared types for image-generation providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SourceImage:
|
||||
"""A binary source image passed into an image-generation provider."""
|
||||
|
||||
url: str
|
||||
mime_type: str
|
||||
data: bytes
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class GeneratedImageResult:
|
||||
"""Normalized image-generation output returned by providers."""
|
||||
|
||||
image_bytes: bytes
|
||||
mime_type: str
|
||||
provider: str
|
||||
model: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class ImageGenerationProvider(Protocol):
|
||||
"""Contract implemented by concrete image-generation providers."""
|
||||
|
||||
async def generate_tryon_image(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
person_image: SourceImage,
|
||||
garment_image: SourceImage,
|
||||
) -> GeneratedImageResult: ...
|
||||
149
app/infra/image_generation/gemini_provider.py
Normal file
149
app/infra/image_generation/gemini_provider.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Gemini-backed image-generation provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
|
||||
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
|
||||
|
||||
|
||||
class GeminiImageProvider:
|
||||
"""Call the Gemini image-generation API via a configurable REST endpoint."""
|
||||
|
||||
provider_name = "gemini"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
timeout_seconds: int,
|
||||
max_attempts: int = 2,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.max_attempts = max(1, max_attempts)
|
||||
self._http_client = http_client
|
||||
|
||||
async def generate_tryon_image(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
person_image: SourceImage,
|
||||
garment_image: SourceImage,
|
||||
) -> GeneratedImageResult:
|
||||
"""Generate a try-on image from a prepared person image and a garment image."""
|
||||
|
||||
return await self._generate_two_image_edit(
|
||||
prompt=prompt,
|
||||
first_image=person_image,
|
||||
second_image=garment_image,
|
||||
)
|
||||
|
||||
async def generate_scene_image(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
source_image: SourceImage,
|
||||
scene_image: SourceImage,
|
||||
) -> GeneratedImageResult:
|
||||
"""Generate a scene-composited image from a rendered subject image and a scene reference."""
|
||||
|
||||
return await self._generate_two_image_edit(
|
||||
prompt=prompt,
|
||||
first_image=source_image,
|
||||
second_image=scene_image,
|
||||
)
|
||||
|
||||
async def _generate_two_image_edit(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
first_image: SourceImage,
|
||||
second_image: SourceImage,
|
||||
) -> GeneratedImageResult:
|
||||
"""Generate an edited image from a prompt plus two inline image inputs."""
|
||||
|
||||
payload = {
|
||||
"contents": [
|
||||
{
|
||||
"parts": [
|
||||
{"text": prompt},
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": first_image.mime_type,
|
||||
"data": base64.b64encode(first_image.data).decode("utf-8"),
|
||||
}
|
||||
},
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": second_image.mime_type,
|
||||
"data": base64.b64encode(second_image.data).decode("utf-8"),
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
],
|
||||
"generationConfig": {
|
||||
"responseModalities": ["TEXT", "IMAGE"],
|
||||
},
|
||||
}
|
||||
|
||||
owns_client = self._http_client is None
|
||||
client = self._http_client or httpx.AsyncClient(timeout=self.timeout_seconds)
|
||||
try:
|
||||
response = None
|
||||
for attempt in range(1, self.max_attempts + 1):
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/models/{self.model}:generateContent",
|
||||
headers={
|
||||
"x-goog-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=payload,
|
||||
timeout=self.timeout_seconds,
|
||||
)
|
||||
response.raise_for_status()
|
||||
break
|
||||
except httpx.TransportError:
|
||||
if attempt >= self.max_attempts:
|
||||
raise
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.aclose()
|
||||
|
||||
if response is None:
|
||||
raise RuntimeError("Gemini provider did not receive a response")
|
||||
|
||||
body = response.json()
|
||||
image_part = self._find_image_part(body)
|
||||
image_data = image_part.get("inlineData") or image_part.get("inline_data")
|
||||
mime_type = image_data.get("mimeType") or image_data.get("mime_type") or "image/png"
|
||||
data = image_data.get("data")
|
||||
if not data:
|
||||
raise ValueError("Gemini response did not include image bytes")
|
||||
|
||||
return GeneratedImageResult(
|
||||
image_bytes=base64.b64decode(data),
|
||||
mime_type=mime_type,
|
||||
provider=self.provider_name,
|
||||
model=self.model,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _find_image_part(body: dict) -> dict:
|
||||
candidates = body.get("candidates") or []
|
||||
for candidate in candidates:
|
||||
content = candidate.get("content") or {}
|
||||
for part in content.get("parts") or []:
|
||||
if part.get("inlineData") or part.get("inline_data"):
|
||||
return part
|
||||
raise ValueError("Gemini response did not contain an image part")
|
||||
107
app/infra/storage/s3.py
Normal file
107
app/infra/storage/s3.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""S3 direct-upload helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import boto3
|
||||
|
||||
from app.config.settings import get_settings
|
||||
from app.domain.enums import LibraryResourceType, WorkflowStepName
|
||||
|
||||
RESOURCE_PREFIXES: dict[LibraryResourceType, str] = {
|
||||
LibraryResourceType.MODEL: "models",
|
||||
LibraryResourceType.SCENE: "scenes",
|
||||
LibraryResourceType.GARMENT: "garments",
|
||||
}
|
||||
|
||||
|
||||
class S3PresignService:
|
||||
"""Generate presigned upload URLs and derived public URLs."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.settings = get_settings()
|
||||
self._client = boto3.client(
|
||||
"s3",
|
||||
region_name=self.settings.s3_region or None,
|
||||
endpoint_url=self.settings.s3_endpoint or None,
|
||||
aws_access_key_id=self.settings.s3_access_key or None,
|
||||
aws_secret_access_key=self.settings.s3_secret_key or None,
|
||||
)
|
||||
|
||||
def create_upload(self, resource_type: LibraryResourceType, file_name: str, content_type: str) -> tuple[str, str]:
|
||||
"""Return a storage key and presigned PUT URL for a resource file."""
|
||||
|
||||
storage_key = self._build_storage_key(resource_type, file_name)
|
||||
upload_url = self._client.generate_presigned_url(
|
||||
"put_object",
|
||||
Params={
|
||||
"Bucket": self.settings.s3_bucket,
|
||||
"Key": storage_key,
|
||||
"ContentType": content_type,
|
||||
},
|
||||
ExpiresIn=self.settings.s3_presign_expiry_seconds,
|
||||
HttpMethod="PUT",
|
||||
)
|
||||
return storage_key, upload_url
|
||||
|
||||
def get_public_url(self, storage_key: str) -> str:
|
||||
"""Return the public CDN URL for an uploaded object."""
|
||||
|
||||
if self.settings.s3_cname:
|
||||
base = self.settings.s3_cname
|
||||
if not base.startswith("http://") and not base.startswith("https://"):
|
||||
base = f"https://{base}"
|
||||
return f"{base.rstrip('/')}/{storage_key}"
|
||||
|
||||
endpoint = self.settings.s3_endpoint.rstrip("/")
|
||||
return f"{endpoint}/{self.settings.s3_bucket}/{storage_key}"
|
||||
|
||||
def _build_storage_key(self, resource_type: LibraryResourceType, file_name: str) -> str:
|
||||
suffix = Path(file_name).suffix or ".bin"
|
||||
stem = Path(file_name).stem.replace(" ", "-").lower() or "file"
|
||||
return f"library/{RESOURCE_PREFIXES[resource_type]}/{uuid4().hex}-{stem}{suffix.lower()}"
|
||||
|
||||
|
||||
class S3ObjectStorageService:
|
||||
"""Upload generated workflow artifacts to the configured object store."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.settings = get_settings()
|
||||
self._client = boto3.client(
|
||||
"s3",
|
||||
region_name=self.settings.s3_region or None,
|
||||
endpoint_url=self.settings.s3_endpoint or None,
|
||||
aws_access_key_id=self.settings.s3_access_key or None,
|
||||
aws_secret_access_key=self.settings.s3_secret_key or None,
|
||||
)
|
||||
self._presign = S3PresignService()
|
||||
|
||||
async def upload_generated_image(
|
||||
self,
|
||||
*,
|
||||
order_id: int,
|
||||
step_name: WorkflowStepName,
|
||||
image_bytes: bytes,
|
||||
mime_type: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Upload bytes and return the storage key plus public URL."""
|
||||
|
||||
storage_key = self._build_storage_key(order_id=order_id, step_name=step_name, mime_type=mime_type)
|
||||
self._client.put_object(
|
||||
Bucket=self.settings.s3_bucket,
|
||||
Key=storage_key,
|
||||
Body=image_bytes,
|
||||
ContentType=mime_type,
|
||||
)
|
||||
return storage_key, self._presign.get_public_url(storage_key)
|
||||
|
||||
@staticmethod
|
||||
def _build_storage_key(*, order_id: int, step_name: WorkflowStepName, mime_type: str) -> str:
|
||||
suffix = {
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/webp": ".webp",
|
||||
}.get(mime_type, ".bin")
|
||||
return f"orders/{order_id}/{step_name.value}/{uuid4().hex}{suffix}"
|
||||
@@ -6,6 +6,7 @@ from fastapi import FastAPI
|
||||
|
||||
from app.api.routers.assets import router as assets_router
|
||||
from app.api.routers.health import router as health_router
|
||||
from app.api.routers.library import router as library_router
|
||||
from app.api.routers.orders import router as orders_router
|
||||
from app.api.routers.revisions import router as revisions_router
|
||||
from app.api.routers.reviews import router as reviews_router
|
||||
@@ -30,6 +31,7 @@ def create_app() -> FastAPI:
|
||||
settings = get_settings()
|
||||
app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan)
|
||||
app.include_router(health_router)
|
||||
app.include_router(library_router, prefix=settings.api_prefix)
|
||||
app.include_router(orders_router, prefix=settings.api_prefix)
|
||||
app.include_router(assets_router, prefix=settings.api_prefix)
|
||||
app.include_router(revisions_router, prefix=settings.api_prefix)
|
||||
|
||||
@@ -1,20 +1,82 @@
|
||||
"""Export mock activity."""
|
||||
"""Export activity."""
|
||||
|
||||
from app.domain.enums import AssetType, OrderStatus, StepStatus
|
||||
from app.infra.db.models.asset import AssetORM
|
||||
from app.infra.db.session import get_session_factory
|
||||
from temporalio import activity
|
||||
|
||||
from app.domain.enums import AssetType
|
||||
from app.workers.activities.tryon_activities import execute_asset_step
|
||||
from app.workers.activities.tryon_activities import create_step_record, jsonable, load_order_and_run, utc_now
|
||||
from app.workers.workflows.types import MockActivityResult, StepActivityInput
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def run_export_activity(payload: StepActivityInput) -> MockActivityResult:
|
||||
"""Mock final asset export."""
|
||||
"""Finalize the chosen source asset as the order's exported deliverable."""
|
||||
|
||||
return await execute_asset_step(
|
||||
payload,
|
||||
AssetType.FINAL,
|
||||
filename="final.png",
|
||||
finalize=True,
|
||||
)
|
||||
if payload.source_asset_id is None:
|
||||
raise ValueError("run_export_activity requires source_asset_id")
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
||||
step = create_step_record(payload)
|
||||
session.add(step)
|
||||
|
||||
order.status = OrderStatus.RUNNING
|
||||
workflow_run.status = OrderStatus.RUNNING
|
||||
workflow_run.current_step = payload.step_name
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
source_asset = await session.get(AssetORM, payload.source_asset_id)
|
||||
if source_asset is None:
|
||||
raise ValueError(f"Source asset {payload.source_asset_id} not found")
|
||||
if source_asset.order_id != payload.order_id:
|
||||
raise ValueError(
|
||||
f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}"
|
||||
)
|
||||
|
||||
metadata = {
|
||||
**payload.metadata,
|
||||
"source_asset_id": payload.source_asset_id,
|
||||
"selected_asset_id": payload.selected_asset_id,
|
||||
}
|
||||
metadata = {key: value for key, value in metadata.items() if value is not None}
|
||||
|
||||
asset = AssetORM(
|
||||
order_id=payload.order_id,
|
||||
asset_type=AssetType.FINAL,
|
||||
step_name=payload.step_name,
|
||||
uri=source_asset.uri,
|
||||
metadata_json=jsonable(metadata),
|
||||
)
|
||||
session.add(asset)
|
||||
await session.flush()
|
||||
|
||||
result = MockActivityResult(
|
||||
step_name=payload.step_name,
|
||||
success=True,
|
||||
asset_id=asset.id,
|
||||
uri=asset.uri,
|
||||
score=0.95,
|
||||
passed=True,
|
||||
message="mock success",
|
||||
metadata=jsonable(metadata) or {},
|
||||
)
|
||||
|
||||
order.final_asset_id = asset.id
|
||||
order.status = OrderStatus.SUCCEEDED
|
||||
workflow_run.status = OrderStatus.SUCCEEDED
|
||||
|
||||
step.step_status = StepStatus.SUCCEEDED
|
||||
step.output_json = jsonable(result)
|
||||
step.ended_at = utc_now()
|
||||
await session.commit()
|
||||
return result
|
||||
except Exception as exc:
|
||||
step.step_status = StepStatus.FAILED
|
||||
step.error_message = str(exc)
|
||||
step.ended_at = utc_now()
|
||||
order.status = OrderStatus.FAILED
|
||||
workflow_run.status = OrderStatus.FAILED
|
||||
await session.commit()
|
||||
raise
|
||||
|
||||
@@ -29,11 +29,22 @@ async def run_qc_activity(payload: StepActivityInput) -> MockActivityResult:
|
||||
candidate_uri: str | None = None
|
||||
|
||||
if passed:
|
||||
if payload.source_asset_id is None:
|
||||
raise ValueError("run_qc_activity requires source_asset_id")
|
||||
|
||||
source_asset = await session.get(AssetORM, payload.source_asset_id)
|
||||
if source_asset is None:
|
||||
raise ValueError(f"Source asset {payload.source_asset_id} not found")
|
||||
if source_asset.order_id != payload.order_id:
|
||||
raise ValueError(
|
||||
f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}"
|
||||
)
|
||||
|
||||
candidate = AssetORM(
|
||||
order_id=payload.order_id,
|
||||
asset_type=AssetType.QC_CANDIDATE,
|
||||
step_name=payload.step_name,
|
||||
uri=mock_uri(payload.order_id, payload.step_name.value, "candidate.png"),
|
||||
uri=source_asset.uri,
|
||||
metadata_json=jsonable({"source_asset_id": payload.source_asset_id}),
|
||||
)
|
||||
session.add(candidate)
|
||||
|
||||
@@ -1,19 +1,122 @@
|
||||
"""Scene mock activity."""
|
||||
"""Scene activities."""
|
||||
|
||||
from temporalio import activity
|
||||
|
||||
from app.domain.enums import AssetType
|
||||
from app.workers.activities.tryon_activities import execute_asset_step
|
||||
from app.domain.enums import AssetType, LibraryResourceType, OrderStatus, StepStatus
|
||||
from app.infra.db.models.asset import AssetORM
|
||||
from app.infra.db.session import get_session_factory
|
||||
from app.workers.activities.tryon_activities import (
|
||||
create_step_record,
|
||||
execute_asset_step,
|
||||
get_image_generation_service,
|
||||
get_order_artifact_storage_service,
|
||||
jsonable,
|
||||
load_active_library_resource,
|
||||
load_order_and_run,
|
||||
utc_now,
|
||||
)
|
||||
from app.workers.workflows.types import MockActivityResult, StepActivityInput
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def run_scene_activity(payload: StepActivityInput) -> MockActivityResult:
|
||||
"""Mock scene replacement."""
|
||||
"""Generate a scene-composited asset, or fall back to mock mode when configured."""
|
||||
|
||||
return await execute_asset_step(
|
||||
payload,
|
||||
AssetType.SCENE,
|
||||
extra_metadata={"scene_ref_asset_id": payload.scene_ref_asset_id},
|
||||
)
|
||||
service = get_image_generation_service()
|
||||
if service.__class__.__name__ == "MockImageGenerationService":
|
||||
return await execute_asset_step(
|
||||
payload,
|
||||
AssetType.SCENE,
|
||||
extra_metadata={"scene_ref_asset_id": payload.scene_ref_asset_id},
|
||||
)
|
||||
|
||||
if payload.source_asset_id is None:
|
||||
raise ValueError("run_scene_activity requires source_asset_id")
|
||||
if payload.scene_ref_asset_id is None:
|
||||
raise ValueError("run_scene_activity requires scene_ref_asset_id")
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
||||
step = create_step_record(payload)
|
||||
session.add(step)
|
||||
|
||||
order.status = OrderStatus.RUNNING
|
||||
workflow_run.status = OrderStatus.RUNNING
|
||||
workflow_run.current_step = payload.step_name
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
source_asset = await session.get(AssetORM, payload.source_asset_id)
|
||||
if source_asset is None:
|
||||
raise ValueError(f"Source asset {payload.source_asset_id} not found")
|
||||
if source_asset.order_id != payload.order_id:
|
||||
raise ValueError(
|
||||
f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}"
|
||||
)
|
||||
|
||||
scene_resource, scene_original = await load_active_library_resource(
|
||||
session,
|
||||
payload.scene_ref_asset_id,
|
||||
resource_type=LibraryResourceType.SCENE,
|
||||
)
|
||||
|
||||
generated = await service.generate_scene_image(
|
||||
source_image_url=source_asset.uri,
|
||||
scene_image_url=scene_original.public_url,
|
||||
)
|
||||
storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image(
|
||||
order_id=payload.order_id,
|
||||
step_name=payload.step_name,
|
||||
image_bytes=generated.image_bytes,
|
||||
mime_type=generated.mime_type,
|
||||
)
|
||||
|
||||
metadata = {
|
||||
**payload.metadata,
|
||||
"source_asset_id": source_asset.id,
|
||||
"scene_ref_asset_id": payload.scene_ref_asset_id,
|
||||
"scene_resource_id": scene_resource.id,
|
||||
"scene_original_file_id": scene_original.id,
|
||||
"scene_original_url": scene_original.public_url,
|
||||
"provider": generated.provider,
|
||||
"model": generated.model,
|
||||
"storage_key": storage_key,
|
||||
"mime_type": generated.mime_type,
|
||||
"prompt": generated.prompt,
|
||||
}
|
||||
metadata = {key: value for key, value in metadata.items() if value is not None}
|
||||
|
||||
asset = AssetORM(
|
||||
order_id=payload.order_id,
|
||||
asset_type=AssetType.SCENE,
|
||||
step_name=payload.step_name,
|
||||
uri=public_url,
|
||||
metadata_json=jsonable(metadata),
|
||||
)
|
||||
session.add(asset)
|
||||
await session.flush()
|
||||
|
||||
result = MockActivityResult(
|
||||
step_name=payload.step_name,
|
||||
success=True,
|
||||
asset_id=asset.id,
|
||||
uri=asset.uri,
|
||||
score=1.0,
|
||||
passed=True,
|
||||
message="scene generated",
|
||||
metadata=jsonable(metadata) or {},
|
||||
)
|
||||
|
||||
step.step_status = StepStatus.SUCCEEDED
|
||||
step.output_json = jsonable(result)
|
||||
step.ended_at = utc_now()
|
||||
await session.commit()
|
||||
return result
|
||||
except Exception as exc:
|
||||
step.step_status = StepStatus.FAILED
|
||||
step.error_message = str(exc)
|
||||
step.ended_at = utc_now()
|
||||
order.status = OrderStatus.FAILED
|
||||
workflow_run.status = OrderStatus.FAILED
|
||||
await session.commit()
|
||||
raise
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Prepare-model and try-on mock activities plus shared helpers."""
|
||||
"""Prepare-model and try-on activities plus shared helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -8,14 +8,20 @@ from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from temporalio import activity
|
||||
|
||||
from app.domain.enums import AssetType, OrderStatus, StepStatus
|
||||
from app.application.services.image_generation_service import build_image_generation_service
|
||||
from app.domain.enums import AssetType, LibraryResourceStatus, LibraryResourceType, OrderStatus, StepStatus
|
||||
from app.infra.db.models.asset import AssetORM
|
||||
from app.infra.db.models.library_resource import LibraryResourceORM
|
||||
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
|
||||
from app.infra.db.models.order import OrderORM
|
||||
from app.infra.db.models.workflow_run import WorkflowRunORM
|
||||
from app.infra.db.models.workflow_step import WorkflowStepORM
|
||||
from app.infra.db.session import get_session_factory
|
||||
from app.infra.storage.s3 import S3ObjectStorageService
|
||||
from app.workers.workflows.types import MockActivityResult, StepActivityInput
|
||||
|
||||
|
||||
@@ -71,6 +77,64 @@ def create_step_record(payload: StepActivityInput) -> WorkflowStepORM:
|
||||
)
|
||||
|
||||
|
||||
async def load_active_library_resource(
|
||||
session,
|
||||
resource_id: int,
|
||||
*,
|
||||
resource_type: LibraryResourceType,
|
||||
) -> tuple[LibraryResourceORM, LibraryResourceFileORM]:
|
||||
"""Load an active library resource and its original file from the library."""
|
||||
|
||||
result = await session.execute(
|
||||
select(LibraryResourceORM)
|
||||
.options(selectinload(LibraryResourceORM.files))
|
||||
.where(LibraryResourceORM.id == resource_id)
|
||||
)
|
||||
resource = result.scalar_one_or_none()
|
||||
if resource is None:
|
||||
raise ValueError(f"Library resource {resource_id} not found")
|
||||
if resource.resource_type != resource_type:
|
||||
raise ValueError(f"Resource {resource_id} is not a {resource_type.value} resource")
|
||||
if resource.status != LibraryResourceStatus.ACTIVE:
|
||||
raise ValueError(f"Library resource {resource_id} is not active")
|
||||
if resource.original_file_id is None:
|
||||
raise ValueError(f"Library resource {resource_id} is missing an original file")
|
||||
|
||||
original_file = next((item for item in resource.files if item.id == resource.original_file_id), None)
|
||||
if original_file is None:
|
||||
raise ValueError(f"Library resource {resource_id} original file record not found")
|
||||
return resource, original_file
|
||||
|
||||
|
||||
def get_image_generation_service():
|
||||
"""Return the configured image-generation service."""
|
||||
|
||||
return build_image_generation_service()
|
||||
|
||||
|
||||
def get_order_artifact_storage_service() -> S3ObjectStorageService:
|
||||
"""Return the object-storage service for workflow-generated images."""
|
||||
|
||||
return S3ObjectStorageService()
|
||||
|
||||
|
||||
def build_resource_input_snapshot(
|
||||
resource: LibraryResourceORM,
|
||||
original_file: LibraryResourceFileORM,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a frontend-friendly snapshot of one library input resource."""
|
||||
|
||||
return {
|
||||
"resource_id": resource.id,
|
||||
"resource_name": resource.name,
|
||||
"original_file_id": original_file.id,
|
||||
"original_url": original_file.public_url,
|
||||
"mime_type": original_file.mime_type,
|
||||
"width": original_file.width,
|
||||
"height": original_file.height,
|
||||
}
|
||||
|
||||
|
||||
async def execute_asset_step(
|
||||
payload: StepActivityInput,
|
||||
asset_type: AssetType,
|
||||
@@ -145,26 +209,213 @@ async def execute_asset_step(
|
||||
|
||||
@activity.defn
|
||||
async def prepare_model_activity(payload: StepActivityInput) -> MockActivityResult:
|
||||
"""Mock model preparation for the pipeline."""
|
||||
"""Resolve a model resource into an order-scoped prepared-model asset."""
|
||||
|
||||
return await execute_asset_step(
|
||||
payload,
|
||||
AssetType.PREPARED_MODEL,
|
||||
extra_metadata={
|
||||
"model_id": payload.model_id,
|
||||
"pose_id": payload.pose_id,
|
||||
"garment_asset_id": payload.garment_asset_id,
|
||||
"scene_ref_asset_id": payload.scene_ref_asset_id,
|
||||
},
|
||||
)
|
||||
if payload.model_id is None:
|
||||
raise ValueError("prepare_model_activity requires model_id")
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
||||
step = create_step_record(payload)
|
||||
session.add(step)
|
||||
|
||||
order.status = OrderStatus.RUNNING
|
||||
workflow_run.status = OrderStatus.RUNNING
|
||||
workflow_run.current_step = payload.step_name
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
resource, original_file = await load_active_library_resource(
|
||||
session,
|
||||
payload.model_id,
|
||||
resource_type=LibraryResourceType.MODEL,
|
||||
)
|
||||
garment_snapshot = None
|
||||
if payload.garment_asset_id is not None:
|
||||
garment_resource, garment_original = await load_active_library_resource(
|
||||
session,
|
||||
payload.garment_asset_id,
|
||||
resource_type=LibraryResourceType.GARMENT,
|
||||
)
|
||||
garment_snapshot = build_resource_input_snapshot(garment_resource, garment_original)
|
||||
|
||||
scene_snapshot = None
|
||||
if payload.scene_ref_asset_id is not None:
|
||||
scene_resource, scene_original = await load_active_library_resource(
|
||||
session,
|
||||
payload.scene_ref_asset_id,
|
||||
resource_type=LibraryResourceType.SCENE,
|
||||
)
|
||||
scene_snapshot = build_resource_input_snapshot(scene_resource, scene_original)
|
||||
|
||||
metadata = {
|
||||
**payload.metadata,
|
||||
"model_id": payload.model_id,
|
||||
"pose_id": payload.pose_id,
|
||||
"garment_asset_id": payload.garment_asset_id,
|
||||
"scene_ref_asset_id": payload.scene_ref_asset_id,
|
||||
"library_resource_id": resource.id,
|
||||
"library_original_file_id": original_file.id,
|
||||
"library_original_url": original_file.public_url,
|
||||
"library_original_mime_type": original_file.mime_type,
|
||||
"library_original_width": original_file.width,
|
||||
"library_original_height": original_file.height,
|
||||
"model_input": build_resource_input_snapshot(resource, original_file),
|
||||
"garment_input": garment_snapshot,
|
||||
"scene_input": scene_snapshot,
|
||||
"pose_input": {"pose_id": payload.pose_id} if payload.pose_id is not None else None,
|
||||
"normalized": False,
|
||||
}
|
||||
metadata = {key: value for key, value in metadata.items() if value is not None}
|
||||
|
||||
asset = AssetORM(
|
||||
order_id=payload.order_id,
|
||||
asset_type=AssetType.PREPARED_MODEL,
|
||||
step_name=payload.step_name,
|
||||
uri=original_file.public_url,
|
||||
metadata_json=jsonable(metadata),
|
||||
)
|
||||
session.add(asset)
|
||||
await session.flush()
|
||||
|
||||
result = MockActivityResult(
|
||||
step_name=payload.step_name,
|
||||
success=True,
|
||||
asset_id=asset.id,
|
||||
uri=asset.uri,
|
||||
score=1.0,
|
||||
passed=True,
|
||||
message="prepared model ready",
|
||||
metadata=jsonable(metadata) or {},
|
||||
)
|
||||
|
||||
step.step_status = StepStatus.SUCCEEDED
|
||||
step.output_json = jsonable(result)
|
||||
step.ended_at = utc_now()
|
||||
await session.commit()
|
||||
return result
|
||||
except Exception as exc:
|
||||
step.step_status = StepStatus.FAILED
|
||||
step.error_message = str(exc)
|
||||
step.ended_at = utc_now()
|
||||
order.status = OrderStatus.FAILED
|
||||
workflow_run.status = OrderStatus.FAILED
|
||||
await session.commit()
|
||||
raise
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def run_tryon_activity(payload: StepActivityInput) -> MockActivityResult:
|
||||
"""Mock try-on rendering."""
|
||||
"""执行试衣渲染步骤,或在未接真实能力时走 mock 分支。
|
||||
|
||||
return await execute_asset_step(
|
||||
payload,
|
||||
AssetType.TRYON,
|
||||
extra_metadata={"prepared_asset_id": payload.source_asset_id},
|
||||
)
|
||||
流程:
|
||||
1. 读取当前配置的图片生成 service。
|
||||
2. 如果当前仍是 mock service,就退回通用的 mock 资产产出逻辑。
|
||||
3. 如果是真实模式,先读取 prepare_model_activity 产出的 prepared_model 资产。
|
||||
4. 再读取本次选中的服装资源,并解析出它的原图 URL。
|
||||
5. 把“模特准备图 + 服装原图”一起发给 provider 做试衣生成。
|
||||
6. 把生成结果上传到 S3,并落成订单内的一条 TRYON 资产。
|
||||
7. 在 metadata 里记录 provider、model、prompt 和输入来源,方便追踪排查。
|
||||
"""
|
||||
|
||||
service = get_image_generation_service()
|
||||
# 保留旧的 mock 分支,这样在没有接入真实 provider 时 workflow 也还能跑通。
|
||||
if service.__class__.__name__ == "MockImageGenerationService":
|
||||
return await execute_asset_step(
|
||||
payload,
|
||||
AssetType.TRYON,
|
||||
extra_metadata={"prepared_asset_id": payload.source_asset_id},
|
||||
)
|
||||
|
||||
if payload.source_asset_id is None:
|
||||
raise ValueError("run_tryon_activity requires source_asset_id")
|
||||
if payload.garment_asset_id is None:
|
||||
raise ValueError("run_tryon_activity requires garment_asset_id")
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id)
|
||||
step = create_step_record(payload)
|
||||
session.add(step)
|
||||
|
||||
order.status = OrderStatus.RUNNING
|
||||
workflow_run.status = OrderStatus.RUNNING
|
||||
workflow_run.current_step = payload.step_name
|
||||
await session.flush()
|
||||
|
||||
try:
|
||||
# 试衣步骤的输入起点固定是上一阶段产出的 prepared_model 资产。
|
||||
prepared_asset = await session.get(AssetORM, payload.source_asset_id)
|
||||
if prepared_asset is None or prepared_asset.order_id != payload.order_id:
|
||||
raise ValueError(f"Prepared asset {payload.source_asset_id} not found for order {payload.order_id}")
|
||||
if prepared_asset.asset_type != AssetType.PREPARED_MODEL:
|
||||
raise ValueError(f"Asset {payload.source_asset_id} is not a prepared_model asset")
|
||||
|
||||
# 服装素材来自共享资源库,workflow 实际消费的是资源库里的原图。
|
||||
garment_resource, garment_original = await load_active_library_resource(
|
||||
session,
|
||||
payload.garment_asset_id,
|
||||
resource_type=LibraryResourceType.GARMENT,
|
||||
)
|
||||
|
||||
# provider 接收两张真实输入图:模特准备图 + 服装参考图。
|
||||
generated = await service.generate_tryon_image(
|
||||
person_image_url=prepared_asset.uri,
|
||||
garment_image_url=garment_original.public_url,
|
||||
)
|
||||
# workflow 生成出的结果图先上传到 S3,再登记成订单资产。
|
||||
storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image(
|
||||
order_id=payload.order_id,
|
||||
step_name=payload.step_name,
|
||||
image_bytes=generated.image_bytes,
|
||||
mime_type=generated.mime_type,
|
||||
)
|
||||
|
||||
# 记录足够的追踪信息,方便回溯这张试衣图由哪些输入和哪个 provider 产出。
|
||||
metadata = {
|
||||
**payload.metadata,
|
||||
"prepared_asset_id": prepared_asset.id,
|
||||
"garment_resource_id": garment_resource.id,
|
||||
"garment_original_file_id": garment_original.id,
|
||||
"garment_original_url": garment_original.public_url,
|
||||
"provider": generated.provider,
|
||||
"model": generated.model,
|
||||
"storage_key": storage_key,
|
||||
"mime_type": generated.mime_type,
|
||||
"prompt": generated.prompt,
|
||||
}
|
||||
metadata = {key: value for key, value in metadata.items() if value is not None}
|
||||
|
||||
asset = AssetORM(
|
||||
order_id=payload.order_id,
|
||||
asset_type=AssetType.TRYON,
|
||||
step_name=payload.step_name,
|
||||
uri=public_url,
|
||||
metadata_json=jsonable(metadata),
|
||||
)
|
||||
session.add(asset)
|
||||
await session.flush()
|
||||
|
||||
result = MockActivityResult(
|
||||
step_name=payload.step_name,
|
||||
success=True,
|
||||
asset_id=asset.id,
|
||||
uri=asset.uri,
|
||||
score=1.0,
|
||||
passed=True,
|
||||
message="try-on generated",
|
||||
metadata=jsonable(metadata) or {},
|
||||
)
|
||||
|
||||
step.step_status = StepStatus.SUCCEEDED
|
||||
step.output_json = jsonable(result)
|
||||
step.ended_at = utc_now()
|
||||
await session.commit()
|
||||
return result
|
||||
except Exception as exc:
|
||||
step.step_status = StepStatus.FAILED
|
||||
step.error_message = str(exc)
|
||||
step.ended_at = utc_now()
|
||||
order.status = OrderStatus.FAILED
|
||||
workflow_run.status = OrderStatus.FAILED
|
||||
await session.commit()
|
||||
raise
|
||||
|
||||
@@ -26,14 +26,13 @@ with workflow.unsafe.imports_passed_through():
|
||||
from app.workers.activities.review_activities import mark_workflow_failed_activity
|
||||
from app.workers.activities.scene_activities import run_scene_activity
|
||||
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
|
||||
from app.workers.workflows.timeout_policy import DEFAULT_ACTIVITY_TIMEOUT, activity_timeout_for_task_queue
|
||||
from app.workers.workflows.types import (
|
||||
PipelineWorkflowInput,
|
||||
StepActivityInput,
|
||||
WorkflowFailureActivityInput,
|
||||
)
|
||||
|
||||
|
||||
ACTIVITY_TIMEOUT = timedelta(seconds=30)
|
||||
ACTIVITY_RETRY_POLICY = RetryPolicy(
|
||||
initial_interval=timedelta(seconds=1),
|
||||
backoff_coefficient=2.0,
|
||||
@@ -64,6 +63,8 @@ class LowEndPipelineWorkflow:
|
||||
try:
|
||||
# 每个步骤都通过 execute_activity 触发真正执行。
|
||||
# workflow 自己不做计算,只负责调度。
|
||||
# prepare_model 的职责是把“模特资源 + 订单上下文”整理成后续可消费的人物底图。
|
||||
# 这一步产出的 prepared.asset_id 会作为 tryon 的 source_asset_id。
|
||||
prepared = await workflow.execute_activity(
|
||||
prepare_model_activity,
|
||||
StepActivityInput(
|
||||
@@ -76,12 +77,14 @@ class LowEndPipelineWorkflow:
|
||||
scene_ref_asset_id=payload.scene_ref_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
current_step = WorkflowStepName.TRYON
|
||||
# 下游步骤通过 source_asset_id 引用上一步生成的资产。
|
||||
# tryon 是换装主步骤:把 prepared 模特图和 garment_asset_id 组合成试衣结果。
|
||||
# 它产出的 tryon_result.asset_id 是后续 scene / qc 的基础输入。
|
||||
tryon_result = await workflow.execute_activity(
|
||||
run_tryon_activity,
|
||||
StepActivityInput(
|
||||
@@ -92,38 +95,46 @@ class LowEndPipelineWorkflow:
|
||||
garment_asset_id=payload.garment_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
current_step = WorkflowStepName.SCENE
|
||||
scene_result = await workflow.execute_activity(
|
||||
run_scene_activity,
|
||||
StepActivityInput(
|
||||
order_id=payload.order_id,
|
||||
workflow_run_id=payload.workflow_run_id,
|
||||
step_name=WorkflowStepName.SCENE,
|
||||
source_asset_id=tryon_result.asset_id,
|
||||
scene_ref_asset_id=payload.scene_ref_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
scene_source_asset_id = tryon_result.asset_id
|
||||
if payload.scene_ref_asset_id is not None:
|
||||
current_step = WorkflowStepName.SCENE
|
||||
# scene 是可选步骤。
|
||||
# 有场景图时,用 scene_ref_asset_id 把 tryon 结果合成到目标背景;
|
||||
# 没有场景图时,直接沿用 tryon 结果继续往下走。
|
||||
scene_result = await workflow.execute_activity(
|
||||
run_scene_activity,
|
||||
StepActivityInput(
|
||||
order_id=payload.order_id,
|
||||
workflow_run_id=payload.workflow_run_id,
|
||||
step_name=WorkflowStepName.SCENE,
|
||||
source_asset_id=tryon_result.asset_id,
|
||||
scene_ref_asset_id=payload.scene_ref_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
scene_source_asset_id = scene_result.asset_id
|
||||
|
||||
current_step = WorkflowStepName.QC
|
||||
# QC 是流程里的“闸门”。
|
||||
# 如果这里不通过,低端流程直接失败,不会再 export。
|
||||
# source_asset_id 指向当前链路里“最接近最终成品”的那张图:
|
||||
# 有场景时是 scene 结果,没有场景时就是 tryon 结果。
|
||||
qc_result = await workflow.execute_activity(
|
||||
run_qc_activity,
|
||||
StepActivityInput(
|
||||
order_id=payload.order_id,
|
||||
workflow_run_id=payload.workflow_run_id,
|
||||
step_name=WorkflowStepName.QC,
|
||||
source_asset_id=scene_result.asset_id,
|
||||
source_asset_id=scene_source_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
@@ -134,16 +145,17 @@ class LowEndPipelineWorkflow:
|
||||
current_step = WorkflowStepName.EXPORT
|
||||
# candidate_asset_ids 是 QC 推荐可导出的候选资产。
|
||||
# 当前 MVP 只会返回一个候选;如果没有,就退回 scene 结果导出。
|
||||
# export 是最后的收口步骤:生成最终交付资产,并把订单状态写成 succeeded。
|
||||
final_result = await workflow.execute_activity(
|
||||
run_export_activity,
|
||||
StepActivityInput(
|
||||
order_id=payload.order_id,
|
||||
workflow_run_id=payload.workflow_run_id,
|
||||
step_name=WorkflowStepName.EXPORT,
|
||||
source_asset_id=(qc_result.candidate_asset_ids or [scene_result.asset_id])[0],
|
||||
source_asset_id=(qc_result.candidate_asset_ids or [scene_source_asset_id])[0],
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
return {
|
||||
@@ -177,6 +189,6 @@ class LowEndPipelineWorkflow:
|
||||
message=message,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=DEFAULT_ACTIVITY_TIMEOUT,
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
@@ -34,6 +34,7 @@ with workflow.unsafe.imports_passed_through():
|
||||
from app.workers.activities.scene_activities import run_scene_activity
|
||||
from app.workers.activities.texture_activities import run_texture_activity
|
||||
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
|
||||
from app.workers.workflows.timeout_policy import DEFAULT_ACTIVITY_TIMEOUT, activity_timeout_for_task_queue
|
||||
from app.workers.workflows.types import (
|
||||
MockActivityResult,
|
||||
PipelineWorkflowInput,
|
||||
@@ -44,8 +45,6 @@ with workflow.unsafe.imports_passed_through():
|
||||
WorkflowFailureActivityInput,
|
||||
)
|
||||
|
||||
|
||||
ACTIVITY_TIMEOUT = timedelta(seconds=30)
|
||||
ACTIVITY_RETRY_POLICY = RetryPolicy(
|
||||
initial_interval=timedelta(seconds=1),
|
||||
backoff_coefficient=2.0,
|
||||
@@ -87,6 +86,8 @@ class MidEndPipelineWorkflow:
|
||||
# current_step 用于失败时记录“最后跑到哪一步”。
|
||||
current_step = WorkflowStepName.PREPARE_MODEL
|
||||
try:
|
||||
# prepare_model / tryon / scene 这三段和低端流程共享同一套 activity,
|
||||
# 区别只在于中端流程后面还会继续进入 texture / face / fusion / review。
|
||||
prepared = await workflow.execute_activity(
|
||||
prepare_model_activity,
|
||||
StepActivityInput(
|
||||
@@ -99,11 +100,12 @@ class MidEndPipelineWorkflow:
|
||||
scene_ref_asset_id=payload.scene_ref_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
current_step = WorkflowStepName.TRYON
|
||||
# tryon 产出基础换装图,后面的所有增强步骤都以它为起点。
|
||||
tryon_result = await workflow.execute_activity(
|
||||
run_tryon_activity,
|
||||
StepActivityInput(
|
||||
@@ -114,23 +116,37 @@ class MidEndPipelineWorkflow:
|
||||
garment_asset_id=payload.garment_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
current_step = WorkflowStepName.SCENE
|
||||
scene_result = await self._run_scene(payload, tryon_result.asset_id)
|
||||
scene_source_asset_id = tryon_result.asset_id
|
||||
if payload.scene_ref_asset_id is not None:
|
||||
current_step = WorkflowStepName.SCENE
|
||||
# scene 仍然是可选步骤。
|
||||
# 如果存在场景图,就先完成换背景,再把结果交给 texture / fusion;
|
||||
# 如果没有场景图,就直接把 tryon 结果当成“场景基底”继续流程。
|
||||
scene_result = await self._run_scene(payload, tryon_result.asset_id)
|
||||
scene_source_asset_id = scene_result.asset_id
|
||||
else:
|
||||
scene_result = tryon_result
|
||||
|
||||
current_step = WorkflowStepName.TEXTURE
|
||||
texture_result = await self._run_texture(payload, scene_result.asset_id)
|
||||
# texture 是复杂流独有步骤。
|
||||
# 它通常负责衣物纹理、材质细节或局部增强,输入是当前场景基底图。
|
||||
texture_result = await self._run_texture(payload, scene_source_asset_id)
|
||||
|
||||
current_step = WorkflowStepName.FACE
|
||||
# face 也是复杂流独有步骤。
|
||||
# 它在 texture 结果基础上做脸部修复/增强,产出供 fusion 合成的人像版本。
|
||||
face_result = await self._run_face(payload, texture_result.asset_id)
|
||||
|
||||
current_step = WorkflowStepName.FUSION
|
||||
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
|
||||
# fusion 负责把“场景基底”和“脸部增强结果”合成成候选成品。
|
||||
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
|
||||
|
||||
current_step = WorkflowStepName.QC
|
||||
# QC 和低端流程复用同一套 activity,但这里检查的是 fusion 产物。
|
||||
qc_result = await self._run_qc(payload, fusion_result.asset_id)
|
||||
if not qc_result.passed:
|
||||
await self._mark_failed(payload, current_step, qc_result.message)
|
||||
@@ -144,6 +160,8 @@ class MidEndPipelineWorkflow:
|
||||
current_step = WorkflowStepName.REVIEW
|
||||
# 这里通过 activity 把数据库里的订单状态更新成 waiting_review,
|
||||
# 同时创建 review_task,供 API 查询待审核列表。
|
||||
# review 是复杂流真正区别于低端流程的核心:
|
||||
# 在 export 前插入人工决策点,并支持按意见回流重跑。
|
||||
await workflow.execute_activity(
|
||||
mark_waiting_for_review_activity,
|
||||
ReviewWaitActivityInput(
|
||||
@@ -152,7 +170,7 @@ class MidEndPipelineWorkflow:
|
||||
candidate_asset_ids=qc_result.candidate_asset_ids,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
@@ -172,7 +190,7 @@ class MidEndPipelineWorkflow:
|
||||
comment=review_payload.comment,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
@@ -180,6 +198,7 @@ class MidEndPipelineWorkflow:
|
||||
current_step = WorkflowStepName.EXPORT
|
||||
# 如果审核人显式选了资产,就导出该资产;
|
||||
# 否则默认导出 QC 候选资产。
|
||||
# export 本身仍然和低端流程是同一个 activity。
|
||||
export_source_id = review_payload.selected_asset_id
|
||||
if export_source_id is None:
|
||||
export_source_id = (qc_result.candidate_asset_ids or [fusion_result.asset_id])[0]
|
||||
@@ -192,7 +211,7 @@ class MidEndPipelineWorkflow:
|
||||
source_asset_id=export_source_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
return {
|
||||
@@ -208,22 +227,30 @@ class MidEndPipelineWorkflow:
|
||||
# rerun 的核心思想是:
|
||||
# 把指定节点后的链路重新跑一遍,然后再次进入 QC 和 waiting_review。
|
||||
if review_payload.decision == ReviewDecision.RERUN_SCENE:
|
||||
current_step = WorkflowStepName.SCENE
|
||||
scene_result = await self._run_scene(payload, tryon_result.asset_id)
|
||||
# 从 scene 开始重跑意味着 scene / texture / face / fusion 全部重算。
|
||||
if payload.scene_ref_asset_id is not None:
|
||||
current_step = WorkflowStepName.SCENE
|
||||
scene_result = await self._run_scene(payload, tryon_result.asset_id)
|
||||
scene_source_asset_id = scene_result.asset_id
|
||||
else:
|
||||
scene_result = tryon_result
|
||||
scene_source_asset_id = tryon_result.asset_id
|
||||
current_step = WorkflowStepName.TEXTURE
|
||||
texture_result = await self._run_texture(payload, scene_result.asset_id)
|
||||
texture_result = await self._run_texture(payload, scene_source_asset_id)
|
||||
current_step = WorkflowStepName.FACE
|
||||
face_result = await self._run_face(payload, texture_result.asset_id)
|
||||
current_step = WorkflowStepName.FUSION
|
||||
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
|
||||
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
|
||||
elif review_payload.decision == ReviewDecision.RERUN_FACE:
|
||||
# 从 face 开始重跑时,scene / texture 结果保持不变。
|
||||
current_step = WorkflowStepName.FACE
|
||||
face_result = await self._run_face(payload, texture_result.asset_id)
|
||||
current_step = WorkflowStepName.FUSION
|
||||
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
|
||||
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
|
||||
elif review_payload.decision == ReviewDecision.RERUN_FUSION:
|
||||
# 从 fusion 重跑是最小范围回流,只重做最终合成。
|
||||
current_step = WorkflowStepName.FUSION
|
||||
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
|
||||
fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id)
|
||||
|
||||
current_step = WorkflowStepName.QC
|
||||
qc_result = await self._run_qc(payload, fusion_result.asset_id)
|
||||
@@ -264,7 +291,7 @@ class MidEndPipelineWorkflow:
|
||||
scene_ref_asset_id=payload.scene_ref_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
@@ -280,7 +307,7 @@ class MidEndPipelineWorkflow:
|
||||
source_asset_id=source_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
@@ -296,7 +323,7 @@ class MidEndPipelineWorkflow:
|
||||
source_asset_id=source_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
@@ -318,7 +345,7 @@ class MidEndPipelineWorkflow:
|
||||
selected_asset_id=face_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
@@ -334,7 +361,7 @@ class MidEndPipelineWorkflow:
|
||||
source_asset_id=source_asset_id,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE),
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
@@ -355,6 +382,6 @@ class MidEndPipelineWorkflow:
|
||||
message=message,
|
||||
),
|
||||
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
||||
start_to_close_timeout=ACTIVITY_TIMEOUT,
|
||||
start_to_close_timeout=DEFAULT_ACTIVITY_TIMEOUT,
|
||||
retry_policy=ACTIVITY_RETRY_POLICY,
|
||||
)
|
||||
|
||||
24
app/workers/workflows/timeout_policy.py
Normal file
24
app/workers/workflows/timeout_policy.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Shared activity timeout policy for Temporal workflows."""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from app.infra.temporal.task_queues import (
|
||||
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
||||
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
|
||||
)
|
||||
|
||||
DEFAULT_ACTIVITY_TIMEOUT = timedelta(seconds=30)
|
||||
LONG_RUNNING_ACTIVITY_TIMEOUT = timedelta(minutes=5)
|
||||
|
||||
LONG_RUNNING_TASK_QUEUES = {
|
||||
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
||||
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
|
||||
}
|
||||
|
||||
|
||||
def activity_timeout_for_task_queue(task_queue: str) -> timedelta:
|
||||
"""Return the timeout that matches the workload behind the given task queue."""
|
||||
|
||||
if task_queue in LONG_RUNNING_TASK_QUEUES:
|
||||
return LONG_RUNNING_ACTIVITY_TIMEOUT
|
||||
return DEFAULT_ACTIVITY_TIMEOUT
|
||||
@@ -32,14 +32,22 @@ class PipelineWorkflowInput:
|
||||
这是一张订单进入 workflow 时携带的最小上下文。
|
||||
"""
|
||||
|
||||
# 订单主键。整个 workflow 的所有步骤都会围绕这张订单落库、查状态。
|
||||
order_id: int
|
||||
# workflow_run 主键。用于把每一步 step 记录关联到同一次运行。
|
||||
workflow_run_id: int
|
||||
# 客户层级。决定业务侧订单归属,也会影响后续统计和展示。
|
||||
customer_level: CustomerLevel
|
||||
# 服务模式。这里直接决定启动简单流程还是复杂流程。
|
||||
service_mode: ServiceMode
|
||||
# 模特资源 ID。prepare_model 会基于它准备人物素材。
|
||||
model_id: int
|
||||
pose_id: int
|
||||
# 服装资源/资产 ID。tryon 会把它作为换装输入。
|
||||
garment_asset_id: int
|
||||
scene_ref_asset_id: int
|
||||
# 场景参考图 ID。可为空;为空时跳过 scene 步骤。
|
||||
scene_ref_asset_id: int | None = None
|
||||
# 模特姿势 ID。当前 MVP 已经放宽为可空,后续需要姿势控制时再启用。
|
||||
pose_id: int | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""在反序列化后把枚举字段修正回来。"""
|
||||
@@ -59,15 +67,25 @@ class StepActivityInput:
|
||||
- 上一步产出的 asset_id
|
||||
"""
|
||||
|
||||
# 当前步骤属于哪张订单。
|
||||
order_id: int
|
||||
# 当前步骤属于哪次 workflow 运行。
|
||||
workflow_run_id: int
|
||||
# 当前执行的步骤名,例如 tryon / qc / export。
|
||||
step_name: WorkflowStepName
|
||||
# 模特资源 ID。只有 prepare_model 等少数步骤会直接消费它。
|
||||
model_id: int | None = None
|
||||
# 姿势 ID。当前大多为透传预留字段。
|
||||
pose_id: int | None = None
|
||||
# 服装资源/资产 ID。tryon 是主要消费者。
|
||||
garment_asset_id: int | None = None
|
||||
# 场景参考图 ID。scene 步骤会用它完成换背景。
|
||||
scene_ref_asset_id: int | None = None
|
||||
# 上一步产出的资产 ID。绝大多数步骤都靠它串起资产链路。
|
||||
source_asset_id: int | None = None
|
||||
# 人工选中的资产 ID。review / fusion / export 等需要显式选图时使用。
|
||||
selected_asset_id: int | None = None
|
||||
# 额外扩展参数。给特殊步骤塞一些不值得单独建字段的临时上下文。
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -80,14 +98,23 @@ class StepActivityInput:
|
||||
class MockActivityResult:
|
||||
"""通用 mock activity 返回结构。"""
|
||||
|
||||
# 这是哪一步返回的结果,方便 workflow 和落库代码识别来源。
|
||||
step_name: WorkflowStepName
|
||||
# activity 是否执行成功。一般表示“代码执行成功”,不等于业务一定通过。
|
||||
success: bool
|
||||
# 这一步新生成的资产 ID;如果步骤不产出资产,可以为空。
|
||||
asset_id: int | None
|
||||
# 产出资产的访问地址;当前 mock 阶段通常是 mock:// URI。
|
||||
uri: str | None
|
||||
# 模型分数/质量分数之类的数值结果,非所有步骤都有。
|
||||
score: float | None = None
|
||||
# 业务是否通过。典型场景是 QC:activity 成功执行,但 passed 可能为 False。
|
||||
passed: bool | None = None
|
||||
# 给 workflow 或 API 展示的简短结果消息。
|
||||
message: str = "mock success"
|
||||
# 候选资产列表。主要给 QC / review 这类“多候选图”步骤使用。
|
||||
candidate_asset_ids: list[int] = field(default_factory=list)
|
||||
# 补充元数据。用于把步骤内部的一些上下文回传给调用方。
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -100,9 +127,13 @@ class MockActivityResult:
|
||||
class ReviewSignalPayload:
|
||||
"""API 发给中端 workflow 的审核 signal 载荷。"""
|
||||
|
||||
# 审核动作:通过 / 拒绝 / 从某一步重跑。
|
||||
decision: ReviewDecision
|
||||
# 操作审核的用户 ID,便于审计和任务归属。
|
||||
reviewer_id: int
|
||||
# 审核人最终选中的候选资产 ID;approve 时最常见。
|
||||
selected_asset_id: int | None = None
|
||||
# 审核备注,例如“脸部不自然,重跑融合”。
|
||||
comment: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -115,9 +146,13 @@ class ReviewSignalPayload:
|
||||
class ReviewWaitActivityInput:
|
||||
"""把流程切到 waiting_review 时传给 activity 的输入。"""
|
||||
|
||||
# 当前审核任务属于哪张订单。
|
||||
order_id: int
|
||||
# 当前审核任务属于哪次 workflow 运行。
|
||||
workflow_run_id: int
|
||||
# 提交给审核端看的候选资产列表。
|
||||
candidate_asset_ids: list[int] = field(default_factory=list)
|
||||
# 进入审核态时附带的说明文案,当前大多留空。
|
||||
comment: str | None = None
|
||||
|
||||
|
||||
@@ -125,11 +160,17 @@ class ReviewWaitActivityInput:
|
||||
class ReviewResolutionActivityInput:
|
||||
"""审核结果到达后,用于结束 waiting_review 的输入。"""
|
||||
|
||||
# 当前审核结果属于哪张订单。
|
||||
order_id: int
|
||||
# 当前审核结果属于哪次 workflow 运行。
|
||||
workflow_run_id: int
|
||||
# 审核最终决策。
|
||||
decision: ReviewDecision
|
||||
# 处理这次审核的审核人 ID。
|
||||
reviewer_id: int
|
||||
# 如果审核人明确挑了一张图,这里记录最终选择的资产 ID。
|
||||
selected_asset_id: int | None = None
|
||||
# 审核备注或重跑原因。
|
||||
comment: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -142,10 +183,15 @@ class ReviewResolutionActivityInput:
|
||||
class WorkflowFailureActivityInput:
|
||||
"""流程失败收尾 activity 的输入。"""
|
||||
|
||||
# 失败发生在哪张订单。
|
||||
order_id: int
|
||||
# 失败发生在哪次 workflow 运行。
|
||||
workflow_run_id: int
|
||||
# 失败停留在哪个步骤,用于更新 workflow_run.current_step。
|
||||
current_step: WorkflowStepName
|
||||
# 失败原因文本,通常来自异常或 QC 驳回消息。
|
||||
message: str
|
||||
# 要写回数据库的最终状态,默认就是 failed。
|
||||
status: OrderStatus = OrderStatus.FAILED
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
@@ -11,6 +11,7 @@ requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"aiosqlite>=0.20,<1.0",
|
||||
"alembic>=1.13,<2.0",
|
||||
"boto3>=1.35,<2.0",
|
||||
"fastapi>=0.115,<1.0",
|
||||
"greenlet>=3.1,<4.0",
|
||||
"httpx>=0.27,<1.0",
|
||||
|
||||
@@ -20,6 +20,13 @@ async def api_runtime(tmp_path, monkeypatch):
|
||||
db_path = tmp_path / "test.db"
|
||||
monkeypatch.setenv("DATABASE_URL", f"sqlite+aiosqlite:///{db_path.as_posix()}")
|
||||
monkeypatch.setenv("AUTO_CREATE_TABLES", "true")
|
||||
monkeypatch.setenv("S3_ACCESS_KEY", "test-access")
|
||||
monkeypatch.setenv("S3_SECRET_KEY", "test-secret")
|
||||
monkeypatch.setenv("S3_BUCKET", "test-bucket")
|
||||
monkeypatch.setenv("S3_REGION", "ap-southeast-1")
|
||||
monkeypatch.setenv("S3_ENDPOINT", "https://s3.example.com")
|
||||
monkeypatch.setenv("S3_CNAME", "images.example.com")
|
||||
monkeypatch.setenv("IMAGE_GENERATION_PROVIDER", "mock")
|
||||
|
||||
get_settings.cache_clear()
|
||||
await dispose_database()
|
||||
@@ -41,4 +48,3 @@ async def api_runtime(tmp_path, monkeypatch):
|
||||
set_temporal_client(None)
|
||||
await dispose_database()
|
||||
get_settings.cache_clear()
|
||||
|
||||
|
||||
@@ -39,15 +39,15 @@ async def wait_for_step_count(client, order_id: int, step_name: str, minimum_cou
|
||||
async def create_mid_end_order(client):
|
||||
"""Create a standard semi-pro order for review-path tests."""
|
||||
|
||||
resources = await create_workflow_resources(client, include_scene=True)
|
||||
response = await client.post(
|
||||
"/api/v1/orders",
|
||||
json={
|
||||
"customer_level": "mid",
|
||||
"service_mode": "semi_pro",
|
||||
"model_id": 101,
|
||||
"pose_id": 3,
|
||||
"garment_asset_id": 9001,
|
||||
"scene_ref_asset_id": 8001,
|
||||
"model_id": resources["model"]["id"],
|
||||
"garment_asset_id": resources["garment"]["id"],
|
||||
"scene_ref_asset_id": resources["scene"]["id"],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -55,6 +55,115 @@ async def create_mid_end_order(client):
|
||||
return response.json()
|
||||
|
||||
|
||||
async def create_library_resource(client, payload: dict) -> dict:
|
||||
"""Create a resource-library item for workflow integration tests."""
|
||||
|
||||
response = await client.post("/api/v1/library/resources", json=payload)
|
||||
assert response.status_code == 201
|
||||
return response.json()
|
||||
|
||||
|
||||
async def create_workflow_resources(client, *, include_scene: bool) -> dict[str, dict]:
|
||||
"""Create a minimal set of real library resources for workflow tests."""
|
||||
|
||||
model_resource = await create_library_resource(
|
||||
client,
|
||||
{
|
||||
"resource_type": "model",
|
||||
"name": "Ava Studio",
|
||||
"description": "棚拍女模特",
|
||||
"tags": ["女装", "棚拍"],
|
||||
"gender": "female",
|
||||
"age_group": "adult",
|
||||
"files": [
|
||||
{
|
||||
"file_role": "original",
|
||||
"storage_key": "library/models/ava/original.png",
|
||||
"public_url": "https://images.example.com/library/models/ava/original.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 1024,
|
||||
"sort_order": 0,
|
||||
"width": 1200,
|
||||
"height": 1600,
|
||||
},
|
||||
{
|
||||
"file_role": "thumbnail",
|
||||
"storage_key": "library/models/ava/thumb.png",
|
||||
"public_url": "https://images.example.com/library/models/ava/thumb.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 256,
|
||||
"sort_order": 0,
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
garment_resource = await create_library_resource(
|
||||
client,
|
||||
{
|
||||
"resource_type": "garment",
|
||||
"name": "Cream Dress",
|
||||
"description": "米白色连衣裙",
|
||||
"tags": ["女装"],
|
||||
"category": "dress",
|
||||
"files": [
|
||||
{
|
||||
"file_role": "original",
|
||||
"storage_key": "library/garments/cream-dress/original.png",
|
||||
"public_url": "https://images.example.com/library/garments/cream-dress/original.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 1024,
|
||||
"sort_order": 0,
|
||||
},
|
||||
{
|
||||
"file_role": "thumbnail",
|
||||
"storage_key": "library/garments/cream-dress/thumb.png",
|
||||
"public_url": "https://images.example.com/library/garments/cream-dress/thumb.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 256,
|
||||
"sort_order": 0,
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
resources = {
|
||||
"model": model_resource,
|
||||
"garment": garment_resource,
|
||||
}
|
||||
|
||||
if include_scene:
|
||||
resources["scene"] = await create_library_resource(
|
||||
client,
|
||||
{
|
||||
"resource_type": "scene",
|
||||
"name": "Loft Window",
|
||||
"description": "暖调室内场景",
|
||||
"tags": ["室内"],
|
||||
"environment": "indoor",
|
||||
"files": [
|
||||
{
|
||||
"file_role": "original",
|
||||
"storage_key": "library/scenes/loft/original.png",
|
||||
"public_url": "https://images.example.com/library/scenes/loft/original.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 1024,
|
||||
"sort_order": 0,
|
||||
},
|
||||
{
|
||||
"file_role": "thumbnail",
|
||||
"storage_key": "library/scenes/loft/thumb.png",
|
||||
"public_url": "https://images.example.com/library/scenes/loft/thumb.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 256,
|
||||
"sort_order": 0,
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
return resources
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthcheck(api_runtime):
|
||||
"""The health endpoint should always respond successfully."""
|
||||
@@ -71,15 +180,15 @@ async def test_low_end_order_completes(api_runtime):
|
||||
"""Low-end orders should run through the full automated pipeline."""
|
||||
|
||||
client, env = api_runtime
|
||||
resources = await create_workflow_resources(client, include_scene=True)
|
||||
response = await client.post(
|
||||
"/api/v1/orders",
|
||||
json={
|
||||
"customer_level": "low",
|
||||
"service_mode": "auto_basic",
|
||||
"model_id": 101,
|
||||
"pose_id": 3,
|
||||
"garment_asset_id": 9001,
|
||||
"scene_ref_asset_id": 8001,
|
||||
"model_id": resources["model"]["id"],
|
||||
"garment_asset_id": resources["garment"]["id"],
|
||||
"scene_ref_asset_id": resources["scene"]["id"],
|
||||
},
|
||||
)
|
||||
|
||||
@@ -105,6 +214,160 @@ async def test_low_end_order_completes(api_runtime):
|
||||
assert workflow_response.json()["workflow_status"] == "succeeded"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_model_step_uses_library_original_file(api_runtime):
|
||||
"""prepare_model should use the model resource's original file rather than a mock URI."""
|
||||
|
||||
def expected_input_snapshot(resource: dict) -> dict:
|
||||
original_file = next(file for file in resource["files"] if file["file_role"] == "original")
|
||||
snapshot = {
|
||||
"resource_id": resource["id"],
|
||||
"resource_name": resource["name"],
|
||||
"original_file_id": original_file["id"],
|
||||
"original_url": resource["original_url"],
|
||||
"mime_type": original_file["mime_type"],
|
||||
"width": original_file["width"],
|
||||
"height": original_file["height"],
|
||||
}
|
||||
return {key: value for key, value in snapshot.items() if value is not None}
|
||||
|
||||
client, env = api_runtime
|
||||
resources = await create_workflow_resources(client, include_scene=True)
|
||||
model_resource = resources["model"]
|
||||
garment_resource = resources["garment"]
|
||||
scene_resource = resources["scene"]
|
||||
|
||||
response = await client.post(
|
||||
"/api/v1/orders",
|
||||
json={
|
||||
"customer_level": "low",
|
||||
"service_mode": "auto_basic",
|
||||
"model_id": model_resource["id"],
|
||||
"garment_asset_id": garment_resource["id"],
|
||||
"scene_ref_asset_id": scene_resource["id"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
payload = response.json()
|
||||
|
||||
handle = env.client.get_workflow_handle(payload["workflow_id"])
|
||||
result = await handle.result()
|
||||
assert result["status"] == "succeeded"
|
||||
|
||||
assets_response = await client.get(f"/api/v1/orders/{payload['order_id']}/assets")
|
||||
assert assets_response.status_code == 200
|
||||
prepared_asset = next(
|
||||
asset for asset in assets_response.json() if asset["asset_type"] == "prepared_model"
|
||||
)
|
||||
assert prepared_asset["uri"] == model_resource["original_url"]
|
||||
assert prepared_asset["metadata_json"]["library_resource_id"] == model_resource["id"]
|
||||
assert prepared_asset["metadata_json"]["library_original_file_id"] == next(
|
||||
file["id"] for file in model_resource["files"] if file["file_role"] == "original"
|
||||
)
|
||||
assert prepared_asset["metadata_json"]["library_original_url"] == model_resource["original_url"]
|
||||
assert prepared_asset["metadata_json"]["model_input"] == expected_input_snapshot(model_resource)
|
||||
assert prepared_asset["metadata_json"]["garment_input"] == expected_input_snapshot(garment_resource)
|
||||
assert prepared_asset["metadata_json"]["scene_input"] == expected_input_snapshot(scene_resource)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_end_order_completes_without_scene(api_runtime):
|
||||
"""Low-end orders should still complete when scene input is omitted."""
|
||||
|
||||
client, env = api_runtime
|
||||
resources = await create_workflow_resources(client, include_scene=False)
|
||||
response = await client.post(
|
||||
"/api/v1/orders",
|
||||
json={
|
||||
"customer_level": "low",
|
||||
"service_mode": "auto_basic",
|
||||
"model_id": resources["model"]["id"],
|
||||
"garment_asset_id": resources["garment"]["id"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
payload = response.json()
|
||||
|
||||
handle = env.client.get_workflow_handle(payload["workflow_id"])
|
||||
result = await handle.result()
|
||||
|
||||
assert result["status"] == "succeeded"
|
||||
|
||||
order_response = await client.get(f"/api/v1/orders/{payload['order_id']}")
|
||||
assert order_response.status_code == 200
|
||||
assert order_response.json()["scene_ref_asset_id"] is None
|
||||
assert order_response.json()["status"] == "succeeded"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_end_order_reuses_tryon_uri_for_qc_and_final_without_scene(api_runtime):
|
||||
"""Without scene input, qc/export should keep the real try-on image URI."""
|
||||
|
||||
client, env = api_runtime
|
||||
resources = await create_workflow_resources(client, include_scene=False)
|
||||
response = await client.post(
|
||||
"/api/v1/orders",
|
||||
json={
|
||||
"customer_level": "low",
|
||||
"service_mode": "auto_basic",
|
||||
"model_id": resources["model"]["id"],
|
||||
"garment_asset_id": resources["garment"]["id"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
payload = response.json()
|
||||
|
||||
handle = env.client.get_workflow_handle(payload["workflow_id"])
|
||||
result = await handle.result()
|
||||
assert result["status"] == "succeeded"
|
||||
|
||||
assets_response = await client.get(f"/api/v1/orders/{payload['order_id']}/assets")
|
||||
assert assets_response.status_code == 200
|
||||
assets = assets_response.json()
|
||||
|
||||
tryon_asset = next(asset for asset in assets if asset["asset_type"] == "tryon")
|
||||
qc_asset = next(asset for asset in assets if asset["asset_type"] == "qc_candidate")
|
||||
final_asset = next(asset for asset in assets if asset["asset_type"] == "final")
|
||||
|
||||
assert qc_asset["uri"] == tryon_asset["uri"]
|
||||
assert final_asset["uri"] == tryon_asset["uri"]
|
||||
assert final_asset["metadata_json"]["source_asset_id"] == qc_asset["id"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mid_end_order_waits_review_then_approves_without_scene(api_runtime):
|
||||
"""Mid-end orders should still reach review and approve when scene input is omitted."""
|
||||
|
||||
client, env = api_runtime
|
||||
resources = await create_workflow_resources(client, include_scene=False)
|
||||
response = await client.post(
|
||||
"/api/v1/orders",
|
||||
json={
|
||||
"customer_level": "mid",
|
||||
"service_mode": "semi_pro",
|
||||
"model_id": resources["model"]["id"],
|
||||
"garment_asset_id": resources["garment"]["id"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
payload = response.json()
|
||||
|
||||
await wait_for_workflow_status(client, payload["order_id"], "waiting_review")
|
||||
|
||||
review_response = await client.post(
|
||||
f"/api/v1/reviews/{payload['order_id']}/submit",
|
||||
json={"decision": "approve", "reviewer_id": 77, "comment": "通过"},
|
||||
)
|
||||
assert review_response.status_code == 200
|
||||
|
||||
handle = env.client.get_workflow_handle(payload["workflow_id"])
|
||||
result = await handle.result()
|
||||
assert result["status"] == "succeeded"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mid_end_order_waits_review_then_approves(api_runtime):
|
||||
"""Mid-end orders should pause for review and continue after approval."""
|
||||
@@ -133,6 +396,162 @@ async def test_mid_end_order_waits_review_then_approves(api_runtime):
|
||||
assert order_response.json()["status"] == "succeeded"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_upload_presign_returns_direct_upload_metadata(api_runtime):
|
||||
"""Library upload presign should return S3 direct-upload metadata and a public URL."""
|
||||
|
||||
client, _ = api_runtime
|
||||
response = await client.post(
|
||||
"/api/v1/library/uploads/presign",
|
||||
json={
|
||||
"resource_type": "model",
|
||||
"file_role": "original",
|
||||
"file_name": "ava.png",
|
||||
"content_type": "image/png",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["method"] == "PUT"
|
||||
assert payload["storage_key"].startswith("library/models/")
|
||||
assert payload["public_url"].startswith("https://")
|
||||
assert "library/models/" in payload["upload_url"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_resource_can_be_created_and_listed(api_runtime):
|
||||
"""Creating a library resource should persist original, thumbnail, and gallery files."""
|
||||
|
||||
client, _ = api_runtime
|
||||
create_response = await client.post(
|
||||
"/api/v1/library/resources",
|
||||
json={
|
||||
"resource_type": "model",
|
||||
"name": "Ava Studio",
|
||||
"description": "棚拍女模特",
|
||||
"tags": ["女装", "棚拍"],
|
||||
"gender": "female",
|
||||
"age_group": "adult",
|
||||
"files": [
|
||||
{
|
||||
"file_role": "original",
|
||||
"storage_key": "library/models/ava/original.png",
|
||||
"public_url": "https://images.marcusd.me/library/models/ava/original.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 1024,
|
||||
"sort_order": 0,
|
||||
},
|
||||
{
|
||||
"file_role": "thumbnail",
|
||||
"storage_key": "library/models/ava/thumb.png",
|
||||
"public_url": "https://images.marcusd.me/library/models/ava/thumb.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 256,
|
||||
"sort_order": 0,
|
||||
},
|
||||
{
|
||||
"file_role": "gallery",
|
||||
"storage_key": "library/models/ava/gallery-1.png",
|
||||
"public_url": "https://images.marcusd.me/library/models/ava/gallery-1.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 2048,
|
||||
"sort_order": 1,
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert create_response.status_code == 201
|
||||
created = create_response.json()
|
||||
assert created["resource_type"] == "model"
|
||||
assert created["pose_id"] is None
|
||||
assert created["cover_url"] == "https://images.marcusd.me/library/models/ava/thumb.png"
|
||||
assert created["original_url"] == "https://images.marcusd.me/library/models/ava/original.png"
|
||||
assert len(created["files"]) == 3
|
||||
|
||||
list_response = await client.get("/api/v1/library/resources", params={"resource_type": "model"})
|
||||
assert list_response.status_code == 200
|
||||
listing = list_response.json()
|
||||
assert listing["total"] == 1
|
||||
assert listing["items"][0]["name"] == "Ava Studio"
|
||||
assert listing["items"][0]["gender"] == "female"
|
||||
assert listing["items"][0]["pose_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_resource_list_supports_type_specific_filters(api_runtime):
|
||||
"""Library listing should support resource-type-specific filter fields."""
|
||||
|
||||
client, _ = api_runtime
|
||||
|
||||
payloads = [
|
||||
{
|
||||
"resource_type": "scene",
|
||||
"name": "Loft Window",
|
||||
"description": "暖调室内场景",
|
||||
"tags": ["室内"],
|
||||
"environment": "indoor",
|
||||
"files": [
|
||||
{
|
||||
"file_role": "original",
|
||||
"storage_key": "library/scenes/loft/original.png",
|
||||
"public_url": "https://images.marcusd.me/library/scenes/loft/original.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 1024,
|
||||
"sort_order": 0,
|
||||
},
|
||||
{
|
||||
"file_role": "thumbnail",
|
||||
"storage_key": "library/scenes/loft/thumb.png",
|
||||
"public_url": "https://images.marcusd.me/library/scenes/loft/thumb.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 256,
|
||||
"sort_order": 0,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"resource_type": "scene",
|
||||
"name": "Garden Walk",
|
||||
"description": "自然光室外场景",
|
||||
"tags": ["室外"],
|
||||
"environment": "outdoor",
|
||||
"files": [
|
||||
{
|
||||
"file_role": "original",
|
||||
"storage_key": "library/scenes/garden/original.png",
|
||||
"public_url": "https://images.marcusd.me/library/scenes/garden/original.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 1024,
|
||||
"sort_order": 0,
|
||||
},
|
||||
{
|
||||
"file_role": "thumbnail",
|
||||
"storage_key": "library/scenes/garden/thumb.png",
|
||||
"public_url": "https://images.marcusd.me/library/scenes/garden/thumb.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 256,
|
||||
"sort_order": 0,
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
for payload in payloads:
|
||||
response = await client.post("/api/v1/library/resources", json=payload)
|
||||
assert response.status_code == 201
|
||||
|
||||
indoor_response = await client.get(
|
||||
"/api/v1/library/resources",
|
||||
params={"resource_type": "scene", "environment": "indoor"},
|
||||
)
|
||||
assert indoor_response.status_code == 200
|
||||
indoor_payload = indoor_response.json()
|
||||
assert indoor_payload["total"] == 1
|
||||
assert indoor_payload["items"][0]["name"] == "Loft Window"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("decision", "expected_step"),
|
||||
@@ -312,15 +731,16 @@ async def test_orders_list_returns_recent_orders_with_revision_summary(api_runti
|
||||
|
||||
client, env = api_runtime
|
||||
|
||||
low_resources = await create_workflow_resources(client, include_scene=True)
|
||||
low_order = await client.post(
|
||||
"/api/v1/orders",
|
||||
json={
|
||||
"customer_level": "low",
|
||||
"service_mode": "auto_basic",
|
||||
"model_id": 201,
|
||||
"model_id": low_resources["model"]["id"],
|
||||
"pose_id": 11,
|
||||
"garment_asset_id": 9101,
|
||||
"scene_ref_asset_id": 8101,
|
||||
"garment_asset_id": low_resources["garment"]["id"],
|
||||
"scene_ref_asset_id": low_resources["scene"]["id"],
|
||||
},
|
||||
)
|
||||
assert low_order.status_code == 201
|
||||
@@ -394,15 +814,16 @@ async def test_workflows_list_returns_recent_runs_with_failure_count(api_runtime
|
||||
|
||||
client, env = api_runtime
|
||||
|
||||
low_resources = await create_workflow_resources(client, include_scene=True)
|
||||
low_order = await client.post(
|
||||
"/api/v1/orders",
|
||||
json={
|
||||
"customer_level": "low",
|
||||
"service_mode": "auto_basic",
|
||||
"model_id": 301,
|
||||
"model_id": low_resources["model"]["id"],
|
||||
"pose_id": 21,
|
||||
"garment_asset_id": 9201,
|
||||
"scene_ref_asset_id": 8201,
|
||||
"garment_asset_id": low_resources["garment"]["id"],
|
||||
"scene_ref_asset_id": low_resources["scene"]["id"],
|
||||
},
|
||||
)
|
||||
assert low_order.status_code == 201
|
||||
|
||||
150
tests/test_gemini_provider.py
Normal file
150
tests/test_gemini_provider.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Tests for the Gemini image provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.infra.image_generation.base import SourceImage
|
||||
from app.infra.image_generation.gemini_provider import GeminiImageProvider
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_provider_uses_configured_base_url_and_model():
|
||||
"""The provider should call the configured endpoint and decode the returned image bytes."""
|
||||
|
||||
expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent"
|
||||
png_bytes = b"generated-png"
|
||||
|
||||
async def handler(request: httpx.Request) -> httpx.Response:
|
||||
assert str(request.url) == expected_url
|
||||
assert request.headers["x-goog-api-key"] == "test-key"
|
||||
payload = request.read().decode("utf-8")
|
||||
assert "gemini-test-model" not in payload
|
||||
assert "Dress the person in the garment" in payload
|
||||
assert "cGVyc29uLWJ5dGVz" in payload
|
||||
assert "Z2FybWVudC1ieXRlcw==" in payload
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{"text": "done"},
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": base64.b64encode(png_bytes).decode("utf-8"),
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
provider = GeminiImageProvider(
|
||||
api_key="test-key",
|
||||
base_url="https://gemini.example.com/custom",
|
||||
model="gemini-test-model",
|
||||
timeout_seconds=5,
|
||||
http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
try:
|
||||
result = await provider.generate_tryon_image(
|
||||
prompt="Dress the person in the garment",
|
||||
person_image=SourceImage(
|
||||
url="https://images.example.com/person.png",
|
||||
mime_type="image/png",
|
||||
data=b"person-bytes",
|
||||
),
|
||||
garment_image=SourceImage(
|
||||
url="https://images.example.com/garment.png",
|
||||
mime_type="image/png",
|
||||
data=b"garment-bytes",
|
||||
),
|
||||
)
|
||||
finally:
|
||||
await provider._http_client.aclose()
|
||||
|
||||
assert result.image_bytes == png_bytes
|
||||
assert result.mime_type == "image/png"
|
||||
assert result.provider == "gemini"
|
||||
assert result.model == "gemini-test-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_provider_retries_remote_disconnect_and_uses_request_timeout():
|
||||
"""The provider should retry transient transport failures and pass timeout per request."""
|
||||
|
||||
expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent"
|
||||
jpeg_bytes = b"generated-jpeg"
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self) -> None:
|
||||
self.attempts = 0
|
||||
self.timeouts: list[int | None] = []
|
||||
|
||||
async def post(self, url: str, **kwargs) -> httpx.Response:
|
||||
self.attempts += 1
|
||||
self.timeouts.append(kwargs.get("timeout"))
|
||||
assert url == expected_url
|
||||
if self.attempts == 1:
|
||||
raise httpx.RemoteProtocolError("Server disconnected without sending a response.")
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/jpeg",
|
||||
"data": base64.b64encode(jpeg_bytes).decode("utf-8"),
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
request=httpx.Request("POST", url),
|
||||
)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
fake_client = FakeClient()
|
||||
|
||||
provider = GeminiImageProvider(
|
||||
api_key="test-key",
|
||||
base_url="https://gemini.example.com/custom",
|
||||
model="gemini-test-model",
|
||||
timeout_seconds=300,
|
||||
http_client=fake_client,
|
||||
)
|
||||
|
||||
result = await provider.generate_tryon_image(
|
||||
prompt="Dress the person in the garment",
|
||||
person_image=SourceImage(
|
||||
url="https://images.example.com/person.png",
|
||||
mime_type="image/png",
|
||||
data=b"person-bytes",
|
||||
),
|
||||
garment_image=SourceImage(
|
||||
url="https://images.example.com/garment.png",
|
||||
mime_type="image/png",
|
||||
data=b"garment-bytes",
|
||||
),
|
||||
)
|
||||
|
||||
assert fake_client.attempts == 2
|
||||
assert fake_client.timeouts == [300, 300]
|
||||
assert result.image_bytes == jpeg_bytes
|
||||
assert result.mime_type == "image/jpeg"
|
||||
91
tests/test_image_generation_service.py
Normal file
91
tests/test_image_generation_service.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Tests for application-level image-generation orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.application.services.image_generation_service import ImageGenerationService
|
||||
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
|
||||
|
||||
|
||||
class FakeProvider:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[tuple[str, str, str]] = []
|
||||
|
||||
async def generate_tryon_image(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
person_image: SourceImage,
|
||||
garment_image: SourceImage,
|
||||
) -> GeneratedImageResult:
|
||||
self.calls.append(("tryon", person_image.url, garment_image.url))
|
||||
return GeneratedImageResult(
|
||||
image_bytes=b"tryon",
|
||||
mime_type="image/png",
|
||||
provider="gemini",
|
||||
model="gemini-test",
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
async def generate_scene_image(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
source_image: SourceImage,
|
||||
scene_image: SourceImage,
|
||||
) -> GeneratedImageResult:
|
||||
self.calls.append(("scene", source_image.url, scene_image.url))
|
||||
return GeneratedImageResult(
|
||||
image_bytes=b"scene",
|
||||
mime_type="image/jpeg",
|
||||
provider="gemini",
|
||||
model="gemini-test",
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_generation_service_downloads_inputs_for_tryon_and_scene(monkeypatch):
|
||||
"""The service should download both inputs and dispatch them to the matching provider method."""
|
||||
|
||||
responses = {
|
||||
"https://images.example.com/person.png": (b"person-bytes", "image/png"),
|
||||
"https://images.example.com/garment.png": (b"garment-bytes", "image/png"),
|
||||
"https://images.example.com/source.jpg": (b"source-bytes", "image/jpeg"),
|
||||
"https://images.example.com/scene.jpg": (b"scene-bytes", "image/jpeg"),
|
||||
}
|
||||
|
||||
async def handler(request: httpx.Request) -> httpx.Response:
|
||||
body, mime = responses[str(request.url)]
|
||||
return httpx.Response(200, content=body, headers={"content-type": mime}, request=request)
|
||||
|
||||
monkeypatch.setenv("GEMINI_API_KEY", "test-key")
|
||||
|
||||
from app.config.settings import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
provider = FakeProvider()
|
||||
downloader = httpx.AsyncClient(transport=httpx.MockTransport(handler))
|
||||
service = ImageGenerationService(provider=provider, downloader=downloader)
|
||||
|
||||
try:
|
||||
tryon = await service.generate_tryon_image(
|
||||
person_image_url="https://images.example.com/person.png",
|
||||
garment_image_url="https://images.example.com/garment.png",
|
||||
)
|
||||
scene = await service.generate_scene_image(
|
||||
source_image_url="https://images.example.com/source.jpg",
|
||||
scene_image_url="https://images.example.com/scene.jpg",
|
||||
)
|
||||
finally:
|
||||
await downloader.aclose()
|
||||
get_settings.cache_clear()
|
||||
|
||||
assert provider.calls == [
|
||||
("tryon", "https://images.example.com/person.png", "https://images.example.com/garment.png"),
|
||||
("scene", "https://images.example.com/source.jpg", "https://images.example.com/scene.jpg"),
|
||||
]
|
||||
assert tryon.image_bytes == b"tryon"
|
||||
assert scene.image_bytes == b"scene"
|
||||
63
tests/test_library_resource_actions.py
Normal file
63
tests/test_library_resource_actions.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_library_resource_can_be_updated_and_archived(api_runtime):
|
||||
client, _ = api_runtime
|
||||
|
||||
create_response = await client.post(
|
||||
"/api/v1/library/resources",
|
||||
json={
|
||||
"resource_type": "model",
|
||||
"name": "Ava Studio",
|
||||
"description": "棚拍女模特",
|
||||
"tags": ["女装", "棚拍"],
|
||||
"gender": "female",
|
||||
"age_group": "adult",
|
||||
"files": [
|
||||
{
|
||||
"file_role": "original",
|
||||
"storage_key": "library/models/ava/original.png",
|
||||
"public_url": "https://images.marcusd.me/library/models/ava/original.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 1024,
|
||||
"sort_order": 0,
|
||||
},
|
||||
{
|
||||
"file_role": "thumbnail",
|
||||
"storage_key": "library/models/ava/thumb.png",
|
||||
"public_url": "https://images.marcusd.me/library/models/ava/thumb.png",
|
||||
"mime_type": "image/png",
|
||||
"size_bytes": 256,
|
||||
"sort_order": 0,
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert create_response.status_code == 201
|
||||
resource = create_response.json()
|
||||
|
||||
update_response = await client.patch(
|
||||
f"/api/v1/library/resources/{resource['id']}",
|
||||
json={
|
||||
"name": "Ava Studio Updated",
|
||||
"description": "新的描述",
|
||||
"tags": ["女装", "更新"],
|
||||
},
|
||||
)
|
||||
|
||||
assert update_response.status_code == 200
|
||||
updated = update_response.json()
|
||||
assert updated["name"] == "Ava Studio Updated"
|
||||
assert updated["description"] == "新的描述"
|
||||
assert updated["tags"] == ["女装", "更新"]
|
||||
|
||||
archive_response = await client.delete(f"/api/v1/library/resources/{resource['id']}")
|
||||
|
||||
assert archive_response.status_code == 200
|
||||
assert archive_response.json()["id"] == resource["id"]
|
||||
|
||||
list_response = await client.get("/api/v1/library/resources", params={"resource_type": "model"})
|
||||
assert list_response.status_code == 200
|
||||
assert list_response.json()["items"] == []
|
||||
15
tests/test_settings.py
Normal file
15
tests/test_settings.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Tests for application settings resolution."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from app.config.settings import Settings
|
||||
|
||||
|
||||
def test_settings_env_file_uses_backend_repo_path() -> None:
|
||||
"""Settings should resolve .env relative to the backend repo, not the launch cwd."""
|
||||
|
||||
env_file = Settings.model_config.get("env_file")
|
||||
|
||||
assert isinstance(env_file, Path)
|
||||
assert env_file.is_absolute()
|
||||
assert env_file.name == ".env"
|
||||
298
tests/test_tryon_activity.py
Normal file
298
tests/test_tryon_activity.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""Focused tests for the try-on activity implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from app.domain.enums import AssetType, LibraryFileRole, LibraryResourceStatus, LibraryResourceType, OrderStatus, WorkflowStepName
|
||||
from app.infra.db.models.asset import AssetORM
|
||||
from app.infra.db.models.library_resource import LibraryResourceORM
|
||||
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
|
||||
from app.infra.db.models.order import OrderORM
|
||||
from app.infra.db.models.workflow_run import WorkflowRunORM
|
||||
from app.infra.db.session import get_session_factory
|
||||
from app.workers.activities import tryon_activities
|
||||
from app.workers.workflows.types import StepActivityInput
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FakeGeneratedImage:
|
||||
image_bytes: bytes
|
||||
mime_type: str
|
||||
provider: str
|
||||
model: str
|
||||
prompt: str
|
||||
|
||||
|
||||
class FakeImageGenerationService:
|
||||
async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> FakeGeneratedImage:
|
||||
assert person_image_url == "https://images.example.com/orders/1/prepared-model.png"
|
||||
assert garment_image_url == "https://images.example.com/library/garments/cream-dress/original.png"
|
||||
return FakeGeneratedImage(
|
||||
image_bytes=b"fake-png-binary",
|
||||
mime_type="image/png",
|
||||
provider="gemini",
|
||||
model="gemini-test-image",
|
||||
prompt="test prompt",
|
||||
)
|
||||
|
||||
async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> FakeGeneratedImage:
|
||||
assert source_image_url == "https://images.example.com/orders/1/tryon/generated.png"
|
||||
assert scene_image_url == "https://images.example.com/library/scenes/studio/original.png"
|
||||
return FakeGeneratedImage(
|
||||
image_bytes=b"fake-scene-binary",
|
||||
mime_type="image/jpeg",
|
||||
provider="gemini",
|
||||
model="gemini-test-image",
|
||||
prompt="scene prompt",
|
||||
)
|
||||
|
||||
|
||||
class FakeOrderArtifactStorageService:
|
||||
async def upload_generated_image(
|
||||
self,
|
||||
*,
|
||||
order_id: int,
|
||||
step_name: WorkflowStepName,
|
||||
image_bytes: bytes,
|
||||
mime_type: str,
|
||||
) -> tuple[str, str]:
|
||||
assert order_id == 1
|
||||
assert step_name == WorkflowStepName.TRYON
|
||||
assert image_bytes == b"fake-png-binary"
|
||||
assert mime_type == "image/png"
|
||||
return (
|
||||
"orders/1/tryon/generated.png",
|
||||
"https://images.example.com/orders/1/tryon/generated.png",
|
||||
)
|
||||
|
||||
|
||||
class FakeSceneArtifactStorageService:
|
||||
async def upload_generated_image(
|
||||
self,
|
||||
*,
|
||||
order_id: int,
|
||||
step_name: WorkflowStepName,
|
||||
image_bytes: bytes,
|
||||
mime_type: str,
|
||||
) -> tuple[str, str]:
|
||||
assert order_id == 1
|
||||
assert step_name == WorkflowStepName.SCENE
|
||||
assert image_bytes == b"fake-scene-binary"
|
||||
assert mime_type == "image/jpeg"
|
||||
return (
|
||||
"orders/1/scene/generated.jpg",
|
||||
"https://images.example.com/orders/1/scene/generated.jpg",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_tryon_activity_persists_uploaded_asset_with_provider_metadata(api_runtime, monkeypatch):
|
||||
"""Gemini-mode try-on should persist the uploaded output URL instead of a mock URI."""
|
||||
|
||||
monkeypatch.setattr(
|
||||
tryon_activities,
|
||||
"get_image_generation_service",
|
||||
lambda: FakeImageGenerationService(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
tryon_activities,
|
||||
"get_order_artifact_storage_service",
|
||||
lambda: FakeOrderArtifactStorageService(),
|
||||
)
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
order = OrderORM(
|
||||
customer_level="low",
|
||||
service_mode="auto_basic",
|
||||
status=OrderStatus.CREATED,
|
||||
model_id=1,
|
||||
garment_asset_id=2,
|
||||
)
|
||||
session.add(order)
|
||||
await session.flush()
|
||||
|
||||
workflow_run = WorkflowRunORM(
|
||||
order_id=order.id,
|
||||
workflow_id=f"order-{order.id}",
|
||||
workflow_type="LowEndPipelineWorkflow",
|
||||
status=OrderStatus.CREATED,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
await session.flush()
|
||||
|
||||
prepared_asset = AssetORM(
|
||||
order_id=order.id,
|
||||
asset_type=AssetType.PREPARED_MODEL,
|
||||
step_name=WorkflowStepName.PREPARE_MODEL,
|
||||
uri="https://images.example.com/orders/1/prepared-model.png",
|
||||
metadata_json={"library_resource_id": 1},
|
||||
)
|
||||
session.add(prepared_asset)
|
||||
|
||||
garment_resource = LibraryResourceORM(
|
||||
resource_type=LibraryResourceType.GARMENT,
|
||||
name="Cream Dress",
|
||||
description="米白色连衣裙",
|
||||
tags=["女装"],
|
||||
status=LibraryResourceStatus.ACTIVE,
|
||||
category="dress",
|
||||
)
|
||||
session.add(garment_resource)
|
||||
await session.flush()
|
||||
|
||||
garment_original = LibraryResourceFileORM(
|
||||
resource_id=garment_resource.id,
|
||||
file_role=LibraryFileRole.ORIGINAL,
|
||||
storage_key="library/garments/cream-dress/original.png",
|
||||
public_url="https://images.example.com/library/garments/cream-dress/original.png",
|
||||
bucket="test-bucket",
|
||||
mime_type="image/png",
|
||||
size_bytes=1024,
|
||||
sort_order=0,
|
||||
)
|
||||
garment_thumb = LibraryResourceFileORM(
|
||||
resource_id=garment_resource.id,
|
||||
file_role=LibraryFileRole.THUMBNAIL,
|
||||
storage_key="library/garments/cream-dress/thumb.png",
|
||||
public_url="https://images.example.com/library/garments/cream-dress/thumb.png",
|
||||
bucket="test-bucket",
|
||||
mime_type="image/png",
|
||||
size_bytes=256,
|
||||
sort_order=0,
|
||||
)
|
||||
session.add_all([garment_original, garment_thumb])
|
||||
await session.flush()
|
||||
garment_resource.original_file_id = garment_original.id
|
||||
garment_resource.cover_file_id = garment_thumb.id
|
||||
await session.commit()
|
||||
|
||||
payload = StepActivityInput(
|
||||
order_id=1,
|
||||
workflow_run_id=1,
|
||||
step_name=WorkflowStepName.TRYON,
|
||||
source_asset_id=1,
|
||||
garment_asset_id=1,
|
||||
)
|
||||
|
||||
result = await tryon_activities.run_tryon_activity(payload)
|
||||
|
||||
assert result.uri == "https://images.example.com/orders/1/tryon/generated.png"
|
||||
assert result.metadata["provider"] == "gemini"
|
||||
assert result.metadata["model"] == "gemini-test-image"
|
||||
assert result.metadata["prepared_asset_id"] == 1
|
||||
assert result.metadata["garment_resource_id"] == 1
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
assets = (await session.execute(
|
||||
AssetORM.__table__.select().where(AssetORM.order_id == 1, AssetORM.asset_type == AssetType.TRYON)
|
||||
)).mappings().all()
|
||||
|
||||
assert len(assets) == 1
|
||||
assert assets[0]["uri"] == "https://images.example.com/orders/1/tryon/generated.png"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_scene_activity_persists_uploaded_asset_with_provider_metadata(api_runtime, monkeypatch):
|
||||
"""Gemini-mode scene should persist the uploaded output URL instead of a mock URI."""
|
||||
|
||||
from app.workers.activities import scene_activities
|
||||
|
||||
monkeypatch.setattr(
|
||||
scene_activities,
|
||||
"get_image_generation_service",
|
||||
lambda: FakeImageGenerationService(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
scene_activities,
|
||||
"get_order_artifact_storage_service",
|
||||
lambda: FakeSceneArtifactStorageService(),
|
||||
)
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
order = OrderORM(
|
||||
customer_level="low",
|
||||
service_mode="auto_basic",
|
||||
status=OrderStatus.CREATED,
|
||||
model_id=1,
|
||||
garment_asset_id=2,
|
||||
scene_ref_asset_id=3,
|
||||
)
|
||||
session.add(order)
|
||||
await session.flush()
|
||||
|
||||
workflow_run = WorkflowRunORM(
|
||||
order_id=order.id,
|
||||
workflow_id=f"order-{order.id}",
|
||||
workflow_type="LowEndPipelineWorkflow",
|
||||
status=OrderStatus.CREATED,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
await session.flush()
|
||||
|
||||
tryon_asset = AssetORM(
|
||||
order_id=order.id,
|
||||
asset_type=AssetType.TRYON,
|
||||
step_name=WorkflowStepName.TRYON,
|
||||
uri="https://images.example.com/orders/1/tryon/generated.png",
|
||||
metadata_json={"prepared_asset_id": 1},
|
||||
)
|
||||
session.add(tryon_asset)
|
||||
|
||||
scene_resource = LibraryResourceORM(
|
||||
resource_type=LibraryResourceType.SCENE,
|
||||
name="Studio Background",
|
||||
description="摄影棚背景",
|
||||
tags=["室内"],
|
||||
status=LibraryResourceStatus.ACTIVE,
|
||||
environment="indoor",
|
||||
)
|
||||
session.add(scene_resource)
|
||||
await session.flush()
|
||||
|
||||
scene_original = LibraryResourceFileORM(
|
||||
resource_id=scene_resource.id,
|
||||
file_role=LibraryFileRole.ORIGINAL,
|
||||
storage_key="library/scenes/studio/original.png",
|
||||
public_url="https://images.example.com/library/scenes/studio/original.png",
|
||||
bucket="test-bucket",
|
||||
mime_type="image/png",
|
||||
size_bytes=2048,
|
||||
sort_order=0,
|
||||
)
|
||||
session.add(scene_original)
|
||||
await session.flush()
|
||||
scene_resource.original_file_id = scene_original.id
|
||||
scene_resource.cover_file_id = scene_original.id
|
||||
await session.commit()
|
||||
|
||||
payload = StepActivityInput(
|
||||
order_id=1,
|
||||
workflow_run_id=1,
|
||||
step_name=WorkflowStepName.SCENE,
|
||||
source_asset_id=1,
|
||||
scene_ref_asset_id=1,
|
||||
)
|
||||
|
||||
result = await scene_activities.run_scene_activity(payload)
|
||||
|
||||
assert result.uri == "https://images.example.com/orders/1/scene/generated.jpg"
|
||||
assert result.metadata["provider"] == "gemini"
|
||||
assert result.metadata["model"] == "gemini-test-image"
|
||||
assert result.metadata["source_asset_id"] == 1
|
||||
assert result.metadata["scene_resource_id"] == 1
|
||||
|
||||
async with get_session_factory()() as session:
|
||||
assets = (
|
||||
await session.execute(
|
||||
AssetORM.__table__.select().where(
|
||||
AssetORM.order_id == 1,
|
||||
AssetORM.asset_type == AssetType.SCENE,
|
||||
)
|
||||
)
|
||||
).mappings().all()
|
||||
|
||||
assert len(assets) == 1
|
||||
assert assets[0]["uri"] == "https://images.example.com/orders/1/scene/generated.jpg"
|
||||
33
tests/test_workflow_timeouts.py
Normal file
33
tests/test_workflow_timeouts.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Tests for workflow activity timeout policies."""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from app.infra.temporal.task_queues import (
|
||||
IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
||||
IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
|
||||
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
|
||||
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
|
||||
IMAGE_PIPELINE_QC_TASK_QUEUE,
|
||||
)
|
||||
from app.workers.workflows.timeout_policy import (
|
||||
DEFAULT_ACTIVITY_TIMEOUT,
|
||||
LONG_RUNNING_ACTIVITY_TIMEOUT,
|
||||
activity_timeout_for_task_queue,
|
||||
)
|
||||
|
||||
|
||||
def test_activity_timeout_for_task_queue_uses_long_timeout_for_image_work():
|
||||
"""Image generation and post-processing queues should get a longer timeout."""
|
||||
|
||||
assert DEFAULT_ACTIVITY_TIMEOUT == timedelta(seconds=30)
|
||||
assert LONG_RUNNING_ACTIVITY_TIMEOUT == timedelta(minutes=5)
|
||||
assert activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE) == LONG_RUNNING_ACTIVITY_TIMEOUT
|
||||
assert activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE) == LONG_RUNNING_ACTIVITY_TIMEOUT
|
||||
|
||||
|
||||
def test_activity_timeout_for_task_queue_keeps_short_timeout_for_light_steps():
|
||||
"""Control, QC, and export queues should stay on the short timeout."""
|
||||
|
||||
assert activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE) == DEFAULT_ACTIVITY_TIMEOUT
|
||||
assert activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE) == DEFAULT_ACTIVITY_TIMEOUT
|
||||
assert activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE) == DEFAULT_ACTIVITY_TIMEOUT
|
||||
Reference in New Issue
Block a user