Files
auto-virtual-tryon/app/workers/workflows/mid_end_pipeline.py

361 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""中端图片流水线工作流。
和低端流程相比,这里最大的区别是:
1. 会在 QC 之后停在 waiting_review
2. 通过 Temporal signal 接收人工审核结果
3. 可以按审核意见回流到 scene / face / fusion 重新跑
"""
from datetime import timedelta
from temporalio import workflow
from temporalio.common import RetryPolicy
# 这些导入属于 workflow 外部世界的对象,明确标记为 pass-through
# 避免把它们当成需要重放的 workflow 逻辑一部分。
with workflow.unsafe.imports_passed_through():
from app.domain.enums import OrderStatus, ReviewDecision, WorkflowStepName
from app.infra.temporal.task_queues import (
IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
IMAGE_PIPELINE_QC_TASK_QUEUE,
)
from app.workers.activities.export_activities import run_export_activity
from app.workers.activities.face_activities import run_face_activity
from app.workers.activities.fusion_activities import run_fusion_activity
from app.workers.activities.qc_activities import run_qc_activity
from app.workers.activities.review_activities import (
complete_review_wait_activity,
mark_waiting_for_review_activity,
mark_workflow_failed_activity,
)
from app.workers.activities.scene_activities import run_scene_activity
from app.workers.activities.texture_activities import run_texture_activity
from app.workers.activities.tryon_activities import prepare_model_activity, run_tryon_activity
from app.workers.workflows.types import (
MockActivityResult,
PipelineWorkflowInput,
ReviewResolutionActivityInput,
ReviewSignalPayload,
ReviewWaitActivityInput,
StepActivityInput,
WorkflowFailureActivityInput,
)
ACTIVITY_TIMEOUT = timedelta(seconds=30)
ACTIVITY_RETRY_POLICY = RetryPolicy(
initial_interval=timedelta(seconds=1),
backoff_coefficient=2.0,
maximum_attempts=3,
)
@workflow.defn
class MidEndPipelineWorkflow:
"""中端半自动工作流。
这个 workflow 会经历“自动生成 -> 等待人工审核 -> 按审核意见继续”。
"""
def __init__(self) -> None:
# signal 到达后,先暂存在 workflow 内存里,
# 主流程再通过 wait_condition 继续往下走。
self._review_payload: ReviewSignalPayload | None = None
@workflow.signal
def submit_review(self, payload: ReviewSignalPayload) -> None:
"""接收 API 层发来的审核 signal。
这一步不会直接继续执行,只是把审核结果写进 workflow 内存状态。
真正恢复主流程是在 `_wait_for_review` 里。
"""
self._review_payload = payload
@workflow.run
async def run(self, payload: PipelineWorkflowInput) -> dict[str, int | str | None]:
"""执行中端工作流主流程。
主线是:
prepare_model -> tryon -> scene -> texture -> face -> fusion -> qc
-> waiting_review -> approve/export 或 rerun
"""
# current_step 用于失败时记录“最后跑到哪一步”。
current_step = WorkflowStepName.PREPARE_MODEL
try:
prepared = await workflow.execute_activity(
prepare_model_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.PREPARE_MODEL,
model_id=payload.model_id,
pose_id=payload.pose_id,
garment_asset_id=payload.garment_asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.TRYON
tryon_result = await workflow.execute_activity(
run_tryon_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.TRYON,
source_asset_id=prepared.asset_id,
garment_asset_id=payload.garment_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
current_step = WorkflowStepName.TEXTURE
texture_result = await self._run_texture(payload, scene_result.asset_id)
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
current_step = WorkflowStepName.QC
qc_result = await self._run_qc(payload, fusion_result.asset_id)
if not qc_result.passed:
await self._mark_failed(payload, current_step, qc_result.message)
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
# 中端流程会一直循环到:
# 1. 审核 approve 然后 export 成功
# 2. 审核 reject 直接结束
# 3. rerun 后再次回到 waiting_review继续等下一次人工输入
while True:
current_step = WorkflowStepName.REVIEW
# 这里通过 activity 把数据库里的订单状态更新成 waiting_review
# 同时创建 review_task供 API 查询待审核列表。
await workflow.execute_activity(
mark_waiting_for_review_activity,
ReviewWaitActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
candidate_asset_ids=qc_result.candidate_asset_ids,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
# workflow 在这里“停住”,直到外部 signal 进来。
review_payload = await self._wait_for_review()
# signal 到达后,先把 review 这一步的等待态收口成已处理,
# 这样数据库里的 review_step / review_task 状态是完整的。
await workflow.execute_activity(
complete_review_wait_activity,
ReviewResolutionActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
decision=review_payload.decision,
reviewer_id=review_payload.reviewer_id,
selected_asset_id=review_payload.selected_asset_id,
comment=review_payload.comment,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
if review_payload.decision == ReviewDecision.APPROVE:
current_step = WorkflowStepName.EXPORT
# 如果审核人显式选了资产,就导出该资产;
# 否则默认导出 QC 候选资产。
export_source_id = review_payload.selected_asset_id
if export_source_id is None:
export_source_id = (qc_result.candidate_asset_ids or [fusion_result.asset_id])[0]
final_result = await workflow.execute_activity(
run_export_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.EXPORT,
source_asset_id=export_source_id,
),
task_queue=IMAGE_PIPELINE_EXPORT_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
return {
"order_id": payload.order_id,
"status": OrderStatus.SUCCEEDED.value,
"final_asset_id": final_result.asset_id,
}
if review_payload.decision == ReviewDecision.REJECT:
# reject 不再重跑,直接结束。
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
# rerun 的核心思想是:
# 把指定节点后的链路重新跑一遍,然后再次进入 QC 和 waiting_review。
if review_payload.decision == ReviewDecision.RERUN_SCENE:
current_step = WorkflowStepName.SCENE
scene_result = await self._run_scene(payload, tryon_result.asset_id)
current_step = WorkflowStepName.TEXTURE
texture_result = await self._run_texture(payload, scene_result.asset_id)
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
elif review_payload.decision == ReviewDecision.RERUN_FACE:
current_step = WorkflowStepName.FACE
face_result = await self._run_face(payload, texture_result.asset_id)
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
elif review_payload.decision == ReviewDecision.RERUN_FUSION:
current_step = WorkflowStepName.FUSION
fusion_result = await self._run_fusion(payload, scene_result.asset_id, face_result.asset_id)
current_step = WorkflowStepName.QC
qc_result = await self._run_qc(payload, fusion_result.asset_id)
if not qc_result.passed:
await self._mark_failed(payload, current_step, qc_result.message)
return {"order_id": payload.order_id, "status": OrderStatus.FAILED.value, "final_asset_id": None}
except Exception as exc:
await self._mark_failed(payload, current_step, str(exc))
raise
async def _wait_for_review(self) -> ReviewSignalPayload:
"""等待人工审核 signal。
`workflow.wait_condition` 是 Temporal 里很常见的等待方式:
workflow 会被安全地挂起,不会像普通 while + sleep 那样空转占资源。
"""
if self._review_payload is None:
await workflow.wait_condition(lambda: self._review_payload is not None)
assert self._review_payload is not None
review_payload = self._review_payload
self._review_payload = None
return review_payload
async def _run_scene(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""执行 scene activity。
抽成私有方法后rerun_scene 时可以直接复用,不需要复制整段 activity 调用代码。
"""
return await workflow.execute_activity(
run_scene_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.SCENE,
source_asset_id=source_asset_id,
scene_ref_asset_id=payload.scene_ref_asset_id,
),
task_queue=IMAGE_PIPELINE_IMAGE_GEN_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_texture(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""执行 texture activity。"""
return await workflow.execute_activity(
run_texture_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.TEXTURE,
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_face(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""执行 face activity。"""
return await workflow.execute_activity(
run_face_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.FACE,
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_fusion(
self,
payload: PipelineWorkflowInput,
source_asset_id: int | None,
face_asset_id: int | None,
) -> MockActivityResult:
"""执行 fusion activity。"""
return await workflow.execute_activity(
run_fusion_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.FUSION,
source_asset_id=source_asset_id,
selected_asset_id=face_asset_id,
),
task_queue=IMAGE_PIPELINE_POST_PROCESS_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _run_qc(self, payload: PipelineWorkflowInput, source_asset_id: int | None) -> MockActivityResult:
"""执行 QC activity。"""
return await workflow.execute_activity(
run_qc_activity,
StepActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
step_name=WorkflowStepName.QC,
source_asset_id=source_asset_id,
),
task_queue=IMAGE_PIPELINE_QC_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)
async def _mark_failed(
self,
payload: PipelineWorkflowInput,
current_step: WorkflowStepName,
message: str,
) -> None:
"""持久化 workflow 失败状态。"""
await workflow.execute_activity(
mark_workflow_failed_activity,
WorkflowFailureActivityInput(
order_id=payload.order_id,
workflow_run_id=payload.workflow_run_id,
current_step=current_step,
message=message,
),
task_queue=IMAGE_PIPELINE_CONTROL_TASK_QUEUE,
start_to_close_timeout=ACTIVITY_TIMEOUT,
retry_policy=ACTIVITY_RETRY_POLICY,
)