Files
auto-virtual-tryon/tests/test_gemini_provider.py

151 lines
5.1 KiB
Python

"""Tests for the Gemini image provider."""
from __future__ import annotations
import base64
import httpx
import pytest
from app.infra.image_generation.base import SourceImage
from app.infra.image_generation.gemini_provider import GeminiImageProvider
@pytest.mark.asyncio
async def test_gemini_provider_uses_configured_base_url_and_model():
"""The provider should call the configured endpoint and decode the returned image bytes."""
expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent"
png_bytes = b"generated-png"
async def handler(request: httpx.Request) -> httpx.Response:
assert str(request.url) == expected_url
assert request.headers["x-goog-api-key"] == "test-key"
payload = request.read().decode("utf-8")
assert "gemini-test-model" not in payload
assert "Dress the person in the garment" in payload
assert "cGVyc29uLWJ5dGVz" in payload
assert "Z2FybWVudC1ieXRlcw==" in payload
return httpx.Response(
200,
json={
"candidates": [
{
"content": {
"parts": [
{"text": "done"},
{
"inlineData": {
"mimeType": "image/png",
"data": base64.b64encode(png_bytes).decode("utf-8"),
}
},
]
}
}
]
},
)
provider = GeminiImageProvider(
api_key="test-key",
base_url="https://gemini.example.com/custom",
model="gemini-test-model",
timeout_seconds=5,
http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)),
)
try:
result = await provider.generate_tryon_image(
prompt="Dress the person in the garment",
person_image=SourceImage(
url="https://images.example.com/person.png",
mime_type="image/png",
data=b"person-bytes",
),
garment_image=SourceImage(
url="https://images.example.com/garment.png",
mime_type="image/png",
data=b"garment-bytes",
),
)
finally:
await provider._http_client.aclose()
assert result.image_bytes == png_bytes
assert result.mime_type == "image/png"
assert result.provider == "gemini"
assert result.model == "gemini-test-model"
@pytest.mark.asyncio
async def test_gemini_provider_retries_remote_disconnect_and_uses_request_timeout():
"""The provider should retry transient transport failures and pass timeout per request."""
expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent"
jpeg_bytes = b"generated-jpeg"
class FakeClient:
def __init__(self) -> None:
self.attempts = 0
self.timeouts: list[int | None] = []
async def post(self, url: str, **kwargs) -> httpx.Response:
self.attempts += 1
self.timeouts.append(kwargs.get("timeout"))
assert url == expected_url
if self.attempts == 1:
raise httpx.RemoteProtocolError("Server disconnected without sending a response.")
return httpx.Response(
200,
json={
"candidates": [
{
"content": {
"parts": [
{
"inlineData": {
"mimeType": "image/jpeg",
"data": base64.b64encode(jpeg_bytes).decode("utf-8"),
}
}
]
}
}
]
},
request=httpx.Request("POST", url),
)
async def aclose(self) -> None:
return None
fake_client = FakeClient()
provider = GeminiImageProvider(
api_key="test-key",
base_url="https://gemini.example.com/custom",
model="gemini-test-model",
timeout_seconds=300,
http_client=fake_client,
)
result = await provider.generate_tryon_image(
prompt="Dress the person in the garment",
person_image=SourceImage(
url="https://images.example.com/person.png",
mime_type="image/png",
data=b"person-bytes",
),
garment_image=SourceImage(
url="https://images.example.com/garment.png",
mime_type="image/png",
data=b"garment-bytes",
),
)
assert fake_client.attempts == 2
assert fake_client.timeouts == [300, 300]
assert result.image_bytes == jpeg_bytes
assert result.mime_type == "image/jpeg"