feat: personal ai agent framework

This commit is contained in:
saboteur7
2026-01-30 09:53:46 +08:00
parent 25cf6823d0
commit bb850bb6c5
62 changed files with 7675 additions and 275 deletions

3
.gitignore vendored
View File

@@ -35,3 +35,6 @@ plugins/banwords/lib/__pycache__
!plugins/linkai
!plugins/agent
client_config.json
ref/
.cursor/
local/

10
agent/memory/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
"""
Memory module for AgentMesh
Provides long-term memory capabilities with hybrid search (vector + keyword)
"""
from agent.memory.manager import MemoryManager
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config']

139
agent/memory/chunker.py Normal file
View File

@@ -0,0 +1,139 @@
"""
Text chunking utilities for memory
Splits text into chunks with token limits and overlap
"""
from typing import List, Tuple
from dataclasses import dataclass
@dataclass
class TextChunk:
"""Represents a text chunk with line numbers"""
text: str
start_line: int
end_line: int
class TextChunker:
"""Chunks text by line count with token estimation"""
def __init__(self, max_tokens: int = 500, overlap_tokens: int = 50):
"""
Initialize chunker
Args:
max_tokens: Maximum tokens per chunk
overlap_tokens: Overlap tokens between chunks
"""
self.max_tokens = max_tokens
self.overlap_tokens = overlap_tokens
# Rough estimation: ~4 chars per token for English/Chinese mixed
self.chars_per_token = 4
def chunk_text(self, text: str) -> List[TextChunk]:
"""
Chunk text into overlapping segments
Args:
text: Input text to chunk
Returns:
List of TextChunk objects
"""
if not text.strip():
return []
lines = text.split('\n')
chunks = []
max_chars = self.max_tokens * self.chars_per_token
overlap_chars = self.overlap_tokens * self.chars_per_token
current_chunk = []
current_chars = 0
start_line = 1
for i, line in enumerate(lines, start=1):
line_chars = len(line)
# If single line exceeds max, split it
if line_chars > max_chars:
# Save current chunk if exists
if current_chunk:
chunks.append(TextChunk(
text='\n'.join(current_chunk),
start_line=start_line,
end_line=i - 1
))
current_chunk = []
current_chars = 0
# Split long line into multiple chunks
for sub_chunk in self._split_long_line(line, max_chars):
chunks.append(TextChunk(
text=sub_chunk,
start_line=i,
end_line=i
))
start_line = i + 1
continue
# Check if adding this line would exceed limit
if current_chars + line_chars > max_chars and current_chunk:
# Save current chunk
chunks.append(TextChunk(
text='\n'.join(current_chunk),
start_line=start_line,
end_line=i - 1
))
# Start new chunk with overlap
overlap_lines = self._get_overlap_lines(current_chunk, overlap_chars)
current_chunk = overlap_lines + [line]
current_chars = sum(len(l) for l in current_chunk)
start_line = i - len(overlap_lines)
else:
# Add line to current chunk
current_chunk.append(line)
current_chars += line_chars
# Save last chunk
if current_chunk:
chunks.append(TextChunk(
text='\n'.join(current_chunk),
start_line=start_line,
end_line=len(lines)
))
return chunks
def _split_long_line(self, line: str, max_chars: int) -> List[str]:
"""Split a single long line into multiple chunks"""
chunks = []
for i in range(0, len(line), max_chars):
chunks.append(line[i:i + max_chars])
return chunks
def _get_overlap_lines(self, lines: List[str], target_chars: int) -> List[str]:
"""Get last few lines that fit within target_chars for overlap"""
overlap = []
chars = 0
for line in reversed(lines):
line_chars = len(line)
if chars + line_chars > target_chars:
break
overlap.insert(0, line)
chars += line_chars
return overlap
def chunk_markdown(self, text: str) -> List[TextChunk]:
"""
Chunk markdown text while respecting structure
(For future enhancement: respect markdown sections)
"""
return self.chunk_text(text)

114
agent/memory/config.py Normal file
View File

@@ -0,0 +1,114 @@
"""
Memory configuration module
Provides global memory configuration with simplified workspace structure
"""
import os
from dataclasses import dataclass, field
from typing import Optional, List
from pathlib import Path
@dataclass
class MemoryConfig:
"""Configuration for memory storage and search"""
# Storage paths (default: ~/cow)
workspace_root: str = field(default_factory=lambda: os.path.expanduser("~/cow"))
# Embedding config
embedding_provider: str = "openai" # "openai" | "local"
embedding_model: str = "text-embedding-3-small"
embedding_dim: int = 1536
# Chunking config
chunk_max_tokens: int = 500
chunk_overlap_tokens: int = 50
# Search config
max_results: int = 10
min_score: float = 0.3
# Hybrid search weights
vector_weight: float = 0.7
keyword_weight: float = 0.3
# Memory sources
sources: List[str] = field(default_factory=lambda: ["memory", "session"])
# Sync config
enable_auto_sync: bool = True
sync_on_search: bool = True
def get_workspace(self) -> Path:
"""Get workspace root directory"""
return Path(self.workspace_root)
def get_memory_dir(self) -> Path:
"""Get memory files directory"""
return self.get_workspace() / "memory"
def get_db_path(self) -> Path:
"""Get SQLite database path for long-term memory index"""
index_dir = self.get_memory_dir() / "long-term"
index_dir.mkdir(parents=True, exist_ok=True)
return index_dir / "index.db"
def get_skills_dir(self) -> Path:
"""Get skills directory"""
return self.get_workspace() / "skills"
def get_agent_workspace(self, agent_name: Optional[str] = None) -> Path:
"""
Get workspace directory for an agent
Args:
agent_name: Optional agent name (not used in current implementation)
Returns:
Path to workspace directory
"""
workspace = self.get_workspace()
# Ensure workspace directory exists
workspace.mkdir(parents=True, exist_ok=True)
return workspace
# Global memory configuration
_global_memory_config: Optional[MemoryConfig] = None
def get_default_memory_config() -> MemoryConfig:
"""
Get the global memory configuration.
If not set, returns a default configuration.
Returns:
MemoryConfig instance
"""
global _global_memory_config
if _global_memory_config is None:
_global_memory_config = MemoryConfig()
return _global_memory_config
def set_global_memory_config(config: MemoryConfig):
"""
Set the global memory configuration.
This should be called before creating any MemoryManager instances.
Args:
config: MemoryConfig instance to use globally
Example:
>>> from agent.memory import MemoryConfig, set_global_memory_config
>>> config = MemoryConfig(
... workspace_root="~/my_agents",
... embedding_provider="openai",
... vector_weight=0.8
... )
>>> set_global_memory_config(config)
"""
global _global_memory_config
_global_memory_config = config

175
agent/memory/embedding.py Normal file
View File

@@ -0,0 +1,175 @@
"""
Embedding providers for memory
Supports OpenAI and local embedding models
"""
from typing import List, Optional
from abc import ABC, abstractmethod
import hashlib
import json
class EmbeddingProvider(ABC):
"""Base class for embedding providers"""
@abstractmethod
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
pass
@abstractmethod
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
pass
@property
@abstractmethod
def dimensions(self) -> int:
"""Get embedding dimensions"""
pass
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""OpenAI embedding provider"""
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
"""
Initialize OpenAI embedding provider
Args:
model: Model name (text-embedding-3-small or text-embedding-3-large)
api_key: OpenAI API key
api_base: Optional API base URL
"""
self.model = model
self.api_key = api_key
self.api_base = api_base or "https://api.openai.com/v1"
# Lazy import to avoid dependency issues
try:
from openai import OpenAI
self.client = OpenAI(api_key=api_key, base_url=api_base)
except ImportError:
raise ImportError("OpenAI package not installed. Install with: pip install openai")
# Set dimensions based on model
self._dimensions = 1536 if "small" in model else 3072
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
response = self.client.embeddings.create(
input=text,
model=self.model
)
return response.data[0].embedding
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
if not texts:
return []
response = self.client.embeddings.create(
input=texts,
model=self.model
)
return [item.embedding for item in response.data]
@property
def dimensions(self) -> int:
return self._dimensions
class LocalEmbeddingProvider(EmbeddingProvider):
"""Local embedding provider using sentence-transformers"""
def __init__(self, model: str = "all-MiniLM-L6-v2"):
"""
Initialize local embedding provider
Args:
model: Model name from sentence-transformers
"""
self.model_name = model
try:
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(model)
self._dimensions = self.model.get_sentence_embedding_dimension()
except ImportError:
raise ImportError(
"sentence-transformers not installed. "
"Install with: pip install sentence-transformers"
)
def embed(self, text: str) -> List[float]:
"""Generate embedding for text"""
embedding = self.model.encode(text, convert_to_numpy=True)
return embedding.tolist()
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings for multiple texts"""
if not texts:
return []
embeddings = self.model.encode(texts, convert_to_numpy=True)
return embeddings.tolist()
@property
def dimensions(self) -> int:
return self._dimensions
class EmbeddingCache:
"""Cache for embeddings to avoid recomputation"""
def __init__(self):
self.cache = {}
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
"""Get cached embedding"""
key = self._compute_key(text, provider, model)
return self.cache.get(key)
def put(self, text: str, provider: str, model: str, embedding: List[float]):
"""Cache embedding"""
key = self._compute_key(text, provider, model)
self.cache[key] = embedding
@staticmethod
def _compute_key(text: str, provider: str, model: str) -> str:
"""Compute cache key"""
content = f"{provider}:{model}:{text}"
return hashlib.md5(content.encode('utf-8')).hexdigest()
def clear(self):
"""Clear cache"""
self.cache.clear()
def create_embedding_provider(
provider: str = "openai",
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None
) -> EmbeddingProvider:
"""
Factory function to create embedding provider
Args:
provider: Provider name ("openai" or "local")
model: Model name (provider-specific)
api_key: API key for remote providers
api_base: API base URL for remote providers
Returns:
EmbeddingProvider instance
"""
if provider == "openai":
model = model or "text-embedding-3-small"
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
elif provider == "local":
model = model or "all-MiniLM-L6-v2"
return LocalEmbeddingProvider(model=model)
else:
raise ValueError(f"Unknown embedding provider: {provider}")

623
agent/memory/manager.py Normal file
View File

