feat(agent): support user-initiated cancel for in-flight agent runs

This commit is contained in:
zhayujie
2026-05-26 23:36:09 +08:00
parent ad2db1a776
commit 8d67177a1b
19 changed files with 691 additions and 22 deletions

View File

@@ -7,6 +7,7 @@ import json
import time
from typing import List, Dict, Any, Optional, Callable, Tuple
from agent.protocol.cancel import AgentCancelledError
from agent.protocol.models import LLMRequest, LLMModel
from agent.protocol.message_utils import sanitize_claude_messages, compress_turn_to_text_only
from agent.tools.base_tool import BaseTool, ToolResult
@@ -64,7 +65,8 @@ class AgentStreamExecutor:
max_turns: int = 50,
on_event: Optional[Callable] = None,
messages: Optional[List[Dict]] = None,
max_context_turns: int = 30
max_context_turns: int = 30,
cancel_event=None,
):
"""
Initialize stream executor
@@ -78,6 +80,10 @@ class AgentStreamExecutor:
on_event: Event callback function
messages: Optional existing message history (for persistent conversations)
max_context_turns: Maximum number of conversation turns to keep in context
cancel_event: Optional threading.Event used to signal user cancel.
Checked at every safe point (turn boundary, before tool execution,
during LLM streaming). When set, raises AgentCancelledError which
run_stream catches to gracefully wind down.
"""
self.agent = agent
self.model = model
@@ -87,6 +93,7 @@ class AgentStreamExecutor:
self.max_turns = max_turns
self.on_event = on_event
self.max_context_turns = max_context_turns
self.cancel_event = cancel_event
# Message history - use provided messages or create new list
self.messages = messages if messages is not None else []
@@ -97,6 +104,73 @@ class AgentStreamExecutor:
# Track files to send (populated by read tool)
self.files_to_send = [] # List of file metadata dicts
def _check_cancelled(self) -> None:
"""Raise AgentCancelledError if the user requested cancellation.
Called at safe points (turn start, between tool calls, between LLM
chunks). Cheap to call: just an Event.is_set() probe.
"""
if self.cancel_event is not None and self.cancel_event.is_set():
raise AgentCancelledError("agent cancelled by user")
def _handle_cancelled(self, partial_response: str) -> None:
"""Wind down ``self.messages`` after a user-initiated cancel.
The messages list may be in any of these states when we get here:
(a) Last message is an assistant message containing tool_use
blocks but the matching tool_result has not been appended yet.
(b) Last message is an assistant text-only reply (cancel happened
right before the next turn started).
(c) Last message is a user tool_result message and we cancelled
between turns.
For (a) we MUST synthesise tool_result blocks, otherwise the next
request will fail Claude/OpenAI's strict pairing validation. For
(b)/(c) the state is already valid and we just append a small
cancellation note so the user/LLM both see the boundary clearly.
"""
try:
# Step 1: close any orphaned tool_use in the trailing assistant
# message by injecting matching tool_result blocks.
if self.messages and isinstance(self.messages[-1], dict) \
and self.messages[-1].get("role") == "assistant":
last = self.messages[-1]
content = last.get("content")
if isinstance(content, list):
pending_tool_use_ids = [
block.get("id")
for block in content
if isinstance(block, dict) and block.get("type") == "tool_use"
]
pending_tool_use_ids = [tid for tid in pending_tool_use_ids if tid]
if pending_tool_use_ids:
tool_result_blocks = [
{
"type": "tool_result",
"tool_use_id": tid,
"content": "Cancelled by user before this tool finished.",
"is_error": True,
}
for tid in pending_tool_use_ids
]
self.messages.append({
"role": "user",
"content": tool_result_blocks,
})
logger.info(
f"[Agent] Injected {len(tool_result_blocks)} cancellation "
f"tool_result blocks to keep message history valid"
)
# Step 2: append a stable "interrupted" marker so the LLM sees a
# clear stop boundary on the next turn.
self.messages.append({
"role": "assistant",
"content": [{"type": "text", "text": "_(Cancelled by user)_"}],
})
except Exception as e:
logger.warning(f"[Agent] _handle_cancelled cleanup failed: {e}")
def _emit_event(self, event_type: str, data: dict = None):
"""Emit event"""
if self.on_event:
@@ -270,8 +344,13 @@ class AgentStreamExecutor:
final_response = ""
turn = 0
cancelled = False
try:
while turn < self.max_turns:
# Check at the very top of every turn so a cancel arriving
# between turns short-circuits cleanly.
self._check_cancelled()
turn += 1
logger.info(f"[Agent] 第 {turn}")
self._emit_event("turn_start", {"turn": turn})
@@ -375,6 +454,8 @@ class AgentStreamExecutor:
try:
for tool_call in tool_calls:
# Honour cancel between tool invocations within the same turn
self._check_cancelled()
result = self._execute_tool(tool_call)
tool_results.append(result)
@@ -557,6 +638,15 @@ class AgentStreamExecutor:
self.messages.pop(prompt_insert_idx)
logger.debug("[Agent] Removed injected max-steps prompt from message history")
except AgentCancelledError:
# User-initiated stop: wind down message history cleanly so the
# next turn is unaffected; channels emit a "cancelled" UI event.
cancelled = True
logger.info(f"[Agent] 🛑 已被用户中止 (第 {turn} 轮)")
self._handle_cancelled(final_response)
if not final_response or not final_response.strip():
final_response = "_(Cancelled)_"
except Exception as e:
logger.error(f"❌ Agent执行错误: {e}")
self._emit_event("error", {"error": str(e)})
@@ -564,8 +654,11 @@ class AgentStreamExecutor:
finally:
final_response = final_response.strip() if final_response else final_response
logger.info(f"[Agent] 🏁 完成 ({turn}轮)")
self._emit_event("agent_end", {"final_response": final_response})
if cancelled:
# Emit before agent_end so channels can mark UI as cancelled
self._emit_event("agent_cancelled", {"final_response": final_response})
logger.info(f"[Agent] 🏁 完成 ({turn}轮)" + (" [cancelled]" if cancelled else ""))
self._emit_event("agent_end", {"final_response": final_response, "cancelled": cancelled})
return final_response
@@ -644,7 +737,32 @@ class AgentStreamExecutor:
try:
stream = self.model.call_stream(request)
# Probe cancel every N chunks to bound reaction time without
# checking on every token.
_cancel_probe_counter = 0
_CANCEL_PROBE_EVERY = 8
for chunk in stream:
_cancel_probe_counter += 1
if _cancel_probe_counter >= _CANCEL_PROBE_EVERY:
_cancel_probe_counter = 0
if self.cancel_event is not None and self.cancel_event.is_set():
# Persist partial text only; tool_use args may be
# truncated mid-stream and would fail validation.
logger.info("[Agent] cancel detected mid-stream, aborting LLM call")
if full_content:
partial_msg = {
"role": "assistant",
"content": [{"type": "text", "text": full_content}],
}
self.messages.append(partial_msg)
self._emit_event("message_end", {
"content": full_content,
"tool_calls": [],
"cancelled": True,
})
raise AgentCancelledError("cancelled during LLM streaming")
# Check for errors
if isinstance(chunk, dict) and chunk.get("error"):
# Extract error message from nested structure
@@ -738,6 +856,10 @@ class AgentStreamExecutor:
elif isinstance(choice, dict) and choice.get("_gemini_raw_parts"):
gemini_raw_parts = choice["_gemini_raw_parts"]
except AgentCancelledError:
# Must propagate untouched; never treat as a retryable error.
raise
except Exception as e:
error_str = str(e)
error_str_lower = error_str.lower()