"""Order application service.""" from fastapi import HTTPException, status from math import ceil from sqlalchemy import String, cast, func, or_, select from sqlalchemy.orm import selectinload from app.api.schemas.asset import AssetRead from app.api.schemas.order import ( CreateOrderRequest, CreateOrderResponse, OrderDetailResponse, OrderListItemResponse, OrderListResponse, ) from app.application.services.revision_service import RevisionService from app.application.services.workflow_service import WorkflowService from app.domain.enums import CustomerLevel, OrderStatus, ServiceMode from app.infra.db.models.order import OrderORM from app.infra.db.models.workflow_run import WorkflowRunORM from app.workers.workflows.types import PipelineWorkflowInput class OrderService: """Application service for order management.""" def __init__(self) -> None: self.workflow_service = WorkflowService() self.revision_service = RevisionService() async def create_order(self, session, payload: CreateOrderRequest) -> CreateOrderResponse: """Create an order, persist a workflow run, and start Temporal execution.""" self._validate_mode(payload.customer_level, payload.service_mode) order = OrderORM( customer_level=payload.customer_level, service_mode=payload.service_mode, status=OrderStatus.CREATED, model_id=payload.model_id, pose_id=payload.pose_id, garment_asset_id=payload.garment_asset_id, scene_ref_asset_id=payload.scene_ref_asset_id, ) session.add(order) await session.flush() workflow_id = f"order-{order.id}" workflow_run = WorkflowRunORM( order_id=order.id, workflow_id=workflow_id, workflow_type=self.workflow_service.workflow_type_for_mode(payload.service_mode), status=OrderStatus.CREATED, ) session.add(workflow_run) await session.commit() workflow_input = PipelineWorkflowInput( order_id=order.id, workflow_run_id=workflow_run.id, customer_level=order.customer_level, service_mode=order.service_mode, model_id=order.model_id, pose_id=order.pose_id, garment_asset_id=order.garment_asset_id, scene_ref_asset_id=order.scene_ref_asset_id, ) try: await self.workflow_service.start_workflow(workflow_input) except Exception as exc: order.status = OrderStatus.FAILED workflow_run.status = OrderStatus.FAILED await session.commit() raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=f"Failed to start Temporal workflow: {exc}", ) from exc return CreateOrderResponse(order_id=order.id, workflow_id=workflow_id, status=order.status) async def get_order(self, session, order_id: int) -> OrderDetailResponse: """Return a single order with workflow context and final asset.""" result = await session.execute( select(OrderORM) .where(OrderORM.id == order_id) .options( selectinload(OrderORM.assets), selectinload(OrderORM.workflow_runs), ) ) order = result.scalar_one_or_none() if order is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Order not found") workflow_run = order.workflow_runs[0] if order.workflow_runs else None final_asset = next((asset for asset in order.assets if asset.id == order.final_asset_id), None) snapshot = await self.revision_service.get_revision_snapshot(session, order_id) return OrderDetailResponse( order_id=order.id, customer_level=order.customer_level, service_mode=order.service_mode, status=order.status, model_id=order.model_id, pose_id=order.pose_id, garment_asset_id=order.garment_asset_id, scene_ref_asset_id=order.scene_ref_asset_id, final_asset_id=order.final_asset_id, workflow_id=workflow_run.workflow_id if workflow_run else None, current_step=workflow_run.current_step if workflow_run else None, current_revision_asset_id=snapshot.current_revision_asset_id, current_revision_version=snapshot.current_revision_version, latest_revision_asset_id=snapshot.latest_revision_asset_id, latest_revision_version=snapshot.latest_revision_version, revision_count=snapshot.revision_count, review_task_status=snapshot.review_task_status, pending_manual_confirm=snapshot.pending_manual_confirm, final_asset=AssetRead.model_validate(final_asset) if final_asset else None, created_at=order.created_at, updated_at=order.updated_at, ) async def list_orders( self, session, *, page: int = 1, limit: int = 20, query: str | None = None, status_filter: OrderStatus | None = None, order_id: int | None = None, ) -> OrderListResponse: """Return recent orders for dashboard overview pages.""" filters = [] if status_filter is not None: filters.append(OrderORM.status == status_filter) if order_id is not None: filters.append(OrderORM.id == order_id) if query: search_term = query.strip() if search_term: filters.append( or_( cast(OrderORM.id, String).ilike(f"{search_term}%"), OrderORM.workflow_runs.any( WorkflowRunORM.workflow_id.ilike(f"%{search_term}%") ), ) ) query = select(OrderORM).options(selectinload(OrderORM.workflow_runs)) count_query = select(func.count()).select_from(OrderORM) if filters: query = query.where(*filters) count_query = count_query.where(*filters) total = (await session.execute(count_query)).scalar_one() total_pages = ceil(total / limit) if total else 0 offset = (page - 1) * limit query = query.order_by(OrderORM.updated_at.desc(), OrderORM.id.desc()).offset(offset).limit(limit) result = await session.execute(query) orders = result.scalars().all() items = [] for order in orders: workflow_run = order.workflow_runs[0] if order.workflow_runs else None snapshot = await self.revision_service.get_revision_snapshot(session, order.id) items.append( OrderListItemResponse( order_id=order.id, workflow_id=workflow_run.workflow_id if workflow_run else None, customer_level=order.customer_level, service_mode=order.service_mode, status=order.status, current_step=workflow_run.current_step if workflow_run else None, updated_at=order.updated_at, final_asset_id=order.final_asset_id, review_task_status=snapshot.review_task_status, latest_revision_asset_id=snapshot.latest_revision_asset_id, latest_revision_version=snapshot.latest_revision_version, revision_count=snapshot.revision_count, pending_manual_confirm=snapshot.pending_manual_confirm, ) ) return OrderListResponse( page=page, limit=limit, total=total, total_pages=total_pages, items=items, ) @staticmethod def _validate_mode(customer_level: CustomerLevel, service_mode: ServiceMode) -> None: """Validate the allowed customer-level and service-mode combinations.""" if customer_level == CustomerLevel.LOW and service_mode != ServiceMode.AUTO_BASIC: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Low-level customers only support auto_basic", ) if customer_level == CustomerLevel.MID and service_mode != ServiceMode.SEMI_PRO: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Mid-level customers only support semi_pro", )