mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
"""
|
|
Cancel token registry for aborting in-flight agent runs.
|
|
|
|
A user cancel (web Cancel button, /cancel command) sets a threading.Event
|
|
that the agent loop polls at safe checkpoints. Tokens are keyed by
|
|
request_id (preferred) and tracked under session_id as a fallback. Entries
|
|
are released after the run completes to keep the registry bounded.
|
|
|
|
No project deps — importable from any layer without circular imports.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
from typing import Dict, Optional
|
|
|
|
|
|
class AgentCancelledError(Exception):
|
|
"""Raised inside the agent loop when a stop has been requested.
|
|
|
|
The agent stream executor catches this, injects a "[Interrupted]" note
|
|
into the message history (preserving tool_use/tool_result integrity)
|
|
and returns a partial response to the caller.
|
|
"""
|
|
|
|
|
|
class _CancelEntry:
|
|
__slots__ = ("event", "session_id")
|
|
|
|
def __init__(self, session_id: Optional[str]):
|
|
self.event = threading.Event()
|
|
self.session_id = session_id
|
|
|
|
|
|
class CancelTokenRegistry:
|
|
"""In-process registry mapping request_id -> cancel Event.
|
|
|
|
Thread-safe. Singleton via module-level ``_registry``.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._lock = threading.Lock()
|
|
self._by_request: Dict[str, _CancelEntry] = {}
|
|
# session_id -> set of request_ids currently in flight (usually 1).
|
|
self._by_session: Dict[str, set] = {}
|
|
|
|
def register(self, request_id: str, session_id: Optional[str] = None) -> threading.Event:
|
|
"""Create (or return existing) cancel event for a request.
|
|
|
|
Returns the threading.Event the caller should poll via ``is_set()``.
|
|
"""
|
|
if not request_id:
|
|
return threading.Event()
|
|
with self._lock:
|
|
entry = self._by_request.get(request_id)
|
|
if entry is None:
|
|
entry = _CancelEntry(session_id)
|
|
self._by_request[request_id] = entry
|
|
if session_id:
|
|
self._by_session.setdefault(session_id, set()).add(request_id)
|
|
return entry.event
|
|
|
|
def get_event(self, request_id: str) -> Optional[threading.Event]:
|
|
if not request_id:
|
|
return None
|
|
with self._lock:
|
|
entry = self._by_request.get(request_id)
|
|
return entry.event if entry else None
|
|
|
|
def cancel_request(self, request_id: str) -> bool:
|
|
"""Trigger cancel for a specific request. Returns True when matched."""
|
|
if not request_id:
|
|
return False
|
|
with self._lock:
|
|
entry = self._by_request.get(request_id)
|
|
if entry is None:
|
|
return False
|
|
entry.event.set()
|
|
return True
|
|
|
|
def cancel_session(self, session_id: str) -> int:
|
|
"""Trigger cancel for every in-flight request of a session.
|
|
|
|
Returns the number of requests cancelled (0 when nothing was running).
|
|
"""
|
|
if not session_id:
|
|
return 0
|
|
with self._lock:
|
|
request_ids = list(self._by_session.get(session_id, ()))
|
|
entries = [self._by_request[r] for r in request_ids if r in self._by_request]
|
|
for entry in entries:
|
|
entry.event.set()
|
|
return len(entries)
|
|
|
|
def unregister(self, request_id: str) -> None:
|
|
"""Remove an entry once the agent run is done. Safe to call twice."""
|
|
if not request_id:
|
|
return
|
|
with self._lock:
|
|
entry = self._by_request.pop(request_id, None)
|
|
if entry and entry.session_id:
|
|
bucket = self._by_session.get(entry.session_id)
|
|
if bucket is not None:
|
|
bucket.discard(request_id)
|
|
if not bucket:
|
|
self._by_session.pop(entry.session_id, None)
|
|
|
|
def has_active(self, session_id: str) -> bool:
|
|
if not session_id:
|
|
return False
|
|
with self._lock:
|
|
bucket = self._by_session.get(session_id)
|
|
return bool(bucket)
|
|
|
|
|
|
_registry = CancelTokenRegistry()
|
|
|
|
|
|
def get_cancel_registry() -> CancelTokenRegistry:
|
|
"""Module-level accessor for the singleton registry."""
|
|
return _registry
|