46 lines
1.1 KiB
Python
46 lines
1.1 KiB
Python
"""Temporal client 辅助函数。
|
||
|
||
API 和 worker 都需要连接 Temporal Server。
|
||
这里做了一个简单的单例缓存,避免重复建立连接。
|
||
"""
|
||
|
||
import asyncio
|
||
|
||
from temporalio.client import Client
|
||
|
||
from app.config.settings import get_settings
|
||
|
||
_client: Client | None = None
|
||
_client_lock = asyncio.Lock()
|
||
|
||
|
||
async def get_temporal_client() -> Client:
|
||
"""返回缓存后的 Temporal Client。
|
||
|
||
第一次调用时才真正连接 Temporal;后续复用同一个 client。
|
||
"""
|
||
|
||
global _client
|
||
if _client is not None:
|
||
return _client
|
||
|
||
async with _client_lock:
|
||
# 双重检查,避免并发场景下重复 connect。
|
||
if _client is None:
|
||
settings = get_settings()
|
||
_client = await Client.connect(
|
||
settings.temporal_address,
|
||
namespace=settings.temporal_namespace,
|
||
)
|
||
return _client
|
||
|
||
|
||
def set_temporal_client(client: Client | None) -> None:
|
||
"""覆盖缓存的 Temporal Client。
|
||
|
||
主要用于测试场景,把真实连接替换成 Temporal 测试环境里的 client。
|
||
"""
|
||
|
||
global _client
|
||
_client = client
|