Files
auto-virtual-tryon/tests/test_tryon_activity.py

299 lines
10 KiB
Python

"""Focused tests for the try-on activity implementation."""
from __future__ import annotations
from dataclasses import dataclass
import pytest
from app.domain.enums import AssetType, LibraryFileRole, LibraryResourceStatus, LibraryResourceType, OrderStatus, WorkflowStepName
from app.infra.db.models.asset import AssetORM
from app.infra.db.models.library_resource import LibraryResourceORM
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
from app.infra.db.models.order import OrderORM
from app.infra.db.models.workflow_run import WorkflowRunORM
from app.infra.db.session import get_session_factory
from app.workers.activities import tryon_activities
from app.workers.workflows.types import StepActivityInput
@dataclass(slots=True)
class FakeGeneratedImage:
image_bytes: bytes
mime_type: str
provider: str
model: str
prompt: str
class FakeImageGenerationService:
async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> FakeGeneratedImage:
assert person_image_url == "https://images.example.com/orders/1/prepared-model.png"
assert garment_image_url == "https://images.example.com/library/garments/cream-dress/original.png"
return FakeGeneratedImage(
image_bytes=b"fake-png-binary",
mime_type="image/png",
provider="gemini",
model="gemini-test-image",
prompt="test prompt",
)
async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> FakeGeneratedImage:
assert source_image_url == "https://images.example.com/orders/1/tryon/generated.png"
assert scene_image_url == "https://images.example.com/library/scenes/studio/original.png"
return FakeGeneratedImage(
image_bytes=b"fake-scene-binary",
mime_type="image/jpeg",
provider="gemini",
model="gemini-test-image",
prompt="scene prompt",
)
class FakeOrderArtifactStorageService:
async def upload_generated_image(
self,
*,
order_id: int,
step_name: WorkflowStepName,
image_bytes: bytes,
mime_type: str,
) -> tuple[str, str]:
assert order_id == 1
assert step_name == WorkflowStepName.TRYON
assert image_bytes == b"fake-png-binary"
assert mime_type == "image/png"
return (
"orders/1/tryon/generated.png",
"https://images.example.com/orders/1/tryon/generated.png",
)
class FakeSceneArtifactStorageService:
async def upload_generated_image(
self,
*,
order_id: int,
step_name: WorkflowStepName,
image_bytes: bytes,
mime_type: str,
) -> tuple[str, str]:
assert order_id == 1
assert step_name == WorkflowStepName.SCENE
assert image_bytes == b"fake-scene-binary"
assert mime_type == "image/jpeg"
return (
"orders/1/scene/generated.jpg",
"https://images.example.com/orders/1/scene/generated.jpg",
)
@pytest.mark.asyncio
async def test_run_tryon_activity_persists_uploaded_asset_with_provider_metadata(api_runtime, monkeypatch):
"""Gemini-mode try-on should persist the uploaded output URL instead of a mock URI."""
monkeypatch.setattr(
tryon_activities,
"get_image_generation_service",
lambda: FakeImageGenerationService(),
)
monkeypatch.setattr(
tryon_activities,
"get_order_artifact_storage_service",
lambda: FakeOrderArtifactStorageService(),
)
async with get_session_factory()() as session:
order = OrderORM(
customer_level="low",
service_mode="auto_basic",
status=OrderStatus.CREATED,
model_id=1,
garment_asset_id=2,
)
session.add(order)
await session.flush()
workflow_run = WorkflowRunORM(
order_id=order.id,
workflow_id=f"order-{order.id}",
workflow_type="LowEndPipelineWorkflow",
status=OrderStatus.CREATED,
)
session.add(workflow_run)
await session.flush()
prepared_asset = AssetORM(
order_id=order.id,
asset_type=AssetType.PREPARED_MODEL,
step_name=WorkflowStepName.PREPARE_MODEL,
uri="https://images.example.com/orders/1/prepared-model.png",
metadata_json={"library_resource_id": 1},
)
session.add(prepared_asset)
garment_resource = LibraryResourceORM(
resource_type=LibraryResourceType.GARMENT,
name="Cream Dress",
description="米白色连衣裙",
tags=["女装"],
status=LibraryResourceStatus.ACTIVE,
category="dress",
)
session.add(garment_resource)
await session.flush()
garment_original = LibraryResourceFileORM(
resource_id=garment_resource.id,
file_role=LibraryFileRole.ORIGINAL,
storage_key="library/garments/cream-dress/original.png",
public_url="https://images.example.com/library/garments/cream-dress/original.png",
bucket="test-bucket",
mime_type="image/png",
size_bytes=1024,
sort_order=0,
)
garment_thumb = LibraryResourceFileORM(
resource_id=garment_resource.id,
file_role=LibraryFileRole.THUMBNAIL,
storage_key="library/garments/cream-dress/thumb.png",
public_url="https://images.example.com/library/garments/cream-dress/thumb.png",
bucket="test-bucket",
mime_type="image/png",
size_bytes=256,
sort_order=0,
)
session.add_all([garment_original, garment_thumb])
await session.flush()
garment_resource.original_file_id = garment_original.id
garment_resource.cover_file_id = garment_thumb.id
await session.commit()
payload = StepActivityInput(
order_id=1,
workflow_run_id=1,
step_name=WorkflowStepName.TRYON,
source_asset_id=1,
garment_asset_id=1,
)
result = await tryon_activities.run_tryon_activity(payload)
assert result.uri == "https://images.example.com/orders/1/tryon/generated.png"
assert result.metadata["provider"] == "gemini"
assert result.metadata["model"] == "gemini-test-image"
assert result.metadata["prepared_asset_id"] == 1
assert result.metadata["garment_resource_id"] == 1
async with get_session_factory()() as session:
assets = (await session.execute(
AssetORM.__table__.select().where(AssetORM.order_id == 1, AssetORM.asset_type == AssetType.TRYON)
)).mappings().all()
assert len(assets) == 1
assert assets[0]["uri"] == "https://images.example.com/orders/1/tryon/generated.png"
@pytest.mark.asyncio
async def test_run_scene_activity_persists_uploaded_asset_with_provider_metadata(api_runtime, monkeypatch):
"""Gemini-mode scene should persist the uploaded output URL instead of a mock URI."""
from app.workers.activities import scene_activities
monkeypatch.setattr(
scene_activities,
"get_image_generation_service",
lambda: FakeImageGenerationService(),
)
monkeypatch.setattr(
scene_activities,
"get_order_artifact_storage_service",
lambda: FakeSceneArtifactStorageService(),
)
async with get_session_factory()() as session:
order = OrderORM(
customer_level="low",
service_mode="auto_basic",
status=OrderStatus.CREATED,
model_id=1,
garment_asset_id=2,
scene_ref_asset_id=3,
)
session.add(order)
await session.flush()
workflow_run = WorkflowRunORM(
order_id=order.id,
workflow_id=f"order-{order.id}",
workflow_type="LowEndPipelineWorkflow",
status=OrderStatus.CREATED,
)
session.add(workflow_run)
await session.flush()
tryon_asset = AssetORM(
order_id=order.id,
asset_type=AssetType.TRYON,
step_name=WorkflowStepName.TRYON,
uri="https://images.example.com/orders/1/tryon/generated.png",
metadata_json={"prepared_asset_id": 1},
)
session.add(tryon_asset)
scene_resource = LibraryResourceORM(
resource_type=LibraryResourceType.SCENE,
name="Studio Background",
description="摄影棚背景",
tags=["室内"],
status=LibraryResourceStatus.ACTIVE,
environment="indoor",
)
session.add(scene_resource)
await session.flush()
scene_original = LibraryResourceFileORM(
resource_id=scene_resource.id,
file_role=LibraryFileRole.ORIGINAL,
storage_key="library/scenes/studio/original.png",
public_url="https://images.example.com/library/scenes/studio/original.png",
bucket="test-bucket",
mime_type="image/png",
size_bytes=2048,
sort_order=0,
)
session.add(scene_original)
await session.flush()
scene_resource.original_file_id = scene_original.id
scene_resource.cover_file_id = scene_original.id
await session.commit()
payload = StepActivityInput(
order_id=1,
workflow_run_id=1,
step_name=WorkflowStepName.SCENE,
source_asset_id=1,
scene_ref_asset_id=1,
)
result = await scene_activities.run_scene_activity(payload)
assert result.uri == "https://images.example.com/orders/1/scene/generated.jpg"
assert result.metadata["provider"] == "gemini"
assert result.metadata["model"] == "gemini-test-image"
assert result.metadata["source_asset_id"] == 1
assert result.metadata["scene_resource_id"] == 1
async with get_session_factory()() as session:
assets = (
await session.execute(
AssetORM.__table__.select().where(
AssetORM.order_id == 1,
AssetORM.asset_type == AssetType.SCENE,
)
)
).mappings().all()
assert len(assets) == 1
assert assets[0]["uri"] == "https://images.example.com/orders/1/scene/generated.jpg"