"""Application-facing orchestration for image-generation providers.""" from __future__ import annotations import base64 import httpx from fastapi import HTTPException, status from app.config.settings import get_settings from app.infra.image_generation.base import GeneratedImageResult, SourceImage from app.infra.image_generation.gemini_provider import GeminiImageProvider TRYON_PROMPT = ( "Use the first image as the base person image and the second image as the garment reference. " "Generate a realistic virtual try-on result where the person in the first image is wearing the garment from " "the second image. Preserve the person's identity, body proportions, pose, camera angle, lighting direction, " "and background composition from the first image. Preserve the garment's key silhouette, material feel, color, " "and visible design details from the second image. Produce a clean e-commerce quality result without adding " "extra accessories, text, watermarks, or duplicate limbs." ) SCENE_PROMPT = ( "Use the first image as the finished subject image and the second image as the target scene reference. " "Generate a realistic composite where the person, clothing, body proportions, facial identity, pose, and camera " "framing from the first image are preserved, while the environment is adapted to match the second image. " "Keep the garment details, colors, and texture from the first image intact. Blend the subject naturally into the " "target scene with coherent perspective, lighting, shadows, and depth. Do not add extra people, accessories, " "text, logos, watermarks, or duplicate body parts." ) MOCK_PNG_BASE64 = ( "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO5Ww4QAAAAASUVORK5CYII=" ) class MockImageGenerationService: """Deterministic fallback used in tests and when no real provider is configured.""" async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> GeneratedImageResult: return GeneratedImageResult( image_bytes=base64.b64decode(MOCK_PNG_BASE64), mime_type="image/png", provider="mock", model="mock-image-provider", prompt=TRYON_PROMPT, ) async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> GeneratedImageResult: return GeneratedImageResult( image_bytes=base64.b64decode(MOCK_PNG_BASE64), mime_type="image/png", provider="mock", model="mock-image-provider", prompt=SCENE_PROMPT, ) class ImageGenerationService: """Download source assets and dispatch them to the configured provider.""" def __init__( self, *, provider: GeminiImageProvider | None = None, downloader: httpx.AsyncClient | None = None, ) -> None: self.settings = get_settings() self.provider = provider or GeminiImageProvider( api_key=self.settings.gemini_api_key, base_url=self.settings.gemini_base_url, model=self.settings.gemini_model, timeout_seconds=self.settings.gemini_timeout_seconds, max_attempts=self.settings.gemini_max_attempts, ) self._downloader = downloader async def generate_tryon_image( self, *, person_image_url: str, garment_image_url: str, ) -> GeneratedImageResult: """Generate a try-on image using the configured provider.""" if not self.settings.gemini_api_key: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="GEMINI_API_KEY is required when image_generation_provider=gemini", ) person_image, garment_image = await self._download_inputs( first_image_url=person_image_url, second_image_url=garment_image_url, ) return await self.provider.generate_tryon_image( prompt=TRYON_PROMPT, person_image=person_image, garment_image=garment_image, ) async def generate_scene_image( self, *, source_image_url: str, scene_image_url: str, ) -> GeneratedImageResult: """Generate a scene-composited image using the configured provider.""" if not self.settings.gemini_api_key: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="GEMINI_API_KEY is required when image_generation_provider=gemini", ) source_image, scene_image = await self._download_inputs( first_image_url=source_image_url, second_image_url=scene_image_url, ) return await self.provider.generate_scene_image( prompt=SCENE_PROMPT, source_image=source_image, scene_image=scene_image, ) async def _download_inputs( self, *, first_image_url: str, second_image_url: str, ) -> tuple[SourceImage, SourceImage]: owns_client = self._downloader is None client = self._downloader or httpx.AsyncClient(timeout=self.settings.gemini_timeout_seconds) try: first_response = await client.get(first_image_url) first_response.raise_for_status() second_response = await client.get(second_image_url) second_response.raise_for_status() finally: if owns_client: await client.aclose() first_mime = first_response.headers.get("content-type", "image/png").split(";")[0].strip() second_mime = second_response.headers.get("content-type", "image/png").split(";")[0].strip() return ( SourceImage(url=first_image_url, mime_type=first_mime or "image/png", data=first_response.content), SourceImage(url=second_image_url, mime_type=second_mime or "image/png", data=second_response.content), ) def build_image_generation_service(): """Return the configured image-generation service.""" settings = get_settings() if settings.image_generation_provider.lower() == "mock": return MockImageGenerationService() if settings.image_generation_provider.lower() == "gemini": return ImageGenerationService() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Unsupported image_generation_provider: {settings.image_generation_provider}", )