feat: add resource library and real image workflow
This commit is contained in:
150
tests/test_gemini_provider.py
Normal file
150
tests/test_gemini_provider.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Tests for the Gemini image provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from app.infra.image_generation.base import SourceImage
|
||||
from app.infra.image_generation.gemini_provider import GeminiImageProvider
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_provider_uses_configured_base_url_and_model():
|
||||
"""The provider should call the configured endpoint and decode the returned image bytes."""
|
||||
|
||||
expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent"
|
||||
png_bytes = b"generated-png"
|
||||
|
||||
async def handler(request: httpx.Request) -> httpx.Response:
|
||||
assert str(request.url) == expected_url
|
||||
assert request.headers["x-goog-api-key"] == "test-key"
|
||||
payload = request.read().decode("utf-8")
|
||||
assert "gemini-test-model" not in payload
|
||||
assert "Dress the person in the garment" in payload
|
||||
assert "cGVyc29uLWJ5dGVz" in payload
|
||||
assert "Z2FybWVudC1ieXRlcw==" in payload
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{"text": "done"},
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/png",
|
||||
"data": base64.b64encode(png_bytes).decode("utf-8"),
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
provider = GeminiImageProvider(
|
||||
api_key="test-key",
|
||||
base_url="https://gemini.example.com/custom",
|
||||
model="gemini-test-model",
|
||||
timeout_seconds=5,
|
||||
http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)),
|
||||
)
|
||||
|
||||
try:
|
||||
result = await provider.generate_tryon_image(
|
||||
prompt="Dress the person in the garment",
|
||||
person_image=SourceImage(
|
||||
url="https://images.example.com/person.png",
|
||||
mime_type="image/png",
|
||||
data=b"person-bytes",
|
||||
),
|
||||
garment_image=SourceImage(
|
||||
url="https://images.example.com/garment.png",
|
||||
mime_type="image/png",
|
||||
data=b"garment-bytes",
|
||||
),
|
||||
)
|
||||
finally:
|
||||
await provider._http_client.aclose()
|
||||
|
||||
assert result.image_bytes == png_bytes
|
||||
assert result.mime_type == "image/png"
|
||||
assert result.provider == "gemini"
|
||||
assert result.model == "gemini-test-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_provider_retries_remote_disconnect_and_uses_request_timeout():
|
||||
"""The provider should retry transient transport failures and pass timeout per request."""
|
||||
|
||||
expected_url = "https://gemini.example.com/custom/models/gemini-test-model:generateContent"
|
||||
jpeg_bytes = b"generated-jpeg"
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self) -> None:
|
||||
self.attempts = 0
|
||||
self.timeouts: list[int | None] = []
|
||||
|
||||
async def post(self, url: str, **kwargs) -> httpx.Response:
|
||||
self.attempts += 1
|
||||
self.timeouts.append(kwargs.get("timeout"))
|
||||
assert url == expected_url
|
||||
if self.attempts == 1:
|
||||
raise httpx.RemoteProtocolError("Server disconnected without sending a response.")
|
||||
return httpx.Response(
|
||||
200,
|
||||
json={
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": "image/jpeg",
|
||||
"data": base64.b64encode(jpeg_bytes).decode("utf-8"),
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
request=httpx.Request("POST", url),
|
||||
)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return None
|
||||
|
||||
fake_client = FakeClient()
|
||||
|
||||
provider = GeminiImageProvider(
|
||||
api_key="test-key",
|
||||
base_url="https://gemini.example.com/custom",
|
||||
model="gemini-test-model",
|
||||
timeout_seconds=300,
|
||||
http_client=fake_client,
|
||||
)
|
||||
|
||||
result = await provider.generate_tryon_image(
|
||||
prompt="Dress the person in the garment",
|
||||
person_image=SourceImage(
|
||||
url="https://images.example.com/person.png",
|
||||
mime_type="image/png",
|
||||
data=b"person-bytes",
|
||||
),
|
||||
garment_image=SourceImage(
|
||||
url="https://images.example.com/garment.png",
|
||||
mime_type="image/png",
|
||||
data=b"garment-bytes",
|
||||
),
|
||||
)
|
||||
|
||||
assert fake_client.attempts == 2
|
||||
assert fake_client.timeouts == [300, 300]
|
||||
assert result.image_bytes == jpeg_bytes
|
||||
assert result.mime_type == "image/jpeg"
|
||||
Reference in New Issue
Block a user