"""Tests for the Gemini image provider.""" from __future__ import annotations import base64 import httpx import pytest from app.infra.image_generation.base import SourceImage from app.infra.image_generation.gemini_provider import GeminiImageProvider @pytest.mark.asyncio async def test_gemini_provider_uses_configured_base_url_and_model(): """The provider should call the configured endpoint and decode the returned image bytes.""" expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent" png_bytes = b"generated-png" async def handler(request: httpx.Request) -> httpx.Response: assert str(request.url) == expected_url assert request.headers["x-goog-api-key"] == "test-key" payload = request.read().decode("utf-8") assert "gemini-test-model" not in payload assert "Dress the person in the garment" in payload assert "cGVyc29uLWJ5dGVz" in payload assert "Z2FybWVudC1ieXRlcw==" in payload return httpx.Response( 200, json={ "candidates": [ { "content": { "parts": [ {"text": "done"}, { "inlineData": { "mimeType": "image/png", "data": base64.b64encode(png_bytes).decode("utf-8"), } }, ] } } ] }, ) provider = GeminiImageProvider( api_key="test-key", base_url="https://gemini.example.com/custom", model="gemini-test-model", timeout_seconds=5, http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)), ) try: result = await provider.generate_tryon_image( prompt="Dress the person in the garment", person_image=SourceImage( url="https://images.example.com/person.png", mime_type="image/png", data=b"person-bytes", ), garment_image=SourceImage( url="https://images.example.com/garment.png", mime_type="image/png", data=b"garment-bytes", ), ) finally: await provider._http_client.aclose() assert result.image_bytes == png_bytes assert result.mime_type == "image/png" assert result.provider == "gemini" assert result.model == "gemini-test-model" @pytest.mark.asyncio async def test_gemini_provider_retries_remote_disconnect_and_uses_request_timeout(): """The provider should retry transient transport failures and pass timeout per request.""" expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent" jpeg_bytes = b"generated-jpeg" class FakeClient: def __init__(self) -> None: self.attempts = 0 self.timeouts: list[int | None] = [] async def post(self, url: str, **kwargs) -> httpx.Response: self.attempts += 1 self.timeouts.append(kwargs.get("timeout")) assert url == expected_url if self.attempts == 1: raise httpx.RemoteProtocolError("Server disconnected without sending a response.") return httpx.Response( 200, json={ "candidates": [ { "content": { "parts": [ { "inlineData": { "mimeType": "image/jpeg", "data": base64.b64encode(jpeg_bytes).decode("utf-8"), } } ] } } ] }, request=httpx.Request("POST", url), ) async def aclose(self) -> None: return None fake_client = FakeClient() provider = GeminiImageProvider( api_key="test-key", base_url="https://gemini.example.com/custom", model="gemini-test-model", timeout_seconds=300, http_client=fake_client, ) result = await provider.generate_tryon_image( prompt="Dress the person in the garment", person_image=SourceImage( url="https://images.example.com/person.png", mime_type="image/png", data=b"person-bytes", ), garment_image=SourceImage( url="https://images.example.com/garment.png", mime_type="image/png", data=b"garment-bytes", ), ) assert fake_client.attempts == 2 assert fake_client.timeouts == [300, 300] assert result.image_bytes == jpeg_bytes assert result.mime_type == "image/jpeg"