229 lines
8.0 KiB
Python
229 lines
8.0 KiB
Python
"""
|
||
生成阶段状态管理器
|
||
用于支持生成任务的暂停、停止功能
|
||
使用轮询检查机制,确保线程安全
|
||
支持多用户/多generation_id并行管理
|
||
"""
|
||
|
||
import threading
|
||
import asyncio
|
||
import time
|
||
from typing import Optional, Dict
|
||
from enum import Enum
|
||
|
||
|
||
class GenerationStatus(Enum):
|
||
"""生成状态枚举"""
|
||
GENERATING = "generating" # 正在生成
|
||
PAUSED = "paused" # 已暂停
|
||
STOPPED = "stopped" # 已停止
|
||
COMPLETED = "completed" # 已完成
|
||
IDLE = "idle" # 空闲
|
||
|
||
|
||
class GenerationStateManager:
|
||
"""
|
||
生成阶段状态管理器
|
||
|
||
功能:
|
||
- 管理多用户/多generation_id的并行状态(使用字典存储)
|
||
- 管理生成任务状态(生成中/暂停/停止/完成)
|
||
- 提供线程安全的状态查询和修改接口
|
||
|
||
设计说明:
|
||
- 保持单例模式(Manager本身)
|
||
- 但内部状态按 generation_id 隔离存储
|
||
- 解决多用户并发生成时的干扰问题
|
||
"""
|
||
|
||
_instance: Optional['GenerationStateManager'] = None
|
||
_lock = threading.Lock()
|
||
|
||
def __new__(cls):
|
||
"""单例模式"""
|
||
if cls._instance is None:
|
||
with cls._lock:
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance._initialized = False
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
"""初始化状态管理器"""
|
||
if self._initialized:
|
||
return
|
||
|
||
self._initialized = True
|
||
|
||
# 状态存储:generation_id -> 状态字典
|
||
# 结构:{
|
||
# 'status': GenerationStatus,
|
||
# 'goal': str,
|
||
# 'should_stop': bool
|
||
# }
|
||
self._states: Dict[str, Dict] = {}
|
||
|
||
# 每个 generation_id 的锁(更细粒度的锁)
|
||
self._locks: Dict[str, threading.Lock] = {}
|
||
|
||
# 全局锁(用于管理 _states 和 _locks 本身的线程安全)
|
||
self._manager_lock = threading.Lock()
|
||
|
||
def _get_lock(self, generation_id: str) -> threading.Lock:
|
||
"""获取指定 generation_id 的锁,如果不存在则创建"""
|
||
with self._manager_lock:
|
||
if generation_id not in self._locks:
|
||
self._locks[generation_id] = threading.Lock()
|
||
return self._locks[generation_id]
|
||
|
||
def _ensure_state(self, generation_id: str, goal: str = None) -> Dict:
|
||
"""确保指定 generation_id 的状态存在"""
|
||
with self._manager_lock:
|
||
if generation_id not in self._states:
|
||
self._states[generation_id] = {
|
||
'status': GenerationStatus.IDLE,
|
||
'goal': goal,
|
||
'should_stop': False
|
||
}
|
||
return self._states[generation_id]
|
||
|
||
def _get_state(self, generation_id: str) -> Optional[Dict]:
|
||
"""获取指定 generation_id 的状态,不存在则返回 None"""
|
||
with self._manager_lock:
|
||
return self._states.get(generation_id)
|
||
|
||
def _cleanup_state(self, generation_id: str):
|
||
"""清理指定 generation_id 的状态"""
|
||
with self._manager_lock:
|
||
self._states.pop(generation_id, None)
|
||
self._locks.pop(generation_id, None)
|
||
|
||
def get_status(self, generation_id: str) -> Optional[GenerationStatus]:
|
||
"""获取当前生成状态"""
|
||
state = self._get_state(generation_id)
|
||
if state is None:
|
||
return None
|
||
with self._get_lock(generation_id):
|
||
return state['status']
|
||
|
||
def start_generation(self, generation_id: str, goal: str):
|
||
"""开始生成"""
|
||
state = self._ensure_state(generation_id, goal)
|
||
with self._get_lock(generation_id):
|
||
state['status'] = GenerationStatus.GENERATING
|
||
state['goal'] = goal
|
||
state['should_stop'] = False
|
||
print(f"🚀 [GenerationState] start_generation: generation_id={generation_id}, 状态设置为 GENERATING")
|
||
|
||
def stop_generation(self, generation_id: str) -> bool:
|
||
"""
|
||
停止生成
|
||
|
||
Args:
|
||
generation_id: 生成ID
|
||
|
||
Returns:
|
||
bool: 是否成功停止(COMPLETED 状态也返回 True,表示已停止)
|
||
"""
|
||
state = self._get_state(generation_id)
|
||
if state is None:
|
||
print(f"⚠️ [GenerationState] stop_generation: generation_id={generation_id} 不存在")
|
||
return True # 不存在也算停止成功
|
||
|
||
with self._get_lock(generation_id):
|
||
if state['status'] == GenerationStatus.STOPPED:
|
||
print(f"✅ [GenerationState] stop_generation: generation_id={generation_id} 已经是 STOPPED 状态")
|
||
return True # 已经停止也算成功
|
||
|
||
if state['status'] == GenerationStatus.COMPLETED:
|
||
print(f"✅ [GenerationState] stop_generation: generation_id={generation_id} 已经 COMPLETED,视为停止成功")
|
||
return True # 已完成也视为停止成功
|
||
|
||
if state['status'] == GenerationStatus.IDLE:
|
||
print(f"⚠️ [GenerationState] stop_generation: generation_id={generation_id} 是 IDLE 状态,无需停止")
|
||
return True # 空闲状态也视为无需停止
|
||
|
||
# 真正需要停止的情况
|
||
state['status'] = GenerationStatus.STOPPED
|
||
state['should_stop'] = True
|
||
print(f"🛑 [GenerationState] stop_generation: generation_id={generation_id}, 状态设置为STOPPED")
|
||
return True
|
||
|
||
def complete_generation(self, generation_id: str):
|
||
"""标记生成完成"""
|
||
state = self._ensure_state(generation_id)
|
||
with self._get_lock(generation_id):
|
||
state['status'] = GenerationStatus.COMPLETED
|
||
print(f"✅ [GenerationState] complete_generation: generation_id={generation_id}")
|
||
|
||
def cleanup(self, generation_id: str):
|
||
"""清理指定 generation_id 的所有状态"""
|
||
self._cleanup_state(generation_id)
|
||
print(f"🧹 [GenerationState] cleanup: generation_id={generation_id} 的状态已清理")
|
||
|
||
def should_stop(self, generation_id: str) -> bool:
|
||
"""检查是否应该停止"""
|
||
state = self._get_state(generation_id)
|
||
if state is None:
|
||
return False
|
||
with self._get_lock(generation_id):
|
||
return state.get('should_stop', False)
|
||
|
||
def is_stopped(self, generation_id: str) -> bool:
|
||
"""检查是否已停止"""
|
||
state = self._get_state(generation_id)
|
||
if state is None:
|
||
return False
|
||
with self._get_lock(generation_id):
|
||
return state['status'] == GenerationStatus.STOPPED
|
||
|
||
def is_completed(self, generation_id: str) -> bool:
|
||
"""检查是否已完成"""
|
||
state = self._get_state(generation_id)
|
||
if state is None:
|
||
return False
|
||
with self._get_lock(generation_id):
|
||
return state['status'] == GenerationStatus.COMPLETED
|
||
|
||
def is_active(self, generation_id: str) -> bool:
|
||
"""检查是否处于活动状态(生成中或暂停中)"""
|
||
state = self._get_state(generation_id)
|
||
if state is None:
|
||
return False
|
||
with self._get_lock(generation_id):
|
||
return state['status'] == GenerationStatus.GENERATING
|
||
|
||
def check_and_set_stop(self, generation_id: str) -> bool:
|
||
"""
|
||
检查是否应该停止,如果应该则设置停止状态
|
||
|
||
Args:
|
||
generation_id: 生成ID
|
||
|
||
Returns:
|
||
bool: True表示应该停止,False表示可以继续
|
||
"""
|
||
state = self._get_state(generation_id)
|
||
if state is None:
|
||
return False
|
||
with self._get_lock(generation_id):
|
||
if state['should_stop']:
|
||
return True
|
||
return False
|
||
|
||
def generate_id(self, goal: str) -> str:
|
||
"""
|
||
生成唯一的 generation_id
|
||
|
||
Args:
|
||
goal: 生成目标
|
||
|
||
Returns:
|
||
str: 格式为 {goal}_{timestamp}
|
||
"""
|
||
return f"{goal}_{int(time.time() * 1000)}"
|
||
|
||
|
||
# 全局单例实例
|
||
generation_state_manager = GenerationStateManager()
|