151 lines
5.1 KiB
Python
151 lines
5.1 KiB
Python
"""Tests for the Gemini image provider."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from app.infra.image_generation.base import SourceImage
|
|
from app.infra.image_generation.gemini_provider import GeminiImageProvider
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_gemini_provider_uses_configured_base_url_and_model():
|
|
"""The provider should call the configured endpoint and decode the returned image bytes."""
|
|
|
|
expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent"
|
|
png_bytes = b"generated-png"
|
|
|
|
async def handler(request: httpx.Request) -> httpx.Response:
|
|
assert str(request.url) == expected_url
|
|
assert request.headers["x-goog-api-key"] == "test-key"
|
|
payload = request.read().decode("utf-8")
|
|
assert "gemini-test-model" not in payload
|
|
assert "Dress the person in the garment" in payload
|
|
assert "cGVyc29uLWJ5dGVz" in payload
|
|
assert "Z2FybWVudC1ieXRlcw==" in payload
|
|
return httpx.Response(
|
|
200,
|
|
json={
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [
|
|
{"text": "done"},
|
|
{
|
|
"inlineData": {
|
|
"mimeType": "image/png",
|
|
"data": base64.b64encode(png_bytes).decode("utf-8"),
|
|
}
|
|
},
|
|
]
|
|
}
|
|
}
|
|
]
|
|
},
|
|
)
|
|
|
|
provider = GeminiImageProvider(
|
|
api_key="test-key",
|
|
base_url="https://gemini.example.com/custom",
|
|
model="gemini-test-model",
|
|
timeout_seconds=5,
|
|
http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)),
|
|
)
|
|
|
|
try:
|
|
result = await provider.generate_tryon_image(
|
|
prompt="Dress the person in the garment",
|
|
person_image=SourceImage(
|
|
url="https://images.example.com/person.png",
|
|
mime_type="image/png",
|
|
data=b"person-bytes",
|
|
),
|
|
garment_image=SourceImage(
|
|
url="https://images.example.com/garment.png",
|
|
mime_type="image/png",
|
|
data=b"garment-bytes",
|
|
),
|
|
)
|
|
finally:
|
|
await provider._http_client.aclose()
|
|
|
|
assert result.image_bytes == png_bytes
|
|
assert result.mime_type == "image/png"
|
|
assert result.provider == "gemini"
|
|
assert result.model == "gemini-test-model"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_gemini_provider_retries_remote_disconnect_and_uses_request_timeout():
|
|
"""The provider should retry transient transport failures and pass timeout per request."""
|
|
|
|
expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent"
|
|
jpeg_bytes = b"generated-jpeg"
|
|
|
|
class FakeClient:
|
|
def __init__(self) -> None:
|
|
self.attempts = 0
|
|
self.timeouts: list[int | None] = []
|
|
|
|
async def post(self, url: str, **kwargs) -> httpx.Response:
|
|
self.attempts += 1
|
|
self.timeouts.append(kwargs.get("timeout"))
|
|
assert url == expected_url
|
|
if self.attempts == 1:
|
|
raise httpx.RemoteProtocolError("Server disconnected without sending a response.")
|
|
return httpx.Response(
|
|
200,
|
|
json={
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [
|
|
{
|
|
"inlineData": {
|
|
"mimeType": "image/jpeg",
|
|
"data": base64.b64encode(jpeg_bytes).decode("utf-8"),
|
|
}
|
|
}
|
|
]
|
|
}
|
|
}
|
|
]
|
|
},
|
|
request=httpx.Request("POST", url),
|
|
)
|
|
|
|
async def aclose(self) -> None:
|
|
return None
|
|
|
|
fake_client = FakeClient()
|
|
|
|
provider = GeminiImageProvider(
|
|
api_key="test-key",
|
|
base_url="https://gemini.example.com/custom",
|
|
model="gemini-test-model",
|
|
timeout_seconds=300,
|
|
http_client=fake_client,
|
|
)
|
|
|
|
result = await provider.generate_tryon_image(
|
|
prompt="Dress the person in the garment",
|
|
person_image=SourceImage(
|
|
url="https://images.example.com/person.png",
|
|
mime_type="image/png",
|
|
data=b"person-bytes",
|
|
),
|
|
garment_image=SourceImage(
|
|
url="https://images.example.com/garment.png",
|
|
mime_type="image/png",
|
|
data=b"garment-bytes",
|
|
),
|
|
)
|
|
|
|
assert fake_client.attempts == 2
|
|
assert fake_client.timeouts == [300, 300]
|
|
assert result.image_bytes == jpeg_bytes
|
|
assert result.mime_type == "image/jpeg"
|