92 lines
3.2 KiB
Python
92 lines
3.2 KiB
Python
"""Tests for application-level image-generation orchestration."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from app.application.services.image_generation_service import ImageGenerationService
|
|
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
|
|
|
|
|
|
class FakeProvider:
|
|
def __init__(self) -> None:
|
|
self.calls: list[tuple[str, str, str]] = []
|
|
|
|
async def generate_tryon_image(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
person_image: SourceImage,
|
|
garment_image: SourceImage,
|
|
) -> GeneratedImageResult:
|
|
self.calls.append(("tryon", person_image.url, garment_image.url))
|
|
return GeneratedImageResult(
|
|
image_bytes=b"tryon",
|
|
mime_type="image/png",
|
|
provider="gemini",
|
|
model="gemini-test",
|
|
prompt=prompt,
|
|
)
|
|
|
|
async def generate_scene_image(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
source_image: SourceImage,
|
|
scene_image: SourceImage,
|
|
) -> GeneratedImageResult:
|
|
self.calls.append(("scene", source_image.url, scene_image.url))
|
|
return GeneratedImageResult(
|
|
image_bytes=b"scene",
|
|
mime_type="image/jpeg",
|
|
provider="gemini",
|
|
model="gemini-test",
|
|
prompt=prompt,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_image_generation_service_downloads_inputs_for_tryon_and_scene(monkeypatch):
|
|
"""The service should download both inputs and dispatch them to the matching provider method."""
|
|
|
|
responses = {
|
|
"https://images.example.com/person.png": (b"person-bytes", "image/png"),
|
|
"https://images.example.com/garment.png": (b"garment-bytes", "image/png"),
|
|
"https://images.example.com/source.jpg": (b"source-bytes", "image/jpeg"),
|
|
"https://images.example.com/scene.jpg": (b"scene-bytes", "image/jpeg"),
|
|
}
|
|
|
|
async def handler(request: httpx.Request) -> httpx.Response:
|
|
body, mime = responses[str(request.url)]
|
|
return httpx.Response(200, content=body, headers={"content-type": mime}, request=request)
|
|
|
|
monkeypatch.setenv("GEMINI_API_KEY", "test-key")
|
|
|
|
from app.config.settings import get_settings
|
|
|
|
get_settings.cache_clear()
|
|
provider = FakeProvider()
|
|
downloader = httpx.AsyncClient(transport=httpx.MockTransport(handler))
|
|
service = ImageGenerationService(provider=provider, downloader=downloader)
|
|
|
|
try:
|
|
tryon = await service.generate_tryon_image(
|
|
person_image_url="https://images.example.com/person.png",
|
|
garment_image_url="https://images.example.com/garment.png",
|
|
)
|
|
scene = await service.generate_scene_image(
|
|
source_image_url="https://images.example.com/source.jpg",
|
|
scene_image_url="https://images.example.com/scene.jpg",
|
|
)
|
|
finally:
|
|
await downloader.aclose()
|
|
get_settings.cache_clear()
|
|
|
|
assert provider.calls == [
|
|
("tryon", "https://images.example.com/person.png", "https://images.example.com/garment.png"),
|
|
("scene", "https://images.example.com/source.jpg", "https://images.example.com/scene.jpg"),
|
|
]
|
|
assert tryon.image_bytes == b"tryon"
|
|
assert scene.image_bytes == b"scene"
|