Files
auto-virtual-tryon/app/application/services/image_generation_service.py

164 lines
6.4 KiB
Python

"""Application-facing orchestration for image-generation providers."""
from __future__ import annotations
import base64
import httpx
from fastapi import HTTPException, status
from app.config.settings import get_settings
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
from app.infra.image_generation.gemini_provider import GeminiImageProvider
TRYON_PROMPT = (
"Use the first image as the base person image and the second image as the garment reference. "
"Generate a realistic virtual try-on result where the person in the first image is wearing the garment from "
"the second image. Preserve the person's identity, body proportions, pose, camera angle, lighting direction, "
"and background composition from the first image. Preserve the garment's key silhouette, material feel, color, "
"and visible design details from the second image. Produce a clean e-commerce quality result without adding "
"extra accessories, text, watermarks, or duplicate limbs."
)
SCENE_PROMPT = (
"Use the first image as the finished subject image and the second image as the target scene reference. "
"Generate a realistic composite where the person, clothing, body proportions, facial identity, pose, and camera "
"framing from the first image are preserved, while the environment is adapted to match the second image. "
"Keep the garment details, colors, and texture from the first image intact. Blend the subject naturally into the "
"target scene with coherent perspective, lighting, shadows, and depth. Do not add extra people, accessories, "
"text, logos, watermarks, or duplicate body parts."
)
MOCK_PNG_BASE64 = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO5Ww4QAAAAASUVORK5CYII="
)
class MockImageGenerationService:
"""Deterministic fallback used in tests and when no real provider is configured."""
async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> GeneratedImageResult:
return GeneratedImageResult(
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
mime_type="image/png",
provider="mock",
model="mock-image-provider",
prompt=TRYON_PROMPT,
)
async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> GeneratedImageResult:
return GeneratedImageResult(
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
mime_type="image/png",
provider="mock",
model="mock-image-provider",
prompt=SCENE_PROMPT,
)
class ImageGenerationService:
"""Download source assets and dispatch them to the configured provider."""
def __init__(
self,
*,
provider: GeminiImageProvider | None = None,
downloader: httpx.AsyncClient | None = None,
) -> None:
self.settings = get_settings()
self.provider = provider or GeminiImageProvider(
api_key=self.settings.gemini_api_key,
base_url=self.settings.gemini_base_url,
model=self.settings.gemini_model,
timeout_seconds=self.settings.gemini_timeout_seconds,
max_attempts=self.settings.gemini_max_attempts,
)
self._downloader = downloader
async def generate_tryon_image(
self,
*,
person_image_url: str,
garment_image_url: str,
) -> GeneratedImageResult:
"""Generate a try-on image using the configured provider."""
if not self.settings.gemini_api_key:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
)
person_image, garment_image = await self._download_inputs(
first_image_url=person_image_url,
second_image_url=garment_image_url,
)
return await self.provider.generate_tryon_image(
prompt=TRYON_PROMPT,
person_image=person_image,
garment_image=garment_image,
)
async def generate_scene_image(
self,
*,
source_image_url: str,
scene_image_url: str,
) -> GeneratedImageResult:
"""Generate a scene-composited image using the configured provider."""
if not self.settings.gemini_api_key:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
)
source_image, scene_image = await self._download_inputs(
first_image_url=source_image_url,
second_image_url=scene_image_url,
)
return await self.provider.generate_scene_image(
prompt=SCENE_PROMPT,
source_image=source_image,
scene_image=scene_image,
)
async def _download_inputs(
self,
*,
first_image_url: str,
second_image_url: str,
) -> tuple[SourceImage, SourceImage]:
owns_client = self._downloader is None
client = self._downloader or httpx.AsyncClient(timeout=self.settings.gemini_timeout_seconds)
try:
first_response = await client.get(first_image_url)
first_response.raise_for_status()
second_response = await client.get(second_image_url)
second_response.raise_for_status()
finally:
if owns_client:
await client.aclose()
first_mime = first_response.headers.get("content-type", "image/png").split(";")[0].strip()
second_mime = second_response.headers.get("content-type", "image/png").split(";")[0].strip()
return (
SourceImage(url=first_image_url, mime_type=first_mime or "image/png", data=first_response.content),
SourceImage(url=second_image_url, mime_type=second_mime or "image/png", data=second_response.content),
)
def build_image_generation_service():
"""Return the configured image-generation service."""
settings = get_settings()
if settings.image_generation_provider.lower() == "mock":
return MockImageGenerationService()
if settings.image_generation_provider.lower() == "gemini":
return ImageGenerationService()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Unsupported image_generation_provider: {settings.image_generation_provider}",
)