Implement FastAPI Temporal MVP pipeline
This commit is contained in:
26
app/application/services/asset_service.py
Normal file
26
app/application/services/asset_service.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Asset application service."""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.schemas.asset import AssetRead
|
||||
from app.infra.db.models.asset import AssetORM
|
||||
from app.infra.db.models.order import OrderORM
|
||||
|
||||
|
||||
class AssetService:
|
||||
"""Application service for asset queries."""
|
||||
|
||||
async def list_order_assets(self, session: AsyncSession, order_id: int) -> list[AssetRead]:
|
||||
"""Return all assets belonging to an order."""
|
||||
|
||||
order = await session.get(OrderORM, order_id)
|
||||
if order is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Order not found")
|
||||
|
||||
result = await session.execute(
|
||||
select(AssetORM).where(AssetORM.order_id == order_id).order_by(AssetORM.created_at.asc())
|
||||
)
|
||||
return [AssetRead.model_validate(asset) for asset in result.scalars().all()]
|
||||
|
||||
122
app/application/services/order_service.py
Normal file
122
app/application/services/order_service.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Order application service."""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.api.schemas.asset import AssetRead
|
||||
from app.api.schemas.order import CreateOrderRequest, CreateOrderResponse, OrderDetailResponse
|
||||
from app.application.services.workflow_service import WorkflowService
|
||||
from app.domain.enums import CustomerLevel, OrderStatus, ServiceMode
|
||||
from app.infra.db.models.order import OrderORM
|
||||
from app.infra.db.models.workflow_run import WorkflowRunORM
|
||||
from app.workers.workflows.types import PipelineWorkflowInput
|
||||
|
||||
|
||||
class OrderService:
|
||||
"""Application service for order management."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.workflow_service = WorkflowService()
|
||||
|
||||
async def create_order(self, session, payload: CreateOrderRequest) -> CreateOrderResponse:
|
||||
"""Create an order, persist a workflow run, and start Temporal execution."""
|
||||
|
||||
self._validate_mode(payload.customer_level, payload.service_mode)
|
||||
|
||||
order = OrderORM(
|
||||
customer_level=payload.customer_level,
|
||||
service_mode=payload.service_mode,
|
||||
status=OrderStatus.CREATED,
|
||||
model_id=payload.model_id,
|
||||
pose_id=payload.pose_id,
|
||||
garment_asset_id=payload.garment_asset_id,
|
||||
scene_ref_asset_id=payload.scene_ref_asset_id,
|
||||
)
|
||||
session.add(order)
|
||||
await session.flush()
|
||||
|
||||
workflow_id = f"order-{order.id}"
|
||||
workflow_run = WorkflowRunORM(
|
||||
order_id=order.id,
|
||||
workflow_id=workflow_id,
|
||||
workflow_type=self.workflow_service.workflow_type_for_mode(payload.service_mode),
|
||||
status=OrderStatus.CREATED,
|
||||
)
|
||||
session.add(workflow_run)
|
||||
await session.commit()
|
||||
|
||||
workflow_input = PipelineWorkflowInput(
|
||||
order_id=order.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
customer_level=order.customer_level,
|
||||
service_mode=order.service_mode,
|
||||
model_id=order.model_id,
|
||||
pose_id=order.pose_id,
|
||||
garment_asset_id=order.garment_asset_id,
|
||||
scene_ref_asset_id=order.scene_ref_asset_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await self.workflow_service.start_workflow(workflow_input)
|
||||
except Exception as exc:
|
||||
order.status = OrderStatus.FAILED
|
||||
workflow_run.status = OrderStatus.FAILED
|
||||
await session.commit()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Failed to start Temporal workflow: {exc}",
|
||||
) from exc
|
||||
|
||||
return CreateOrderResponse(order_id=order.id, workflow_id=workflow_id, status=order.status)
|
||||
|
||||
async def get_order(self, session, order_id: int) -> OrderDetailResponse:
|
||||
"""Return a single order with workflow context and final asset."""
|
||||
|
||||
result = await session.execute(
|
||||
select(OrderORM)
|
||||
.where(OrderORM.id == order_id)
|
||||
.options(
|
||||
selectinload(OrderORM.assets),
|
||||
selectinload(OrderORM.workflow_runs),
|
||||
)
|
||||
)
|
||||
order = result.scalar_one_or_none()
|
||||
if order is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Order not found")
|
||||
|
||||
workflow_run = order.workflow_runs[0] if order.workflow_runs else None
|
||||
final_asset = next((asset for asset in order.assets if asset.id == order.final_asset_id), None)
|
||||
|
||||
return OrderDetailResponse(
|
||||
order_id=order.id,
|
||||
customer_level=order.customer_level,
|
||||
service_mode=order.service_mode,
|
||||
status=order.status,
|
||||
model_id=order.model_id,
|
||||
pose_id=order.pose_id,
|
||||
garment_asset_id=order.garment_asset_id,
|
||||
scene_ref_asset_id=order.scene_ref_asset_id,
|
||||
final_asset_id=order.final_asset_id,
|
||||
workflow_id=workflow_run.workflow_id if workflow_run else None,
|
||||
current_step=workflow_run.current_step if workflow_run else None,
|
||||
final_asset=AssetRead.model_validate(final_asset) if final_asset else None,
|
||||
created_at=order.created_at,
|
||||
updated_at=order.updated_at,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_mode(customer_level: CustomerLevel, service_mode: ServiceMode) -> None:
|
||||
"""Validate the allowed customer-level and service-mode combinations."""
|
||||
|
||||
if customer_level == CustomerLevel.LOW and service_mode != ServiceMode.AUTO_BASIC:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Low-level customers only support auto_basic",
|
||||
)
|
||||
if customer_level == CustomerLevel.MID and service_mode != ServiceMode.SEMI_PRO:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Mid-level customers only support semi_pro",
|
||||
)
|
||||
|
||||
112
app/application/services/review_service.py
Normal file
112
app/application/services/review_service.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Review application service."""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.api.schemas.review import PendingReviewResponse, SubmitReviewRequest, SubmitReviewResponse
|
||||
from app.application.services.workflow_service import WorkflowService
|
||||
from app.domain.enums import OrderStatus, ReviewTaskStatus
|
||||
from app.infra.db.models.asset import AssetORM
|
||||
from app.infra.db.models.order import OrderORM
|
||||
from app.infra.db.models.review_task import ReviewTaskORM
|
||||
from app.infra.db.models.workflow_run import WorkflowRunORM
|
||||
from app.workers.workflows.types import ReviewSignalPayload
|
||||
|
||||
|
||||
class ReviewService:
|
||||
"""Application service for review flows."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.workflow_service = WorkflowService()
|
||||
|
||||
async def list_pending_reviews(self, session) -> list[PendingReviewResponse]:
|
||||
"""Return all pending review tasks."""
|
||||
|
||||
result = await session.execute(
|
||||
select(ReviewTaskORM, WorkflowRunORM)
|
||||
.join(WorkflowRunORM, WorkflowRunORM.order_id == ReviewTaskORM.order_id)
|
||||
.where(ReviewTaskORM.status == ReviewTaskStatus.PENDING)
|
||||
.order_by(ReviewTaskORM.created_at.asc())
|
||||
)
|
||||
|
||||
return [
|
||||
PendingReviewResponse(
|
||||
review_task_id=review_task.id,
|
||||
order_id=review_task.order_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
current_step=workflow_run.current_step,
|
||||
created_at=review_task.created_at,
|
||||
)
|
||||
for review_task, workflow_run in result.all()
|
||||
]
|
||||
|
||||
async def submit_review(self, session, order_id: int, payload: SubmitReviewRequest) -> SubmitReviewResponse:
|
||||
"""Persist the review submission and signal the Temporal workflow."""
|
||||
|
||||
order = await session.get(OrderORM, order_id)
|
||||
if order is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Order not found")
|
||||
if order.status != OrderStatus.WAITING_REVIEW:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Order is not waiting for review",
|
||||
)
|
||||
|
||||
workflow_result = await session.execute(
|
||||
select(WorkflowRunORM).where(WorkflowRunORM.order_id == order_id)
|
||||
)
|
||||
workflow_run = workflow_result.scalar_one_or_none()
|
||||
if workflow_run is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workflow not found")
|
||||
|
||||
if payload.selected_asset_id is not None:
|
||||
asset = await session.get(AssetORM, payload.selected_asset_id)
|
||||
if asset is None or asset.order_id != order_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Selected asset does not belong to the order",
|
||||
)
|
||||
|
||||
pending_result = await session.execute(
|
||||
select(ReviewTaskORM)
|
||||
.where(
|
||||
ReviewTaskORM.order_id == order_id,
|
||||
ReviewTaskORM.status == ReviewTaskStatus.PENDING,
|
||||
)
|
||||
.order_by(ReviewTaskORM.created_at.desc())
|
||||
)
|
||||
review_task = pending_result.scalars().first()
|
||||
if review_task is None:
|
||||
review_task = ReviewTaskORM(order_id=order_id, status=ReviewTaskStatus.SUBMITTED)
|
||||
session.add(review_task)
|
||||
|
||||
review_task.status = ReviewTaskStatus.SUBMITTED
|
||||
review_task.decision = payload.decision
|
||||
review_task.reviewer_id = payload.reviewer_id
|
||||
review_task.selected_asset_id = payload.selected_asset_id
|
||||
review_task.comment = payload.comment
|
||||
await session.commit()
|
||||
|
||||
try:
|
||||
await self.workflow_service.signal_review(
|
||||
workflow_run.workflow_id,
|
||||
ReviewSignalPayload(
|
||||
decision=payload.decision,
|
||||
reviewer_id=payload.reviewer_id,
|
||||
selected_asset_id=payload.selected_asset_id,
|
||||
comment=payload.comment,
|
||||
),
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=f"Failed to signal Temporal workflow: {exc}",
|
||||
) from exc
|
||||
|
||||
return SubmitReviewResponse(
|
||||
order_id=order_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
decision=payload.decision,
|
||||
status="submitted",
|
||||
)
|
||||
|
||||
77
app/application/services/workflow_service.py
Normal file
77
app/application/services/workflow_service.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Temporal workflow application service."""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.api.schemas.workflow import WorkflowStatusResponse, WorkflowStepRead
|
||||
from app.domain.enums import ServiceMode
|
||||
from app.infra.db.models.workflow_run import WorkflowRunORM
|
||||
from app.infra.temporal.client import get_temporal_client
|
||||
from app.infra.temporal.task_queues import IMAGE_PIPELINE_CONTROL_TASK_QUEUE
|
||||
from app.workers.workflows.low_end_pipeline import LowEndPipelineWorkflow
|
||||
from app.workers.workflows.mid_end_pipeline import MidEndPipelineWorkflow
|
||||
from app.workers.workflows.types import PipelineWorkflowInput, ReviewSignalPayload
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
"""Application service for Temporal workflow orchestration."""
|
||||
|
||||
@staticmethod
|
||||
def workflow_type_for_mode(service_mode: ServiceMode) -> str:
|
||||
"""Return the workflow class name for a service mode."""
|
||||
|
||||
if service_mode == ServiceMode.AUTO_BASIC:
|
||||
return LowEndPipelineWorkflow.__name__
|
||||
return MidEndPipelineWorkflow.__name__
|
||||
|
||||
async def start_workflow(self, workflow_input: PipelineWorkflowInput) -> None:
|
||||
"""Start the appropriate Temporal workflow for an order."""
|
||||
|
||||
client = await get_temporal_client()
|
||||
workflow_id = f"order-{workflow_input.order_id}"
|
||||
workflow_callable = (
|
||||
LowEndPipelineWorkflow.run
|
||||
if workflow_input.service_mode == ServiceMode.AUTO_BASIC
|
||||
else MidEndPipelineWorkflow.run
|
||||
)
|
||||
await client.start_workflow(
|
||||
workflow_callable,
|
||||
workflow_input,
|
||||
id=workflow_id,
|
||||
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
||||
run_timeout=timedelta(minutes=30),
|
||||
task_timeout=timedelta(seconds=30),
|
||||
)
|
||||
|
||||
async def signal_review(self, workflow_id: str, payload: ReviewSignalPayload) -> None:
|
||||
"""Send a review signal to a running Temporal workflow."""
|
||||
|
||||
client = await get_temporal_client()
|
||||
handle = client.get_workflow_handle(workflow_id=workflow_id)
|
||||
await handle.signal("submit_review", payload)
|
||||
|
||||
async def get_workflow_status(self, session, order_id: int) -> WorkflowStatusResponse:
|
||||
"""Return persisted workflow execution state for an order."""
|
||||
|
||||
result = await session.execute(
|
||||
select(WorkflowRunORM)
|
||||
.where(WorkflowRunORM.order_id == order_id)
|
||||
.options(selectinload(WorkflowRunORM.steps))
|
||||
)
|
||||
workflow_run = result.scalar_one_or_none()
|
||||
if workflow_run is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workflow not found")
|
||||
|
||||
return WorkflowStatusResponse(
|
||||
order_id=workflow_run.order_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_type=workflow_run.workflow_type,
|
||||
workflow_status=workflow_run.status,
|
||||
current_step=workflow_run.current_step,
|
||||
steps=[WorkflowStepRead.model_validate(step) for step in workflow_run.steps],
|
||||
created_at=workflow_run.created_at,
|
||||
updated_at=workflow_run.updated_at,
|
||||
)
|
||||
Reference in New Issue
Block a user