"""Gemini-backed image-generation provider.""" from __future__ import annotations import base64 import httpx from app.infra.image_generation.base import GeneratedImageResult, SourceImage class GeminiImageProvider: """Call the Gemini image-generation API via a configurable REST endpoint.""" provider_name = "gemini" def __init__( self, *, api_key: str, base_url: str, model: str, timeout_seconds: int, max_attempts: int = 2, http_client: httpx.AsyncClient | None = None, ) -> None: self.api_key = api_key self.base_url = base_url.rstrip("/") self.model = model self.timeout_seconds = timeout_seconds self.max_attempts = max(1, max_attempts) self._http_client = http_client async def generate_tryon_image( self, *, prompt: str, person_image: SourceImage, garment_image: SourceImage, ) -> GeneratedImageResult: """Generate a try-on image from a prepared person image and a garment image.""" return await self._generate_two_image_edit( prompt=prompt, first_image=person_image, second_image=garment_image, ) async def generate_scene_image( self, *, prompt: str, source_image: SourceImage, scene_image: SourceImage, ) -> GeneratedImageResult: """Generate a scene-composited image from a rendered subject image and a scene reference.""" return await self._generate_two_image_edit( prompt=prompt, first_image=source_image, second_image=scene_image, ) async def _generate_two_image_edit( self, *, prompt: str, first_image: SourceImage, second_image: SourceImage, ) -> GeneratedImageResult: """Generate an edited image from a prompt plus two inline image inputs.""" payload = { "contents": [ { "parts": [ {"text": prompt}, { "inline_data": { "mime_type": first_image.mime_type, "data": base64.b64encode(first_image.data).decode("utf-8"), } }, { "inline_data": { "mime_type": second_image.mime_type, "data": base64.b64encode(second_image.data).decode("utf-8"), } }, ] } ], "generationConfig": { "responseModalities": ["TEXT", "IMAGE"], }, } owns_client = self._http_client is None client = self._http_client or httpx.AsyncClient(timeout=self.timeout_seconds) try: response = None for attempt in range(1, self.max_attempts + 1): try: response = await client.post( f"{self.base_url}/models/{self.model}:generateContent", headers={ "x-goog-api-key": self.api_key, "Content-Type": "application/json", }, json=payload, timeout=self.timeout_seconds, ) response.raise_for_status() break except httpx.TransportError: if attempt >= self.max_attempts: raise finally: if owns_client: await client.aclose() if response is None: raise RuntimeError("Gemini provider did not receive a response") body = response.json() image_part = self._find_image_part(body) image_data = image_part.get("inlineData") or image_part.get("inline_data") mime_type = image_data.get("mimeType") or image_data.get("mime_type") or "image/png" data = image_data.get("data") if not data: raise ValueError("Gemini response did not include image bytes") return GeneratedImageResult( image_bytes=base64.b64decode(data), mime_type=mime_type, provider=self.provider_name, model=self.model, prompt=prompt, ) @staticmethod def _find_image_part(body: dict) -> dict: candidates = body.get("candidates") or [] for candidate in candidates: content = candidate.get("content") or {} for part in content.get("parts") or []: if part.get("inlineData") or part.get("inline_data"): return part raise ValueError("Gemini response did not contain an image part")