feat: add resource library and real image workflow
This commit is contained in:
163
app/application/services/image_generation_service.py
Normal file
163
app/application/services/image_generation_service.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Application-facing orchestration for image-generation providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.config.settings import get_settings
|
||||
from app.infra.image_generation.base import GeneratedImageResult, SourceImage
|
||||
from app.infra.image_generation.gemini_provider import GeminiImageProvider
|
||||
|
||||
TRYON_PROMPT = (
|
||||
"Use the first image as the base person image and the second image as the garment reference. "
|
||||
"Generate a realistic virtual try-on result where the person in the first image is wearing the garment from "
|
||||
"the second image. Preserve the person's identity, body proportions, pose, camera angle, lighting direction, "
|
||||
"and background composition from the first image. Preserve the garment's key silhouette, material feel, color, "
|
||||
"and visible design details from the second image. Produce a clean e-commerce quality result without adding "
|
||||
"extra accessories, text, watermarks, or duplicate limbs."
|
||||
)
|
||||
|
||||
SCENE_PROMPT = (
|
||||
"Use the first image as the finished subject image and the second image as the target scene reference. "
|
||||
"Generate a realistic composite where the person, clothing, body proportions, facial identity, pose, and camera "
|
||||
"framing from the first image are preserved, while the environment is adapted to match the second image. "
|
||||
"Keep the garment details, colors, and texture from the first image intact. Blend the subject naturally into the "
|
||||
"target scene with coherent perspective, lighting, shadows, and depth. Do not add extra people, accessories, "
|
||||
"text, logos, watermarks, or duplicate body parts."
|
||||
)
|
||||
|
||||
MOCK_PNG_BASE64 = (
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO5Ww4QAAAAASUVORK5CYII="
|
||||
)
|
||||
|
||||
|
||||
class MockImageGenerationService:
|
||||
"""Deterministic fallback used in tests and when no real provider is configured."""
|
||||
|
||||
async def generate_tryon_image(self, *, person_image_url: str, garment_image_url: str) -> GeneratedImageResult:
|
||||
return GeneratedImageResult(
|
||||
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
|
||||
mime_type="image/png",
|
||||
provider="mock",
|
||||
model="mock-image-provider",
|
||||
prompt=TRYON_PROMPT,
|
||||
)
|
||||
|
||||
async def generate_scene_image(self, *, source_image_url: str, scene_image_url: str) -> GeneratedImageResult:
|
||||
return GeneratedImageResult(
|
||||
image_bytes=base64.b64decode(MOCK_PNG_BASE64),
|
||||
mime_type="image/png",
|
||||
provider="mock",
|
||||
model="mock-image-provider",
|
||||
prompt=SCENE_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationService:
|
||||
"""Download source assets and dispatch them to the configured provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: GeminiImageProvider | None = None,
|
||||
downloader: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self.settings = get_settings()
|
||||
self.provider = provider or GeminiImageProvider(
|
||||
api_key=self.settings.gemini_api_key,
|
||||
base_url=self.settings.gemini_base_url,
|
||||
model=self.settings.gemini_model,
|
||||
timeout_seconds=self.settings.gemini_timeout_seconds,
|
||||
max_attempts=self.settings.gemini_max_attempts,
|
||||
)
|
||||
self._downloader = downloader
|
||||
|
||||
async def generate_tryon_image(
|
||||
self,
|
||||
*,
|
||||
person_image_url: str,
|
||||
garment_image_url: str,
|
||||
) -> GeneratedImageResult:
|
||||
"""Generate a try-on image using the configured provider."""
|
||||
|
||||
if not self.settings.gemini_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
|
||||
)
|
||||
|
||||
person_image, garment_image = await self._download_inputs(
|
||||
first_image_url=person_image_url,
|
||||
second_image_url=garment_image_url,
|
||||
)
|
||||
return await self.provider.generate_tryon_image(
|
||||
prompt=TRYON_PROMPT,
|
||||
person_image=person_image,
|
||||
garment_image=garment_image,
|
||||
)
|
||||
|
||||
async def generate_scene_image(
|
||||
self,
|
||||
*,
|
||||
source_image_url: str,
|
||||
scene_image_url: str,
|
||||
) -> GeneratedImageResult:
|
||||
"""Generate a scene-composited image using the configured provider."""
|
||||
|
||||
if not self.settings.gemini_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="GEMINI_API_KEY is required when image_generation_provider=gemini",
|
||||
)
|
||||
|
||||
source_image, scene_image = await self._download_inputs(
|
||||
first_image_url=source_image_url,
|
||||
second_image_url=scene_image_url,
|
||||
)
|
||||
return await self.provider.generate_scene_image(
|
||||
prompt=SCENE_PROMPT,
|
||||
source_image=source_image,
|
||||
scene_image=scene_image,
|
||||
)
|
||||
|
||||
async def _download_inputs(
|
||||
self,
|
||||
*,
|
||||
first_image_url: str,
|
||||
second_image_url: str,
|
||||
) -> tuple[SourceImage, SourceImage]:
|
||||
owns_client = self._downloader is None
|
||||
client = self._downloader or httpx.AsyncClient(timeout=self.settings.gemini_timeout_seconds)
|
||||
try:
|
||||
first_response = await client.get(first_image_url)
|
||||
first_response.raise_for_status()
|
||||
second_response = await client.get(second_image_url)
|
||||
second_response.raise_for_status()
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.aclose()
|
||||
|
||||
first_mime = first_response.headers.get("content-type", "image/png").split(";")[0].strip()
|
||||
second_mime = second_response.headers.get("content-type", "image/png").split(";")[0].strip()
|
||||
|
||||
return (
|
||||
SourceImage(url=first_image_url, mime_type=first_mime or "image/png", data=first_response.content),
|
||||
SourceImage(url=second_image_url, mime_type=second_mime or "image/png", data=second_response.content),
|
||||
)
|
||||
|
||||
|
||||
def build_image_generation_service():
|
||||
"""Return the configured image-generation service."""
|
||||
|
||||
settings = get_settings()
|
||||
if settings.image_generation_provider.lower() == "mock":
|
||||
return MockImageGenerationService()
|
||||
if settings.image_generation_provider.lower() == "gemini":
|
||||
return ImageGenerationService()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Unsupported image_generation_provider: {settings.image_generation_provider}",
|
||||
)
|
||||
366
app/application/services/library_service.py
Normal file
366
app/application/services/library_service.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""Library resource application service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.api.schemas.library import (
|
||||
ArchiveLibraryResourceResponse,
|
||||
CreateLibraryResourceRequest,
|
||||
LibraryResourceFileRead,
|
||||
LibraryResourceListResponse,
|
||||
LibraryResourceRead,
|
||||
PresignUploadResponse,
|
||||
UpdateLibraryResourceRequest,
|
||||
)
|
||||
from app.config.settings import get_settings
|
||||
from app.domain.enums import LibraryFileRole, LibraryResourceStatus, LibraryResourceType
|
||||
from app.infra.db.models.library_resource import LibraryResourceORM
|
||||
from app.infra.db.models.library_resource_file import LibraryResourceFileORM
|
||||
from app.infra.storage.s3 import RESOURCE_PREFIXES, S3PresignService
|
||||
|
||||
|
||||
class LibraryService:
|
||||
"""Application service for resource-library uploads and queries."""
|
||||
|
||||
def __init__(self, presign_service: S3PresignService | None = None) -> None:
|
||||
self.presign_service = presign_service or S3PresignService()
|
||||
|
||||
def create_upload_presign(
|
||||
self,
|
||||
resource_type: LibraryResourceType,
|
||||
file_name: str,
|
||||
content_type: str,
|
||||
) -> PresignUploadResponse:
|
||||
"""Create upload metadata for a direct S3 PUT."""
|
||||
|
||||
settings = get_settings()
|
||||
if not all(
|
||||
[
|
||||
settings.s3_bucket,
|
||||
settings.s3_region,
|
||||
settings.s3_access_key,
|
||||
settings.s3_secret_key,
|
||||
settings.s3_endpoint,
|
||||
]
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="S3 upload settings are incomplete",
|
||||
)
|
||||
|
||||
storage_key, upload_url = self.presign_service.create_upload(
|
||||
resource_type=resource_type,
|
||||
file_name=file_name,
|
||||
content_type=content_type,
|
||||
)
|
||||
return PresignUploadResponse(
|
||||
method="PUT",
|
||||
upload_url=upload_url,
|
||||
headers={"content-type": content_type},
|
||||
storage_key=storage_key,
|
||||
public_url=self.presign_service.get_public_url(storage_key),
|
||||
)
|
||||
|
||||
async def create_resource(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
payload: CreateLibraryResourceRequest,
|
||||
) -> LibraryResourceRead:
|
||||
"""Persist a resource and its uploaded file metadata."""
|
||||
|
||||
self._validate_payload(payload)
|
||||
|
||||
resource = LibraryResourceORM(
|
||||
resource_type=payload.resource_type,
|
||||
name=payload.name.strip(),
|
||||
description=payload.description.strip() if payload.description else None,
|
||||
tags=[tag.strip() for tag in payload.tags if tag.strip()],
|
||||
status=LibraryResourceStatus.ACTIVE,
|
||||
gender=payload.gender.strip() if payload.gender else None,
|
||||
age_group=payload.age_group.strip() if payload.age_group else None,
|
||||
pose_id=payload.pose_id,
|
||||
environment=payload.environment.strip() if payload.environment else None,
|
||||
category=payload.category.strip() if payload.category else None,
|
||||
)
|
||||
session.add(resource)
|
||||
await session.flush()
|
||||
|
||||
file_models: list[LibraryResourceFileORM] = []
|
||||
for item in payload.files:
|
||||
file_model = LibraryResourceFileORM(
|
||||
resource_id=resource.id,
|
||||
file_role=item.file_role,
|
||||
storage_key=item.storage_key,
|
||||
public_url=item.public_url,
|
||||
bucket=get_settings().s3_bucket,
|
||||
mime_type=item.mime_type,
|
||||
size_bytes=item.size_bytes,
|
||||
sort_order=item.sort_order,
|
||||
width=item.width,
|
||||
height=item.height,
|
||||
)
|
||||
session.add(file_model)
|
||||
file_models.append(file_model)
|
||||
|
||||
await session.flush()
|
||||
resource.original_file_id = self._find_role(file_models, LibraryFileRole.ORIGINAL).id
|
||||
resource.cover_file_id = self._find_role(file_models, LibraryFileRole.THUMBNAIL).id
|
||||
await session.commit()
|
||||
|
||||
return self._to_read(resource, files=file_models)
|
||||
|
||||
async def list_resources(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
resource_type: LibraryResourceType | None = None,
|
||||
query: str | None = None,
|
||||
gender: str | None = None,
|
||||
age_group: str | None = None,
|
||||
environment: str | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
) -> LibraryResourceListResponse:
|
||||
"""List persisted resources with simple filter support."""
|
||||
|
||||
filters = [LibraryResourceORM.status == LibraryResourceStatus.ACTIVE]
|
||||
if resource_type is not None:
|
||||
filters.append(LibraryResourceORM.resource_type == resource_type)
|
||||
if query:
|
||||
like = f"%{query.strip()}%"
|
||||
filters.append(
|
||||
or_(
|
||||
LibraryResourceORM.name.ilike(like),
|
||||
LibraryResourceORM.description.ilike(like),
|
||||
)
|
||||
)
|
||||
if gender:
|
||||
filters.append(LibraryResourceORM.gender == gender)
|
||||
if age_group:
|
||||
filters.append(LibraryResourceORM.age_group == age_group)
|
||||
if environment:
|
||||
filters.append(LibraryResourceORM.environment == environment)
|
||||
if category:
|
||||
filters.append(LibraryResourceORM.category == category)
|
||||
|
||||
total = (
|
||||
await session.execute(select(func.count(LibraryResourceORM.id)).where(*filters))
|
||||
).scalar_one()
|
||||
|
||||
result = await session.execute(
|
||||
select(LibraryResourceORM)
|
||||
.options(selectinload(LibraryResourceORM.files))
|
||||
.where(*filters)
|
||||
.order_by(LibraryResourceORM.created_at.desc(), LibraryResourceORM.id.desc())
|
||||
.offset((page - 1) * limit)
|
||||
.limit(limit)
|
||||
)
|
||||
items = result.scalars().all()
|
||||
return LibraryResourceListResponse(total=total, items=[self._to_read(item) for item in items])
|
||||
|
||||
async def update_resource(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
resource_id: int,
|
||||
payload: UpdateLibraryResourceRequest,
|
||||
) -> LibraryResourceRead:
|
||||
"""Update editable metadata on a library resource."""
|
||||
|
||||
resource = await self._get_resource_or_404(session, resource_id)
|
||||
|
||||
if payload.name is not None:
|
||||
name = payload.name.strip()
|
||||
if not name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Resource name is required",
|
||||
)
|
||||
resource.name = name
|
||||
|
||||
if payload.description is not None:
|
||||
resource.description = payload.description.strip() or None
|
||||
|
||||
if payload.tags is not None:
|
||||
resource.tags = [tag.strip() for tag in payload.tags if tag.strip()]
|
||||
|
||||
if resource.resource_type == LibraryResourceType.MODEL:
|
||||
if payload.gender is not None:
|
||||
resource.gender = payload.gender.strip() or None
|
||||
if payload.age_group is not None:
|
||||
resource.age_group = payload.age_group.strip() or None
|
||||
if payload.pose_id is not None:
|
||||
resource.pose_id = payload.pose_id
|
||||
elif resource.resource_type == LibraryResourceType.SCENE:
|
||||
if payload.environment is not None:
|
||||
resource.environment = payload.environment.strip() or None
|
||||
elif resource.resource_type == LibraryResourceType.GARMENT:
|
||||
if payload.category is not None:
|
||||
resource.category = payload.category.strip() or None
|
||||
|
||||
if payload.cover_file_id is not None:
|
||||
cover_file = next(
|
||||
(file for file in resource.files if file.id == payload.cover_file_id),
|
||||
None,
|
||||
)
|
||||
if cover_file is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cover file does not belong to the resource",
|
||||
)
|
||||
resource.cover_file_id = cover_file.id
|
||||
|
||||
self._validate_existing_resource(resource)
|
||||
await session.commit()
|
||||
await session.refresh(resource, attribute_names=["files"])
|
||||
return self._to_read(resource)
|
||||
|
||||
async def archive_resource(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
resource_id: int,
|
||||
) -> ArchiveLibraryResourceResponse:
|
||||
"""Soft delete a library resource by archiving it."""
|
||||
|
||||
resource = await self._get_resource_or_404(session, resource_id)
|
||||
resource.status = LibraryResourceStatus.ARCHIVED
|
||||
await session.commit()
|
||||
return ArchiveLibraryResourceResponse(id=resource.id)
|
||||
|
||||
def _validate_payload(self, payload: CreateLibraryResourceRequest) -> None:
|
||||
if not payload.name.strip():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Resource name is required")
|
||||
|
||||
if not payload.files:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="At least one file is required")
|
||||
|
||||
originals = [file for file in payload.files if file.file_role == LibraryFileRole.ORIGINAL]
|
||||
thumbnails = [file for file in payload.files if file.file_role == LibraryFileRole.THUMBNAIL]
|
||||
if len(originals) != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Exactly one original file is required",
|
||||
)
|
||||
if len(thumbnails) != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Exactly one thumbnail file is required",
|
||||
)
|
||||
|
||||
expected_prefix = f"library/{RESOURCE_PREFIXES[payload.resource_type]}/"
|
||||
for file in payload.files:
|
||||
if not file.storage_key.startswith(expected_prefix):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Uploaded file key does not match resource type",
|
||||
)
|
||||
|
||||
if payload.resource_type == LibraryResourceType.MODEL:
|
||||
if not payload.gender or not payload.age_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Model resources require gender and age_group",
|
||||
)
|
||||
elif payload.resource_type == LibraryResourceType.SCENE:
|
||||
if not payload.environment:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Scene resources require environment",
|
||||
)
|
||||
elif payload.resource_type == LibraryResourceType.GARMENT and not payload.category:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Garment resources require category",
|
||||
)
|
||||
|
||||
def _find_role(
|
||||
self,
|
||||
files: list[LibraryResourceFileORM],
|
||||
role: LibraryFileRole,
|
||||
) -> LibraryResourceFileORM:
|
||||
return next(file for file in files if file.file_role == role)
|
||||
|
||||
def _to_read(
|
||||
self,
|
||||
resource: LibraryResourceORM,
|
||||
*,
|
||||
files: list[LibraryResourceFileORM] | None = None,
|
||||
) -> LibraryResourceRead:
|
||||
resource_files = files if files is not None else list(resource.files)
|
||||
files_sorted = sorted(resource_files, key=lambda item: (item.sort_order, item.id))
|
||||
cover = next((item for item in files_sorted if item.id == resource.cover_file_id), None)
|
||||
original = next((item for item in files_sorted if item.id == resource.original_file_id), None)
|
||||
return LibraryResourceRead(
|
||||
id=resource.id,
|
||||
resource_type=resource.resource_type,
|
||||
name=resource.name,
|
||||
description=resource.description,
|
||||
tags=resource.tags,
|
||||
status=resource.status,
|
||||
gender=resource.gender,
|
||||
age_group=resource.age_group,
|
||||
pose_id=resource.pose_id,
|
||||
environment=resource.environment,
|
||||
category=resource.category,
|
||||
cover_url=cover.public_url if cover else None,
|
||||
original_url=original.public_url if original else None,
|
||||
files=[
|
||||
LibraryResourceFileRead(
|
||||
id=file.id,
|
||||
file_role=file.file_role,
|
||||
storage_key=file.storage_key,
|
||||
public_url=file.public_url,
|
||||
bucket=file.bucket,
|
||||
mime_type=file.mime_type,
|
||||
size_bytes=file.size_bytes,
|
||||
sort_order=file.sort_order,
|
||||
width=file.width,
|
||||
height=file.height,
|
||||
created_at=file.created_at,
|
||||
)
|
||||
for file in files_sorted
|
||||
],
|
||||
created_at=resource.created_at,
|
||||
updated_at=resource.updated_at,
|
||||
)
|
||||
|
||||
async def _get_resource_or_404(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
resource_id: int,
|
||||
) -> LibraryResourceORM:
|
||||
result = await session.execute(
|
||||
select(LibraryResourceORM)
|
||||
.options(selectinload(LibraryResourceORM.files))
|
||||
.where(LibraryResourceORM.id == resource_id)
|
||||
)
|
||||
resource = result.scalar_one_or_none()
|
||||
if resource is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Resource not found")
|
||||
return resource
|
||||
|
||||
def _validate_existing_resource(self, resource: LibraryResourceORM) -> None:
|
||||
if not resource.name.strip():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Resource name is required")
|
||||
|
||||
if resource.resource_type == LibraryResourceType.MODEL:
|
||||
if not resource.gender or not resource.age_group:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Model resources require gender and age_group",
|
||||
)
|
||||
elif resource.resource_type == LibraryResourceType.SCENE:
|
||||
if not resource.environment:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Scene resources require environment",
|
||||
)
|
||||
elif resource.resource_type == LibraryResourceType.GARMENT and not resource.category:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Garment resources require category",
|
||||
)
|
||||
Reference in New Issue
Block a user