mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: personal ai agent framework
This commit is contained in:
461
agent/protocol/agent_stream.py
Normal file
461
agent/protocol/agent_stream.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
Agent Stream Execution Module - Multi-turn reasoning based on tool-call
|
||||
|
||||
Provides streaming output, event system, and complete tool-call loop
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class AgentStreamExecutor:
|
||||
"""
|
||||
Agent Stream Executor
|
||||
|
||||
Handles multi-turn reasoning loop based on tool-call:
|
||||
1. LLM generates response (may include tool calls)
|
||||
2. Execute tools
|
||||
3. Return results to LLM
|
||||
4. Repeat until no more tool calls
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent, # Agent instance
|
||||
model: LLMModel,
|
||||
system_prompt: str,
|
||||
tools: List[BaseTool],
|
||||
max_turns: int = 50,
|
||||
on_event: Optional[Callable] = None,
|
||||
messages: Optional[List[Dict]] = None
|
||||
):
|
||||
"""
|
||||
Initialize stream executor
|
||||
|
||||
Args:
|
||||
agent: Agent instance (for accessing context)
|
||||
model: LLM model
|
||||
system_prompt: System prompt
|
||||
tools: List of available tools
|
||||
max_turns: Maximum number of turns
|
||||
on_event: Event callback function
|
||||
messages: Optional existing message history (for persistent conversations)
|
||||
"""
|
||||
self.agent = agent
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
# Convert tools list to dict
|
||||
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
|
||||
self.max_turns = max_turns
|
||||
self.on_event = on_event
|
||||
|
||||
# Message history - use provided messages or create new list
|
||||
self.messages = messages if messages is not None else []
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict = None):
|
||||
"""Emit event"""
|
||||
if self.on_event:
|
||||
try:
|
||||
self.on_event({
|
||||
"type": event_type,
|
||||
"timestamp": time.time(),
|
||||
"data": data or {}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Event callback error: {e}")
|
||||
|
||||
def run_stream(self, user_message: str) -> str:
|
||||
"""
|
||||
Execute streaming reasoning loop
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
"""
|
||||
# Log user message
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info(f"👤 用户: {user_message}")
|
||||
logger.info(f"{'='*50}")
|
||||
|
||||
# Add user message (Claude format - use content blocks for consistency)
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_message
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
self._emit_event("agent_start")
|
||||
|
||||
final_response = ""
|
||||
turn = 0
|
||||
|
||||
try:
|
||||
while turn < self.max_turns:
|
||||
turn += 1
|
||||
logger.info(f"\n{'='*50} 第 {turn} 轮 {'='*50}")
|
||||
self._emit_event("turn_start", {"turn": turn})
|
||||
|
||||
# Check if memory flush is needed (before calling LLM)
|
||||
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
||||
usage = self.agent.last_usage
|
||||
if usage and 'input_tokens' in usage:
|
||||
current_tokens = usage.get('input_tokens', 0)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
reserve_tokens = self.agent.context_reserve_tokens or 20000
|
||||
|
||||
if self.agent.memory_manager.should_flush_memory(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens
|
||||
):
|
||||
self._emit_event("memory_flush_start", {
|
||||
"current_tokens": current_tokens,
|
||||
"threshold": context_window - reserve_tokens - 4000
|
||||
})
|
||||
|
||||
# TODO: Execute memory flush in background
|
||||
# This would require async support
|
||||
logger.info(f"Memory flush recommended at {current_tokens} tokens")
|
||||
|
||||
# Call LLM
|
||||
assistant_msg, tool_calls = self._call_llm_stream()
|
||||
final_response = assistant_msg
|
||||
|
||||
# No tool calls, end loop
|
||||
if not tool_calls:
|
||||
if assistant_msg:
|
||||
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
|
||||
logger.info(f"✅ 完成 (无工具调用)")
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": False
|
||||
})
|
||||
break
|
||||
|
||||
# Log tool calls in compact format
|
||||
tool_names = [tc['name'] for tc in tool_calls]
|
||||
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
|
||||
|
||||
# Execute tools
|
||||
tool_results = []
|
||||
tool_result_blocks = []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
result = self._execute_tool(tool_call)
|
||||
tool_results.append(result)
|
||||
|
||||
# Log tool result in compact format
|
||||
status_emoji = "✅" if result.get("status") == "success" else "❌"
|
||||
result_str = str(result.get('result', ''))
|
||||
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
||||
|
||||
# Build tool result block (Claude format)
|
||||
# Content should be a string representation of the result
|
||||
result_content = json.dumps(result) if not isinstance(result, str) else result
|
||||
tool_result_blocks.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call["id"],
|
||||
"content": result_content
|
||||
})
|
||||
|
||||
# Add tool results to message history as user message (Claude format)
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": tool_result_blocks
|
||||
})
|
||||
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": True,
|
||||
"tool_count": len(tool_calls)
|
||||
})
|
||||
|
||||
if turn >= self.max_turns:
|
||||
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Agent执行错误: {e}")
|
||||
self._emit_event("error", {"error": str(e)})
|
||||
raise
|
||||
|
||||
finally:
|
||||
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
|
||||
self._emit_event("agent_end", {"final_response": final_response})
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
|
||||
"""
|
||||
Call LLM with streaming
|
||||
|
||||
Returns:
|
||||
(response_text, tool_calls)
|
||||
"""
|
||||
# Trim messages if needed (using agent's context management)
|
||||
self._trim_messages()
|
||||
|
||||
# Prepare messages
|
||||
messages = self._prepare_messages()
|
||||
|
||||
# Debug: log message structure
|
||||
logger.debug(f"Sending {len(messages)} messages to LLM")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
content_types = [c.get("type") for c in content if isinstance(c, dict)]
|
||||
logger.debug(f" Message {i}: role={role}, content_blocks={content_types}")
|
||||
else:
|
||||
logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}")
|
||||
|
||||
# Prepare tool definitions (OpenAI/Claude format)
|
||||
tools_schema = None
|
||||
if self.tools:
|
||||
tools_schema = []
|
||||
for tool in self.tools.values():
|
||||
tools_schema.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.params # Claude uses input_schema
|
||||
})
|
||||
|
||||
# Create request
|
||||
request = LLMRequest(
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
stream=True,
|
||||
tools=tools_schema,
|
||||
system=self.system_prompt # Pass system prompt separately for Claude API
|
||||
)
|
||||
|
||||
self._emit_event("message_start", {"role": "assistant"})
|
||||
|
||||
# Streaming response
|
||||
full_content = ""
|
||||
tool_calls_buffer = {} # {index: {id, name, arguments}}
|
||||
|
||||
try:
|
||||
stream = self.model.call_stream(request)
|
||||
|
||||
for chunk in stream:
|
||||
# Check for errors
|
||||
if isinstance(chunk, dict) and chunk.get("error"):
|
||||
error_msg = chunk.get("message", "Unknown error")
|
||||
status_code = chunk.get("status_code", "N/A")
|
||||
logger.error(f"API Error: {error_msg} (Status: {status_code})")
|
||||
logger.error(f"Full error chunk: {chunk}")
|
||||
raise Exception(f"{error_msg} (Status: {status_code})")
|
||||
|
||||
# Parse chunk
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choice = chunk["choices"][0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Handle text content
|
||||
if "content" in delta and delta["content"]:
|
||||
content_delta = delta["content"]
|
||||
full_content += content_delta
|
||||
self._emit_event("message_update", {"delta": content_delta})
|
||||
|
||||
# Handle tool calls
|
||||
if "tool_calls" in delta:
|
||||
for tc_delta in delta["tool_calls"]:
|
||||
index = tc_delta.get("index", 0)
|
||||
|
||||
if index not in tool_calls_buffer:
|
||||
tool_calls_buffer[index] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": ""
|
||||
}
|
||||
|
||||
if "id" in tc_delta:
|
||||
tool_calls_buffer[index]["id"] = tc_delta["id"]
|
||||
|
||||
if "function" in tc_delta:
|
||||
func = tc_delta["function"]
|
||||
if "name" in func:
|
||||
tool_calls_buffer[index]["name"] = func["name"]
|
||||
if "arguments" in func:
|
||||
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call error: {e}")
|
||||
raise
|
||||
|
||||
# Parse tool calls
|
||||
tool_calls = []
|
||||
for idx in sorted(tool_calls_buffer.keys()):
|
||||
tc = tool_calls_buffer[idx]
|
||||
try:
|
||||
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
|
||||
arguments = {}
|
||||
|
||||
tool_calls.append({
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
# Add assistant message to history (Claude format uses content blocks)
|
||||
assistant_msg = {"role": "assistant", "content": []}
|
||||
|
||||
# Add text content block if present
|
||||
if full_content:
|
||||
assistant_msg["content"].append({
|
||||
"type": "text",
|
||||
"text": full_content
|
||||
})
|
||||
|
||||
# Add tool_use blocks if present
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
assistant_msg["content"].append({
|
||||
"type": "tool_use",
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"input": tc["arguments"]
|
||||
})
|
||||
|
||||
# Only append if content is not empty
|
||||
if assistant_msg["content"]:
|
||||
self.messages.append(assistant_msg)
|
||||
|
||||
self._emit_event("message_end", {
|
||||
"content": full_content,
|
||||
"tool_calls": tool_calls
|
||||
})
|
||||
|
||||
return full_content, tool_calls
|
||||
|
||||
def _execute_tool(self, tool_call: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute tool
|
||||
|
||||
Args:
|
||||
tool_call: {"id": str, "name": str, "arguments": dict}
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
tool_name = tool_call["name"]
|
||||
tool_id = tool_call["id"]
|
||||
arguments = tool_call["arguments"]
|
||||
|
||||
self._emit_event("tool_execution_start", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
try:
|
||||
tool = self.tools.get(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
# Set tool context
|
||||
tool.model = self.model
|
||||
tool.context = self.agent
|
||||
|
||||
# Execute tool
|
||||
start_time = time.time()
|
||||
result: ToolResult = tool.execute_tool(arguments)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
result_dict = {
|
||||
"status": result.status,
|
||||
"result": result.result,
|
||||
"execution_time": execution_time
|
||||
}
|
||||
|
||||
self._emit_event("tool_execution_end", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
**result_dict
|
||||
})
|
||||
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"result": str(e),
|
||||
"execution_time": 0
|
||||
}
|
||||
self._emit_event("tool_execution_end", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
**error_result
|
||||
})
|
||||
return error_result
|
||||
|
||||
def _trim_messages(self):
|
||||
"""
|
||||
Trim message history to stay within context limits.
|
||||
Uses agent's context management configuration.
|
||||
"""
|
||||
if not self.messages or not self.agent:
|
||||
return
|
||||
|
||||
# Get context window and reserve tokens from agent
|
||||
context_window = self.agent._get_model_context_window()
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
max_tokens = context_window - reserve_tokens
|
||||
|
||||
# Estimate current tokens
|
||||
current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
|
||||
|
||||
# Add system prompt tokens
|
||||
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
|
||||
current_tokens += system_tokens
|
||||
|
||||
# If under limit, no need to trim
|
||||
if current_tokens <= max_tokens:
|
||||
return
|
||||
|
||||
# Keep messages from newest, accumulating tokens
|
||||
available_tokens = max_tokens - system_tokens
|
||||
kept_messages = []
|
||||
accumulated_tokens = 0
|
||||
|
||||
for msg in reversed(self.messages):
|
||||
msg_tokens = self.agent._estimate_message_tokens(msg)
|
||||
if accumulated_tokens + msg_tokens <= available_tokens:
|
||||
kept_messages.insert(0, msg)
|
||||
accumulated_tokens += msg_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
old_count = len(self.messages)
|
||||
self.messages = kept_messages
|
||||
new_count = len(self.messages)
|
||||
|
||||
if old_count > new_count:
|
||||
logger.info(
|
||||
f"Context trimmed: {old_count} -> {new_count} messages "
|
||||
f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
|
||||
f"limit: {max_tokens})"
|
||||
)
|
||||
|
||||
def _prepare_messages(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Prepare messages to send to LLM
|
||||
|
||||
Note: For Claude API, system prompt should be passed separately via system parameter,
|
||||
not as a message. The AgentLLMModel will handle this.
|
||||
"""
|
||||
# Don't add system message here - it will be handled separately by the LLM adapter
|
||||
return self.messages
|
||||
Reference in New Issue
Block a user