164 lines
6.4 KiB
Python
164 lines
6.4 KiB
Python
"""Application-facing orchestration for image-generation providers."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
|
|
import httpx
|
|
from fastapi import HTTPException, status
|
|
|
|
from app.config.settings import get_settings
|
|
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
|
|
from app.infra.image_generation.gemini_provider import GeminiImageProvider
|
|
|
|
TRYON_PROMPT = (
|
|
"Use the first image as the base person image and the second image as the garment reference. "
|
|
"Generate a realistic virtual try-on result where the person in the first image is wearing the garment from "
|
|
"the second image. Preserve the person's identity, body proportions, pose, camera angle, lighting direction, "
|
|
"and background composition from the first image. Preserve the garment's key silhouette, material feel, color, "
|
|
"and visible design details from the second image. Produce a clean e-commerce quality result without adding "
|
|
"extra accessories, text, watermarks, or duplicate limbs."
|
|
)
|
|
|
|
SCENE_PROMPT = (
|
|
"Use the first image as the finished subject image and the second image as the target scene reference. "
|
|
"Generate a realistic composite where the person, clothing, body proportions, facial identity, pose, and camera "
|
|
"framing from the first image are preserved, while the environment is adapted to match the second image. "
|
|
"Keep the garment details, colors, and texture from the first image intact. Blend the subject naturally into the "
|
|
"target scene with coherent perspective, lighting, shadows, and depth. Do not add extra people, accessories, "
|
|
"text, logos, watermarks, or duplicate body parts."
|
|
)
|
|
|
|
MOCK_PNG_BASE64 = (
|
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO5Ww4QAAAAASUVORK5CYII="
|
|
)
|
|
|
|
|
|
class MockImageGenerationService:
|
|
"""Deterministic fallback used in tests and when no real provider is configured."""
|
|
|
|
async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> GeneratedImageResult:
|
|
return GeneratedImageResult(
|
|
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
|
|
mime_type="image/png",
|
|
provider="mock",
|
|
model="mock-image-provider",
|
|
prompt=TRYON_PROMPT,
|
|
)
|
|
|
|
async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> GeneratedImageResult:
|
|
return GeneratedImageResult(
|
|
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
|
|
mime_type="image/png",
|
|
provider="mock",
|
|
model="mock-image-provider",
|
|
prompt=SCENE_PROMPT,
|
|
)
|
|
|
|
|
|
class ImageGenerationService:
|
|
"""Download source assets and dispatch them to the configured provider."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
provider: GeminiImageProvider | None = None,
|
|
downloader: httpx.AsyncClient | None = None,
|
|
) -> None:
|
|
self.settings = get_settings()
|
|
self.provider = provider or GeminiImageProvider(
|
|
api_key=self.settings.gemini_api_key,
|
|
base_url=self.settings.gemini_base_url,
|
|
model=self.settings.gemini_model,
|
|
timeout_seconds=self.settings.gemini_timeout_seconds,
|
|
max_attempts=self.settings.gemini_max_attempts,
|
|
)
|
|
self._downloader = downloader
|
|
|
|
async def generate_tryon_image(
|
|
self,
|
|
*,
|
|
person_image_url: str,
|
|
garment_image_url: str,
|
|
) -> GeneratedImageResult:
|
|
"""Generate a try-on image using the configured provider."""
|
|
|
|
if not self.settings.gemini_api_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
|
|
)
|
|
|
|
person_image, garment_image = await self._download_inputs(
|
|
first_image_url=person_image_url,
|
|
second_image_url=garment_image_url,
|
|
)
|
|
return await self.provider.generate_tryon_image(
|
|
prompt=TRYON_PROMPT,
|
|
person_image=person_image,
|
|
garment_image=garment_image,
|
|
)
|
|
|
|
async def generate_scene_image(
|
|
self,
|
|
*,
|
|
source_image_url: str,
|
|
scene_image_url: str,
|
|
) -> GeneratedImageResult:
|
|
"""Generate a scene-composited image using the configured provider."""
|
|
|
|
if not self.settings.gemini_api_key:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
|
|
)
|
|
|
|
source_image, scene_image = await self._download_inputs(
|
|
first_image_url=source_image_url,
|
|
second_image_url=scene_image_url,
|
|
)
|
|
return await self.provider.generate_scene_image(
|
|
prompt=SCENE_PROMPT,
|
|
source_image=source_image,
|
|
scene_image=scene_image,
|
|
)
|
|
|
|
async def _download_inputs(
|
|
self,
|
|
*,
|
|
first_image_url: str,
|
|
second_image_url: str,
|
|
) -> tuple[SourceImage, SourceImage]:
|
|
owns_client = self._downloader is None
|
|
client = self._downloader or httpx.AsyncClient(timeout=self.settings.gemini_timeout_seconds)
|
|
try:
|
|
first_response = await client.get(first_image_url)
|
|
first_response.raise_for_status()
|
|
second_response = await client.get(second_image_url)
|
|
second_response.raise_for_status()
|
|
finally:
|
|
if owns_client:
|
|
await client.aclose()
|
|
|
|
first_mime = first_response.headers.get("content-type", "image/png").split(";")[0].strip()
|
|
second_mime = second_response.headers.get("content-type", "image/png").split(";")[0].strip()
|
|
|
|
return (
|
|
SourceImage(url=first_image_url, mime_type=first_mime or "image/png", data=first_response.content),
|
|
SourceImage(url=second_image_url, mime_type=second_mime or "image/png", data=second_response.content),
|
|
)
|
|
|
|
|
|
def build_image_generation_service():
|
|
"""Return the configured image-generation service."""
|
|
|
|
settings = get_settings()
|
|
if settings.image_generation_provider.lower() == "mock":
|
|
return MockImageGenerationService()
|
|
if settings.image_generation_provider.lower() == "gemini":
|
|
return ImageGenerationService()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Unsupported image_generation_provider: {settings.image_generation_provider}",
|
|
)
|