193 lines
7.4 KiB
Python
193 lines
7.4 KiB
Python
"""Temporal 工作流服务层。
|
|
|
|
这一层位于 API 和 Temporal 之间,负责:
|
|
1. 选择该启动哪个 workflow
|
|
2. 发送 signal
|
|
3. 查询已持久化的 workflow 状态
|
|
"""
|
|
|
|
from datetime import timedelta
|
|
from math import ceil
|
|
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy import String, cast, func, or_, select
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.api.schemas.workflow import (
|
|
WorkflowListItemResponse,
|
|
WorkflowListResponse,
|
|
WorkflowStatusResponse,
|
|
WorkflowStepRead,
|
|
)
|
|
from app.domain.enums import OrderStatus
|
|
from app.application.services.revision_service import RevisionService
|
|
from app.domain.enums import ServiceMode
|
|
from app.infra.db.models.workflow_run import WorkflowRunORM
|
|
from app.infra.temporal.client import get_temporal_client
|
|
from app.infra.temporal.task_queues import IMAGE_PIPELINE_CONTROL_TASK_QUEUE
|
|
from app.workers.workflows.low_end_pipeline import LowEndPipelineWorkflow
|
|
from app.workers.workflows.mid_end_pipeline import MidEndPipelineWorkflow
|
|
from app.workers.workflows.types import PipelineWorkflowInput, ReviewSignalPayload
|
|
|
|
|
|
class WorkflowService:
|
|
"""Temporal 编排服务。"""
|
|
|
|
def __init__(self) -> None:
|
|
self.revision_service = RevisionService()
|
|
|
|
@staticmethod
|
|
def workflow_type_for_mode(service_mode: ServiceMode) -> str:
|
|
"""根据服务模式返回对应的 workflow 类型名。"""
|
|
|
|
if service_mode == ServiceMode.AUTO_BASIC:
|
|
return LowEndPipelineWorkflow.__name__
|
|
return MidEndPipelineWorkflow.__name__
|
|
|
|
async def start_workflow(self, workflow_input: PipelineWorkflowInput) -> None:
|
|
"""为订单启动对应的 Temporal workflow。
|
|
|
|
这里做的只是“发起执行”:
|
|
真正的流水线顺序仍然在 workflow 类里定义。
|
|
"""
|
|
|
|
client = await get_temporal_client()
|
|
workflow_id = f"order-{workflow_input.order_id}"
|
|
workflow_callable = (
|
|
LowEndPipelineWorkflow.run
|
|
if workflow_input.service_mode == ServiceMode.AUTO_BASIC
|
|
else MidEndPipelineWorkflow.run
|
|
)
|
|
await client.start_workflow(
|
|
workflow_callable,
|
|
workflow_input,
|
|
# workflow_id 固定为 order-{order_id},方便 API 后续按订单回查。
|
|
id=workflow_id,
|
|
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
|
|
run_timeout=timedelta(minutes=30),
|
|
task_timeout=timedelta(seconds=30),
|
|
)
|
|
|
|
async def signal_review(self, workflow_id: str, payload: ReviewSignalPayload) -> None:
|
|
"""向运行中的 workflow 发送审核 signal。"""
|
|
|
|
client = await get_temporal_client()
|
|
handle = client.get_workflow_handle(workflow_id=workflow_id)
|
|
# "submit_review" 对应 workflow 里用 @workflow.signal 标记的方法名。
|
|
await handle.signal("submit_review", payload)
|
|
|
|
async def get_workflow_status(self, session, order_id: int) -> WorkflowStatusResponse:
|
|
"""返回订单对应的已持久化 workflow 状态。
|
|
|
|
这里查的是我们自己数据库里的状态镜像,不是直接去 Temporal history 现查。
|
|
这么做更适合业务 API 对外暴露。
|
|
"""
|
|
|
|
result = await session.execute(
|
|
select(WorkflowRunORM)
|
|
.where(WorkflowRunORM.order_id == order_id)
|
|
.options(selectinload(WorkflowRunORM.steps))
|
|
)
|
|
workflow_run = result.scalar_one_or_none()
|
|
if workflow_run is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workflow not found")
|
|
|
|
snapshot = await self.revision_service.get_revision_snapshot(session, order_id)
|
|
|
|
return WorkflowStatusResponse(
|
|
order_id=workflow_run.order_id,
|
|
workflow_id=workflow_run.workflow_id,
|
|
workflow_type=workflow_run.workflow_type,
|
|
workflow_status=workflow_run.status,
|
|
current_step=workflow_run.current_step,
|
|
current_revision_asset_id=snapshot.current_revision_asset_id,
|
|
current_revision_version=snapshot.current_revision_version,
|
|
latest_revision_asset_id=snapshot.latest_revision_asset_id,
|
|
latest_revision_version=snapshot.latest_revision_version,
|
|
revision_count=snapshot.revision_count,
|
|
review_task_status=snapshot.review_task_status,
|
|
pending_manual_confirm=snapshot.pending_manual_confirm,
|
|
steps=[WorkflowStepRead.model_validate(step) for step in workflow_run.steps],
|
|
created_at=workflow_run.created_at,
|
|
updated_at=workflow_run.updated_at,
|
|
)
|
|
|
|
async def list_workflows(
|
|
self,
|
|
session,
|
|
*,
|
|
page: int = 1,
|
|
limit: int = 20,
|
|
query: str | None = None,
|
|
status_filter: OrderStatus | None = None,
|
|
order_id: int | None = None,
|
|
) -> WorkflowListResponse:
|
|
"""Return recent workflow runs for dashboard lookup pages."""
|
|
|
|
filters = []
|
|
|
|
if status_filter is not None:
|
|
filters.append(WorkflowRunORM.status == status_filter)
|
|
|
|
if order_id is not None:
|
|
filters.append(WorkflowRunORM.order_id == order_id)
|
|
|
|
if query:
|
|
search_term = query.strip()
|
|
if search_term:
|
|
filters.append(
|
|
or_(
|
|
cast(WorkflowRunORM.order_id, String).ilike(f"{search_term}%"),
|
|
WorkflowRunORM.workflow_id.ilike(f"%{search_term}%"),
|
|
)
|
|
)
|
|
|
|
query = select(WorkflowRunORM).options(selectinload(WorkflowRunORM.steps))
|
|
count_query = select(func.count()).select_from(WorkflowRunORM)
|
|
|
|
if filters:
|
|
query = query.where(*filters)
|
|
count_query = count_query.where(*filters)
|
|
|
|
total = (await session.execute(count_query)).scalar_one()
|
|
total_pages = ceil(total / limit) if total else 0
|
|
offset = (page - 1) * limit
|
|
|
|
query = query.order_by(WorkflowRunORM.updated_at.desc(), WorkflowRunORM.id.desc()).offset(offset).limit(limit)
|
|
|
|
result = await session.execute(query)
|
|
workflow_runs = result.scalars().all()
|
|
|
|
items = []
|
|
for workflow_run in workflow_runs:
|
|
snapshot = await self.revision_service.get_revision_snapshot(
|
|
session,
|
|
workflow_run.order_id,
|
|
)
|
|
items.append(
|
|
WorkflowListItemResponse(
|
|
order_id=workflow_run.order_id,
|
|
workflow_id=workflow_run.workflow_id,
|
|
workflow_type=workflow_run.workflow_type,
|
|
workflow_status=workflow_run.status,
|
|
current_step=workflow_run.current_step,
|
|
updated_at=workflow_run.updated_at,
|
|
failure_count=sum(
|
|
1 for step in workflow_run.steps if step.step_status.value == "failed"
|
|
),
|
|
review_task_status=snapshot.review_task_status,
|
|
latest_revision_asset_id=snapshot.latest_revision_asset_id,
|
|
latest_revision_version=snapshot.latest_revision_version,
|
|
revision_count=snapshot.revision_count,
|
|
pending_manual_confirm=snapshot.pending_manual_confirm,
|
|
)
|
|
)
|
|
|
|
return WorkflowListResponse(
|
|
page=page,
|
|
limit=limit,
|
|
total=total,
|
|
total_pages=total_pages,
|
|
items=items,
|
|
)
|