Files
auto-virtual-tryon/tests/test_image_generation_service.py

92 lines
3.2 KiB
Python

"""Tests for application-level image-generation orchestration."""
from __future__ import annotations
import httpx
import pytest
from app.application.services.image_generation_service import ImageGenerationService
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
class FakeProvider:
def __init__(self) -> None:
self.calls: list[tuple[str, str, str]] = []
async def generate_tryon_image(
self,
*,
prompt: str,
person_image: SourceImage,
garment_image: SourceImage,
) -> GeneratedImageResult:
self.calls.append(("tryon", person_image.url, garment_image.url))
return GeneratedImageResult(
image_bytes=b"tryon",
mime_type="image/png",
provider="gemini",
model="gemini-test",
prompt=prompt,
)
async def generate_scene_image(
self,
*,
prompt: str,
source_image: SourceImage,
scene_image: SourceImage,
) -> GeneratedImageResult:
self.calls.append(("scene", source_image.url, scene_image.url))
return GeneratedImageResult(
image_bytes=b"scene",
mime_type="image/jpeg",
provider="gemini",
model="gemini-test",
prompt=prompt,
)
@pytest.mark.asyncio
async def test_image_generation_service_downloads_inputs_for_tryon_and_scene(monkeypatch):
"""The service should download both inputs and dispatch them to the matching provider method."""
responses = {
"https://images.example.com/person.png": (b"person-bytes", "image/png"),
"https://images.example.com/garment.png": (b"garment-bytes", "image/png"),
"https://images.example.com/source.jpg": (b"source-bytes", "image/jpeg"),
"https://images.example.com/scene.jpg": (b"scene-bytes", "image/jpeg"),
}
async def handler(request: httpx.Request) -> httpx.Response:
body, mime = responses[str(request.url)]
return httpx.Response(200, content=body, headers={"content-type": mime}, request=request)
monkeypatch.setenv("GEMINI_API_KEY", "test-key")
from app.config.settings import get_settings
get_settings.cache_clear()
provider = FakeProvider()
downloader = httpx.AsyncClient(transport=httpx.MockTransport(handler))
service = ImageGenerationService(provider=provider, downloader=downloader)
try:
tryon = await service.generate_tryon_image(
person_image_url="https://images.example.com/person.png",
garment_image_url="https://images.example.com/garment.png",
)
scene = await service.generate_scene_image(
source_image_url="https://images.example.com/source.jpg",
scene_image_url="https://images.example.com/scene.jpg",
)
finally:
await downloader.aclose()
get_settings.cache_clear()
assert provider.calls == [
("tryon", "https://images.example.com/person.png", "https://images.example.com/garment.png"),
("scene", "https://images.example.com/source.jpg", "https://images.example.com/scene.jpg"),
]
assert tryon.image_bytes == b"tryon"
assert scene.image_bytes == b"scene"