mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: personal ai agent framework
This commit is contained in:
20
agent/protocol/__init__.py
Normal file
20
agent/protocol/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .agent import Agent
|
||||
from .agent_stream import AgentStreamExecutor
|
||||
from .task import Task, TaskType, TaskStatus
|
||||
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
|
||||
from .models import LLMModel, LLMRequest, ModelFactory
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'AgentStreamExecutor',
|
||||
'Task',
|
||||
'TaskType',
|
||||
'TaskStatus',
|
||||
'AgentResult',
|
||||
'AgentAction',
|
||||
'AgentActionType',
|
||||
'ToolResult',
|
||||
'LLMModel',
|
||||
'LLMRequest',
|
||||
'ModelFactory'
|
||||
]
|
||||
292
agent/protocol/agent.py
Normal file
292
agent/protocol/agent.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult
|
||||
from agent.tools.base_tool import BaseTool, ToolStage
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
||||
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
||||
context_reserve_tokens=None, memory_manager=None, name: str = None):
|
||||
"""
|
||||
Initialize the Agent with system prompt, model, description.
|
||||
|
||||
:param system_prompt: The system prompt for the agent.
|
||||
:param description: A description of the agent.
|
||||
:param model: An instance of LLMModel to be used by the agent.
|
||||
:param tools: Optional list of tools for the agent to use.
|
||||
:param output_mode: Control how execution progress is displayed:
|
||||
"print" for console output or "logger" for using logger
|
||||
:param max_steps: Maximum number of steps the agent can take (default: 100)
|
||||
:param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model)
|
||||
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
|
||||
:param memory_manager: Optional MemoryManager instance for memory operations
|
||||
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
|
||||
"""
|
||||
self.name = name or "Agent"
|
||||
self.system_prompt = system_prompt
|
||||
self.model: LLMModel = model # Instance of LLMModel
|
||||
self.description = description
|
||||
self.tools: list = []
|
||||
self.max_steps = max_steps # max tool-call steps, default 100
|
||||
self.max_context_tokens = max_context_tokens # max tokens in context
|
||||
self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests
|
||||
self.captured_actions = [] # Initialize captured actions list
|
||||
self.output_mode = output_mode
|
||||
self.last_usage = None # Store last API response usage info
|
||||
self.messages = [] # Unified message history for stream mode
|
||||
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||
if tools:
|
||||
for tool in tools:
|
||||
self.add_tool(tool)
|
||||
|
||||
def add_tool(self, tool: BaseTool):
|
||||
"""
|
||||
Add a tool to the agent.
|
||||
|
||||
:param tool: The tool to add (either a tool instance or a tool name)
|
||||
"""
|
||||
# If tool is already an instance, use it directly
|
||||
tool.model = self.model
|
||||
self.tools.append(tool)
|
||||
|
||||
def _get_model_context_window(self) -> int:
|
||||
"""
|
||||
Get the model's context window size in tokens.
|
||||
Auto-detect based on model name.
|
||||
|
||||
Model context windows:
|
||||
- Claude 3.5/3.7 Sonnet: 200K tokens
|
||||
- Claude 3 Opus: 200K tokens
|
||||
- GPT-4 Turbo/128K: 128K tokens
|
||||
- GPT-4: 8K-32K tokens
|
||||
- GPT-3.5: 16K tokens
|
||||
- DeepSeek: 64K tokens
|
||||
|
||||
:return: Context window size in tokens
|
||||
"""
|
||||
if self.model and hasattr(self.model, 'model'):
|
||||
model_name = self.model.model.lower()
|
||||
|
||||
# Claude models - 200K context
|
||||
if 'claude-3' in model_name or 'claude-sonnet' in model_name:
|
||||
return 200000
|
||||
|
||||
# GPT-4 models
|
||||
elif 'gpt-4' in model_name:
|
||||
if 'turbo' in model_name or '128k' in model_name:
|
||||
return 128000
|
||||
elif '32k' in model_name:
|
||||
return 32000
|
||||
else:
|
||||
return 8000
|
||||
|
||||
# GPT-3.5
|
||||
elif 'gpt-3.5' in model_name:
|
||||
if '16k' in model_name:
|
||||
return 16000
|
||||
else:
|
||||
return 4000
|
||||
|
||||
# DeepSeek
|
||||
elif 'deepseek' in model_name:
|
||||
return 64000
|
||||
|
||||
# Default conservative value
|
||||
return 10000
|
||||
|
||||
def _get_context_reserve_tokens(self) -> int:
|
||||
"""
|
||||
Get the number of tokens to reserve for new requests.
|
||||
This prevents context overflow by keeping a buffer.
|
||||
|
||||
:return: Number of tokens to reserve
|
||||
"""
|
||||
if self.context_reserve_tokens is not None:
|
||||
return self.context_reserve_tokens
|
||||
|
||||
# Reserve ~20% of context window for new requests
|
||||
context_window = self._get_model_context_window()
|
||||
return max(4000, int(context_window * 0.2))
|
||||
|
||||
def _estimate_message_tokens(self, message: dict) -> int:
|
||||
"""
|
||||
Estimate token count for a message using chars/4 heuristic.
|
||||
This is a conservative estimate (tends to overestimate).
|
||||
|
||||
:param message: Message dict with 'role' and 'content'
|
||||
:return: Estimated token count
|
||||
"""
|
||||
content = message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
return max(1, len(content) // 4)
|
||||
elif isinstance(content, list):
|
||||
# Handle multi-part content (text + images)
|
||||
total_chars = 0
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get('type') == 'text':
|
||||
total_chars += len(part.get('text', ''))
|
||||
elif isinstance(part, dict) and part.get('type') == 'image':
|
||||
# Estimate images as ~1200 tokens
|
||||
total_chars += 4800
|
||||
return max(1, total_chars // 4)
|
||||
return 1
|
||||
|
||||
def _find_tool(self, tool_name: str):
|
||||
"""Find and return a tool with the specified name"""
|
||||
for tool in self.tools:
|
||||
if tool.name == tool_name:
|
||||
# Only pre-process stage tools can be actively called
|
||||
if tool.stage == ToolStage.PRE_PROCESS:
|
||||
tool.model = self.model
|
||||
tool.context = self # Set tool context
|
||||
return tool
|
||||
else:
|
||||
# If it's a post-process tool, return None to prevent direct calling
|
||||
logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.")
|
||||
return None
|
||||
return None
|
||||
|
||||
# output function based on mode
|
||||
def output(self, message="", end="\n"):
|
||||
if self.output_mode == "print":
|
||||
print(message, end=end)
|
||||
elif message:
|
||||
logger.info(message)
|
||||
|
||||
def _execute_post_process_tools(self):
|
||||
"""Execute all post-process stage tools"""
|
||||
# Get all post-process stage tools
|
||||
post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS]
|
||||
|
||||
# Execute each tool
|
||||
for tool in post_process_tools:
|
||||
# Set tool context
|
||||
tool.context = self
|
||||
|
||||
# Record start time for execution timing
|
||||
start_time = time.time()
|
||||
|
||||
# Execute tool (with empty parameters, tool will extract needed info from context)
|
||||
result = tool.execute({})
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Capture tool use for tracking
|
||||
self.capture_tool_use(
|
||||
tool_name=tool.name,
|
||||
input_params={}, # Post-process tools typically don't take parameters
|
||||
output=result.result,
|
||||
status=result.status,
|
||||
error_message=str(result.result) if result.status == "error" else None,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# Log result
|
||||
if result.status == "success":
|
||||
# Print tool execution result in the desired format
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}")
|
||||
else:
|
||||
# Print failure in print mode
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}")
|
||||
|
||||
def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None,
|
||||
execution_time=0.0):
|
||||
"""
|
||||
Capture a tool use action.
|
||||
|
||||
:param thought: thought content
|
||||
:param tool_name: Name of the tool used
|
||||
:param input_params: Parameters passed to the tool
|
||||
:param output: Output from the tool
|
||||
:param status: Status of the tool execution
|
||||
:param error_message: Error message if the tool execution failed
|
||||
:param execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_result = ToolResult(
|
||||
tool_name=tool_name,
|
||||
input_params=input_params,
|
||||
output=output,
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
action = AgentAction(
|
||||
agent_id=self.id if hasattr(self, 'id') else str(id(self)),
|
||||
agent_name=self.name,
|
||||
action_type=AgentActionType.TOOL_USE,
|
||||
tool_result=tool_result,
|
||||
thought=thought
|
||||
)
|
||||
|
||||
self.captured_actions.append(action)
|
||||
|
||||
return action
|
||||
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False) -> str:
|
||||
"""
|
||||
Execute single agent task with streaming (based on tool-call)
|
||||
|
||||
This method supports:
|
||||
- Streaming output
|
||||
- Multi-turn reasoning based on tool-call
|
||||
- Event callbacks
|
||||
- Persistent conversation history across calls
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
on_event: Event callback function callback(event: dict)
|
||||
event = {"type": str, "timestamp": float, "data": dict}
|
||||
clear_history: If True, clear conversation history before this call (default: False)
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
|
||||
Example:
|
||||
# Multi-turn conversation with memory
|
||||
response1 = agent.run_stream("My name is Alice")
|
||||
response2 = agent.run_stream("What's my name?") # Will remember Alice
|
||||
|
||||
# Single-turn without memory
|
||||
response = agent.run_stream("Hello", clear_history=True)
|
||||
"""
|
||||
# Clear history if requested
|
||||
if clear_history:
|
||||
self.messages = []
|
||||
|
||||
# Get model to use
|
||||
if not self.model:
|
||||
raise ValueError("No model available for agent")
|
||||
|
||||
# Create stream executor with agent's message history
|
||||
executor = AgentStreamExecutor(
|
||||
agent=self,
|
||||
model=self.model,
|
||||
system_prompt=self.system_prompt,
|
||||
tools=self.tools,
|
||||
max_turns=self.max_steps,
|
||||
on_event=on_event,
|
||||
messages=self.messages # Pass agent's message history
|
||||
)
|
||||
|
||||
# Execute
|
||||
response = executor.run_stream(user_message)
|
||||
|
||||
# Update agent's message history from executor
|
||||
self.messages = executor.messages
|
||||
|
||||
# Execute all post-process tools
|
||||
self._execute_post_process_tools()
|
||||
|
||||
return response
|
||||
|
||||
def clear_history(self):
|
||||
"""Clear conversation history and captured actions"""
|
||||
self.messages = []
|
||||
self.captured_actions = []
|
||||
461
agent/protocol/agent_stream.py
Normal file
461
agent/protocol/agent_stream.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
Agent Stream Execution Module - Multi-turn reasoning based on tool-call
|
||||
|
||||
Provides streaming output, event system, and complete tool-call loop
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class AgentStreamExecutor:
|
||||
"""
|
||||
Agent Stream Executor
|
||||
|
||||
Handles multi-turn reasoning loop based on tool-call:
|
||||
1. LLM generates response (may include tool calls)
|
||||
2. Execute tools
|
||||
3. Return results to LLM
|
||||
4. Repeat until no more tool calls
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent, # Agent instance
|
||||
model: LLMModel,
|
||||
system_prompt: str,
|
||||
tools: List[BaseTool],
|
||||
max_turns: int = 50,
|
||||
on_event: Optional[Callable] = None,
|
||||
messages: Optional[List[Dict]] = None
|
||||
):
|
||||
"""
|
||||
Initialize stream executor
|
||||
|
||||
Args:
|
||||
agent: Agent instance (for accessing context)
|
||||
model: LLM model
|
||||
system_prompt: System prompt
|
||||
tools: List of available tools
|
||||
max_turns: Maximum number of turns
|
||||
on_event: Event callback function
|
||||
messages: Optional existing message history (for persistent conversations)
|
||||
"""
|
||||
self.agent = agent
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
# Convert tools list to dict
|
||||
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
|
||||
self.max_turns = max_turns
|
||||
self.on_event = on_event
|
||||
|
||||
# Message history - use provided messages or create new list
|
||||
self.messages = messages if messages is not None else []
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict = None):
|
||||
"""Emit event"""
|
||||
if self.on_event:
|
||||
try:
|
||||
self.on_event({
|
||||
"type": event_type,
|
||||
"timestamp": time.time(),
|
||||
"data": data or {}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Event callback error: {e}")
|
||||
|
||||
def run_stream(self, user_message: str) -> str:
|
||||
"""
|
||||
Execute streaming reasoning loop
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
"""
|
||||
# Log user message
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info(f"👤 用户: {user_message}")
|
||||
logger.info(f"{'='*50}")
|
||||
|
||||
# Add user message (Claude format - use content blocks for consistency)
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_message
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
self._emit_event("agent_start")
|
||||
|
||||
final_response = ""
|
||||
turn = 0
|
||||
|
||||
try:
|
||||
while turn < self.max_turns:
|
||||
turn += 1
|
||||
logger.info(f"\n{'='*50} 第 {turn} 轮 {'='*50}")
|
||||
self._emit_event("turn_start", {"turn": turn})
|
||||
|
||||
# Check if memory flush is needed (before calling LLM)
|
||||
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
||||
usage = self.agent.last_usage
|
||||
if usage and 'input_tokens' in usage:
|
||||
current_tokens = usage.get('input_tokens', 0)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
reserve_tokens = self.agent.context_reserve_tokens or 20000
|
||||
|
||||
if self.agent.memory_manager.should_flush_memory(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens
|
||||
):
|
||||
self._emit_event("memory_flush_start", {
|
||||
"current_tokens": current_tokens,
|
||||
"threshold": context_window - reserve_tokens - 4000
|
||||
})
|
||||
|
||||
# TODO: Execute memory flush in background
|
||||
# This would require async support
|
||||
logger.info(f"Memory flush recommended at {current_tokens} tokens")
|
||||
|
||||
# Call LLM
|
||||
assistant_msg, tool_calls = self._call_llm_stream()
|
||||
final_response = assistant_msg
|
||||
|
||||
# No tool calls, end loop
|
||||
if not tool_calls:
|
||||
if assistant_msg:
|
||||
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
|
||||
logger.info(f"✅ 完成 (无工具调用)")
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": False
|
||||
})
|
||||
break
|
||||
|
||||
# Log tool calls in compact format
|
||||
tool_names = [tc['name'] for tc in tool_calls]
|
||||
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
|
||||
|
||||
# Execute tools
|
||||
tool_results = []
|
||||
tool_result_blocks = []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
result = self._execute_tool(tool_call)
|
||||
tool_results.append(result)
|
||||
|
||||
# Log tool result in compact format
|
||||
status_emoji = "✅" if result.get("status") == "success" else "❌"
|
||||
result_str = str(result.get('result', ''))
|
||||
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
||||
|
||||
# Build tool result block (Claude format)
|
||||
# Content should be a string representation of the result
|
||||
result_content = json.dumps(result) if not isinstance(result, str) else result
|
||||
tool_result_blocks.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call["id"],
|
||||
"content": result_content
|
||||
})
|
||||
|
||||
# Add tool results to message history as user message (Claude format)
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": tool_result_blocks
|
||||
})
|
||||
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": True,
|
||||
"tool_count": len(tool_calls)
|
||||
})
|
||||
|
||||
if turn >= self.max_turns:
|
||||
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Agent执行错误: {e}")
|
||||
self._emit_event("error", {"error": str(e)})
|
||||
raise
|
||||
|
||||
finally:
|
||||
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
|
||||
self._emit_event("agent_end", {"final_response": final_response})
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
|
||||
"""
|
||||
Call LLM with streaming
|
||||
|
||||
Returns:
|
||||
(response_text, tool_calls)
|
||||
"""
|
||||
# Trim messages if needed (using agent's context management)
|
||||
self._trim_messages()
|
||||
|
||||
# Prepare messages
|
||||
messages = self._prepare_messages()
|
||||
|
||||
# Debug: log message structure
|
||||
logger.debug(f"Sending {len(messages)} messages to LLM")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
content_types = [c.get("type") for c in content if isinstance(c, dict)]
|
||||
logger.debug(f" Message {i}: role={role}, content_blocks={content_types}")
|
||||
else:
|
||||
logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}")
|
||||
|
||||
# Prepare tool definitions (OpenAI/Claude format)
|
||||
tools_schema = None
|
||||
if self.tools:
|
||||
tools_schema = []
|
||||
for tool in self.tools.values():
|
||||
tools_schema.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.params # Claude uses input_schema
|
||||
})
|
||||
|
||||
# Create request
|
||||
request = LLMRequest(
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
stream=True,
|
||||
tools=tools_schema,
|
||||
system=self.system_prompt # Pass system prompt separately for Claude API
|
||||
)
|
||||
|
||||
self._emit_event("message_start", {"role": "assistant"})
|
||||
|
||||
# Streaming response
|
||||
full_content = ""
|
||||
tool_calls_buffer = {} # {index: {id, name, arguments}}
|
||||
|
||||
try:
|
||||
stream = self.model.call_stream(request)
|
||||
|
||||
for chunk in stream:
|
||||
# Check for errors
|
||||
if isinstance(chunk, dict) and chunk.get("error"):
|
||||
error_msg = chunk.get("message", "Unknown error")
|
||||
status_code = chunk.get("status_code", "N/A")
|
||||
logger.error(f"API Error: {error_msg} (Status: {status_code})")
|
||||
logger.error(f"Full error chunk: {chunk}")
|
||||
raise Exception(f"{error_msg} (Status: {status_code})")
|
||||
|
||||
# Parse chunk
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choice = chunk["choices"][0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Handle text content
|
||||
if "content" in delta and delta["content"]:
|
||||
content_delta = delta["content"]
|
||||
full_content += content_delta
|
||||
self._emit_event("message_update", {"delta": content_delta})
|
||||
|
||||
# Handle tool calls
|
||||
if "tool_calls" in delta:
|
||||
for tc_delta in delta["tool_calls"]:
|
||||
index = tc_delta.get("index", 0)
|
||||
|
||||
if index not in tool_calls_buffer:
|
||||
tool_calls_buffer[index] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": ""
|
||||
}
|
||||
|
||||
if "id" in tc_delta:
|
||||
tool_calls_buffer[index]["id"] = tc_delta["id"]
|
||||
|
||||
if "function" in tc_delta:
|
||||
func = tc_delta["function"]
|
||||
if "name" in func:
|
||||
tool_calls_buffer[index]["name"] = func["name"]
|
||||
if "arguments" in func:
|
||||
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call error: {e}")
|
||||
raise
|
||||
|
||||
# Parse tool calls
|
||||
tool_calls = []
|
||||
for idx in sorted(tool_calls_buffer.keys()):
|
||||
tc = tool_calls_buffer[idx]
|
||||
try:
|
||||
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
|
||||
arguments = {}
|
||||
|
||||
tool_calls.append({
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
# Add assistant message to history (Claude format uses content blocks)
|
||||
assistant_msg = {"role": "assistant", "content": []}
|
||||
|
||||
# Add text content block if present
|
||||
if full_content:
|
||||
assistant_msg["content"].append({
|
||||
"type": "text",
|
||||
"text": full_content
|
||||
})
|
||||
|
||||
# Add tool_use blocks if present
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
assistant_msg["content"].append({
|
||||
"type": "tool_use",
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"input": tc["arguments"]
|
||||
})
|
||||
|
||||
# Only append if content is not empty
|
||||
if assistant_msg["content"]:
|
||||
self.messages.append(assistant_msg)
|
||||
|
||||
self._emit_event("message_end", {
|
||||
"content": full_content,
|
||||
"tool_calls": tool_calls
|
||||
})
|
||||
|
||||
return full_content, tool_calls
|
||||
|
||||
def _execute_tool(self, tool_call: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute tool
|
||||
|
||||
Args:
|
||||
tool_call: {"id": str, "name": str, "arguments": dict}
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
tool_name = tool_call["name"]
|
||||
tool_id = tool_call["id"]
|
||||
arguments = tool_call["arguments"]
|
||||
|
||||
self._emit_event("tool_execution_start", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
try:
|
||||
tool = self.tools.get(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
# Set tool context
|
||||
tool.model = self.model
|
||||
tool.context = self.agent
|
||||
|
||||
# Execute tool
|
||||
start_time = time.time()
|
||||
result: ToolResult = tool.execute_tool(arguments)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
result_dict = {
|
||||
"status": result.status,
|
||||
"result": result.result,
|
||||
"execution_time": execution_time
|
||||
}
|
||||
|
||||
self._emit_event("tool_execution_end", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
**result_dict
|
||||
})
|
||||
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"result": str(e),
|
||||
"execution_time": 0
|
||||
}
|
||||
self._emit_event("tool_execution_end", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
**error_result
|
||||
})
|
||||
return error_result
|
||||
|
||||
def _trim_messages(self):
|
||||
"""
|
||||
Trim message history to stay within context limits.
|
||||
Uses agent's context management configuration.
|
||||
"""
|
||||
if not self.messages or not self.agent:
|
||||
return
|
||||
|
||||
# Get context window and reserve tokens from agent
|
||||
context_window = self.agent._get_model_context_window()
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
max_tokens = context_window - reserve_tokens
|
||||
|
||||
# Estimate current tokens
|
||||
current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
|
||||
|
||||
# Add system prompt tokens
|
||||
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
|
||||
current_tokens += system_tokens
|
||||
|
||||
# If under limit, no need to trim
|
||||
if current_tokens <= max_tokens:
|
||||
return
|
||||
|
||||
# Keep messages from newest, accumulating tokens
|
||||
available_tokens = max_tokens - system_tokens
|
||||
kept_messages = []
|
||||
accumulated_tokens = 0
|
||||
|
||||
for msg in reversed(self.messages):
|
||||
msg_tokens = self.agent._estimate_message_tokens(msg)
|
||||
if accumulated_tokens + msg_tokens <= available_tokens:
|
||||
kept_messages.insert(0, msg)
|
||||
accumulated_tokens += msg_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
old_count = len(self.messages)
|
||||
self.messages = kept_messages
|
||||
new_count = len(self.messages)
|
||||
|
||||
if old_count > new_count:
|
||||
logger.info(
|
||||
f"Context trimmed: {old_count} -> {new_count} messages "
|
||||
f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
|
||||
f"limit: {max_tokens})"
|
||||
)
|
||||
|
||||
def _prepare_messages(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Prepare messages to send to LLM
|
||||
|
||||
Note: For Claude API, system prompt should be passed separately via system parameter,
|
||||
not as a message. The AgentLLMModel will handle this.
|
||||
"""
|
||||
# Don't add system message here - it will be handled separately by the LLM adapter
|
||||
return self.messages
|
||||
27
agent/protocol/context.py
Normal file
27
agent/protocol/context.py
Normal file
@@ -0,0 +1,27 @@
|
||||
class TeamContext:
|
||||
def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100):
|
||||
"""
|
||||
Initialize the TeamContext with a name, description, rules, a list of agents, and a user question.
|
||||
:param name: The name of the group context.
|
||||
:param description: A description of the group context.
|
||||
:param rule: The rules governing the group context.
|
||||
:param agents: A list of agents in the context.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.rule = rule
|
||||
self.agents = agents
|
||||
self.user_task = "" # For backward compatibility
|
||||
self.task = None # Will be a Task instance
|
||||
self.model = None # Will be an instance of LLMModel
|
||||
self.task_short_name = None # Store the task directory name
|
||||
# List of agents that have been executed
|
||||
self.agent_outputs: list = []
|
||||
self.current_steps = 0
|
||||
self.max_steps = max_steps
|
||||
|
||||
|
||||
class AgentOutput:
|
||||
def __init__(self, agent_name: str, output: str):
|
||||
self.agent_name = agent_name
|
||||
self.output = output
|
||||
57
agent/protocol/models.py
Normal file
57
agent/protocol/models.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Models module for agent system.
|
||||
Provides basic model classes needed by tools and bridge integration.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
"""Request model for LLM operations"""
|
||||
|
||||
def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None,
|
||||
temperature: float = 0.7, max_tokens: Optional[int] = None,
|
||||
stream: bool = False, tools: Optional[List] = None, **kwargs):
|
||||
self.messages = messages or []
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.tools = tools
|
||||
# Allow extra attributes
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LLMModel:
|
||||
"""Base class for LLM models"""
|
||||
|
||||
def __init__(self, model: str = None, **kwargs):
|
||||
self.model = model
|
||||
self.config = kwargs
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with a request.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call not implemented in this context")
|
||||
|
||||
def call_stream(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with streaming.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call_stream not implemented in this context")
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
"""Factory for creating model instances"""
|
||||
|
||||
@staticmethod
|
||||
def create_model(model_type: str, **kwargs):
|
||||
"""
|
||||
Create a model instance based on type.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("ModelFactory.create_model not implemented in this context")
|
||||
96
agent/protocol/result.py
Normal file
96
agent/protocol/result.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from agent.protocol.task import Task, TaskStatus
|
||||
|
||||
|
||||
class AgentActionType(Enum):
|
||||
"""Enum representing different types of agent actions."""
|
||||
TOOL_USE = "tool_use"
|
||||
THINKING = "thinking"
|
||||
FINAL_ANSWER = "final_answer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""
|
||||
Represents the result of a tool use.
|
||||
|
||||
Attributes:
|
||||
tool_name: Name of the tool used
|
||||
input_params: Parameters passed to the tool
|
||||
output: Output from the tool
|
||||
status: Status of the tool execution (success/error)
|
||||
error_message: Error message if the tool execution failed
|
||||
execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_name: str
|
||||
input_params: Dict[str, Any]
|
||||
output: Any
|
||||
status: str
|
||||
error_message: Optional[str] = None
|
||||
execution_time: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""
|
||||
Represents an action taken by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the action
|
||||
agent_id: ID of the agent that performed the action
|
||||
agent_name: Name of the agent that performed the action
|
||||
action_type: Type of action (tool use, thinking, final answer)
|
||||
content: Content of the action (thought content, final answer content)
|
||||
tool_result: Tool use details if action_type is TOOL_USE
|
||||
timestamp: When the action was performed
|
||||
"""
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
action_type: AgentActionType
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
tool_result: Optional[ToolResult] = None
|
||||
thought: Optional[str] = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""
|
||||
Represents the result of an agent's execution.
|
||||
|
||||
Attributes:
|
||||
final_answer: The final answer provided by the agent
|
||||
step_count: Number of steps taken by the agent
|
||||
status: Status of the execution (success/error)
|
||||
error_message: Error message if execution failed
|
||||
"""
|
||||
final_answer: str
|
||||
step_count: int
|
||||
status: str = "success"
|
||||
error_message: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def success(cls, final_answer: str, step_count: int) -> "AgentResult":
|
||||
"""Create a successful result"""
|
||||
return cls(final_answer=final_answer, step_count=step_count)
|
||||
|
||||
@classmethod
|
||||
def error(cls, error_message: str, step_count: int = 0) -> "AgentResult":
|
||||
"""Create an error result"""
|
||||
return cls(
|
||||
final_answer=f"Error: {error_message}",
|
||||
step_count=step_count,
|
||||
status="error",
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
@property
|
||||
def is_error(self) -> bool:
|
||||
"""Check if the result represents an error"""
|
||||
return self.status == "error"
|
||||
95
agent/protocol/task.py
Normal file
95
agent/protocol/task.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""Enum representing different types of tasks."""
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
FILE = "file"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Enum representing the status of a task."""
|
||||
INIT = "init" # Initial state
|
||||
PROCESSING = "processing" # In progress
|
||||
COMPLETED = "completed" # Completed
|
||||
FAILED = "failed" # Failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""
|
||||
Represents a task to be processed by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the task
|
||||
content: The primary text content of the task
|
||||
type: Type of the task
|
||||
status: Current status of the task
|
||||
created_at: Timestamp when the task was created
|
||||
updated_at: Timestamp when the task was last updated
|
||||
metadata: Additional metadata for the task
|
||||
images: List of image URLs or base64 encoded images
|
||||
videos: List of video URLs
|
||||
audios: List of audio URLs or base64 encoded audios
|
||||
files: List of file URLs or paths
|
||||
"""
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
type: TaskType = TaskType.TEXT
|
||||
status: TaskStatus = TaskStatus.INIT
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Media content
|
||||
images: List[str] = field(default_factory=list)
|
||||
videos: List[str] = field(default_factory=list)
|
||||
audios: List[str] = field(default_factory=list)
|
||||
files: List[str] = field(default_factory=list)
|
||||
|
||||
def __init__(self, content: str = "", **kwargs):
|
||||
"""
|
||||
Initialize a Task with content and optional keyword arguments.
|
||||
|
||||
Args:
|
||||
content: The text content of the task
|
||||
**kwargs: Additional attributes to set
|
||||
"""
|
||||
self.id = kwargs.get('id', str(uuid.uuid4()))
|
||||
self.content = content
|
||||
self.type = kwargs.get('type', TaskType.TEXT)
|
||||
self.status = kwargs.get('status', TaskStatus.INIT)
|
||||
self.created_at = kwargs.get('created_at', time.time())
|
||||
self.updated_at = kwargs.get('updated_at', time.time())
|
||||
self.metadata = kwargs.get('metadata', {})
|
||||
self.images = kwargs.get('images', [])
|
||||
self.videos = kwargs.get('videos', [])
|
||||
self.audios = kwargs.get('audios', [])
|
||||
self.files = kwargs.get('files', [])
|
||||
|
||||
def get_text(self) -> str:
|
||||
"""
|
||||
Get the text content of the task.
|
||||
|
||||
Returns:
|
||||
The text content
|
||||
"""
|
||||
return self.content
|
||||
|
||||
def update_status(self, status: TaskStatus) -> None:
|
||||
"""
|
||||
Update the status of the task.
|
||||
|
||||
Args:
|
||||
status: The new status
|
||||
"""
|
||||
self.status = status
|
||||
self.updated_at = time.time()
|
||||
Reference in New Issue
Block a user