"""Temporal 工作流服务层。 这一层位于 API 和 Temporal 之间,负责: 1. 选择该启动哪个 workflow 2. 发送 signal 3. 查询已持久化的 workflow 状态 """ from datetime import timedelta from math import ceil from fastapi import HTTPException, status from sqlalchemy import String, cast, func, or_, select from sqlalchemy.orm import selectinload from app.api.schemas.workflow import ( WorkflowListItemResponse, WorkflowListResponse, WorkflowStatusResponse, WorkflowStepRead, ) from app.domain.enums import OrderStatus from app.application.services.revision_service import RevisionService from app.domain.enums import ServiceMode from app.infra.db.models.workflow_run import WorkflowRunORM from app.infra.temporal.client import get_temporal_client from app.infra.temporal.task_queues import IMAGE_PIPELINE_CONTROL_TASK_QUEUE from app.workers.workflows.low_end_pipeline import LowEndPipelineWorkflow from app.workers.workflows.mid_end_pipeline import MidEndPipelineWorkflow from app.workers.workflows.types import PipelineWorkflowInput, ReviewSignalPayload class WorkflowService: """Temporal 编排服务。""" def __init__(self) -> None: self.revision_service = RevisionService() @staticmethod def workflow_type_for_mode(service_mode: ServiceMode) -> str: """根据服务模式返回对应的 workflow 类型名。""" if service_mode == ServiceMode.AUTO_BASIC: return LowEndPipelineWorkflow.__name__ return MidEndPipelineWorkflow.__name__ async def start_workflow(self, workflow_input: PipelineWorkflowInput) -> None: """为订单启动对应的 Temporal workflow。 这里做的只是“发起执行”: 真正的流水线顺序仍然在 workflow 类里定义。 """ client = await get_temporal_client() workflow_id = f"order-{workflow_input.order_id}" workflow_callable = ( LowEndPipelineWorkflow.run if workflow_input.service_mode == ServiceMode.AUTO_BASIC else MidEndPipelineWorkflow.run ) await client.start_workflow( workflow_callable, workflow_input, # workflow_id 固定为 order-{order_id},方便 API 后续按订单回查。 id=workflow_id, task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE, run_timeout=timedelta(minutes=30), task_timeout=timedelta(seconds=30), ) async def signal_review(self, workflow_id: str, payload: ReviewSignalPayload) -> None: """向运行中的 workflow 发送审核 signal。""" client = await get_temporal_client() handle = client.get_workflow_handle(workflow_id=workflow_id) # "submit_review" 对应 workflow 里用 @workflow.signal 标记的方法名。 await handle.signal("submit_review", payload) async def get_workflow_status(self, session, order_id: int) -> WorkflowStatusResponse: """返回订单对应的已持久化 workflow 状态。 这里查的是我们自己数据库里的状态镜像,不是直接去 Temporal history 现查。 这么做更适合业务 API 对外暴露。 """ result = await session.execute( select(WorkflowRunORM) .where(WorkflowRunORM.order_id == order_id) .options(selectinload(WorkflowRunORM.steps)) ) workflow_run = result.scalar_one_or_none() if workflow_run is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workflow not found") snapshot = await self.revision_service.get_revision_snapshot(session, order_id) return WorkflowStatusResponse( order_id=workflow_run.order_id, workflow_id=workflow_run.workflow_id, workflow_type=workflow_run.workflow_type, workflow_status=workflow_run.status, current_step=workflow_run.current_step, current_revision_asset_id=snapshot.current_revision_asset_id, current_revision_version=snapshot.current_revision_version, latest_revision_asset_id=snapshot.latest_revision_asset_id, latest_revision_version=snapshot.latest_revision_version, revision_count=snapshot.revision_count, review_task_status=snapshot.review_task_status, pending_manual_confirm=snapshot.pending_manual_confirm, steps=[WorkflowStepRead.model_validate(step) for step in workflow_run.steps], created_at=workflow_run.created_at, updated_at=workflow_run.updated_at, ) async def list_workflows( self, session, *, page: int = 1, limit: int = 20, query: str | None = None, status_filter: OrderStatus | None = None, order_id: int | None = None, ) -> WorkflowListResponse: """Return recent workflow runs for dashboard lookup pages.""" filters = [] if status_filter is not None: filters.append(WorkflowRunORM.status == status_filter) if order_id is not None: filters.append(WorkflowRunORM.order_id == order_id) if query: search_term = query.strip() if search_term: filters.append( or_( cast(WorkflowRunORM.order_id, String).ilike(f"{search_term}%"), WorkflowRunORM.workflow_id.ilike(f"%{search_term}%"), ) ) query = select(WorkflowRunORM).options(selectinload(WorkflowRunORM.steps)) count_query = select(func.count()).select_from(WorkflowRunORM) if filters: query = query.where(*filters) count_query = count_query.where(*filters) total = (await session.execute(count_query)).scalar_one() total_pages = ceil(total / limit) if total else 0 offset = (page - 1) * limit query = query.order_by(WorkflowRunORM.updated_at.desc(), WorkflowRunORM.id.desc()).offset(offset).limit(limit) result = await session.execute(query) workflow_runs = result.scalars().all() items = [] for workflow_run in workflow_runs: snapshot = await self.revision_service.get_revision_snapshot( session, workflow_run.order_id, ) items.append( WorkflowListItemResponse( order_id=workflow_run.order_id, workflow_id=workflow_run.workflow_id, workflow_type=workflow_run.workflow_type, workflow_status=workflow_run.status, current_step=workflow_run.current_step, updated_at=workflow_run.updated_at, failure_count=sum( 1 for step in workflow_run.steps if step.step_status.value == "failed" ), review_task_status=snapshot.review_task_status, latest_revision_asset_id=snapshot.latest_revision_asset_id, latest_revision_version=snapshot.latest_revision_version, revision_count=snapshot.revision_count, pending_manual_confirm=snapshot.pending_manual_confirm, ) ) return WorkflowListResponse( page=page, limit=limit, total=total, total_pages=total_pages, items=items, )