150 lines
4.9 KiB
Python
150 lines
4.9 KiB
Python
"""Gemini-backed image-generation provider."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
|
|
import httpx
|
|
|
|
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
|
|
|
|
|
|
class GeminiImageProvider:
|
|
"""Call the Gemini image-generation API via a configurable REST endpoint."""
|
|
|
|
provider_name = "gemini"
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
api_key: str,
|
|
base_url: str,
|
|
model: str,
|
|
timeout_seconds: int,
|
|
max_attempts: int = 2,
|
|
http_client: httpx.AsyncClient | None = None,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.base_url = base_url.rstrip("/")
|
|
self.model = model
|
|
self.timeout_seconds = timeout_seconds
|
|
self.max_attempts = max(1, max_attempts)
|
|
self._http_client = http_client
|
|
|
|
async def generate_tryon_image(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
person_image: SourceImage,
|
|
garment_image: SourceImage,
|
|
) -> GeneratedImageResult:
|
|
"""Generate a try-on image from a prepared person image and a garment image."""
|
|
|
|
return await self._generate_two_image_edit(
|
|
prompt=prompt,
|
|
first_image=person_image,
|
|
second_image=garment_image,
|
|
)
|
|
|
|
async def generate_scene_image(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
source_image: SourceImage,
|
|
scene_image: SourceImage,
|
|
) -> GeneratedImageResult:
|
|
"""Generate a scene-composited image from a rendered subject image and a scene reference."""
|
|
|
|
return await self._generate_two_image_edit(
|
|
prompt=prompt,
|
|
first_image=source_image,
|
|
second_image=scene_image,
|
|
)
|
|
|
|
async def _generate_two_image_edit(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
first_image: SourceImage,
|
|
second_image: SourceImage,
|
|
) -> GeneratedImageResult:
|
|
"""Generate an edited image from a prompt plus two inline image inputs."""
|
|
|
|
payload = {
|
|
"contents": [
|
|
{
|
|
"parts": [
|
|
{"text": prompt},
|
|
{
|
|
"inline_data": {
|
|
"mime_type": first_image.mime_type,
|
|
"data": base64.b64encode(first_image.data).decode("utf-8"),
|
|
}
|
|
},
|
|
{
|
|
"inline_data": {
|
|
"mime_type": second_image.mime_type,
|
|
"data": base64.b64encode(second_image.data).decode("utf-8"),
|
|
}
|
|
},
|
|
]
|
|
}
|
|
],
|
|
"generationConfig": {
|
|
"responseModalities": ["TEXT", "IMAGE"],
|
|
},
|
|
}
|
|
|
|
owns_client = self._http_client is None
|
|
client = self._http_client or httpx.AsyncClient(timeout=self.timeout_seconds)
|
|
try:
|
|
response = None
|
|
for attempt in range(1, self.max_attempts + 1):
|
|
try:
|
|
response = await client.post(
|
|
f"{self.base_url}/models/{self.model}:generateContent",
|
|
headers={
|
|
"x-goog-api-key": self.api_key,
|
|
"Content-Type": "application/json",
|
|
},
|
|
json=payload,
|
|
timeout=self.timeout_seconds,
|
|
)
|
|
response.raise_for_status()
|
|
break
|
|
except httpx.TransportError:
|
|
if attempt >= self.max_attempts:
|
|
raise
|
|
finally:
|
|
if owns_client:
|
|
await client.aclose()
|
|
|
|
if response is None:
|
|
raise RuntimeError("Gemini provider did not receive a response")
|
|
|
|
body = response.json()
|
|
image_part = self._find_image_part(body)
|
|
image_data = image_part.get("inlineData") or image_part.get("inline_data")
|
|
mime_type = image_data.get("mimeType") or image_data.get("mime_type") or "image/png"
|
|
data = image_data.get("data")
|
|
if not data:
|
|
raise ValueError("Gemini response did not include image bytes")
|
|
|
|
return GeneratedImageResult(
|
|
image_bytes=base64.b64decode(data),
|
|
mime_type=mime_type,
|
|
provider=self.provider_name,
|
|
model=self.model,
|
|
prompt=prompt,
|
|
)
|
|
|
|
@staticmethod
|
|
def _find_image_part(body: dict) -> dict:
|
|
candidates = body.get("candidates") or []
|
|
for candidate in candidates:
|
|
content = candidate.get("content") or {}
|
|
for part in content.get("parts") or []:
|
|
if part.get("inlineData") or part.get("inline_data"):
|
|
return part
|
|
raise ValueError("Gemini response did not contain an image part")
|