@@ -0,0 +1,623 @@
"""
Memory manager for AgentMesh
Provides high-level interface for memory operations
"""
import os
from typing import List, Optional, Dict, Any
from pathlib import Path
import hashlib
from datetime import datetime, timedelta
from agent.memory.config import MemoryConfig, get_default_memory_config
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
from agent.memory.chunker import TextChunker
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
class MemoryManager:
"""
Memory manager with hybrid search capabilities
Provides long-term memory for agents with vector and keyword search
"""
def __init__(
self,
config: Optional[MemoryConfig] = None,
embedding_provider: Optional[EmbeddingProvider] = None,
llm_model: Optional[Any] = None
):
"""
Initialize memory manager
Args:
config: Memory configuration (uses global config if not provided)
embedding_provider: Custom embedding provider (optional)
llm_model: LLM model for summarization (optional)
"""
self.config = config or get_default_memory_config()
# Initialize storage
db_path = self.config.get_db_path()
self.storage = MemoryStorage(db_path)
# Initialize chunker
self.chunker = TextChunker(
max_tokens=self.config.chunk_max_tokens,
overlap_tokens=self.config.chunk_overlap_tokens
)
# Initialize embedding provider (optional)
self.embedding_provider = None
if embedding_provider:
self.embedding_provider = embedding_provider
else:
# Try to create embedding provider, but allow failure
try:
# Get API key from environment or config
api_key = os.environ.get('OPENAI_API_KEY')
api_base = os.environ.get('OPENAI_API_BASE')
self.embedding_provider = create_embedding_provider(
provider=self.config.embedding_provider,
model=self.config.embedding_model,
api_key=api_key,
api_base=api_base
)
except Exception as e:
# Embedding provider failed, but that's OK
# We can still use keyword search and file operations
print(f"⚠️ Warning: Embedding provider initialization failed: {e}")
print(f" Memory will work with keyword search only (no semantic search)")
# Initialize memory flush manager
workspace_dir = self.config.get_workspace()
self.flush_manager = MemoryFlushManager(
workspace_dir=workspace_dir,
llm_model=llm_model
)
# Ensure workspace directories exist
self._init_workspace()
self._dirty = False
def _init_workspace(self):
"""Initialize workspace directories"""
memory_dir = self.config.get_memory_dir()
memory_dir.mkdir(parents=True, exist_ok=True)
# Create default memory files
workspace_dir = self.config.get_workspace()
create_memory_files_if_needed(workspace_dir)
async def search(
self,
query: str,
user_id: Optional[str] = None,
max_results: Optional[int] = None,
min_score: Optional[float] = None,
include_shared: bool = True
) -> List[SearchResult]:
"""
Search memory with hybrid search (vector + keyword)
Args:
query: Search query
user_id: User ID for scoped search
max_results: Maximum results to return
min_score: Minimum score threshold
include_shared: Include shared memories
Returns:
List of search results sorted by relevance
"""
max_results = max_results or self.config.max_results
min_score = min_score or self.config.min_score
# Determine scopes
scopes = []
if include_shared:
scopes.append("shared")
if user_id:
scopes.append("user")
if not scopes:
return []
# Sync if needed
if self.config.sync_on_search and self._dirty:
await self.sync()
# Perform vector search (if embedding provider available)
vector_results = []
if self.embedding_provider:
query_embedding = self.embedding_provider.embed(query)
vector_results = self.storage.search_vector(
query_embedding=query_embedding,
user_id=user_id,
scopes=scopes,
limit=max_results * 2 # Get more candidates for merging
)
# Perform keyword search
keyword_results = self.storage.search_keyword(
query=query,
user_id=user_id,
scopes=scopes,
limit=max_results * 2
)
# Merge results
merged = self._merge_results(
vector_results,
keyword_results,
self.config.vector_weight,
self.config.keyword_weight
)
# Filter by min score and limit
filtered = [r for r in merged if r.score >= min_score]
return filtered[:max_results]
async def add_memory(
self,
content: str,
user_id: Optional[str] = None,
scope: str = "shared",
source: str = "memory",
path: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""
Add new memory content
Args:
content: Memory content
user_id: User ID for user-scoped memory
scope: Memory scope ("shared", "user", "session")
source: Memory source ("memory" or "session")
path: File path (auto-generated if not provided)
metadata: Additional metadata
"""
if not content.strip():
return
# Generate path if not provided
if not path:
content_hash = hashlib.md5(content.encode('utf-8')).hexdigest()[:8]
if user_id and scope == "user":
path = f"memory/users/{user_id}/memory_{content_hash}.md"
else:
path = f"memory/shared/memory_{content_hash}.md"
# Chunk content
chunks = self.chunker.chunk_text(content)
# Generate embeddings (if provider available)
texts = [chunk.text for chunk in chunks]
if self.embedding_provider:
embeddings = self.embedding_provider.embed_batch(texts)
else:
# No embeddings, just use None
embeddings = [None] * len(texts)
# Create memory chunks
memory_chunks = []
for chunk, embedding in zip(chunks, embeddings):
chunk_id = self._generate_chunk_id(path, chunk.start_line, chunk.end_line)
chunk_hash = MemoryStorage.compute_hash(chunk.text)
memory_chunks.append(MemoryChunk(
id=chunk_id,
agent_id="default",
user_id=user_id,
scope=scope,
source=source,
path=path,
start_line=chunk.start_line,
end_line=chunk.end_line,
text=chunk.text,
embedding=embedding,
hash=chunk_hash,
metadata=metadata
))
# Save to storage
self.storage.save_chunks_batch(memory_chunks)
# Update file metadata
file_hash = MemoryStorage.compute_hash(content)
self.storage.update_file_metadata(
path=path,
source=source,
file_hash=file_hash,
mtime=int(os.path.getmtime(__file__)), # Use current time
size=len(content)
)
async def sync(self, force: bool = False):
"""
Synchronize memory from files
Args:
force: Force full reindex
"""
memory_dir = self.config.get_memory_dir()
workspace_dir = self.config.get_workspace()
# Scan memory/MEMORY.md
memory_file = memory_dir / "MEMORY.md"
if memory_file.exists():
await self._sync_file(memory_file, "memory", "shared", None)
# Scan memory directory (including daily summaries)
if memory_dir.exists():
for file_path in memory_dir.rglob("*.md"):
# Determine scope and user_id from path
rel_path = file_path.relative_to(workspace_dir)
parts = rel_path.parts
# Check if it's in daily summary directory
if "daily" in parts:
# Daily summary files
if "users" in parts or len(parts) > 3:
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
user_idx = parts.index("daily") + 1
user_id = parts[user_idx] if user_idx < len(parts) else None
scope = "user"
else:
# Shared daily summary: memory/daily/2024-01-29.md
user_id = None
scope = "shared"
elif "users" in parts:
# User-scoped memory
user_idx = parts.index("users") + 1
user_id = parts[user_idx] if user_idx < len(parts) else None
scope = "user"
else:
# Shared memory
user_id = None
scope = "shared"
await self._sync_file(file_path, "memory", scope, user_id)
self._dirty = False
async def _sync_file(
self,
file_path: Path,
source: str,
scope: str,
user_id: Optional[str]
):
"""Sync a single file"""
# Compute file hash
content = file_path.read_text()
file_hash = MemoryStorage.compute_hash(content)
# Get relative path
workspace_dir = self.config.get_workspace()
rel_path = str(file_path.relative_to(workspace_dir))
# Check if file changed
stored_hash = self.storage.get_file_hash(rel_path)
if stored_hash == file_hash:
return # No changes
# Delete old chunks
self.storage.delete_by_path(rel_path)
# Chunk and embed
chunks = self.chunker.chunk_text(content)
if not chunks:
return
texts = [chunk.text for chunk in chunks]
if self.embedding_provider:
embeddings = self.embedding_provider.embed_batch(texts)
else:
embeddings = [None] * len(texts)
# Create memory chunks
memory_chunks = []
for chunk, embedding in zip(chunks, embeddings):
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
chunk_hash = MemoryStorage.compute_hash(chunk.text)
memory_chunks.append(MemoryChunk(
id=chunk_id,
agent_id="default",
user_id=user_id,
scope=scope,
source=source,
path=rel_path,
start_line=chunk.start_line,
end_line=chunk.end_line,
text=chunk.text,
embedding=embedding,
hash=chunk_hash,
metadata=None
))
# Save
self.storage.save_chunks_batch(memory_chunks)
# Update file metadata
stat = file_path.stat()
self.storage.update_file_metadata(
path=rel_path,
source=source,
file_hash=file_hash,
mtime=int(stat.st_mtime),
size=stat.st_size
)
def should_flush_memory(
self,
current_tokens: int,
context_window: int = 128000,
reserve_tokens: int = 20000,
soft_threshold: int = 4000
) -> bool:
"""
Check if memory flush should be triggered
Args:
current_tokens: Current session token count
context_window: Model's context window size (default: 128K)
reserve_tokens: Reserve tokens for compaction overhead (default: 20K)
soft_threshold: Trigger N tokens before threshold (default: 4K)
Returns:
True if memory flush should run
"""
return self.flush_manager.should_flush(
current_tokens=current_tokens,
context_window=context_window,
reserve_tokens=reserve_tokens,
soft_threshold=soft_threshold
)
async def execute_memory_flush(
self,
agent_executor,
current_tokens: int,
user_id: Optional[str] = None,
**executor_kwargs
) -> bool:
"""
Execute memory flush before compaction
This runs a silent agent turn to write durable memories to disk.
Similar to clawdbot's pre-compaction memory flush.
Args:
agent_executor: Async function to execute agent with prompt
current_tokens: Current session token count
user_id: Optional user ID
**executor_kwargs: Additional kwargs for agent executor
Returns:
True if flush completed successfully
Example:
>>> async def run_agent(prompt, system_prompt, silent=False):
... # Your agent execution logic
... pass
>>>
>>> if manager.should_flush_memory(current_tokens=100000):
... await manager.execute_memory_flush(
... agent_executor=run_agent,
... current_tokens=100000
... )
"""
success = await self.flush_manager.execute_flush(
agent_executor=agent_executor,
current_tokens=current_tokens,
user_id=user_id,
**executor_kwargs
)
if success:
# Mark dirty so next search will sync the new memories
self._dirty = True
return success
def build_memory_guidance(self, lang: str = "en", include_context: bool = True) -> str:
"""
Build natural memory guidance for agent system prompt
Following clawdbot's approach:
1. Load MEMORY.md as bootstrap context (blends into background)
2. Load daily files on-demand via memory_search tool
3. Agent should NOT proactively mention memories unless user asks
Args:
lang: Language for guidance ("en" or "zh")
include_context: Whether to include bootstrap memory context (default: True)
MEMORY.md is loaded as background context (like clawdbot)
Daily files are accessed via memory_search tool
Returns:
Memory guidance text (and optionally context) for system prompt
"""
today_file = self.flush_manager.get_today_memory_file().name
if lang == "zh":
guidance = f"""## 记忆召回
回答关于过去工作、决策、日期、人物、偏好或待办事项的问题前:先用 memory_search 搜索 MEMORY.md + memory/*.md然后用 memory_get 只读取需要的行。如果搜索后仍不确定,说明你已检查过。
## 记忆存储
当用户分享持久偏好、决策或重要事实时(无论是否明确要求"记住"),主动存储:
- 持久信息(偏好、决策、人物信息)→ memory/MEMORY.md
- 当天的笔记和上下文 → memory/{today_file}
- 静默存储,仅在用户明确要求时确认
## 记忆使用原则
- 不要主动提起或列举记忆内容
- 只在用户明确询问相关信息时才使用记忆
- 记忆是背景知识,不是要展示的内容
- 自然使用记忆,就像你本来就知道这些信息"""
else:
guidance = f"""## Memory Recall
Before answering anything about prior work, decisions, dates, people, preferences, or todos: run memory_search on MEMORY.md + memory/*.md; then use memory_get to pull only the needed lines. If low confidence after search, say you checked.
## Memory Storage
When user shares durable preferences, decisions, or important facts (whether or not they explicitly say "remember"), proactively store:
- Durable info (preferences, decisions, people) → memory/MEMORY.md
- Daily notes and context → memory/{today_file}
- Store silently; only confirm when explicitly requested
## Memory Usage Principles
- Don't proactively mention or list memory contents
- Only use memories when user explicitly asks about them
- Memories are background knowledge, not content to showcase
- Use memories naturally as if you inherently knew this information"""
if include_context:
# Load bootstrap context (MEMORY.md only, like clawdbot)
bootstrap_context = self.load_bootstrap_memories()
if bootstrap_context:
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
return guidance
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
"""
Load bootstrap memory files for session start
Following clawdbot's design:
- Only loads memory/MEMORY.md (long-term curated memory)
- Daily files (YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
- User-specific MEMORY.md is also loaded if user_id provided
Returns memory content WITHOUT obvious headers so it blends naturally
into the context as background knowledge.
Args:
user_id: Optional user ID for user-specific memories
Returns:
Memory content to inject into system prompt (blends naturally as background context)
"""
workspace_dir = self.config.get_workspace()
memory_dir = self.config.get_memory_dir()
sections = []
# 1. Load memory/MEMORY.md ONLY (long-term curated memory)
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
memory_file = memory_dir / "MEMORY.md"
if memory_file.exists():
try:
content = memory_file.read_text(encoding='utf-8').strip()
if content:
sections.append(content)
except Exception as e:
print(f"Warning: Failed to read memory/MEMORY.md: {e}")
# 2. Load user-specific MEMORY.md if user_id provided
if user_id:
user_memory_dir = memory_dir / "users" / user_id
user_memory_file = user_memory_dir / "MEMORY.md"
if user_memory_file.exists():
try:
content = user_memory_file.read_text(encoding='utf-8').strip()
if content:
sections.append(content)
except Exception as e:
print(f"Warning: Failed to read user memory: {e}")
if not sections:
return ""
# Join sections without obvious headers - let memories blend naturally
# This makes the agent feel like it "just knows" rather than "checking memory files"
return "\n\n".join(sections)
def get_status(self) -> Dict[str, Any]:
"""Get memory status"""
stats = self.storage.get_stats()
return {
'chunks': stats['chunks'],
'files': stats['files'],
'workspace': str(self.config.get_workspace()),
'dirty': self._dirty,
'embedding_enabled': self.embedding_provider is not None,
'embedding_provider': self.config.embedding_provider if self.embedding_provider else 'disabled',
'embedding_model': self.config.embedding_model if self.embedding_provider else 'N/A',
'search_mode': 'hybrid (vector + keyword)' if self.embedding_provider else 'keyword only (FTS5)'
}
def mark_dirty(self):
"""Mark memory as dirty (needs sync)"""
self._dirty = True
def close(self):
"""Close memory manager and release resources"""
self.storage.close()
# Helper methods
def _generate_chunk_id(self, path: str, start_line: int, end_line: int) -> str:
"""Generate unique chunk ID"""
content = f"{path}:{start_line}:{end_line}"
return hashlib.md5(content.encode('utf-8')).hexdigest()
def _merge_results(
self,
vector_results: List[SearchResult],
keyword_results: List[SearchResult],
vector_weight: float,
keyword_weight: float
) -> List[SearchResult]:
"""Merge vector and keyword search results"""
# Create a map by (path, start_line, end_line)
merged_map = {}
for result in vector_results:
key = (result.path, result.start_line, result.end_line)
merged_map[key] = {
'result': result,
'vector_score': result.score,
'keyword_score': 0.0
}
for result in keyword_results:
key = (result.path, result.start_line, result.end_line)
if key in merged_map:
merged_map[key]['keyword_score'] = result.score
else:
merged_map[key] = {
'result': result,
'vector_score': 0.0,
'keyword_score': result.score
}
# Calculate combined scores
merged_results = []
for entry in merged_map.values():
combined_score = (
vector_weight * entry['vector_score'] +
keyword_weight * entry['keyword_score']
)
result = entry['result']
merged_results.append(SearchResult(
path=result.path,
start_line=result.start_line,
end_line=result.end_line,
score=combined_score,
snippet=result.snippet,
source=result.source,
user_id=result.user_id
))
# Sort by score
merged_results.sort(key=lambda r: r.score, reverse=True)
return merged_results

418
agent/memory/storage.py Normal file
View File

@@ -0,0 +1,418 @@
"""
Storage layer for memory using SQLite + FTS5
Provides vector and keyword search capabilities
"""
import sqlite3
import json
import hashlib
from typing import List, Dict, Optional, Any
from pathlib import Path
from dataclasses import dataclass
@dataclass
class MemoryChunk:
"""Represents a memory chunk with text and embedding"""
id: str
user_id: Optional[str]
scope: str # "shared" | "user" | "session"
source: str # "memory" | "session"
path: str
start_line: int
end_line: int
text: str
embedding: Optional[List[float]]
hash: str
metadata: Optional[Dict[str, Any]] = None
@dataclass
class SearchResult:
"""Search result with score and snippet"""
path: str
start_line: int
end_line: int
score: float
snippet: str
source: str
user_id: Optional[str] = None
class MemoryStorage:
"""SQLite-based storage with FTS5 for keyword search"""
def __init__(self, db_path: Path):
self.db_path = db_path
self.conn: Optional[sqlite3.Connection] = None
self._init_db()
def _init_db(self):
"""Initialize database with schema"""
self.conn = sqlite3.connect(str(self.db_path))
self.conn.row_factory = sqlite3.Row
# Enable JSON support
self.conn.execute("PRAGMA journal_mode=WAL")
# Create chunks table with embeddings
self.conn.execute("""
CREATE TABLE IF NOT EXISTS chunks (
id TEXT PRIMARY KEY,
user_id TEXT,
scope TEXT NOT NULL DEFAULT 'shared',
source TEXT NOT NULL DEFAULT 'memory',
path TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
text TEXT NOT NULL,
embedding TEXT,
hash TEXT NOT NULL,
metadata TEXT,
created_at INTEGER DEFAULT (strftime('%s', 'now')),
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
)
""")
# Create indexes
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_user
ON chunks(user_id)
""")
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_scope
ON chunks(scope)
""")
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_hash
ON chunks(path, hash)
""")
# Create FTS5 virtual table for keyword search
self.conn.execute("""
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
text,
id UNINDEXED,
user_id UNINDEXED,
path UNINDEXED,
source UNINDEXED,
scope UNINDEXED,
content='chunks',
content_rowid='rowid'
)
""")
# Create triggers to keep FTS in sync
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
DELETE FROM chunks_fts WHERE rowid = old.rowid;
END
""")
self.conn.execute("""
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
UPDATE chunks_fts SET text = new.text, id = new.id,
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
WHERE rowid = new.rowid;
END
""")
# Create files metadata table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS files (
path TEXT PRIMARY KEY,
source TEXT NOT NULL DEFAULT 'memory',
hash TEXT NOT NULL,
mtime INTEGER NOT NULL,
size INTEGER NOT NULL,
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
)
""")
self.conn.commit()
def save_chunk(self, chunk: MemoryChunk):
"""Save a memory chunk"""
self.conn.execute("""
INSERT OR REPLACE INTO chunks
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
""", (
chunk.id,
chunk.user_id,
chunk.scope,
chunk.source,
chunk.path,
chunk.start_line,
chunk.end_line,
chunk.text,
json.dumps(chunk.embedding) if chunk.embedding else None,
chunk.hash,
json.dumps(chunk.metadata) if chunk.metadata else None
))
self.conn.commit()
def save_chunks_batch(self, chunks: List[MemoryChunk]):
"""Save multiple chunks in a batch"""
self.conn.executemany("""
INSERT OR REPLACE INTO chunks
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
""", [
(
c.id, c.user_id, c.scope, c.source, c.path,
c.start_line, c.end_line, c.text,
json.dumps(c.embedding) if c.embedding else None,
c.hash,
json.dumps(c.metadata) if c.metadata else None
)
for c in chunks
])
self.conn.commit()
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
"""Get a chunk by ID"""
row = self.conn.execute("""
SELECT * FROM chunks WHERE id = ?
""", (chunk_id,)).fetchone()
if not row:
return None
return self._row_to_chunk(row)
def search_vector(
self,
query_embedding: List[float],
user_id: Optional[str] = None,
scopes: List[str] = None,
limit: int = 10
) -> List[SearchResult]:
"""
Vector similarity search using in-memory cosine similarity
(sqlite-vec can be added later for better performance)
"""
if scopes is None:
scopes = ["shared"]
if user_id:
scopes.append("user")
# Build query
scope_placeholders = ','.join('?' * len(scopes))
params = scopes
if user_id:
query = f"""
SELECT * FROM chunks
WHERE scope IN ({scope_placeholders})
AND (scope = 'shared' OR user_id = ?)
AND embedding IS NOT NULL
"""
params.append(user_id)
else:
query = f"""
SELECT * FROM chunks
WHERE scope IN ({scope_placeholders})
AND embedding IS NOT NULL
"""
rows = self.conn.execute(query, params).fetchall()
# Calculate cosine similarity
results = []
for row in rows:
embedding = json.loads(row['embedding'])
similarity = self._cosine_similarity(query_embedding, embedding)
if similarity > 0:
results.append((similarity, row))
# Sort by similarity and limit
results.sort(key=lambda x: x[0], reverse=True)
results = results[:limit]
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=score,
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for score, row in results
]
def search_keyword(
self,
query: str,
user_id: Optional[str] = None,
scopes: List[str] = None,
limit: int = 10
) -> List[SearchResult]:
"""Keyword search using FTS5"""
if scopes is None:
scopes = ["shared"]
if user_id:
scopes.append("user")
# Build FTS query
fts_query = self._build_fts_query(query)
if not fts_query:
return []
scope_placeholders = ','.join('?' * len(scopes))
params = [fts_query] + scopes
if user_id:
sql_query = f"""
SELECT chunks.*, bm25(chunks_fts) as rank
FROM chunks_fts
JOIN chunks ON chunks.id = chunks_fts.id
WHERE chunks_fts MATCH ?
AND chunks.scope IN ({scope_placeholders})
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
ORDER BY rank
LIMIT ?
"""
params.extend([user_id, limit])
else:
sql_query = f"""
SELECT chunks.*, bm25(chunks_fts) as rank
FROM chunks_fts
JOIN chunks ON chunks.id = chunks_fts.id
WHERE chunks_fts MATCH ?
AND chunks.scope IN ({scope_placeholders})
ORDER BY rank
LIMIT ?
"""
params.append(limit)
rows = self.conn.execute(sql_query, params).fetchall()
return [
SearchResult(
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
score=self._bm25_rank_to_score(row['rank']),
snippet=self._truncate_text(row['text'], 500),
source=row['source'],
user_id=row['user_id']
)
for row in rows
]
def delete_by_path(self, path: str):
"""Delete all chunks from a file"""
self.conn.execute("""
DELETE FROM chunks WHERE path = ?
""", (path,))
self.conn.commit()
def get_file_hash(self, path: str) -> Optional[str]:
"""Get stored file hash"""
row = self.conn.execute("""
SELECT hash FROM files WHERE path = ?
""", (path,)).fetchone()
return row['hash'] if row else None
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
"""Update file metadata"""
self.conn.execute("""
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
""", (path, source, file_hash, mtime, size))
self.conn.commit()
def get_stats(self) -> Dict[str, int]:
"""Get storage statistics"""
chunks_count = self.conn.execute("""
SELECT COUNT(*) as cnt FROM chunks
""").fetchone()['cnt']
files_count = self.conn.execute("""
SELECT COUNT(*) as cnt FROM files
""").fetchone()['cnt']
return {
'chunks': chunks_count,
'files': files_count
}
def close(self):
"""Close database connection"""
if self.conn:
self.conn.close()
# Helper methods
def _row_to_chunk(self, row) -> MemoryChunk:
"""Convert database row to MemoryChunk"""
return MemoryChunk(
id=row['id'],
user_id=row['user_id'],
scope=row['scope'],
source=row['source'],
path=row['path'],
start_line=row['start_line'],
end_line=row['end_line'],
text=row['text'],
embedding=json.loads(row['embedding']) if row['embedding'] else None,
hash=row['hash'],
metadata=json.loads(row['metadata']) if row['metadata'] else None
)
@staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""Calculate cosine similarity between two vectors"""
if len(vec1) != len(vec2):
return 0.0
dot_product = sum(a * b for a, b in zip(vec1, vec2))
norm1 = sum(a * a for a in vec1) ** 0.5
norm2 = sum(b * b for b in vec2) ** 0.5
if norm1 == 0 or norm2 == 0:
return 0.0
return dot_product / (norm1 * norm2)
@staticmethod
def _build_fts_query(raw_query: str) -> Optional[str]:
"""Build FTS5 query from raw text"""
import re
tokens = re.findall(r'[A-Za-z0-9_\u4e00-\u9fff]+', raw_query)
if not tokens:
return None
quoted = [f'"{t}"' for t in tokens]
return ' AND '.join(quoted)
@staticmethod
def _bm25_rank_to_score(rank: float) -> float:
"""Convert BM25 rank to 0-1 score"""
normalized = max(0, rank) if rank is not None else 999
return 1 / (1 + normalized)
@staticmethod
def _truncate_text(text: str, max_chars: int) -> str:
"""Truncate text to max characters"""
if len(text) <= max_chars:
return text
return text[:max_chars] + "..."
@staticmethod
def compute_hash(content: str) -> str:
"""Compute SHA256 hash of content"""
return hashlib.sha256(content.encode('utf-8')).hexdigest()

235
agent/memory/summarizer.py Normal file
View File

@@ -0,0 +1,235 @@
"""
Memory flush manager
Triggers memory flush before context compaction (similar to clawdbot)
"""
from typing import Optional, Callable, Any
from pathlib import Path
from datetime import datetime
class MemoryFlushManager:
"""
Manages memory flush operations before context compaction
Similar to clawdbot's memory flush mechanism:
- Triggers when context approaches token limit
- Runs a silent agent turn to write memories to disk
- Uses memory/YYYY-MM-DD.md for daily notes
- Uses MEMORY.md for long-term curated memories
"""
def __init__(
self,
workspace_dir: Path,
llm_model: Optional[Any] = None
):
"""
Initialize memory flush manager
Args:
workspace_dir: Workspace directory
llm_model: LLM model for agent execution (optional)
"""
self.workspace_dir = workspace_dir
self.llm_model = llm_model
self.memory_dir = workspace_dir / "memory"
self.memory_dir.mkdir(parents=True, exist_ok=True)
# Tracking
self.last_flush_token_count: Optional[int] = None
self.last_flush_timestamp: Optional[datetime] = None
def should_flush(
self,
current_tokens: int,
context_window: int,
reserve_tokens: int = 20000,
soft_threshold: int = 4000
) -> bool:
"""
Determine if memory flush should be triggered
Similar to clawdbot's shouldRunMemoryFlush logic:
threshold = contextWindow - reserveTokens - softThreshold
Args:
current_tokens: Current session token count
context_window: Model's context window size
reserve_tokens: Reserve tokens for compaction overhead
soft_threshold: Trigger flush N tokens before threshold
Returns:
True if flush should run
"""
if current_tokens <= 0:
return False
threshold = max(0, context_window - reserve_tokens - soft_threshold)
if threshold <= 0:
return False
# Check if we've crossed the threshold
if current_tokens < threshold:
return False
# Avoid duplicate flush in same compaction cycle
if self.last_flush_token_count is not None:
if current_tokens <= self.last_flush_token_count + soft_threshold:
return False
return True
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
"""
Get today's memory file path: memory/YYYY-MM-DD.md
Args:
user_id: Optional user ID for user-specific memory
Returns:
Path to today's memory file
"""
today = datetime.now().strftime("%Y-%m-%d")
if user_id:
user_dir = self.memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir / f"{today}.md"
else:
return self.memory_dir / f"{today}.md"
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
"""
Get main memory file path: memory/MEMORY.md
Args:
user_id: Optional user ID for user-specific memory
Returns:
Path to main memory file
"""
if user_id:
user_dir = self.memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir / "MEMORY.md"
else:
return self.memory_dir / "MEMORY.md"
def create_flush_prompt(self) -> str:
"""
Create prompt for memory flush turn
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_PROMPT
"""
today = datetime.now().strftime("%Y-%m-%d")
return (
f"Pre-compaction memory flush. "
f"Store durable memories now (use memory/{today}.md for daily notes; "
f"create memory/ if needed). "
f"If nothing to store, reply with NO_REPLY."
)
def create_flush_system_prompt(self) -> str:
"""
Create system prompt for memory flush turn
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_SYSTEM_PROMPT
"""
return (
"Pre-compaction memory flush turn. "
"The session is near auto-compaction; capture durable memories to disk. "
"You may reply, but usually NO_REPLY is correct."
)
async def execute_flush(
self,
agent_executor: Callable,
current_tokens: int,
user_id: Optional[str] = None,
**executor_kwargs
) -> bool:
"""
Execute memory flush by running a silent agent turn
Args:
agent_executor: Function to execute agent with prompt
current_tokens: Current token count
user_id: Optional user ID
**executor_kwargs: Additional kwargs for agent executor
Returns:
True if flush completed successfully
"""
try:
# Create flush prompts
prompt = self.create_flush_prompt()
system_prompt = self.create_flush_system_prompt()
# Execute agent turn (silent, no user-visible reply expected)
await agent_executor(
prompt=prompt,
system_prompt=system_prompt,
silent=True, # NO_REPLY expected
**executor_kwargs
)
# Track flush
self.last_flush_token_count = current_tokens
self.last_flush_timestamp = datetime.now()
return True
except Exception as e:
print(f"Memory flush failed: {e}")
return False
def get_status(self) -> dict:
"""Get memory flush status"""
return {
'last_flush_tokens': self.last_flush_token_count,
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
'today_file': str(self.get_today_memory_file()),
'main_file': str(self.get_main_memory_file())
}
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
"""
Create default memory files if they don't exist
Args:
workspace_dir: Workspace directory
user_id: Optional user ID for user-specific files
"""
memory_dir = workspace_dir / "memory"
memory_dir.mkdir(parents=True, exist_ok=True)
# Create main MEMORY.md in memory directory
if user_id:
user_dir = memory_dir / "users" / user_id
user_dir.mkdir(parents=True, exist_ok=True)
main_memory = user_dir / "MEMORY.md"
else:
main_memory = memory_dir / "MEMORY.md"
if not main_memory.exists():
# Create empty file or with minimal structure (no obvious "Memory" header)
# Following clawdbot's approach: memories should blend naturally into context
main_memory.write_text("")
# Create today's memory file
today = datetime.now().strftime("%Y-%m-%d")
if user_id:
user_dir = memory_dir / "users" / user_id
today_memory = user_dir / f"{today}.md"
else:
today_memory = memory_dir / f"{today}.md"
if not today_memory.exists():
today_memory.write_text(
f"# Daily Memory: {today}\n\n"
f"Day-to-day notes and running context.\n\n"
)

View File

@@ -0,0 +1,10 @@
"""
Memory tools for AgentMesh
Provides memory_search and memory_get tools for agents
"""
from agent.memory.tools.memory_search import MemorySearchTool
from agent.memory.tools.memory_get import MemoryGetTool
__all__ = ['MemorySearchTool', 'MemoryGetTool']

View File

@@ -0,0 +1,118 @@
"""
Memory get tool
Allows agents to read specific sections from memory files
"""
from typing import Dict, Any, Optional
from pathlib import Path
from agent.tools.base_tool import BaseTool
from agent.memory.manager import MemoryManager
class MemoryGetTool(BaseTool):
"""Tool for reading memory file contents"""
def __init__(self, memory_manager: MemoryManager):
"""
Initialize memory get tool
Args:
memory_manager: MemoryManager instance
"""
super().__init__()
self.memory_manager = memory_manager
self._name = "memory_get"
self._description = (
"Read specific memory file content by path and line range. "
"Use after memory_search to get full context from historical memory files."
)
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')"
},
"start_line": {
"type": "integer",
"description": "Starting line number (optional, default: 1)",
"default": 1
},
"num_lines": {
"type": "integer",
"description": "Number of lines to read (optional, reads all if not specified)"
}
},
"required": ["path"]
}
async def execute(self, **kwargs) -> str:
"""
Execute memory file read
Args:
path: File path
start_line: Start line
num_lines: Number of lines
Returns:
File content
"""
path = kwargs.get("path")
start_line = kwargs.get("start_line", 1)
num_lines = kwargs.get("num_lines")
if not path:
return "Error: path parameter is required"
try:
workspace_dir = self.memory_manager.config.get_workspace()
file_path = workspace_dir / path
if not file_path.exists():
return f"Error: File not found: {path}"
content = file_path.read_text()
lines = content.split('\n')
# Handle line range
if start_line < 1:
start_line = 1
start_idx = start_line - 1
if num_lines:
end_idx = start_idx + num_lines
selected_lines = lines[start_idx:end_idx]
else:
selected_lines = lines[start_idx:]
result = '\n'.join(selected_lines)
# Add metadata
total_lines = len(lines)
shown_lines = len(selected_lines)
output = [
f"File: {path}",
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
"",
result
]
return '\n'.join(output)
except Exception as e:
return f"Error reading memory file: {str(e)}"

View File

@@ -0,0 +1,106 @@
"""
Memory search tool
Allows agents to search their memory using semantic and keyword search
"""
from typing import Dict, Any, Optional
from agent.tools.base_tool import BaseTool
from agent.memory.manager import MemoryManager
class MemorySearchTool(BaseTool):
"""Tool for searching agent memory"""
def __init__(self, memory_manager: MemoryManager, user_id: Optional[str] = None):
"""
Initialize memory search tool
Args:
memory_manager: MemoryManager instance
user_id: Optional user ID for scoped search
"""
super().__init__()
self.memory_manager = memory_manager
self.user_id = user_id
self._name = "memory_search"
self._description = (
"Search historical memory files (beyond today/yesterday) using semantic and keyword search. "
"Recent context (MEMORY.md + today + yesterday) is already loaded. "
"Use this ONLY for older dates, specific past events, or when current context lacks needed info."
)
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def parameters(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query (can be natural language question or keywords)"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to return (default: 10)",
"default": 10
},
"min_score": {
"type": "number",
"description": "Minimum relevance score (0-1, default: 0.3)",
"default": 0.3
}
},
"required": ["query"]
}
async def execute(self, **kwargs) -> str:
"""
Execute memory search
Args:
query: Search query
max_results: Maximum results
min_score: Minimum score
Returns:
Formatted search results
"""
query = kwargs.get("query")
max_results = kwargs.get("max_results", 10)
min_score = kwargs.get("min_score", 0.3)
if not query:
return "Error: query parameter is required"
try:
results = await self.memory_manager.search(
query=query,
user_id=self.user_id,
max_results=max_results,
min_score=min_score,
include_shared=True
)
if not results:
return f"No relevant memories found for query: {query}"
# Format results
output = [f"Found {len(results)} relevant memories:\n"]
for i, result in enumerate(results, 1):
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
output.append(f" Score: {result.score:.3f}")
output.append(f" Snippet: {result.snippet}")
return "\n".join(output)
except Exception as e:
return f"Error searching memory: {str(e)}"

View File

@@ -0,0 +1,20 @@
from .agent import Agent
from .agent_stream import AgentStreamExecutor
from .task import Task, TaskType, TaskStatus
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
from .models import LLMModel, LLMRequest, ModelFactory
__all__ = [
'Agent',
'AgentStreamExecutor',
'Task',
'TaskType',
'TaskStatus',
'AgentResult',
'AgentAction',
'AgentActionType',
'ToolResult',
'LLMModel',
'LLMRequest',
'ModelFactory'
]

292
agent/protocol/agent.py Normal file
View File

@@ -0,0 +1,292 @@
import json
import time
from common.log import logger
from agent.protocol.models import LLMRequest, LLMModel
from agent.protocol.agent_stream import AgentStreamExecutor
from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult
from agent.tools.base_tool import BaseTool, ToolStage
class Agent:
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
context_reserve_tokens=None, memory_manager=None, name: str = None):
"""
Initialize the Agent with system prompt, model, description.
:param system_prompt: The system prompt for the agent.
:param description: A description of the agent.
:param model: An instance of LLMModel to be used by the agent.
:param tools: Optional list of tools for the agent to use.
:param output_mode: Control how execution progress is displayed:
"print" for console output or "logger" for using logger
:param max_steps: Maximum number of steps the agent can take (default: 100)
:param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model)
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
:param memory_manager: Optional MemoryManager instance for memory operations
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
"""
self.name = name or "Agent"
self.system_prompt = system_prompt
self.model: LLMModel = model # Instance of LLMModel
self.description = description
self.tools: list = []
self.max_steps = max_steps # max tool-call steps, default 100
self.max_context_tokens = max_context_tokens # max tokens in context
self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests
self.captured_actions = [] # Initialize captured actions list
self.output_mode = output_mode
self.last_usage = None # Store last API response usage info
self.messages = [] # Unified message history for stream mode
self.memory_manager = memory_manager # Memory manager for auto memory flush
if tools:
for tool in tools:
self.add_tool(tool)
def add_tool(self, tool: BaseTool):
"""
Add a tool to the agent.
:param tool: The tool to add (either a tool instance or a tool name)
"""
# If tool is already an instance, use it directly
tool.model = self.model
self.tools.append(tool)
def _get_model_context_window(self) -> int:
"""
Get the model's context window size in tokens.
Auto-detect based on model name.
Model context windows:
- Claude 3.5/3.7 Sonnet: 200K tokens
- Claude 3 Opus: 200K tokens
- GPT-4 Turbo/128K: 128K tokens
- GPT-4: 8K-32K tokens
- GPT-3.5: 16K tokens
- DeepSeek: 64K tokens
:return: Context window size in tokens
"""
if self.model and hasattr(self.model, 'model'):
model_name = self.model.model.lower()
# Claude models - 200K context
if 'claude-3' in model_name or 'claude-sonnet' in model_name:
return 200000
# GPT-4 models
elif 'gpt-4' in model_name:
if 'turbo' in model_name or '128k' in model_name:
return 128000
elif '32k' in model_name:
return 32000
else:
return 8000
# GPT-3.5
elif 'gpt-3.5' in model_name:
if '16k' in model_name:
return 16000
else:
return 4000
# DeepSeek
elif 'deepseek' in model_name:
return 64000
# Default conservative value
return 10000
def _get_context_reserve_tokens(self) -> int:
"""
Get the number of tokens to reserve for new requests.
This prevents context overflow by keeping a buffer.
:return: Number of tokens to reserve
"""
if self.context_reserve_tokens is not None:
return self.context_reserve_tokens
# Reserve ~20% of context window for new requests
context_window = self._get_model_context_window()
return max(4000, int(context_window * 0.2))
def _estimate_message_tokens(self, message: dict) -> int:
"""
Estimate token count for a message using chars/4 heuristic.
This is a conservative estimate (tends to overestimate).
:param message: Message dict with 'role' and 'content'
:return: Estimated token count
"""
content = message.get('content', '')
if isinstance(content, str):
return max(1, len(content) // 4)
elif isinstance(content, list):
# Handle multi-part content (text + images)
total_chars = 0
for part in content:
if isinstance(part, dict) and part.get('type') == 'text':
total_chars += len(part.get('text', ''))
elif isinstance(part, dict) and part.get('type') == 'image':
# Estimate images as ~1200 tokens
total_chars += 4800
return max(1, total_chars // 4)
return 1
def _find_tool(self, tool_name: str):
"""Find and return a tool with the specified name"""
for tool in self.tools:
if tool.name == tool_name:
# Only pre-process stage tools can be actively called
if tool.stage == ToolStage.PRE_PROCESS:
tool.model = self.model
tool.context = self # Set tool context
return tool
else:
# If it's a post-process tool, return None to prevent direct calling
logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.")
return None
return None
# output function based on mode
def output(self, message="", end="\n"):
if self.output_mode == "print":
print(message, end=end)
elif message:
logger.info(message)
def _execute_post_process_tools(self):
"""Execute all post-process stage tools"""
# Get all post-process stage tools
post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS]
# Execute each tool
for tool in post_process_tools:
# Set tool context
tool.context = self
# Record start time for execution timing
start_time = time.time()
# Execute tool (with empty parameters, tool will extract needed info from context)
result = tool.execute({})
# Calculate execution time
execution_time = time.time() - start_time
# Capture tool use for tracking
self.capture_tool_use(
tool_name=tool.name,
input_params={}, # Post-process tools typically don't take parameters
output=result.result,
status=result.status,
error_message=str(result.result) if result.status == "error" else None,
execution_time=execution_time
)
# Log result
if result.status == "success":
# Print tool execution result in the desired format
self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}")
else:
# Print failure in print mode
self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}")
def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None,
execution_time=0.0):
"""
Capture a tool use action.
:param thought: thought content
:param tool_name: Name of the tool used
:param input_params: Parameters passed to the tool
:param output: Output from the tool
:param status: Status of the tool execution
:param error_message: Error message if the tool execution failed
:param execution_time: Time taken to execute the tool
"""
tool_result = ToolResult(
tool_name=tool_name,
input_params=input_params,
output=output,
status=status,
error_message=error_message,
execution_time=execution_time
)
action = AgentAction(
agent_id=self.id if hasattr(self, 'id') else str(id(self)),
agent_name=self.name,
action_type=AgentActionType.TOOL_USE,
tool_result=tool_result,
thought=thought
)
self.captured_actions.append(action)
return action
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False) -> str:
"""
Execute single agent task with streaming (based on tool-call)
This method supports:
- Streaming output
- Multi-turn reasoning based on tool-call
- Event callbacks
- Persistent conversation history across calls
Args:
user_message: User message
on_event: Event callback function callback(event: dict)
event = {"type": str, "timestamp": float, "data": dict}
clear_history: If True, clear conversation history before this call (default: False)
Returns:
Final response text
Example:
# Multi-turn conversation with memory
response1 = agent.run_stream("My name is Alice")
response2 = agent.run_stream("What's my name?") # Will remember Alice
# Single-turn without memory
response = agent.run_stream("Hello", clear_history=True)
"""
# Clear history if requested
if clear_history:
self.messages = []
# Get model to use
if not self.model:
raise ValueError("No model available for agent")
# Create stream executor with agent's message history
executor = AgentStreamExecutor(
agent=self,
model=self.model,
system_prompt=self.system_prompt,
tools=self.tools,
max_turns=self.max_steps,
on_event=on_event,
messages=self.messages # Pass agent's message history
)
# Execute
response = executor.run_stream(user_message)
# Update agent's message history from executor
self.messages = executor.messages
# Execute all post-process tools
self._execute_post_process_tools()
return response
def clear_history(self):
"""Clear conversation history and captured actions"""
self.messages = []
self.captured_actions = []

View File

@@ -0,0 +1,461 @@
"""
Agent Stream Execution Module - Multi-turn reasoning based on tool-call
Provides streaming output, event system, and complete tool-call loop
"""
import json
import time
from typing import List, Dict, Any, Optional, Callable
from common.log import logger
from agent.protocol.models import LLMRequest, LLMModel
from agent.tools.base_tool import BaseTool, ToolResult
class AgentStreamExecutor:
"""
Agent Stream Executor
Handles multi-turn reasoning loop based on tool-call:
1. LLM generates response (may include tool calls)
2. Execute tools
3. Return results to LLM
4. Repeat until no more tool calls
"""
def __init__(
self,
agent, # Agent instance
model: LLMModel,
system_prompt: str,
tools: List[BaseTool],
max_turns: int = 50,
on_event: Optional[Callable] = None,
messages: Optional[List[Dict]] = None
):
"""
Initialize stream executor
Args:
agent: Agent instance (for accessing context)
model: LLM model
system_prompt: System prompt
tools: List of available tools
max_turns: Maximum number of turns
on_event: Event callback function
messages: Optional existing message history (for persistent conversations)
"""
self.agent = agent
self.model = model
self.system_prompt = system_prompt
# Convert tools list to dict
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
self.max_turns = max_turns
self.on_event = on_event
# Message history - use provided messages or create new list
self.messages = messages if messages is not None else []
def _emit_event(self, event_type: str, data: dict = None):
"""Emit event"""
if self.on_event:
try:
self.on_event({
"type": event_type,
"timestamp": time.time(),
"data": data or {}
})
except Exception as e:
logger.error(f"Event callback error: {e}")
def run_stream(self, user_message: str) -> str:
"""
Execute streaming reasoning loop
Args:
user_message: User message
Returns:
Final response text
"""
# Log user message
logger.info(f"\n{'='*50}")
logger.info(f"👤 用户: {user_message}")
logger.info(f"{'='*50}")
# Add user message (Claude format - use content blocks for consistency)
self.messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": user_message
}
]
})
self._emit_event("agent_start")
final_response = ""
turn = 0
try:
while turn < self.max_turns:
turn += 1
logger.info(f"\n{'='*50}{turn}{'='*50}")
self._emit_event("turn_start", {"turn": turn})
# Check if memory flush is needed (before calling LLM)
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
usage = self.agent.last_usage
if usage and 'input_tokens' in usage:
current_tokens = usage.get('input_tokens', 0)
context_window = self.agent._get_model_context_window()
reserve_tokens = self.agent.context_reserve_tokens or 20000
if self.agent.memory_manager.should_flush_memory(
current_tokens=current_tokens,
context_window=context_window,
reserve_tokens=reserve_tokens
):
self._emit_event("memory_flush_start", {
"current_tokens": current_tokens,
"threshold": context_window - reserve_tokens - 4000
})
# TODO: Execute memory flush in background
# This would require async support
logger.info(f"Memory flush recommended at {current_tokens} tokens")
# Call LLM
assistant_msg, tool_calls = self._call_llm_stream()
final_response = assistant_msg
# No tool calls, end loop
if not tool_calls:
if assistant_msg:
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
logger.info(f"✅ 完成 (无工具调用)")
self._emit_event("turn_end", {
"turn": turn,
"has_tool_calls": False
})
break
# Log tool calls in compact format
tool_names = [tc['name'] for tc in tool_calls]
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
# Execute tools
tool_results = []
tool_result_blocks = []
for tool_call in tool_calls:
result = self._execute_tool(tool_call)
tool_results.append(result)
# Log tool result in compact format
status_emoji = "" if result.get("status") == "success" else ""
result_str = str(result.get('result', ''))
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
# Build tool result block (Claude format)
# Content should be a string representation of the result
result_content = json.dumps(result) if not isinstance(result, str) else result
tool_result_blocks.append({
"type": "tool_result",
"tool_use_id": tool_call["id"],
"content": result_content
})
# Add tool results to message history as user message (Claude format)
self.messages.append({
"role": "user",
"content": tool_result_blocks
})
self._emit_event("turn_end", {
"turn": turn,
"has_tool_calls": True,
"tool_count": len(tool_calls)
})
if turn >= self.max_turns:
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
except Exception as e:
logger.error(f"❌ Agent执行错误: {e}")
self._emit_event("error", {"error": str(e)})
raise
finally:
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
self._emit_event("agent_end", {"final_response": final_response})
return final_response
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
"""
Call LLM with streaming
Returns:
(response_text, tool_calls)
"""
# Trim messages if needed (using agent's context management)
self._trim_messages()
# Prepare messages
messages = self._prepare_messages()
# Debug: log message structure
logger.debug(f"Sending {len(messages)} messages to LLM")
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = msg.get("content", "")
if isinstance(content, list):
content_types = [c.get("type") for c in content if isinstance(c, dict)]
logger.debug(f" Message {i}: role={role}, content_blocks={content_types}")
else:
logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}")
# Prepare tool definitions (OpenAI/Claude format)
tools_schema = None
if self.tools:
tools_schema = []
for tool in self.tools.values():
tools_schema.append({
"name": tool.name,
"description": tool.description,
"input_schema": tool.params # Claude uses input_schema
})
# Create request
request = LLMRequest(
messages=messages,
temperature=0,
stream=True,
tools=tools_schema,
system=self.system_prompt # Pass system prompt separately for Claude API
)
self._emit_event("message_start", {"role": "assistant"})
# Streaming response
full_content = ""
tool_calls_buffer = {} # {index: {id, name, arguments}}
try:
stream = self.model.call_stream(request)
for chunk in stream:
# Check for errors
if isinstance(chunk, dict) and chunk.get("error"):
error_msg = chunk.get("message", "Unknown error")
status_code = chunk.get("status_code", "N/A")
logger.error(f"API Error: {error_msg} (Status: {status_code})")
logger.error(f"Full error chunk: {chunk}")
raise Exception(f"{error_msg} (Status: {status_code})")
# Parse chunk
if isinstance(chunk, dict) and "choices" in chunk:
choice = chunk["choices"][0]
delta = choice.get("delta", {})
# Handle text content
if "content" in delta and delta["content"]:
content_delta = delta["content"]
full_content += content_delta
self._emit_event("message_update", {"delta": content_delta})
# Handle tool calls
if "tool_calls" in delta:
for tc_delta in delta["tool_calls"]:
index = tc_delta.get("index", 0)
if index not in tool_calls_buffer:
tool_calls_buffer[index] = {
"id": "",
"name": "",
"arguments": ""
}
if "id" in tc_delta:
tool_calls_buffer[index]["id"] = tc_delta["id"]
if "function" in tc_delta:
func = tc_delta["function"]
if "name" in func:
tool_calls_buffer[index]["name"] = func["name"]
if "arguments" in func:
tool_calls_buffer[index]["arguments"] += func["arguments"]
except Exception as e:
logger.error(f"LLM call error: {e}")
raise
# Parse tool calls
tool_calls = []
for idx in sorted(tool_calls_buffer.keys()):
tc = tool_calls_buffer[idx]
try:
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
except json.JSONDecodeError as e:
logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
arguments = {}
tool_calls.append({
"id": tc["id"],
"name": tc["name"],
"arguments": arguments
})
# Add assistant message to history (Claude format uses content blocks)
assistant_msg = {"role": "assistant", "content": []}
# Add text content block if present
if full_content:
assistant_msg["content"].append({
"type": "text",
"text": full_content
})
# Add tool_use blocks if present
if tool_calls:
for tc in tool_calls:
assistant_msg["content"].append({
"type": "tool_use",
"id": tc["id"],
"name": tc["name"],
"input": tc["arguments"]
})
# Only append if content is not empty
if assistant_msg["content"]:
self.messages.append(assistant_msg)
self._emit_event("message_end", {
"content": full_content,
"tool_calls": tool_calls
})
return full_content, tool_calls
def _execute_tool(self, tool_call: Dict) -> Dict[str, Any]:
"""
Execute tool
Args:
tool_call: {"id": str, "name": str, "arguments": dict}
Returns:
Tool execution result
"""
tool_name = tool_call["name"]
tool_id = tool_call["id"]
arguments = tool_call["arguments"]
self._emit_event("tool_execution_start", {
"tool_call_id": tool_id,
"tool_name": tool_name,
"arguments": arguments
})
try:
tool = self.tools.get(tool_name)
if not tool:
raise ValueError(f"Tool '{tool_name}' not found")
# Set tool context
tool.model = self.model
tool.context = self.agent
# Execute tool
start_time = time.time()
result: ToolResult = tool.execute_tool(arguments)
execution_time = time.time() - start_time
result_dict = {
"status": result.status,
"result": result.result,
"execution_time": execution_time
}
self._emit_event("tool_execution_end", {
"tool_call_id": tool_id,
"tool_name": tool_name,
**result_dict
})
return result_dict
except Exception as e:
logger.error(f"Tool execution error: {e}")
error_result = {
"status": "error",
"result": str(e),
"execution_time": 0
}
self._emit_event("tool_execution_end", {
"tool_call_id": tool_id,
"tool_name": tool_name,
**error_result
})
return error_result
def _trim_messages(self):
"""
Trim message history to stay within context limits.
Uses agent's context management configuration.
"""
if not self.messages or not self.agent:
return
# Get context window and reserve tokens from agent
context_window = self.agent._get_model_context_window()
reserve_tokens = self.agent._get_context_reserve_tokens()
max_tokens = context_window - reserve_tokens
# Estimate current tokens
current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
# Add system prompt tokens
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
current_tokens += system_tokens
# If under limit, no need to trim
if current_tokens <= max_tokens:
return
# Keep messages from newest, accumulating tokens
available_tokens = max_tokens - system_tokens
kept_messages = []
accumulated_tokens = 0
for msg in reversed(self.messages):
msg_tokens = self.agent._estimate_message_tokens(msg)
if accumulated_tokens + msg_tokens <= available_tokens:
kept_messages.insert(0, msg)
accumulated_tokens += msg_tokens
else:
break
old_count = len(self.messages)
self.messages = kept_messages
new_count = len(self.messages)
if old_count > new_count:
logger.info(
f"Context trimmed: {old_count} -> {new_count} messages "
f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
f"limit: {max_tokens})"
)
def _prepare_messages(self) -> List[Dict[str, Any]]:
"""
Prepare messages to send to LLM
Note: For Claude API, system prompt should be passed separately via system parameter,
not as a message. The AgentLLMModel will handle this.
"""
# Don't add system message here - it will be handled separately by the LLM adapter
return self.messages

27
agent/protocol/context.py Normal file
View File

@@ -0,0 +1,27 @@
class TeamContext:
def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100):
"""
Initialize the TeamContext with a name, description, rules, a list of agents, and a user question.
:param name: The name of the group context.
:param description: A description of the group context.
:param rule: The rules governing the group context.
:param agents: A list of agents in the context.
"""
self.name = name
self.description = description
self.rule = rule
self.agents = agents
self.user_task = "" # For backward compatibility
self.task = None # Will be a Task instance
self.model = None # Will be an instance of LLMModel
self.task_short_name = None # Store the task directory name
# List of agents that have been executed
self.agent_outputs: list = []
self.current_steps = 0
self.max_steps = max_steps
class AgentOutput:
def __init__(self, agent_name: str, output: str):
self.agent_name = agent_name
self.output = output

57
agent/protocol/models.py Normal file
View File

@@ -0,0 +1,57 @@
"""
Models module for agent system.
Provides basic model classes needed by tools and bridge integration.
"""
from typing import Any, Dict, List, Optional
class LLMRequest:
"""Request model for LLM operations"""
def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None,
temperature: float = 0.7, max_tokens: Optional[int] = None,
stream: bool = False, tools: Optional[List] = None, **kwargs):
self.messages = messages or []
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
self.tools = tools
# Allow extra attributes
for key, value in kwargs.items():
setattr(self, key, value)
class LLMModel:
"""Base class for LLM models"""
def __init__(self, model: str = None, **kwargs):
self.model = model
self.config = kwargs
def call(self, request: LLMRequest):
"""
Call the model with a request.
This is a placeholder implementation.
"""
raise NotImplementedError("LLMModel.call not implemented in this context")
def call_stream(self, request: LLMRequest):
"""
Call the model with streaming.
This is a placeholder implementation.
"""
raise NotImplementedError("LLMModel.call_stream not implemented in this context")
class ModelFactory:
"""Factory for creating model instances"""
@staticmethod
def create_model(model_type: str, **kwargs):
"""
Create a model instance based on type.
This is a placeholder implementation.
"""
raise NotImplementedError("ModelFactory.create_model not implemented in this context")

96
agent/protocol/result.py Normal file
View File

@@ -0,0 +1,96 @@
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Dict, Any, Optional
from agent.protocol.task import Task, TaskStatus
class AgentActionType(Enum):
"""Enum representing different types of agent actions."""
TOOL_USE = "tool_use"
THINKING = "thinking"
FINAL_ANSWER = "final_answer"
@dataclass
class ToolResult:
"""
Represents the result of a tool use.
Attributes:
tool_name: Name of the tool used
input_params: Parameters passed to the tool
output: Output from the tool
status: Status of the tool execution (success/error)
error_message: Error message if the tool execution failed
execution_time: Time taken to execute the tool
"""
tool_name: str
input_params: Dict[str, Any]
output: Any
status: str
error_message: Optional[str] = None
execution_time: float = 0.0
@dataclass
class AgentAction:
"""
Represents an action taken by an agent.
Attributes:
id: Unique identifier for the action
agent_id: ID of the agent that performed the action
agent_name: Name of the agent that performed the action
action_type: Type of action (tool use, thinking, final answer)
content: Content of the action (thought content, final answer content)
tool_result: Tool use details if action_type is TOOL_USE
timestamp: When the action was performed
"""
agent_id: str
agent_name: str
action_type: AgentActionType
id: str = field(default_factory=lambda: str(uuid.uuid4()))
content: str = ""
tool_result: Optional[ToolResult] = None
thought: Optional[str] = None
timestamp: float = field(default_factory=time.time)
@dataclass
class AgentResult:
"""
Represents the result of an agent's execution.
Attributes:
final_answer: The final answer provided by the agent
step_count: Number of steps taken by the agent
status: Status of the execution (success/error)
error_message: Error message if execution failed
"""
final_answer: str
step_count: int
status: str = "success"
error_message: Optional[str] = None
@classmethod
def success(cls, final_answer: str, step_count: int) -> "AgentResult":
"""Create a successful result"""
return cls(final_answer=final_answer, step_count=step_count)
@classmethod
def error(cls, error_message: str, step_count: int = 0) -> "AgentResult":
"""Create an error result"""
return cls(
final_answer=f"Error: {error_message}",
step_count=step_count,
status="error",
error_message=error_message
)
@property
def is_error(self) -> bool:
"""Check if the result represents an error"""
return self.status == "error"

95
agent/protocol/task.py Normal file
View File

@@ -0,0 +1,95 @@
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Any, List
class TaskType(Enum):
"""Enum representing different types of tasks."""
TEXT = "text"
IMAGE = "image"
VIDEO = "video"
AUDIO = "audio"
FILE = "file"
MIXED = "mixed"
class TaskStatus(Enum):
"""Enum representing the status of a task."""
INIT = "init" # Initial state
PROCESSING = "processing" # In progress
COMPLETED = "completed" # Completed
FAILED = "failed" # Failed
@dataclass
class Task:
"""
Represents a task to be processed by an agent.
Attributes:
id: Unique identifier for the task
content: The primary text content of the task
type: Type of the task
status: Current status of the task
created_at: Timestamp when the task was created
updated_at: Timestamp when the task was last updated
metadata: Additional metadata for the task
images: List of image URLs or base64 encoded images
videos: List of video URLs
audios: List of audio URLs or base64 encoded audios
files: List of file URLs or paths
"""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
content: str = ""
type: TaskType = TaskType.TEXT
status: TaskStatus = TaskStatus.INIT
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
metadata: Dict[str, Any] = field(default_factory=dict)
# Media content
images: List[str] = field(default_factory=list)
videos: List[str] = field(default_factory=list)
audios: List[str] = field(default_factory=list)
files: List[str] = field(default_factory=list)
def __init__(self, content: str = "", **kwargs):
"""
Initialize a Task with content and optional keyword arguments.
Args:
content: The text content of the task
**kwargs: Additional attributes to set
"""
self.id = kwargs.get('id', str(uuid.uuid4()))
self.content = content
self.type = kwargs.get('type', TaskType.TEXT)
self.status = kwargs.get('status', TaskStatus.INIT)
self.created_at = kwargs.get('created_at', time.time())
self.updated_at = kwargs.get('updated_at', time.time())
self.metadata = kwargs.get('metadata', {})
self.images = kwargs.get('images', [])
self.videos = kwargs.get('videos', [])
self.audios = kwargs.get('audios', [])
self.files = kwargs.get('files', [])
def get_text(self) -> str:
"""
Get the text content of the task.
Returns:
The text content
"""
return self.content
def update_status(self, status: TaskStatus) -> None:
"""
Update the status of the task.
Args:
status: The new status
"""
self.status = status
self.updated_at = time.time()

101
agent/tools/__init__.py Normal file
View File

@@ -0,0 +1,101 @@
# Import base tool
from agent.tools.base_tool import BaseTool
from agent.tools.tool_manager import ToolManager
# Import basic tools (no external dependencies)
from agent.tools.calculator.calculator import Calculator
from agent.tools.current_time.current_time import CurrentTime
# Import file operation tools
from agent.tools.read.read import Read
from agent.tools.write.write import Write
from agent.tools.edit.edit import Edit
from agent.tools.bash.bash import Bash
from agent.tools.grep.grep import Grep
from agent.tools.find.find import Find
from agent.tools.ls.ls import Ls
# Import memory tools
from agent.tools.memory.memory_search import MemorySearchTool
from agent.tools.memory.memory_get import MemoryGetTool
# Import tools with optional dependencies
def _import_optional_tools():
"""Import tools that have optional dependencies"""
tools = {}
# Google Search (requires requests)
try:
from agent.tools.google_search.google_search import GoogleSearch
tools['GoogleSearch'] = GoogleSearch
except ImportError:
pass
# File Save (may have dependencies)
try:
from agent.tools.file_save.file_save import FileSave
tools['FileSave'] = FileSave
except ImportError:
pass
# Terminal (basic, should work)
try:
from agent.tools.terminal.terminal import Terminal
tools['Terminal'] = Terminal
except ImportError:
pass
return tools
# Load optional tools
_optional_tools = _import_optional_tools()
GoogleSearch = _optional_tools.get('GoogleSearch')
FileSave = _optional_tools.get('FileSave')
Terminal = _optional_tools.get('Terminal')
# Delayed import for BrowserTool
def _import_browser_tool():
try:
from agent.tools.browser.browser_tool import BrowserTool
return BrowserTool
except ImportError:
# Return a placeholder class that will prompt the user to install dependencies when instantiated
class BrowserToolPlaceholder:
def __init__(self, *args, **kwargs):
raise ImportError(
"The 'browser-use' package is required to use BrowserTool. "
"Please install it with 'pip install browser-use>=0.1.40'."
)
return BrowserToolPlaceholder
# Dynamically set BrowserTool
BrowserTool = _import_browser_tool()
# Export all tools (including optional ones that might be None)
__all__ = [
'BaseTool',
'ToolManager',
'Calculator',
'CurrentTime',
'Read',
'Write',
'Edit',
'Bash',
'Grep',
'Find',
'Ls',
'MemorySearchTool',
'MemoryGetTool',
# Optional tools (may be None if dependencies not available)
'GoogleSearch',
'FileSave',
'Terminal',
'BrowserTool'
]
"""
Tools module for Agent.
"""

99
agent/tools/base_tool.py Normal file
View File

@@ -0,0 +1,99 @@
from enum import Enum
from typing import Any, Optional
from common.log import logger
import copy
class ToolStage(Enum):
"""Enum representing tool decision stages"""
PRE_PROCESS = "pre_process" # Tools that need to be actively selected by the agent
POST_PROCESS = "post_process" # Tools that automatically execute after final_answer
class ToolResult:
"""Tool execution result"""
def __init__(self, status: str = None, result: Any = None, ext_data: Any = None):
self.status = status
self.result = result
self.ext_data = ext_data
@staticmethod
def success(result, ext_data: Any = None):
return ToolResult(status="success", result=result, ext_data=ext_data)
@staticmethod
def fail(result, ext_data: Any = None):
return ToolResult(status="error", result=result, ext_data=ext_data)
class BaseTool:
"""Base class for all tools."""
# Default decision stage is pre-process
stage = ToolStage.PRE_PROCESS
# Class attributes must be inherited
name: str = "base_tool"
description: str = "Base tool"
params: dict = {} # Store JSON Schema
model: Optional[Any] = None # LLM model instance, type depends on bot implementation
@classmethod
def get_json_schema(cls) -> dict:
"""Get the standard description of the tool"""
return {
"name": cls.name,
"description": cls.description,
"parameters": cls.params
}
def execute_tool(self, params: dict) -> ToolResult:
try:
return self.execute(params)
except Exception as e:
logger.error(e)
def execute(self, params: dict) -> ToolResult:
"""Specific logic to be implemented by subclasses"""
raise NotImplementedError
@classmethod
def _parse_schema(cls) -> dict:
"""Convert JSON Schema to Pydantic fields"""
fields = {}
for name, prop in cls.params["properties"].items():
# Convert JSON Schema types to Python types
type_map = {
"string": str,
"number": float,
"integer": int,
"boolean": bool,
"array": list,
"object": dict
}
fields[name] = (
type_map[prop["type"]],
prop.get("default", ...)
)
return fields
def should_auto_execute(self, context) -> bool:
"""
Determine if this tool should be automatically executed based on context.
:param context: The agent context
:return: True if the tool should be executed, False otherwise
"""
# Only tools in post-process stage will be automatically executed
return self.stage == ToolStage.POST_PROCESS
def close(self):
"""
Close any resources used by the tool.
This method should be overridden by tools that need to clean up resources
such as browser connections, file handles, etc.
By default, this method does nothing.
"""
pass

View File

@@ -0,0 +1,3 @@
from .bash import Bash
__all__ = ['Bash']

187
agent/tools/bash/bash.py Normal file
View File

@@ -0,0 +1,187 @@
"""
Bash tool - Execute bash commands
"""
import os
import subprocess
import tempfile
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
class Bash(BaseTool):
"""Tool for executing bash commands"""
name: str = "bash"
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
IMPORTANT SAFETY GUIDELINES:
- You can freely create, modify, and delete files within the current workspace
- For operations outside the workspace or potentially destructive commands (rm -rf, system commands, etc.), always explain what you're about to do and ask for user confirmation first
- Be especially careful with: file deletions, system modifications, network operations, or commands that might affect system stability
- When in doubt, describe the command's purpose and ask for permission before executing"""
params: dict = {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Bash command to execute"
},
"timeout": {
"type": "integer",
"description": "Timeout in seconds (optional, default: 30)"
}
},
"required": ["command"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
# Ensure working directory exists
if not os.path.exists(self.cwd):
os.makedirs(self.cwd, exist_ok=True)
self.default_timeout = self.config.get("timeout", 30)
# Enable safety mode by default (can be disabled in config)
self.safety_mode = self.config.get("safety_mode", True)
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute a bash command
:param args: Dictionary containing the command and optional timeout
:return: Command output or error
"""
command = args.get("command", "").strip()
timeout = args.get("timeout", self.default_timeout)
if not command:
return ToolResult.fail("Error: command parameter is required")
# Optional safety check - only warn about extremely dangerous commands
if self.safety_mode:
warning = self._get_safety_warning(command)
if warning:
return ToolResult.fail(
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
try:
# Execute command
result = subprocess.run(
command,
shell=True,
cwd=self.cwd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=timeout
)
# Combine stdout and stderr
output = result.stdout
if result.stderr:
output += "\n" + result.stderr
# Check if we need to save full output to temp file
temp_file_path = None
total_bytes = len(output.encode('utf-8'))
if total_bytes > DEFAULT_MAX_BYTES:
# Save full output to temp file
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='bash-') as f:
f.write(output)
temp_file_path = f.name
# Apply tail truncation
truncation = truncate_tail(output)
output_text = truncation.content or "(no output)"
# Build result
details = {}
if truncation.truncated:
details["truncation"] = truncation.to_dict()
if temp_file_path:
details["full_output_path"] = temp_file_path
# Build notice
start_line = truncation.total_lines - truncation.output_lines + 1
end_line = truncation.total_lines
if truncation.last_line_partial:
# Edge case: last line alone > 30KB
last_line = output.split('\n')[-1] if output else ""
last_line_size = format_size(len(last_line.encode('utf-8')))
output_text += f"\n\n[Showing last {format_size(truncation.output_bytes)} of line {end_line} (line is {last_line_size}). Full output: {temp_file_path}]"
elif truncation.truncated_by == "lines":
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines}. Full output: {temp_file_path}]"
else:
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Full output: {temp_file_path}]"
# Check exit code
if result.returncode != 0:
output_text += f"\n\nCommand exited with code {result.returncode}"
return ToolResult.fail({
"output": output_text,
"exit_code": result.returncode,
"details": details if details else None
})
return ToolResult.success({
"output": output_text,
"exit_code": result.returncode,
"details": details if details else None
})
except subprocess.TimeoutExpired:
return ToolResult.fail(f"Error: Command timed out after {timeout} seconds")
except Exception as e:
return ToolResult.fail(f"Error executing command: {str(e)}")
def _get_safety_warning(self, command: str) -> str:
"""
Get safety warning for potentially dangerous commands
Only warns about extremely dangerous system-level operations
:param command: Command to check
:return: Warning message if dangerous, empty string if safe
"""
cmd_lower = command.lower().strip()
# Only block extremely dangerous system operations
dangerous_patterns = [
# System shutdown/reboot
("shutdown", "This command will shut down the system"),
("reboot", "This command will reboot the system"),
("halt", "This command will halt the system"),
("poweroff", "This command will power off the system"),
# Critical system modifications
("rm -rf /", "This command will delete the entire filesystem"),
("rm -rf /*", "This command will delete the entire filesystem"),
("dd if=/dev/zero", "This command can destroy disk data"),
("mkfs", "This command will format a filesystem, destroying all data"),
("fdisk", "This command modifies disk partitions"),
# User/system management (only if targeting system users)
("userdel root", "This command will delete the root user"),
("passwd root", "This command will change the root password"),
]
for pattern, warning in dangerous_patterns:
if pattern in cmd_lower:
return warning
# Check for recursive deletion outside workspace
if "rm" in cmd_lower and "-rf" in cmd_lower:
# Allow deletion within current workspace
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
# Check if targeting system directories
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
if any(sysdir in cmd_lower for sysdir in system_dirs):
return "This command will recursively delete system directories"
return "" # No warning needed

View File

@@ -0,0 +1,59 @@
class BrowserAction:
"""Base class for browser actions"""
code = ""
description = ""
class Navigate(BrowserAction):
"""Navigate to a URL in the current tab"""
code = "navigate"
description = "Navigate to URL in the current tab"
class ClickElement(BrowserAction):
"""Click an element on the page"""
code = "click_element"
description = "Click element"
class ExtractContent(BrowserAction):
"""Extract content from the page"""
code = "extract_content"
description = "Extract the page content to retrieve specific information for a goal"
class InputText(BrowserAction):
"""Input text into an element"""
code = "input_text"
description = "Input text into a input interactive element"
class ScrollDown(BrowserAction):
"""Scroll down the page"""
code = "scroll_down"
description = "Scroll down the page by pixel amount"
class ScrollUp(BrowserAction):
"""Scroll up the page"""
code = "scroll_up"
description = "Scroll up the page by pixel amount - if no amount is specified, scroll up one page"
class OpenTab(BrowserAction):
"""Open a URL in a new tab"""
code = "open_tab"
description = "Open url in new tab"
class SwitchTab(BrowserAction):
"""Switch to a tab"""
code = "switch_tab"
description = "Switched to tab"
class SendKeys(BrowserAction):
"""Switch to a tab"""
code = "send_keys"
description = "Send strings of special keyboard keys like Escape, Backspace, Insert, PageDown, Delete, Enter, " \
"ArrowRight, ArrowUp, etc"

View File

@@ -0,0 +1,317 @@
import asyncio
from typing import Any, Dict
import json
import re
import os
import platform
from browser_use import Browser
from browser_use import BrowserConfig
from browser_use.browser.context import BrowserContext, BrowserContextConfig
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.browser.browser_action import *
from agent.models import LLMRequest
from agent.models.model_factory import ModelFactory
from browser_use.dom.service import DomService
from common.log import logger
# Use lazy import, only import when actually used
def _import_browser_use():
try:
import browser_use
return browser_use
except ImportError:
raise ImportError(
"The 'browser-use' package is required to use BrowserTool. "
"Please install it with 'pip install browser-use>=0.1.40' or "
"'pip install agentmesh-sdk[full]'."
)
def _get_action_prompt():
action_classes = [Navigate, ClickElement, ExtractContent, InputText, OpenTab, SwitchTab, ScrollDown, ScrollUp,
SendKeys]
action_prompt = ""
for action_class in action_classes:
action_prompt += f"{action_class.code}: {action_class.description}\n"
return action_prompt.strip()
def _header_less() -> bool:
if platform.system() == "Linux" and not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"):
return True
return False
class BrowserTool(BaseTool):
name: str = "browser"
description: str = "A tool to perform browser operations like navigating to URLs, element interaction, " \
"and extracting content."
params: dict = {
"type": "object",
"properties": {
"operation": {
"type": "string",
"description": f"The browser operation to perform: \n{_get_action_prompt()}"
},
"url": {
"type": "string",
"description": f"The URL to navigate to (required for '{Navigate.code}', '{OpenTab.code}' actions). "
},
"goal": {
"type": "string",
"description": f"The goal of extracting page content (required for '{ExtractContent.code}' action)."
},
"text": {
"type": "string",
"description": f"Text to type (required for '{InputText.code}' action)."
},
"index": {
"type": "integer",
"description": f"Element index (required for '{ClickElement.code}', '{InputText.code}' actions)",
},
"tab_id": {
"type": "integer",
"description": f"Page tab ID (required for '{SwitchTab.code}' action)",
},
"scroll_amount": {
"type": "integer",
"description": f"The number of pixels to scroll (required for '{ScrollDown.code}', '{ScrollUp.code}' action)."
},
"keys": {
"type": "string",
"description": f"Keys to send (required for '{SendKeys.code}' action)"
}
},
"required": ["operation"]
}
# Class variable to ensure only one browser instance is created
browser = None
browser_context: BrowserContext = None
dom_service: DomService = None
_initialized = False
# Adding an event loop variable
_event_loop = None
def __init__(self):
# Only import during initialization, not at module level
self.browser_use = _import_browser_use()
# Do not initialize the browser in the constructor, but initialize it on the first execution
pass
async def _init_browser(self) -> BrowserContext:
"""Ensure the browser is initialized"""
if not BrowserTool._initialized:
os.environ['BROWSER_USE_LOGGING_LEVEL'] = 'error'
print("Initializing browser...")
# Initialize the browser synchronously
BrowserTool.browser = Browser(BrowserConfig(headless=_header_less(),
disable_security=True))
context_config = BrowserContextConfig()
context_config.highlight_elements = True
BrowserTool.browser_context = await BrowserTool.browser.new_context(context_config)
BrowserTool._initialized = True
print("Browser initialized successfully")
BrowserTool.dom_service = DomService(await BrowserTool.browser_context.get_current_page())
return BrowserTool.browser_context
def execute(self, params: Dict[str, Any]) -> ToolResult:
"""
Execute browser operations based on the provided arguments.
:param params: Dictionary containing the action and related parameters
:return: Result of the browser operation
"""
# Ensure browser_use is imported
if not hasattr(self, 'browser_use'):
self.browser_use = _import_browser_use()
action = params.get("operation", "").lower()
try:
# Use a single event loop
if BrowserTool._event_loop is None:
BrowserTool._event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(BrowserTool._event_loop)
# Run tasks in the existing event loop
return BrowserTool._event_loop.run_until_complete(self._execute_async(action, params))
except Exception as e:
print(f"Error executing browser action: {e}")
return ToolResult.fail(result=f"Error executing browser action: {str(e)}")
async def _get_page_state(self, context: BrowserContext):
state = await self._get_state(context)
include_attributes = ["img", "div", "button", "input"]
elements = state.element_tree.clickable_elements_to_string(include_attributes)
pattern = r'\[\d+\]<[^>]+\/>'
# Find all matching elements
interactive_elements = re.findall(pattern, elements)
page_state = {
"url": state.url,
"title": state.title,
"pixels_above": getattr(state, "pixels_above", 0),
"pixels_below": getattr(state, "pixels_below", 0),
"tabs": [tab.model_dump() for tab in state.tabs],
"interactive_elements": interactive_elements,
}
return page_state
async def _get_state(self, context: BrowserContext, cache_clickable_elements_hashes=True):
try:
return await context.get_state()
except TypeError:
return await context.get_state(cache_clickable_elements_hashes=cache_clickable_elements_hashes)
async def _get_page_info(self, context: BrowserContext):
page_state = await self._get_page_state(context)
state_str = f"""## Current browser state
The following is the information of the current browser page. Each serial number in interactive_elements represents the element index:
{json.dumps(page_state, indent=4, ensure_ascii=False)}
"""
return state_str
async def _execute_async(self, action: str, params: Dict[str, Any]) -> ToolResult:
"""Asynchronously execute browser operations"""
# Use the browser context from the class variable
context = await self._init_browser()
if action == Navigate.code:
url = params.get("url")
if not url:
return ToolResult.fail(result="URL is required for navigate action")
if url.startswith("/"):
url = f"file://{url}"
print(f"Navigating to {url}...")
page = await context.get_current_page()
await page.goto(url)
await page.wait_for_load_state()
state = await self._get_page_info(context)
# print(state)
print(f"Navigation complete")
return ToolResult.success(result=f"Navigated to {url}", ext_data=state)
elif action == OpenTab.code:
url = params.get("url")
if url.startswith("/"):
url = f"file://{url}"
await context.create_new_tab(url)
msg = f"Opened new tab with {url}"
return ToolResult.success(result=msg)
elif action == ExtractContent.code:
try:
goal = params.get("goal")
page = await context.get_current_page()
if params.get("url"):
await page.goto(params.get("url"))
await page.wait_for_load_state()
import markdownify
content = markdownify.markdownify(await page.content())
elements = await self._get_page_state(context)
prompt = f"Your task is to extract the content of the page. You will be given a page and a goal and you should extract all relevant information around this goal from the page. If the goal is vague, " \
f"summarize the page. Respond in json format. elements: {elements.get('interactive_elements')}, extraction goal: {goal}, Page: {content},"
request = LLMRequest(
messages=[{"role": "user", "content": prompt}],
temperature=0,
json_format=True
)
model = self.model or ModelFactory().get_model(model_name="gpt-4o")
response = model.call(request)
if response.success:
extract_content = response.data["choices"][0]["message"]["content"]
print(f"Extract from page: {extract_content}")
return ToolResult.success(result=f"Extract from page: {extract_content}",
ext_data=await self._get_page_info(context))
else:
return ToolResult.fail(result=f"Extract from page failed: {response.get_error_msg()}")
except Exception as e:
logger.error(e)
elif action == ClickElement.code:
index = params.get("index")
element = await context.get_dom_element_by_index(index)
await context._click_element_node(element)
msg = f"Clicked element at index {index}"
print(msg)
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
elif action == InputText.code:
index = params.get("index")
text = params.get("text")
element = await context.get_dom_element_by_index(index)
await context._input_text_element_node(element, text)
await asyncio.sleep(1)
msg = f"Input text into element successfully, index: {index}, text: {text}"
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
elif action == SwitchTab.code:
tab_id = params.get("tab_id")
print(f"Switch tab, tab_id={tab_id}")
await context.switch_to_tab(tab_id)
page = await context.get_current_page()
await page.wait_for_load_state()
msg = f"Switched to tab {tab_id}"
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
elif action in [ScrollDown.code, ScrollUp.code]:
scroll_amount = params.get("scroll_amount")
if not scroll_amount:
scroll_amount = context.config.browser_window_size["height"]
print(f"Scrolling by {scroll_amount} pixels")
scroll_amount = scroll_amount if action == ScrollDown.code else (scroll_amount * -1)
await context.execute_javascript(f"window.scrollBy(0, {scroll_amount});")
msg = f"{action} by {scroll_amount} pixels"
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
elif action == SendKeys.code:
keys = params.get("keys")
page = await context.get_current_page()
await page.keyboard.press(keys)
msg = f"Sent keys: {keys}"
print(msg)
return ToolResult(output=f"Sent keys: {keys}")
else:
msg = "Failed to operate the browser"
return ToolResult.fail(result=msg)
def close(self):
"""
Close browser resources.
This method handles the asynchronous closing of browser and browser context.
"""
if not BrowserTool._initialized:
return
try:
# Use the existing event loop to close browser resources
if BrowserTool._event_loop is not None:
# Define the async close function
async def close_browser_async():
if BrowserTool.browser_context is not None:
try:
await BrowserTool.browser_context.close()
except Exception as e:
logger.error(f"Error closing browser context: {e}")
if BrowserTool.browser is not None:
try:
await BrowserTool.browser.close()
except Exception as e:
logger.error(f"Error closing browser: {e}")
# Reset the initialized flag
BrowserTool._initialized = False
BrowserTool.browser = None
BrowserTool.browser_context = None
BrowserTool.dom_service = None
# Run the async close function in the existing event loop
BrowserTool._event_loop.run_until_complete(close_browser_async())
# Close the event loop
BrowserTool._event_loop.close()
BrowserTool._event_loop = None
except Exception as e:
print(f"Error during browser cleanup: {e}")

View File

@@ -0,0 +1,18 @@
def copy(self):
"""
Special copy method for browser tool to avoid recreating browser instance.
:return: A new instance with shared browser reference but unique model
"""
new_tool = self.__class__()
# Copy essential attributes
new_tool.model = self.model
new_tool.context = getattr(self, 'context', None)
new_tool.config = getattr(self, 'config', None)
# Share the browser instance instead of creating a new one
if hasattr(self, 'browser'):
new_tool.browser = self.browser
return new_tool

View File

@@ -0,0 +1,58 @@
import math
from agent.tools.base_tool import BaseTool, ToolResult
class Calculator(BaseTool):
name: str = "calculator"
description: str = "A tool to perform basic mathematical calculations."
params: dict = {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). "
"Ensure your input is a valid Python expression, it will be evaluated directly."
}
},
"required": ["expression"]
}
config: dict = {}
def execute(self, args: dict) -> ToolResult:
try:
# Get the expression
expression = args["expression"]
# Create a safe local environment containing only basic math functions
safe_locals = {
"abs": abs,
"round": round,
"max": max,
"min": min,
"pow": pow,
"sqrt": math.sqrt,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"pi": math.pi,
"e": math.e,
"log": math.log,
"log10": math.log10,
"exp": math.exp,
"floor": math.floor,
"ceil": math.ceil
}
# Safely evaluate the expression
result = eval(expression, {"__builtins__": {}}, safe_locals)
return ToolResult.success({
"result": result,
"expression": expression
})
except Exception as e:
return ToolResult.success({
"error": str(e),
"expression": args.get("expression", "")
})

View File

@@ -0,0 +1,75 @@
import datetime
import time
from agent.tools.base_tool import BaseTool, ToolResult
class CurrentTime(BaseTool):
name: str = "time"
description: str = "A tool to get current date and time information."
params: dict = {
"type": "object",
"properties": {
"format": {
"type": "string",
"description": "Optional format for the time (e.g., 'iso', 'unix', 'human'). Default is 'human'."
},
"timezone": {
"type": "string",
"description": "Optional timezone specification (e.g., 'UTC', 'local'). Default is 'local'."
}
},
"required": []
}
config: dict = {}
def execute(self, args: dict) -> ToolResult:
try:
# Get the format and timezone parameters, with defaults
time_format = args.get("format", "human").lower()
timezone = args.get("timezone", "local").lower()
# Get current time
current_time = datetime.datetime.now()
# Handle timezone if specified
if timezone == "utc":
current_time = datetime.datetime.utcnow()
# Format the time according to the specified format
if time_format == "iso":
# ISO 8601 format
formatted_time = current_time.isoformat()
elif time_format == "unix":
# Unix timestamp (seconds since epoch)
formatted_time = time.time()
else:
# Human-readable format
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
# Prepare additional time components for the response
year = current_time.year
month = current_time.month
day = current_time.day
hour = current_time.hour
minute = current_time.minute
second = current_time.second
weekday = current_time.strftime("%A") # Full weekday name
result = {
"current_time": formatted_time,
"components": {
"year": year,
"month": month,
"day": day,
"hour": hour,
"minute": minute,
"second": second,
"weekday": weekday
},
"format": time_format,
"timezone": timezone
}
return ToolResult.success(result=result)
except Exception as e:
return ToolResult.fail(result=str(e))

View File

@@ -0,0 +1,3 @@
from .edit import Edit
__all__ = ['Edit']

164
agent/tools/edit/edit.py Normal file
View File

@@ -0,0 +1,164 @@
"""
Edit tool - Precise file editing
Edit files through exact text replacement
"""
import os
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.diff import (
strip_bom,
detect_line_ending,
normalize_to_lf,
restore_line_endings,
normalize_for_fuzzy_match,
fuzzy_find_text,
generate_diff_string
)
class Edit(BaseTool):
"""Tool for precise file editing"""
name: str = "edit"
description: str = "Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to edit (relative or absolute)"
},
"oldText": {
"type": "string",
"description": "Exact text to find and replace (must match exactly)"
},
"newText": {
"type": "string",
"description": "New text to replace the old text with"
}
},
"required": ["path", "oldText", "newText"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file edit operation
:param args: Contains file path, old text and new text
:return: Operation result
"""
path = args.get("path", "").strip()
old_text = args.get("oldText", "")
new_text = args.get("newText", "")
if not path:
return ToolResult.fail("Error: path parameter is required")
# Resolve path
absolute_path = self._resolve_path(path)
# Check if file exists
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: File not found: {path}")
# Check if readable/writable
if not os.access(absolute_path, os.R_OK | os.W_OK):
return ToolResult.fail(f"Error: File is not readable/writable: {path}")
try:
# Read file
with open(absolute_path, 'r', encoding='utf-8') as f:
raw_content = f.read()
# Remove BOM (LLM won't include invisible BOM in oldText)
bom, content = strip_bom(raw_content)
# Detect original line ending
original_ending = detect_line_ending(content)
# Normalize to LF
normalized_content = normalize_to_lf(content)
normalized_old_text = normalize_to_lf(old_text)
normalized_new_text = normalize_to_lf(new_text)
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
if not match_result.found:
return ToolResult.fail(
f"Error: Could not find the exact text in {path}. "
"The old text must match exactly including all whitespace and newlines."
)
# Calculate occurrence count (use fuzzy normalized content for consistency)
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
occurrences = fuzzy_content.count(fuzzy_old_text)
if occurrences > 1:
return ToolResult.fail(
f"Error: Found {occurrences} occurrences of the text in {path}. "
"The text must be unique. Please provide more context to make it unique."
)
# Execute replacement (use matched text position)
base_content = match_result.content_for_replacement
new_content = (
base_content[:match_result.index] +
normalized_new_text +
base_content[match_result.index + match_result.match_length:]
)
# Verify replacement actually changed content
if base_content == new_content:
return ToolResult.fail(
f"Error: No changes made to {path}. "
"The replacement produced identical content. "
"This might indicate an issue with special characters or the text not existing as expected."
)
# Restore original line endings
final_content = bom + restore_line_endings(new_content, original_ending)
# Write file
with open(absolute_path, 'w', encoding='utf-8') as f:
f.write(final_content)
# Generate diff
diff_result = generate_diff_string(base_content, new_content)
result = {
"message": f"Successfully replaced text in {path}",
"path": path,
"diff": diff_result['diff'],
"first_changed_line": diff_result['first_changed_line']
}
return ToolResult.success(result)
except UnicodeDecodeError:
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {path}")
except PermissionError:
return ToolResult.fail(f"Error: Permission denied accessing {path}")
except Exception as e:
return ToolResult.fail(f"Error editing file: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""
Resolve path to absolute path
:param path: Relative or absolute path
:return: Absolute path
"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -0,0 +1,3 @@
from .file_save import FileSave
__all__ = ['FileSave']

View File

@@ -0,0 +1,770 @@
import os
import time
import re
import json
from pathlib import Path
from typing import Dict, Any, Optional, Tuple
from agent.tools.base_tool import BaseTool, ToolResult, ToolStage
from agent.models import LLMRequest
from common.log import logger
class FileSave(BaseTool):
"""Tool for saving content to files in the workspace directory."""
name = "file_save"
description = "Save the agent's output to a file in the workspace directory. Content is automatically extracted from the agent's previous outputs."
# Set as post-process stage tool
stage = ToolStage.POST_PROCESS
params = {
"type": "object",
"properties": {
"file_name": {
"type": "string",
"description": "Optional. The name of the file to save. If not provided, a name will be generated based on the content."
},
"file_type": {
"type": "string",
"description": "Optional. The type/extension of the file (e.g., 'txt', 'md', 'py', 'java'). If not provided, it will be inferred from the content."
},
"extract_code": {
"type": "boolean",
"description": "Optional. If true, will attempt to extract code blocks from the content. Default is false."
}
},
"required": [] # No required fields, as everything can be extracted from context
}
def __init__(self):
self.context = None
self.config = {}
self.workspace_dir = Path("workspace")
def execute(self, params: Dict[str, Any]) -> ToolResult:
"""
Save content to a file in the workspace directory.
:param params: The parameters for the file output operation.
:return: Result of the operation.
"""
# Extract content from context
if not hasattr(self, 'context') or not self.context:
return ToolResult.fail("Error: No context available to extract content from.")
content = self._extract_content_from_context()
# If no content could be extracted, return error
if not content:
return ToolResult.fail("Error: Couldn't extract content from context.")
# Use model to determine file parameters
try:
task_dir = self._get_task_dir_from_context()
file_name, file_type, extract_code = self._get_file_params_from_model(content)
except Exception as e:
logger.error(f"Error determining file parameters: {str(e)}")
# Fall back to manual parameter extraction
task_dir = params.get("task_dir") or self._get_task_id_from_context() or f"task_{int(time.time())}"
file_name = params.get("file_name") or self._infer_file_name(content)
file_type = params.get("file_type") or self._infer_file_type(content)
extract_code = params.get("extract_code", False)
# Get team_name from context
team_name = self._get_team_name_from_context() or "default_team"
# Create directory structure
task_dir_path = self.workspace_dir / team_name / task_dir
task_dir_path.mkdir(parents=True, exist_ok=True)
if extract_code:
# Save the complete content as markdown
md_file_name = f"{file_name}.md"
md_file_path = task_dir_path / md_file_name
# Write content to file
with open(md_file_path, 'w', encoding='utf-8') as f:
f.write(content)
return self._handle_multiple_code_blocks(content)
# Ensure file_name has the correct extension
if file_type and not file_name.endswith(f".{file_type}"):
file_name = f"{file_name}.{file_type}"
# Create the full file path
file_path = task_dir_path / file_name
# Get absolute path for storage in team_context
abs_file_path = file_path.absolute()
try:
# Write content to file
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
# Update the current agent's final_answer to include file information
if hasattr(self.context, 'team_context'):
# Store with absolute path in team_context
self.context.team_context.agent_outputs[-1].output += f"\n\nSaved file: {abs_file_path}"
return ToolResult.success({
"status": "success",
"file_path": str(file_path) # Return relative path in result
})
except Exception as e:
return ToolResult.fail(f"Error saving file: {str(e)}")
def _handle_multiple_code_blocks(self, content: str) -> ToolResult:
"""
Handle content with multiple code blocks, extracting and saving each as a separate file.
:param content: The content containing multiple code blocks
:return: Result of the operation
"""
# Extract code blocks with context (including potential file name information)
code_blocks_with_context = self._extract_code_blocks_with_context(content)
if not code_blocks_with_context:
return ToolResult.fail("No code blocks found in the content.")
# Get task directory and team name
task_dir = self._get_task_dir_from_context() or f"task_{int(time.time())}"
team_name = self._get_team_name_from_context() or "default_team"
# Create directory structure
task_dir_path = self.workspace_dir / team_name / task_dir
task_dir_path.mkdir(parents=True, exist_ok=True)
saved_files = []
for block_with_context in code_blocks_with_context:
try:
# Use model to determine file name for this code block
block_file_name, block_file_type = self._get_filename_for_code_block(block_with_context)
# Clean the code block (remove md code markers)
clean_code = self._clean_code_block(block_with_context)
# Ensure file_name has the correct extension
if block_file_type and not block_file_name.endswith(f".{block_file_type}"):
block_file_name = f"{block_file_name}.{block_file_type}"
# Create the full file path (no subdirectories)
file_path = task_dir_path / block_file_name
# Get absolute path for storage in team_context
abs_file_path = file_path.absolute()
# Write content to file
with open(file_path, 'w', encoding='utf-8') as f:
f.write(clean_code)
saved_files.append({
"file_path": str(file_path),
"abs_file_path": str(abs_file_path), # Store absolute path for internal use
"file_name": block_file_name,
"size": len(clean_code),
"status": "success",
"type": "code"
})
except Exception as e:
logger.error(f"Error saving code block: {str(e)}")
# Continue with the next block even if this one fails
if not saved_files:
return ToolResult.fail("Failed to save any code blocks.")
# Update the current agent's final_answer to include files information
if hasattr(self, 'context') and self.context:
# If the agent has a final_answer attribute, append the files info to it
if hasattr(self.context, 'team_context'):
# Use relative paths for display
display_info = f"\n\nSaved files to {task_dir_path}:\n" + "\n".join(
[f"- {f['file_path']}" for f in saved_files])
# Check if we need to append the info
if not self.context.team_context.agent_outputs[-1].output.endswith(display_info):
# Store with absolute paths in team_context
abs_info = f"\n\nSaved files to {task_dir_path.absolute()}:\n" + "\n".join(
[f"- {f['abs_file_path']}" for f in saved_files])
self.context.team_context.agent_outputs[-1].output += abs_info
result = {
"status": "success",
"files": [{"file_path": f["file_path"]} for f in saved_files]
}
return ToolResult.success(result)
def _extract_code_blocks_with_context(self, content: str) -> list:
"""
Extract code blocks from content, including context lines before the block.
:param content: The content to extract code blocks from
:return: List of code blocks with context
"""
# Check if content starts with <!DOCTYPE or <html - likely a full HTML file
if content.strip().startswith(("<!DOCTYPE", "<html", "<?xml")):
return [content] # Return the entire content as a single block
# Split content into lines
lines = content.split('\n')
blocks = []
in_code_block = False
current_block = []
context_lines = []
# Check if there are any code block markers in the content
if not re.search(r'```\w+', content):
# If no code block markers and content looks like code, return the entire content
if self._is_likely_code(content):
return [content]
for line in lines:
if line.strip().startswith('```'):
if in_code_block:
# End of code block
current_block.append(line)
# Only add blocks that have a language specified
block_content = '\n'.join(current_block)
if re.search(r'```\w+', current_block[0]):
# Combine context with code block
blocks.append('\n'.join(context_lines + current_block))
current_block = []
context_lines = []
in_code_block = False
else:
# Start of code block - check if it has a language specified
if re.search(r'```\w+', line) and not re.search(r'```language=\s*$', line):
# Start of code block with language
in_code_block = True
current_block = [line]
# Keep only the last few context lines
context_lines = context_lines[-5:] if context_lines else []
elif in_code_block:
current_block.append(line)
else:
# Store context lines when not in a code block
context_lines.append(line)
return blocks
def _get_filename_for_code_block(self, block_with_context: str) -> Tuple[str, str]:
"""
Determine the file name for a code block.
:param block_with_context: The code block with context lines
:return: Tuple of (file_name, file_type)
"""
# Define common code file extensions
COMMON_CODE_EXTENSIONS = {
'py', 'js', 'java', 'c', 'cpp', 'h', 'hpp', 'cs', 'go', 'rb', 'php',
'html', 'css', 'ts', 'jsx', 'tsx', 'vue', 'sh', 'sql', 'json', 'xml',
'yaml', 'yml', 'md', 'rs', 'swift', 'kt', 'scala', 'pl', 'r', 'lua'
}
# Split the block into lines to examine only the context around code block markers
lines = block_with_context.split('\n')
# Find the code block start marker line index
start_marker_idx = -1
for i, line in enumerate(lines):
if line.strip().startswith('```') and not line.strip() == '```':
start_marker_idx = i
break
if start_marker_idx == -1:
# No code block marker found
return "", ""
# Extract the language from the code block marker
code_marker = lines[start_marker_idx].strip()
language = ""
if len(code_marker) > 3:
language = code_marker[3:].strip().split('=')[0].strip()
# Define the context range (5 lines before and 2 after the marker)
context_start = max(0, start_marker_idx - 5)
context_end = min(len(lines), start_marker_idx + 3)
# Extract only the relevant context lines
context_lines = lines[context_start:context_end]
# First, check for explicit file headers like "## filename.ext"
for line in context_lines:
# Match patterns like "## filename.ext" or "# filename.ext"
header_match = re.search(r'^\s*#{1,6}\s+([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)\s*$', line)
if header_match:
file_name = header_match.group(1)
file_type = os.path.splitext(file_name)[1].lstrip('.')
if file_type in COMMON_CODE_EXTENSIONS:
return os.path.splitext(file_name)[0], file_type
# Simple patterns to match explicit file names in the context
file_patterns = [
# Match explicit file names in headers or text
r'(?:file|filename)[:=\s]+[\'"]?([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)[\'"]?',
# Match language=filename.ext in code markers
r'language=([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)',
# Match standalone filenames with extensions
r'\b([a-zA-Z0-9_-]+\.(py|js|java|c|cpp|h|hpp|cs|go|rb|php|html|css|ts|jsx|tsx|vue|sh|sql|json|xml|yaml|yml|md|rs|swift|kt|scala|pl|r|lua))\b',
# Match file paths in comments
r'#\s*([a-zA-Z0-9_/-]+\.[a-zA-Z0-9]+)'
]
# Check each context line for file name patterns
for line in context_lines:
line = line.strip()
for pattern in file_patterns:
matches = re.findall(pattern, line)
if matches:
for match in matches:
if isinstance(match, tuple):
# If the match is a tuple (filename, extension)
file_name = match[0]
file_type = match[1]
# Verify it's not a code reference like Direction.DOWN
if not any(keyword in file_name for keyword in ['class.', 'enum.', 'import.']):
return os.path.splitext(file_name)[0], file_type
else:
# If the match is a string (full filename)
file_name = match
file_type = os.path.splitext(file_name)[1].lstrip('.')
# Verify it's not a code reference
if file_type in COMMON_CODE_EXTENSIONS and not any(
keyword in file_name for keyword in ['class.', 'enum.', 'import.']):
return os.path.splitext(file_name)[0], file_type
# If no explicit file name found, use LLM to infer from code content
# Extract the code content
code_content = block_with_context
# Get the first 20 lines of code for LLM analysis
code_lines = code_content.split('\n')
code_preview = '\n'.join(code_lines[:20])
# Get the model to use
model_to_use = None
if hasattr(self, 'context') and self.context:
if hasattr(self.context, 'model') and self.context.model:
model_to_use = self.context.model
elif hasattr(self.context, 'team_context') and self.context.team_context:
if hasattr(self.context.team_context, 'model') and self.context.team_context.model:
model_to_use = self.context.team_context.model
# If no model is available in context, use the tool's model
if not model_to_use and hasattr(self, 'model') and self.model:
model_to_use = self.model
if model_to_use:
# Prepare a prompt for the model
prompt = f"""Analyze the following code and determine the most appropriate file name and file type/extension.
The file name should be descriptive but concise, using snake_case (lowercase with underscores).
The file type should be a standard file extension (e.g., py, js, html, css, java).
Code preview (first 20 lines):
{code_preview}
Return your answer in JSON format with these fields:
- file_name: The suggested file name (without extension)
- file_type: The suggested file extension
JSON response:"""
# Create a request to the model
request = LLMRequest(
messages=[{"role": "user", "content": prompt}],
temperature=0,
json_format=True
)
try:
response = model_to_use.call(request)
if not response.is_error:
# Clean the JSON response
json_content = self._clean_json_response(response.data["choices"][0]["message"]["content"])
result = json.loads(json_content)
file_name = result.get("file_name", "")
file_type = result.get("file_type", "")
if file_name and file_type:
return file_name, file_type
except Exception as e:
logger.error(f"Error using model to determine file name: {str(e)}")
# If we still don't have a file name, use the language as file type
if language and language in COMMON_CODE_EXTENSIONS:
timestamp = int(time.time())
return f"code_{timestamp}", language
# If all else fails, return empty strings
return "", ""
def _clean_json_response(self, text: str) -> str:
"""
Clean JSON response from LLM by removing markdown code block markers.
:param text: The text containing JSON possibly wrapped in markdown code blocks
:return: Clean JSON string
"""
# Remove markdown code block markers if present
if text.startswith("```json"):
text = text[7:]
elif text.startswith("```"):
# Find the first newline to skip the language identifier line
first_newline = text.find('\n')
if first_newline != -1:
text = text[first_newline + 1:]
if text.endswith("```"):
text = text[:-3]
return text.strip()
def _clean_code_block(self, block_with_context: str) -> str:
"""
Clean a code block by removing markdown code markers and context lines.
:param block_with_context: Code block with context lines
:return: Clean code ready for execution
"""
# Check if this is a full HTML or XML document
if block_with_context.strip().startswith(("<!DOCTYPE", "<html", "<?xml")):
return block_with_context
# Find the code block
code_block_match = re.search(r'```(?:\w+)?(?:[:=][^\n]+)?\n([\s\S]*?)\n```', block_with_context)
if code_block_match:
return code_block_match.group(1)
# If no match found, try to extract anything between ``` markers
lines = block_with_context.split('\n')
start_idx = None
end_idx = None
for i, line in enumerate(lines):
if line.strip().startswith('```'):
if start_idx is None:
start_idx = i
else:
end_idx = i
break
if start_idx is not None and end_idx is not None:
# Extract the code between the markers, excluding the markers themselves
code_lines = lines[start_idx + 1:end_idx]
return '\n'.join(code_lines)
# If all else fails, return the original content
return block_with_context
def _get_file_params_from_model(self, content, model=None):
"""
Use LLM to determine if the content is code and suggest appropriate file parameters.
Args:
content: The content to analyze
model: Optional model to use for the analysis
Returns:
tuple: (file_name, file_type, extract_code) for backward compatibility
"""
if model is None:
model = self.model
if not model:
# Default fallback if no model is available
return "output", "txt", False
prompt = f"""
Analyze the following content and determine:
1. Is this primarily code implementation (where most of the content consists of code blocks)?
2. What would be an appropriate filename and file extension?
Content to analyze: ```
{content[:500]} # Only show first 500 chars to avoid token limits ```
{"..." if len(content) > 500 else ""}
Respond in JSON format only with the following structure:
{{
"is_code": true/false, # Whether this is primarily code implementation
"filename": "suggested_filename", # Don't include extension, english words
"extension": "appropriate_extension" # Don't include the dot, e.g., "md", "py", "js"
}}
"""
try:
# Create a request to the model
request = LLMRequest(
messages=[{"role": "user", "content": prompt}],
temperature=0.1,
json_format=True
)
# Call the model using the standard interface
response = model.call(request)
if response.is_error:
logger.warning(f"Error from model: {response.error_message}")
raise Exception(f"Model error: {response.error_message}")
# Extract JSON from response
result = response.data["choices"][0]["message"]["content"]
# Clean the JSON response
result = self._clean_json_response(result)
# Parse the JSON
params = json.loads(result)
# For backward compatibility, return tuple format
file_name = params.get("filename", "output")
# Remove dot from extension if present
file_type = params.get("extension", "md").lstrip(".")
extract_code = params.get("is_code", False)
return file_name, file_type, extract_code
except Exception as e:
logger.warning(f"Error getting file parameters from model: {e}")
# Default fallback
return "output", "md", False
def _get_team_name_from_context(self) -> Optional[str]:
"""
Get team name from the agent's context.
:return: Team name or None if not found
"""
if hasattr(self, 'context') and self.context:
# Try to get team name from team_context
if hasattr(self.context, 'team_context') and self.context.team_context:
return self.context.team_context.name
# Try direct team_name attribute
if hasattr(self.context, 'name'):
return self.context.name
return None
def _get_task_id_from_context(self) -> Optional[str]:
"""
Get task ID from the agent's context.
:return: Task ID or None if not found
"""
if hasattr(self, 'context') and self.context:
# Try to get task ID from task object
if hasattr(self.context, 'task') and self.context.task:
return self.context.task.id
# Try team_context's task
if hasattr(self.context, 'team_context') and self.context.team_context:
if hasattr(self.context.team_context, 'task') and self.context.team_context.task:
return self.context.team_context.task.id
return None
def _get_task_dir_from_context(self) -> Optional[str]:
"""
Get task directory name from the team context.
:return: Task directory name or None if not found
"""
if hasattr(self, 'context') and self.context:
# Try to get from team_context
if hasattr(self.context, 'team_context') and self.context.team_context:
if hasattr(self.context.team_context, 'task_short_name') and self.context.team_context.task_short_name:
return self.context.team_context.task_short_name
# Fall back to task ID if available
return self._get_task_id_from_context()
def _extract_content_from_context(self) -> str:
"""
Extract content from the agent's context.
:return: Extracted content
"""
# Check if we have access to the agent's context
if not hasattr(self, 'context') or not self.context:
return ""
# Try to get the most recent final answer from the agent
if hasattr(self.context, 'final_answer') and self.context.final_answer:
return self.context.final_answer
# Try to get the most recent final answer from team context
if hasattr(self.context, 'team_context') and self.context.team_context:
if hasattr(self.context.team_context, 'agent_outputs') and self.context.team_context.agent_outputs:
latest_output = self.context.team_context.agent_outputs[-1].output
return latest_output
# If we have action history, try to get the most recent final answer
if hasattr(self.context, 'action_history') and self.context.action_history:
for action in reversed(self.context.action_history):
if "final_answer" in action and action["final_answer"]:
return action["final_answer"]
return ""
def _extract_code_blocks(self, content: str) -> str:
"""
Extract code blocks from markdown content.
:param content: The content to extract code blocks from
:return: Extracted code blocks
"""
# Pattern to match markdown code blocks
code_block_pattern = r'```(?:\w+)?\n([\s\S]*?)\n```'
# Find all code blocks
code_blocks = re.findall(code_block_pattern, content)
if code_blocks:
# Join all code blocks with newlines
return '\n\n'.join(code_blocks)
return content # Return original content if no code blocks found
def _infer_file_name(self, content: str) -> str:
"""
Infer a file name from the content.
:param content: The content to analyze.
:return: A suggested file name.
"""
# Check for title patterns in markdown
title_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
if title_match:
# Convert title to a valid filename
title = title_match.group(1).strip()
return self._sanitize_filename(title)
# Check for class/function definitions in code
code_match = re.search(r'(class|def|function)\s+(\w+)', content)
if code_match:
return self._sanitize_filename(code_match.group(2))
# Default name based on content type
if self._is_likely_code(content):
return "code"
elif self._is_likely_markdown(content):
return "document"
elif self._is_likely_json(content):
return "data"
else:
return "output"
def _infer_file_type(self, content: str) -> str:
"""
Infer the file type/extension from the content.
:param content: The content to analyze.
:return: A suggested file extension.
"""
# Check for common programming language patterns
if re.search(r'(import\s+[a-zA-Z0-9_]+|from\s+[a-zA-Z0-9_\.]+\s+import)', content):
return "py" # Python
elif re.search(r'(public\s+class|private\s+class|protected\s+class)', content):
return "java" # Java
elif re.search(r'(function\s+\w+\s*\(|const\s+\w+\s*=|let\s+\w+\s*=|var\s+\w+\s*=)', content):
return "js" # JavaScript
elif re.search(r'(<html|<body|<div|<p>)', content):
return "html" # HTML
elif re.search(r'(#include\s+<\w+\.h>|int\s+main\s*\()', content):
return "cpp" # C/C++
# Check for markdown
if self._is_likely_markdown(content):
return "md"
# Check for JSON
if self._is_likely_json(content):
return "json"
# Default to text
return "txt"
def _is_likely_code(self, content: str) -> bool:
"""Check if the content is likely code."""
# First check for common HTML/XML patterns
if content.strip().startswith(("<!DOCTYPE", "<html", "<?xml", "<head", "<body")):
return True
code_patterns = [
r'(class|def|function|import|from|public|private|protected|#include)',
r'(\{\s*\n|\}\s*\n|\[\s*\n|\]\s*\n)',
r'(if\s*\(|for\s*\(|while\s*\()',
r'(<\w+>.*?</\w+>)', # HTML/XML tags
r'(var|let|const)\s+\w+\s*=', # JavaScript variable declarations
r'#\s*\w+', # CSS ID selectors or Python comments
r'\.\w+\s*\{', # CSS class selectors
r'@media|@import|@font-face' # CSS at-rules
]
return any(re.search(pattern, content) for pattern in code_patterns)
def _is_likely_markdown(self, content: str) -> bool:
"""Check if the content is likely markdown."""
md_patterns = [
r'^#\s+.+$', # Headers
r'^\*\s+.+$', # Unordered lists
r'^\d+\.\s+.+$', # Ordered lists
r'\[.+\]\(.+\)', # Links
r'!\[.+\]\(.+\)' # Images
]
return any(re.search(pattern, content, re.MULTILINE) for pattern in md_patterns)
def _is_likely_json(self, content: str) -> bool:
"""Check if the content is likely JSON."""
try:
content = content.strip()
if (content.startswith('{') and content.endswith('}')) or (
content.startswith('[') and content.endswith(']')):
json.loads(content)
return True
except:
pass
return False
def _sanitize_filename(self, name: str) -> str:
"""
Sanitize a string to be used as a filename.
:param name: The string to sanitize.
:return: A sanitized filename.
"""
# Replace spaces with underscores
name = name.replace(' ', '_')
# Remove invalid characters
name = re.sub(r'[^\w\-\.]', '', name)
# Limit length
if len(name) > 50:
name = name[:50]
return name.lower()
def _process_file_path(self, file_path: str) -> Tuple[str, str]:
"""
Process a file path to extract the file name and type, and create directories if needed.
:param file_path: The file path to process
:return: Tuple of (file_name, file_type)
"""
# Get the file name and extension
file_name = os.path.basename(file_path)
file_type = os.path.splitext(file_name)[1].lstrip('.')
return os.path.splitext(file_name)[0], file_type

View File

@@ -0,0 +1,3 @@
from .find import Find
__all__ = ['Find']

177
agent/tools/find/find.py Normal file
View File

@@ -0,0 +1,177 @@
"""
Find tool - Search for files by glob pattern
"""
import os
import glob as glob_module
from typing import Dict, Any, List
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
DEFAULT_LIMIT = 1000
class Find(BaseTool):
"""Tool for finding files by pattern"""
name: str = "find"
description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
params: dict = {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'"
},
"path": {
"type": "string",
"description": "Directory to search in (default: current directory)"
},
"limit": {
"type": "integer",
"description": f"Maximum number of results (default: {DEFAULT_LIMIT})"
}
},
"required": ["pattern"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file search
:param args: Search parameters
:return: Search results or error
"""
pattern = args.get("pattern", "").strip()
search_path = args.get("path", ".").strip()
limit = args.get("limit", DEFAULT_LIMIT)
if not pattern:
return ToolResult.fail("Error: pattern parameter is required")
# Resolve search path
absolute_path = self._resolve_path(search_path)
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: Path not found: {search_path}")
if not os.path.isdir(absolute_path):
return ToolResult.fail(f"Error: Not a directory: {search_path}")
try:
# Load .gitignore patterns
ignore_patterns = self._load_gitignore(absolute_path)
# Search for files
results = []
search_pattern = os.path.join(absolute_path, pattern)
# Use glob with recursive support
for file_path in glob_module.glob(search_pattern, recursive=True):
# Skip if matches ignore patterns
if self._should_ignore(file_path, absolute_path, ignore_patterns):
continue
# Get relative path
relative_path = os.path.relpath(file_path, absolute_path)
# Add trailing slash for directories
if os.path.isdir(file_path):
relative_path += '/'
results.append(relative_path)
if len(results) >= limit:
break
if not results:
return ToolResult.success({"message": "No files found matching pattern", "files": []})
# Sort results
results.sort()
# Format output
raw_output = '\n'.join(results)
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
output = truncation.content
details = {}
notices = []
result_limit_reached = len(results) >= limit
if result_limit_reached:
notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern")
details["result_limit_reached"] = limit
if truncation.truncated:
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
details["truncation"] = truncation.to_dict()
if notices:
output += f"\n\n[{'. '.join(notices)}]"
return ToolResult.success({
"output": output,
"file_count": len(results),
"details": details if details else None
})
except Exception as e:
return ToolResult.fail(f"Error executing find: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))
def _load_gitignore(self, directory: str) -> List[str]:
"""Load .gitignore patterns from directory"""
patterns = []
gitignore_path = os.path.join(directory, '.gitignore')
if os.path.exists(gitignore_path):
try:
with open(gitignore_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line and not line.startswith('#'):
patterns.append(line)
except:
pass
# Add common ignore patterns
patterns.extend([
'.git',
'__pycache__',
'*.pyc',
'node_modules',
'.DS_Store'
])
return patterns
def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool:
"""Check if file should be ignored based on patterns"""
relative_path = os.path.relpath(file_path, base_path)
for pattern in patterns:
# Simple pattern matching
if pattern in relative_path:
return True
# Check if it's a directory pattern
if pattern.endswith('/'):
if relative_path.startswith(pattern.rstrip('/')):
return True
return False

View File

@@ -0,0 +1,48 @@
import requests
from agent.tools.base_tool import BaseTool, ToolResult
class GoogleSearch(BaseTool):
name: str = "google_search"
description: str = "A tool to perform Google searches using the Serper API."
params: dict = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query to perform."
}
},
"required": ["query"]
}
config: dict = {}
def __init__(self, config=None):
self.config = config or {}
def execute(self, args: dict) -> ToolResult:
api_key = self.config.get("api_key") # Replace with your actual API key
url = "https://google.serper.dev/search"
headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json"
}
data = {
"q": args.get("query"),
"k": 10
}
response = requests.post(url, headers=headers, json=data)
result = response.json()
if result.get("statusCode") and result.get("statusCode") == 503:
return ToolResult.fail(result=result)
else:
# Check if the returned result contains the 'organic' key and ensure it is a list
if 'organic' in result and isinstance(result.get('organic'), list):
result_data = result['organic']
else:
# If there are no organic results, return the full response or an empty list
result_data = result.get('organic', []) if isinstance(result.get('organic'), list) else []
return ToolResult.success(result=result_data)

View File

@@ -0,0 +1,3 @@
from .grep import Grep
__all__ = ['Grep']

248
agent/tools/grep/grep.py Normal file
View File

@@ -0,0 +1,248 @@
"""
Grep tool - Search file contents for patterns
Uses ripgrep (rg) for fast searching
"""
import os
import re
import subprocess
import json
from typing import Dict, Any, List, Optional
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import (
truncate_head, truncate_line, format_size,
DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH
)
DEFAULT_LIMIT = 100
class Grep(BaseTool):
"""Tool for searching file contents"""
name: str = "grep"
description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars."
params: dict = {
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Search pattern (regex or literal string)"
},
"path": {
"type": "string",
"description": "Directory or file to search (default: current directory)"
},
"glob": {
"type": "string",
"description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'"
},
"ignoreCase": {
"type": "boolean",
"description": "Case-insensitive search (default: false)"
},
"literal": {
"type": "boolean",
"description": "Treat pattern as literal string instead of regex (default: false)"
},
"context": {
"type": "integer",
"description": "Number of lines to show before and after each match (default: 0)"
},
"limit": {
"type": "integer",
"description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})"
}
},
"required": ["pattern"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
self.rg_path = self._find_ripgrep()
def _find_ripgrep(self) -> Optional[str]:
"""Find ripgrep executable"""
try:
result = subprocess.run(['which', 'rg'], capture_output=True, text=True)
if result.returncode == 0:
return result.stdout.strip()
except:
pass
return None
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute grep search
:param args: Search parameters
:return: Search results or error
"""
if not self.rg_path:
return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.")
pattern = args.get("pattern", "").strip()
search_path = args.get("path", ".").strip()
glob = args.get("glob")
ignore_case = args.get("ignoreCase", False)
literal = args.get("literal", False)
context = args.get("context", 0)
limit = args.get("limit", DEFAULT_LIMIT)
if not pattern:
return ToolResult.fail("Error: pattern parameter is required")
# Resolve search path
absolute_path = self._resolve_path(search_path)
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: Path not found: {search_path}")
# Build ripgrep command
cmd = [
self.rg_path,
'--json',
'--line-number',
'--color=never',
'--hidden'
]
if ignore_case:
cmd.append('--ignore-case')
if literal:
cmd.append('--fixed-strings')
if glob:
cmd.extend(['--glob', glob])
cmd.extend([pattern, absolute_path])
try:
# Execute ripgrep
result = subprocess.run(
cmd,
cwd=self.cwd,
capture_output=True,
text=True,
timeout=30
)
# Parse JSON output
matches = []
match_count = 0
for line in result.stdout.splitlines():
if not line.strip():
continue
try:
event = json.loads(line)
if event.get('type') == 'match':
data = event.get('data', {})
file_path = data.get('path', {}).get('text')
line_number = data.get('line_number')
if file_path and line_number:
matches.append({
'file': file_path,
'line': line_number
})
match_count += 1
if match_count >= limit:
break
except json.JSONDecodeError:
continue
if match_count == 0:
return ToolResult.success({"message": "No matches found", "matches": []})
# Format output with context
output_lines = []
lines_truncated = False
is_directory = os.path.isdir(absolute_path)
for match in matches:
file_path = match['file']
line_number = match['line']
# Format file path
if is_directory:
relative_path = os.path.relpath(file_path, absolute_path)
else:
relative_path = os.path.basename(file_path)
# Read file and get context
try:
with open(file_path, 'r', encoding='utf-8') as f:
file_lines = f.read().split('\n')
# Calculate context range
start = max(0, line_number - 1 - context) if context > 0 else line_number - 1
end = min(len(file_lines), line_number + context) if context > 0 else line_number
# Format lines with context
for i in range(start, end):
line_text = file_lines[i].replace('\r', '')
# Truncate long lines
truncated_text, was_truncated = truncate_line(line_text)
if was_truncated:
lines_truncated = True
# Format output
current_line = i + 1
if current_line == line_number:
output_lines.append(f"{relative_path}:{current_line}: {truncated_text}")
else:
output_lines.append(f"{relative_path}-{current_line}- {truncated_text}")
except Exception:
output_lines.append(f"{relative_path}:{line_number}: (unable to read file)")
# Apply byte truncation
raw_output = '\n'.join(output_lines)
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
output = truncation.content
details = {}
notices = []
if match_count >= limit:
notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern")
details["match_limit_reached"] = limit
if truncation.truncated:
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
details["truncation"] = truncation.to_dict()
if lines_truncated:
notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines")
details["lines_truncated"] = True
if notices:
output += f"\n\n[{'. '.join(notices)}]"
return ToolResult.success({
"output": output,
"match_count": match_count,
"details": details if details else None
})
except subprocess.TimeoutExpired:
return ToolResult.fail("Error: Search timed out after 30 seconds")
except Exception as e:
return ToolResult.fail(f"Error executing grep: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -0,0 +1,3 @@
from .ls import Ls
__all__ = ['Ls']

125
agent/tools/ls/ls.py Normal file
View File

@@ -0,0 +1,125 @@
"""
Ls tool - List directory contents
"""
import os
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
DEFAULT_LIMIT = 500
class Ls(BaseTool):
"""Tool for listing directory contents"""
name: str = "ls"
description: str = f"List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to {DEFAULT_LIMIT} entries or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Directory to list (default: current directory)"
},
"limit": {
"type": "integer",
"description": f"Maximum number of entries to return (default: {DEFAULT_LIMIT})"
}
},
"required": []
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute directory listing
:param args: Listing parameters
:return: Directory contents or error
"""
path = args.get("path", ".").strip()
limit = args.get("limit", DEFAULT_LIMIT)
# Resolve path
absolute_path = self._resolve_path(path)
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: Path not found: {path}")
if not os.path.isdir(absolute_path):
return ToolResult.fail(f"Error: Not a directory: {path}")
try:
# Read directory entries
entries = os.listdir(absolute_path)
# Sort alphabetically (case-insensitive)
entries.sort(key=lambda x: x.lower())
# Format entries with directory indicators
results = []
entry_limit_reached = False
for entry in entries:
if len(results) >= limit:
entry_limit_reached = True
break
full_path = os.path.join(absolute_path, entry)
try:
if os.path.isdir(full_path):
results.append(entry + '/')
else:
results.append(entry)
except:
# Skip entries we can't stat
continue
if not results:
return ToolResult.success({"message": "(empty directory)", "entries": []})
# Format output
raw_output = '\n'.join(results)
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
output = truncation.content
details = {}
notices = []
if entry_limit_reached:
notices.append(f"{limit} entries limit reached. Use limit={limit * 2} for more")
details["entry_limit_reached"] = limit
if truncation.truncated:
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
details["truncation"] = truncation.to_dict()
if notices:
output += f"\n\n[{'. '.join(notices)}]"
return ToolResult.success({
"output": output,
"entry_count": len(results),
"details": details if details else None
})
except PermissionError:
return ToolResult.fail(f"Error: Permission denied reading directory: {path}")
except Exception as e:
return ToolResult.fail(f"Error listing directory: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""Resolve path to absolute path"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -0,0 +1,10 @@
"""
Memory tools for Agent
Provides memory_search and memory_get tools
"""
from agent.tools.memory.memory_search import MemorySearchTool
from agent.tools.memory.memory_get import MemoryGetTool
__all__ = ['MemorySearchTool', 'MemoryGetTool']

View File

@@ -0,0 +1,107 @@
"""
Memory get tool
Allows agents to read specific sections from memory files
"""
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool
class MemoryGetTool(BaseTool):
"""Tool for reading memory file contents"""
name: str = "memory_get"
description: str = (
"Read specific content from memory files. "
"Use this to get full context from a memory file or specific line range."
)
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')"
},
"start_line": {
"type": "integer",
"description": "Starting line number (optional, default: 1)",
"default": 1
},
"num_lines": {
"type": "integer",
"description": "Number of lines to read (optional, reads all if not specified)"
}
},
"required": ["path"]
}
def __init__(self, memory_manager):
"""
Initialize memory get tool
Args:
memory_manager: MemoryManager instance
"""
super().__init__()
self.memory_manager = memory_manager
def execute(self, args: dict):
"""
Execute memory file read
Args:
args: Dictionary with path, start_line, num_lines
Returns:
ToolResult with file content
"""
from agent.tools.base_tool import ToolResult
path = args.get("path")
start_line = args.get("start_line", 1)
num_lines = args.get("num_lines")
if not path:
return ToolResult.fail("Error: path parameter is required")
try:
workspace_dir = self.memory_manager.config.get_workspace()
file_path = workspace_dir / path
if not file_path.exists():
return ToolResult.fail(f"Error: File not found: {path}")
content = file_path.read_text()
lines = content.split('\n')
# Handle line range
if start_line < 1:
start_line = 1
start_idx = start_line - 1
if num_lines:
end_idx = start_idx + num_lines
selected_lines = lines[start_idx:end_idx]
else:
selected_lines = lines[start_idx:]
result = '\n'.join(selected_lines)
# Add metadata
total_lines = len(lines)
shown_lines = len(selected_lines)
output = [
f"File: {path}",
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
"",
result
]
return ToolResult.success('\n'.join(output))
except Exception as e:
return ToolResult.fail(f"Error reading memory file: {str(e)}")

View File

@@ -0,0 +1,96 @@
"""
Memory search tool
Allows agents to search their memory using semantic and keyword search
"""
from typing import Dict, Any, Optional
from agent.tools.base_tool import BaseTool
class MemorySearchTool(BaseTool):
"""Tool for searching agent memory"""
name: str = "memory_search"
description: str = (
"Search agent's long-term memory using semantic and keyword search. "
"Use this to recall past conversations, preferences, and knowledge."
)
params: dict = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query (can be natural language question or keywords)"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to return (default: 10)",
"default": 10
},
"min_score": {
"type": "number",
"description": "Minimum relevance score (0-1, default: 0.3)",
"default": 0.3
}
},
"required": ["query"]
}
def __init__(self, memory_manager, user_id: Optional[str] = None):
"""
Initialize memory search tool
Args:
memory_manager: MemoryManager instance
user_id: Optional user ID for scoped search
"""
super().__init__()
self.memory_manager = memory_manager
self.user_id = user_id
def execute(self, args: dict):
"""
Execute memory search
Args:
args: Dictionary with query, max_results, min_score
Returns:
ToolResult with formatted search results
"""
from agent.tools.base_tool import ToolResult
import asyncio
query = args.get("query")
max_results = args.get("max_results", 10)
min_score = args.get("min_score", 0.3)
if not query:
return ToolResult.fail("Error: query parameter is required")
try:
# Run async search in sync context
results = asyncio.run(self.memory_manager.search(
query=query,
user_id=self.user_id,
max_results=max_results,
min_score=min_score,
include_shared=True
))
if not results:
return ToolResult.success(f"No relevant memories found for query: {query}")
# Format results
output = [f"Found {len(results)} relevant memories:\n"]
for i, result in enumerate(results, 1):
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
output.append(f" Score: {result.score:.3f}")
output.append(f" Snippet: {result.snippet}")
return ToolResult.success("\n".join(output))
except Exception as e:
return ToolResult.fail(f"Error searching memory: {str(e)}")

View File

@@ -0,0 +1,3 @@
from .read import Read
__all__ = ['Read']

336
agent/tools/read/read.py Normal file
View File

@@ -0,0 +1,336 @@
"""
Read tool - Read file contents
Supports text files, images (jpg, png, gif, webp), and PDF files
"""
import os
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
class Read(BaseTool):
"""Tool for reading file contents"""
name: str = "read"
description: str = f"Read the contents of a file. Supports text files, PDF files, and images (jpg, png, gif, webp). For text files, output is truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Use offset/limit for large files."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to read (relative or absolute)"
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (1-indexed, optional)"
},
"limit": {
"type": "integer",
"description": "Maximum number of lines to read (optional)"
}
},
"required": ["path"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
# Supported image formats
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp'}
# Supported PDF format
self.pdf_extensions = {'.pdf'}
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file read operation
:param args: Contains file path and optional offset/limit parameters
:return: File content or error message
"""
path = args.get("path", "").strip()
offset = args.get("offset")
limit = args.get("limit")
if not path:
return ToolResult.fail("Error: path parameter is required")
# Resolve path
absolute_path = self._resolve_path(path)
# Check if file exists
if not os.path.exists(absolute_path):
return ToolResult.fail(f"Error: File not found: {path}")
# Check if readable
if not os.access(absolute_path, os.R_OK):
return ToolResult.fail(f"Error: File is not readable: {path}")
# Check file type
file_ext = Path(absolute_path).suffix.lower()
# Check if image
if file_ext in self.image_extensions:
return self._read_image(absolute_path, file_ext)
# Check if PDF
if file_ext in self.pdf_extensions:
return self._read_pdf(absolute_path, path, offset, limit)
# Read text file
return self._read_text(absolute_path, path, offset, limit)
def _resolve_path(self, path: str) -> str:
"""
Resolve path to absolute path
:param path: Relative or absolute path
:return: Absolute path
"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
"""
Read image file
:param absolute_path: Absolute path to the image file
:param file_ext: File extension
:return: Result containing image information
"""
try:
# Read image file
with open(absolute_path, 'rb') as f:
image_data = f.read()
# Get file size
file_size = len(image_data)
# Return image information (actual image data can be base64 encoded when needed)
import base64
base64_data = base64.b64encode(image_data).decode('utf-8')
# Determine MIME type
mime_type_map = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.gif': 'image/gif',
'.webp': 'image/webp'
}
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
result = {
"type": "image",
"mime_type": mime_type,
"size": file_size,
"size_formatted": format_size(file_size),
"data": base64_data # Base64 encoded image data
}
return ToolResult.success(result)
except Exception as e:
return ToolResult.fail(f"Error reading image file: {str(e)}")
def _read_text(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
"""
Read text file
:param absolute_path: Absolute path to the file
:param display_path: Path to display
:param offset: Starting line number (1-indexed)
:param limit: Maximum number of lines to read
:return: File content or error message
"""
try:
# Read file
with open(absolute_path, 'r', encoding='utf-8') as f:
content = f.read()
all_lines = content.split('\n')
total_file_lines = len(all_lines)
# Apply offset (if specified)
start_line = 0
if offset is not None:
start_line = max(0, offset - 1) # Convert to 0-indexed
if start_line >= total_file_lines:
return ToolResult.fail(
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
)
start_line_display = start_line + 1 # For display (1-indexed)
# If user specified limit, use it
selected_content = content
user_limited_lines = None
if limit is not None:
end_line = min(start_line + limit, total_file_lines)
selected_content = '\n'.join(all_lines[start_line:end_line])
user_limited_lines = end_line - start_line
elif offset is not None:
selected_content = '\n'.join(all_lines[start_line:])
# Apply truncation (considering line count and byte limits)
truncation = truncate_head(selected_content)
output_text = ""
details = {}
if truncation.first_line_exceeds_limit:
# First line exceeds 30KB limit
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
output_text = f"[Line {start_line_display} is {first_line_size}, exceeds {format_size(DEFAULT_MAX_BYTES)} limit. Use bash tool to read: head -c {DEFAULT_MAX_BYTES} {display_path} | tail -n +{start_line_display}]"
details["truncation"] = truncation.to_dict()
elif truncation.truncated:
# Truncation occurred
end_line_display = start_line_display + truncation.output_lines - 1
next_offset = end_line_display + 1
output_text = truncation.content
if truncation.truncated_by == "lines":
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines}. Use offset={next_offset} to continue.]"
else:
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
details["truncation"] = truncation.to_dict()
elif user_limited_lines is not None and start_line + user_limited_lines < total_file_lines:
# User specified limit, more content available, but no truncation
remaining = total_file_lines - (start_line + user_limited_lines)
next_offset = start_line + user_limited_lines + 1
output_text = truncation.content
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
else:
# No truncation, no exceeding user limit
output_text = truncation.content
result = {
"content": output_text,
"total_lines": total_file_lines,
"start_line": start_line_display,
"output_lines": truncation.output_lines
}
if details:
result["details"] = details
return ToolResult.success(result)
except UnicodeDecodeError:
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {display_path}")
except Exception as e:
return ToolResult.fail(f"Error reading file: {str(e)}")
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
"""
Read PDF file content
:param absolute_path: Absolute path to the file
:param display_path: Path to display
:param offset: Starting line number (1-indexed)
:param limit: Maximum number of lines to read
:return: PDF text content or error message
"""
try:
# Try to import pypdf
try:
from pypdf import PdfReader
except ImportError:
return ToolResult.fail(
"Error: pypdf library not installed. Install with: pip install pypdf"
)
# Read PDF
reader = PdfReader(absolute_path)
total_pages = len(reader.pages)
# Extract text from all pages
text_parts = []
for page_num, page in enumerate(reader.pages, 1):
page_text = page.extract_text()
if page_text.strip():
text_parts.append(f"--- Page {page_num} ---\n{page_text}")
if not text_parts:
return ToolResult.success({
"content": f"[PDF file with {total_pages} pages, but no text content could be extracted]",
"total_pages": total_pages,
"message": "PDF may contain only images or be encrypted"
})
# Merge all text
full_content = "\n\n".join(text_parts)
all_lines = full_content.split('\n')
total_lines = len(all_lines)
# Apply offset and limit (same logic as text files)
start_line = 0
if offset is not None:
start_line = max(0, offset - 1)
if start_line >= total_lines:
return ToolResult.fail(
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
)
start_line_display = start_line + 1
selected_content = full_content
user_limited_lines = None
if limit is not None:
end_line = min(start_line + limit, total_lines)
selected_content = '\n'.join(all_lines[start_line:end_line])
user_limited_lines = end_line - start_line
elif offset is not None:
selected_content = '\n'.join(all_lines[start_line:])
# Apply truncation
truncation = truncate_head(selected_content)
output_text = ""
details = {}
if truncation.truncated:
end_line_display = start_line_display + truncation.output_lines - 1
next_offset = end_line_display + 1
output_text = truncation.content
if truncation.truncated_by == "lines":
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
else:
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
details["truncation"] = truncation.to_dict()
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
remaining = total_lines - (start_line + user_limited_lines)
next_offset = start_line + user_limited_lines + 1
output_text = truncation.content
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
else:
output_text = truncation.content
result = {
"content": output_text,
"total_pages": total_pages,
"total_lines": total_lines,
"start_line": start_line_display,
"output_lines": truncation.output_lines
}
if details:
result["details"] = details
return ToolResult.success(result)
except Exception as e:
return ToolResult.fail(f"Error reading PDF file: {str(e)}")

View File

@@ -0,0 +1,3 @@
from .terminal import Terminal
__all__ = ['Terminal']

View File

@@ -0,0 +1,100 @@
import platform
import subprocess
from typing import Dict, Any
from agent.tools.base_tool import BaseTool, ToolResult
class Terminal(BaseTool):
name: str = "terminal"
description: str = "A tool to run terminal commands on the local system"
params: dict = {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": f"The terminal command to execute which should be valid in {platform.system()} platform"
}
},
"required": ["command"]
}
config: dict = {}
def __init__(self, config=None):
self.config = config or {}
# Set of dangerous commands that should be blocked
self.command_ban_set = {"halt", "poweroff", "shutdown", "reboot", "rm", "kill",
"exit", "sudo", "su", "userdel", "groupdel", "logout", "alias"}
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute a terminal command safely.
:param args: Dictionary containing the command to execute
:return: Result of the command execution
"""
command = args.get("command", "").strip()
# Check if the command is safe to execute
if not self._is_safe_command(command):
return ToolResult.fail(result=f"Command '{command}' is not allowed for security reasons.")
try:
result = subprocess.run(
command,
shell=True,
check=True, # Raise exception on non-zero return code
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=self.config.get("timeout", 30)
)
return ToolResult.success({
"stdout": result.stdout,
"stderr": result.stderr,
"return_code": result.returncode,
"command": command
})
except subprocess.CalledProcessError as e:
# Preserve the original error handling for CalledProcessError
return ToolResult.fail({
"stdout": e.stdout,
"stderr": e.stderr,
"return_code": e.returncode,
"command": command
})
except subprocess.TimeoutExpired:
return ToolResult.fail(result=f"Command timed out after {self.config.get('timeout', 20)} seconds.")
except Exception as e:
return ToolResult.fail(result=f"Error executing command: {str(e)}")
def _is_safe_command(self, command: str) -> bool:
"""
Check if a command is safe to execute.
:param command: The command to check
:return: True if the command is safe, False otherwise
"""
# Split the command to get the base command
cmd_parts = command.split()
if not cmd_parts:
return False
base_cmd = cmd_parts[0].lower()
# Check if the base command is in the ban list
if base_cmd in self.command_ban_set:
return False
# Check for sudo/su commands
if any(banned in command.lower() for banned in ["sudo ", "su -"]):
return False
# Check for rm -rf or similar dangerous patterns
if "rm" in base_cmd and ("-rf" in command or "-r" in command or "-f" in command):
return False
# Additional security checks can be added here
return True

208
agent/tools/tool_manager.py Normal file
View File

@@ -0,0 +1,208 @@
import importlib
import importlib.util
from pathlib import Path
from typing import Dict, Any, Type
from agent.tools.base_tool import BaseTool
from common.log import logger
class ToolManager:
"""
Tool manager for managing tools.
"""
_instance = None
def __new__(cls):
"""Singleton pattern to ensure only one instance of ToolManager exists."""
if cls._instance is None:
cls._instance = super(ToolManager, cls).__new__(cls)
cls._instance.tool_classes = {} # Store tool classes instead of instances
cls._instance._initialized = False
return cls._instance
def __init__(self):
# Initialize only once
if not hasattr(self, 'tool_classes'):
self.tool_classes = {} # Dictionary to store tool classes
def load_tools(self, tools_dir: str = "", config_dict=None):
"""
Load tools from both directory and configuration.
:param tools_dir: Directory to scan for tool modules
"""
if tools_dir:
self._load_tools_from_directory(tools_dir)
self._configure_tools_from_config()
else:
self._load_tools_from_init()
self._configure_tools_from_config(config_dict)
def _load_tools_from_init(self) -> bool:
"""
Load tool classes from tools.__init__.__all__
:return: True if tools were loaded, False otherwise
"""
try:
# Try to import the tools package
tools_package = importlib.import_module("agent.tools")
# Check if __all__ is defined
if hasattr(tools_package, "__all__"):
tool_classes = tools_package.__all__
# Import each tool class directly from the tools package
for class_name in tool_classes:
try:
# Skip base classes
if class_name in ["BaseTool", "ToolManager"]:
continue
# Get the class directly from the tools package
if hasattr(tools_package, class_name):
cls = getattr(tools_package, class_name)
if (
isinstance(cls, type)
and issubclass(cls, BaseTool)
and cls != BaseTool
):
try:
# Create a temporary instance to get the name
temp_instance = cls()
tool_name = temp_instance.name
# Store the class, not the instance
self.tool_classes[tool_name] = cls
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
except ImportError as e:
# Ignore browser_use dependency missing errors
if "browser_use" in str(e):
pass
else:
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
except Exception as e:
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
except Exception as e:
logger.error(f"Error importing class {class_name}: {e}")
return len(self.tool_classes) > 0
return False
except ImportError:
logger.warning("Could not import agent.tools package")
return False
except Exception as e:
logger.error(f"Error loading tools from __init__.__all__: {e}")
return False
def _load_tools_from_directory(self, tools_dir: str):
"""Dynamically load tool classes from directory"""
tools_path = Path(tools_dir)
# Traverse all .py files
for py_file in tools_path.rglob("*.py"):
# Skip initialization files and base tool files
if py_file.name in ["__init__.py", "base_tool.py", "tool_manager.py"]:
continue
# Get module name
module_name = py_file.stem
try:
# Load module directly from file
spec = importlib.util.spec_from_file_location(module_name, py_file)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Find tool classes in the module
for attr_name in dir(module):
cls = getattr(module, attr_name)
if (
isinstance(cls, type)
and issubclass(cls, BaseTool)
and cls != BaseTool
):
try:
# Create a temporary instance to get the name
temp_instance = cls()
tool_name = temp_instance.name
# Store the class, not the instance
self.tool_classes[tool_name] = cls
except ImportError as e:
# Ignore browser_use dependency missing errors
if "browser_use" in str(e):
pass
else:
print(f"Error initializing tool class {cls.__name__}: {e}")
except Exception as e:
print(f"Error initializing tool class {cls.__name__}: {e}")
except Exception as e:
print(f"Error importing module {py_file}: {e}")
def _configure_tools_from_config(self, config_dict=None):
"""Configure tool classes based on configuration file"""
try:
# Get tools configuration
tools_config = config_dict or config().get("tools", {})
# Record tools that are configured but not loaded
missing_tools = []
# Store configurations for later use when instantiating
self.tool_configs = tools_config
# Check which configured tools are missing
for tool_name in tools_config:
if tool_name not in self.tool_classes:
missing_tools.append(tool_name)
# If there are missing tools, record warnings
if missing_tools:
for tool_name in missing_tools:
if tool_name == "browser":
logger.error(
"Browser tool is configured but could not be loaded. "
"Please install the required dependency with: "
"pip install browser-use>=0.1.40 or pip install agentmesh-sdk[full]"
)
else:
logger.warning(f"Tool '{tool_name}' is configured but could not be loaded.")
except Exception as e:
logger.error(f"Error configuring tools from config: {e}")
def create_tool(self, name: str) -> BaseTool:
"""
Get a new instance of a tool by name.
:param name: The name of the tool to get.
:return: A new instance of the tool or None if not found.
"""
tool_class = self.tool_classes.get(name)
if tool_class:
# Create a new instance
tool_instance = tool_class()
# Apply configuration if available
if hasattr(self, 'tool_configs') and name in self.tool_configs:
tool_instance.config = self.tool_configs[name]
return tool_instance
return None
def list_tools(self) -> dict:
"""
Get information about all loaded tools.
:return: A dictionary with tool information.
"""
result = {}
for name, tool_class in self.tool_classes.items():
# Create a temporary instance to get schema
temp_instance = tool_class()
result[name] = {
"description": temp_instance.description,
"parameters": temp_instance.get_json_schema()
}
return result

View File

@@ -0,0 +1,40 @@
from .truncate import (
truncate_head,
truncate_tail,
truncate_line,
format_size,
TruncationResult,
DEFAULT_MAX_LINES,
DEFAULT_MAX_BYTES,
GREP_MAX_LINE_LENGTH
)
from .diff import (
strip_bom,
detect_line_ending,
normalize_to_lf,
restore_line_endings,
normalize_for_fuzzy_match,
fuzzy_find_text,
generate_diff_string,
FuzzyMatchResult
)
__all__ = [
'truncate_head',
'truncate_tail',
'truncate_line',
'format_size',
'TruncationResult',
'DEFAULT_MAX_LINES',
'DEFAULT_MAX_BYTES',
'GREP_MAX_LINE_LENGTH',
'strip_bom',
'detect_line_ending',
'normalize_to_lf',
'restore_line_endings',
'normalize_for_fuzzy_match',
'fuzzy_find_text',
'generate_diff_string',
'FuzzyMatchResult'
]

167
agent/tools/utils/diff.py Normal file
View File

@@ -0,0 +1,167 @@
"""
Diff tools for file editing
Provides fuzzy matching and diff generation functionality
"""
import difflib
import re
from typing import Optional, Tuple
def strip_bom(text: str) -> Tuple[str, str]:
"""
Remove BOM (Byte Order Mark)
:param text: Original text
:return: (BOM, text after removing BOM)
"""
if text.startswith('\ufeff'):
return '\ufeff', text[1:]
return '', text
def detect_line_ending(text: str) -> str:
"""
Detect line ending type
:param text: Text content
:return: Line ending type ('\r\n' or '\n')
"""
if '\r\n' in text:
return '\r\n'
return '\n'
def normalize_to_lf(text: str) -> str:
"""
Normalize all line endings to LF (\n)
:param text: Original text
:return: Normalized text
"""
return text.replace('\r\n', '\n').replace('\r', '\n')
def restore_line_endings(text: str, original_ending: str) -> str:
"""
Restore original line endings
:param text: LF normalized text
:param original_ending: Original line ending
:return: Text with restored line endings
"""
if original_ending == '\r\n':
return text.replace('\n', '\r\n')
return text
def normalize_for_fuzzy_match(text: str) -> str:
"""
Normalize text for fuzzy matching
Remove excess whitespace but preserve basic structure
:param text: Original text
:return: Normalized text
"""
# Compress multiple spaces to one
text = re.sub(r'[ \t]+', ' ', text)
# Remove trailing spaces
text = re.sub(r' +\n', '\n', text)
# Remove leading spaces (but preserve indentation structure, only remove excess)
lines = text.split('\n')
normalized_lines = []
for line in lines:
# Preserve indentation but normalize to multiples of single spaces
stripped = line.lstrip()
if stripped:
indent_count = len(line) - len(stripped)
# Normalize indentation (convert tabs to spaces)
normalized_indent = ' ' * indent_count
normalized_lines.append(normalized_indent + stripped)
else:
normalized_lines.append('')
return '\n'.join(normalized_lines)
class FuzzyMatchResult:
"""Fuzzy match result"""
def __init__(self, found: bool, index: int = -1, match_length: int = 0, content_for_replacement: str = ""):
self.found = found
self.index = index
self.match_length = match_length
self.content_for_replacement = content_for_replacement
def fuzzy_find_text(content: str, old_text: str) -> FuzzyMatchResult:
"""
Find text in content, try exact match first, then fuzzy match
:param content: Content to search in
:param old_text: Text to find
:return: Match result
"""
# First try exact match
index = content.find(old_text)
if index != -1:
return FuzzyMatchResult(
found=True,
index=index,
match_length=len(old_text),
content_for_replacement=content
)
# Try fuzzy match
fuzzy_content = normalize_for_fuzzy_match(content)
fuzzy_old_text = normalize_for_fuzzy_match(old_text)
index = fuzzy_content.find(fuzzy_old_text)
if index != -1:
# Fuzzy match successful, use normalized content for replacement
return FuzzyMatchResult(
found=True,
index=index,
match_length=len(fuzzy_old_text),
content_for_replacement=fuzzy_content
)
# Not found
return FuzzyMatchResult(found=False)
def generate_diff_string(old_content: str, new_content: str) -> dict:
"""
Generate unified diff string
:param old_content: Old content
:param new_content: New content
:return: Dictionary containing diff and first changed line number
"""
old_lines = old_content.split('\n')
new_lines = new_content.split('\n')
# Generate unified diff
diff_lines = list(difflib.unified_diff(
old_lines,
new_lines,
lineterm='',
fromfile='original',
tofile='modified'
))
# Find first changed line number
first_changed_line = None
for line in diff_lines:
if line.startswith('@@'):
# Parse @@ -1,3 +1,3 @@ format
match = re.search(r'@@ -\d+,?\d* \+(\d+)', line)
if match:
first_changed_line = int(match.group(1))
break
diff_string = '\n'.join(diff_lines)
return {
'diff': diff_string,
'first_changed_line': first_changed_line
}

View File

@@ -0,0 +1,292 @@
"""
Shared truncation utilities for tool outputs.
Truncation is based on two independent limits - whichever is hit first wins:
- Line limit (default: 2000 lines)
- Byte limit (default: 50KB)
Never returns partial lines (except bash tail truncation edge case).
"""
from typing import Dict, Any, Optional, Literal
DEFAULT_MAX_LINES = 2000
DEFAULT_MAX_BYTES = 50 * 1024 # 50KB
GREP_MAX_LINE_LENGTH = 500 # Max chars per grep match line
class TruncationResult:
"""Truncation result"""
def __init__(
self,
content: str,
truncated: bool,
truncated_by: Optional[Literal["lines", "bytes"]],
total_lines: int,
total_bytes: int,
output_lines: int,
output_bytes: int,
last_line_partial: bool = False,
first_line_exceeds_limit: bool = False,
max_lines: int = DEFAULT_MAX_LINES,
max_bytes: int = DEFAULT_MAX_BYTES
):
self.content = content
self.truncated = truncated
self.truncated_by = truncated_by
self.total_lines = total_lines
self.total_bytes = total_bytes
self.output_lines = output_lines
self.output_bytes = output_bytes
self.last_line_partial = last_line_partial
self.first_line_exceeds_limit = first_line_exceeds_limit
self.max_lines = max_lines
self.max_bytes = max_bytes
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"content": self.content,
"truncated": self.truncated,
"truncated_by": self.truncated_by,
"total_lines": self.total_lines,
"total_bytes": self.total_bytes,
"output_lines": self.output_lines,
"output_bytes": self.output_bytes,
"last_line_partial": self.last_line_partial,
"first_line_exceeds_limit": self.first_line_exceeds_limit,
"max_lines": self.max_lines,
"max_bytes": self.max_bytes
}
def format_size(bytes_count: int) -> str:
"""Format bytes as human-readable size"""
if bytes_count < 1024:
return f"{bytes_count}B"
elif bytes_count < 1024 * 1024:
return f"{bytes_count / 1024:.1f}KB"
else:
return f"{bytes_count / (1024 * 1024):.1f}MB"
def truncate_head(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
"""
Truncate content from the head (keep first N lines/bytes).
Suitable for file reads where you want to see the beginning.
Never returns partial lines. If first line exceeds byte limit,
returns empty content with first_line_exceeds_limit=True.
:param content: Content to truncate
:param max_lines: Maximum number of lines (default: 2000)
:param max_bytes: Maximum number of bytes (default: 50KB)
:return: Truncation result
"""
if max_lines is None:
max_lines = DEFAULT_MAX_LINES
if max_bytes is None:
max_bytes = DEFAULT_MAX_BYTES
total_bytes = len(content.encode('utf-8'))
lines = content.split('\n')
total_lines = len(lines)
# Check if no truncation is needed
if total_lines <= max_lines and total_bytes <= max_bytes:
return TruncationResult(
content=content,
truncated=False,
truncated_by=None,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=total_lines,
output_bytes=total_bytes,
last_line_partial=False,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
# Check if first line alone exceeds byte limit
first_line_bytes = len(lines[0].encode('utf-8'))
if first_line_bytes > max_bytes:
return TruncationResult(
content="",
truncated=True,
truncated_by="bytes",
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=0,
output_bytes=0,
last_line_partial=False,
first_line_exceeds_limit=True,
max_lines=max_lines,
max_bytes=max_bytes
)
# Collect complete lines that fit
output_lines_arr = []
output_bytes_count = 0
truncated_by = "lines"
for i, line in enumerate(lines):
if i >= max_lines:
break
# Calculate line bytes (add 1 for newline if not first line)
line_bytes = len(line.encode('utf-8')) + (1 if i > 0 else 0)
if output_bytes_count + line_bytes > max_bytes:
truncated_by = "bytes"
break
output_lines_arr.append(line)
output_bytes_count += line_bytes
# If exited due to line limit
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
truncated_by = "lines"
output_content = '\n'.join(output_lines_arr)
final_output_bytes = len(output_content.encode('utf-8'))
return TruncationResult(
content=output_content,
truncated=True,
truncated_by=truncated_by,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=len(output_lines_arr),
output_bytes=final_output_bytes,
last_line_partial=False,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
def truncate_tail(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
"""
Truncate content from tail (keep last N lines/bytes).
Suitable for bash output where you want to see the ending content (errors, final results).
If the last line of original content exceeds byte limit, may return partial first line.
:param content: Content to truncate
:param max_lines: Maximum lines (default: 2000)
:param max_bytes: Maximum bytes (default: 50KB)
:return: Truncation result
"""
if max_lines is None:
max_lines = DEFAULT_MAX_LINES
if max_bytes is None:
max_bytes = DEFAULT_MAX_BYTES
total_bytes = len(content.encode('utf-8'))
lines = content.split('\n')
total_lines = len(lines)
# Check if no truncation is needed
if total_lines <= max_lines and total_bytes <= max_bytes:
return TruncationResult(
content=content,
truncated=False,
truncated_by=None,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=total_lines,
output_bytes=total_bytes,
last_line_partial=False,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
# Work backwards from the end
output_lines_arr = []
output_bytes_count = 0
truncated_by = "lines"
last_line_partial = False
for i in range(len(lines) - 1, -1, -1):
if len(output_lines_arr) >= max_lines:
break
line = lines[i]
# Calculate line bytes (add newline if not the first added line)
line_bytes = len(line.encode('utf-8')) + (1 if len(output_lines_arr) > 0 else 0)
if output_bytes_count + line_bytes > max_bytes:
truncated_by = "bytes"
# Edge case: if we haven't added any lines yet and this line exceeds maxBytes,
# take the end portion of this line
if len(output_lines_arr) == 0:
truncated_line = _truncate_string_to_bytes_from_end(line, max_bytes)
output_lines_arr.insert(0, truncated_line)
output_bytes_count = len(truncated_line.encode('utf-8'))
last_line_partial = True
break
output_lines_arr.insert(0, line)
output_bytes_count += line_bytes
# If exited due to line limit
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
truncated_by = "lines"
output_content = '\n'.join(output_lines_arr)
final_output_bytes = len(output_content.encode('utf-8'))
return TruncationResult(
content=output_content,
truncated=True,
truncated_by=truncated_by,
total_lines=total_lines,
total_bytes=total_bytes,
output_lines=len(output_lines_arr),
output_bytes=final_output_bytes,
last_line_partial=last_line_partial,
first_line_exceeds_limit=False,
max_lines=max_lines,
max_bytes=max_bytes
)
def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
"""
Truncate string to fit byte limit (from end).
Properly handles multi-byte UTF-8 characters.
:param text: String to truncate
:param max_bytes: Maximum bytes
:return: Truncated string
"""
encoded = text.encode('utf-8')
if len(encoded) <= max_bytes:
return text
# Start from end, skip back maxBytes
start = len(encoded) - max_bytes
# Find valid UTF-8 boundary (character start)
while start < len(encoded) and (encoded[start] & 0xC0) == 0x80:
start += 1
return encoded[start:].decode('utf-8', errors='ignore')
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> tuple[str, bool]:
"""
Truncate single line to max characters, add [truncated] suffix.
Used for grep match lines.
:param line: Line to truncate
:param max_chars: Maximum characters
:return: (truncated text, whether truncated)
"""
if len(line) <= max_chars:
return line, False
return f"{line[:max_chars]}... [truncated]", True

View File

@@ -0,0 +1,3 @@
from .write import Write
__all__ = ['Write']

View File

@@ -0,0 +1,91 @@
"""
Write tool - Write file content
Creates or overwrites files, automatically creates parent directories
"""
import os
from typing import Dict, Any
from pathlib import Path
from agent.tools.base_tool import BaseTool, ToolResult
class Write(BaseTool):
"""Tool for writing file content"""
name: str = "write"
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories."
params: dict = {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to write (relative or absolute)"
},
"content": {
"type": "string",
"description": "Content to write to the file"
}
},
"required": ["path", "content"]
}
def __init__(self, config: dict = None):
self.config = config or {}
self.cwd = self.config.get("cwd", os.getcwd())
def execute(self, args: Dict[str, Any]) -> ToolResult:
"""
Execute file write operation
:param args: Contains file path and content
:return: Operation result
"""
path = args.get("path", "").strip()
content = args.get("content", "")
if not path:
return ToolResult.fail("Error: path parameter is required")
# Resolve path
absolute_path = self._resolve_path(path)
try:
# Create parent directory (if needed)
parent_dir = os.path.dirname(absolute_path)
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
# Write file
with open(absolute_path, 'w', encoding='utf-8') as f:
f.write(content)
# Get bytes written
bytes_written = len(content.encode('utf-8'))
result = {
"message": f"Successfully wrote {bytes_written} bytes to {path}",
"path": path,
"bytes_written": bytes_written
}
return ToolResult.success(result)
except PermissionError:
return ToolResult.fail(f"Error: Permission denied writing to {path}")
except Exception as e:
return ToolResult.fail(f"Error writing file: {str(e)}")
def _resolve_path(self, path: str) -> str:
"""
Resolve path to absolute path
:param path: Relative or absolute path
:return: Absolute path
"""
# Expand ~ to user home directory
path = os.path.expanduser(path)
if os.path.isabs(path):
return path
return os.path.abspath(os.path.join(self.cwd, path))

View File

@@ -1,222 +0,0 @@
import re
import time
import json
import uuid
from curl_cffi import requests
from bot.bot import Bot
from bot.claude.claude_ai_session import ClaudeAiSession
from bot.openai.open_ai_image import OpenAIImage
from bot.session_manager import SessionManager
from bridge.context import Context, ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf
class ClaudeAIBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
self.claude_api_cookie = conf().get("claude_api_cookie")
self.proxy = conf().get("proxy")
self.con_uuid_dic = {}
if self.proxy:
self.proxies = {
"http": self.proxy,
"https": self.proxy
}
else:
self.proxies = None
self.error = ""
self.org_uuid = self.get_organization_id()
def generate_uuid(self):
random_uuid = uuid.uuid4()
random_uuid_str = str(random_uuid)
formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
return formatted_uuid
def reply(self, query, context: Context = None) -> Reply:
if context.type == ContextType.TEXT:
return self._chat(query, context)
elif context.type == ContextType.IMAGE_CREATE:
ok, res = self.create_img(query, 0)
if ok:
reply = Reply(ReplyType.IMAGE_URL, res)
else:
reply = Reply(ReplyType.ERROR, res)
return reply
else:
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
return reply
def get_organization_id(self):
url = "https://claude.ai/api/organizations"
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
'Accept-Language': 'en-US,en;q=0.5',
'Referer': 'https://claude.ai/chats',
'Content-Type': 'application/json',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'Connection': 'keep-alive',
'Cookie': f'{self.claude_api_cookie}'
}
try:
response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
res = json.loads(response.text)
uuid = res[0]['uuid']
except:
if "App unavailable" in response.text:
logger.error("IP error: The IP is not allowed to be used on Claude")
self.error = "ip所在地区不被claude支持"
elif "Invalid authorization" in response.text:
logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
self.error = "无法通过claude身份验证请检查cookie"
return None
return uuid
def conversation_share_check(self,session_id):
if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
con_uuid = conf().get("claude_uuid")
return con_uuid
if session_id not in self.con_uuid_dic:
self.con_uuid_dic[session_id] = self.generate_uuid()
self.create_new_chat(self.con_uuid_dic[session_id])
return self.con_uuid_dic[session_id]
def check_cookie(self):
flag = self.get_organization_id()
return flag
def create_new_chat(self, con_uuid):
"""
新建claude对话实体
:param con_uuid: 对话id
:return:
"""
url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
payload = json.dumps({"uuid": con_uuid, "name": ""})
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
'Accept-Language': 'en-US,en;q=0.5',
'Referer': 'https://claude.ai/chats',
'Content-Type': 'application/json',
'Origin': 'https://claude.ai',
'DNT': '1',
'Connection': 'keep-alive',
'Cookie': self.claude_api_cookie,
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'TE': 'trailers'
}
response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
# Returns JSON of the newly created conversation information
return response.json()
def _chat(self, query, context, retry_count=0) -> Reply:
"""
发起对话请求
:param query: 请求提示词
:param context: 对话上下文
:param retry_count: 当前递归重试次数
:return: 回复
"""
if retry_count >= 2:
# exit from retry 2 times
logger.warn("[CLAUDEAI] failed after maximum number of retry times")
return Reply(ReplyType.ERROR, "请再问我一次吧")
try:
session_id = context["session_id"]
if self.org_uuid is None:
return Reply(ReplyType.ERROR, self.error)
session = self.sessions.session_query(query, session_id)
con_uuid = self.conversation_share_check(session_id)
model = conf().get("model") or "gpt-3.5-turbo"
# remove system message
if session.messages[0].get("role") == "system":
if model == "wenxin" or model == "claude":
session.messages.pop(0)
logger.info(f"[CLAUDEAI] query={query}")
# do http request
base_url = "https://claude.ai"
payload = json.dumps({
"completion": {
"prompt": f"{query}",
"timezone": "Asia/Kolkata",
"model": "claude-2"
},
"organization_uuid": f"{self.org_uuid}",
"conversation_uuid": f"{con_uuid}",
"text": f"{query}",
"attachments": []
})
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
'Accept': 'text/event-stream, text/event-stream',
'Accept-Language': 'en-US,en;q=0.5',
'Referer': 'https://claude.ai/chats',
'Content-Type': 'application/json',
'Origin': 'https://claude.ai',
'DNT': '1',
'Connection': 'keep-alive',
'Cookie': f'{self.claude_api_cookie}',
'Sec-Fetch-Dest': 'empty',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Site': 'same-origin',
'TE': 'trailers'
}
res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
if res.status_code == 200 or "pemission" in res.text:
# execute success
decoded_data = res.content.decode("utf-8")
decoded_data = re.sub('\n+', '\n', decoded_data).strip()
data_strings = decoded_data.split('\n')
completions = []
for data_string in data_strings:
json_str = data_string[6:].strip()
data = json.loads(json_str)
if 'completion' in data:
completions.append(data['completion'])
reply_content = ''.join(completions)
if "rate limi" in reply_content:
logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
return Reply(ReplyType.ERROR, "对话达到系统速率限制与cladue同步请进入官网查看解除限制时间")
logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
self.sessions.session_reply(reply_content, session_id, 100)
return Reply(ReplyType.TEXT, reply_content)
else:
flag = self.check_cookie()
if flag == None:
return Reply(ReplyType.ERROR, self.error)
response = res.json()
error = response.get("error")
logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
if res.status_code >= 500:
# server error, need retry
time.sleep(2)
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
except Exception as e:
logger.exception(e)
# retry
time.sleep(2)
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
return self._chat(query, context, retry_count + 1)

View File

@@ -1,9 +0,0 @@
from bot.session_manager import Session
class ClaudeAiSession(Session):
def __init__(self, session_id, system_prompt=None, model="claude"):
super().__init__(session_id, system_prompt)
self.model = model
# claude逆向不支持role prompt
# self.reset()

View File

@@ -1,19 +1,18 @@
# encoding:utf-8
import json
import time
import openai
import openai.error
import anthropic
import requests
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
from bot.bot import Bot
from bot.openai.open_ai_image import OpenAIImage
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from common import const
from common.log import logger
from config import conf
user_session = dict()
@@ -23,13 +22,9 @@ user_session = dict()
class ClaudeAPIBot(Bot, OpenAIImage):
def __init__(self):
super().__init__()
proxy = conf().get("proxy", None)
base_url = conf().get("open_ai_api_base", None) # 复用"open_ai_api_base"参数作为base_url
self.claudeClient = anthropic.Anthropic(
api_key=conf().get("claude_api_key"),
proxies=proxy if proxy else None,
base_url=base_url if base_url else None
)
self.api_key = conf().get("claude_api_key")
self.api_base = conf().get("open_ai_api_base") or "https://api.anthropic.com/v1"
self.proxy = conf().get("proxy", None)
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003")
def reply(self, query, context=None):
@@ -73,39 +68,104 @@ class ClaudeAPIBot(Bot, OpenAIImage):
reply = Reply(ReplyType.ERROR, retstring)
return reply
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
def reply_text(self, session: BaiduWenxinSession, retry_count=0, tools=None):
try:
actual_model = self._model_mapping(conf().get("model"))
response = self.claudeClient.messages.create(
model=actual_model,
max_tokens=4096,
system=conf().get("character_desc", ""),
messages=session.messages
# Prepare headers
headers = {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
# Extract system prompt if present and prepare Claude-compatible messages
system_prompt = conf().get("character_desc", "")
claude_messages = []
for msg in session.messages:
if msg.get("role") == "system":
system_prompt = msg["content"]
else:
claude_messages.append(msg)
# Prepare request data
data = {
"model": actual_model,
"messages": claude_messages,
"max_tokens": self._get_max_tokens(actual_model)
}
if system_prompt:
data["system"] = system_prompt
if tools:
data["tools"] = tools
# Make HTTP request
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
response = requests.post(
f"{self.api_base}/messages",
headers=headers,
json=data,
proxies=proxies
)
# response = openai.Completion.create(prompt=str(session), **self.args)
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
total_tokens = response.usage.input_tokens+response.usage.output_tokens
completion_tokens = response.usage.output_tokens
if response.status_code != 200:
raise Exception(f"API request failed: {response.status_code} - {response.text}")
claude_response = response.json()
# Handle response content and tool calls
res_content = ""
tool_calls = []
content_blocks = claude_response.get("content", [])
for block in content_blocks:
if block.get("type") == "text":
res_content += block.get("text", "")
elif block.get("type") == "tool_use":
tool_calls.append({
"id": block.get("id", ""),
"name": block.get("name", ""),
"arguments": block.get("input", {})
})
res_content = res_content.strip().replace("<|endoftext|>", "")
usage = claude_response.get("usage", {})
total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
completion_tokens = usage.get("output_tokens", 0)
logger.info("[CLAUDE_API] reply={}".format(res_content))
return {
if tool_calls:
logger.info("[CLAUDE_API] tool_calls={}".format(tool_calls))
result = {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"content": res_content,
}
if tool_calls:
result["tool_calls"] = tool_calls
return result
except Exception as e:
need_retry = retry_count < 2
result = {"total_tokens": 0, "completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
if isinstance(e, openai.error.RateLimitError):
# Handle different types of errors
error_str = str(e).lower()
if "rate" in error_str or "limit" in error_str:
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
result["content"] = "提问太快啦,请休息一下再问我吧"
if need_retry:
time.sleep(20)
elif isinstance(e, openai.error.Timeout):
elif "timeout" in error_str:
logger.warn("[CLAUDE_API] Timeout: {}".format(e))
result["content"] = "我没有收到你的消息"
if need_retry:
time.sleep(5)
elif isinstance(e, openai.error.APIConnectionError):
elif "connection" in error_str or "network" in error_str:
logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e))
need_retry = False
result["content"] = "我连接不到你的网络"
@@ -116,7 +176,7 @@ class ClaudeAPIBot(Bot, OpenAIImage):
if need_retry:
logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1))
return self.reply_text(session, retry_count + 1)
return self.reply_text(session, retry_count + 1, tools)
else:
return result
@@ -130,3 +190,288 @@ class ClaudeAPIBot(Bot, OpenAIImage):
elif model == "claude-3.5-sonnet":
return const.CLAUDE_35_SONNET
return model
def _get_max_tokens(self, model: str) -> int:
"""
Get max_tokens for the model.
Reference from pi-mono:
- Claude 3.5/3.7: 8192
- Claude 3 Opus: 4096
- Default: 8192
"""
if model and (model.startswith("claude-3-5") or model.startswith("claude-3-7")):
return 8192
elif model and model.startswith("claude-3") and "opus" in model:
return 4096
elif model and (model.startswith("claude-sonnet-4") or model.startswith("claude-opus-4")):
return 64000
return 8192
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
"""
Call Claude API with tool support for agent integration
Args:
messages: List of messages
tools: List of tool definitions
stream: Whether to use streaming
**kwargs: Additional parameters
Returns:
Formatted response compatible with OpenAI format or generator for streaming
"""
actual_model = self._model_mapping(conf().get("model"))
# Extract system prompt from messages if present
system_prompt = kwargs.get("system", conf().get("character_desc", ""))
claude_messages = []
for msg in messages:
if msg.get("role") == "system":
system_prompt = msg["content"]
else:
claude_messages.append(msg)
request_params = {
"model": actual_model,
"max_tokens": kwargs.get("max_tokens", self._get_max_tokens(actual_model)),
"messages": claude_messages,
"stream": stream
}
if system_prompt:
request_params["system"] = system_prompt
if tools:
request_params["tools"] = tools
try:
if stream:
return self._handle_stream_response(request_params)
else:
return self._handle_sync_response(request_params)
except Exception as e:
logger.error(f"Claude API call error: {e}")
if stream:
# Return error generator for stream
def error_generator():
yield {
"error": True,
"message": str(e),
"status_code": 500
}
return error_generator()
else:
# Return error response for sync
return {
"error": True,
"message": str(e),
"status_code": 500
}
def _handle_sync_response(self, request_params):
"""Handle synchronous Claude API response"""
# Prepare headers
headers = {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
# Make HTTP request
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
response = requests.post(
f"{self.api_base}/messages",
headers=headers,
json=request_params,
proxies=proxies
)
if response.status_code != 200:
raise Exception(f"API request failed: {response.status_code} - {response.text}")
claude_response = response.json()
# Extract content blocks
text_content = ""
tool_calls = []
content_blocks = claude_response.get("content", [])
for block in content_blocks:
if block.get("type") == "text":
text_content += block.get("text", "")
elif block.get("type") == "tool_use":
tool_calls.append({
"id": block.get("id", ""),
"type": "function",
"function": {
"name": block.get("name", ""),
"arguments": json.dumps(block.get("input", {}))
}
})
# Build message in OpenAI format
message = {
"role": "assistant",
"content": text_content
}
if tool_calls:
message["tool_calls"] = tool_calls
# Format response to match OpenAI structure
usage = claude_response.get("usage", {})
formatted_response = {
"id": claude_response.get("id", ""),
"object": "chat.completion",
"created": int(time.time()),
"model": claude_response.get("model", request_params["model"]),
"choices": [
{
"index": 0,
"message": message,
"finish_reason": claude_response.get("stop_reason", "stop")
}
],
"usage": {
"prompt_tokens": usage.get("input_tokens", 0),
"completion_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
}
}
return formatted_response
def _handle_stream_response(self, request_params):
"""Handle streaming Claude API response using HTTP requests"""
# Prepare headers
headers = {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
# Add stream parameter
request_params["stream"] = True
# Track tool use state
tool_uses_map = {} # {index: {id, name, input}}
current_tool_use_index = -1
try:
# Make streaming HTTP request
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
response = requests.post(
f"{self.api_base}/messages",
headers=headers,
json=request_params,
proxies=proxies,
stream=True
)
if response.status_code != 200:
error_text = response.text
try:
error_data = json.loads(error_text)
error_msg = error_data.get("error", {}).get("message", error_text)
except:
error_msg = error_text or "Unknown error"
yield {
"error": True,
"status_code": response.status_code,
"message": error_msg
}
return
# Process streaming response
for line in response.iter_lines():
if line:
line = line.decode('utf-8')
if line.startswith('data: '):
line = line[6:] # Remove 'data: ' prefix
if line == '[DONE]':
break
try:
event = json.loads(line)
event_type = event.get("type")
if event_type == "content_block_start":
# New content block
block = event.get("content_block", {})
if block.get("type") == "tool_use":
current_tool_use_index = event.get("index", 0)
tool_uses_map[current_tool_use_index] = {
"id": block.get("id", ""),
"name": block.get("name", ""),
"input": ""
}
elif event_type == "content_block_delta":
delta = event.get("delta", {})
delta_type = delta.get("type")
if delta_type == "text_delta":
# Text content
content = delta.get("text", "")
yield {
"id": event.get("id", ""),
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request_params["model"],
"choices": [{
"index": 0,
"delta": {"content": content},
"finish_reason": None
}]
}
elif delta_type == "input_json_delta":
# Tool input accumulation
if current_tool_use_index >= 0:
tool_uses_map[current_tool_use_index]["input"] += delta.get("partial_json", "")
elif event_type == "message_delta":
# Message complete - yield tool calls if any
if tool_uses_map:
for idx in sorted(tool_uses_map.keys()):
tool_data = tool_uses_map[idx]
yield {
"id": event.get("id", ""),
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request_params["model"],
"choices": [{
"index": 0,
"delta": {
"tool_calls": [{
"index": idx,
"id": tool_data["id"],
"type": "function",
"function": {
"name": tool_data["name"],
"arguments": tool_data["input"]
}
}]
},
"finish_reason": None
}]
}
except json.JSONDecodeError:
continue
except requests.RequestException as e:
logger.error(f"Claude streaming request error: {e}")
yield {
"error": True,
"message": f"Connection error: {str(e)}",
"status_code": 0
}
except Exception as e:
logger.error(f"Claude streaming error: {e}")
yield {
"error": True,
"message": str(e),
"status_code": 500
}

288
bridge/agent_bridge.py Normal file
View File

@@ -0,0 +1,288 @@
"""
Agent Bridge - Integrates Agent system with existing COW bridge
"""
from typing import Optional, List
from agent.protocol import Agent, LLMModel, LLMRequest
from agent.tools import Calculator, CurrentTime, Read, Write, Edit, Bash, Grep, Find, Ls
from bridge.bridge import Bridge
from bridge.context import Context
from bridge.reply import Reply, ReplyType
from common import const
from common.log import logger
class AgentLLMModel(LLMModel):
"""
LLM Model adapter that uses COW's existing bot infrastructure
"""
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
# Get model name directly from config
from config import conf
model_name = conf().get("model", const.GPT_41)
super().__init__(model=model_name)
self.bridge = bridge
self.bot_type = bot_type
self._bot = None
@property
def bot(self):
"""Lazy load the bot"""
if self._bot is None:
self._bot = self.bridge.get_bot(self.bot_type)
return self._bot
def call(self, request: LLMRequest):
"""
Call the model using COW's bot infrastructure
"""
try:
# For non-streaming calls, we'll use the existing reply method
# This is a simplified implementation
if hasattr(self.bot, 'call_with_tools'):
# Use tool-enabled call if available
kwargs = {
'messages': request.messages,
'tools': getattr(request, 'tools', None),
'stream': False
}
# Only pass max_tokens if it's explicitly set
if request.max_tokens is not None:
kwargs['max_tokens'] = request.max_tokens
response = self.bot.call_with_tools(**kwargs)
return self._format_response(response)
else:
# Fallback to regular call
# This would need to be implemented based on your specific needs
raise NotImplementedError("Regular call not implemented yet")
except Exception as e:
logger.error(f"AgentLLMModel call error: {e}")
raise
def call_stream(self, request: LLMRequest):
"""
Call the model with streaming using COW's bot infrastructure
"""
try:
if hasattr(self.bot, 'call_with_tools'):
# Use tool-enabled streaming call if available
# Ensure max_tokens is an integer, use default if None
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
# Extract system prompt if present
system_prompt = getattr(request, 'system', None)
# Build kwargs for call_with_tools
kwargs = {
'messages': request.messages,
'tools': getattr(request, 'tools', None),
'stream': True,
'max_tokens': max_tokens
}
# Add system prompt if present
if system_prompt:
kwargs['system'] = system_prompt
stream = self.bot.call_with_tools(**kwargs)
# Convert Claude stream format to our expected format
for chunk in stream:
yield self._format_stream_chunk(chunk)
else:
raise NotImplementedError("Streaming call not implemented yet")
except Exception as e:
logger.error(f"AgentLLMModel call_stream error: {e}")
raise
def _format_response(self, response):
"""Format Claude response to our expected format"""
# This would need to be implemented based on Claude's response format
return response
def _format_stream_chunk(self, chunk):
"""Format Claude stream chunk to our expected format"""
# This would need to be implemented based on Claude's stream format
return chunk
class AgentBridge:
"""
Bridge class that integrates single super Agent with COW
"""
def __init__(self, bridge: Bridge):
self.bridge = bridge
self.agent: Optional[Agent] = None
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
"""
Create the super agent with COW integration
Args:
system_prompt: System prompt
tools: List of tools (optional)
**kwargs: Additional agent parameters
Returns:
Agent instance
"""
# Create LLM model that uses COW's bot infrastructure
model = AgentLLMModel(self.bridge)
# Default tools if none provided
if tools is None:
tools = [
Calculator(),
CurrentTime(),
Read(),
Write(),
Edit(),
Bash(),
Grep(),
Find(),
Ls()
]
# Create the single super agent
self.agent = Agent(
system_prompt=system_prompt,
description=kwargs.get("description", "AI Super Agent"),
model=model,
tools=tools,
max_steps=kwargs.get("max_steps", 15),
output_mode=kwargs.get("output_mode", "logger")
)
return self.agent
def get_agent(self) -> Optional[Agent]:
"""Get the super agent, create if not exists"""
if self.agent is None:
self._init_default_agent()
return self.agent
def _init_default_agent(self):
"""Initialize default super agent with config and memory"""
from config import conf
import os
# Get base system prompt from config
base_prompt = conf().get("character_desc", "你是一个AI助手")
# Setup memory if enabled
memory_manager = None
memory_tools = []
try:
# Try to initialize memory system
from agent.memory import MemoryManager, MemoryConfig
from agent.tools import MemorySearchTool, MemoryGetTool
# Create memory config directly with sensible defaults
workspace_root = os.path.expanduser("~/cow")
memory_config = MemoryConfig(
workspace_root=workspace_root,
embedding_provider="local", # Use local embedding (no API key needed)
embedding_model="all-MiniLM-L6-v2"
)
# Create memory manager with the config
memory_manager = MemoryManager(memory_config)
# Create memory tools
memory_tools = [
MemorySearchTool(memory_manager),
MemoryGetTool(memory_manager)
]
# Build memory guidance and add to system prompt
memory_guidance = memory_manager.build_memory_guidance(
lang="zh",
include_context=True
)
system_prompt = base_prompt + "\n\n" + memory_guidance
logger.info(f"[AgentBridge] Memory system initialized")
logger.info(f"[AgentBridge] Workspace: {memory_config.get_workspace()}")
except Exception as e:
logger.warning(f"[AgentBridge] Memory system not available: {e}")
logger.info("[AgentBridge] Continuing without memory features")
system_prompt = base_prompt
import traceback
traceback.print_exc()
logger.info("[AgentBridge] Initializing super agent")
# Configure file tools to work in the correct workspace
file_config = {"cwd": workspace_root} if memory_manager else {}
# Create default tools with workspace config
from agent.tools import Calculator, CurrentTime, Read, Write, Edit, Bash, Grep, Find, Ls
tools = [
Calculator(),
CurrentTime(),
Read(config=file_config),
Write(config=file_config),
Edit(config=file_config),
Bash(config=file_config),
Grep(config=file_config),
Find(config=file_config),
Ls(config=file_config)
]
# Create agent with configured tools
agent = self.create_agent(
system_prompt=system_prompt,
tools=tools,
max_steps=50,
output_mode="logger"
)
# Attach memory manager to agent if available
if memory_manager:
agent.memory_manager = memory_manager
# Add memory tools if available
if memory_tools:
for tool in memory_tools:
agent.add_tool(tool)
logger.info(f"[AgentBridge] Added {len(memory_tools)} memory tools")
def agent_reply(self, query: str, context: Context = None,
on_event=None, clear_history: bool = False) -> Reply:
"""
Use super agent to reply to a query
Args:
query: User query
context: COW context (optional)
on_event: Event callback (optional)
clear_history: Whether to clear conversation history
Returns:
Reply object
"""
try:
# Get agent (will auto-initialize if needed)
agent = self.get_agent()
if not agent:
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
# Use agent's run_stream method
response = agent.run_stream(
user_message=query,
on_event=on_event,
clear_history=clear_history
)
return Reply(ReplyType.TEXT, response)
except Exception as e:
logger.error(f"Agent reply error: {e}")
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")

View File

@@ -23,7 +23,7 @@ class Bridge(object):
if bot_type:
self.btype["chat"] = bot_type
else:
model_type = conf().get("model") or const.GPT35
model_type = conf().get("model") or const.GPT_41_MINI
if model_type in ["text-davinci-003"]:
self.btype["chat"] = const.OPEN_AI
if conf().get("use_azure_chatgpt", False):
@@ -64,6 +64,7 @@ class Bridge(object):
self.bots = {}
self.chat_bots = {}
self._agent_bridge = None
# 模型对应的接口
def get_bot(self, typename):
@@ -104,3 +105,29 @@ class Bridge(object):
重置bot路由
"""
self.__init__()
def get_agent_bridge(self):
"""
Get agent bridge for agent-based conversations
"""
if self._agent_bridge is None:
from bridge.agent_bridge import AgentBridge
self._agent_bridge = AgentBridge(self)
return self._agent_bridge
def fetch_agent_reply(self, query: str, context: Context = None,
on_event=None, clear_history: bool = False) -> Reply:
"""
Use super agent to handle the query
Args:
query: User query
context: Context object
on_event: Event callback for streaming
clear_history: Whether to clear conversation history
Returns:
Reply object
"""
agent_bridge = self.get_agent_bridge()
return agent_bridge.agent_reply(query, context, on_event, clear_history)

View File

@@ -5,6 +5,8 @@ Message sending channel abstract class
from bridge.bridge import Bridge
from bridge.context import Context
from bridge.reply import *
from common.log import logger
from config import conf
class Channel(object):
@@ -35,7 +37,30 @@ class Channel(object):
raise NotImplementedError
def build_reply_content(self, query, context: Context = None) -> Reply:
return Bridge().fetch_reply_content(query, context)
"""
Build reply content, using agent if enabled in config
"""
# Check if agent mode is enabled
use_agent = conf().get("agent", False)
if use_agent:
try:
logger.info("[Channel] Using agent mode")
# Use agent bridge to handle the query
return Bridge().fetch_agent_reply(
query=query,
context=context,
on_event=None,
clear_history=False
)
except Exception as e:
logger.error(f"[Channel] Agent mode failed, fallback to normal mode: {e}")
# Fallback to normal mode if agent fails
return Bridge().fetch_reply_content(query, context)
else:
# Normal mode
return Bridge().fetch_reply_content(query, context)
def build_voice_to_text(self, voice_file) -> Reply:
return Bridge().fetch_voice_to_text(voice_file)

View File

@@ -150,9 +150,6 @@ class WebChannel(ChatChannel):
Poll for responses using the session_id.
"""
try:
# 不记录轮询请求的日志
web.ctx.log_request = False
data = web.data()
json_data = json.loads(data)
session_id = json_data.get('session_id')
@@ -215,19 +212,20 @@ class WebChannel(ChatChannel):
)
app = web.application(urls, globals(), autoreload=False)
# 禁用web.py的默认日志输出
import io
from contextlib import redirect_stdout
# 完全禁用web.py的HTTP日志输出
# 创建一个空的日志处理函数
def null_log_function(status, environ):
pass
# 配置web.py的日志级别为ERROR只显示错误
# 替换web.py的日志函数
web.httpserver.LogMiddleware.log = lambda self, status, environ: None
# 配置web.py的日志级别为ERROR
logging.getLogger("web").setLevel(logging.ERROR)
# 禁用web.httpserver的日志
logging.getLogger("web.httpserver").setLevel(logging.ERROR)
# 临时重定向标准输出捕获web.py的启动消息
with redirect_stdout(io.StringIO()):
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
# 启动服务器
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
class RootHandler:

View File

@@ -3,6 +3,7 @@
"model": "",
"open_ai_api_key": "YOUR API KEY",
"claude_api_key": "YOUR API KEY",
"claude_api_base": "https://api.anthropic.com",
"text_to_image": "dall-e-2",
"voice_to_text": "openai",
"text_to_voice": "openai",
@@ -30,8 +31,9 @@
"expires_in_seconds": 3600,
"character_desc": "你是基于大语言模型的AI智能助手旨在回答并解决人们的任何问题并且可以使用多种语言与人交流。",
"temperature": 0.7,
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
"use_linkai": false,
"linkai_api_key": "",
"linkai_app_code": ""
"linkai_app_code": "",
"agent": true
}

View File

@@ -1,10 +1,10 @@
# encoding:utf-8
import copy
import json
import logging
import os
import pickle
import copy
from common.log import logger
@@ -183,6 +183,7 @@ available_setting = {
"Minimax_group_id": "",
"Minimax_base_url": "",
"web_port": 9899,
"agent": False # 是否开启Agent模式
}

5
memory/2026-01-29.md Normal file
View File

@@ -0,0 +1,5 @@
# 2026-01-29 记录
## 老王的重要决定
- 今天老王告诉我他决定要学AI了这是一个重要的决策
- 这可能会是他学习和职业发展的一个转折点

21
memory/MEMORY.md Normal file
View File

@@ -0,0 +1,21 @@
# Memory
Long-term curated memories and preferences.
## 用户信息
- 用户名:老王
## 用户信息
- 用户名:老王
## 用户偏好
- 喜欢吃红烧肉
- 爱打篮球
## 重要决策
- 决定要学习AI2026-01-29
## Notes
- Important decisions and facts go here
- This is your long-term knowledge base