From 04da401ab4aa0d1a3c8dc5cb1dc7d733ee6b805c Mon Sep 17 00:00:00 2001 From: afei A <57030625+NewHubBoy@users.noreply.github.com> Date: Sun, 29 Mar 2026 00:24:29 +0800 Subject: [PATCH] feat: add resource library and real image workflow --- .env.example | 6 + alembic/env.py | 5 +- .../20260328_0003_resource_library.py | 92 ++++ .../20260328_0004_optional_pose_id.py | 25 + ...260328_0005_optional_scene_ref_asset_id.py | 25 + app/api/routers/library.py | 89 ++++ app/api/schemas/library.py | 118 +++++ app/api/schemas/order.py | 8 +- .../services/image_generation_service.py | 163 +++++++ app/application/services/library_service.py | 366 ++++++++++++++ app/config/settings.py | 20 +- app/domain/enums.py | 23 + app/domain/models/order.py | 5 +- app/infra/db/models/library_resource.py | 45 ++ app/infra/db/models/library_resource_file.py | 37 ++ app/infra/db/models/order.py | 5 +- app/infra/db/session.py | 4 +- app/infra/image_generation/base.py | 38 ++ app/infra/image_generation/gemini_provider.py | 149 ++++++ app/infra/storage/s3.py | 107 +++++ app/main.py | 2 + app/workers/activities/export_activities.py | 82 +++- app/workers/activities/qc_activities.py | 13 +- app/workers/activities/scene_activities.py | 121 ++++- app/workers/activities/tryon_activities.py | 289 ++++++++++- app/workers/workflows/low_end_pipeline.py | 58 ++- app/workers/workflows/mid_end_pipeline.py | 73 ++- app/workers/workflows/timeout_policy.py | 24 + app/workers/workflows/types.py | 50 +- pyproject.toml | 1 + tests/conftest.py | 8 +- tests/test_api.py | 449 +++++++++++++++++- tests/test_gemini_provider.py | 150 ++++++ tests/test_image_generation_service.py | 91 ++++ tests/test_library_resource_actions.py | 63 +++ tests/test_settings.py | 15 + tests/test_tryon_activity.py | 298 ++++++++++++ tests/test_workflow_timeouts.py | 33 ++ 38 files changed, 3033 insertions(+), 117 deletions(-) create mode 100644 alembic/versions/20260328_0003_resource_library.py create mode 100644 alembic/versions/20260328_0004_optional_pose_id.py create mode 100644 alembic/versions/20260328_0005_optional_scene_ref_asset_id.py create mode 100644 app/api/routers/library.py create mode 100644 app/api/schemas/library.py create mode 100644 app/application/services/image_generation_service.py create mode 100644 app/application/services/library_service.py create mode 100644 app/infra/db/models/library_resource.py create mode 100644 app/infra/db/models/library_resource_file.py create mode 100644 app/infra/image_generation/base.py create mode 100644 app/infra/image_generation/gemini_provider.py create mode 100644 app/infra/storage/s3.py create mode 100644 app/workers/workflows/timeout_policy.py create mode 100644 tests/test_gemini_provider.py create mode 100644 tests/test_image_generation_service.py create mode 100644 tests/test_library_resource_actions.py create mode 100644 tests/test_settings.py create mode 100644 tests/test_tryon_activity.py create mode 100644 tests/test_workflow_timeouts.py diff --git a/.env.example b/.env.example index 7d3b879..41e635d 100644 --- a/.env.example +++ b/.env.example @@ -5,3 +5,9 @@ AUTO_CREATE_TABLES=true DATABASE_URL=sqlite+aiosqlite:///./temporal_demo.db TEMPORAL_ADDRESS=localhost:7233 TEMPORAL_NAMESPACE=default +IMAGE_GENERATION_PROVIDER=mock +GEMINI_API_KEY= +GEMINI_BASE_URL=https://api.museidea.com/v1beta +GEMINI_MODEL=gemini-3.1-flash-image-preview +GEMINI_TIMEOUT_SECONDS=300 +GEMINI_MAX_ATTEMPTS=2 diff --git a/alembic/env.py b/alembic/env.py index ebc97c4..51d48db 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -8,12 +8,14 @@ from sqlalchemy import engine_from_config, pool from app.config.settings import get_settings from app.infra.db.base import Base from app.infra.db.models.asset import AssetORM +from app.infra.db.models.library_resource import LibraryResourceORM +from app.infra.db.models.library_resource_file import LibraryResourceFileORM from app.infra.db.models.order import OrderORM from app.infra.db.models.review_task import ReviewTaskORM from app.infra.db.models.workflow_run import WorkflowRunORM from app.infra.db.models.workflow_step import WorkflowStepORM -del AssetORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM +del AssetORM, LibraryResourceORM, LibraryResourceFileORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM config = context.config @@ -58,4 +60,3 @@ if context.is_offline_mode(): run_migrations_offline() else: run_migrations_online() - diff --git a/alembic/versions/20260328_0003_resource_library.py b/alembic/versions/20260328_0003_resource_library.py new file mode 100644 index 0000000..c5f8a39 --- /dev/null +++ b/alembic/versions/20260328_0003_resource_library.py @@ -0,0 +1,92 @@ +"""resource library schema + +Revision ID: 20260328_0003 +Revises: 20260327_0002 +Create Date: 2026-03-28 10:20:00.000000 +""" + +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + + +revision: str = "20260328_0003" +down_revision: str | None = "20260327_0002" +branch_labels: Sequence[str] | None = None +depends_on: Sequence[str] | None = None + + +def upgrade() -> None: + """Create the resource-library tables.""" + + op.create_table( + "library_resources", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column( + "resource_type", + sa.Enum("model", "scene", "garment", name="libraryresourcetype", native_enum=False), + nullable=False, + ), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("tags", sa.JSON(), nullable=False), + sa.Column( + "status", + sa.Enum("active", "archived", name="libraryresourcestatus", native_enum=False), + nullable=False, + ), + sa.Column("gender", sa.String(length=32), nullable=True), + sa.Column("age_group", sa.String(length=32), nullable=True), + sa.Column("pose_id", sa.Integer(), nullable=True), + sa.Column("environment", sa.String(length=32), nullable=True), + sa.Column("category", sa.String(length=128), nullable=True), + sa.Column("cover_file_id", sa.Integer(), nullable=True), + sa.Column("original_file_id", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + ) + op.create_index("ix_library_resources_resource_type", "library_resources", ["resource_type"]) + op.create_index("ix_library_resources_status", "library_resources", ["status"]) + op.create_index("ix_library_resources_gender", "library_resources", ["gender"]) + op.create_index("ix_library_resources_age_group", "library_resources", ["age_group"]) + op.create_index("ix_library_resources_environment", "library_resources", ["environment"]) + op.create_index("ix_library_resources_category", "library_resources", ["category"]) + + op.create_table( + "library_resource_files", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("resource_id", sa.Integer(), sa.ForeignKey("library_resources.id"), nullable=False), + sa.Column( + "file_role", + sa.Enum("original", "thumbnail", "gallery", name="libraryfilerole", native_enum=False), + nullable=False, + ), + sa.Column("storage_key", sa.String(length=500), nullable=False), + sa.Column("public_url", sa.String(length=500), nullable=False), + sa.Column("bucket", sa.String(length=255), nullable=False), + sa.Column("mime_type", sa.String(length=255), nullable=False), + sa.Column("size_bytes", sa.Integer(), nullable=False), + sa.Column("sort_order", sa.Integer(), nullable=False), + sa.Column("width", sa.Integer(), nullable=True), + sa.Column("height", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + ) + op.create_index("ix_library_resource_files_resource_id", "library_resource_files", ["resource_id"]) + op.create_index("ix_library_resource_files_file_role", "library_resource_files", ["file_role"]) + + +def downgrade() -> None: + """Drop the resource-library tables.""" + + op.drop_index("ix_library_resource_files_file_role", table_name="library_resource_files") + op.drop_index("ix_library_resource_files_resource_id", table_name="library_resource_files") + op.drop_table("library_resource_files") + op.drop_index("ix_library_resources_category", table_name="library_resources") + op.drop_index("ix_library_resources_environment", table_name="library_resources") + op.drop_index("ix_library_resources_age_group", table_name="library_resources") + op.drop_index("ix_library_resources_gender", table_name="library_resources") + op.drop_index("ix_library_resources_status", table_name="library_resources") + op.drop_index("ix_library_resources_resource_type", table_name="library_resources") + op.drop_table("library_resources") diff --git a/alembic/versions/20260328_0004_optional_pose_id.py b/alembic/versions/20260328_0004_optional_pose_id.py new file mode 100644 index 0000000..563a3e5 --- /dev/null +++ b/alembic/versions/20260328_0004_optional_pose_id.py @@ -0,0 +1,25 @@ +"""Make order pose_id optional for MVP. + +Revision ID: 20260328_0004 +Revises: 20260328_0003 +Create Date: 2026-03-28 11:08:00.000000 +""" + +from alembic import op +import sqlalchemy as sa + + +revision = "20260328_0004" +down_revision = "20260328_0003" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("orders") as batch_op: + batch_op.alter_column("pose_id", existing_type=sa.Integer(), nullable=True) + + +def downgrade() -> None: + with op.batch_alter_table("orders") as batch_op: + batch_op.alter_column("pose_id", existing_type=sa.Integer(), nullable=False) diff --git a/alembic/versions/20260328_0005_optional_scene_ref_asset_id.py b/alembic/versions/20260328_0005_optional_scene_ref_asset_id.py new file mode 100644 index 0000000..cab10d1 --- /dev/null +++ b/alembic/versions/20260328_0005_optional_scene_ref_asset_id.py @@ -0,0 +1,25 @@ +"""Make order scene_ref_asset_id optional for MVP. + +Revision ID: 20260328_0005 +Revises: 20260328_0004 +Create Date: 2026-03-28 14:20:00.000000 +""" + +from alembic import op +import sqlalchemy as sa + + +revision = "20260328_0005" +down_revision = "20260328_0004" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("orders") as batch_op: + batch_op.alter_column("scene_ref_asset_id", existing_type=sa.Integer(), nullable=True) + + +def downgrade() -> None: + with op.batch_alter_table("orders") as batch_op: + batch_op.alter_column("scene_ref_asset_id", existing_type=sa.Integer(), nullable=False) diff --git a/app/api/routers/library.py b/app/api/routers/library.py new file mode 100644 index 0000000..7c3fb9c --- /dev/null +++ b/app/api/routers/library.py @@ -0,0 +1,89 @@ +"""Library resource routes.""" + +from fastapi import APIRouter, Depends, Query, status +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.schemas.library import ( + ArchiveLibraryResourceResponse, + CreateLibraryResourceRequest, + LibraryResourceListResponse, + LibraryResourceRead, + PresignUploadRequest, + PresignUploadResponse, + UpdateLibraryResourceRequest, +) +from app.application.services.library_service import LibraryService +from app.domain.enums import LibraryResourceType +from app.infra.db.session import get_db_session + +router = APIRouter(prefix="/library", tags=["library"]) +library_service = LibraryService() + + +@router.post("/uploads/presign", response_model=PresignUploadResponse) +async def create_library_upload_presign(payload: PresignUploadRequest) -> PresignUploadResponse: + """Create direct-upload metadata for a library file.""" + + return library_service.create_upload_presign( + resource_type=payload.resource_type, + file_name=payload.file_name, + content_type=payload.content_type, + ) + + +@router.post("/resources", response_model=LibraryResourceRead, status_code=status.HTTP_201_CREATED) +async def create_library_resource( + payload: CreateLibraryResourceRequest, + session: AsyncSession = Depends(get_db_session), +) -> LibraryResourceRead: + """Create a library resource with already-uploaded file metadata.""" + + return await library_service.create_resource(session, payload) + + +@router.get("/resources", response_model=LibraryResourceListResponse) +async def list_library_resources( + resource_type: LibraryResourceType | None = Query(default=None), + query: str | None = Query(default=None, min_length=1), + gender: str | None = Query(default=None), + age_group: str | None = Query(default=None), + environment: str | None = Query(default=None), + category: str | None = Query(default=None), + page: int = Query(default=1, ge=1), + limit: int = Query(default=20, ge=1, le=100), + session: AsyncSession = Depends(get_db_session), +) -> LibraryResourceListResponse: + """List library resources with basic filtering.""" + + return await library_service.list_resources( + session, + resource_type=resource_type, + query=query, + gender=gender, + age_group=age_group, + environment=environment, + category=category, + page=page, + limit=limit, + ) + + +@router.patch("/resources/{resource_id}", response_model=LibraryResourceRead) +async def update_library_resource( + resource_id: int, + payload: UpdateLibraryResourceRequest, + session: AsyncSession = Depends(get_db_session), +) -> LibraryResourceRead: + """Update editable metadata on a library resource.""" + + return await library_service.update_resource(session, resource_id, payload) + + +@router.delete("/resources/{resource_id}", response_model=ArchiveLibraryResourceResponse) +async def archive_library_resource( + resource_id: int, + session: AsyncSession = Depends(get_db_session), +) -> ArchiveLibraryResourceResponse: + """Archive a library resource instead of hard deleting it.""" + + return await library_service.archive_resource(session, resource_id) diff --git a/app/api/schemas/library.py b/app/api/schemas/library.py new file mode 100644 index 0000000..0580898 --- /dev/null +++ b/app/api/schemas/library.py @@ -0,0 +1,118 @@ +"""Library resource API schemas.""" + +from datetime import datetime + +from pydantic import BaseModel, Field + +from app.domain.enums import LibraryFileRole, LibraryResourceStatus, LibraryResourceType + + +class PresignUploadRequest(BaseModel): + """Request payload for generating direct-upload metadata.""" + + resource_type: LibraryResourceType + file_role: LibraryFileRole + file_name: str + content_type: str + + +class PresignUploadResponse(BaseModel): + """Response returned when generating direct-upload metadata.""" + + method: str + upload_url: str + headers: dict[str, str] = Field(default_factory=dict) + storage_key: str + public_url: str + + +class CreateLibraryResourceFileRequest(BaseModel): + """Metadata for one already-uploaded file.""" + + file_role: LibraryFileRole + storage_key: str + public_url: str + mime_type: str + size_bytes: int + sort_order: int = 0 + width: int | None = None + height: int | None = None + + +class CreateLibraryResourceRequest(BaseModel): + """One-shot resource creation request.""" + + resource_type: LibraryResourceType + name: str + description: str | None = None + tags: list[str] = Field(default_factory=list) + gender: str | None = None + age_group: str | None = None + pose_id: int | None = None + environment: str | None = None + category: str | None = None + files: list[CreateLibraryResourceFileRequest] + + +class UpdateLibraryResourceRequest(BaseModel): + """Partial update request for a library resource.""" + + name: str | None = None + description: str | None = None + tags: list[str] | None = None + gender: str | None = None + age_group: str | None = None + pose_id: int | None = None + environment: str | None = None + category: str | None = None + cover_file_id: int | None = None + + +class ArchiveLibraryResourceResponse(BaseModel): + """Response returned after archiving a resource.""" + + id: int + + +class LibraryResourceFileRead(BaseModel): + """Serialized library file metadata.""" + + id: int + file_role: LibraryFileRole + storage_key: str + public_url: str + bucket: str + mime_type: str + size_bytes: int + sort_order: int + width: int | None = None + height: int | None = None + created_at: datetime + + +class LibraryResourceRead(BaseModel): + """Serialized library resource.""" + + id: int + resource_type: LibraryResourceType + name: str + description: str | None = None + tags: list[str] + status: LibraryResourceStatus + gender: str | None = None + age_group: str | None = None + pose_id: int | None = None + environment: str | None = None + category: str | None = None + cover_url: str | None = None + original_url: str | None = None + files: list[LibraryResourceFileRead] + created_at: datetime + updated_at: datetime + + +class LibraryResourceListResponse(BaseModel): + """Paginated library resource response.""" + + total: int + items: list[LibraryResourceRead] diff --git a/app/api/schemas/order.py b/app/api/schemas/order.py index de4f252..4f3ff93 100644 --- a/app/api/schemas/order.py +++ b/app/api/schemas/order.py @@ -14,9 +14,9 @@ class CreateOrderRequest(BaseModel): customer_level: CustomerLevel service_mode: ServiceMode model_id: int - pose_id: int + pose_id: int | None = None garment_asset_id: int - scene_ref_asset_id: int + scene_ref_asset_id: int | None = None class CreateOrderResponse(BaseModel): @@ -35,9 +35,9 @@ class OrderDetailResponse(BaseModel): service_mode: ServiceMode status: OrderStatus model_id: int - pose_id: int + pose_id: int | None garment_asset_id: int - scene_ref_asset_id: int + scene_ref_asset_id: int | None final_asset_id: int | None workflow_id: str | None current_step: WorkflowStepName | None diff --git a/app/application/services/image_generation_service.py b/app/application/services/image_generation_service.py new file mode 100644 index 0000000..6356f40 --- /dev/null +++ b/app/application/services/image_generation_service.py @@ -0,0 +1,163 @@ +"""Application-facing orchestration for image-generation providers.""" + +from __future__ import annotations + +import base64 + +import httpx +from fastapi import HTTPException, status + +from app.config.settings import get_settings +from app.infra.image_generation.base import GeneratedImageResult, SourceImage +from app.infra.image_generation.gemini_provider import GeminiImageProvider + +TRYON_PROMPT = ( + "Use the first image as the base person image and the second image as the garment reference. " + "Generate a realistic virtual try-on result where the person in the first image is wearing the garment from " + "the second image. Preserve the person's identity, body proportions, pose, camera angle, lighting direction, " + "and background composition from the first image. Preserve the garment's key silhouette, material feel, color, " + "and visible design details from the second image. Produce a clean e-commerce quality result without adding " + "extra accessories, text, watermarks, or duplicate limbs." +) + +SCENE_PROMPT = ( + "Use the first image as the finished subject image and the second image as the target scene reference. " + "Generate a realistic composite where the person, clothing, body proportions, facial identity, pose, and camera " + "framing from the first image are preserved, while the environment is adapted to match the second image. " + "Keep the garment details, colors, and texture from the first image intact. Blend the subject naturally into the " + "target scene with coherent perspective, lighting, shadows, and depth. Do not add extra people, accessories, " + "text, logos, watermarks, or duplicate body parts." +) + +MOCK_PNG_BASE64 = ( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO5Ww4QAAAAASUVORK5CYII=" +) + + +class MockImageGenerationService: + """Deterministic fallback used in tests and when no real provider is configured.""" + + async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> GeneratedImageResult: + return GeneratedImageResult( + image_bytes=base64.b64decode(MOCK_PNG_BASE64), + mime_type="image/png", + provider="mock", + model="mock-image-provider", + prompt=TRYON_PROMPT, + ) + + async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> GeneratedImageResult: + return GeneratedImageResult( + image_bytes=base64.b64decode(MOCK_PNG_BASE64), + mime_type="image/png", + provider="mock", + model="mock-image-provider", + prompt=SCENE_PROMPT, + ) + + +class ImageGenerationService: + """Download source assets and dispatch them to the configured provider.""" + + def __init__( + self, + *, + provider: GeminiImageProvider | None = None, + downloader: httpx.AsyncClient | None = None, + ) -> None: + self.settings = get_settings() + self.provider = provider or GeminiImageProvider( + api_key=self.settings.gemini_api_key, + base_url=self.settings.gemini_base_url, + model=self.settings.gemini_model, + timeout_seconds=self.settings.gemini_timeout_seconds, + max_attempts=self.settings.gemini_max_attempts, + ) + self._downloader = downloader + + async def generate_tryon_image( + self, + *, + person_image_url: str, + garment_image_url: str, + ) -> GeneratedImageResult: + """Generate a try-on image using the configured provider.""" + + if not self.settings.gemini_api_key: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="GEMINI_API_KEY is required when image_generation_provider=gemini", + ) + + person_image, garment_image = await self._download_inputs( + first_image_url=person_image_url, + second_image_url=garment_image_url, + ) + return await self.provider.generate_tryon_image( + prompt=TRYON_PROMPT, + person_image=person_image, + garment_image=garment_image, + ) + + async def generate_scene_image( + self, + *, + source_image_url: str, + scene_image_url: str, + ) -> GeneratedImageResult: + """Generate a scene-composited image using the configured provider.""" + + if not self.settings.gemini_api_key: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="GEMINI_API_KEY is required when image_generation_provider=gemini", + ) + + source_image, scene_image = await self._download_inputs( + first_image_url=source_image_url, + second_image_url=scene_image_url, + ) + return await self.provider.generate_scene_image( + prompt=SCENE_PROMPT, + source_image=source_image, + scene_image=scene_image, + ) + + async def _download_inputs( + self, + *, + first_image_url: str, + second_image_url: str, + ) -> tuple[SourceImage, SourceImage]: + owns_client = self._downloader is None + client = self._downloader or httpx.AsyncClient(timeout=self.settings.gemini_timeout_seconds) + try: + first_response = await client.get(first_image_url) + first_response.raise_for_status() + second_response = await client.get(second_image_url) + second_response.raise_for_status() + finally: + if owns_client: + await client.aclose() + + first_mime = first_response.headers.get("content-type", "image/png").split(";")[0].strip() + second_mime = second_response.headers.get("content-type", "image/png").split(";")[0].strip() + + return ( + SourceImage(url=first_image_url, mime_type=first_mime or "image/png", data=first_response.content), + SourceImage(url=second_image_url, mime_type=second_mime or "image/png", data=second_response.content), + ) + + +def build_image_generation_service(): + """Return the configured image-generation service.""" + + settings = get_settings() + if settings.image_generation_provider.lower() == "mock": + return MockImageGenerationService() + if settings.image_generation_provider.lower() == "gemini": + return ImageGenerationService() + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Unsupported image_generation_provider: {settings.image_generation_provider}", + ) diff --git a/app/application/services/library_service.py b/app/application/services/library_service.py new file mode 100644 index 0000000..5e6948f --- /dev/null +++ b/app/application/services/library_service.py @@ -0,0 +1,366 @@ +"""Library resource application service.""" + +from __future__ import annotations + +from fastapi import HTTPException, status +from sqlalchemy import func, or_, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.api.schemas.library import ( + ArchiveLibraryResourceResponse, + CreateLibraryResourceRequest, + LibraryResourceFileRead, + LibraryResourceListResponse, + LibraryResourceRead, + PresignUploadResponse, + UpdateLibraryResourceRequest, +) +from app.config.settings import get_settings +from app.domain.enums import LibraryFileRole, LibraryResourceStatus, LibraryResourceType +from app.infra.db.models.library_resource import LibraryResourceORM +from app.infra.db.models.library_resource_file import LibraryResourceFileORM +from app.infra.storage.s3 import RESOURCE_PREFIXES, S3PresignService + + +class LibraryService: + """Application service for resource-library uploads and queries.""" + + def __init__(self, presign_service: S3PresignService | None = None) -> None: + self.presign_service = presign_service or S3PresignService() + + def create_upload_presign( + self, + resource_type: LibraryResourceType, + file_name: str, + content_type: str, + ) -> PresignUploadResponse: + """Create upload metadata for a direct S3 PUT.""" + + settings = get_settings() + if not all( + [ + settings.s3_bucket, + settings.s3_region, + settings.s3_access_key, + settings.s3_secret_key, + settings.s3_endpoint, + ] + ): + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="S3 upload settings are incomplete", + ) + + storage_key, upload_url = self.presign_service.create_upload( + resource_type=resource_type, + file_name=file_name, + content_type=content_type, + ) + return PresignUploadResponse( + method="PUT", + upload_url=upload_url, + headers={"content-type": content_type}, + storage_key=storage_key, + public_url=self.presign_service.get_public_url(storage_key), + ) + + async def create_resource( + self, + session: AsyncSession, + payload: CreateLibraryResourceRequest, + ) -> LibraryResourceRead: + """Persist a resource and its uploaded file metadata.""" + + self._validate_payload(payload) + + resource = LibraryResourceORM( + resource_type=payload.resource_type, + name=payload.name.strip(), + description=payload.description.strip() if payload.description else None, + tags=[tag.strip() for tag in payload.tags if tag.strip()], + status=LibraryResourceStatus.ACTIVE, + gender=payload.gender.strip() if payload.gender else None, + age_group=payload.age_group.strip() if payload.age_group else None, + pose_id=payload.pose_id, + environment=payload.environment.strip() if payload.environment else None, + category=payload.category.strip() if payload.category else None, + ) + session.add(resource) + await session.flush() + + file_models: list[LibraryResourceFileORM] = [] + for item in payload.files: + file_model = LibraryResourceFileORM( + resource_id=resource.id, + file_role=item.file_role, + storage_key=item.storage_key, + public_url=item.public_url, + bucket=get_settings().s3_bucket, + mime_type=item.mime_type, + size_bytes=item.size_bytes, + sort_order=item.sort_order, + width=item.width, + height=item.height, + ) + session.add(file_model) + file_models.append(file_model) + + await session.flush() + resource.original_file_id = self._find_role(file_models, LibraryFileRole.ORIGINAL).id + resource.cover_file_id = self._find_role(file_models, LibraryFileRole.THUMBNAIL).id + await session.commit() + + return self._to_read(resource, files=file_models) + + async def list_resources( + self, + session: AsyncSession, + *, + resource_type: LibraryResourceType | None = None, + query: str | None = None, + gender: str | None = None, + age_group: str | None = None, + environment: str | None = None, + category: str | None = None, + page: int = 1, + limit: int = 20, + ) -> LibraryResourceListResponse: + """List persisted resources with simple filter support.""" + + filters = [LibraryResourceORM.status == LibraryResourceStatus.ACTIVE] + if resource_type is not None: + filters.append(LibraryResourceORM.resource_type == resource_type) + if query: + like = f"%{query.strip()}%" + filters.append( + or_( + LibraryResourceORM.name.ilike(like), + LibraryResourceORM.description.ilike(like), + ) + ) + if gender: + filters.append(LibraryResourceORM.gender == gender) + if age_group: + filters.append(LibraryResourceORM.age_group == age_group) + if environment: + filters.append(LibraryResourceORM.environment == environment) + if category: + filters.append(LibraryResourceORM.category == category) + + total = ( + await session.execute(select(func.count(LibraryResourceORM.id)).where(*filters)) + ).scalar_one() + + result = await session.execute( + select(LibraryResourceORM) + .options(selectinload(LibraryResourceORM.files)) + .where(*filters) + .order_by(LibraryResourceORM.created_at.desc(), LibraryResourceORM.id.desc()) + .offset((page - 1) * limit) + .limit(limit) + ) + items = result.scalars().all() + return LibraryResourceListResponse(total=total, items=[self._to_read(item) for item in items]) + + async def update_resource( + self, + session: AsyncSession, + resource_id: int, + payload: UpdateLibraryResourceRequest, + ) -> LibraryResourceRead: + """Update editable metadata on a library resource.""" + + resource = await self._get_resource_or_404(session, resource_id) + + if payload.name is not None: + name = payload.name.strip() + if not name: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Resource name is required", + ) + resource.name = name + + if payload.description is not None: + resource.description = payload.description.strip() or None + + if payload.tags is not None: + resource.tags = [tag.strip() for tag in payload.tags if tag.strip()] + + if resource.resource_type == LibraryResourceType.MODEL: + if payload.gender is not None: + resource.gender = payload.gender.strip() or None + if payload.age_group is not None: + resource.age_group = payload.age_group.strip() or None + if payload.pose_id is not None: + resource.pose_id = payload.pose_id + elif resource.resource_type == LibraryResourceType.SCENE: + if payload.environment is not None: + resource.environment = payload.environment.strip() or None + elif resource.resource_type == LibraryResourceType.GARMENT: + if payload.category is not None: + resource.category = payload.category.strip() or None + + if payload.cover_file_id is not None: + cover_file = next( + (file for file in resource.files if file.id == payload.cover_file_id), + None, + ) + if cover_file is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cover file does not belong to the resource", + ) + resource.cover_file_id = cover_file.id + + self._validate_existing_resource(resource) + await session.commit() + await session.refresh(resource, attribute_names=["files"]) + return self._to_read(resource) + + async def archive_resource( + self, + session: AsyncSession, + resource_id: int, + ) -> ArchiveLibraryResourceResponse: + """Soft delete a library resource by archiving it.""" + + resource = await self._get_resource_or_404(session, resource_id) + resource.status = LibraryResourceStatus.ARCHIVED + await session.commit() + return ArchiveLibraryResourceResponse(id=resource.id) + + def _validate_payload(self, payload: CreateLibraryResourceRequest) -> None: + if not payload.name.strip(): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Resource name is required") + + if not payload.files: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one file is required") + + originals = [file for file in payload.files if file.file_role == LibraryFileRole.ORIGINAL] + thumbnails = [file for file in payload.files if file.file_role == LibraryFileRole.THUMBNAIL] + if len(originals) != 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Exactly one original file is required", + ) + if len(thumbnails) != 1: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Exactly one thumbnail file is required", + ) + + expected_prefix = f"library/{RESOURCE_PREFIXES[payload.resource_type]}/" + for file in payload.files: + if not file.storage_key.startswith(expected_prefix): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Uploaded file key does not match resource type", + ) + + if payload.resource_type == LibraryResourceType.MODEL: + if not payload.gender or not payload.age_group: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Model resources require gender and age_group", + ) + elif payload.resource_type == LibraryResourceType.SCENE: + if not payload.environment: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Scene resources require environment", + ) + elif payload.resource_type == LibraryResourceType.GARMENT and not payload.category: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Garment resources require category", + ) + + def _find_role( + self, + files: list[LibraryResourceFileORM], + role: LibraryFileRole, + ) -> LibraryResourceFileORM: + return next(file for file in files if file.file_role == role) + + def _to_read( + self, + resource: LibraryResourceORM, + *, + files: list[LibraryResourceFileORM] | None = None, + ) -> LibraryResourceRead: + resource_files = files if files is not None else list(resource.files) + files_sorted = sorted(resource_files, key=lambda item: (item.sort_order, item.id)) + cover = next((item for item in files_sorted if item.id == resource.cover_file_id), None) + original = next((item for item in files_sorted if item.id == resource.original_file_id), None) + return LibraryResourceRead( + id=resource.id, + resource_type=resource.resource_type, + name=resource.name, + description=resource.description, + tags=resource.tags, + status=resource.status, + gender=resource.gender, + age_group=resource.age_group, + pose_id=resource.pose_id, + environment=resource.environment, + category=resource.category, + cover_url=cover.public_url if cover else None, + original_url=original.public_url if original else None, + files=[ + LibraryResourceFileRead( + id=file.id, + file_role=file.file_role, + storage_key=file.storage_key, + public_url=file.public_url, + bucket=file.bucket, + mime_type=file.mime_type, + size_bytes=file.size_bytes, + sort_order=file.sort_order, + width=file.width, + height=file.height, + created_at=file.created_at, + ) + for file in files_sorted + ], + created_at=resource.created_at, + updated_at=resource.updated_at, + ) + + async def _get_resource_or_404( + self, + session: AsyncSession, + resource_id: int, + ) -> LibraryResourceORM: + result = await session.execute( + select(LibraryResourceORM) + .options(selectinload(LibraryResourceORM.files)) + .where(LibraryResourceORM.id == resource_id) + ) + resource = result.scalar_one_or_none() + if resource is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found") + return resource + + def _validate_existing_resource(self, resource: LibraryResourceORM) -> None: + if not resource.name.strip(): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Resource name is required") + + if resource.resource_type == LibraryResourceType.MODEL: + if not resource.gender or not resource.age_group: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Model resources require gender and age_group", + ) + elif resource.resource_type == LibraryResourceType.SCENE: + if not resource.environment: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Scene resources require environment", + ) + elif resource.resource_type == LibraryResourceType.GARMENT and not resource.category: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Garment resources require category", + ) diff --git a/app/config/settings.py b/app/config/settings.py index fa301db..57838d0 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -1,9 +1,12 @@ """Application settings.""" from functools import lru_cache +from pathlib import Path from pydantic_settings import BaseSettings, SettingsConfigDict +PROJECT_ROOT = Path(__file__).resolve().parents[2] + class Settings(BaseSettings): """Runtime settings loaded from environment variables.""" @@ -15,11 +18,25 @@ class Settings(BaseSettings): database_url: str = "sqlite+aiosqlite:///./temporal_demo.db" temporal_address: str = "localhost:7233" temporal_namespace: str = "default" + s3_access_key: str = "" + s3_secret_key: str = "" + s3_bucket: str = "" + s3_region: str = "" + s3_endpoint: str = "" + s3_cname: str = "" + s3_presign_expiry_seconds: int = 900 + image_generation_provider: str = "mock" + gemini_api_key: str = "" + gemini_base_url: str = "https://api.museidea.com/v1beta" + gemini_model: str = "gemini-3.1-flash-image-preview" + gemini_timeout_seconds: int = 300 + gemini_max_attempts: int = 2 model_config = SettingsConfigDict( - env_file=".env", + env_file=PROJECT_ROOT / ".env", env_file_encoding="utf-8", case_sensitive=False, + extra="ignore", ) @property @@ -34,4 +51,3 @@ def get_settings() -> Settings: """Return the cached application settings.""" return Settings() - diff --git a/app/domain/enums.py b/app/domain/enums.py index bde6a8b..17b6449 100644 --- a/app/domain/enums.py +++ b/app/domain/enums.py @@ -82,3 +82,26 @@ class AssetType(str, Enum): QC_CANDIDATE = "qc_candidate" MANUAL_REVISION = "manual_revision" FINAL = "final" + + +class LibraryResourceType(str, Enum): + """Supported resource-library item types.""" + + MODEL = "model" + SCENE = "scene" + GARMENT = "garment" + + +class LibraryFileRole(str, Enum): + """Supported file roles within a library resource.""" + + ORIGINAL = "original" + THUMBNAIL = "thumbnail" + GALLERY = "gallery" + + +class LibraryResourceStatus(str, Enum): + """Lifecycle state for a library resource.""" + + ACTIVE = "active" + ARCHIVED = "archived" diff --git a/app/domain/models/order.py b/app/domain/models/order.py index 08e599f..46a2703 100644 --- a/app/domain/models/order.py +++ b/app/domain/models/order.py @@ -15,12 +15,11 @@ class Order: service_mode: ServiceMode status: OrderStatus model_id: int - pose_id: int + pose_id: int | None garment_asset_id: int - scene_ref_asset_id: int + scene_ref_asset_id: int | None final_asset_id: int | None workflow_id: str | None current_step: WorkflowStepName | None created_at: datetime updated_at: datetime - diff --git a/app/infra/db/models/library_resource.py b/app/infra/db/models/library_resource.py new file mode 100644 index 0000000..e24ce5b --- /dev/null +++ b/app/infra/db/models/library_resource.py @@ -0,0 +1,45 @@ +"""Library resource ORM model.""" + +from __future__ import annotations + +from sqlalchemy import Enum, Integer, JSON, String, Text +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.domain.enums import LibraryResourceStatus, LibraryResourceType +from app.infra.db.base import Base, TimestampMixin + + +class LibraryResourceORM(TimestampMixin, Base): + """Persisted library resource independent from order-generated assets.""" + + __tablename__ = "library_resources" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + resource_type: Mapped[LibraryResourceType] = mapped_column( + Enum(LibraryResourceType, native_enum=False), + nullable=False, + index=True, + ) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + tags: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list) + status: Mapped[LibraryResourceStatus] = mapped_column( + Enum(LibraryResourceStatus, native_enum=False), + nullable=False, + default=LibraryResourceStatus.ACTIVE, + index=True, + ) + gender: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + age_group: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + pose_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + environment: Mapped[str | None] = mapped_column(String(32), nullable=True, index=True) + category: Mapped[str | None] = mapped_column(String(128), nullable=True, index=True) + cover_file_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + original_file_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + + files = relationship( + "LibraryResourceFileORM", + back_populates="resource", + lazy="selectin", + cascade="all, delete-orphan", + ) diff --git a/app/infra/db/models/library_resource_file.py b/app/infra/db/models/library_resource_file.py new file mode 100644 index 0000000..e5f351f --- /dev/null +++ b/app/infra/db/models/library_resource_file.py @@ -0,0 +1,37 @@ +"""Library resource file ORM model.""" + +from __future__ import annotations + +from sqlalchemy import Enum, ForeignKey, Integer, String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.domain.enums import LibraryFileRole +from app.infra.db.base import Base, TimestampMixin + + +class LibraryResourceFileORM(TimestampMixin, Base): + """Persisted uploaded file metadata for a library resource.""" + + __tablename__ = "library_resource_files" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + resource_id: Mapped[int] = mapped_column( + ForeignKey("library_resources.id"), + nullable=False, + index=True, + ) + file_role: Mapped[LibraryFileRole] = mapped_column( + Enum(LibraryFileRole, native_enum=False), + nullable=False, + index=True, + ) + storage_key: Mapped[str] = mapped_column(String(500), nullable=False) + public_url: Mapped[str] = mapped_column(String(500), nullable=False) + bucket: Mapped[str] = mapped_column(String(255), nullable=False) + mime_type: Mapped[str] = mapped_column(String(255), nullable=False) + size_bytes: Mapped[int] = mapped_column(Integer, nullable=False) + sort_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + width: Mapped[int | None] = mapped_column(Integer, nullable=True) + height: Mapped[int | None] = mapped_column(Integer, nullable=True) + + resource = relationship("LibraryResourceORM", back_populates="files") diff --git a/app/infra/db/models/order.py b/app/infra/db/models/order.py index e3e2c55..893d259 100644 --- a/app/infra/db/models/order.py +++ b/app/infra/db/models/order.py @@ -27,12 +27,11 @@ class OrderORM(TimestampMixin, Base): default=OrderStatus.CREATED, ) model_id: Mapped[int] = mapped_column(Integer, nullable=False) - pose_id: Mapped[int] = mapped_column(Integer, nullable=False) + pose_id: Mapped[int | None] = mapped_column(Integer, nullable=True) garment_asset_id: Mapped[int] = mapped_column(Integer, nullable=False) - scene_ref_asset_id: Mapped[int] = mapped_column(Integer, nullable=False) + scene_ref_asset_id: Mapped[int | None] = mapped_column(Integer, nullable=True) final_asset_id: Mapped[int | None] = mapped_column(Integer, nullable=True) assets = relationship("AssetORM", back_populates="order", lazy="selectin") review_tasks = relationship("ReviewTaskORM", back_populates="order", lazy="selectin") workflow_runs = relationship("WorkflowRunORM", back_populates="order", lazy="selectin") - diff --git a/app/infra/db/session.py b/app/infra/db/session.py index 78b092f..d207720 100644 --- a/app/infra/db/session.py +++ b/app/infra/db/session.py @@ -44,12 +44,14 @@ async def init_database() -> None: """Create database tables when running the MVP without migrations.""" from app.infra.db.models.asset import AssetORM + from app.infra.db.models.library_resource import LibraryResourceORM + from app.infra.db.models.library_resource_file import LibraryResourceFileORM from app.infra.db.models.order import OrderORM from app.infra.db.models.review_task import ReviewTaskORM from app.infra.db.models.workflow_run import WorkflowRunORM from app.infra.db.models.workflow_step import WorkflowStepORM - del AssetORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM + del AssetORM, LibraryResourceORM, LibraryResourceFileORM, OrderORM, ReviewTaskORM, WorkflowRunORM, WorkflowStepORM async with get_async_engine().begin() as connection: await connection.run_sync(Base.metadata.create_all) diff --git a/app/infra/image_generation/base.py b/app/infra/image_generation/base.py new file mode 100644 index 0000000..13ea73a --- /dev/null +++ b/app/infra/image_generation/base.py @@ -0,0 +1,38 @@ +"""Shared types for image-generation providers.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol + + +@dataclass(slots=True) +class SourceImage: + """A binary source image passed into an image-generation provider.""" + + url: str + mime_type: str + data: bytes + + +@dataclass(slots=True) +class GeneratedImageResult: + """Normalized image-generation output returned by providers.""" + + image_bytes: bytes + mime_type: str + provider: str + model: str + prompt: str + + +class ImageGenerationProvider(Protocol): + """Contract implemented by concrete image-generation providers.""" + + async def generate_tryon_image( + self, + *, + prompt: str, + person_image: SourceImage, + garment_image: SourceImage, + ) -> GeneratedImageResult: ... diff --git a/app/infra/image_generation/gemini_provider.py b/app/infra/image_generation/gemini_provider.py new file mode 100644 index 0000000..5abecef --- /dev/null +++ b/app/infra/image_generation/gemini_provider.py @@ -0,0 +1,149 @@ +"""Gemini-backed image-generation provider.""" + +from __future__ import annotations + +import base64 + +import httpx + +from app.infra.image_generation.base import GeneratedImageResult, SourceImage + + +class GeminiImageProvider: + """Call the Gemini image-generation API via a configurable REST endpoint.""" + + provider_name = "gemini" + + def __init__( + self, + *, + api_key: str, + base_url: str, + model: str, + timeout_seconds: int, + max_attempts: int = 2, + http_client: httpx.AsyncClient | None = None, + ) -> None: + self.api_key = api_key + self.base_url = base_url.rstrip("/") + self.model = model + self.timeout_seconds = timeout_seconds + self.max_attempts = max(1, max_attempts) + self._http_client = http_client + + async def generate_tryon_image( + self, + *, + prompt: str, + person_image: SourceImage, + garment_image: SourceImage, + ) -> GeneratedImageResult: + """Generate a try-on image from a prepared person image and a garment image.""" + + return await self._generate_two_image_edit( + prompt=prompt, + first_image=person_image, + second_image=garment_image, + ) + + async def generate_scene_image( + self, + *, + prompt: str, + source_image: SourceImage, + scene_image: SourceImage, + ) -> GeneratedImageResult: + """Generate a scene-composited image from a rendered subject image and a scene reference.""" + + return await self._generate_two_image_edit( + prompt=prompt, + first_image=source_image, + second_image=scene_image, + ) + + async def _generate_two_image_edit( + self, + *, + prompt: str, + first_image: SourceImage, + second_image: SourceImage, + ) -> GeneratedImageResult: + """Generate an edited image from a prompt plus two inline image inputs.""" + + payload = { + "contents": [ + { + "parts": [ + {"text": prompt}, + { + "inline_data": { + "mime_type": first_image.mime_type, + "data": base64.b64encode(first_image.data).decode("utf-8"), + } + }, + { + "inline_data": { + "mime_type": second_image.mime_type, + "data": base64.b64encode(second_image.data).decode("utf-8"), + } + }, + ] + } + ], + "generationConfig": { + "responseModalities": ["TEXT", "IMAGE"], + }, + } + + owns_client = self._http_client is None + client = self._http_client or httpx.AsyncClient(timeout=self.timeout_seconds) + try: + response = None + for attempt in range(1, self.max_attempts + 1): + try: + response = await client.post( + f"{self.base_url}/models/{self.model}:generateContent", + headers={ + "x-goog-api-key": self.api_key, + "Content-Type": "application/json", + }, + json=payload, + timeout=self.timeout_seconds, + ) + response.raise_for_status() + break + except httpx.TransportError: + if attempt >= self.max_attempts: + raise + finally: + if owns_client: + await client.aclose() + + if response is None: + raise RuntimeError("Gemini provider did not receive a response") + + body = response.json() + image_part = self._find_image_part(body) + image_data = image_part.get("inlineData") or image_part.get("inline_data") + mime_type = image_data.get("mimeType") or image_data.get("mime_type") or "image/png" + data = image_data.get("data") + if not data: + raise ValueError("Gemini response did not include image bytes") + + return GeneratedImageResult( + image_bytes=base64.b64decode(data), + mime_type=mime_type, + provider=self.provider_name, + model=self.model, + prompt=prompt, + ) + + @staticmethod + def _find_image_part(body: dict) -> dict: + candidates = body.get("candidates") or [] + for candidate in candidates: + content = candidate.get("content") or {} + for part in content.get("parts") or []: + if part.get("inlineData") or part.get("inline_data"): + return part + raise ValueError("Gemini response did not contain an image part") diff --git a/app/infra/storage/s3.py b/app/infra/storage/s3.py new file mode 100644 index 0000000..68594ca --- /dev/null +++ b/app/infra/storage/s3.py @@ -0,0 +1,107 @@ +"""S3 direct-upload helpers.""" + +from __future__ import annotations + +from pathlib import Path +from uuid import uuid4 + +import boto3 + +from app.config.settings import get_settings +from app.domain.enums import LibraryResourceType, WorkflowStepName + +RESOURCE_PREFIXES: dict[LibraryResourceType, str] = { + LibraryResourceType.MODEL: "models", + LibraryResourceType.SCENE: "scenes", + LibraryResourceType.GARMENT: "garments", +} + + +class S3PresignService: + """Generate presigned upload URLs and derived public URLs.""" + + def __init__(self) -> None: + self.settings = get_settings() + self._client = boto3.client( + "s3", + region_name=self.settings.s3_region or None, + endpoint_url=self.settings.s3_endpoint or None, + aws_access_key_id=self.settings.s3_access_key or None, + aws_secret_access_key=self.settings.s3_secret_key or None, + ) + + def create_upload(self, resource_type: LibraryResourceType, file_name: str, content_type: str) -> tuple[str, str]: + """Return a storage key and presigned PUT URL for a resource file.""" + + storage_key = self._build_storage_key(resource_type, file_name) + upload_url = self._client.generate_presigned_url( + "put_object", + Params={ + "Bucket": self.settings.s3_bucket, + "Key": storage_key, + "ContentType": content_type, + }, + ExpiresIn=self.settings.s3_presign_expiry_seconds, + HttpMethod="PUT", + ) + return storage_key, upload_url + + def get_public_url(self, storage_key: str) -> str: + """Return the public CDN URL for an uploaded object.""" + + if self.settings.s3_cname: + base = self.settings.s3_cname + if not base.startswith("http://") and not base.startswith("https://"): + base = f"https://{base}" + return f"{base.rstrip('/')}/{storage_key}" + + endpoint = self.settings.s3_endpoint.rstrip("/") + return f"{endpoint}/{self.settings.s3_bucket}/{storage_key}" + + def _build_storage_key(self, resource_type: LibraryResourceType, file_name: str) -> str: + suffix = Path(file_name).suffix or ".bin" + stem = Path(file_name).stem.replace(" ", "-").lower() or "file" + return f"library/{RESOURCE_PREFIXES[resource_type]}/{uuid4().hex}-{stem}{suffix.lower()}" + + +class S3ObjectStorageService: + """Upload generated workflow artifacts to the configured object store.""" + + def __init__(self) -> None: + self.settings = get_settings() + self._client = boto3.client( + "s3", + region_name=self.settings.s3_region or None, + endpoint_url=self.settings.s3_endpoint or None, + aws_access_key_id=self.settings.s3_access_key or None, + aws_secret_access_key=self.settings.s3_secret_key or None, + ) + self._presign = S3PresignService() + + async def upload_generated_image( + self, + *, + order_id: int, + step_name: WorkflowStepName, + image_bytes: bytes, + mime_type: str, + ) -> tuple[str, str]: + """Upload bytes and return the storage key plus public URL.""" + + storage_key = self._build_storage_key(order_id=order_id, step_name=step_name, mime_type=mime_type) + self._client.put_object( + Bucket=self.settings.s3_bucket, + Key=storage_key, + Body=image_bytes, + ContentType=mime_type, + ) + return storage_key, self._presign.get_public_url(storage_key) + + @staticmethod + def _build_storage_key(*, order_id: int, step_name: WorkflowStepName, mime_type: str) -> str: + suffix = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/webp": ".webp", + }.get(mime_type, ".bin") + return f"orders/{order_id}/{step_name.value}/{uuid4().hex}{suffix}" diff --git a/app/main.py b/app/main.py index dd2ab7e..e64fcb5 100644 --- a/app/main.py +++ b/app/main.py @@ -6,6 +6,7 @@ from fastapi import FastAPI from app.api.routers.assets import router as assets_router from app.api.routers.health import router as health_router +from app.api.routers.library import router as library_router from app.api.routers.orders import router as orders_router from app.api.routers.revisions import router as revisions_router from app.api.routers.reviews import router as reviews_router @@ -30,6 +31,7 @@ def create_app() -> FastAPI: settings = get_settings() app = FastAPI(title=settings.app_name, debug=settings.debug, lifespan=lifespan) app.include_router(health_router) + app.include_router(library_router, prefix=settings.api_prefix) app.include_router(orders_router, prefix=settings.api_prefix) app.include_router(assets_router, prefix=settings.api_prefix) app.include_router(revisions_router, prefix=settings.api_prefix) diff --git a/app/workers/activities/export_activities.py b/app/workers/activities/export_activities.py index 605ff03..17a5bec 100644 --- a/app/workers/activities/export_activities.py +++ b/app/workers/activities/export_activities.py @@ -1,20 +1,82 @@ -"""Export mock activity.""" +"""Export activity.""" +from app.domain.enums import AssetType, OrderStatus, StepStatus +from app.infra.db.models.asset import AssetORM +from app.infra.db.session import get_session_factory from temporalio import activity -from app.domain.enums import AssetType -from app.workers.activities.tryon_activities import execute_asset_step +from app.workers.activities.tryon_activities import create_step_record, jsonable, load_order_and_run, utc_now from app.workers.workflows.types import MockActivityResult, StepActivityInput @activity.defn async def run_export_activity(payload: StepActivityInput) -> MockActivityResult: - """Mock final asset export.""" + """Finalize the chosen source asset as the order's exported deliverable.""" - return await execute_asset_step( - payload, - AssetType.FINAL, - filename="final.png", - finalize=True, - ) + if payload.source_asset_id is None: + raise ValueError("run_export_activity requires source_asset_id") + async with get_session_factory()() as session: + order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id) + step = create_step_record(payload) + session.add(step) + + order.status = OrderStatus.RUNNING + workflow_run.status = OrderStatus.RUNNING + workflow_run.current_step = payload.step_name + await session.flush() + + try: + source_asset = await session.get(AssetORM, payload.source_asset_id) + if source_asset is None: + raise ValueError(f"Source asset {payload.source_asset_id} not found") + if source_asset.order_id != payload.order_id: + raise ValueError( + f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}" + ) + + metadata = { + **payload.metadata, + "source_asset_id": payload.source_asset_id, + "selected_asset_id": payload.selected_asset_id, + } + metadata = {key: value for key, value in metadata.items() if value is not None} + + asset = AssetORM( + order_id=payload.order_id, + asset_type=AssetType.FINAL, + step_name=payload.step_name, + uri=source_asset.uri, + metadata_json=jsonable(metadata), + ) + session.add(asset) + await session.flush() + + result = MockActivityResult( + step_name=payload.step_name, + success=True, + asset_id=asset.id, + uri=asset.uri, + score=0.95, + passed=True, + message="mock success", + metadata=jsonable(metadata) or {}, + ) + + order.final_asset_id = asset.id + order.status = OrderStatus.SUCCEEDED + workflow_run.status = OrderStatus.SUCCEEDED + + step.step_status = StepStatus.SUCCEEDED + step.output_json = jsonable(result) + step.ended_at = utc_now() + await session.commit() + return result + except Exception as exc: + step.step_status = StepStatus.FAILED + step.error_message = str(exc) + step.ended_at = utc_now() + order.status = OrderStatus.FAILED + workflow_run.status = OrderStatus.FAILED + await session.commit() + raise diff --git a/app/workers/activities/qc_activities.py b/app/workers/activities/qc_activities.py index 1b10e1f..8e99044 100644 --- a/app/workers/activities/qc_activities.py +++ b/app/workers/activities/qc_activities.py @@ -29,11 +29,22 @@ async def run_qc_activity(payload: StepActivityInput) -> MockActivityResult: candidate_uri: str | None = None if passed: + if payload.source_asset_id is None: + raise ValueError("run_qc_activity requires source_asset_id") + + source_asset = await session.get(AssetORM, payload.source_asset_id) + if source_asset is None: + raise ValueError(f"Source asset {payload.source_asset_id} not found") + if source_asset.order_id != payload.order_id: + raise ValueError( + f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}" + ) + candidate = AssetORM( order_id=payload.order_id, asset_type=AssetType.QC_CANDIDATE, step_name=payload.step_name, - uri=mock_uri(payload.order_id, payload.step_name.value, "candidate.png"), + uri=source_asset.uri, metadata_json=jsonable({"source_asset_id": payload.source_asset_id}), ) session.add(candidate) diff --git a/app/workers/activities/scene_activities.py b/app/workers/activities/scene_activities.py index 05e81bf..caeaa90 100644 --- a/app/workers/activities/scene_activities.py +++ b/app/workers/activities/scene_activities.py @@ -1,19 +1,122 @@ -"""Scene mock activity.""" +"""Scene activities.""" from temporalio import activity -from app.domain.enums import AssetType -from app.workers.activities.tryon_activities import execute_asset_step +from app.domain.enums import AssetType, LibraryResourceType, OrderStatus, StepStatus +from app.infra.db.models.asset import AssetORM +from app.infra.db.session import get_session_factory +from app.workers.activities.tryon_activities import ( + create_step_record, + execute_asset_step, + get_image_generation_service, + get_order_artifact_storage_service, + jsonable, + load_active_library_resource, + load_order_and_run, + utc_now, +) from app.workers.workflows.types import MockActivityResult, StepActivityInput @activity.defn async def run_scene_activity(payload: StepActivityInput) -> MockActivityResult: - """Mock scene replacement.""" + """Generate a scene-composited asset, or fall back to mock mode when configured.""" - return await execute_asset_step( - payload, - AssetType.SCENE, - extra_metadata={"scene_ref_asset_id": payload.scene_ref_asset_id}, - ) + service = get_image_generation_service() + if service.__class__.__name__ == "MockImageGenerationService": + return await execute_asset_step( + payload, + AssetType.SCENE, + extra_metadata={"scene_ref_asset_id": payload.scene_ref_asset_id}, + ) + if payload.source_asset_id is None: + raise ValueError("run_scene_activity requires source_asset_id") + if payload.scene_ref_asset_id is None: + raise ValueError("run_scene_activity requires scene_ref_asset_id") + + async with get_session_factory()() as session: + order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id) + step = create_step_record(payload) + session.add(step) + + order.status = OrderStatus.RUNNING + workflow_run.status = OrderStatus.RUNNING + workflow_run.current_step = payload.step_name + await session.flush() + + try: + source_asset = await session.get(AssetORM, payload.source_asset_id) + if source_asset is None: + raise ValueError(f"Source asset {payload.source_asset_id} not found") + if source_asset.order_id != payload.order_id: + raise ValueError( + f"Source asset {payload.source_asset_id} does not belong to order {payload.order_id}" + ) + + scene_resource, scene_original = await load_active_library_resource( + session, + payload.scene_ref_asset_id, + resource_type=LibraryResourceType.SCENE, + ) + + generated = await service.generate_scene_image( + source_image_url=source_asset.uri, + scene_image_url=scene_original.public_url, + ) + storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image( + order_id=payload.order_id, + step_name=payload.step_name, + image_bytes=generated.image_bytes, + mime_type=generated.mime_type, + ) + + metadata = { + **payload.metadata, + "source_asset_id": source_asset.id, + "scene_ref_asset_id": payload.scene_ref_asset_id, + "scene_resource_id": scene_resource.id, + "scene_original_file_id": scene_original.id, + "scene_original_url": scene_original.public_url, + "provider": generated.provider, + "model": generated.model, + "storage_key": storage_key, + "mime_type": generated.mime_type, + "prompt": generated.prompt, + } + metadata = {key: value for key, value in metadata.items() if value is not None} + + asset = AssetORM( + order_id=payload.order_id, + asset_type=AssetType.SCENE, + step_name=payload.step_name, + uri=public_url, + metadata_json=jsonable(metadata), + ) + session.add(asset) + await session.flush() + + result = MockActivityResult( + step_name=payload.step_name, + success=True, + asset_id=asset.id, + uri=asset.uri, + score=1.0, + passed=True, + message="scene generated", + metadata=jsonable(metadata) or {}, + ) + + step.step_status = StepStatus.SUCCEEDED + step.output_json = jsonable(result) + step.ended_at = utc_now() + await session.commit() + return result + except Exception as exc: + step.step_status = StepStatus.FAILED + step.error_message = str(exc) + step.ended_at = utc_now() + order.status = OrderStatus.FAILED + workflow_run.status = OrderStatus.FAILED + await session.commit() + raise diff --git a/app/workers/activities/tryon_activities.py b/app/workers/activities/tryon_activities.py index 412ad2e..df3a09d 100644 --- a/app/workers/activities/tryon_activities.py +++ b/app/workers/activities/tryon_activities.py @@ -1,4 +1,4 @@ -"""Prepare-model and try-on mock activities plus shared helpers.""" +"""Prepare-model and try-on activities plus shared helpers.""" from __future__ import annotations @@ -8,14 +8,20 @@ from enum import Enum from typing import Any from uuid import uuid4 +from sqlalchemy import select +from sqlalchemy.orm import selectinload from temporalio import activity -from app.domain.enums import AssetType, OrderStatus, StepStatus +from app.application.services.image_generation_service import build_image_generation_service +from app.domain.enums import AssetType, LibraryResourceStatus, LibraryResourceType, OrderStatus, StepStatus from app.infra.db.models.asset import AssetORM +from app.infra.db.models.library_resource import LibraryResourceORM +from app.infra.db.models.library_resource_file import LibraryResourceFileORM from app.infra.db.models.order import OrderORM from app.infra.db.models.workflow_run import WorkflowRunORM from app.infra.db.models.workflow_step import WorkflowStepORM from app.infra.db.session import get_session_factory +from app.infra.storage.s3 import S3ObjectStorageService from app.workers.workflows.types import MockActivityResult, StepActivityInput @@ -71,6 +77,64 @@ def create_step_record(payload: StepActivityInput) -> WorkflowStepORM: ) +async def load_active_library_resource( + session, + resource_id: int, + *, + resource_type: LibraryResourceType, +) -> tuple[LibraryResourceORM, LibraryResourceFileORM]: + """Load an active library resource and its original file from the library.""" + + result = await session.execute( + select(LibraryResourceORM) + .options(selectinload(LibraryResourceORM.files)) + .where(LibraryResourceORM.id == resource_id) + ) + resource = result.scalar_one_or_none() + if resource is None: + raise ValueError(f"Library resource {resource_id} not found") + if resource.resource_type != resource_type: + raise ValueError(f"Resource {resource_id} is not a {resource_type.value} resource") + if resource.status != LibraryResourceStatus.ACTIVE: + raise ValueError(f"Library resource {resource_id} is not active") + if resource.original_file_id is None: + raise ValueError(f"Library resource {resource_id} is missing an original file") + + original_file = next((item for item in resource.files if item.id == resource.original_file_id), None) + if original_file is None: + raise ValueError(f"Library resource {resource_id} original file record not found") + return resource, original_file + + +def get_image_generation_service(): + """Return the configured image-generation service.""" + + return build_image_generation_service() + + +def get_order_artifact_storage_service() -> S3ObjectStorageService: + """Return the object-storage service for workflow-generated images.""" + + return S3ObjectStorageService() + + +def build_resource_input_snapshot( + resource: LibraryResourceORM, + original_file: LibraryResourceFileORM, +) -> dict[str, Any]: + """Build a frontend-friendly snapshot of one library input resource.""" + + return { + "resource_id": resource.id, + "resource_name": resource.name, + "original_file_id": original_file.id, + "original_url": original_file.public_url, + "mime_type": original_file.mime_type, + "width": original_file.width, + "height": original_file.height, + } + + async def execute_asset_step( payload: StepActivityInput, asset_type: AssetType, @@ -145,26 +209,213 @@ async def execute_asset_step( @activity.defn async def prepare_model_activity(payload: StepActivityInput) -> MockActivityResult: - """Mock model preparation for the pipeline.""" + """Resolve a model resource into an order-scoped prepared-model asset.""" - return await execute_asset_step( - payload, - AssetType.PREPARED_MODEL, - extra_metadata={ - "model_id": payload.model_id, - "pose_id": payload.pose_id, - "garment_asset_id": payload.garment_asset_id, - "scene_ref_asset_id": payload.scene_ref_asset_id, - }, - ) + if payload.model_id is None: + raise ValueError("prepare_model_activity requires model_id") + + async with get_session_factory()() as session: + order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id) + step = create_step_record(payload) + session.add(step) + + order.status = OrderStatus.RUNNING + workflow_run.status = OrderStatus.RUNNING + workflow_run.current_step = payload.step_name + await session.flush() + + try: + resource, original_file = await load_active_library_resource( + session, + payload.model_id, + resource_type=LibraryResourceType.MODEL, + ) + garment_snapshot = None + if payload.garment_asset_id is not None: + garment_resource, garment_original = await load_active_library_resource( + session, + payload.garment_asset_id, + resource_type=LibraryResourceType.GARMENT, + ) + garment_snapshot = build_resource_input_snapshot(garment_resource, garment_original) + + scene_snapshot = None + if payload.scene_ref_asset_id is not None: + scene_resource, scene_original = await load_active_library_resource( + session, + payload.scene_ref_asset_id, + resource_type=LibraryResourceType.SCENE, + ) + scene_snapshot = build_resource_input_snapshot(scene_resource, scene_original) + + metadata = { + **payload.metadata, + "model_id": payload.model_id, + "pose_id": payload.pose_id, + "garment_asset_id": payload.garment_asset_id, + "scene_ref_asset_id": payload.scene_ref_asset_id, + "library_resource_id": resource.id, + "library_original_file_id": original_file.id, + "library_original_url": original_file.public_url, + "library_original_mime_type": original_file.mime_type, + "library_original_width": original_file.width, + "library_original_height": original_file.height, + "model_input": build_resource_input_snapshot(resource, original_file), + "garment_input": garment_snapshot, + "scene_input": scene_snapshot, + "pose_input": {"pose_id": payload.pose_id} if payload.pose_id is not None else None, + "normalized": False, + } + metadata = {key: value for key, value in metadata.items() if value is not None} + + asset = AssetORM( + order_id=payload.order_id, + asset_type=AssetType.PREPARED_MODEL, + step_name=payload.step_name, + uri=original_file.public_url, + metadata_json=jsonable(metadata), + ) + session.add(asset) + await session.flush() + + result = MockActivityResult( + step_name=payload.step_name, + success=True, + asset_id=asset.id, + uri=asset.uri, + score=1.0, + passed=True, + message="prepared model ready", + metadata=jsonable(metadata) or {}, + ) + + step.step_status = StepStatus.SUCCEEDED + step.output_json = jsonable(result) + step.ended_at = utc_now() + await session.commit() + return result + except Exception as exc: + step.step_status = StepStatus.FAILED + step.error_message = str(exc) + step.ended_at = utc_now() + order.status = OrderStatus.FAILED + workflow_run.status = OrderStatus.FAILED + await session.commit() + raise @activity.defn async def run_tryon_activity(payload: StepActivityInput) -> MockActivityResult: - """Mock try-on rendering.""" + """执行试衣渲染步骤,或在未接真实能力时走 mock 分支。 - return await execute_asset_step( - payload, - AssetType.TRYON, - extra_metadata={"prepared_asset_id": payload.source_asset_id}, - ) + 流程: + 1. 读取当前配置的图片生成 service。 + 2. 如果当前仍是 mock service,就退回通用的 mock 资产产出逻辑。 + 3. 如果是真实模式,先读取 prepare_model_activity 产出的 prepared_model 资产。 + 4. 再读取本次选中的服装资源,并解析出它的原图 URL。 + 5. 把“模特准备图 + 服装原图”一起发给 provider 做试衣生成。 + 6. 把生成结果上传到 S3,并落成订单内的一条 TRYON 资产。 + 7. 在 metadata 里记录 provider、model、prompt 和输入来源,方便追踪排查。 + """ + + service = get_image_generation_service() + # 保留旧的 mock 分支,这样在没有接入真实 provider 时 workflow 也还能跑通。 + if service.__class__.__name__ == "MockImageGenerationService": + return await execute_asset_step( + payload, + AssetType.TRYON, + extra_metadata={"prepared_asset_id": payload.source_asset_id}, + ) + + if payload.source_asset_id is None: + raise ValueError("run_tryon_activity requires source_asset_id") + if payload.garment_asset_id is None: + raise ValueError("run_tryon_activity requires garment_asset_id") + + async with get_session_factory()() as session: + order, workflow_run = await load_order_and_run(session, payload.order_id, payload.workflow_run_id) + step = create_step_record(payload) + session.add(step) + + order.status = OrderStatus.RUNNING + workflow_run.status = OrderStatus.RUNNING + workflow_run.current_step = payload.step_name + await session.flush() + + try: + # 试衣步骤的输入起点固定是上一阶段产出的 prepared_model 资产。 + prepared_asset = await session.get(AssetORM, payload.source_asset_id) + if prepared_asset is None or prepared_asset.order_id != payload.order_id: + raise ValueError(f"Prepared asset {payload.source_asset_id} not found for order {payload.order_id}") + if prepared_asset.asset_type != AssetType.PREPARED_MODEL: + raise ValueError(f"Asset {payload.source_asset_id} is not a prepared_model asset") + + # 服装素材来自共享资源库,workflow 实际消费的是资源库里的原图。 + garment_resource, garment_original = await load_active_library_resource( + session, + payload.garment_asset_id, + resource_type=LibraryResourceType.GARMENT, + ) + + # provider 接收两张真实输入图:模特准备图 + 服装参考图。 + generated = await service.generate_tryon_image( + person_image_url=prepared_asset.uri, + garment_image_url=garment_original.public_url, + ) + # workflow 生成出的结果图先上传到 S3,再登记成订单资产。 + storage_key, public_url = await get_order_artifact_storage_service().upload_generated_image( + order_id=payload.order_id, + step_name=payload.step_name, + image_bytes=generated.image_bytes, + mime_type=generated.mime_type, + ) + + # 记录足够的追踪信息,方便回溯这张试衣图由哪些输入和哪个 provider 产出。 + metadata = { + **payload.metadata, + "prepared_asset_id": prepared_asset.id, + "garment_resource_id": garment_resource.id, + "garment_original_file_id": garment_original.id, + "garment_original_url": garment_original.public_url, + "provider": generated.provider, + "model": generated.model, + "storage_key": storage_key, + "mime_type": generated.mime_type, + "prompt": generated.prompt, + } + metadata = {key: value for key, value in metadata.items() if value is not None} + + asset = AssetORM( + order_id=payload.order_id, + asset_type=AssetType.TRYON, + step_name=payload.step_name, + uri=public_url, + metadata_json=jsonable(metadata), + ) + session.add(asset) + await session.flush() + + result = MockActivityResult( + step_name=payload.step_name, + success=True, + asset_id=asset.id, + uri=asset.uri, + score=1.0, + passed=True, + message="try-on generated", + metadata=jsonable(metadata) or {}, + ) + + step.step_status = StepStatus.SUCCEEDED + step.output_json = jsonable(result) + step.ended_at = utc_now() + await session.commit() + return result + except Exception as exc: + step.step_status = StepStatus.FAILED + step.error_message = str(exc) + step.ended_at = utc_now() + order.status = OrderStatus.FAILED + workflow_run.status = OrderStatus.FAILED + await session.commit() + raise diff --git a/app/workers/workflows/low_end_pipeline.py b/app/workers/workflows/low_end_pipeline.py index 35ba9f1..d442feb 100644 --- a/app/workers/workflows/low_end_pipeline.py +++ b/app/workers/workflows/low_end_pipeline.py @@ -26,14 +26,13 @@ with workflow.unsafe.imports_passed_through(): from app.workers.activities.review_activities import mark_workflow_failed_activity from app.workers.activities.scene_activities import run_scene_activity from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity + from app.workers.workflows.timeout_policy import DEFAULT_ACTIVITY_TIMEOUT, activity_timeout_for_task_queue from app.workers.workflows.types import ( PipelineWorkflowInput, StepActivityInput, WorkflowFailureActivityInput, ) - -ACTIVITY_TIMEOUT = timedelta(seconds=30) ACTIVITY_RETRY_POLICY = RetryPolicy( initial_interval=timedelta(seconds=1), backoff_coefficient=2.0, @@ -64,6 +63,8 @@ class LowEndPipelineWorkflow: try: # 每个步骤都通过 execute_activity 触发真正执行。 # workflow 自己不做计算,只负责调度。 + # prepare_model 的职责是把“模特资源 + 订单上下文”整理成后续可消费的人物底图。 + # 这一步产出的 prepared.asset_id 会作为 tryon 的 source_asset_id。 prepared = await workflow.execute_activity( prepare_model_activity, StepActivityInput( @@ -76,12 +77,14 @@ class LowEndPipelineWorkflow: scene_ref_asset_id=payload.scene_ref_asset_id, ), task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) current_step = WorkflowStepName.TRYON # 下游步骤通过 source_asset_id 引用上一步生成的资产。 + # tryon 是换装主步骤:把 prepared 模特图和 garment_asset_id 组合成试衣结果。 + # 它产出的 tryon_result.asset_id 是后续 scene / qc 的基础输入。 tryon_result = await workflow.execute_activity( run_tryon_activity, StepActivityInput( @@ -92,38 +95,46 @@ class LowEndPipelineWorkflow: garment_asset_id=payload.garment_asset_id, ), task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) - current_step = WorkflowStepName.SCENE - scene_result = await workflow.execute_activity( - run_scene_activity, - StepActivityInput( - order_id=payload.order_id, - workflow_run_id=payload.workflow_run_id, - step_name=WorkflowStepName.SCENE, - source_asset_id=tryon_result.asset_id, - scene_ref_asset_id=payload.scene_ref_asset_id, - ), - task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, - retry_policy=ACTIVITY_RETRY_POLICY, - ) + scene_source_asset_id = tryon_result.asset_id + if payload.scene_ref_asset_id is not None: + current_step = WorkflowStepName.SCENE + # scene 是可选步骤。 + # 有场景图时,用 scene_ref_asset_id 把 tryon 结果合成到目标背景; + # 没有场景图时,直接沿用 tryon 结果继续往下走。 + scene_result = await workflow.execute_activity( + run_scene_activity, + StepActivityInput( + order_id=payload.order_id, + workflow_run_id=payload.workflow_run_id, + step_name=WorkflowStepName.SCENE, + source_asset_id=tryon_result.asset_id, + scene_ref_asset_id=payload.scene_ref_asset_id, + ), + task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE), + retry_policy=ACTIVITY_RETRY_POLICY, + ) + scene_source_asset_id = scene_result.asset_id current_step = WorkflowStepName.QC # QC 是流程里的“闸门”。 # 如果这里不通过,低端流程直接失败,不会再 export。 + # source_asset_id 指向当前链路里“最接近最终成品”的那张图: + # 有场景时是 scene 结果,没有场景时就是 tryon 结果。 qc_result = await workflow.execute_activity( run_qc_activity, StepActivityInput( order_id=payload.order_id, workflow_run_id=payload.workflow_run_id, step_name=WorkflowStepName.QC, - source_asset_id=scene_result.asset_id, + source_asset_id=scene_source_asset_id, ), task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) @@ -134,16 +145,17 @@ class LowEndPipelineWorkflow: current_step = WorkflowStepName.EXPORT # candidate_asset_ids 是 QC 推荐可导出的候选资产。 # 当前 MVP 只会返回一个候选;如果没有,就退回 scene 结果导出。 + # export 是最后的收口步骤:生成最终交付资产,并把订单状态写成 succeeded。 final_result = await workflow.execute_activity( run_export_activity, StepActivityInput( order_id=payload.order_id, workflow_run_id=payload.workflow_run_id, step_name=WorkflowStepName.EXPORT, - source_asset_id=(qc_result.candidate_asset_ids or [scene_result.asset_id])[0], + source_asset_id=(qc_result.candidate_asset_ids or [scene_source_asset_id])[0], ), task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) return { @@ -177,6 +189,6 @@ class LowEndPipelineWorkflow: message=message, ), task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=DEFAULT_ACTIVITY_TIMEOUT, retry_policy=ACTIVITY_RETRY_POLICY, ) diff --git a/app/workers/workflows/mid_end_pipeline.py b/app/workers/workflows/mid_end_pipeline.py index 6aa9f8c..5ff0ebb 100644 --- a/app/workers/workflows/mid_end_pipeline.py +++ b/app/workers/workflows/mid_end_pipeline.py @@ -34,6 +34,7 @@ with workflow.unsafe.imports_passed_through(): from app.workers.activities.scene_activities import run_scene_activity from app.workers.activities.texture_activities import run_texture_activity from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity + from app.workers.workflows.timeout_policy import DEFAULT_ACTIVITY_TIMEOUT, activity_timeout_for_task_queue from app.workers.workflows.types import ( MockActivityResult, PipelineWorkflowInput, @@ -44,8 +45,6 @@ with workflow.unsafe.imports_passed_through(): WorkflowFailureActivityInput, ) - -ACTIVITY_TIMEOUT = timedelta(seconds=30) ACTIVITY_RETRY_POLICY = RetryPolicy( initial_interval=timedelta(seconds=1), backoff_coefficient=2.0, @@ -87,6 +86,8 @@ class MidEndPipelineWorkflow: # current_step 用于失败时记录“最后跑到哪一步”。 current_step = WorkflowStepName.PREPARE_MODEL try: + # prepare_model / tryon / scene 这三段和低端流程共享同一套 activity, + # 区别只在于中端流程后面还会继续进入 texture / face / fusion / review。 prepared = await workflow.execute_activity( prepare_model_activity, StepActivityInput( @@ -99,11 +100,12 @@ class MidEndPipelineWorkflow: scene_ref_asset_id=payload.scene_ref_asset_id, ), task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) current_step = WorkflowStepName.TRYON + # tryon 产出基础换装图,后面的所有增强步骤都以它为起点。 tryon_result = await workflow.execute_activity( run_tryon_activity, StepActivityInput( @@ -114,23 +116,37 @@ class MidEndPipelineWorkflow: garment_asset_id=payload.garment_asset_id, ), task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) - current_step = WorkflowStepName.SCENE - scene_result = await self._run_scene(payload, tryon_result.asset_id) + scene_source_asset_id = tryon_result.asset_id + if payload.scene_ref_asset_id is not None: + current_step = WorkflowStepName.SCENE + # scene 仍然是可选步骤。 + # 如果存在场景图,就先完成换背景,再把结果交给 texture / fusion; + # 如果没有场景图,就直接把 tryon 结果当成“场景基底”继续流程。 + scene_result = await self._run_scene(payload, tryon_result.asset_id) + scene_source_asset_id = scene_result.asset_id + else: + scene_result = tryon_result current_step = WorkflowStepName.TEXTURE - texture_result = await self._run_texture(payload, scene_result.asset_id) + # texture 是复杂流独有步骤。 + # 它通常负责衣物纹理、材质细节或局部增强,输入是当前场景基底图。 + texture_result = await self._run_texture(payload, scene_source_asset_id) current_step = WorkflowStepName.FACE + # face 也是复杂流独有步骤。 + # 它在 texture 结果基础上做脸部修复/增强,产出供 fusion 合成的人像版本。 face_result = await self._run_face(payload, texture_result.asset_id) current_step = WorkflowStepName.FUSION - fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id) + # fusion 负责把“场景基底”和“脸部增强结果”合成成候选成品。 + fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id) current_step = WorkflowStepName.QC + # QC 和低端流程复用同一套 activity,但这里检查的是 fusion 产物。 qc_result = await self._run_qc(payload, fusion_result.asset_id) if not qc_result.passed: await self._mark_failed(payload, current_step, qc_result.message) @@ -144,6 +160,8 @@ class MidEndPipelineWorkflow: current_step = WorkflowStepName.REVIEW # 这里通过 activity 把数据库里的订单状态更新成 waiting_review, # 同时创建 review_task,供 API 查询待审核列表。 + # review 是复杂流真正区别于低端流程的核心: + # 在 export 前插入人工决策点,并支持按意见回流重跑。 await workflow.execute_activity( mark_waiting_for_review_activity, ReviewWaitActivityInput( @@ -152,7 +170,7 @@ class MidEndPipelineWorkflow: candidate_asset_ids=qc_result.candidate_asset_ids, ), task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) @@ -172,7 +190,7 @@ class MidEndPipelineWorkflow: comment=review_payload.comment, ), task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) @@ -180,6 +198,7 @@ class MidEndPipelineWorkflow: current_step = WorkflowStepName.EXPORT # 如果审核人显式选了资产,就导出该资产; # 否则默认导出 QC 候选资产。 + # export 本身仍然和低端流程是同一个 activity。 export_source_id = review_payload.selected_asset_id if export_source_id is None: export_source_id = (qc_result.candidate_asset_ids or [fusion_result.asset_id])[0] @@ -192,7 +211,7 @@ class MidEndPipelineWorkflow: source_asset_id=export_source_id, ), task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) return { @@ -208,22 +227,30 @@ class MidEndPipelineWorkflow: # rerun 的核心思想是: # 把指定节点后的链路重新跑一遍,然后再次进入 QC 和 waiting_review。 if review_payload.decision == ReviewDecision.RERUN_SCENE: - current_step = WorkflowStepName.SCENE - scene_result = await self._run_scene(payload, tryon_result.asset_id) + # 从 scene 开始重跑意味着 scene / texture / face / fusion 全部重算。 + if payload.scene_ref_asset_id is not None: + current_step = WorkflowStepName.SCENE + scene_result = await self._run_scene(payload, tryon_result.asset_id) + scene_source_asset_id = scene_result.asset_id + else: + scene_result = tryon_result + scene_source_asset_id = tryon_result.asset_id current_step = WorkflowStepName.TEXTURE - texture_result = await self._run_texture(payload, scene_result.asset_id) + texture_result = await self._run_texture(payload, scene_source_asset_id) current_step = WorkflowStepName.FACE face_result = await self._run_face(payload, texture_result.asset_id) current_step = WorkflowStepName.FUSION - fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id) + fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id) elif review_payload.decision == ReviewDecision.RERUN_FACE: + # 从 face 开始重跑时,scene / texture 结果保持不变。 current_step = WorkflowStepName.FACE face_result = await self._run_face(payload, texture_result.asset_id) current_step = WorkflowStepName.FUSION - fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id) + fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id) elif review_payload.decision == ReviewDecision.RERUN_FUSION: + # 从 fusion 重跑是最小范围回流,只重做最终合成。 current_step = WorkflowStepName.FUSION - fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id) + fusion_result = await self._run_fusion(payload, scene_source_asset_id, face_result.asset_id) current_step = WorkflowStepName.QC qc_result = await self._run_qc(payload, fusion_result.asset_id) @@ -264,7 +291,7 @@ class MidEndPipelineWorkflow: scene_ref_asset_id=payload.scene_ref_asset_id, ), task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) @@ -280,7 +307,7 @@ class MidEndPipelineWorkflow: source_asset_id=source_asset_id, ), task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) @@ -296,7 +323,7 @@ class MidEndPipelineWorkflow: source_asset_id=source_asset_id, ), task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) @@ -318,7 +345,7 @@ class MidEndPipelineWorkflow: selected_asset_id=face_asset_id, ), task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) @@ -334,7 +361,7 @@ class MidEndPipelineWorkflow: source_asset_id=source_asset_id, ), task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE), retry_policy=ACTIVITY_RETRY_POLICY, ) @@ -355,6 +382,6 @@ class MidEndPipelineWorkflow: message=message, ), task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE, - start_to_close_timeout=ACTIVITY_TIMEOUT, + start_to_close_timeout=DEFAULT_ACTIVITY_TIMEOUT, retry_policy=ACTIVITY_RETRY_POLICY, ) diff --git a/app/workers/workflows/timeout_policy.py b/app/workers/workflows/timeout_policy.py new file mode 100644 index 0000000..d3cccc7 --- /dev/null +++ b/app/workers/workflows/timeout_policy.py @@ -0,0 +1,24 @@ +"""Shared activity timeout policy for Temporal workflows.""" + +from datetime import timedelta + +from app.infra.temporal.task_queues import ( + IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE, + IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE, +) + +DEFAULT_ACTIVITY_TIMEOUT = timedelta(seconds=30) +LONG_RUNNING_ACTIVITY_TIMEOUT = timedelta(minutes=5) + +LONG_RUNNING_TASK_QUEUES = { + IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE, + IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE, +} + + +def activity_timeout_for_task_queue(task_queue: str) -> timedelta: + """Return the timeout that matches the workload behind the given task queue.""" + + if task_queue in LONG_RUNNING_TASK_QUEUES: + return LONG_RUNNING_ACTIVITY_TIMEOUT + return DEFAULT_ACTIVITY_TIMEOUT diff --git a/app/workers/workflows/types.py b/app/workers/workflows/types.py index aee6eae..8952cf7 100644 --- a/app/workers/workflows/types.py +++ b/app/workers/workflows/types.py @@ -32,14 +32,22 @@ class PipelineWorkflowInput: 这是一张订单进入 workflow 时携带的最小上下文。 """ + # 订单主键。整个 workflow 的所有步骤都会围绕这张订单落库、查状态。 order_id: int + # workflow_run 主键。用于把每一步 step 记录关联到同一次运行。 workflow_run_id: int + # 客户层级。决定业务侧订单归属,也会影响后续统计和展示。 customer_level: CustomerLevel + # 服务模式。这里直接决定启动简单流程还是复杂流程。 service_mode: ServiceMode + # 模特资源 ID。prepare_model 会基于它准备人物素材。 model_id: int - pose_id: int + # 服装资源/资产 ID。tryon 会把它作为换装输入。 garment_asset_id: int - scene_ref_asset_id: int + # 场景参考图 ID。可为空;为空时跳过 scene 步骤。 + scene_ref_asset_id: int | None = None + # 模特姿势 ID。当前 MVP 已经放宽为可空,后续需要姿势控制时再启用。 + pose_id: int | None = None def __post_init__(self) -> None: """在反序列化后把枚举字段修正回来。""" @@ -59,15 +67,25 @@ class StepActivityInput: - 上一步产出的 asset_id """ + # 当前步骤属于哪张订单。 order_id: int + # 当前步骤属于哪次 workflow 运行。 workflow_run_id: int + # 当前执行的步骤名,例如 tryon / qc / export。 step_name: WorkflowStepName + # 模特资源 ID。只有 prepare_model 等少数步骤会直接消费它。 model_id: int | None = None + # 姿势 ID。当前大多为透传预留字段。 pose_id: int | None = None + # 服装资源/资产 ID。tryon 是主要消费者。 garment_asset_id: int | None = None + # 场景参考图 ID。scene 步骤会用它完成换背景。 scene_ref_asset_id: int | None = None + # 上一步产出的资产 ID。绝大多数步骤都靠它串起资产链路。 source_asset_id: int | None = None + # 人工选中的资产 ID。review / fusion / export 等需要显式选图时使用。 selected_asset_id: int | None = None + # 额外扩展参数。给特殊步骤塞一些不值得单独建字段的临时上下文。 metadata: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: @@ -80,14 +98,23 @@ class StepActivityInput: class MockActivityResult: """通用 mock activity 返回结构。""" + # 这是哪一步返回的结果,方便 workflow 和落库代码识别来源。 step_name: WorkflowStepName + # activity 是否执行成功。一般表示“代码执行成功”,不等于业务一定通过。 success: bool + # 这一步新生成的资产 ID;如果步骤不产出资产,可以为空。 asset_id: int | None + # 产出资产的访问地址;当前 mock 阶段通常是 mock:// URI。 uri: str | None + # 模型分数/质量分数之类的数值结果,非所有步骤都有。 score: float | None = None + # 业务是否通过。典型场景是 QC:activity 成功执行,但 passed 可能为 False。 passed: bool | None = None + # 给 workflow 或 API 展示的简短结果消息。 message: str = "mock success" + # 候选资产列表。主要给 QC / review 这类“多候选图”步骤使用。 candidate_asset_ids: list[int] = field(default_factory=list) + # 补充元数据。用于把步骤内部的一些上下文回传给调用方。 metadata: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: @@ -100,9 +127,13 @@ class MockActivityResult: class ReviewSignalPayload: """API 发给中端 workflow 的审核 signal 载荷。""" + # 审核动作:通过 / 拒绝 / 从某一步重跑。 decision: ReviewDecision + # 操作审核的用户 ID,便于审计和任务归属。 reviewer_id: int + # 审核人最终选中的候选资产 ID;approve 时最常见。 selected_asset_id: int | None = None + # 审核备注,例如“脸部不自然,重跑融合”。 comment: str | None = None def __post_init__(self) -> None: @@ -115,9 +146,13 @@ class ReviewSignalPayload: class ReviewWaitActivityInput: """把流程切到 waiting_review 时传给 activity 的输入。""" + # 当前审核任务属于哪张订单。 order_id: int + # 当前审核任务属于哪次 workflow 运行。 workflow_run_id: int + # 提交给审核端看的候选资产列表。 candidate_asset_ids: list[int] = field(default_factory=list) + # 进入审核态时附带的说明文案,当前大多留空。 comment: str | None = None @@ -125,11 +160,17 @@ class ReviewWaitActivityInput: class ReviewResolutionActivityInput: """审核结果到达后,用于结束 waiting_review 的输入。""" + # 当前审核结果属于哪张订单。 order_id: int + # 当前审核结果属于哪次 workflow 运行。 workflow_run_id: int + # 审核最终决策。 decision: ReviewDecision + # 处理这次审核的审核人 ID。 reviewer_id: int + # 如果审核人明确挑了一张图,这里记录最终选择的资产 ID。 selected_asset_id: int | None = None + # 审核备注或重跑原因。 comment: str | None = None def __post_init__(self) -> None: @@ -142,10 +183,15 @@ class ReviewResolutionActivityInput: class WorkflowFailureActivityInput: """流程失败收尾 activity 的输入。""" + # 失败发生在哪张订单。 order_id: int + # 失败发生在哪次 workflow 运行。 workflow_run_id: int + # 失败停留在哪个步骤,用于更新 workflow_run.current_step。 current_step: WorkflowStepName + # 失败原因文本,通常来自异常或 QC 驳回消息。 message: str + # 要写回数据库的最终状态,默认就是 failed。 status: OrderStatus = OrderStatus.FAILED def __post_init__(self) -> None: diff --git a/pyproject.toml b/pyproject.toml index 4f3a2e3..b15a7e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ requires-python = ">=3.11" dependencies = [ "aiosqlite>=0.20,<1.0", "alembic>=1.13,<2.0", + "boto3>=1.35,<2.0", "fastapi>=0.115,<1.0", "greenlet>=3.1,<4.0", "httpx>=0.27,<1.0", diff --git a/tests/conftest.py b/tests/conftest.py index c462711..cf0c9dd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,13 @@ async def api_runtime(tmp_path, monkeypatch): db_path = tmp_path / "test.db" monkeypatch.setenv("DATABASE_URL", f"sqlite+aiosqlite:///{db_path.as_posix()}") monkeypatch.setenv("AUTO_CREATE_TABLES", "true") + monkeypatch.setenv("S3_ACCESS_KEY", "test-access") + monkeypatch.setenv("S3_SECRET_KEY", "test-secret") + monkeypatch.setenv("S3_BUCKET", "test-bucket") + monkeypatch.setenv("S3_REGION", "ap-southeast-1") + monkeypatch.setenv("S3_ENDPOINT", "https://s3.example.com") + monkeypatch.setenv("S3_CNAME", "images.example.com") + monkeypatch.setenv("IMAGE_GENERATION_PROVIDER", "mock") get_settings.cache_clear() await dispose_database() @@ -41,4 +48,3 @@ async def api_runtime(tmp_path, monkeypatch): set_temporal_client(None) await dispose_database() get_settings.cache_clear() - diff --git a/tests/test_api.py b/tests/test_api.py index c0896d7..bd4983b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -39,15 +39,15 @@ async def wait_for_step_count(client, order_id: int, step_name: str, minimum_cou async def create_mid_end_order(client): """Create a standard semi-pro order for review-path tests.""" + resources = await create_workflow_resources(client, include_scene=True) response = await client.post( "/api/v1/orders", json={ "customer_level": "mid", "service_mode": "semi_pro", - "model_id": 101, - "pose_id": 3, - "garment_asset_id": 9001, - "scene_ref_asset_id": 8001, + "model_id": resources["model"]["id"], + "garment_asset_id": resources["garment"]["id"], + "scene_ref_asset_id": resources["scene"]["id"], }, ) @@ -55,6 +55,115 @@ async def create_mid_end_order(client): return response.json() +async def create_library_resource(client, payload: dict) -> dict: + """Create a resource-library item for workflow integration tests.""" + + response = await client.post("/api/v1/library/resources", json=payload) + assert response.status_code == 201 + return response.json() + + +async def create_workflow_resources(client, *, include_scene: bool) -> dict[str, dict]: + """Create a minimal set of real library resources for workflow tests.""" + + model_resource = await create_library_resource( + client, + { + "resource_type": "model", + "name": "Ava Studio", + "description": "棚拍女模特", + "tags": ["女装", "棚拍"], + "gender": "female", + "age_group": "adult", + "files": [ + { + "file_role": "original", + "storage_key": "library/models/ava/original.png", + "public_url": "https://images.example.com/library/models/ava/original.png", + "mime_type": "image/png", + "size_bytes": 1024, + "sort_order": 0, + "width": 1200, + "height": 1600, + }, + { + "file_role": "thumbnail", + "storage_key": "library/models/ava/thumb.png", + "public_url": "https://images.example.com/library/models/ava/thumb.png", + "mime_type": "image/png", + "size_bytes": 256, + "sort_order": 0, + }, + ], + }, + ) + garment_resource = await create_library_resource( + client, + { + "resource_type": "garment", + "name": "Cream Dress", + "description": "米白色连衣裙", + "tags": ["女装"], + "category": "dress", + "files": [ + { + "file_role": "original", + "storage_key": "library/garments/cream-dress/original.png", + "public_url": "https://images.example.com/library/garments/cream-dress/original.png", + "mime_type": "image/png", + "size_bytes": 1024, + "sort_order": 0, + }, + { + "file_role": "thumbnail", + "storage_key": "library/garments/cream-dress/thumb.png", + "public_url": "https://images.example.com/library/garments/cream-dress/thumb.png", + "mime_type": "image/png", + "size_bytes": 256, + "sort_order": 0, + }, + ], + }, + ) + + resources = { + "model": model_resource, + "garment": garment_resource, + } + + if include_scene: + resources["scene"] = await create_library_resource( + client, + { + "resource_type": "scene", + "name": "Loft Window", + "description": "暖调室内场景", + "tags": ["室内"], + "environment": "indoor", + "files": [ + { + "file_role": "original", + "storage_key": "library/scenes/loft/original.png", + "public_url": "https://images.example.com/library/scenes/loft/original.png", + "mime_type": "image/png", + "size_bytes": 1024, + "sort_order": 0, + }, + { + "file_role": "thumbnail", + "storage_key": "library/scenes/loft/thumb.png", + "public_url": "https://images.example.com/library/scenes/loft/thumb.png", + "mime_type": "image/png", + "size_bytes": 256, + "sort_order": 0, + }, + ], + }, + ) + + return resources + + @pytest.mark.asyncio async def test_healthcheck(api_runtime): """The health endpoint should always respond successfully.""" @@ -71,15 +180,15 @@ async def test_low_end_order_completes(api_runtime): """Low-end orders should run through the full automated pipeline.""" client, env = api_runtime + resources = await create_workflow_resources(client, include_scene=True) response = await client.post( "/api/v1/orders", json={ "customer_level": "low", "service_mode": "auto_basic", - "model_id": 101, - "pose_id": 3, - "garment_asset_id": 9001, - "scene_ref_asset_id": 8001, + "model_id": resources["model"]["id"], + "garment_asset_id": resources["garment"]["id"], + "scene_ref_asset_id": resources["scene"]["id"], }, ) @@ -105,6 +214,160 @@ async def test_low_end_order_completes(api_runtime): assert workflow_response.json()["workflow_status"] == "succeeded" +@pytest.mark.asyncio +async def test_prepare_model_step_uses_library_original_file(api_runtime): + """prepare_model should use the model resource's original file rather than a mock URI.""" + + def expected_input_snapshot(resource: dict) -> dict: + original_file = next(file for file in resource["files"] if file["file_role"] == "original") + snapshot = { + "resource_id": resource["id"], + "resource_name": resource["name"], + "original_file_id": original_file["id"], + "original_url": resource["original_url"], + "mime_type": original_file["mime_type"], + "width": original_file["width"], + "height": original_file["height"], + } + return {key: value for key, value in snapshot.items() if value is not None} + + client, env = api_runtime + resources = await create_workflow_resources(client, include_scene=True) + model_resource = resources["model"] + garment_resource = resources["garment"] + scene_resource = resources["scene"] + + response = await client.post( + "/api/v1/orders", + json={ + "customer_level": "low", + "service_mode": "auto_basic", + "model_id": model_resource["id"], + "garment_asset_id": garment_resource["id"], + "scene_ref_asset_id": scene_resource["id"], + }, + ) + + assert response.status_code == 201 + payload = response.json() + + handle = env.client.get_workflow_handle(payload["workflow_id"]) + result = await handle.result() + assert result["status"] == "succeeded" + + assets_response = await client.get(f"/api/v1/orders/{payload['order_id']}/assets") + assert assets_response.status_code == 200 + prepared_asset = next( + asset for asset in assets_response.json() if asset["asset_type"] == "prepared_model" + ) + assert prepared_asset["uri"] == model_resource["original_url"] + assert prepared_asset["metadata_json"]["library_resource_id"] == model_resource["id"] + assert prepared_asset["metadata_json"]["library_original_file_id"] == next( + file["id"] for file in model_resource["files"] if file["file_role"] == "original" + ) + assert prepared_asset["metadata_json"]["library_original_url"] == model_resource["original_url"] + assert prepared_asset["metadata_json"]["model_input"] == expected_input_snapshot(model_resource) + assert prepared_asset["metadata_json"]["garment_input"] == expected_input_snapshot(garment_resource) + assert prepared_asset["metadata_json"]["scene_input"] == expected_input_snapshot(scene_resource) + + +@pytest.mark.asyncio +async def test_low_end_order_completes_without_scene(api_runtime): + """Low-end orders should still complete when scene input is omitted.""" + + client, env = api_runtime + resources = await create_workflow_resources(client, include_scene=False) + response = await client.post( + "/api/v1/orders", + json={ + "customer_level": "low", + "service_mode": "auto_basic", + "model_id": resources["model"]["id"], + "garment_asset_id": resources["garment"]["id"], + }, + ) + + assert response.status_code == 201 + payload = response.json() + + handle = env.client.get_workflow_handle(payload["workflow_id"]) + result = await handle.result() + + assert result["status"] == "succeeded" + + order_response = await client.get(f"/api/v1/orders/{payload['order_id']}") + assert order_response.status_code == 200 + assert order_response.json()["scene_ref_asset_id"] is None + assert order_response.json()["status"] == "succeeded" + + +@pytest.mark.asyncio +async def test_low_end_order_reuses_tryon_uri_for_qc_and_final_without_scene(api_runtime): + """Without scene input, qc/export should keep the real try-on image URI.""" + + client, env = api_runtime + resources = await create_workflow_resources(client, include_scene=False) + response = await client.post( + "/api/v1/orders", + json={ + "customer_level": "low", + "service_mode": "auto_basic", + "model_id": resources["model"]["id"], + "garment_asset_id": resources["garment"]["id"], + }, + ) + + assert response.status_code == 201 + payload = response.json() + + handle = env.client.get_workflow_handle(payload["workflow_id"]) + result = await handle.result() + assert result["status"] == "succeeded" + + assets_response = await client.get(f"/api/v1/orders/{payload['order_id']}/assets") + assert assets_response.status_code == 200 + assets = assets_response.json() + + tryon_asset = next(asset for asset in assets if asset["asset_type"] == "tryon") + qc_asset = next(asset for asset in assets if asset["asset_type"] == "qc_candidate") + final_asset = next(asset for asset in assets if asset["asset_type"] == "final") + + assert qc_asset["uri"] == tryon_asset["uri"] + assert final_asset["uri"] == tryon_asset["uri"] + assert final_asset["metadata_json"]["source_asset_id"] == qc_asset["id"] + +@pytest.mark.asyncio +async def test_mid_end_order_waits_review_then_approves_without_scene(api_runtime): + """Mid-end orders should still reach review and approve when scene input is omitted.""" + + client, env = api_runtime + resources = await create_workflow_resources(client, include_scene=False) + response = await client.post( + "/api/v1/orders", + json={ + "customer_level": "mid", + "service_mode": "semi_pro", + "model_id": resources["model"]["id"], + "garment_asset_id": resources["garment"]["id"], + }, + ) + + assert response.status_code == 201 + payload = response.json() + + await wait_for_workflow_status(client, payload["order_id"], "waiting_review") + + review_response = await client.post( + f"/api/v1/reviews/{payload['order_id']}/submit", + json={"decision": "approve", "reviewer_id": 77, "comment": "通过"}, + ) + assert review_response.status_code == 200 + + handle = env.client.get_workflow_handle(payload["workflow_id"]) + result = await handle.result() + assert result["status"] == "succeeded" + + @pytest.mark.asyncio async def test_mid_end_order_waits_review_then_approves(api_runtime): """Mid-end orders should pause for review and continue after approval.""" @@ -133,6 +396,162 @@ async def test_mid_end_order_waits_review_then_approves(api_runtime): assert order_response.json()["status"] == "succeeded" +@pytest.mark.asyncio +async def test_library_upload_presign_returns_direct_upload_metadata(api_runtime): + """Library upload presign should return S3 direct-upload metadata and a public URL.""" + + client, _ = api_runtime + response = await client.post( + "/api/v1/library/uploads/presign", + json={ + "resource_type": "model", + "file_role": "original", + "file_name": "ava.png", + "content_type": "image/png", + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["method"] == "PUT" + assert payload["storage_key"].startswith("library/models/") + assert payload["public_url"].startswith("https://") + assert "library/models/" in payload["upload_url"] + + +@pytest.mark.asyncio +async def test_library_resource_can_be_created_and_listed(api_runtime): + """Creating a library resource should persist original, thumbnail, and gallery files.""" + + client, _ = api_runtime + create_response = await client.post( + "/api/v1/library/resources", + json={ + "resource_type": "model", + "name": "Ava Studio", + "description": "棚拍女模特", + "tags": ["女装", "棚拍"], + "gender": "female", + "age_group": "adult", + "files": [ + { + "file_role": "original", + "storage_key": "library/models/ava/original.png", + "public_url": "https://images.marcusd.me/library/models/ava/original.png", + "mime_type": "image/png", + "size_bytes": 1024, + "sort_order": 0, + }, + { + "file_role": "thumbnail", + "storage_key": "library/models/ava/thumb.png", + "public_url": "https://images.marcusd.me/library/models/ava/thumb.png", + "mime_type": "image/png", + "size_bytes": 256, + "sort_order": 0, + }, + { + "file_role": "gallery", + "storage_key": "library/models/ava/gallery-1.png", + "public_url": "https://images.marcusd.me/library/models/ava/gallery-1.png", + "mime_type": "image/png", + "size_bytes": 2048, + "sort_order": 1, + }, + ], + }, + ) + + assert create_response.status_code == 201 + created = create_response.json() + assert created["resource_type"] == "model" + assert created["pose_id"] is None + assert created["cover_url"] == "https://images.marcusd.me/library/models/ava/thumb.png" + assert created["original_url"] == "https://images.marcusd.me/library/models/ava/original.png" + assert len(created["files"]) == 3 + + list_response = await client.get("/api/v1/library/resources", params={"resource_type": "model"}) + assert list_response.status_code == 200 + listing = list_response.json() + assert listing["total"] == 1 + assert listing["items"][0]["name"] == "Ava Studio" + assert listing["items"][0]["gender"] == "female" + assert listing["items"][0]["pose_id"] is None + + +@pytest.mark.asyncio +async def test_library_resource_list_supports_type_specific_filters(api_runtime): + """Library listing should support resource-type-specific filter fields.""" + + client, _ = api_runtime + + payloads = [ + { + "resource_type": "scene", + "name": "Loft Window", + "description": "暖调室内场景", + "tags": ["室内"], + "environment": "indoor", + "files": [ + { + "file_role": "original", + "storage_key": "library/scenes/loft/original.png", + "public_url": "https://images.marcusd.me/library/scenes/loft/original.png", + "mime_type": "image/png", + "size_bytes": 1024, + "sort_order": 0, + }, + { + "file_role": "thumbnail", + "storage_key": "library/scenes/loft/thumb.png", + "public_url": "https://images.marcusd.me/library/scenes/loft/thumb.png", + "mime_type": "image/png", + "size_bytes": 256, + "sort_order": 0, + }, + ], + }, + { + "resource_type": "scene", + "name": "Garden Walk", + "description": "自然光室外场景", + "tags": ["室外"], + "environment": "outdoor", + "files": [ + { + "file_role": "original", + "storage_key": "library/scenes/garden/original.png", + "public_url": "https://images.marcusd.me/library/scenes/garden/original.png", + "mime_type": "image/png", + "size_bytes": 1024, + "sort_order": 0, + }, + { + "file_role": "thumbnail", + "storage_key": "library/scenes/garden/thumb.png", + "public_url": "https://images.marcusd.me/library/scenes/garden/thumb.png", + "mime_type": "image/png", + "size_bytes": 256, + "sort_order": 0, + }, + ], + }, + ] + + for payload in payloads: + response = await client.post("/api/v1/library/resources", json=payload) + assert response.status_code == 201 + + indoor_response = await client.get( + "/api/v1/library/resources", + params={"resource_type": "scene", "environment": "indoor"}, + ) + assert indoor_response.status_code == 200 + indoor_payload = indoor_response.json() + assert indoor_payload["total"] == 1 + assert indoor_payload["items"][0]["name"] == "Loft Window" + + @pytest.mark.asyncio @pytest.mark.parametrize( ("decision", "expected_step"), @@ -312,15 +731,16 @@ async def test_orders_list_returns_recent_orders_with_revision_summary(api_runti client, env = api_runtime + low_resources = await create_workflow_resources(client, include_scene=True) low_order = await client.post( "/api/v1/orders", json={ "customer_level": "low", "service_mode": "auto_basic", - "model_id": 201, + "model_id": low_resources["model"]["id"], "pose_id": 11, - "garment_asset_id": 9101, - "scene_ref_asset_id": 8101, + "garment_asset_id": low_resources["garment"]["id"], + "scene_ref_asset_id": low_resources["scene"]["id"], }, ) assert low_order.status_code == 201 @@ -394,15 +814,16 @@ async def test_workflows_list_returns_recent_runs_with_failure_count(api_runtime client, env = api_runtime + low_resources = await create_workflow_resources(client, include_scene=True) low_order = await client.post( "/api/v1/orders", json={ "customer_level": "low", "service_mode": "auto_basic", - "model_id": 301, + "model_id": low_resources["model"]["id"], "pose_id": 21, - "garment_asset_id": 9201, - "scene_ref_asset_id": 8201, + "garment_asset_id": low_resources["garment"]["id"], + "scene_ref_asset_id": low_resources["scene"]["id"], }, ) assert low_order.status_code == 201 diff --git a/tests/test_gemini_provider.py b/tests/test_gemini_provider.py new file mode 100644 index 0000000..0968fb6 --- /dev/null +++ b/tests/test_gemini_provider.py @@ -0,0 +1,150 @@ +"""Tests for the Gemini image provider.""" + +from __future__ import annotations + +import base64 + +import httpx +import pytest + +from app.infra.image_generation.base import SourceImage +from app.infra.image_generation.gemini_provider import GeminiImageProvider + + +@pytest.mark.asyncio +async def test_gemini_provider_uses_configured_base_url_and_model(): + """The provider should call the configured endpoint and decode the returned image bytes.""" + + expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent" + png_bytes = b"generated-png" + + async def handler(request: httpx.Request) -> httpx.Response: + assert str(request.url) == expected_url + assert request.headers["x-goog-api-key"] == "test-key" + payload = request.read().decode("utf-8") + assert "gemini-test-model" not in payload + assert "Dress the person in the garment" in payload + assert "cGVyc29uLWJ5dGVz" in payload + assert "Z2FybWVudC1ieXRlcw==" in payload + return httpx.Response( + 200, + json={ + "candidates": [ + { + "content": { + "parts": [ + {"text": "done"}, + { + "inlineData": { + "mimeType": "image/png", + "data": base64.b64encode(png_bytes).decode("utf-8"), + } + }, + ] + } + } + ] + }, + ) + + provider = GeminiImageProvider( + api_key="test-key", + base_url="https://gemini.example.com/custom", + model="gemini-test-model", + timeout_seconds=5, + http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)), + ) + + try: + result = await provider.generate_tryon_image( + prompt="Dress the person in the garment", + person_image=SourceImage( + url="https://images.example.com/person.png", + mime_type="image/png", + data=b"person-bytes", + ), + garment_image=SourceImage( + url="https://images.example.com/garment.png", + mime_type="image/png", + data=b"garment-bytes", + ), + ) + finally: + await provider._http_client.aclose() + + assert result.image_bytes == png_bytes + assert result.mime_type == "image/png" + assert result.provider == "gemini" + assert result.model == "gemini-test-model" + + +@pytest.mark.asyncio +async def test_gemini_provider_retries_remote_disconnect_and_uses_request_timeout(): + """The provider should retry transient transport failures and pass timeout per request.""" + + expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent" + jpeg_bytes = b"generated-jpeg" + + class FakeClient: + def __init__(self) -> None: + self.attempts = 0 + self.timeouts: list[int | None] = [] + + async def post(self, url: str, **kwargs) -> httpx.Response: + self.attempts += 1 + self.timeouts.append(kwargs.get("timeout")) + assert url == expected_url + if self.attempts == 1: + raise httpx.RemoteProtocolError("Server disconnected without sending a response.") + return httpx.Response( + 200, + json={ + "candidates": [ + { + "content": { + "parts": [ + { + "inlineData": { + "mimeType": "image/jpeg", + "data": base64.b64encode(jpeg_bytes).decode("utf-8"), + } + } + ] + } + } + ] + }, + request=httpx.Request("POST", url), + ) + + async def aclose(self) -> None: + return None + + fake_client = FakeClient() + + provider = GeminiImageProvider( + api_key="test-key", + base_url="https://gemini.example.com/custom", + model="gemini-test-model", + timeout_seconds=300, + http_client=fake_client, + ) + + result = await provider.generate_tryon_image( + prompt="Dress the person in the garment", + person_image=SourceImage( + url="https://images.example.com/person.png", + mime_type="image/png", + data=b"person-bytes", + ), + garment_image=SourceImage( + url="https://images.example.com/garment.png", + mime_type="image/png", + data=b"garment-bytes", + ), + ) + + assert fake_client.attempts == 2 + assert fake_client.timeouts == [300, 300] + assert result.image_bytes == jpeg_bytes + assert result.mime_type == "image/jpeg" diff --git a/tests/test_image_generation_service.py b/tests/test_image_generation_service.py new file mode 100644 index 0000000..c6d0e5c --- /dev/null +++ b/tests/test_image_generation_service.py @@ -0,0 +1,91 @@ +"""Tests for application-level image-generation orchestration.""" + +from __future__ import annotations + +import httpx +import pytest + +from app.application.services.image_generation_service import ImageGenerationService +from app.infra.image_generation.base import GeneratedImageResult, SourceImage + + +class FakeProvider: + def __init__(self) -> None: + self.calls: list[tuple[str, str, str]] = [] + + async def generate_tryon_image( + self, + *, + prompt: str, + person_image: SourceImage, + garment_image: SourceImage, + ) -> GeneratedImageResult: + self.calls.append(("tryon", person_image.url, garment_image.url)) + return GeneratedImageResult( + image_bytes=b"tryon", + mime_type="image/png", + provider="gemini", + model="gemini-test", + prompt=prompt, + ) + + async def generate_scene_image( + self, + *, + prompt: str, + source_image: SourceImage, + scene_image: SourceImage, + ) -> GeneratedImageResult: + self.calls.append(("scene", source_image.url, scene_image.url)) + return GeneratedImageResult( + image_bytes=b"scene", + mime_type="image/jpeg", + provider="gemini", + model="gemini-test", + prompt=prompt, + ) + + +@pytest.mark.asyncio +async def test_image_generation_service_downloads_inputs_for_tryon_and_scene(monkeypatch): + """The service should download both inputs and dispatch them to the matching provider method.""" + + responses = { + "https://images.example.com/person.png": (b"person-bytes", "image/png"), + "https://images.example.com/garment.png": (b"garment-bytes", "image/png"), + "https://images.example.com/source.jpg": (b"source-bytes", "image/jpeg"), + "https://images.example.com/scene.jpg": (b"scene-bytes", "image/jpeg"), + } + + async def handler(request: httpx.Request) -> httpx.Response: + body, mime = responses[str(request.url)] + return httpx.Response(200, content=body, headers={"content-type": mime}, request=request) + + monkeypatch.setenv("GEMINI_API_KEY", "test-key") + + from app.config.settings import get_settings + + get_settings.cache_clear() + provider = FakeProvider() + downloader = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + service = ImageGenerationService(provider=provider, downloader=downloader) + + try: + tryon = await service.generate_tryon_image( + person_image_url="https://images.example.com/person.png", + garment_image_url="https://images.example.com/garment.png", + ) + scene = await service.generate_scene_image( + source_image_url="https://images.example.com/source.jpg", + scene_image_url="https://images.example.com/scene.jpg", + ) + finally: + await downloader.aclose() + get_settings.cache_clear() + + assert provider.calls == [ + ("tryon", "https://images.example.com/person.png", "https://images.example.com/garment.png"), + ("scene", "https://images.example.com/source.jpg", "https://images.example.com/scene.jpg"), + ] + assert tryon.image_bytes == b"tryon" + assert scene.image_bytes == b"scene" diff --git a/tests/test_library_resource_actions.py b/tests/test_library_resource_actions.py new file mode 100644 index 0000000..c0ab09f --- /dev/null +++ b/tests/test_library_resource_actions.py @@ -0,0 +1,63 @@ +import pytest + + +@pytest.mark.asyncio +async def test_library_resource_can_be_updated_and_archived(api_runtime): + client, _ = api_runtime + + create_response = await client.post( + "/api/v1/library/resources", + json={ + "resource_type": "model", + "name": "Ava Studio", + "description": "棚拍女模特", + "tags": ["女装", "棚拍"], + "gender": "female", + "age_group": "adult", + "files": [ + { + "file_role": "original", + "storage_key": "library/models/ava/original.png", + "public_url": "https://images.marcusd.me/library/models/ava/original.png", + "mime_type": "image/png", + "size_bytes": 1024, + "sort_order": 0, + }, + { + "file_role": "thumbnail", + "storage_key": "library/models/ava/thumb.png", + "public_url": "https://images.marcusd.me/library/models/ava/thumb.png", + "mime_type": "image/png", + "size_bytes": 256, + "sort_order": 0, + }, + ], + }, + ) + + assert create_response.status_code == 201 + resource = create_response.json() + + update_response = await client.patch( + f"/api/v1/library/resources/{resource['id']}", + json={ + "name": "Ava Studio Updated", + "description": "新的描述", + "tags": ["女装", "更新"], + }, + ) + + assert update_response.status_code == 200 + updated = update_response.json() + assert updated["name"] == "Ava Studio Updated" + assert updated["description"] == "新的描述" + assert updated["tags"] == ["女装", "更新"] + + archive_response = await client.delete(f"/api/v1/library/resources/{resource['id']}") + + assert archive_response.status_code == 200 + assert archive_response.json()["id"] == resource["id"] + + list_response = await client.get("/api/v1/library/resources", params={"resource_type": "model"}) + assert list_response.status_code == 200 + assert list_response.json()["items"] == [] diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 0000000..f335392 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,15 @@ +"""Tests for application settings resolution.""" + +from pathlib import Path + +from app.config.settings import Settings + + +def test_settings_env_file_uses_backend_repo_path() -> None: + """Settings should resolve .env relative to the backend repo, not the launch cwd.""" + + env_file = Settings.model_config.get("env_file") + + assert isinstance(env_file, Path) + assert env_file.is_absolute() + assert env_file.name == ".env" diff --git a/tests/test_tryon_activity.py b/tests/test_tryon_activity.py new file mode 100644 index 0000000..39681e0 --- /dev/null +++ b/tests/test_tryon_activity.py @@ -0,0 +1,298 @@ +"""Focused tests for the try-on activity implementation.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import pytest + +from app.domain.enums import AssetType, LibraryFileRole, LibraryResourceStatus, LibraryResourceType, OrderStatus, WorkflowStepName +from app.infra.db.models.asset import AssetORM +from app.infra.db.models.library_resource import LibraryResourceORM +from app.infra.db.models.library_resource_file import LibraryResourceFileORM +from app.infra.db.models.order import OrderORM +from app.infra.db.models.workflow_run import WorkflowRunORM +from app.infra.db.session import get_session_factory +from app.workers.activities import tryon_activities +from app.workers.workflows.types import StepActivityInput + + +@dataclass(slots=True) +class FakeGeneratedImage: + image_bytes: bytes + mime_type: str + provider: str + model: str + prompt: str + + +class FakeImageGenerationService: + async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> FakeGeneratedImage: + assert person_image_url == "https://images.example.com/orders/1/prepared-model.png" + assert garment_image_url == "https://images.example.com/library/garments/cream-dress/original.png" + return FakeGeneratedImage( + image_bytes=b"fake-png-binary", + mime_type="image/png", + provider="gemini", + model="gemini-test-image", + prompt="test prompt", + ) + + async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> FakeGeneratedImage: + assert source_image_url == "https://images.example.com/orders/1/tryon/generated.png" + assert scene_image_url == "https://images.example.com/library/scenes/studio/original.png" + return FakeGeneratedImage( + image_bytes=b"fake-scene-binary", + mime_type="image/jpeg", + provider="gemini", + model="gemini-test-image", + prompt="scene prompt", + ) + + +class FakeOrderArtifactStorageService: + async def upload_generated_image( + self, + *, + order_id: int, + step_name: WorkflowStepName, + image_bytes: bytes, + mime_type: str, + ) -> tuple[str, str]: + assert order_id == 1 + assert step_name == WorkflowStepName.TRYON + assert image_bytes == b"fake-png-binary" + assert mime_type == "image/png" + return ( + "orders/1/tryon/generated.png", + "https://images.example.com/orders/1/tryon/generated.png", + ) + + +class FakeSceneArtifactStorageService: + async def upload_generated_image( + self, + *, + order_id: int, + step_name: WorkflowStepName, + image_bytes: bytes, + mime_type: str, + ) -> tuple[str, str]: + assert order_id == 1 + assert step_name == WorkflowStepName.SCENE + assert image_bytes == b"fake-scene-binary" + assert mime_type == "image/jpeg" + return ( + "orders/1/scene/generated.jpg", + "https://images.example.com/orders/1/scene/generated.jpg", + ) + + +@pytest.mark.asyncio +async def test_run_tryon_activity_persists_uploaded_asset_with_provider_metadata(api_runtime, monkeypatch): + """Gemini-mode try-on should persist the uploaded output URL instead of a mock URI.""" + + monkeypatch.setattr( + tryon_activities, + "get_image_generation_service", + lambda: FakeImageGenerationService(), + ) + monkeypatch.setattr( + tryon_activities, + "get_order_artifact_storage_service", + lambda: FakeOrderArtifactStorageService(), + ) + + async with get_session_factory()() as session: + order = OrderORM( + customer_level="low", + service_mode="auto_basic", + status=OrderStatus.CREATED, + model_id=1, + garment_asset_id=2, + ) + session.add(order) + await session.flush() + + workflow_run = WorkflowRunORM( + order_id=order.id, + workflow_id=f"order-{order.id}", + workflow_type="LowEndPipelineWorkflow", + status=OrderStatus.CREATED, + ) + session.add(workflow_run) + await session.flush() + + prepared_asset = AssetORM( + order_id=order.id, + asset_type=AssetType.PREPARED_MODEL, + step_name=WorkflowStepName.PREPARE_MODEL, + uri="https://images.example.com/orders/1/prepared-model.png", + metadata_json={"library_resource_id": 1}, + ) + session.add(prepared_asset) + + garment_resource = LibraryResourceORM( + resource_type=LibraryResourceType.GARMENT, + name="Cream Dress", + description="米白色连衣裙", + tags=["女装"], + status=LibraryResourceStatus.ACTIVE, + category="dress", + ) + session.add(garment_resource) + await session.flush() + + garment_original = LibraryResourceFileORM( + resource_id=garment_resource.id, + file_role=LibraryFileRole.ORIGINAL, + storage_key="library/garments/cream-dress/original.png", + public_url="https://images.example.com/library/garments/cream-dress/original.png", + bucket="test-bucket", + mime_type="image/png", + size_bytes=1024, + sort_order=0, + ) + garment_thumb = LibraryResourceFileORM( + resource_id=garment_resource.id, + file_role=LibraryFileRole.THUMBNAIL, + storage_key="library/garments/cream-dress/thumb.png", + public_url="https://images.example.com/library/garments/cream-dress/thumb.png", + bucket="test-bucket", + mime_type="image/png", + size_bytes=256, + sort_order=0, + ) + session.add_all([garment_original, garment_thumb]) + await session.flush() + garment_resource.original_file_id = garment_original.id + garment_resource.cover_file_id = garment_thumb.id + await session.commit() + + payload = StepActivityInput( + order_id=1, + workflow_run_id=1, + step_name=WorkflowStepName.TRYON, + source_asset_id=1, + garment_asset_id=1, + ) + + result = await tryon_activities.run_tryon_activity(payload) + + assert result.uri == "https://images.example.com/orders/1/tryon/generated.png" + assert result.metadata["provider"] == "gemini" + assert result.metadata["model"] == "gemini-test-image" + assert result.metadata["prepared_asset_id"] == 1 + assert result.metadata["garment_resource_id"] == 1 + + async with get_session_factory()() as session: + assets = (await session.execute( + AssetORM.__table__.select().where(AssetORM.order_id == 1, AssetORM.asset_type == AssetType.TRYON) + )).mappings().all() + + assert len(assets) == 1 + assert assets[0]["uri"] == "https://images.example.com/orders/1/tryon/generated.png" + + +@pytest.mark.asyncio +async def test_run_scene_activity_persists_uploaded_asset_with_provider_metadata(api_runtime, monkeypatch): + """Gemini-mode scene should persist the uploaded output URL instead of a mock URI.""" + + from app.workers.activities import scene_activities + + monkeypatch.setattr( + scene_activities, + "get_image_generation_service", + lambda: FakeImageGenerationService(), + ) + monkeypatch.setattr( + scene_activities, + "get_order_artifact_storage_service", + lambda: FakeSceneArtifactStorageService(), + ) + + async with get_session_factory()() as session: + order = OrderORM( + customer_level="low", + service_mode="auto_basic", + status=OrderStatus.CREATED, + model_id=1, + garment_asset_id=2, + scene_ref_asset_id=3, + ) + session.add(order) + await session.flush() + + workflow_run = WorkflowRunORM( + order_id=order.id, + workflow_id=f"order-{order.id}", + workflow_type="LowEndPipelineWorkflow", + status=OrderStatus.CREATED, + ) + session.add(workflow_run) + await session.flush() + + tryon_asset = AssetORM( + order_id=order.id, + asset_type=AssetType.TRYON, + step_name=WorkflowStepName.TRYON, + uri="https://images.example.com/orders/1/tryon/generated.png", + metadata_json={"prepared_asset_id": 1}, + ) + session.add(tryon_asset) + + scene_resource = LibraryResourceORM( + resource_type=LibraryResourceType.SCENE, + name="Studio Background", + description="摄影棚背景", + tags=["室内"], + status=LibraryResourceStatus.ACTIVE, + environment="indoor", + ) + session.add(scene_resource) + await session.flush() + + scene_original = LibraryResourceFileORM( + resource_id=scene_resource.id, + file_role=LibraryFileRole.ORIGINAL, + storage_key="library/scenes/studio/original.png", + public_url="https://images.example.com/library/scenes/studio/original.png", + bucket="test-bucket", + mime_type="image/png", + size_bytes=2048, + sort_order=0, + ) + session.add(scene_original) + await session.flush() + scene_resource.original_file_id = scene_original.id + scene_resource.cover_file_id = scene_original.id + await session.commit() + + payload = StepActivityInput( + order_id=1, + workflow_run_id=1, + step_name=WorkflowStepName.SCENE, + source_asset_id=1, + scene_ref_asset_id=1, + ) + + result = await scene_activities.run_scene_activity(payload) + + assert result.uri == "https://images.example.com/orders/1/scene/generated.jpg" + assert result.metadata["provider"] == "gemini" + assert result.metadata["model"] == "gemini-test-image" + assert result.metadata["source_asset_id"] == 1 + assert result.metadata["scene_resource_id"] == 1 + + async with get_session_factory()() as session: + assets = ( + await session.execute( + AssetORM.__table__.select().where( + AssetORM.order_id == 1, + AssetORM.asset_type == AssetType.SCENE, + ) + ) + ).mappings().all() + + assert len(assets) == 1 + assert assets[0]["uri"] == "https://images.example.com/orders/1/scene/generated.jpg" diff --git a/tests/test_workflow_timeouts.py b/tests/test_workflow_timeouts.py new file mode 100644 index 0000000..c88d07a --- /dev/null +++ b/tests/test_workflow_timeouts.py @@ -0,0 +1,33 @@ +"""Tests for workflow activity timeout policies.""" + +from datetime import timedelta + +from app.infra.temporal.task_queues import ( + IMAGE_PIPELINE_CONTROL_TASK_QUEUE, + IMAGE_PIPELINE_EXPORT_TASK_QUEUE, + IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE, + IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE, + IMAGE_PIPELINE_QC_TASK_QUEUE, +) +from app.workers.workflows.timeout_policy import ( + DEFAULT_ACTIVITY_TIMEOUT, + LONG_RUNNING_ACTIVITY_TIMEOUT, + activity_timeout_for_task_queue, +) + + +def test_activity_timeout_for_task_queue_uses_long_timeout_for_image_work(): + """Image generation and post-processing queues should get a longer timeout.""" + + assert DEFAULT_ACTIVITY_TIMEOUT == timedelta(seconds=30) + assert LONG_RUNNING_ACTIVITY_TIMEOUT == timedelta(minutes=5) + assert activity_timeout_for_task_queue(IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE) == LONG_RUNNING_ACTIVITY_TIMEOUT + assert activity_timeout_for_task_queue(IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE) == LONG_RUNNING_ACTIVITY_TIMEOUT + + +def test_activity_timeout_for_task_queue_keeps_short_timeout_for_light_steps(): + """Control, QC, and export queues should stay on the short timeout.""" + + assert activity_timeout_for_task_queue(IMAGE_PIPELINE_CONTROL_TASK_QUEUE) == DEFAULT_ACTIVITY_TIMEOUT + assert activity_timeout_for_task_queue(IMAGE_PIPELINE_QC_TASK_QUEUE) == DEFAULT_ACTIVITY_TIMEOUT + assert activity_timeout_for_task_queue(IMAGE_PIPELINE_EXPORT_TASK_QUEUE) == DEFAULT_ACTIVITY_TIMEOUT