mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: personal ai agent framework
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -35,3 +35,6 @@ plugins/banwords/lib/__pycache__
|
||||
!plugins/linkai
|
||||
!plugins/agent
|
||||
client_config.json
|
||||
ref/
|
||||
.cursor/
|
||||
local/
|
||||
|
||||
10
agent/memory/__init__.py
Normal file
10
agent/memory/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Memory module for AgentMesh
|
||||
|
||||
Provides long-term memory capabilities with hybrid search (vector + keyword)
|
||||
"""
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config']
|
||||
139
agent/memory/chunker.py
Normal file
139
agent/memory/chunker.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Text chunking utilities for memory
|
||||
|
||||
Splits text into chunks with token limits and overlap
|
||||
"""
|
||||
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""Represents a text chunk with line numbers"""
|
||||
text: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
|
||||
|
||||
class TextChunker:
|
||||
"""Chunks text by line count with token estimation"""
|
||||
|
||||
def __init__(self, max_tokens: int = 500, overlap_tokens: int = 50):
|
||||
"""
|
||||
Initialize chunker
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum tokens per chunk
|
||||
overlap_tokens: Overlap tokens between chunks
|
||||
"""
|
||||
self.max_tokens = max_tokens
|
||||
self.overlap_tokens = overlap_tokens
|
||||
# Rough estimation: ~4 chars per token for English/Chinese mixed
|
||||
self.chars_per_token = 4
|
||||
|
||||
def chunk_text(self, text: str) -> List[TextChunk]:
|
||||
"""
|
||||
Chunk text into overlapping segments
|
||||
|
||||
Args:
|
||||
text: Input text to chunk
|
||||
|
||||
Returns:
|
||||
List of TextChunk objects
|
||||
"""
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
lines = text.split('\n')
|
||||
chunks = []
|
||||
|
||||
max_chars = self.max_tokens * self.chars_per_token
|
||||
overlap_chars = self.overlap_tokens * self.chars_per_token
|
||||
|
||||
current_chunk = []
|
||||
current_chars = 0
|
||||
start_line = 1
|
||||
|
||||
for i, line in enumerate(lines, start=1):
|
||||
line_chars = len(line)
|
||||
|
||||
# If single line exceeds max, split it
|
||||
if line_chars > max_chars:
|
||||
# Save current chunk if exists
|
||||
if current_chunk:
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=i - 1
|
||||
))
|
||||
current_chunk = []
|
||||
current_chars = 0
|
||||
|
||||
# Split long line into multiple chunks
|
||||
for sub_chunk in self._split_long_line(line, max_chars):
|
||||
chunks.append(TextChunk(
|
||||
text=sub_chunk,
|
||||
start_line=i,
|
||||
end_line=i
|
||||
))
|
||||
|
||||
start_line = i + 1
|
||||
continue
|
||||
|
||||
# Check if adding this line would exceed limit
|
||||
if current_chars + line_chars > max_chars and current_chunk:
|
||||
# Save current chunk
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=i - 1
|
||||
))
|
||||
|
||||
# Start new chunk with overlap
|
||||
overlap_lines = self._get_overlap_lines(current_chunk, overlap_chars)
|
||||
current_chunk = overlap_lines + [line]
|
||||
current_chars = sum(len(l) for l in current_chunk)
|
||||
start_line = i - len(overlap_lines)
|
||||
else:
|
||||
# Add line to current chunk
|
||||
current_chunk.append(line)
|
||||
current_chars += line_chars
|
||||
|
||||
# Save last chunk
|
||||
if current_chunk:
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=len(lines)
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_long_line(self, line: str, max_chars: int) -> List[str]:
|
||||
"""Split a single long line into multiple chunks"""
|
||||
chunks = []
|
||||
for i in range(0, len(line), max_chars):
|
||||
chunks.append(line[i:i + max_chars])
|
||||
return chunks
|
||||
|
||||
def _get_overlap_lines(self, lines: List[str], target_chars: int) -> List[str]:
|
||||
"""Get last few lines that fit within target_chars for overlap"""
|
||||
overlap = []
|
||||
chars = 0
|
||||
|
||||
for line in reversed(lines):
|
||||
line_chars = len(line)
|
||||
if chars + line_chars > target_chars:
|
||||
break
|
||||
overlap.insert(0, line)
|
||||
chars += line_chars
|
||||
|
||||
return overlap
|
||||
|
||||
def chunk_markdown(self, text: str) -> List[TextChunk]:
|
||||
"""
|
||||
Chunk markdown text while respecting structure
|
||||
(For future enhancement: respect markdown sections)
|
||||
"""
|
||||
return self.chunk_text(text)
|
||||
114
agent/memory/config.py
Normal file
114
agent/memory/config.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Memory configuration module
|
||||
|
||||
Provides global memory configuration with simplified workspace structure
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig:
|
||||
"""Configuration for memory storage and search"""
|
||||
|
||||
# Storage paths (default: ~/cow)
|
||||
workspace_root: str = field(default_factory=lambda: os.path.expanduser("~/cow"))
|
||||
|
||||
# Embedding config
|
||||
embedding_provider: str = "openai" # "openai" | "local"
|
||||
embedding_model: str = "text-embedding-3-small"
|
||||
embedding_dim: int = 1536
|
||||
|
||||
# Chunking config
|
||||
chunk_max_tokens: int = 500
|
||||
chunk_overlap_tokens: int = 50
|
||||
|
||||
# Search config
|
||||
max_results: int = 10
|
||||
min_score: float = 0.3
|
||||
|
||||
# Hybrid search weights
|
||||
vector_weight: float = 0.7
|
||||
keyword_weight: float = 0.3
|
||||
|
||||
# Memory sources
|
||||
sources: List[str] = field(default_factory=lambda: ["memory", "session"])
|
||||
|
||||
# Sync config
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
return Path(self.workspace_root)
|
||||
|
||||
def get_memory_dir(self) -> Path:
|
||||
"""Get memory files directory"""
|
||||
return self.get_workspace() / "memory"
|
||||
|
||||
def get_db_path(self) -> Path:
|
||||
"""Get SQLite database path for long-term memory index"""
|
||||
index_dir = self.get_memory_dir() / "long-term"
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
return index_dir / "index.db"
|
||||
|
||||
def get_skills_dir(self) -> Path:
|
||||
"""Get skills directory"""
|
||||
return self.get_workspace() / "skills"
|
||||
|
||||
def get_agent_workspace(self, agent_name: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get workspace directory for an agent
|
||||
|
||||
Args:
|
||||
agent_name: Optional agent name (not used in current implementation)
|
||||
|
||||
Returns:
|
||||
Path to workspace directory
|
||||
"""
|
||||
workspace = self.get_workspace()
|
||||
# Ensure workspace directory exists
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
return workspace
|
||||
|
||||
|
||||
# Global memory configuration
|
||||
_global_memory_config: Optional[MemoryConfig] = None
|
||||
|
||||
|
||||
def get_default_memory_config() -> MemoryConfig:
|
||||
"""
|
||||
Get the global memory configuration.
|
||||
If not set, returns a default configuration.
|
||||
|
||||
Returns:
|
||||
MemoryConfig instance
|
||||
"""
|
||||
global _global_memory_config
|
||||
if _global_memory_config is None:
|
||||
_global_memory_config = MemoryConfig()
|
||||
return _global_memory_config
|
||||
|
||||
|
||||
def set_global_memory_config(config: MemoryConfig):
|
||||
"""
|
||||
Set the global memory configuration.
|
||||
This should be called before creating any MemoryManager instances.
|
||||
|
||||
Args:
|
||||
config: MemoryConfig instance to use globally
|
||||
|
||||
Example:
|
||||
>>> from agent.memory import MemoryConfig, set_global_memory_config
|
||||
>>> config = MemoryConfig(
|
||||
... workspace_root="~/my_agents",
|
||||
... embedding_provider="openai",
|
||||
... vector_weight=0.8
|
||||
... )
|
||||
>>> set_global_memory_config(config)
|
||||
"""
|
||||
global _global_memory_config
|
||||
_global_memory_config = config
|
||||
175
agent/memory/embedding.py
Normal file
175
agent/memory/embedding.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports OpenAI and local embedding models
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Get embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
Initialize OpenAI embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
||||
api_key: OpenAI API key
|
||||
api_base: Optional API base URL
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
|
||||
# Lazy import to avoid dependency issues
|
||||
try:
|
||||
from openai import OpenAI
|
||||
self.client = OpenAI(api_key=api_key, base_url=api_base)
|
||||
except ImportError:
|
||||
raise ImportError("OpenAI package not installed. Install with: pip install openai")
|
||||
|
||||
# Set dimensions based on model
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
response = self.client.embeddings.create(
|
||||
input=text,
|
||||
model=self.model
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.model
|
||||
)
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||
"""Local embedding provider using sentence-transformers"""
|
||||
|
||||
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
||||
"""
|
||||
Initialize local embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name from sentence-transformers
|
||||
"""
|
||||
self.model_name = model
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.model = SentenceTransformer(model)
|
||||
self._dimensions = self.model.get_sentence_embedding_dimension()
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"sentence-transformers not installed. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
embedding = self.model.encode(text, convert_to_numpy=True)
|
||||
return embedding.tolist()
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
embeddings = self.model.encode(texts, convert_to_numpy=True)
|
||||
return embeddings.tolist()
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
"""Get cached embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
"""Cache embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
"""Compute cache key"""
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
"""Clear cache"""
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Args:
|
||||
provider: Provider name ("openai" or "local")
|
||||
model: Model name (provider-specific)
|
||||
api_key: API key for remote providers
|
||||
api_base: API base URL for remote providers
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
if provider == "openai":
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
elif provider == "local":
|
||||
model = model or "all-MiniLM-L6-v2"
|
||||
return LocalEmbeddingProvider(model=model)
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding provider: {provider}")
|
||||
623
agent/memory/manager.py
Normal file
623
agent/memory/manager.py
Normal file
@@ -0,0 +1,623 @@
|
||||
"""
|
||||
Memory manager for AgentMesh
|
||||
|
||||
Provides high-level interface for memory operations
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||
from agent.memory.chunker import TextChunker
|
||||
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
|
||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Memory manager with hybrid search capabilities
|
||||
|
||||
Provides long-term memory for agents with vector and keyword search
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[MemoryConfig] = None,
|
||||
embedding_provider: Optional[EmbeddingProvider] = None,
|
||||
llm_model: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Initialize memory manager
|
||||
|
||||
Args:
|
||||
config: Memory configuration (uses global config if not provided)
|
||||
embedding_provider: Custom embedding provider (optional)
|
||||
llm_model: LLM model for summarization (optional)
|
||||
"""
|
||||
self.config = config or get_default_memory_config()
|
||||
|
||||
# Initialize storage
|
||||
db_path = self.config.get_db_path()
|
||||
self.storage = MemoryStorage(db_path)
|
||||
|
||||
# Initialize chunker
|
||||
self.chunker = TextChunker(
|
||||
max_tokens=self.config.chunk_max_tokens,
|
||||
overlap_tokens=self.config.chunk_overlap_tokens
|
||||
)
|
||||
|
||||
# Initialize embedding provider (optional)
|
||||
self.embedding_provider = None
|
||||
if embedding_provider:
|
||||
self.embedding_provider = embedding_provider
|
||||
else:
|
||||
# Try to create embedding provider, but allow failure
|
||||
try:
|
||||
# Get API key from environment or config
|
||||
api_key = os.environ.get('OPENAI_API_KEY')
|
||||
api_base = os.environ.get('OPENAI_API_BASE')
|
||||
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider=self.config.embedding_provider,
|
||||
model=self.config.embedding_model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
except Exception as e:
|
||||
# Embedding provider failed, but that's OK
|
||||
# We can still use keyword search and file operations
|
||||
print(f"⚠️ Warning: Embedding provider initialization failed: {e}")
|
||||
print(f"ℹ️ Memory will work with keyword search only (no semantic search)")
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
self.flush_manager = MemoryFlushManager(
|
||||
workspace_dir=workspace_dir,
|
||||
llm_model=llm_model
|
||||
)
|
||||
|
||||
# Ensure workspace directories exist
|
||||
self._init_workspace()
|
||||
|
||||
self._dirty = False
|
||||
|
||||
def _init_workspace(self):
|
||||
"""Initialize workspace directories"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create default memory files
|
||||
workspace_dir = self.config.get_workspace()
|
||||
create_memory_files_if_needed(workspace_dir)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
max_results: Optional[int] = None,
|
||||
min_score: Optional[float] = None,
|
||||
include_shared: bool = True
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search memory with hybrid search (vector + keyword)
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
user_id: User ID for scoped search
|
||||
max_results: Maximum results to return
|
||||
min_score: Minimum score threshold
|
||||
include_shared: Include shared memories
|
||||
|
||||
Returns:
|
||||
List of search results sorted by relevance
|
||||
"""
|
||||
max_results = max_results or self.config.max_results
|
||||
min_score = min_score or self.config.min_score
|
||||
|
||||
# Determine scopes
|
||||
scopes = []
|
||||
if include_shared:
|
||||
scopes.append("shared")
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
if not scopes:
|
||||
return []
|
||||
|
||||
# Sync if needed
|
||||
if self.config.sync_on_search and self._dirty:
|
||||
await self.sync()
|
||||
|
||||
# Perform vector search (if embedding provider available)
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
|
||||
# Perform keyword search
|
||||
keyword_results = self.storage.search_keyword(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
vector_results,
|
||||
keyword_results,
|
||||
self.config.vector_weight,
|
||||
self.config.keyword_weight
|
||||
)
|
||||
|
||||
# Filter by min score and limit
|
||||
filtered = [r for r in merged if r.score >= min_score]
|
||||
return filtered[:max_results]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
user_id: Optional[str] = None,
|
||||
scope: str = "shared",
|
||||
source: str = "memory",
|
||||
path: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Add new memory content
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
user_id: User ID for user-scoped memory
|
||||
scope: Memory scope ("shared", "user", "session")
|
||||
source: Memory source ("memory" or "session")
|
||||
path: File path (auto-generated if not provided)
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
if not content.strip():
|
||||
return
|
||||
|
||||
# Generate path if not provided
|
||||
if not path:
|
||||
content_hash = hashlib.md5(content.encode('utf-8')).hexdigest()[:8]
|
||||
if user_id and scope == "user":
|
||||
path = f"memory/users/{user_id}/memory_{content_hash}.md"
|
||||
else:
|
||||
path = f"memory/shared/memory_{content_hash}.md"
|
||||
|
||||
# Chunk content
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
|
||||
# Generate embeddings (if provider available)
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
# No embeddings, just use None
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
agent_id="default",
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
# Save to storage
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
self.storage.update_file_metadata(
|
||||
path=path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(os.path.getmtime(__file__)), # Use current time
|
||||
size=len(content)
|
||||
)
|
||||
|
||||
async def sync(self, force: bool = False):
|
||||
"""
|
||||
Synchronize memory from files
|
||||
|
||||
Args:
|
||||
force: Force full reindex
|
||||
"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
workspace_dir = self.config.get_workspace()
|
||||
|
||||
# Scan memory/MEMORY.md
|
||||
memory_file = memory_dir / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
await self._sync_file(memory_file, "memory", "shared", None)
|
||||
|
||||
# Scan memory directory (including daily summaries)
|
||||
if memory_dir.exists():
|
||||
for file_path in memory_dir.rglob("*.md"):
|
||||
# Determine scope and user_id from path
|
||||
rel_path = file_path.relative_to(workspace_dir)
|
||||
parts = rel_path.parts
|
||||
|
||||
# Check if it's in daily summary directory
|
||||
if "daily" in parts:
|
||||
# Daily summary files
|
||||
if "users" in parts or len(parts) > 3:
|
||||
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
|
||||
user_idx = parts.index("daily") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared daily summary: memory/daily/2024-01-29.md
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
elif "users" in parts:
|
||||
# User-scoped memory
|
||||
user_idx = parts.index("users") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared memory
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
|
||||
await self._sync_file(file_path, "memory", scope, user_id)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
async def _sync_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
source: str,
|
||||
scope: str,
|
||||
user_id: Optional[str]
|
||||
):
|
||||
"""Sync a single file"""
|
||||
# Compute file hash
|
||||
content = file_path.read_text()
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
|
||||
# Get relative path
|
||||
workspace_dir = self.config.get_workspace()
|
||||
rel_path = str(file_path.relative_to(workspace_dir))
|
||||
|
||||
# Check if file changed
|
||||
stored_hash = self.storage.get_file_hash(rel_path)
|
||||
if stored_hash == file_hash:
|
||||
return # No changes
|
||||
|
||||
# Delete old chunks
|
||||
self.storage.delete_by_path(rel_path)
|
||||
|
||||
# Chunk and embed
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
agent_id="default",
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None
|
||||
))
|
||||
|
||||
# Save
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
stat = file_path.stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size
|
||||
)
|
||||
|
||||
def should_flush_memory(
|
||||
self,
|
||||
current_tokens: int,
|
||||
context_window: int = 128000,
|
||||
reserve_tokens: int = 20000,
|
||||
soft_threshold: int = 4000
|
||||
) -> bool:
|
||||
"""
|
||||
Check if memory flush should be triggered
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
context_window: Model's context window size (default: 128K)
|
||||
reserve_tokens: Reserve tokens for compaction overhead (default: 20K)
|
||||
soft_threshold: Trigger N tokens before threshold (default: 4K)
|
||||
|
||||
Returns:
|
||||
True if memory flush should run
|
||||
"""
|
||||
return self.flush_manager.should_flush(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens,
|
||||
soft_threshold=soft_threshold
|
||||
)
|
||||
|
||||
async def execute_memory_flush(
|
||||
self,
|
||||
agent_executor,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush before compaction
|
||||
|
||||
This runs a silent agent turn to write durable memories to disk.
|
||||
Similar to clawdbot's pre-compaction memory flush.
|
||||
|
||||
Args:
|
||||
agent_executor: Async function to execute agent with prompt
|
||||
current_tokens: Current session token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
|
||||
Example:
|
||||
>>> async def run_agent(prompt, system_prompt, silent=False):
|
||||
... # Your agent execution logic
|
||||
... pass
|
||||
>>>
|
||||
>>> if manager.should_flush_memory(current_tokens=100000):
|
||||
... await manager.execute_memory_flush(
|
||||
... agent_executor=run_agent,
|
||||
... current_tokens=100000
|
||||
... )
|
||||
"""
|
||||
success = await self.flush_manager.execute_flush(
|
||||
agent_executor=agent_executor,
|
||||
current_tokens=current_tokens,
|
||||
user_id=user_id,
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
if success:
|
||||
# Mark dirty so next search will sync the new memories
|
||||
self._dirty = True
|
||||
|
||||
return success
|
||||
|
||||
def build_memory_guidance(self, lang: str = "en", include_context: bool = True) -> str:
|
||||
"""
|
||||
Build natural memory guidance for agent system prompt
|
||||
|
||||
Following clawdbot's approach:
|
||||
1. Load MEMORY.md as bootstrap context (blends into background)
|
||||
2. Load daily files on-demand via memory_search tool
|
||||
3. Agent should NOT proactively mention memories unless user asks
|
||||
|
||||
Args:
|
||||
lang: Language for guidance ("en" or "zh")
|
||||
include_context: Whether to include bootstrap memory context (default: True)
|
||||
MEMORY.md is loaded as background context (like clawdbot)
|
||||
Daily files are accessed via memory_search tool
|
||||
|
||||
Returns:
|
||||
Memory guidance text (and optionally context) for system prompt
|
||||
"""
|
||||
today_file = self.flush_manager.get_today_memory_file().name
|
||||
|
||||
if lang == "zh":
|
||||
guidance = f"""## 记忆召回
|
||||
回答关于过去工作、决策、日期、人物、偏好或待办事项的问题前:先用 memory_search 搜索 MEMORY.md + memory/*.md;然后用 memory_get 只读取需要的行。如果搜索后仍不确定,说明你已检查过。
|
||||
|
||||
## 记忆存储
|
||||
当用户分享持久偏好、决策或重要事实时(无论是否明确要求"记住"),主动存储:
|
||||
- 持久信息(偏好、决策、人物信息)→ memory/MEMORY.md
|
||||
- 当天的笔记和上下文 → memory/{today_file}
|
||||
- 静默存储,仅在用户明确要求时确认
|
||||
|
||||
## 记忆使用原则
|
||||
- 不要主动提起或列举记忆内容
|
||||
- 只在用户明确询问相关信息时才使用记忆
|
||||
- 记忆是背景知识,不是要展示的内容
|
||||
- 自然使用记忆,就像你本来就知道这些信息"""
|
||||
else:
|
||||
guidance = f"""## Memory Recall
|
||||
Before answering anything about prior work, decisions, dates, people, preferences, or todos: run memory_search on MEMORY.md + memory/*.md; then use memory_get to pull only the needed lines. If low confidence after search, say you checked.
|
||||
|
||||
## Memory Storage
|
||||
When user shares durable preferences, decisions, or important facts (whether or not they explicitly say "remember"), proactively store:
|
||||
- Durable info (preferences, decisions, people) → memory/MEMORY.md
|
||||
- Daily notes and context → memory/{today_file}
|
||||
- Store silently; only confirm when explicitly requested
|
||||
|
||||
## Memory Usage Principles
|
||||
- Don't proactively mention or list memory contents
|
||||
- Only use memories when user explicitly asks about them
|
||||
- Memories are background knowledge, not content to showcase
|
||||
- Use memories naturally as if you inherently knew this information"""
|
||||
|
||||
if include_context:
|
||||
# Load bootstrap context (MEMORY.md only, like clawdbot)
|
||||
bootstrap_context = self.load_bootstrap_memories()
|
||||
if bootstrap_context:
|
||||
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
|
||||
|
||||
return guidance
|
||||
|
||||
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Load bootstrap memory files for session start
|
||||
|
||||
Following clawdbot's design:
|
||||
- Only loads memory/MEMORY.md (long-term curated memory)
|
||||
- Daily files (YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
|
||||
- User-specific MEMORY.md is also loaded if user_id provided
|
||||
|
||||
Returns memory content WITHOUT obvious headers so it blends naturally
|
||||
into the context as background knowledge.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memories
|
||||
|
||||
Returns:
|
||||
Memory content to inject into system prompt (blends naturally as background context)
|
||||
"""
|
||||
workspace_dir = self.config.get_workspace()
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
|
||||
sections = []
|
||||
|
||||
# 1. Load memory/MEMORY.md ONLY (long-term curated memory)
|
||||
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
|
||||
memory_file = memory_dir / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
try:
|
||||
content = memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read memory/MEMORY.md: {e}")
|
||||
|
||||
# 2. Load user-specific MEMORY.md if user_id provided
|
||||
if user_id:
|
||||
user_memory_dir = memory_dir / "users" / user_id
|
||||
user_memory_file = user_memory_dir / "MEMORY.md"
|
||||
if user_memory_file.exists():
|
||||
try:
|
||||
content = user_memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read user memory: {e}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
# Join sections without obvious headers - let memories blend naturally
|
||||
# This makes the agent feel like it "just knows" rather than "checking memory files"
|
||||
return "\n\n".join(sections)
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get memory status"""
|
||||
stats = self.storage.get_stats()
|
||||
return {
|
||||
'chunks': stats['chunks'],
|
||||
'files': stats['files'],
|
||||
'workspace': str(self.config.get_workspace()),
|
||||
'dirty': self._dirty,
|
||||
'embedding_enabled': self.embedding_provider is not None,
|
||||
'embedding_provider': self.config.embedding_provider if self.embedding_provider else 'disabled',
|
||||
'embedding_model': self.config.embedding_model if self.embedding_provider else 'N/A',
|
||||
'search_mode': 'hybrid (vector + keyword)' if self.embedding_provider else 'keyword only (FTS5)'
|
||||
}
|
||||
|
||||
def mark_dirty(self):
|
||||
"""Mark memory as dirty (needs sync)"""
|
||||
self._dirty = True
|
||||
|
||||
def close(self):
|
||||
"""Close memory manager and release resources"""
|
||||
self.storage.close()
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _generate_chunk_id(self, path: str, start_line: int, end_line: int) -> str:
|
||||
"""Generate unique chunk ID"""
|
||||
content = f"{path}:{start_line}:{end_line}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
vector_results: List[SearchResult],
|
||||
keyword_results: List[SearchResult],
|
||||
vector_weight: float,
|
||||
keyword_weight: float
|
||||
) -> List[SearchResult]:
|
||||
"""Merge vector and keyword search results"""
|
||||
# Create a map by (path, start_line, end_line)
|
||||
merged_map = {}
|
||||
|
||||
for result in vector_results:
|
||||
key = (result.path, result.start_line, result.end_line)
|
||||
merged_map[key] = {
|
||||
'result': result,
|
||||
'vector_score': result.score,
|
||||
'keyword_score': 0.0
|
||||
}
|
||||
|
||||
for result in keyword_results:
|
||||
key = (result.path, result.start_line, result.end_line)
|
||||
if key in merged_map:
|
||||
merged_map[key]['keyword_score'] = result.score
|
||||
else:
|
||||
merged_map[key] = {
|
||||
'result': result,
|
||||
'vector_score': 0.0,
|
||||
'keyword_score': result.score
|
||||
}
|
||||
|
||||
# Calculate combined scores
|
||||
merged_results = []
|
||||
for entry in merged_map.values():
|
||||
combined_score = (
|
||||
vector_weight * entry['vector_score'] +
|
||||
keyword_weight * entry['keyword_score']
|
||||
)
|
||||
|
||||
result = entry['result']
|
||||
merged_results.append(SearchResult(
|
||||
path=result.path,
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
score=combined_score,
|
||||
snippet=result.snippet,
|
||||
source=result.source,
|
||||
user_id=result.user_id
|
||||
))
|
||||
|
||||
# Sort by score
|
||||
merged_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return merged_results
|
||||
418
agent/memory/storage.py
Normal file
418
agent/memory/storage.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
Storage layer for memory using SQLite + FTS5
|
||||
|
||||
Provides vector and keyword search capabilities
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import hashlib
|
||||
from typing import List, Dict, Optional, Any
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryChunk:
|
||||
"""Represents a memory chunk with text and embedding"""
|
||||
id: str
|
||||
user_id: Optional[str]
|
||||
scope: str # "shared" | "user" | "session"
|
||||
source: str # "memory" | "session"
|
||||
path: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
text: str
|
||||
embedding: Optional[List[float]]
|
||||
hash: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Search result with score and snippet"""
|
||||
path: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
score: float
|
||||
snippet: str
|
||||
source: str
|
||||
user_id: Optional[str] = None
|
||||
|
||||
|
||||
class MemoryStorage:
|
||||
"""SQLite-based storage with FTS5 for keyword search"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self.conn: Optional[sqlite3.Connection] = None
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database with schema"""
|
||||
self.conn = sqlite3.connect(str(self.db_path))
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
# Enable JSON support
|
||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||
|
||||
# Create chunks table with embeddings
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
scope TEXT NOT NULL DEFAULT 'shared',
|
||||
source TEXT NOT NULL DEFAULT 'memory',
|
||||
path TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
embedding TEXT,
|
||||
hash TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at INTEGER DEFAULT (strftime('%s', 'now')),
|
||||
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_user
|
||||
ON chunks(user_id)
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_scope
|
||||
ON chunks(scope)
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_hash
|
||||
ON chunks(path, hash)
|
||||
""")
|
||||
|
||||
# Create FTS5 virtual table for keyword search
|
||||
self.conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
||||
text,
|
||||
id UNINDEXED,
|
||||
user_id UNINDEXED,
|
||||
path UNINDEXED,
|
||||
source UNINDEXED,
|
||||
scope UNINDEXED,
|
||||
content='chunks',
|
||||
content_rowid='rowid'
|
||||
)
|
||||
""")
|
||||
|
||||
# Create triggers to keep FTS in sync
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
||||
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
||||
DELETE FROM chunks_fts WHERE rowid = old.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
||||
UPDATE chunks_fts SET text = new.text, id = new.id,
|
||||
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
|
||||
WHERE rowid = new.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
# Create files metadata table
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
path TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL DEFAULT 'memory',
|
||||
hash TEXT NOT NULL,
|
||||
mtime INTEGER NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
)
|
||||
""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def save_chunk(self, chunk: MemoryChunk):
|
||||
"""Save a memory chunk"""
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (
|
||||
chunk.id,
|
||||
chunk.user_id,
|
||||
chunk.scope,
|
||||
chunk.source,
|
||||
chunk.path,
|
||||
chunk.start_line,
|
||||
chunk.end_line,
|
||||
chunk.text,
|
||||
json.dumps(chunk.embedding) if chunk.embedding else None,
|
||||
chunk.hash,
|
||||
json.dumps(chunk.metadata) if chunk.metadata else None
|
||||
))
|
||||
self.conn.commit()
|
||||
|
||||
def save_chunks_batch(self, chunks: List[MemoryChunk]):
|
||||
"""Save multiple chunks in a batch"""
|
||||
self.conn.executemany("""
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", [
|
||||
(
|
||||
c.id, c.user_id, c.scope, c.source, c.path,
|
||||
c.start_line, c.end_line, c.text,
|
||||
json.dumps(c.embedding) if c.embedding else None,
|
||||
c.hash,
|
||||
json.dumps(c.metadata) if c.metadata else None
|
||||
)
|
||||
for c in chunks
|
||||
])
|
||||
self.conn.commit()
|
||||
|
||||
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
|
||||
"""Get a chunk by ID"""
|
||||
row = self.conn.execute("""
|
||||
SELECT * FROM chunks WHERE id = ?
|
||||
""", (chunk_id,)).fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return self._row_to_chunk(row)
|
||||
|
||||
def search_vector(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
user_id: Optional[str] = None,
|
||||
scopes: List[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Vector similarity search using in-memory cosine similarity
|
||||
(sqlite-vec can be added later for better performance)
|
||||
"""
|
||||
if scopes is None:
|
||||
scopes = ["shared"]
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
# Build query
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
params = scopes
|
||||
|
||||
if user_id:
|
||||
query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE scope IN ({scope_placeholders})
|
||||
AND (scope = 'shared' OR user_id = ?)
|
||||
AND embedding IS NOT NULL
|
||||
"""
|
||||
params.append(user_id)
|
||||
else:
|
||||
query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE scope IN ({scope_placeholders})
|
||||
AND embedding IS NOT NULL
|
||||
"""
|
||||
|
||||
rows = self.conn.execute(query, params).fetchall()
|
||||
|
||||
# Calculate cosine similarity
|
||||
results = []
|
||||
for row in rows:
|
||||
embedding = json.loads(row['embedding'])
|
||||
similarity = self._cosine_similarity(query_embedding, embedding)
|
||||
|
||||
if similarity > 0:
|
||||
results.append((similarity, row))
|
||||
|
||||
# Sort by similarity and limit
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=score,
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for score, row in results
|
||||
]
|
||||
|
||||
def search_keyword(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
scopes: List[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""Keyword search using FTS5"""
|
||||
if scopes is None:
|
||||
scopes = ["shared"]
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
# Build FTS query
|
||||
fts_query = self._build_fts_query(query)
|
||||
if not fts_query:
|
||||
return []
|
||||
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
params = [fts_query] + scopes
|
||||
|
||||
if user_id:
|
||||
sql_query = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||
FROM chunks_fts
|
||||
JOIN chunks ON chunks.id = chunks_fts.id
|
||||
WHERE chunks_fts MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
|
||||
ORDER BY rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.extend([user_id, limit])
|
||||
else:
|
||||
sql_query = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||
FROM chunks_fts
|
||||
JOIN chunks ON chunks.id = chunks_fts.id
|
||||
WHERE chunks_fts MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
ORDER BY rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit)
|
||||
|
||||
rows = self.conn.execute(sql_query, params).fetchall()
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=self._bm25_rank_to_score(row['rank']),
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def delete_by_path(self, path: str):
|
||||
"""Delete all chunks from a file"""
|
||||
self.conn.execute("""
|
||||
DELETE FROM chunks WHERE path = ?
|
||||
""", (path,))
|
||||
self.conn.commit()
|
||||
|
||||
def get_file_hash(self, path: str) -> Optional[str]:
|
||||
"""Get stored file hash"""
|
||||
row = self.conn.execute("""
|
||||
SELECT hash FROM files WHERE path = ?
|
||||
""", (path,)).fetchone()
|
||||
return row['hash'] if row else None
|
||||
|
||||
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
|
||||
"""Update file metadata"""
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (path, source, file_hash, mtime, size))
|
||||
self.conn.commit()
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""Get storage statistics"""
|
||||
chunks_count = self.conn.execute("""
|
||||
SELECT COUNT(*) as cnt FROM chunks
|
||||
""").fetchone()['cnt']
|
||||
|
||||
files_count = self.conn.execute("""
|
||||
SELECT COUNT(*) as cnt FROM files
|
||||
""").fetchone()['cnt']
|
||||
|
||||
return {
|
||||
'chunks': chunks_count,
|
||||
'files': files_count
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""Close database connection"""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _row_to_chunk(self, row) -> MemoryChunk:
|
||||
"""Convert database row to MemoryChunk"""
|
||||
return MemoryChunk(
|
||||
id=row['id'],
|
||||
user_id=row['user_id'],
|
||||
scope=row['scope'],
|
||||
source=row['source'],
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
text=row['text'],
|
||||
embedding=json.loads(row['embedding']) if row['embedding'] else None,
|
||||
hash=row['hash'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors"""
|
||||
if len(vec1) != len(vec2):
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
||||
norm1 = sum(a * a for a in vec1) ** 0.5
|
||||
norm2 = sum(b * b for b in vec2) ** 0.5
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
@staticmethod
|
||||
def _build_fts_query(raw_query: str) -> Optional[str]:
|
||||
"""Build FTS5 query from raw text"""
|
||||
import re
|
||||
tokens = re.findall(r'[A-Za-z0-9_\u4e00-\u9fff]+', raw_query)
|
||||
if not tokens:
|
||||
return None
|
||||
quoted = [f'"{t}"' for t in tokens]
|
||||
return ' AND '.join(quoted)
|
||||
|
||||
@staticmethod
|
||||
def _bm25_rank_to_score(rank: float) -> float:
|
||||
"""Convert BM25 rank to 0-1 score"""
|
||||
normalized = max(0, rank) if rank is not None else 999
|
||||
return 1 / (1 + normalized)
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text(text: str, max_chars: int) -> str:
|
||||
"""Truncate text to max characters"""
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[:max_chars] + "..."
|
||||
|
||||
@staticmethod
|
||||
def compute_hash(content: str) -> str:
|
||||
"""Compute SHA256 hash of content"""
|
||||
return hashlib.sha256(content.encode('utf-8')).hexdigest()
|
||||
235
agent/memory/summarizer.py
Normal file
235
agent/memory/summarizer.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Memory flush manager
|
||||
|
||||
Triggers memory flush before context compaction (similar to clawdbot)
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable, Any
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class MemoryFlushManager:
|
||||
"""
|
||||
Manages memory flush operations before context compaction
|
||||
|
||||
Similar to clawdbot's memory flush mechanism:
|
||||
- Triggers when context approaches token limit
|
||||
- Runs a silent agent turn to write memories to disk
|
||||
- Uses memory/YYYY-MM-DD.md for daily notes
|
||||
- Uses MEMORY.md for long-term curated memories
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
llm_model: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Initialize memory flush manager
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
llm_model: LLM model for agent execution (optional)
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.llm_model = llm_model
|
||||
|
||||
self.memory_dir = workspace_dir / "memory"
|
||||
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Tracking
|
||||
self.last_flush_token_count: Optional[int] = None
|
||||
self.last_flush_timestamp: Optional[datetime] = None
|
||||
|
||||
def should_flush(
|
||||
self,
|
||||
current_tokens: int,
|
||||
context_window: int,
|
||||
reserve_tokens: int = 20000,
|
||||
soft_threshold: int = 4000
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if memory flush should be triggered
|
||||
|
||||
Similar to clawdbot's shouldRunMemoryFlush logic:
|
||||
threshold = contextWindow - reserveTokens - softThreshold
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
context_window: Model's context window size
|
||||
reserve_tokens: Reserve tokens for compaction overhead
|
||||
soft_threshold: Trigger flush N tokens before threshold
|
||||
|
||||
Returns:
|
||||
True if flush should run
|
||||
"""
|
||||
if current_tokens <= 0:
|
||||
return False
|
||||
|
||||
threshold = max(0, context_window - reserve_tokens - soft_threshold)
|
||||
if threshold <= 0:
|
||||
return False
|
||||
|
||||
# Check if we've crossed the threshold
|
||||
if current_tokens < threshold:
|
||||
return False
|
||||
|
||||
# Avoid duplicate flush in same compaction cycle
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + soft_threshold:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get today's memory file path: memory/YYYY-MM-DD.md
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memory
|
||||
|
||||
Returns:
|
||||
Path to today's memory file
|
||||
"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / f"{today}.md"
|
||||
else:
|
||||
return self.memory_dir / f"{today}.md"
|
||||
|
||||
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get main memory file path: memory/MEMORY.md
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memory
|
||||
|
||||
Returns:
|
||||
Path to main memory file
|
||||
"""
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / "MEMORY.md"
|
||||
else:
|
||||
return self.memory_dir / "MEMORY.md"
|
||||
|
||||
def create_flush_prompt(self) -> str:
|
||||
"""
|
||||
Create prompt for memory flush turn
|
||||
|
||||
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_PROMPT
|
||||
"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
return (
|
||||
f"Pre-compaction memory flush. "
|
||||
f"Store durable memories now (use memory/{today}.md for daily notes; "
|
||||
f"create memory/ if needed). "
|
||||
f"If nothing to store, reply with NO_REPLY."
|
||||
)
|
||||
|
||||
def create_flush_system_prompt(self) -> str:
|
||||
"""
|
||||
Create system prompt for memory flush turn
|
||||
|
||||
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_SYSTEM_PROMPT
|
||||
"""
|
||||
return (
|
||||
"Pre-compaction memory flush turn. "
|
||||
"The session is near auto-compaction; capture durable memories to disk. "
|
||||
"You may reply, but usually NO_REPLY is correct."
|
||||
)
|
||||
|
||||
async def execute_flush(
|
||||
self,
|
||||
agent_executor: Callable,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush by running a silent agent turn
|
||||
|
||||
Args:
|
||||
agent_executor: Function to execute agent with prompt
|
||||
current_tokens: Current token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
"""
|
||||
try:
|
||||
# Create flush prompts
|
||||
prompt = self.create_flush_prompt()
|
||||
system_prompt = self.create_flush_system_prompt()
|
||||
|
||||
# Execute agent turn (silent, no user-visible reply expected)
|
||||
await agent_executor(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
silent=True, # NO_REPLY expected
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
# Track flush
|
||||
self.last_flush_token_count = current_tokens
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Memory flush failed: {e}")
|
||||
return False
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get memory flush status"""
|
||||
return {
|
||||
'last_flush_tokens': self.last_flush_token_count,
|
||||
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
|
||||
'today_file': str(self.get_today_memory_file()),
|
||||
'main_file': str(self.get_main_memory_file())
|
||||
}
|
||||
|
||||
|
||||
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
|
||||
"""
|
||||
Create default memory files if they don't exist
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
user_id: Optional user ID for user-specific files
|
||||
"""
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create main MEMORY.md in memory directory
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
main_memory = user_dir / "MEMORY.md"
|
||||
else:
|
||||
main_memory = memory_dir / "MEMORY.md"
|
||||
|
||||
if not main_memory.exists():
|
||||
# Create empty file or with minimal structure (no obvious "Memory" header)
|
||||
# Following clawdbot's approach: memories should blend naturally into context
|
||||
main_memory.write_text("")
|
||||
|
||||
# Create today's memory file
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
today_memory = user_dir / f"{today}.md"
|
||||
else:
|
||||
today_memory = memory_dir / f"{today}.md"
|
||||
|
||||
if not today_memory.exists():
|
||||
today_memory.write_text(
|
||||
f"# Daily Memory: {today}\n\n"
|
||||
f"Day-to-day notes and running context.\n\n"
|
||||
)
|
||||
10
agent/memory/tools/__init__.py
Normal file
10
agent/memory/tools/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Memory tools for AgentMesh
|
||||
|
||||
Provides memory_search and memory_get tools for agents
|
||||
"""
|
||||
|
||||
from agent.memory.tools.memory_search import MemorySearchTool
|
||||
from agent.memory.tools.memory_get import MemoryGetTool
|
||||
|
||||
__all__ = ['MemorySearchTool', 'MemoryGetTool']
|
||||
118
agent/memory/tools/memory_get.py
Normal file
118
agent/memory/tools/memory_get.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Memory get tool
|
||||
|
||||
Allows agents to read specific sections from memory files
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.memory.manager import MemoryManager
|
||||
|
||||
|
||||
class MemoryGetTool(BaseTool):
|
||||
"""Tool for reading memory file contents"""
|
||||
|
||||
def __init__(self, memory_manager: MemoryManager):
|
||||
"""
|
||||
Initialize memory get tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self._name = "memory_get"
|
||||
self._description = (
|
||||
"Read specific memory file content by path and line range. "
|
||||
"Use after memory_search to get full context from historical memory files."
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')"
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Starting line number (optional, default: 1)",
|
||||
"default": 1
|
||||
},
|
||||
"num_lines": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read (optional, reads all if not specified)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs) -> str:
|
||||
"""
|
||||
Execute memory file read
|
||||
|
||||
Args:
|
||||
path: File path
|
||||
start_line: Start line
|
||||
num_lines: Number of lines
|
||||
|
||||
Returns:
|
||||
File content
|
||||
"""
|
||||
path = kwargs.get("path")
|
||||
start_line = kwargs.get("start_line", 1)
|
||||
num_lines = kwargs.get("num_lines")
|
||||
|
||||
if not path:
|
||||
return "Error: path parameter is required"
|
||||
|
||||
try:
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
file_path = workspace_dir / path
|
||||
|
||||
if not file_path.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = file_path.read_text()
|
||||
lines = content.split('\n')
|
||||
|
||||
# Handle line range
|
||||
if start_line < 1:
|
||||
start_line = 1
|
||||
|
||||
start_idx = start_line - 1
|
||||
|
||||
if num_lines:
|
||||
end_idx = start_idx + num_lines
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
else:
|
||||
selected_lines = lines[start_idx:]
|
||||
|
||||
result = '\n'.join(selected_lines)
|
||||
|
||||
# Add metadata
|
||||
total_lines = len(lines)
|
||||
shown_lines = len(selected_lines)
|
||||
|
||||
output = [
|
||||
f"File: {path}",
|
||||
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
|
||||
"",
|
||||
result
|
||||
]
|
||||
|
||||
return '\n'.join(output)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error reading memory file: {str(e)}"
|
||||
106
agent/memory/tools/memory_search.py
Normal file
106
agent/memory/tools/memory_search.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Memory search tool
|
||||
|
||||
Allows agents to search their memory using semantic and keyword search
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.memory.manager import MemoryManager
|
||||
|
||||
|
||||
class MemorySearchTool(BaseTool):
|
||||
"""Tool for searching agent memory"""
|
||||
|
||||
def __init__(self, memory_manager: MemoryManager, user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize memory search tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
user_id: Optional user ID for scoped search
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self.user_id = user_id
|
||||
self._name = "memory_search"
|
||||
self._description = (
|
||||
"Search historical memory files (beyond today/yesterday) using semantic and keyword search. "
|
||||
"Recent context (MEMORY.md + today + yesterday) is already loaded. "
|
||||
"Use this ONLY for older dates, specific past events, or when current context lacks needed info."
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def parameters(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (can be natural language question or keywords)"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 10)",
|
||||
"default": 10
|
||||
},
|
||||
"min_score": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score (0-1, default: 0.3)",
|
||||
"default": 0.3
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs) -> str:
|
||||
"""
|
||||
Execute memory search
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
max_results: Maximum results
|
||||
min_score: Minimum score
|
||||
|
||||
Returns:
|
||||
Formatted search results
|
||||
"""
|
||||
query = kwargs.get("query")
|
||||
max_results = kwargs.get("max_results", 10)
|
||||
min_score = kwargs.get("min_score", 0.3)
|
||||
|
||||
if not query:
|
||||
return "Error: query parameter is required"
|
||||
|
||||
try:
|
||||
results = await self.memory_manager.search(
|
||||
query=query,
|
||||
user_id=self.user_id,
|
||||
max_results=max_results,
|
||||
min_score=min_score,
|
||||
include_shared=True
|
||||
)
|
||||
|
||||
if not results:
|
||||
return f"No relevant memories found for query: {query}"
|
||||
|
||||
# Format results
|
||||
output = [f"Found {len(results)} relevant memories:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
|
||||
output.append(f" Score: {result.score:.3f}")
|
||||
output.append(f" Snippet: {result.snippet}")
|
||||
|
||||
return "\n".join(output)
|
||||
|
||||
except Exception as e:
|
||||
return f"Error searching memory: {str(e)}"
|
||||
20
agent/protocol/__init__.py
Normal file
20
agent/protocol/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .agent import Agent
|
||||
from .agent_stream import AgentStreamExecutor
|
||||
from .task import Task, TaskType, TaskStatus
|
||||
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
|
||||
from .models import LLMModel, LLMRequest, ModelFactory
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'AgentStreamExecutor',
|
||||
'Task',
|
||||
'TaskType',
|
||||
'TaskStatus',
|
||||
'AgentResult',
|
||||
'AgentAction',
|
||||
'AgentActionType',
|
||||
'ToolResult',
|
||||
'LLMModel',
|
||||
'LLMRequest',
|
||||
'ModelFactory'
|
||||
]
|
||||
292
agent/protocol/agent.py
Normal file
292
agent/protocol/agent.py
Normal file
@@ -0,0 +1,292 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult
|
||||
from agent.tools.base_tool import BaseTool, ToolStage
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
||||
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
||||
context_reserve_tokens=None, memory_manager=None, name: str = None):
|
||||
"""
|
||||
Initialize the Agent with system prompt, model, description.
|
||||
|
||||
:param system_prompt: The system prompt for the agent.
|
||||
:param description: A description of the agent.
|
||||
:param model: An instance of LLMModel to be used by the agent.
|
||||
:param tools: Optional list of tools for the agent to use.
|
||||
:param output_mode: Control how execution progress is displayed:
|
||||
"print" for console output or "logger" for using logger
|
||||
:param max_steps: Maximum number of steps the agent can take (default: 100)
|
||||
:param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model)
|
||||
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
|
||||
:param memory_manager: Optional MemoryManager instance for memory operations
|
||||
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
|
||||
"""
|
||||
self.name = name or "Agent"
|
||||
self.system_prompt = system_prompt
|
||||
self.model: LLMModel = model # Instance of LLMModel
|
||||
self.description = description
|
||||
self.tools: list = []
|
||||
self.max_steps = max_steps # max tool-call steps, default 100
|
||||
self.max_context_tokens = max_context_tokens # max tokens in context
|
||||
self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests
|
||||
self.captured_actions = [] # Initialize captured actions list
|
||||
self.output_mode = output_mode
|
||||
self.last_usage = None # Store last API response usage info
|
||||
self.messages = [] # Unified message history for stream mode
|
||||
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||
if tools:
|
||||
for tool in tools:
|
||||
self.add_tool(tool)
|
||||
|
||||
def add_tool(self, tool: BaseTool):
|
||||
"""
|
||||
Add a tool to the agent.
|
||||
|
||||
:param tool: The tool to add (either a tool instance or a tool name)
|
||||
"""
|
||||
# If tool is already an instance, use it directly
|
||||
tool.model = self.model
|
||||
self.tools.append(tool)
|
||||
|
||||
def _get_model_context_window(self) -> int:
|
||||
"""
|
||||
Get the model's context window size in tokens.
|
||||
Auto-detect based on model name.
|
||||
|
||||
Model context windows:
|
||||
- Claude 3.5/3.7 Sonnet: 200K tokens
|
||||
- Claude 3 Opus: 200K tokens
|
||||
- GPT-4 Turbo/128K: 128K tokens
|
||||
- GPT-4: 8K-32K tokens
|
||||
- GPT-3.5: 16K tokens
|
||||
- DeepSeek: 64K tokens
|
||||
|
||||
:return: Context window size in tokens
|
||||
"""
|
||||
if self.model and hasattr(self.model, 'model'):
|
||||
model_name = self.model.model.lower()
|
||||
|
||||
# Claude models - 200K context
|
||||
if 'claude-3' in model_name or 'claude-sonnet' in model_name:
|
||||
return 200000
|
||||
|
||||
# GPT-4 models
|
||||
elif 'gpt-4' in model_name:
|
||||
if 'turbo' in model_name or '128k' in model_name:
|
||||
return 128000
|
||||
elif '32k' in model_name:
|
||||
return 32000
|
||||
else:
|
||||
return 8000
|
||||
|
||||
# GPT-3.5
|
||||
elif 'gpt-3.5' in model_name:
|
||||
if '16k' in model_name:
|
||||
return 16000
|
||||
else:
|
||||
return 4000
|
||||
|
||||
# DeepSeek
|
||||
elif 'deepseek' in model_name:
|
||||
return 64000
|
||||
|
||||
# Default conservative value
|
||||
return 10000
|
||||
|
||||
def _get_context_reserve_tokens(self) -> int:
|
||||
"""
|
||||
Get the number of tokens to reserve for new requests.
|
||||
This prevents context overflow by keeping a buffer.
|
||||
|
||||
:return: Number of tokens to reserve
|
||||
"""
|
||||
if self.context_reserve_tokens is not None:
|
||||
return self.context_reserve_tokens
|
||||
|
||||
# Reserve ~20% of context window for new requests
|
||||
context_window = self._get_model_context_window()
|
||||
return max(4000, int(context_window * 0.2))
|
||||
|
||||
def _estimate_message_tokens(self, message: dict) -> int:
|
||||
"""
|
||||
Estimate token count for a message using chars/4 heuristic.
|
||||
This is a conservative estimate (tends to overestimate).
|
||||
|
||||
:param message: Message dict with 'role' and 'content'
|
||||
:return: Estimated token count
|
||||
"""
|
||||
content = message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
return max(1, len(content) // 4)
|
||||
elif isinstance(content, list):
|
||||
# Handle multi-part content (text + images)
|
||||
total_chars = 0
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get('type') == 'text':
|
||||
total_chars += len(part.get('text', ''))
|
||||
elif isinstance(part, dict) and part.get('type') == 'image':
|
||||
# Estimate images as ~1200 tokens
|
||||
total_chars += 4800
|
||||
return max(1, total_chars // 4)
|
||||
return 1
|
||||
|
||||
def _find_tool(self, tool_name: str):
|
||||
"""Find and return a tool with the specified name"""
|
||||
for tool in self.tools:
|
||||
if tool.name == tool_name:
|
||||
# Only pre-process stage tools can be actively called
|
||||
if tool.stage == ToolStage.PRE_PROCESS:
|
||||
tool.model = self.model
|
||||
tool.context = self # Set tool context
|
||||
return tool
|
||||
else:
|
||||
# If it's a post-process tool, return None to prevent direct calling
|
||||
logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.")
|
||||
return None
|
||||
return None
|
||||
|
||||
# output function based on mode
|
||||
def output(self, message="", end="\n"):
|
||||
if self.output_mode == "print":
|
||||
print(message, end=end)
|
||||
elif message:
|
||||
logger.info(message)
|
||||
|
||||
def _execute_post_process_tools(self):
|
||||
"""Execute all post-process stage tools"""
|
||||
# Get all post-process stage tools
|
||||
post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS]
|
||||
|
||||
# Execute each tool
|
||||
for tool in post_process_tools:
|
||||
# Set tool context
|
||||
tool.context = self
|
||||
|
||||
# Record start time for execution timing
|
||||
start_time = time.time()
|
||||
|
||||
# Execute tool (with empty parameters, tool will extract needed info from context)
|
||||
result = tool.execute({})
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Capture tool use for tracking
|
||||
self.capture_tool_use(
|
||||
tool_name=tool.name,
|
||||
input_params={}, # Post-process tools typically don't take parameters
|
||||
output=result.result,
|
||||
status=result.status,
|
||||
error_message=str(result.result) if result.status == "error" else None,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# Log result
|
||||
if result.status == "success":
|
||||
# Print tool execution result in the desired format
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}")
|
||||
else:
|
||||
# Print failure in print mode
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}")
|
||||
|
||||
def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None,
|
||||
execution_time=0.0):
|
||||
"""
|
||||
Capture a tool use action.
|
||||
|
||||
:param thought: thought content
|
||||
:param tool_name: Name of the tool used
|
||||
:param input_params: Parameters passed to the tool
|
||||
:param output: Output from the tool
|
||||
:param status: Status of the tool execution
|
||||
:param error_message: Error message if the tool execution failed
|
||||
:param execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_result = ToolResult(
|
||||
tool_name=tool_name,
|
||||
input_params=input_params,
|
||||
output=output,
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
action = AgentAction(
|
||||
agent_id=self.id if hasattr(self, 'id') else str(id(self)),
|
||||
agent_name=self.name,
|
||||
action_type=AgentActionType.TOOL_USE,
|
||||
tool_result=tool_result,
|
||||
thought=thought
|
||||
)
|
||||
|
||||
self.captured_actions.append(action)
|
||||
|
||||
return action
|
||||
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False) -> str:
|
||||
"""
|
||||
Execute single agent task with streaming (based on tool-call)
|
||||
|
||||
This method supports:
|
||||
- Streaming output
|
||||
- Multi-turn reasoning based on tool-call
|
||||
- Event callbacks
|
||||
- Persistent conversation history across calls
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
on_event: Event callback function callback(event: dict)
|
||||
event = {"type": str, "timestamp": float, "data": dict}
|
||||
clear_history: If True, clear conversation history before this call (default: False)
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
|
||||
Example:
|
||||
# Multi-turn conversation with memory
|
||||
response1 = agent.run_stream("My name is Alice")
|
||||
response2 = agent.run_stream("What's my name?") # Will remember Alice
|
||||
|
||||
# Single-turn without memory
|
||||
response = agent.run_stream("Hello", clear_history=True)
|
||||
"""
|
||||
# Clear history if requested
|
||||
if clear_history:
|
||||
self.messages = []
|
||||
|
||||
# Get model to use
|
||||
if not self.model:
|
||||
raise ValueError("No model available for agent")
|
||||
|
||||
# Create stream executor with agent's message history
|
||||
executor = AgentStreamExecutor(
|
||||
agent=self,
|
||||
model=self.model,
|
||||
system_prompt=self.system_prompt,
|
||||
tools=self.tools,
|
||||
max_turns=self.max_steps,
|
||||
on_event=on_event,
|
||||
messages=self.messages # Pass agent's message history
|
||||
)
|
||||
|
||||
# Execute
|
||||
response = executor.run_stream(user_message)
|
||||
|
||||
# Update agent's message history from executor
|
||||
self.messages = executor.messages
|
||||
|
||||
# Execute all post-process tools
|
||||
self._execute_post_process_tools()
|
||||
|
||||
return response
|
||||
|
||||
def clear_history(self):
|
||||
"""Clear conversation history and captured actions"""
|
||||
self.messages = []
|
||||
self.captured_actions = []
|
||||
461
agent/protocol/agent_stream.py
Normal file
461
agent/protocol/agent_stream.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
Agent Stream Execution Module - Multi-turn reasoning based on tool-call
|
||||
|
||||
Provides streaming output, event system, and complete tool-call loop
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional, Callable
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class AgentStreamExecutor:
|
||||
"""
|
||||
Agent Stream Executor
|
||||
|
||||
Handles multi-turn reasoning loop based on tool-call:
|
||||
1. LLM generates response (may include tool calls)
|
||||
2. Execute tools
|
||||
3. Return results to LLM
|
||||
4. Repeat until no more tool calls
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent, # Agent instance
|
||||
model: LLMModel,
|
||||
system_prompt: str,
|
||||
tools: List[BaseTool],
|
||||
max_turns: int = 50,
|
||||
on_event: Optional[Callable] = None,
|
||||
messages: Optional[List[Dict]] = None
|
||||
):
|
||||
"""
|
||||
Initialize stream executor
|
||||
|
||||
Args:
|
||||
agent: Agent instance (for accessing context)
|
||||
model: LLM model
|
||||
system_prompt: System prompt
|
||||
tools: List of available tools
|
||||
max_turns: Maximum number of turns
|
||||
on_event: Event callback function
|
||||
messages: Optional existing message history (for persistent conversations)
|
||||
"""
|
||||
self.agent = agent
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
# Convert tools list to dict
|
||||
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
|
||||
self.max_turns = max_turns
|
||||
self.on_event = on_event
|
||||
|
||||
# Message history - use provided messages or create new list
|
||||
self.messages = messages if messages is not None else []
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict = None):
|
||||
"""Emit event"""
|
||||
if self.on_event:
|
||||
try:
|
||||
self.on_event({
|
||||
"type": event_type,
|
||||
"timestamp": time.time(),
|
||||
"data": data or {}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Event callback error: {e}")
|
||||
|
||||
def run_stream(self, user_message: str) -> str:
|
||||
"""
|
||||
Execute streaming reasoning loop
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
"""
|
||||
# Log user message
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info(f"👤 用户: {user_message}")
|
||||
logger.info(f"{'='*50}")
|
||||
|
||||
# Add user message (Claude format - use content blocks for consistency)
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_message
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
self._emit_event("agent_start")
|
||||
|
||||
final_response = ""
|
||||
turn = 0
|
||||
|
||||
try:
|
||||
while turn < self.max_turns:
|
||||
turn += 1
|
||||
logger.info(f"\n{'='*50} 第 {turn} 轮 {'='*50}")
|
||||
self._emit_event("turn_start", {"turn": turn})
|
||||
|
||||
# Check if memory flush is needed (before calling LLM)
|
||||
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
||||
usage = self.agent.last_usage
|
||||
if usage and 'input_tokens' in usage:
|
||||
current_tokens = usage.get('input_tokens', 0)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
reserve_tokens = self.agent.context_reserve_tokens or 20000
|
||||
|
||||
if self.agent.memory_manager.should_flush_memory(
|
||||
current_tokens=current_tokens,
|
||||
context_window=context_window,
|
||||
reserve_tokens=reserve_tokens
|
||||
):
|
||||
self._emit_event("memory_flush_start", {
|
||||
"current_tokens": current_tokens,
|
||||
"threshold": context_window - reserve_tokens - 4000
|
||||
})
|
||||
|
||||
# TODO: Execute memory flush in background
|
||||
# This would require async support
|
||||
logger.info(f"Memory flush recommended at {current_tokens} tokens")
|
||||
|
||||
# Call LLM
|
||||
assistant_msg, tool_calls = self._call_llm_stream()
|
||||
final_response = assistant_msg
|
||||
|
||||
# No tool calls, end loop
|
||||
if not tool_calls:
|
||||
if assistant_msg:
|
||||
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
|
||||
logger.info(f"✅ 完成 (无工具调用)")
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": False
|
||||
})
|
||||
break
|
||||
|
||||
# Log tool calls in compact format
|
||||
tool_names = [tc['name'] for tc in tool_calls]
|
||||
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
|
||||
|
||||
# Execute tools
|
||||
tool_results = []
|
||||
tool_result_blocks = []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
result = self._execute_tool(tool_call)
|
||||
tool_results.append(result)
|
||||
|
||||
# Log tool result in compact format
|
||||
status_emoji = "✅" if result.get("status") == "success" else "❌"
|
||||
result_str = str(result.get('result', ''))
|
||||
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
||||
|
||||
# Build tool result block (Claude format)
|
||||
# Content should be a string representation of the result
|
||||
result_content = json.dumps(result) if not isinstance(result, str) else result
|
||||
tool_result_blocks.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_call["id"],
|
||||
"content": result_content
|
||||
})
|
||||
|
||||
# Add tool results to message history as user message (Claude format)
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": tool_result_blocks
|
||||
})
|
||||
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": True,
|
||||
"tool_count": len(tool_calls)
|
||||
})
|
||||
|
||||
if turn >= self.max_turns:
|
||||
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Agent执行错误: {e}")
|
||||
self._emit_event("error", {"error": str(e)})
|
||||
raise
|
||||
|
||||
finally:
|
||||
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
|
||||
self._emit_event("agent_end", {"final_response": final_response})
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
|
||||
"""
|
||||
Call LLM with streaming
|
||||
|
||||
Returns:
|
||||
(response_text, tool_calls)
|
||||
"""
|
||||
# Trim messages if needed (using agent's context management)
|
||||
self._trim_messages()
|
||||
|
||||
# Prepare messages
|
||||
messages = self._prepare_messages()
|
||||
|
||||
# Debug: log message structure
|
||||
logger.debug(f"Sending {len(messages)} messages to LLM")
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, list):
|
||||
content_types = [c.get("type") for c in content if isinstance(c, dict)]
|
||||
logger.debug(f" Message {i}: role={role}, content_blocks={content_types}")
|
||||
else:
|
||||
logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}")
|
||||
|
||||
# Prepare tool definitions (OpenAI/Claude format)
|
||||
tools_schema = None
|
||||
if self.tools:
|
||||
tools_schema = []
|
||||
for tool in self.tools.values():
|
||||
tools_schema.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.params # Claude uses input_schema
|
||||
})
|
||||
|
||||
# Create request
|
||||
request = LLMRequest(
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
stream=True,
|
||||
tools=tools_schema,
|
||||
system=self.system_prompt # Pass system prompt separately for Claude API
|
||||
)
|
||||
|
||||
self._emit_event("message_start", {"role": "assistant"})
|
||||
|
||||
# Streaming response
|
||||
full_content = ""
|
||||
tool_calls_buffer = {} # {index: {id, name, arguments}}
|
||||
|
||||
try:
|
||||
stream = self.model.call_stream(request)
|
||||
|
||||
for chunk in stream:
|
||||
# Check for errors
|
||||
if isinstance(chunk, dict) and chunk.get("error"):
|
||||
error_msg = chunk.get("message", "Unknown error")
|
||||
status_code = chunk.get("status_code", "N/A")
|
||||
logger.error(f"API Error: {error_msg} (Status: {status_code})")
|
||||
logger.error(f"Full error chunk: {chunk}")
|
||||
raise Exception(f"{error_msg} (Status: {status_code})")
|
||||
|
||||
# Parse chunk
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choice = chunk["choices"][0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# Handle text content
|
||||
if "content" in delta and delta["content"]:
|
||||
content_delta = delta["content"]
|
||||
full_content += content_delta
|
||||
self._emit_event("message_update", {"delta": content_delta})
|
||||
|
||||
# Handle tool calls
|
||||
if "tool_calls" in delta:
|
||||
for tc_delta in delta["tool_calls"]:
|
||||
index = tc_delta.get("index", 0)
|
||||
|
||||
if index not in tool_calls_buffer:
|
||||
tool_calls_buffer[index] = {
|
||||
"id": "",
|
||||
"name": "",
|
||||
"arguments": ""
|
||||
}
|
||||
|
||||
if "id" in tc_delta:
|
||||
tool_calls_buffer[index]["id"] = tc_delta["id"]
|
||||
|
||||
if "function" in tc_delta:
|
||||
func = tc_delta["function"]
|
||||
if "name" in func:
|
||||
tool_calls_buffer[index]["name"] = func["name"]
|
||||
if "arguments" in func:
|
||||
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call error: {e}")
|
||||
raise
|
||||
|
||||
# Parse tool calls
|
||||
tool_calls = []
|
||||
for idx in sorted(tool_calls_buffer.keys()):
|
||||
tc = tool_calls_buffer[idx]
|
||||
try:
|
||||
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
|
||||
arguments = {}
|
||||
|
||||
tool_calls.append({
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
# Add assistant message to history (Claude format uses content blocks)
|
||||
assistant_msg = {"role": "assistant", "content": []}
|
||||
|
||||
# Add text content block if present
|
||||
if full_content:
|
||||
assistant_msg["content"].append({
|
||||
"type": "text",
|
||||
"text": full_content
|
||||
})
|
||||
|
||||
# Add tool_use blocks if present
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
assistant_msg["content"].append({
|
||||
"type": "tool_use",
|
||||
"id": tc["id"],
|
||||
"name": tc["name"],
|
||||
"input": tc["arguments"]
|
||||
})
|
||||
|
||||
# Only append if content is not empty
|
||||
if assistant_msg["content"]:
|
||||
self.messages.append(assistant_msg)
|
||||
|
||||
self._emit_event("message_end", {
|
||||
"content": full_content,
|
||||
"tool_calls": tool_calls
|
||||
})
|
||||
|
||||
return full_content, tool_calls
|
||||
|
||||
def _execute_tool(self, tool_call: Dict) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute tool
|
||||
|
||||
Args:
|
||||
tool_call: {"id": str, "name": str, "arguments": dict}
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
tool_name = tool_call["name"]
|
||||
tool_id = tool_call["id"]
|
||||
arguments = tool_call["arguments"]
|
||||
|
||||
self._emit_event("tool_execution_start", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
"arguments": arguments
|
||||
})
|
||||
|
||||
try:
|
||||
tool = self.tools.get(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
|
||||
# Set tool context
|
||||
tool.model = self.model
|
||||
tool.context = self.agent
|
||||
|
||||
# Execute tool
|
||||
start_time = time.time()
|
||||
result: ToolResult = tool.execute_tool(arguments)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
result_dict = {
|
||||
"status": result.status,
|
||||
"result": result.result,
|
||||
"execution_time": execution_time
|
||||
}
|
||||
|
||||
self._emit_event("tool_execution_end", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
**result_dict
|
||||
})
|
||||
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
error_result = {
|
||||
"status": "error",
|
||||
"result": str(e),
|
||||
"execution_time": 0
|
||||
}
|
||||
self._emit_event("tool_execution_end", {
|
||||
"tool_call_id": tool_id,
|
||||
"tool_name": tool_name,
|
||||
**error_result
|
||||
})
|
||||
return error_result
|
||||
|
||||
def _trim_messages(self):
|
||||
"""
|
||||
Trim message history to stay within context limits.
|
||||
Uses agent's context management configuration.
|
||||
"""
|
||||
if not self.messages or not self.agent:
|
||||
return
|
||||
|
||||
# Get context window and reserve tokens from agent
|
||||
context_window = self.agent._get_model_context_window()
|
||||
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||
max_tokens = context_window - reserve_tokens
|
||||
|
||||
# Estimate current tokens
|
||||
current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
|
||||
|
||||
# Add system prompt tokens
|
||||
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
|
||||
current_tokens += system_tokens
|
||||
|
||||
# If under limit, no need to trim
|
||||
if current_tokens <= max_tokens:
|
||||
return
|
||||
|
||||
# Keep messages from newest, accumulating tokens
|
||||
available_tokens = max_tokens - system_tokens
|
||||
kept_messages = []
|
||||
accumulated_tokens = 0
|
||||
|
||||
for msg in reversed(self.messages):
|
||||
msg_tokens = self.agent._estimate_message_tokens(msg)
|
||||
if accumulated_tokens + msg_tokens <= available_tokens:
|
||||
kept_messages.insert(0, msg)
|
||||
accumulated_tokens += msg_tokens
|
||||
else:
|
||||
break
|
||||
|
||||
old_count = len(self.messages)
|
||||
self.messages = kept_messages
|
||||
new_count = len(self.messages)
|
||||
|
||||
if old_count > new_count:
|
||||
logger.info(
|
||||
f"Context trimmed: {old_count} -> {new_count} messages "
|
||||
f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
|
||||
f"limit: {max_tokens})"
|
||||
)
|
||||
|
||||
def _prepare_messages(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Prepare messages to send to LLM
|
||||
|
||||
Note: For Claude API, system prompt should be passed separately via system parameter,
|
||||
not as a message. The AgentLLMModel will handle this.
|
||||
"""
|
||||
# Don't add system message here - it will be handled separately by the LLM adapter
|
||||
return self.messages
|
||||
27
agent/protocol/context.py
Normal file
27
agent/protocol/context.py
Normal file
@@ -0,0 +1,27 @@
|
||||
class TeamContext:
|
||||
def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100):
|
||||
"""
|
||||
Initialize the TeamContext with a name, description, rules, a list of agents, and a user question.
|
||||
:param name: The name of the group context.
|
||||
:param description: A description of the group context.
|
||||
:param rule: The rules governing the group context.
|
||||
:param agents: A list of agents in the context.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.rule = rule
|
||||
self.agents = agents
|
||||
self.user_task = "" # For backward compatibility
|
||||
self.task = None # Will be a Task instance
|
||||
self.model = None # Will be an instance of LLMModel
|
||||
self.task_short_name = None # Store the task directory name
|
||||
# List of agents that have been executed
|
||||
self.agent_outputs: list = []
|
||||
self.current_steps = 0
|
||||
self.max_steps = max_steps
|
||||
|
||||
|
||||
class AgentOutput:
|
||||
def __init__(self, agent_name: str, output: str):
|
||||
self.agent_name = agent_name
|
||||
self.output = output
|
||||
57
agent/protocol/models.py
Normal file
57
agent/protocol/models.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Models module for agent system.
|
||||
Provides basic model classes needed by tools and bridge integration.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
"""Request model for LLM operations"""
|
||||
|
||||
def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None,
|
||||
temperature: float = 0.7, max_tokens: Optional[int] = None,
|
||||
stream: bool = False, tools: Optional[List] = None, **kwargs):
|
||||
self.messages = messages or []
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.tools = tools
|
||||
# Allow extra attributes
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LLMModel:
|
||||
"""Base class for LLM models"""
|
||||
|
||||
def __init__(self, model: str = None, **kwargs):
|
||||
self.model = model
|
||||
self.config = kwargs
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with a request.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call not implemented in this context")
|
||||
|
||||
def call_stream(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with streaming.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call_stream not implemented in this context")
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
"""Factory for creating model instances"""
|
||||
|
||||
@staticmethod
|
||||
def create_model(model_type: str, **kwargs):
|
||||
"""
|
||||
Create a model instance based on type.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("ModelFactory.create_model not implemented in this context")
|
||||
96
agent/protocol/result.py
Normal file
96
agent/protocol/result.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from agent.protocol.task import Task, TaskStatus
|
||||
|
||||
|
||||
class AgentActionType(Enum):
|
||||
"""Enum representing different types of agent actions."""
|
||||
TOOL_USE = "tool_use"
|
||||
THINKING = "thinking"
|
||||
FINAL_ANSWER = "final_answer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""
|
||||
Represents the result of a tool use.
|
||||
|
||||
Attributes:
|
||||
tool_name: Name of the tool used
|
||||
input_params: Parameters passed to the tool
|
||||
output: Output from the tool
|
||||
status: Status of the tool execution (success/error)
|
||||
error_message: Error message if the tool execution failed
|
||||
execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_name: str
|
||||
input_params: Dict[str, Any]
|
||||
output: Any
|
||||
status: str
|
||||
error_message: Optional[str] = None
|
||||
execution_time: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""
|
||||
Represents an action taken by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the action
|
||||
agent_id: ID of the agent that performed the action
|
||||
agent_name: Name of the agent that performed the action
|
||||
action_type: Type of action (tool use, thinking, final answer)
|
||||
content: Content of the action (thought content, final answer content)
|
||||
tool_result: Tool use details if action_type is TOOL_USE
|
||||
timestamp: When the action was performed
|
||||
"""
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
action_type: AgentActionType
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
tool_result: Optional[ToolResult] = None
|
||||
thought: Optional[str] = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""
|
||||
Represents the result of an agent's execution.
|
||||
|
||||
Attributes:
|
||||
final_answer: The final answer provided by the agent
|
||||
step_count: Number of steps taken by the agent
|
||||
status: Status of the execution (success/error)
|
||||
error_message: Error message if execution failed
|
||||
"""
|
||||
final_answer: str
|
||||
step_count: int
|
||||
status: str = "success"
|
||||
error_message: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def success(cls, final_answer: str, step_count: int) -> "AgentResult":
|
||||
"""Create a successful result"""
|
||||
return cls(final_answer=final_answer, step_count=step_count)
|
||||
|
||||
@classmethod
|
||||
def error(cls, error_message: str, step_count: int = 0) -> "AgentResult":
|
||||
"""Create an error result"""
|
||||
return cls(
|
||||
final_answer=f"Error: {error_message}",
|
||||
step_count=step_count,
|
||||
status="error",
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
@property
|
||||
def is_error(self) -> bool:
|
||||
"""Check if the result represents an error"""
|
||||
return self.status == "error"
|
||||
95
agent/protocol/task.py
Normal file
95
agent/protocol/task.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""Enum representing different types of tasks."""
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
FILE = "file"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Enum representing the status of a task."""
|
||||
INIT = "init" # Initial state
|
||||
PROCESSING = "processing" # In progress
|
||||
COMPLETED = "completed" # Completed
|
||||
FAILED = "failed" # Failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""
|
||||
Represents a task to be processed by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the task
|
||||
content: The primary text content of the task
|
||||
type: Type of the task
|
||||
status: Current status of the task
|
||||
created_at: Timestamp when the task was created
|
||||
updated_at: Timestamp when the task was last updated
|
||||
metadata: Additional metadata for the task
|
||||
images: List of image URLs or base64 encoded images
|
||||
videos: List of video URLs
|
||||
audios: List of audio URLs or base64 encoded audios
|
||||
files: List of file URLs or paths
|
||||
"""
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
type: TaskType = TaskType.TEXT
|
||||
status: TaskStatus = TaskStatus.INIT
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Media content
|
||||
images: List[str] = field(default_factory=list)
|
||||
videos: List[str] = field(default_factory=list)
|
||||
audios: List[str] = field(default_factory=list)
|
||||
files: List[str] = field(default_factory=list)
|
||||
|
||||
def __init__(self, content: str = "", **kwargs):
|
||||
"""
|
||||
Initialize a Task with content and optional keyword arguments.
|
||||
|
||||
Args:
|
||||
content: The text content of the task
|
||||
**kwargs: Additional attributes to set
|
||||
"""
|
||||
self.id = kwargs.get('id', str(uuid.uuid4()))
|
||||
self.content = content
|
||||
self.type = kwargs.get('type', TaskType.TEXT)
|
||||
self.status = kwargs.get('status', TaskStatus.INIT)
|
||||
self.created_at = kwargs.get('created_at', time.time())
|
||||
self.updated_at = kwargs.get('updated_at', time.time())
|
||||
self.metadata = kwargs.get('metadata', {})
|
||||
self.images = kwargs.get('images', [])
|
||||
self.videos = kwargs.get('videos', [])
|
||||
self.audios = kwargs.get('audios', [])
|
||||
self.files = kwargs.get('files', [])
|
||||
|
||||
def get_text(self) -> str:
|
||||
"""
|
||||
Get the text content of the task.
|
||||
|
||||
Returns:
|
||||
The text content
|
||||
"""
|
||||
return self.content
|
||||
|
||||
def update_status(self, status: TaskStatus) -> None:
|
||||
"""
|
||||
Update the status of the task.
|
||||
|
||||
Args:
|
||||
status: The new status
|
||||
"""
|
||||
self.status = status
|
||||
self.updated_at = time.time()
|
||||
101
agent/tools/__init__.py
Normal file
101
agent/tools/__init__.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Import base tool
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
|
||||
# Import basic tools (no external dependencies)
|
||||
from agent.tools.calculator.calculator import Calculator
|
||||
from agent.tools.current_time.current_time import CurrentTime
|
||||
|
||||
# Import file operation tools
|
||||
from agent.tools.read.read import Read
|
||||
from agent.tools.write.write import Write
|
||||
from agent.tools.edit.edit import Edit
|
||||
from agent.tools.bash.bash import Bash
|
||||
from agent.tools.grep.grep import Grep
|
||||
from agent.tools.find.find import Find
|
||||
from agent.tools.ls.ls import Ls
|
||||
|
||||
# Import memory tools
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
# Import tools with optional dependencies
|
||||
def _import_optional_tools():
|
||||
"""Import tools that have optional dependencies"""
|
||||
tools = {}
|
||||
|
||||
# Google Search (requires requests)
|
||||
try:
|
||||
from agent.tools.google_search.google_search import GoogleSearch
|
||||
tools['GoogleSearch'] = GoogleSearch
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# File Save (may have dependencies)
|
||||
try:
|
||||
from agent.tools.file_save.file_save import FileSave
|
||||
tools['FileSave'] = FileSave
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Terminal (basic, should work)
|
||||
try:
|
||||
from agent.tools.terminal.terminal import Terminal
|
||||
tools['Terminal'] = Terminal
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return tools
|
||||
|
||||
# Load optional tools
|
||||
_optional_tools = _import_optional_tools()
|
||||
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
Terminal = _optional_tools.get('Terminal')
|
||||
|
||||
|
||||
# Delayed import for BrowserTool
|
||||
def _import_browser_tool():
|
||||
try:
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
return BrowserTool
|
||||
except ImportError:
|
||||
# Return a placeholder class that will prompt the user to install dependencies when instantiated
|
||||
class BrowserToolPlaceholder:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"The 'browser-use' package is required to use BrowserTool. "
|
||||
"Please install it with 'pip install browser-use>=0.1.40'."
|
||||
)
|
||||
|
||||
return BrowserToolPlaceholder
|
||||
|
||||
|
||||
# Dynamically set BrowserTool
|
||||
BrowserTool = _import_browser_tool()
|
||||
|
||||
# Export all tools (including optional ones that might be None)
|
||||
__all__ = [
|
||||
'BaseTool',
|
||||
'ToolManager',
|
||||
'Calculator',
|
||||
'CurrentTime',
|
||||
'Read',
|
||||
'Write',
|
||||
'Edit',
|
||||
'Bash',
|
||||
'Grep',
|
||||
'Find',
|
||||
'Ls',
|
||||
'MemorySearchTool',
|
||||
'MemoryGetTool',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
'GoogleSearch',
|
||||
'FileSave',
|
||||
'Terminal',
|
||||
'BrowserTool'
|
||||
]
|
||||
|
||||
"""
|
||||
Tools module for Agent.
|
||||
"""
|
||||
99
agent/tools/base_tool.py
Normal file
99
agent/tools/base_tool.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from common.log import logger
|
||||
import copy
|
||||
|
||||
|
||||
class ToolStage(Enum):
|
||||
"""Enum representing tool decision stages"""
|
||||
PRE_PROCESS = "pre_process" # Tools that need to be actively selected by the agent
|
||||
POST_PROCESS = "post_process" # Tools that automatically execute after final_answer
|
||||
|
||||
|
||||
class ToolResult:
|
||||
"""Tool execution result"""
|
||||
|
||||
def __init__(self, status: str = None, result: Any = None, ext_data: Any = None):
|
||||
self.status = status
|
||||
self.result = result
|
||||
self.ext_data = ext_data
|
||||
|
||||
@staticmethod
|
||||
def success(result, ext_data: Any = None):
|
||||
return ToolResult(status="success", result=result, ext_data=ext_data)
|
||||
|
||||
@staticmethod
|
||||
def fail(result, ext_data: Any = None):
|
||||
return ToolResult(status="error", result=result, ext_data=ext_data)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all tools."""
|
||||
|
||||
# Default decision stage is pre-process
|
||||
stage = ToolStage.PRE_PROCESS
|
||||
|
||||
# Class attributes must be inherited
|
||||
name: str = "base_tool"
|
||||
description: str = "Base tool"
|
||||
params: dict = {} # Store JSON Schema
|
||||
model: Optional[Any] = None # LLM model instance, type depends on bot implementation
|
||||
|
||||
@classmethod
|
||||
def get_json_schema(cls) -> dict:
|
||||
"""Get the standard description of the tool"""
|
||||
return {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": cls.params
|
||||
}
|
||||
|
||||
def execute_tool(self, params: dict) -> ToolResult:
|
||||
try:
|
||||
return self.execute(params)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
"""Specific logic to be implemented by subclasses"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _parse_schema(cls) -> dict:
|
||||
"""Convert JSON Schema to Pydantic fields"""
|
||||
fields = {}
|
||||
for name, prop in cls.params["properties"].items():
|
||||
# Convert JSON Schema types to Python types
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict
|
||||
}
|
||||
fields[name] = (
|
||||
type_map[prop["type"]],
|
||||
prop.get("default", ...)
|
||||
)
|
||||
return fields
|
||||
|
||||
def should_auto_execute(self, context) -> bool:
|
||||
"""
|
||||
Determine if this tool should be automatically executed based on context.
|
||||
|
||||
:param context: The agent context
|
||||
:return: True if the tool should be executed, False otherwise
|
||||
"""
|
||||
# Only tools in post-process stage will be automatically executed
|
||||
return self.stage == ToolStage.POST_PROCESS
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close any resources used by the tool.
|
||||
This method should be overridden by tools that need to clean up resources
|
||||
such as browser connections, file handles, etc.
|
||||
|
||||
By default, this method does nothing.
|
||||
"""
|
||||
pass
|
||||
3
agent/tools/bash/__init__.py
Normal file
3
agent/tools/bash/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .bash import Bash
|
||||
|
||||
__all__ = ['Bash']
|
||||
187
agent/tools/bash/bash.py
Normal file
187
agent/tools/bash/bash.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Bash tool - Execute bash commands
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
class Bash(BaseTool):
|
||||
"""Tool for executing bash commands"""
|
||||
|
||||
name: str = "bash"
|
||||
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
|
||||
|
||||
IMPORTANT SAFETY GUIDELINES:
|
||||
- You can freely create, modify, and delete files within the current workspace
|
||||
- For operations outside the workspace or potentially destructive commands (rm -rf, system commands, etc.), always explain what you're about to do and ask for user confirmation first
|
||||
- Be especially careful with: file deletions, system modifications, network operations, or commands that might affect system stability
|
||||
- When in doubt, describe the command's purpose and ask for permission before executing"""
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (optional, default: 30)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
# Ensure working directory exists
|
||||
if not os.path.exists(self.cwd):
|
||||
os.makedirs(self.cwd, exist_ok=True)
|
||||
self.default_timeout = self.config.get("timeout", 30)
|
||||
# Enable safety mode by default (can be disabled in config)
|
||||
self.safety_mode = self.config.get("safety_mode", True)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute a bash command
|
||||
|
||||
:param args: Dictionary containing the command and optional timeout
|
||||
:return: Command output or error
|
||||
"""
|
||||
command = args.get("command", "").strip()
|
||||
timeout = args.get("timeout", self.default_timeout)
|
||||
|
||||
if not command:
|
||||
return ToolResult.fail("Error: command parameter is required")
|
||||
|
||||
# Optional safety check - only warn about extremely dangerous commands
|
||||
if self.safety_mode:
|
||||
warning = self._get_safety_warning(command)
|
||||
if warning:
|
||||
return ToolResult.fail(
|
||||
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
|
||||
|
||||
try:
|
||||
# Execute command
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
cwd=self.cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# Combine stdout and stderr
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
|
||||
# Check if we need to save full output to temp file
|
||||
temp_file_path = None
|
||||
total_bytes = len(output.encode('utf-8'))
|
||||
|
||||
if total_bytes > DEFAULT_MAX_BYTES:
|
||||
# Save full output to temp file
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='bash-') as f:
|
||||
f.write(output)
|
||||
temp_file_path = f.name
|
||||
|
||||
# Apply tail truncation
|
||||
truncation = truncate_tail(output)
|
||||
output_text = truncation.content or "(no output)"
|
||||
|
||||
# Build result
|
||||
details = {}
|
||||
|
||||
if truncation.truncated:
|
||||
details["truncation"] = truncation.to_dict()
|
||||
if temp_file_path:
|
||||
details["full_output_path"] = temp_file_path
|
||||
|
||||
# Build notice
|
||||
start_line = truncation.total_lines - truncation.output_lines + 1
|
||||
end_line = truncation.total_lines
|
||||
|
||||
if truncation.last_line_partial:
|
||||
# Edge case: last line alone > 30KB
|
||||
last_line = output.split('\n')[-1] if output else ""
|
||||
last_line_size = format_size(len(last_line.encode('utf-8')))
|
||||
output_text += f"\n\n[Showing last {format_size(truncation.output_bytes)} of line {end_line} (line is {last_line_size}). Full output: {temp_file_path}]"
|
||||
elif truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines}. Full output: {temp_file_path}]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Full output: {temp_file_path}]"
|
||||
|
||||
# Check exit code
|
||||
if result.returncode != 0:
|
||||
output_text += f"\n\nCommand exited with code {result.returncode}"
|
||||
return ToolResult.fail({
|
||||
"output": output_text,
|
||||
"exit_code": result.returncode,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output_text,
|
||||
"exit_code": result.returncode,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail(f"Error: Command timed out after {timeout} seconds")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing command: {str(e)}")
|
||||
|
||||
def _get_safety_warning(self, command: str) -> str:
|
||||
"""
|
||||
Get safety warning for potentially dangerous commands
|
||||
Only warns about extremely dangerous system-level operations
|
||||
|
||||
:param command: Command to check
|
||||
:return: Warning message if dangerous, empty string if safe
|
||||
"""
|
||||
cmd_lower = command.lower().strip()
|
||||
|
||||
# Only block extremely dangerous system operations
|
||||
dangerous_patterns = [
|
||||
# System shutdown/reboot
|
||||
("shutdown", "This command will shut down the system"),
|
||||
("reboot", "This command will reboot the system"),
|
||||
("halt", "This command will halt the system"),
|
||||
("poweroff", "This command will power off the system"),
|
||||
|
||||
# Critical system modifications
|
||||
("rm -rf /", "This command will delete the entire filesystem"),
|
||||
("rm -rf /*", "This command will delete the entire filesystem"),
|
||||
("dd if=/dev/zero", "This command can destroy disk data"),
|
||||
("mkfs", "This command will format a filesystem, destroying all data"),
|
||||
("fdisk", "This command modifies disk partitions"),
|
||||
|
||||
# User/system management (only if targeting system users)
|
||||
("userdel root", "This command will delete the root user"),
|
||||
("passwd root", "This command will change the root password"),
|
||||
]
|
||||
|
||||
for pattern, warning in dangerous_patterns:
|
||||
if pattern in cmd_lower:
|
||||
return warning
|
||||
|
||||
# Check for recursive deletion outside workspace
|
||||
if "rm" in cmd_lower and "-rf" in cmd_lower:
|
||||
# Allow deletion within current workspace
|
||||
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
|
||||
# Check if targeting system directories
|
||||
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
|
||||
if any(sysdir in cmd_lower for sysdir in system_dirs):
|
||||
return "This command will recursively delete system directories"
|
||||
|
||||
return "" # No warning needed
|
||||
59
agent/tools/browser/browser_action.py
Normal file
59
agent/tools/browser/browser_action.py
Normal file
@@ -0,0 +1,59 @@
|
||||
class BrowserAction:
|
||||
"""Base class for browser actions"""
|
||||
code = ""
|
||||
description = ""
|
||||
|
||||
|
||||
class Navigate(BrowserAction):
|
||||
"""Navigate to a URL in the current tab"""
|
||||
code = "navigate"
|
||||
description = "Navigate to URL in the current tab"
|
||||
|
||||
|
||||
class ClickElement(BrowserAction):
|
||||
"""Click an element on the page"""
|
||||
code = "click_element"
|
||||
description = "Click element"
|
||||
|
||||
|
||||
class ExtractContent(BrowserAction):
|
||||
"""Extract content from the page"""
|
||||
code = "extract_content"
|
||||
description = "Extract the page content to retrieve specific information for a goal"
|
||||
|
||||
|
||||
class InputText(BrowserAction):
|
||||
"""Input text into an element"""
|
||||
code = "input_text"
|
||||
description = "Input text into a input interactive element"
|
||||
|
||||
|
||||
class ScrollDown(BrowserAction):
|
||||
"""Scroll down the page"""
|
||||
code = "scroll_down"
|
||||
description = "Scroll down the page by pixel amount"
|
||||
|
||||
|
||||
class ScrollUp(BrowserAction):
|
||||
"""Scroll up the page"""
|
||||
code = "scroll_up"
|
||||
description = "Scroll up the page by pixel amount - if no amount is specified, scroll up one page"
|
||||
|
||||
|
||||
class OpenTab(BrowserAction):
|
||||
"""Open a URL in a new tab"""
|
||||
code = "open_tab"
|
||||
description = "Open url in new tab"
|
||||
|
||||
|
||||
class SwitchTab(BrowserAction):
|
||||
"""Switch to a tab"""
|
||||
code = "switch_tab"
|
||||
description = "Switched to tab"
|
||||
|
||||
|
||||
class SendKeys(BrowserAction):
|
||||
"""Switch to a tab"""
|
||||
code = "send_keys"
|
||||
description = "Send strings of special keyboard keys like Escape, Backspace, Insert, PageDown, Delete, Enter, " \
|
||||
"ArrowRight, ArrowUp, etc"
|
||||
317
agent/tools/browser/browser_tool.py
Normal file
317
agent/tools/browser/browser_tool.py
Normal file
@@ -0,0 +1,317 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import platform
|
||||
from browser_use import Browser
|
||||
from browser_use import BrowserConfig
|
||||
from browser_use.browser.context import BrowserContext, BrowserContextConfig
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.browser.browser_action import *
|
||||
from agent.models import LLMRequest
|
||||
from agent.models.model_factory import ModelFactory
|
||||
from browser_use.dom.service import DomService
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# Use lazy import, only import when actually used
|
||||
def _import_browser_use():
|
||||
try:
|
||||
import browser_use
|
||||
return browser_use
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'browser-use' package is required to use BrowserTool. "
|
||||
"Please install it with 'pip install browser-use>=0.1.40' or "
|
||||
"'pip install agentmesh-sdk[full]'."
|
||||
)
|
||||
|
||||
|
||||
def _get_action_prompt():
|
||||
action_classes = [Navigate, ClickElement, ExtractContent, InputText, OpenTab, SwitchTab, ScrollDown, ScrollUp,
|
||||
SendKeys]
|
||||
action_prompt = ""
|
||||
for action_class in action_classes:
|
||||
action_prompt += f"{action_class.code}: {action_class.description}\n"
|
||||
return action_prompt.strip()
|
||||
|
||||
|
||||
def _header_less() -> bool:
|
||||
if platform.system() == "Linux" and not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class BrowserTool(BaseTool):
|
||||
name: str = "browser"
|
||||
description: str = "A tool to perform browser operations like navigating to URLs, element interaction, " \
|
||||
"and extracting content."
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"operation": {
|
||||
"type": "string",
|
||||
"description": f"The browser operation to perform: \n{_get_action_prompt()}"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": f"The URL to navigate to (required for '{Navigate.code}', '{OpenTab.code}' actions). "
|
||||
},
|
||||
"goal": {
|
||||
"type": "string",
|
||||
"description": f"The goal of extracting page content (required for '{ExtractContent.code}' action)."
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": f"Text to type (required for '{InputText.code}' action)."
|
||||
},
|
||||
"index": {
|
||||
"type": "integer",
|
||||
"description": f"Element index (required for '{ClickElement.code}', '{InputText.code}' actions)",
|
||||
},
|
||||
"tab_id": {
|
||||
"type": "integer",
|
||||
"description": f"Page tab ID (required for '{SwitchTab.code}' action)",
|
||||
},
|
||||
"scroll_amount": {
|
||||
"type": "integer",
|
||||
"description": f"The number of pixels to scroll (required for '{ScrollDown.code}', '{ScrollUp.code}' action)."
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": f"Keys to send (required for '{SendKeys.code}' action)"
|
||||
}
|
||||
},
|
||||
"required": ["operation"]
|
||||
}
|
||||
|
||||
# Class variable to ensure only one browser instance is created
|
||||
browser = None
|
||||
browser_context: BrowserContext = None
|
||||
dom_service: DomService = None
|
||||
_initialized = False
|
||||
|
||||
# Adding an event loop variable
|
||||
_event_loop = None
|
||||
|
||||
def __init__(self):
|
||||
# Only import during initialization, not at module level
|
||||
self.browser_use = _import_browser_use()
|
||||
# Do not initialize the browser in the constructor, but initialize it on the first execution
|
||||
pass
|
||||
|
||||
async def _init_browser(self) -> BrowserContext:
|
||||
"""Ensure the browser is initialized"""
|
||||
if not BrowserTool._initialized:
|
||||
os.environ['BROWSER_USE_LOGGING_LEVEL'] = 'error'
|
||||
print("Initializing browser...")
|
||||
# Initialize the browser synchronously
|
||||
BrowserTool.browser = Browser(BrowserConfig(headless=_header_less(),
|
||||
disable_security=True))
|
||||
context_config = BrowserContextConfig()
|
||||
context_config.highlight_elements = True
|
||||
BrowserTool.browser_context = await BrowserTool.browser.new_context(context_config)
|
||||
BrowserTool._initialized = True
|
||||
print("Browser initialized successfully")
|
||||
BrowserTool.dom_service = DomService(await BrowserTool.browser_context.get_current_page())
|
||||
return BrowserTool.browser_context
|
||||
|
||||
def execute(self, params: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute browser operations based on the provided arguments.
|
||||
|
||||
:param params: Dictionary containing the action and related parameters
|
||||
:return: Result of the browser operation
|
||||
"""
|
||||
# Ensure browser_use is imported
|
||||
if not hasattr(self, 'browser_use'):
|
||||
self.browser_use = _import_browser_use()
|
||||
action = params.get("operation", "").lower()
|
||||
|
||||
try:
|
||||
# Use a single event loop
|
||||
if BrowserTool._event_loop is None:
|
||||
BrowserTool._event_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(BrowserTool._event_loop)
|
||||
# Run tasks in the existing event loop
|
||||
return BrowserTool._event_loop.run_until_complete(self._execute_async(action, params))
|
||||
except Exception as e:
|
||||
print(f"Error executing browser action: {e}")
|
||||
return ToolResult.fail(result=f"Error executing browser action: {str(e)}")
|
||||
|
||||
async def _get_page_state(self, context: BrowserContext):
|
||||
state = await self._get_state(context)
|
||||
include_attributes = ["img", "div", "button", "input"]
|
||||
elements = state.element_tree.clickable_elements_to_string(include_attributes)
|
||||
pattern = r'\[\d+\]<[^>]+\/>'
|
||||
# Find all matching elements
|
||||
interactive_elements = re.findall(pattern, elements)
|
||||
page_state = {
|
||||
"url": state.url,
|
||||
"title": state.title,
|
||||
"pixels_above": getattr(state, "pixels_above", 0),
|
||||
"pixels_below": getattr(state, "pixels_below", 0),
|
||||
"tabs": [tab.model_dump() for tab in state.tabs],
|
||||
"interactive_elements": interactive_elements,
|
||||
}
|
||||
return page_state
|
||||
|
||||
async def _get_state(self, context: BrowserContext, cache_clickable_elements_hashes=True):
|
||||
try:
|
||||
return await context.get_state()
|
||||
except TypeError:
|
||||
return await context.get_state(cache_clickable_elements_hashes=cache_clickable_elements_hashes)
|
||||
|
||||
async def _get_page_info(self, context: BrowserContext):
|
||||
page_state = await self._get_page_state(context)
|
||||
state_str = f"""## Current browser state
|
||||
The following is the information of the current browser page. Each serial number in interactive_elements represents the element index:
|
||||
{json.dumps(page_state, indent=4, ensure_ascii=False)}
|
||||
"""
|
||||
return state_str
|
||||
|
||||
async def _execute_async(self, action: str, params: Dict[str, Any]) -> ToolResult:
|
||||
"""Asynchronously execute browser operations"""
|
||||
# Use the browser context from the class variable
|
||||
context = await self._init_browser()
|
||||
|
||||
if action == Navigate.code:
|
||||
url = params.get("url")
|
||||
if not url:
|
||||
return ToolResult.fail(result="URL is required for navigate action")
|
||||
if url.startswith("/"):
|
||||
url = f"file://{url}"
|
||||
print(f"Navigating to {url}...")
|
||||
page = await context.get_current_page()
|
||||
await page.goto(url)
|
||||
await page.wait_for_load_state()
|
||||
state = await self._get_page_info(context)
|
||||
# print(state)
|
||||
print(f"Navigation complete")
|
||||
return ToolResult.success(result=f"Navigated to {url}", ext_data=state)
|
||||
|
||||
elif action == OpenTab.code:
|
||||
url = params.get("url")
|
||||
if url.startswith("/"):
|
||||
url = f"file://{url}"
|
||||
await context.create_new_tab(url)
|
||||
msg = f"Opened new tab with {url}"
|
||||
return ToolResult.success(result=msg)
|
||||
|
||||
elif action == ExtractContent.code:
|
||||
try:
|
||||
goal = params.get("goal")
|
||||
page = await context.get_current_page()
|
||||
if params.get("url"):
|
||||
await page.goto(params.get("url"))
|
||||
await page.wait_for_load_state()
|
||||
import markdownify
|
||||
content = markdownify.markdownify(await page.content())
|
||||
elements = await self._get_page_state(context)
|
||||
prompt = f"Your task is to extract the content of the page. You will be given a page and a goal and you should extract all relevant information around this goal from the page. If the goal is vague, " \
|
||||
f"summarize the page. Respond in json format. elements: {elements.get('interactive_elements')}, extraction goal: {goal}, Page: {content},"
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0,
|
||||
json_format=True
|
||||
)
|
||||
model = self.model or ModelFactory().get_model(model_name="gpt-4o")
|
||||
response = model.call(request)
|
||||
if response.success:
|
||||
extract_content = response.data["choices"][0]["message"]["content"]
|
||||
print(f"Extract from page: {extract_content}")
|
||||
return ToolResult.success(result=f"Extract from page: {extract_content}",
|
||||
ext_data=await self._get_page_info(context))
|
||||
else:
|
||||
return ToolResult.fail(result=f"Extract from page failed: {response.get_error_msg()}")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
elif action == ClickElement.code:
|
||||
index = params.get("index")
|
||||
element = await context.get_dom_element_by_index(index)
|
||||
await context._click_element_node(element)
|
||||
msg = f"Clicked element at index {index}"
|
||||
print(msg)
|
||||
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
||||
|
||||
elif action == InputText.code:
|
||||
index = params.get("index")
|
||||
text = params.get("text")
|
||||
element = await context.get_dom_element_by_index(index)
|
||||
await context._input_text_element_node(element, text)
|
||||
await asyncio.sleep(1)
|
||||
msg = f"Input text into element successfully, index: {index}, text: {text}"
|
||||
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
||||
|
||||
elif action == SwitchTab.code:
|
||||
tab_id = params.get("tab_id")
|
||||
print(f"Switch tab, tab_id={tab_id}")
|
||||
await context.switch_to_tab(tab_id)
|
||||
page = await context.get_current_page()
|
||||
await page.wait_for_load_state()
|
||||
msg = f"Switched to tab {tab_id}"
|
||||
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
||||
|
||||
elif action in [ScrollDown.code, ScrollUp.code]:
|
||||
scroll_amount = params.get("scroll_amount")
|
||||
if not scroll_amount:
|
||||
scroll_amount = context.config.browser_window_size["height"]
|
||||
print(f"Scrolling by {scroll_amount} pixels")
|
||||
scroll_amount = scroll_amount if action == ScrollDown.code else (scroll_amount * -1)
|
||||
await context.execute_javascript(f"window.scrollBy(0, {scroll_amount});")
|
||||
msg = f"{action} by {scroll_amount} pixels"
|
||||
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
||||
|
||||
elif action == SendKeys.code:
|
||||
keys = params.get("keys")
|
||||
page = await context.get_current_page()
|
||||
await page.keyboard.press(keys)
|
||||
msg = f"Sent keys: {keys}"
|
||||
print(msg)
|
||||
return ToolResult(output=f"Sent keys: {keys}")
|
||||
|
||||
else:
|
||||
msg = "Failed to operate the browser"
|
||||
return ToolResult.fail(result=msg)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close browser resources.
|
||||
This method handles the asynchronous closing of browser and browser context.
|
||||
"""
|
||||
if not BrowserTool._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# Use the existing event loop to close browser resources
|
||||
if BrowserTool._event_loop is not None:
|
||||
# Define the async close function
|
||||
async def close_browser_async():
|
||||
if BrowserTool.browser_context is not None:
|
||||
try:
|
||||
await BrowserTool.browser_context.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing browser context: {e}")
|
||||
|
||||
if BrowserTool.browser is not None:
|
||||
try:
|
||||
await BrowserTool.browser.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing browser: {e}")
|
||||
|
||||
# Reset the initialized flag
|
||||
BrowserTool._initialized = False
|
||||
BrowserTool.browser = None
|
||||
BrowserTool.browser_context = None
|
||||
BrowserTool.dom_service = None
|
||||
|
||||
# Run the async close function in the existing event loop
|
||||
BrowserTool._event_loop.run_until_complete(close_browser_async())
|
||||
|
||||
# Close the event loop
|
||||
BrowserTool._event_loop.close()
|
||||
BrowserTool._event_loop = None
|
||||
except Exception as e:
|
||||
print(f"Error during browser cleanup: {e}")
|
||||
18
agent/tools/browser_tool.py
Normal file
18
agent/tools/browser_tool.py
Normal file
@@ -0,0 +1,18 @@
|
||||
def copy(self):
|
||||
"""
|
||||
Special copy method for browser tool to avoid recreating browser instance.
|
||||
|
||||
:return: A new instance with shared browser reference but unique model
|
||||
"""
|
||||
new_tool = self.__class__()
|
||||
|
||||
# Copy essential attributes
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, 'context', None)
|
||||
new_tool.config = getattr(self, 'config', None)
|
||||
|
||||
# Share the browser instance instead of creating a new one
|
||||
if hasattr(self, 'browser'):
|
||||
new_tool.browser = self.browser
|
||||
|
||||
return new_tool
|
||||
58
agent/tools/calculator/calculator.py
Normal file
58
agent/tools/calculator/calculator.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import math
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class Calculator(BaseTool):
|
||||
name: str = "calculator"
|
||||
description: str = "A tool to perform basic mathematical calculations."
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). "
|
||||
"Ensure your input is a valid Python expression, it will be evaluated directly."
|
||||
}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
config: dict = {}
|
||||
|
||||
def execute(self, args: dict) -> ToolResult:
|
||||
try:
|
||||
# Get the expression
|
||||
expression = args["expression"]
|
||||
|
||||
# Create a safe local environment containing only basic math functions
|
||||
safe_locals = {
|
||||
"abs": abs,
|
||||
"round": round,
|
||||
"max": max,
|
||||
"min": min,
|
||||
"pow": pow,
|
||||
"sqrt": math.sqrt,
|
||||
"sin": math.sin,
|
||||
"cos": math.cos,
|
||||
"tan": math.tan,
|
||||
"pi": math.pi,
|
||||
"e": math.e,
|
||||
"log": math.log,
|
||||
"log10": math.log10,
|
||||
"exp": math.exp,
|
||||
"floor": math.floor,
|
||||
"ceil": math.ceil
|
||||
}
|
||||
|
||||
# Safely evaluate the expression
|
||||
result = eval(expression, {"__builtins__": {}}, safe_locals)
|
||||
|
||||
return ToolResult.success({
|
||||
"result": result,
|
||||
"expression": expression
|
||||
})
|
||||
except Exception as e:
|
||||
return ToolResult.success({
|
||||
"error": str(e),
|
||||
"expression": args.get("expression", "")
|
||||
})
|
||||
75
agent/tools/current_time/current_time.py
Normal file
75
agent/tools/current_time/current_time.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class CurrentTime(BaseTool):
|
||||
name: str = "time"
|
||||
description: str = "A tool to get current date and time information."
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"format": {
|
||||
"type": "string",
|
||||
"description": "Optional format for the time (e.g., 'iso', 'unix', 'human'). Default is 'human'."
|
||||
},
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Optional timezone specification (e.g., 'UTC', 'local'). Default is 'local'."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
config: dict = {}
|
||||
|
||||
def execute(self, args: dict) -> ToolResult:
|
||||
try:
|
||||
# Get the format and timezone parameters, with defaults
|
||||
time_format = args.get("format", "human").lower()
|
||||
timezone = args.get("timezone", "local").lower()
|
||||
|
||||
# Get current time
|
||||
current_time = datetime.datetime.now()
|
||||
|
||||
# Handle timezone if specified
|
||||
if timezone == "utc":
|
||||
current_time = datetime.datetime.utcnow()
|
||||
|
||||
# Format the time according to the specified format
|
||||
if time_format == "iso":
|
||||
# ISO 8601 format
|
||||
formatted_time = current_time.isoformat()
|
||||
elif time_format == "unix":
|
||||
# Unix timestamp (seconds since epoch)
|
||||
formatted_time = time.time()
|
||||
else:
|
||||
# Human-readable format
|
||||
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Prepare additional time components for the response
|
||||
year = current_time.year
|
||||
month = current_time.month
|
||||
day = current_time.day
|
||||
hour = current_time.hour
|
||||
minute = current_time.minute
|
||||
second = current_time.second
|
||||
weekday = current_time.strftime("%A") # Full weekday name
|
||||
|
||||
result = {
|
||||
"current_time": formatted_time,
|
||||
"components": {
|
||||
"year": year,
|
||||
"month": month,
|
||||
"day": day,
|
||||
"hour": hour,
|
||||
"minute": minute,
|
||||
"second": second,
|
||||
"weekday": weekday
|
||||
},
|
||||
"format": time_format,
|
||||
"timezone": timezone
|
||||
}
|
||||
return ToolResult.success(result=result)
|
||||
except Exception as e:
|
||||
return ToolResult.fail(result=str(e))
|
||||
3
agent/tools/edit/__init__.py
Normal file
3
agent/tools/edit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .edit import Edit
|
||||
|
||||
__all__ = ['Edit']
|
||||
164
agent/tools/edit/edit.py
Normal file
164
agent/tools/edit/edit.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Edit tool - Precise file editing
|
||||
Edit files through exact text replacement
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
normalize_to_lf,
|
||||
restore_line_endings,
|
||||
normalize_for_fuzzy_match,
|
||||
fuzzy_find_text,
|
||||
generate_diff_string
|
||||
)
|
||||
|
||||
|
||||
class Edit(BaseTool):
|
||||
"""Tool for precise file editing"""
|
||||
|
||||
name: str = "edit"
|
||||
description: str = "Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit (relative or absolute)"
|
||||
},
|
||||
"oldText": {
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)"
|
||||
},
|
||||
"newText": {
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with"
|
||||
}
|
||||
},
|
||||
"required": ["path", "oldText", "newText"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file edit operation
|
||||
|
||||
:param args: Contains file path, old text and new text
|
||||
:return: Operation result
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
old_text = args.get("oldText", "")
|
||||
new_text = args.get("newText", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable/writable
|
||||
if not os.access(absolute_path, os.R_OK | os.W_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable/writable: {path}")
|
||||
|
||||
try:
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
raw_content = f.read()
|
||||
|
||||
# Remove BOM (LLM won't include invisible BOM in oldText)
|
||||
bom, content = strip_bom(raw_content)
|
||||
|
||||
# Detect original line ending
|
||||
original_ending = detect_line_ending(content)
|
||||
|
||||
# Normalize to LF
|
||||
normalized_content = normalize_to_lf(content)
|
||||
normalized_old_text = normalize_to_lf(old_text)
|
||||
normalized_new_text = normalize_to_lf(new_text)
|
||||
|
||||
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
|
||||
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
|
||||
|
||||
if not match_result.found:
|
||||
return ToolResult.fail(
|
||||
f"Error: Could not find the exact text in {path}. "
|
||||
"The old text must match exactly including all whitespace and newlines."
|
||||
)
|
||||
|
||||
# Calculate occurrence count (use fuzzy normalized content for consistency)
|
||||
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
|
||||
occurrences = fuzzy_content.count(fuzzy_old_text)
|
||||
|
||||
if occurrences > 1:
|
||||
return ToolResult.fail(
|
||||
f"Error: Found {occurrences} occurrences of the text in {path}. "
|
||||
"The text must be unique. Please provide more context to make it unique."
|
||||
)
|
||||
|
||||
# Execute replacement (use matched text position)
|
||||
base_content = match_result.content_for_replacement
|
||||
new_content = (
|
||||
base_content[:match_result.index] +
|
||||
normalized_new_text +
|
||||
base_content[match_result.index + match_result.match_length:]
|
||||
)
|
||||
|
||||
# Verify replacement actually changed content
|
||||
if base_content == new_content:
|
||||
return ToolResult.fail(
|
||||
f"Error: No changes made to {path}. "
|
||||
"The replacement produced identical content. "
|
||||
"This might indicate an issue with special characters or the text not existing as expected."
|
||||
)
|
||||
|
||||
# Restore original line endings
|
||||
final_content = bom + restore_line_endings(new_content, original_ending)
|
||||
|
||||
# Write file
|
||||
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||
f.write(final_content)
|
||||
|
||||
# Generate diff
|
||||
diff_result = generate_diff_string(base_content, new_content)
|
||||
|
||||
result = {
|
||||
"message": f"Successfully replaced text in {path}",
|
||||
"path": path,
|
||||
"diff": diff_result['diff'],
|
||||
"first_changed_line": diff_result['first_changed_line']
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {path}")
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied accessing {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error editing file: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
3
agent/tools/file_save/__init__.py
Normal file
3
agent/tools/file_save/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .file_save import FileSave
|
||||
|
||||
__all__ = ['FileSave']
|
||||
770
agent/tools/file_save/file_save.py
Normal file
770
agent/tools/file_save/file_save.py
Normal file
@@ -0,0 +1,770 @@
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult, ToolStage
|
||||
from agent.models import LLMRequest
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class FileSave(BaseTool):
|
||||
"""Tool for saving content to files in the workspace directory."""
|
||||
|
||||
name = "file_save"
|
||||
description = "Save the agent's output to a file in the workspace directory. Content is automatically extracted from the agent's previous outputs."
|
||||
|
||||
# Set as post-process stage tool
|
||||
stage = ToolStage.POST_PROCESS
|
||||
|
||||
params = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_name": {
|
||||
"type": "string",
|
||||
"description": "Optional. The name of the file to save. If not provided, a name will be generated based on the content."
|
||||
},
|
||||
"file_type": {
|
||||
"type": "string",
|
||||
"description": "Optional. The type/extension of the file (e.g., 'txt', 'md', 'py', 'java'). If not provided, it will be inferred from the content."
|
||||
},
|
||||
"extract_code": {
|
||||
"type": "boolean",
|
||||
"description": "Optional. If true, will attempt to extract code blocks from the content. Default is false."
|
||||
}
|
||||
},
|
||||
"required": [] # No required fields, as everything can be extracted from context
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.context = None
|
||||
self.config = {}
|
||||
self.workspace_dir = Path("workspace")
|
||||
|
||||
def execute(self, params: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Save content to a file in the workspace directory.
|
||||
|
||||
:param params: The parameters for the file output operation.
|
||||
:return: Result of the operation.
|
||||
"""
|
||||
# Extract content from context
|
||||
if not hasattr(self, 'context') or not self.context:
|
||||
return ToolResult.fail("Error: No context available to extract content from.")
|
||||
|
||||
content = self._extract_content_from_context()
|
||||
|
||||
# If no content could be extracted, return error
|
||||
if not content:
|
||||
return ToolResult.fail("Error: Couldn't extract content from context.")
|
||||
|
||||
# Use model to determine file parameters
|
||||
try:
|
||||
task_dir = self._get_task_dir_from_context()
|
||||
file_name, file_type, extract_code = self._get_file_params_from_model(content)
|
||||
except Exception as e:
|
||||
logger.error(f"Error determining file parameters: {str(e)}")
|
||||
# Fall back to manual parameter extraction
|
||||
task_dir = params.get("task_dir") or self._get_task_id_from_context() or f"task_{int(time.time())}"
|
||||
file_name = params.get("file_name") or self._infer_file_name(content)
|
||||
file_type = params.get("file_type") or self._infer_file_type(content)
|
||||
extract_code = params.get("extract_code", False)
|
||||
|
||||
# Get team_name from context
|
||||
team_name = self._get_team_name_from_context() or "default_team"
|
||||
|
||||
# Create directory structure
|
||||
task_dir_path = self.workspace_dir / team_name / task_dir
|
||||
task_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if extract_code:
|
||||
# Save the complete content as markdown
|
||||
md_file_name = f"{file_name}.md"
|
||||
md_file_path = task_dir_path / md_file_name
|
||||
|
||||
# Write content to file
|
||||
with open(md_file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
return self._handle_multiple_code_blocks(content)
|
||||
|
||||
# Ensure file_name has the correct extension
|
||||
if file_type and not file_name.endswith(f".{file_type}"):
|
||||
file_name = f"{file_name}.{file_type}"
|
||||
|
||||
# Create the full file path
|
||||
file_path = task_dir_path / file_name
|
||||
|
||||
# Get absolute path for storage in team_context
|
||||
abs_file_path = file_path.absolute()
|
||||
|
||||
try:
|
||||
# Write content to file
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
# Update the current agent's final_answer to include file information
|
||||
if hasattr(self.context, 'team_context'):
|
||||
# Store with absolute path in team_context
|
||||
self.context.team_context.agent_outputs[-1].output += f"\n\nSaved file: {abs_file_path}"
|
||||
|
||||
return ToolResult.success({
|
||||
"status": "success",
|
||||
"file_path": str(file_path) # Return relative path in result
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error saving file: {str(e)}")
|
||||
|
||||
def _handle_multiple_code_blocks(self, content: str) -> ToolResult:
|
||||
"""
|
||||
Handle content with multiple code blocks, extracting and saving each as a separate file.
|
||||
|
||||
:param content: The content containing multiple code blocks
|
||||
:return: Result of the operation
|
||||
"""
|
||||
# Extract code blocks with context (including potential file name information)
|
||||
code_blocks_with_context = self._extract_code_blocks_with_context(content)
|
||||
|
||||
if not code_blocks_with_context:
|
||||
return ToolResult.fail("No code blocks found in the content.")
|
||||
|
||||
# Get task directory and team name
|
||||
task_dir = self._get_task_dir_from_context() or f"task_{int(time.time())}"
|
||||
team_name = self._get_team_name_from_context() or "default_team"
|
||||
|
||||
# Create directory structure
|
||||
task_dir_path = self.workspace_dir / team_name / task_dir
|
||||
task_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
saved_files = []
|
||||
|
||||
for block_with_context in code_blocks_with_context:
|
||||
try:
|
||||
# Use model to determine file name for this code block
|
||||
block_file_name, block_file_type = self._get_filename_for_code_block(block_with_context)
|
||||
|
||||
# Clean the code block (remove md code markers)
|
||||
clean_code = self._clean_code_block(block_with_context)
|
||||
|
||||
# Ensure file_name has the correct extension
|
||||
if block_file_type and not block_file_name.endswith(f".{block_file_type}"):
|
||||
block_file_name = f"{block_file_name}.{block_file_type}"
|
||||
|
||||
# Create the full file path (no subdirectories)
|
||||
file_path = task_dir_path / block_file_name
|
||||
|
||||
# Get absolute path for storage in team_context
|
||||
abs_file_path = file_path.absolute()
|
||||
|
||||
# Write content to file
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(clean_code)
|
||||
|
||||
saved_files.append({
|
||||
"file_path": str(file_path),
|
||||
"abs_file_path": str(abs_file_path), # Store absolute path for internal use
|
||||
"file_name": block_file_name,
|
||||
"size": len(clean_code),
|
||||
"status": "success",
|
||||
"type": "code"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving code block: {str(e)}")
|
||||
# Continue with the next block even if this one fails
|
||||
|
||||
if not saved_files:
|
||||
return ToolResult.fail("Failed to save any code blocks.")
|
||||
|
||||
# Update the current agent's final_answer to include files information
|
||||
if hasattr(self, 'context') and self.context:
|
||||
# If the agent has a final_answer attribute, append the files info to it
|
||||
if hasattr(self.context, 'team_context'):
|
||||
# Use relative paths for display
|
||||
display_info = f"\n\nSaved files to {task_dir_path}:\n" + "\n".join(
|
||||
[f"- {f['file_path']}" for f in saved_files])
|
||||
|
||||
# Check if we need to append the info
|
||||
if not self.context.team_context.agent_outputs[-1].output.endswith(display_info):
|
||||
# Store with absolute paths in team_context
|
||||
abs_info = f"\n\nSaved files to {task_dir_path.absolute()}:\n" + "\n".join(
|
||||
[f"- {f['abs_file_path']}" for f in saved_files])
|
||||
self.context.team_context.agent_outputs[-1].output += abs_info
|
||||
|
||||
result = {
|
||||
"status": "success",
|
||||
"files": [{"file_path": f["file_path"]} for f in saved_files]
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _extract_code_blocks_with_context(self, content: str) -> list:
|
||||
"""
|
||||
Extract code blocks from content, including context lines before the block.
|
||||
|
||||
:param content: The content to extract code blocks from
|
||||
:return: List of code blocks with context
|
||||
"""
|
||||
# Check if content starts with <!DOCTYPE or <html - likely a full HTML file
|
||||
if content.strip().startswith(("<!DOCTYPE", "<html", "<?xml")):
|
||||
return [content] # Return the entire content as a single block
|
||||
|
||||
# Split content into lines
|
||||
lines = content.split('\n')
|
||||
|
||||
blocks = []
|
||||
in_code_block = False
|
||||
current_block = []
|
||||
context_lines = []
|
||||
|
||||
# Check if there are any code block markers in the content
|
||||
if not re.search(r'```\w+', content):
|
||||
# If no code block markers and content looks like code, return the entire content
|
||||
if self._is_likely_code(content):
|
||||
return [content]
|
||||
|
||||
for line in lines:
|
||||
if line.strip().startswith('```'):
|
||||
if in_code_block:
|
||||
# End of code block
|
||||
current_block.append(line)
|
||||
# Only add blocks that have a language specified
|
||||
block_content = '\n'.join(current_block)
|
||||
if re.search(r'```\w+', current_block[0]):
|
||||
# Combine context with code block
|
||||
blocks.append('\n'.join(context_lines + current_block))
|
||||
current_block = []
|
||||
context_lines = []
|
||||
in_code_block = False
|
||||
else:
|
||||
# Start of code block - check if it has a language specified
|
||||
if re.search(r'```\w+', line) and not re.search(r'```language=\s*$', line):
|
||||
# Start of code block with language
|
||||
in_code_block = True
|
||||
current_block = [line]
|
||||
# Keep only the last few context lines
|
||||
context_lines = context_lines[-5:] if context_lines else []
|
||||
|
||||
elif in_code_block:
|
||||
current_block.append(line)
|
||||
else:
|
||||
# Store context lines when not in a code block
|
||||
context_lines.append(line)
|
||||
|
||||
return blocks
|
||||
|
||||
def _get_filename_for_code_block(self, block_with_context: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Determine the file name for a code block.
|
||||
|
||||
:param block_with_context: The code block with context lines
|
||||
:return: Tuple of (file_name, file_type)
|
||||
"""
|
||||
# Define common code file extensions
|
||||
COMMON_CODE_EXTENSIONS = {
|
||||
'py', 'js', 'java', 'c', 'cpp', 'h', 'hpp', 'cs', 'go', 'rb', 'php',
|
||||
'html', 'css', 'ts', 'jsx', 'tsx', 'vue', 'sh', 'sql', 'json', 'xml',
|
||||
'yaml', 'yml', 'md', 'rs', 'swift', 'kt', 'scala', 'pl', 'r', 'lua'
|
||||
}
|
||||
|
||||
# Split the block into lines to examine only the context around code block markers
|
||||
lines = block_with_context.split('\n')
|
||||
|
||||
# Find the code block start marker line index
|
||||
start_marker_idx = -1
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith('```') and not line.strip() == '```':
|
||||
start_marker_idx = i
|
||||
break
|
||||
|
||||
if start_marker_idx == -1:
|
||||
# No code block marker found
|
||||
return "", ""
|
||||
|
||||
# Extract the language from the code block marker
|
||||
code_marker = lines[start_marker_idx].strip()
|
||||
language = ""
|
||||
if len(code_marker) > 3:
|
||||
language = code_marker[3:].strip().split('=')[0].strip()
|
||||
|
||||
# Define the context range (5 lines before and 2 after the marker)
|
||||
context_start = max(0, start_marker_idx - 5)
|
||||
context_end = min(len(lines), start_marker_idx + 3)
|
||||
|
||||
# Extract only the relevant context lines
|
||||
context_lines = lines[context_start:context_end]
|
||||
|
||||
# First, check for explicit file headers like "## filename.ext"
|
||||
for line in context_lines:
|
||||
# Match patterns like "## filename.ext" or "# filename.ext"
|
||||
header_match = re.search(r'^\s*#{1,6}\s+([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)\s*$', line)
|
||||
if header_match:
|
||||
file_name = header_match.group(1)
|
||||
file_type = os.path.splitext(file_name)[1].lstrip('.')
|
||||
if file_type in COMMON_CODE_EXTENSIONS:
|
||||
return os.path.splitext(file_name)[0], file_type
|
||||
|
||||
# Simple patterns to match explicit file names in the context
|
||||
file_patterns = [
|
||||
# Match explicit file names in headers or text
|
||||
r'(?:file|filename)[:=\s]+[\'"]?([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)[\'"]?',
|
||||
# Match language=filename.ext in code markers
|
||||
r'language=([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)',
|
||||
# Match standalone filenames with extensions
|
||||
r'\b([a-zA-Z0-9_-]+\.(py|js|java|c|cpp|h|hpp|cs|go|rb|php|html|css|ts|jsx|tsx|vue|sh|sql|json|xml|yaml|yml|md|rs|swift|kt|scala|pl|r|lua))\b',
|
||||
# Match file paths in comments
|
||||
r'#\s*([a-zA-Z0-9_/-]+\.[a-zA-Z0-9]+)'
|
||||
]
|
||||
|
||||
# Check each context line for file name patterns
|
||||
for line in context_lines:
|
||||
line = line.strip()
|
||||
for pattern in file_patterns:
|
||||
matches = re.findall(pattern, line)
|
||||
if matches:
|
||||
for match in matches:
|
||||
if isinstance(match, tuple):
|
||||
# If the match is a tuple (filename, extension)
|
||||
file_name = match[0]
|
||||
file_type = match[1]
|
||||
# Verify it's not a code reference like Direction.DOWN
|
||||
if not any(keyword in file_name for keyword in ['class.', 'enum.', 'import.']):
|
||||
return os.path.splitext(file_name)[0], file_type
|
||||
else:
|
||||
# If the match is a string (full filename)
|
||||
file_name = match
|
||||
file_type = os.path.splitext(file_name)[1].lstrip('.')
|
||||
# Verify it's not a code reference
|
||||
if file_type in COMMON_CODE_EXTENSIONS and not any(
|
||||
keyword in file_name for keyword in ['class.', 'enum.', 'import.']):
|
||||
return os.path.splitext(file_name)[0], file_type
|
||||
|
||||
# If no explicit file name found, use LLM to infer from code content
|
||||
# Extract the code content
|
||||
code_content = block_with_context
|
||||
|
||||
# Get the first 20 lines of code for LLM analysis
|
||||
code_lines = code_content.split('\n')
|
||||
code_preview = '\n'.join(code_lines[:20])
|
||||
|
||||
# Get the model to use
|
||||
model_to_use = None
|
||||
if hasattr(self, 'context') and self.context:
|
||||
if hasattr(self.context, 'model') and self.context.model:
|
||||
model_to_use = self.context.model
|
||||
elif hasattr(self.context, 'team_context') and self.context.team_context:
|
||||
if hasattr(self.context.team_context, 'model') and self.context.team_context.model:
|
||||
model_to_use = self.context.team_context.model
|
||||
|
||||
# If no model is available in context, use the tool's model
|
||||
if not model_to_use and hasattr(self, 'model') and self.model:
|
||||
model_to_use = self.model
|
||||
|
||||
if model_to_use:
|
||||
# Prepare a prompt for the model
|
||||
prompt = f"""Analyze the following code and determine the most appropriate file name and file type/extension.
|
||||
The file name should be descriptive but concise, using snake_case (lowercase with underscores).
|
||||
The file type should be a standard file extension (e.g., py, js, html, css, java).
|
||||
|
||||
Code preview (first 20 lines):
|
||||
{code_preview}
|
||||
|
||||
Return your answer in JSON format with these fields:
|
||||
- file_name: The suggested file name (without extension)
|
||||
- file_type: The suggested file extension
|
||||
|
||||
JSON response:"""
|
||||
|
||||
# Create a request to the model
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0,
|
||||
json_format=True
|
||||
)
|
||||
|
||||
try:
|
||||
response = model_to_use.call(request)
|
||||
|
||||
if not response.is_error:
|
||||
# Clean the JSON response
|
||||
json_content = self._clean_json_response(response.data["choices"][0]["message"]["content"])
|
||||
result = json.loads(json_content)
|
||||
|
||||
file_name = result.get("file_name", "")
|
||||
file_type = result.get("file_type", "")
|
||||
|
||||
if file_name and file_type:
|
||||
return file_name, file_type
|
||||
except Exception as e:
|
||||
logger.error(f"Error using model to determine file name: {str(e)}")
|
||||
|
||||
# If we still don't have a file name, use the language as file type
|
||||
if language and language in COMMON_CODE_EXTENSIONS:
|
||||
timestamp = int(time.time())
|
||||
return f"code_{timestamp}", language
|
||||
|
||||
# If all else fails, return empty strings
|
||||
return "", ""
|
||||
|
||||
def _clean_json_response(self, text: str) -> str:
|
||||
"""
|
||||
Clean JSON response from LLM by removing markdown code block markers.
|
||||
|
||||
:param text: The text containing JSON possibly wrapped in markdown code blocks
|
||||
:return: Clean JSON string
|
||||
"""
|
||||
# Remove markdown code block markers if present
|
||||
if text.startswith("```json"):
|
||||
text = text[7:]
|
||||
elif text.startswith("```"):
|
||||
# Find the first newline to skip the language identifier line
|
||||
first_newline = text.find('\n')
|
||||
if first_newline != -1:
|
||||
text = text[first_newline + 1:]
|
||||
|
||||
if text.endswith("```"):
|
||||
text = text[:-3]
|
||||
|
||||
return text.strip()
|
||||
|
||||
def _clean_code_block(self, block_with_context: str) -> str:
|
||||
"""
|
||||
Clean a code block by removing markdown code markers and context lines.
|
||||
|
||||
:param block_with_context: Code block with context lines
|
||||
:return: Clean code ready for execution
|
||||
"""
|
||||
# Check if this is a full HTML or XML document
|
||||
if block_with_context.strip().startswith(("<!DOCTYPE", "<html", "<?xml")):
|
||||
return block_with_context
|
||||
|
||||
# Find the code block
|
||||
code_block_match = re.search(r'```(?:\w+)?(?:[:=][^\n]+)?\n([\s\S]*?)\n```', block_with_context)
|
||||
|
||||
if code_block_match:
|
||||
return code_block_match.group(1)
|
||||
|
||||
# If no match found, try to extract anything between ``` markers
|
||||
lines = block_with_context.split('\n')
|
||||
start_idx = None
|
||||
end_idx = None
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith('```'):
|
||||
if start_idx is None:
|
||||
start_idx = i
|
||||
else:
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
if start_idx is not None and end_idx is not None:
|
||||
# Extract the code between the markers, excluding the markers themselves
|
||||
code_lines = lines[start_idx + 1:end_idx]
|
||||
return '\n'.join(code_lines)
|
||||
|
||||
# If all else fails, return the original content
|
||||
return block_with_context
|
||||
|
||||
def _get_file_params_from_model(self, content, model=None):
|
||||
"""
|
||||
Use LLM to determine if the content is code and suggest appropriate file parameters.
|
||||
|
||||
Args:
|
||||
content: The content to analyze
|
||||
model: Optional model to use for the analysis
|
||||
|
||||
Returns:
|
||||
tuple: (file_name, file_type, extract_code) for backward compatibility
|
||||
"""
|
||||
if model is None:
|
||||
model = self.model
|
||||
|
||||
if not model:
|
||||
# Default fallback if no model is available
|
||||
return "output", "txt", False
|
||||
|
||||
prompt = f"""
|
||||
Analyze the following content and determine:
|
||||
1. Is this primarily code implementation (where most of the content consists of code blocks)?
|
||||
2. What would be an appropriate filename and file extension?
|
||||
|
||||
Content to analyze: ```
|
||||
{content[:500]} # Only show first 500 chars to avoid token limits ```
|
||||
|
||||
{"..." if len(content) > 500 else ""}
|
||||
|
||||
Respond in JSON format only with the following structure:
|
||||
{{
|
||||
"is_code": true/false, # Whether this is primarily code implementation
|
||||
"filename": "suggested_filename", # Don't include extension, english words
|
||||
"extension": "appropriate_extension" # Don't include the dot, e.g., "md", "py", "js"
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
# Create a request to the model
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.1,
|
||||
json_format=True
|
||||
)
|
||||
|
||||
# Call the model using the standard interface
|
||||
response = model.call(request)
|
||||
|
||||
if response.is_error:
|
||||
logger.warning(f"Error from model: {response.error_message}")
|
||||
raise Exception(f"Model error: {response.error_message}")
|
||||
|
||||
# Extract JSON from response
|
||||
result = response.data["choices"][0]["message"]["content"]
|
||||
|
||||
# Clean the JSON response
|
||||
result = self._clean_json_response(result)
|
||||
|
||||
# Parse the JSON
|
||||
params = json.loads(result)
|
||||
|
||||
# For backward compatibility, return tuple format
|
||||
file_name = params.get("filename", "output")
|
||||
# Remove dot from extension if present
|
||||
file_type = params.get("extension", "md").lstrip(".")
|
||||
extract_code = params.get("is_code", False)
|
||||
|
||||
return file_name, file_type, extract_code
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting file parameters from model: {e}")
|
||||
# Default fallback
|
||||
return "output", "md", False
|
||||
|
||||
def _get_team_name_from_context(self) -> Optional[str]:
|
||||
"""
|
||||
Get team name from the agent's context.
|
||||
|
||||
:return: Team name or None if not found
|
||||
"""
|
||||
if hasattr(self, 'context') and self.context:
|
||||
# Try to get team name from team_context
|
||||
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||
return self.context.team_context.name
|
||||
|
||||
# Try direct team_name attribute
|
||||
if hasattr(self.context, 'name'):
|
||||
return self.context.name
|
||||
|
||||
return None
|
||||
|
||||
def _get_task_id_from_context(self) -> Optional[str]:
|
||||
"""
|
||||
Get task ID from the agent's context.
|
||||
|
||||
:return: Task ID or None if not found
|
||||
"""
|
||||
if hasattr(self, 'context') and self.context:
|
||||
# Try to get task ID from task object
|
||||
if hasattr(self.context, 'task') and self.context.task:
|
||||
return self.context.task.id
|
||||
|
||||
# Try team_context's task
|
||||
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||
if hasattr(self.context.team_context, 'task') and self.context.team_context.task:
|
||||
return self.context.team_context.task.id
|
||||
|
||||
return None
|
||||
|
||||
def _get_task_dir_from_context(self) -> Optional[str]:
|
||||
"""
|
||||
Get task directory name from the team context.
|
||||
|
||||
:return: Task directory name or None if not found
|
||||
"""
|
||||
if hasattr(self, 'context') and self.context:
|
||||
# Try to get from team_context
|
||||
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||
if hasattr(self.context.team_context, 'task_short_name') and self.context.team_context.task_short_name:
|
||||
return self.context.team_context.task_short_name
|
||||
|
||||
# Fall back to task ID if available
|
||||
return self._get_task_id_from_context()
|
||||
|
||||
def _extract_content_from_context(self) -> str:
|
||||
"""
|
||||
Extract content from the agent's context.
|
||||
|
||||
:return: Extracted content
|
||||
"""
|
||||
# Check if we have access to the agent's context
|
||||
if not hasattr(self, 'context') or not self.context:
|
||||
return ""
|
||||
|
||||
# Try to get the most recent final answer from the agent
|
||||
if hasattr(self.context, 'final_answer') and self.context.final_answer:
|
||||
return self.context.final_answer
|
||||
|
||||
# Try to get the most recent final answer from team context
|
||||
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||
if hasattr(self.context.team_context, 'agent_outputs') and self.context.team_context.agent_outputs:
|
||||
latest_output = self.context.team_context.agent_outputs[-1].output
|
||||
return latest_output
|
||||
|
||||
# If we have action history, try to get the most recent final answer
|
||||
if hasattr(self.context, 'action_history') and self.context.action_history:
|
||||
for action in reversed(self.context.action_history):
|
||||
if "final_answer" in action and action["final_answer"]:
|
||||
return action["final_answer"]
|
||||
|
||||
return ""
|
||||
|
||||
def _extract_code_blocks(self, content: str) -> str:
|
||||
"""
|
||||
Extract code blocks from markdown content.
|
||||
|
||||
:param content: The content to extract code blocks from
|
||||
:return: Extracted code blocks
|
||||
"""
|
||||
# Pattern to match markdown code blocks
|
||||
code_block_pattern = r'```(?:\w+)?\n([\s\S]*?)\n```'
|
||||
|
||||
# Find all code blocks
|
||||
code_blocks = re.findall(code_block_pattern, content)
|
||||
|
||||
if code_blocks:
|
||||
# Join all code blocks with newlines
|
||||
return '\n\n'.join(code_blocks)
|
||||
|
||||
return content # Return original content if no code blocks found
|
||||
|
||||
def _infer_file_name(self, content: str) -> str:
|
||||
"""
|
||||
Infer a file name from the content.
|
||||
|
||||
:param content: The content to analyze.
|
||||
:return: A suggested file name.
|
||||
"""
|
||||
# Check for title patterns in markdown
|
||||
title_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
||||
if title_match:
|
||||
# Convert title to a valid filename
|
||||
title = title_match.group(1).strip()
|
||||
return self._sanitize_filename(title)
|
||||
|
||||
# Check for class/function definitions in code
|
||||
code_match = re.search(r'(class|def|function)\s+(\w+)', content)
|
||||
if code_match:
|
||||
return self._sanitize_filename(code_match.group(2))
|
||||
|
||||
# Default name based on content type
|
||||
if self._is_likely_code(content):
|
||||
return "code"
|
||||
elif self._is_likely_markdown(content):
|
||||
return "document"
|
||||
elif self._is_likely_json(content):
|
||||
return "data"
|
||||
else:
|
||||
return "output"
|
||||
|
||||
def _infer_file_type(self, content: str) -> str:
|
||||
"""
|
||||
Infer the file type/extension from the content.
|
||||
|
||||
:param content: The content to analyze.
|
||||
:return: A suggested file extension.
|
||||
"""
|
||||
# Check for common programming language patterns
|
||||
if re.search(r'(import\s+[a-zA-Z0-9_]+|from\s+[a-zA-Z0-9_\.]+\s+import)', content):
|
||||
return "py" # Python
|
||||
elif re.search(r'(public\s+class|private\s+class|protected\s+class)', content):
|
||||
return "java" # Java
|
||||
elif re.search(r'(function\s+\w+\s*\(|const\s+\w+\s*=|let\s+\w+\s*=|var\s+\w+\s*=)', content):
|
||||
return "js" # JavaScript
|
||||
elif re.search(r'(<html|<body|<div|<p>)', content):
|
||||
return "html" # HTML
|
||||
elif re.search(r'(#include\s+<\w+\.h>|int\s+main\s*\()', content):
|
||||
return "cpp" # C/C++
|
||||
|
||||
# Check for markdown
|
||||
if self._is_likely_markdown(content):
|
||||
return "md"
|
||||
|
||||
# Check for JSON
|
||||
if self._is_likely_json(content):
|
||||
return "json"
|
||||
|
||||
# Default to text
|
||||
return "txt"
|
||||
|
||||
def _is_likely_code(self, content: str) -> bool:
|
||||
"""Check if the content is likely code."""
|
||||
# First check for common HTML/XML patterns
|
||||
if content.strip().startswith(("<!DOCTYPE", "<html", "<?xml", "<head", "<body")):
|
||||
return True
|
||||
|
||||
code_patterns = [
|
||||
r'(class|def|function|import|from|public|private|protected|#include)',
|
||||
r'(\{\s*\n|\}\s*\n|\[\s*\n|\]\s*\n)',
|
||||
r'(if\s*\(|for\s*\(|while\s*\()',
|
||||
r'(<\w+>.*?</\w+>)', # HTML/XML tags
|
||||
r'(var|let|const)\s+\w+\s*=', # JavaScript variable declarations
|
||||
r'#\s*\w+', # CSS ID selectors or Python comments
|
||||
r'\.\w+\s*\{', # CSS class selectors
|
||||
r'@media|@import|@font-face' # CSS at-rules
|
||||
]
|
||||
return any(re.search(pattern, content) for pattern in code_patterns)
|
||||
|
||||
def _is_likely_markdown(self, content: str) -> bool:
|
||||
"""Check if the content is likely markdown."""
|
||||
md_patterns = [
|
||||
r'^#\s+.+$', # Headers
|
||||
r'^\*\s+.+$', # Unordered lists
|
||||
r'^\d+\.\s+.+$', # Ordered lists
|
||||
r'\[.+\]\(.+\)', # Links
|
||||
r'!\[.+\]\(.+\)' # Images
|
||||
]
|
||||
return any(re.search(pattern, content, re.MULTILINE) for pattern in md_patterns)
|
||||
|
||||
def _is_likely_json(self, content: str) -> bool:
|
||||
"""Check if the content is likely JSON."""
|
||||
try:
|
||||
content = content.strip()
|
||||
if (content.startswith('{') and content.endswith('}')) or (
|
||||
content.startswith('[') and content.endswith(']')):
|
||||
json.loads(content)
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _sanitize_filename(self, name: str) -> str:
|
||||
"""
|
||||
Sanitize a string to be used as a filename.
|
||||
|
||||
:param name: The string to sanitize.
|
||||
:return: A sanitized filename.
|
||||
"""
|
||||
# Replace spaces with underscores
|
||||
name = name.replace(' ', '_')
|
||||
|
||||
# Remove invalid characters
|
||||
name = re.sub(r'[^\w\-\.]', '', name)
|
||||
|
||||
# Limit length
|
||||
if len(name) > 50:
|
||||
name = name[:50]
|
||||
|
||||
return name.lower()
|
||||
|
||||
def _process_file_path(self, file_path: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Process a file path to extract the file name and type, and create directories if needed.
|
||||
|
||||
:param file_path: The file path to process
|
||||
:return: Tuple of (file_name, file_type)
|
||||
"""
|
||||
# Get the file name and extension
|
||||
file_name = os.path.basename(file_path)
|
||||
file_type = os.path.splitext(file_name)[1].lstrip('.')
|
||||
|
||||
return os.path.splitext(file_name)[0], file_type
|
||||
3
agent/tools/find/__init__.py
Normal file
3
agent/tools/find/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .find import Find
|
||||
|
||||
__all__ = ['Find']
|
||||
177
agent/tools/find/find.py
Normal file
177
agent/tools/find/find.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Find tool - Search for files by glob pattern
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob as glob_module
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 1000
|
||||
|
||||
|
||||
class Find(BaseTool):
|
||||
"""Tool for finding files by pattern"""
|
||||
|
||||
name: str = "find"
|
||||
description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search in (default: current directory)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of results (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file search
|
||||
|
||||
:param args: Search parameters
|
||||
:return: Search results or error
|
||||
"""
|
||||
pattern = args.get("pattern", "").strip()
|
||||
search_path = args.get("path", ".").strip()
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
if not pattern:
|
||||
return ToolResult.fail("Error: pattern parameter is required")
|
||||
|
||||
# Resolve search path
|
||||
absolute_path = self._resolve_path(search_path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||
|
||||
if not os.path.isdir(absolute_path):
|
||||
return ToolResult.fail(f"Error: Not a directory: {search_path}")
|
||||
|
||||
try:
|
||||
# Load .gitignore patterns
|
||||
ignore_patterns = self._load_gitignore(absolute_path)
|
||||
|
||||
# Search for files
|
||||
results = []
|
||||
search_pattern = os.path.join(absolute_path, pattern)
|
||||
|
||||
# Use glob with recursive support
|
||||
for file_path in glob_module.glob(search_pattern, recursive=True):
|
||||
# Skip if matches ignore patterns
|
||||
if self._should_ignore(file_path, absolute_path, ignore_patterns):
|
||||
continue
|
||||
|
||||
# Get relative path
|
||||
relative_path = os.path.relpath(file_path, absolute_path)
|
||||
|
||||
# Add trailing slash for directories
|
||||
if os.path.isdir(file_path):
|
||||
relative_path += '/'
|
||||
|
||||
results.append(relative_path)
|
||||
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
if not results:
|
||||
return ToolResult.success({"message": "No files found matching pattern", "files": []})
|
||||
|
||||
# Sort results
|
||||
results.sort()
|
||||
|
||||
# Format output
|
||||
raw_output = '\n'.join(results)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
result_limit_reached = len(results) >= limit
|
||||
if result_limit_reached:
|
||||
notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||
details["result_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"file_count": len(results),
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing find: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _load_gitignore(self, directory: str) -> List[str]:
|
||||
"""Load .gitignore patterns from directory"""
|
||||
patterns = []
|
||||
gitignore_path = os.path.join(directory, '.gitignore')
|
||||
|
||||
if os.path.exists(gitignore_path):
|
||||
try:
|
||||
with open(gitignore_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
patterns.append(line)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Add common ignore patterns
|
||||
patterns.extend([
|
||||
'.git',
|
||||
'__pycache__',
|
||||
'*.pyc',
|
||||
'node_modules',
|
||||
'.DS_Store'
|
||||
])
|
||||
|
||||
return patterns
|
||||
|
||||
def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool:
|
||||
"""Check if file should be ignored based on patterns"""
|
||||
relative_path = os.path.relpath(file_path, base_path)
|
||||
|
||||
for pattern in patterns:
|
||||
# Simple pattern matching
|
||||
if pattern in relative_path:
|
||||
return True
|
||||
|
||||
# Check if it's a directory pattern
|
||||
if pattern.endswith('/'):
|
||||
if relative_path.startswith(pattern.rstrip('/')):
|
||||
return True
|
||||
|
||||
return False
|
||||
48
agent/tools/google_search/google_search.py
Normal file
48
agent/tools/google_search/google_search.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class GoogleSearch(BaseTool):
|
||||
name: str = "google_search"
|
||||
description: str = "A tool to perform Google searches using the Serper API."
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query to perform."
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
config: dict = {}
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = config or {}
|
||||
|
||||
def execute(self, args: dict) -> ToolResult:
|
||||
api_key = self.config.get("api_key") # Replace with your actual API key
|
||||
url = "https://google.serper.dev/search"
|
||||
headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"q": args.get("query"),
|
||||
"k": 10
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
result = response.json()
|
||||
|
||||
if result.get("statusCode") and result.get("statusCode") == 503:
|
||||
return ToolResult.fail(result=result)
|
||||
else:
|
||||
# Check if the returned result contains the 'organic' key and ensure it is a list
|
||||
if 'organic' in result and isinstance(result.get('organic'), list):
|
||||
result_data = result['organic']
|
||||
else:
|
||||
# If there are no organic results, return the full response or an empty list
|
||||
result_data = result.get('organic', []) if isinstance(result.get('organic'), list) else []
|
||||
return ToolResult.success(result=result_data)
|
||||
3
agent/tools/grep/__init__.py
Normal file
3
agent/tools/grep/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .grep import Grep
|
||||
|
||||
__all__ = ['Grep']
|
||||
248
agent/tools/grep/grep.py
Normal file
248
agent/tools/grep/grep.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Grep tool - Search file contents for patterns
|
||||
Uses ripgrep (rg) for fast searching
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import (
|
||||
truncate_head, truncate_line, format_size,
|
||||
DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 100
|
||||
|
||||
|
||||
class Grep(BaseTool):
|
||||
"""Tool for searching file contents"""
|
||||
|
||||
name: str = "grep"
|
||||
description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Search pattern (regex or literal string)"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search (default: current directory)"
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'"
|
||||
},
|
||||
"ignoreCase": {
|
||||
"type": "boolean",
|
||||
"description": "Case-insensitive search (default: false)"
|
||||
},
|
||||
"literal": {
|
||||
"type": "boolean",
|
||||
"description": "Treat pattern as literal string instead of regex (default: false)"
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to show before and after each match (default: 0)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.rg_path = self._find_ripgrep()
|
||||
|
||||
def _find_ripgrep(self) -> Optional[str]:
|
||||
"""Find ripgrep executable"""
|
||||
try:
|
||||
result = subprocess.run(['which', 'rg'], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute grep search
|
||||
|
||||
:param args: Search parameters
|
||||
:return: Search results or error
|
||||
"""
|
||||
if not self.rg_path:
|
||||
return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.")
|
||||
|
||||
pattern = args.get("pattern", "").strip()
|
||||
search_path = args.get("path", ".").strip()
|
||||
glob = args.get("glob")
|
||||
ignore_case = args.get("ignoreCase", False)
|
||||
literal = args.get("literal", False)
|
||||
context = args.get("context", 0)
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
if not pattern:
|
||||
return ToolResult.fail("Error: pattern parameter is required")
|
||||
|
||||
# Resolve search path
|
||||
absolute_path = self._resolve_path(search_path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||
|
||||
# Build ripgrep command
|
||||
cmd = [
|
||||
self.rg_path,
|
||||
'--json',
|
||||
'--line-number',
|
||||
'--color=never',
|
||||
'--hidden'
|
||||
]
|
||||
|
||||
if ignore_case:
|
||||
cmd.append('--ignore-case')
|
||||
|
||||
if literal:
|
||||
cmd.append('--fixed-strings')
|
||||
|
||||
if glob:
|
||||
cmd.extend(['--glob', glob])
|
||||
|
||||
cmd.extend([pattern, absolute_path])
|
||||
|
||||
try:
|
||||
# Execute ripgrep
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=self.cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Parse JSON output
|
||||
matches = []
|
||||
match_count = 0
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
event = json.loads(line)
|
||||
if event.get('type') == 'match':
|
||||
data = event.get('data', {})
|
||||
file_path = data.get('path', {}).get('text')
|
||||
line_number = data.get('line_number')
|
||||
|
||||
if file_path and line_number:
|
||||
matches.append({
|
||||
'file': file_path,
|
||||
'line': line_number
|
||||
})
|
||||
match_count += 1
|
||||
|
||||
if match_count >= limit:
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if match_count == 0:
|
||||
return ToolResult.success({"message": "No matches found", "matches": []})
|
||||
|
||||
# Format output with context
|
||||
output_lines = []
|
||||
lines_truncated = False
|
||||
is_directory = os.path.isdir(absolute_path)
|
||||
|
||||
for match in matches:
|
||||
file_path = match['file']
|
||||
line_number = match['line']
|
||||
|
||||
# Format file path
|
||||
if is_directory:
|
||||
relative_path = os.path.relpath(file_path, absolute_path)
|
||||
else:
|
||||
relative_path = os.path.basename(file_path)
|
||||
|
||||
# Read file and get context
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
file_lines = f.read().split('\n')
|
||||
|
||||
# Calculate context range
|
||||
start = max(0, line_number - 1 - context) if context > 0 else line_number - 1
|
||||
end = min(len(file_lines), line_number + context) if context > 0 else line_number
|
||||
|
||||
# Format lines with context
|
||||
for i in range(start, end):
|
||||
line_text = file_lines[i].replace('\r', '')
|
||||
|
||||
# Truncate long lines
|
||||
truncated_text, was_truncated = truncate_line(line_text)
|
||||
if was_truncated:
|
||||
lines_truncated = True
|
||||
|
||||
# Format output
|
||||
current_line = i + 1
|
||||
if current_line == line_number:
|
||||
output_lines.append(f"{relative_path}:{current_line}: {truncated_text}")
|
||||
else:
|
||||
output_lines.append(f"{relative_path}-{current_line}- {truncated_text}")
|
||||
|
||||
except Exception:
|
||||
output_lines.append(f"{relative_path}:{line_number}: (unable to read file)")
|
||||
|
||||
# Apply byte truncation
|
||||
raw_output = '\n'.join(output_lines)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
if match_count >= limit:
|
||||
notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||
details["match_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if lines_truncated:
|
||||
notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines")
|
||||
details["lines_truncated"] = True
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"match_count": match_count,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail("Error: Search timed out after 30 seconds")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing grep: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
3
agent/tools/ls/__init__.py
Normal file
3
agent/tools/ls/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ls import Ls
|
||||
|
||||
__all__ = ['Ls']
|
||||
125
agent/tools/ls/ls.py
Normal file
125
agent/tools/ls/ls.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Ls tool - List directory contents
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 500
|
||||
|
||||
|
||||
class Ls(BaseTool):
|
||||
"""Tool for listing directory contents"""
|
||||
|
||||
name: str = "ls"
|
||||
description: str = f"List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to {DEFAULT_LIMIT} entries or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to list (default: current directory)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of entries to return (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute directory listing
|
||||
|
||||
:param args: Listing parameters
|
||||
:return: Directory contents or error
|
||||
"""
|
||||
path = args.get("path", ".").strip()
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: Path not found: {path}")
|
||||
|
||||
if not os.path.isdir(absolute_path):
|
||||
return ToolResult.fail(f"Error: Not a directory: {path}")
|
||||
|
||||
try:
|
||||
# Read directory entries
|
||||
entries = os.listdir(absolute_path)
|
||||
|
||||
# Sort alphabetically (case-insensitive)
|
||||
entries.sort(key=lambda x: x.lower())
|
||||
|
||||
# Format entries with directory indicators
|
||||
results = []
|
||||
entry_limit_reached = False
|
||||
|
||||
for entry in entries:
|
||||
if len(results) >= limit:
|
||||
entry_limit_reached = True
|
||||
break
|
||||
|
||||
full_path = os.path.join(absolute_path, entry)
|
||||
|
||||
try:
|
||||
if os.path.isdir(full_path):
|
||||
results.append(entry + '/')
|
||||
else:
|
||||
results.append(entry)
|
||||
except:
|
||||
# Skip entries we can't stat
|
||||
continue
|
||||
|
||||
if not results:
|
||||
return ToolResult.success({"message": "(empty directory)", "entries": []})
|
||||
|
||||
# Format output
|
||||
raw_output = '\n'.join(results)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
if entry_limit_reached:
|
||||
notices.append(f"{limit} entries limit reached. Use limit={limit * 2} for more")
|
||||
details["entry_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"entry_count": len(results),
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied reading directory: {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error listing directory: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
10
agent/tools/memory/__init__.py
Normal file
10
agent/tools/memory/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Memory tools for Agent
|
||||
|
||||
Provides memory_search and memory_get tools
|
||||
"""
|
||||
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
__all__ = ['MemorySearchTool', 'MemoryGetTool']
|
||||
107
agent/tools/memory/memory_get.py
Normal file
107
agent/tools/memory/memory_get.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Memory get tool
|
||||
|
||||
Allows agents to read specific sections from memory files
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class MemoryGetTool(BaseTool):
|
||||
"""Tool for reading memory file contents"""
|
||||
|
||||
name: str = "memory_get"
|
||||
description: str = (
|
||||
"Read specific content from memory files. "
|
||||
"Use this to get full context from a memory file or specific line range."
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')"
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Starting line number (optional, default: 1)",
|
||||
"default": 1
|
||||
},
|
||||
"num_lines": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read (optional, reads all if not specified)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager):
|
||||
"""
|
||||
Initialize memory get tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
Execute memory file read
|
||||
|
||||
Args:
|
||||
args: Dictionary with path, start_line, num_lines
|
||||
|
||||
Returns:
|
||||
ToolResult with file content
|
||||
"""
|
||||
from agent.tools.base_tool import ToolResult
|
||||
|
||||
path = args.get("path")
|
||||
start_line = args.get("start_line", 1)
|
||||
num_lines = args.get("num_lines")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
try:
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
file_path = workspace_dir / path
|
||||
|
||||
if not file_path.exists():
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
content = file_path.read_text()
|
||||
lines = content.split('\n')
|
||||
|
||||
# Handle line range
|
||||
if start_line < 1:
|
||||
start_line = 1
|
||||
|
||||
start_idx = start_line - 1
|
||||
|
||||
if num_lines:
|
||||
end_idx = start_idx + num_lines
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
else:
|
||||
selected_lines = lines[start_idx:]
|
||||
|
||||
result = '\n'.join(selected_lines)
|
||||
|
||||
# Add metadata
|
||||
total_lines = len(lines)
|
||||
shown_lines = len(selected_lines)
|
||||
|
||||
output = [
|
||||
f"File: {path}",
|
||||
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
|
||||
"",
|
||||
result
|
||||
]
|
||||
|
||||
return ToolResult.success('\n'.join(output))
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading memory file: {str(e)}")
|
||||
96
agent/tools/memory/memory_search.py
Normal file
96
agent/tools/memory/memory_search.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Memory search tool
|
||||
|
||||
Allows agents to search their memory using semantic and keyword search
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class MemorySearchTool(BaseTool):
|
||||
"""Tool for searching agent memory"""
|
||||
|
||||
name: str = "memory_search"
|
||||
description: str = (
|
||||
"Search agent's long-term memory using semantic and keyword search. "
|
||||
"Use this to recall past conversations, preferences, and knowledge."
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (can be natural language question or keywords)"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 10)",
|
||||
"default": 10
|
||||
},
|
||||
"min_score": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score (0-1, default: 0.3)",
|
||||
"default": 0.3
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager, user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize memory search tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
user_id: Optional user ID for scoped search
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self.user_id = user_id
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
Execute memory search
|
||||
|
||||
Args:
|
||||
args: Dictionary with query, max_results, min_score
|
||||
|
||||
Returns:
|
||||
ToolResult with formatted search results
|
||||
"""
|
||||
from agent.tools.base_tool import ToolResult
|
||||
import asyncio
|
||||
|
||||
query = args.get("query")
|
||||
max_results = args.get("max_results", 10)
|
||||
min_score = args.get("min_score", 0.3)
|
||||
|
||||
if not query:
|
||||
return ToolResult.fail("Error: query parameter is required")
|
||||
|
||||
try:
|
||||
# Run async search in sync context
|
||||
results = asyncio.run(self.memory_manager.search(
|
||||
query=query,
|
||||
user_id=self.user_id,
|
||||
max_results=max_results,
|
||||
min_score=min_score,
|
||||
include_shared=True
|
||||
))
|
||||
|
||||
if not results:
|
||||
return ToolResult.success(f"No relevant memories found for query: {query}")
|
||||
|
||||
# Format results
|
||||
output = [f"Found {len(results)} relevant memories:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
|
||||
output.append(f" Score: {result.score:.3f}")
|
||||
output.append(f" Snippet: {result.snippet}")
|
||||
|
||||
return ToolResult.success("\n".join(output))
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error searching memory: {str(e)}")
|
||||
3
agent/tools/read/__init__.py
Normal file
3
agent/tools/read/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .read import Read
|
||||
|
||||
__all__ = ['Read']
|
||||
336
agent/tools/read/read.py
Normal file
336
agent/tools/read/read.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""
|
||||
Read tool - Read file contents
|
||||
Supports text files, images (jpg, png, gif, webp), and PDF files
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
|
||||
|
||||
class Read(BaseTool):
|
||||
"""Tool for reading file contents"""
|
||||
|
||||
name: str = "read"
|
||||
description: str = f"Read the contents of a file. Supports text files, PDF files, and images (jpg, png, gif, webp). For text files, output is truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Use offset/limit for large files."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to read (relative or absolute)"
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, optional)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (optional)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
# Supported image formats
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp'}
|
||||
# Supported PDF format
|
||||
self.pdf_extensions = {'.pdf'}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file read operation
|
||||
|
||||
:param args: Contains file path and optional offset/limit parameters
|
||||
:return: File content or error message
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
offset = args.get("offset")
|
||||
limit = args.get("limit")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable
|
||||
if not os.access(absolute_path, os.R_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable: {path}")
|
||||
|
||||
# Check file type
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
|
||||
# Check if image
|
||||
if file_ext in self.image_extensions:
|
||||
return self._read_image(absolute_path, file_ext)
|
||||
|
||||
# Check if PDF
|
||||
if file_ext in self.pdf_extensions:
|
||||
return self._read_pdf(absolute_path, path, offset, limit)
|
||||
|
||||
# Read text file
|
||||
return self._read_text(absolute_path, path, offset, limit)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
|
||||
"""
|
||||
Read image file
|
||||
|
||||
:param absolute_path: Absolute path to the image file
|
||||
:param file_ext: File extension
|
||||
:return: Result containing image information
|
||||
"""
|
||||
try:
|
||||
# Read image file
|
||||
with open(absolute_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
# Get file size
|
||||
file_size = len(image_data)
|
||||
|
||||
# Return image information (actual image data can be base64 encoded when needed)
|
||||
import base64
|
||||
base64_data = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
# Determine MIME type
|
||||
mime_type_map = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp'
|
||||
}
|
||||
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
|
||||
|
||||
result = {
|
||||
"type": "image",
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"data": base64_data # Base64 encoded image data
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading image file: {str(e)}")
|
||||
|
||||
def _read_text(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read text file
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param display_path: Path to display
|
||||
:param offset: Starting line number (1-indexed)
|
||||
:param limit: Maximum number of lines to read
|
||||
:return: File content or error message
|
||||
"""
|
||||
try:
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
all_lines = content.split('\n')
|
||||
total_file_lines = len(all_lines)
|
||||
|
||||
# Apply offset (if specified)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
start_line = max(0, offset - 1) # Convert to 0-indexed
|
||||
if start_line >= total_file_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1 # For display (1-indexed)
|
||||
|
||||
# If user specified limit, use it
|
||||
selected_content = content
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_file_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
# Apply truncation (considering line count and byte limits)
|
||||
truncation = truncate_head(selected_content)
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
if truncation.first_line_exceeds_limit:
|
||||
# First line exceeds 30KB limit
|
||||
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
|
||||
output_text = f"[Line {start_line_display} is {first_line_size}, exceeds {format_size(DEFAULT_MAX_BYTES)} limit. Use bash tool to read: head -c {DEFAULT_MAX_BYTES} {display_path} | tail -n +{start_line_display}]"
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif truncation.truncated:
|
||||
# Truncation occurred
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
|
||||
output_text = truncation.content
|
||||
|
||||
if truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines}. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_file_lines:
|
||||
# User specified limit, more content available, but no truncation
|
||||
remaining = total_file_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
# No truncation, no exceeding user limit
|
||||
output_text = truncation.content
|
||||
|
||||
result = {
|
||||
"content": output_text,
|
||||
"total_lines": total_file_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines
|
||||
}
|
||||
|
||||
if details:
|
||||
result["details"] = details
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {display_path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading file: {str(e)}")
|
||||
|
||||
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read PDF file content
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param display_path: Path to display
|
||||
:param offset: Starting line number (1-indexed)
|
||||
:param limit: Maximum number of lines to read
|
||||
:return: PDF text content or error message
|
||||
"""
|
||||
try:
|
||||
# Try to import pypdf
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
return ToolResult.fail(
|
||||
"Error: pypdf library not installed. Install with: pip install pypdf"
|
||||
)
|
||||
|
||||
# Read PDF
|
||||
reader = PdfReader(absolute_path)
|
||||
total_pages = len(reader.pages)
|
||||
|
||||
# Extract text from all pages
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num} ---\n{page_text}")
|
||||
|
||||
if not text_parts:
|
||||
return ToolResult.success({
|
||||
"content": f"[PDF file with {total_pages} pages, but no text content could be extracted]",
|
||||
"total_pages": total_pages,
|
||||
"message": "PDF may contain only images or be encrypted"
|
||||
})
|
||||
|
||||
# Merge all text
|
||||
full_content = "\n\n".join(text_parts)
|
||||
all_lines = full_content.split('\n')
|
||||
total_lines = len(all_lines)
|
||||
|
||||
# Apply offset and limit (same logic as text files)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
start_line = max(0, offset - 1)
|
||||
if start_line >= total_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1
|
||||
|
||||
selected_content = full_content
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
# Apply truncation
|
||||
truncation = truncate_head(selected_content)
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
if truncation.truncated:
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
|
||||
output_text = truncation.content
|
||||
|
||||
if truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||
remaining = total_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text = truncation.content
|
||||
|
||||
result = {
|
||||
"content": output_text,
|
||||
"total_pages": total_pages,
|
||||
"total_lines": total_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines
|
||||
}
|
||||
|
||||
if details:
|
||||
result["details"] = details
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading PDF file: {str(e)}")
|
||||
3
agent/tools/terminal/__init__.py
Normal file
3
agent/tools/terminal/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .terminal import Terminal
|
||||
|
||||
__all__ = ['Terminal']
|
||||
100
agent/tools/terminal/terminal.py
Normal file
100
agent/tools/terminal/terminal.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import platform
|
||||
import subprocess
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class Terminal(BaseTool):
|
||||
name: str = "terminal"
|
||||
description: str = "A tool to run terminal commands on the local system"
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": f"The terminal command to execute which should be valid in {platform.system()} platform"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
config: dict = {}
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.config = config or {}
|
||||
# Set of dangerous commands that should be blocked
|
||||
self.command_ban_set = {"halt", "poweroff", "shutdown", "reboot", "rm", "kill",
|
||||
"exit", "sudo", "su", "userdel", "groupdel", "logout", "alias"}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute a terminal command safely.
|
||||
|
||||
:param args: Dictionary containing the command to execute
|
||||
:return: Result of the command execution
|
||||
"""
|
||||
command = args.get("command", "").strip()
|
||||
|
||||
# Check if the command is safe to execute
|
||||
if not self._is_safe_command(command):
|
||||
return ToolResult.fail(result=f"Command '{command}' is not allowed for security reasons.")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
check=True, # Raise exception on non-zero return code
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=self.config.get("timeout", 30)
|
||||
)
|
||||
|
||||
return ToolResult.success({
|
||||
"stdout": result.stdout,
|
||||
"stderr": result.stderr,
|
||||
"return_code": result.returncode,
|
||||
"command": command
|
||||
})
|
||||
except subprocess.CalledProcessError as e:
|
||||
# Preserve the original error handling for CalledProcessError
|
||||
return ToolResult.fail({
|
||||
"stdout": e.stdout,
|
||||
"stderr": e.stderr,
|
||||
"return_code": e.returncode,
|
||||
"command": command
|
||||
})
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail(result=f"Command timed out after {self.config.get('timeout', 20)} seconds.")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(result=f"Error executing command: {str(e)}")
|
||||
|
||||
def _is_safe_command(self, command: str) -> bool:
|
||||
"""
|
||||
Check if a command is safe to execute.
|
||||
|
||||
:param command: The command to check
|
||||
:return: True if the command is safe, False otherwise
|
||||
"""
|
||||
# Split the command to get the base command
|
||||
cmd_parts = command.split()
|
||||
if not cmd_parts:
|
||||
return False
|
||||
|
||||
base_cmd = cmd_parts[0].lower()
|
||||
|
||||
# Check if the base command is in the ban list
|
||||
if base_cmd in self.command_ban_set:
|
||||
return False
|
||||
|
||||
# Check for sudo/su commands
|
||||
if any(banned in command.lower() for banned in ["sudo ", "su -"]):
|
||||
return False
|
||||
|
||||
# Check for rm -rf or similar dangerous patterns
|
||||
if "rm" in base_cmd and ("-rf" in command or "-r" in command or "-f" in command):
|
||||
return False
|
||||
|
||||
# Additional security checks can be added here
|
||||
|
||||
return True
|
||||
208
agent/tools/tool_manager.py
Normal file
208
agent/tools/tool_manager.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Type
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
Tool manager for managing tools.
|
||||
"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern to ensure only one instance of ToolManager exists."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ToolManager, cls).__new__(cls)
|
||||
cls._instance.tool_classes = {} # Store tool classes instead of instances
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Initialize only once
|
||||
if not hasattr(self, 'tool_classes'):
|
||||
self.tool_classes = {} # Dictionary to store tool classes
|
||||
|
||||
def load_tools(self, tools_dir: str = "", config_dict=None):
|
||||
"""
|
||||
Load tools from both directory and configuration.
|
||||
|
||||
:param tools_dir: Directory to scan for tool modules
|
||||
"""
|
||||
if tools_dir:
|
||||
self._load_tools_from_directory(tools_dir)
|
||||
self._configure_tools_from_config()
|
||||
else:
|
||||
self._load_tools_from_init()
|
||||
self._configure_tools_from_config(config_dict)
|
||||
|
||||
def _load_tools_from_init(self) -> bool:
|
||||
"""
|
||||
Load tool classes from tools.__init__.__all__
|
||||
|
||||
:return: True if tools were loaded, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to import the tools package
|
||||
tools_package = importlib.import_module("agent.tools")
|
||||
|
||||
# Check if __all__ is defined
|
||||
if hasattr(tools_package, "__all__"):
|
||||
tool_classes = tools_package.__all__
|
||||
|
||||
# Import each tool class directly from the tools package
|
||||
for class_name in tool_classes:
|
||||
try:
|
||||
# Skip base classes
|
||||
if class_name in ["BaseTool", "ToolManager"]:
|
||||
continue
|
||||
|
||||
# Get the class directly from the tools package
|
||||
if hasattr(tools_package, class_name):
|
||||
cls = getattr(tools_package, class_name)
|
||||
|
||||
if (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, BaseTool)
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
tool_name = temp_instance.name
|
||||
# Store the class, not the instance
|
||||
self.tool_classes[tool_name] = cls
|
||||
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
|
||||
except ImportError as e:
|
||||
# Ignore browser_use dependency missing errors
|
||||
if "browser_use" in str(e):
|
||||
pass
|
||||
else:
|
||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing class {class_name}: {e}")
|
||||
|
||||
return len(self.tool_classes) > 0
|
||||
return False
|
||||
except ImportError:
|
||||
logger.warning("Could not import agent.tools package")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading tools from __init__.__all__: {e}")
|
||||
return False
|
||||
|
||||
def _load_tools_from_directory(self, tools_dir: str):
|
||||
"""Dynamically load tool classes from directory"""
|
||||
tools_path = Path(tools_dir)
|
||||
|
||||
# Traverse all .py files
|
||||
for py_file in tools_path.rglob("*.py"):
|
||||
# Skip initialization files and base tool files
|
||||
if py_file.name in ["__init__.py", "base_tool.py", "tool_manager.py"]:
|
||||
continue
|
||||
|
||||
# Get module name
|
||||
module_name = py_file.stem
|
||||
|
||||
try:
|
||||
# Load module directly from file
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Find tool classes in the module
|
||||
for attr_name in dir(module):
|
||||
cls = getattr(module, attr_name)
|
||||
if (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, BaseTool)
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
tool_name = temp_instance.name
|
||||
# Store the class, not the instance
|
||||
self.tool_classes[tool_name] = cls
|
||||
except ImportError as e:
|
||||
# Ignore browser_use dependency missing errors
|
||||
if "browser_use" in str(e):
|
||||
pass
|
||||
else:
|
||||
print(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error importing module {py_file}: {e}")
|
||||
|
||||
def _configure_tools_from_config(self, config_dict=None):
|
||||
"""Configure tool classes based on configuration file"""
|
||||
try:
|
||||
# Get tools configuration
|
||||
tools_config = config_dict or config().get("tools", {})
|
||||
|
||||
# Record tools that are configured but not loaded
|
||||
missing_tools = []
|
||||
|
||||
# Store configurations for later use when instantiating
|
||||
self.tool_configs = tools_config
|
||||
|
||||
# Check which configured tools are missing
|
||||
for tool_name in tools_config:
|
||||
if tool_name not in self.tool_classes:
|
||||
missing_tools.append(tool_name)
|
||||
|
||||
# If there are missing tools, record warnings
|
||||
if missing_tools:
|
||||
for tool_name in missing_tools:
|
||||
if tool_name == "browser":
|
||||
logger.error(
|
||||
"Browser tool is configured but could not be loaded. "
|
||||
"Please install the required dependency with: "
|
||||
"pip install browser-use>=0.1.40 or pip install agentmesh-sdk[full]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Tool '{tool_name}' is configured but could not be loaded.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error configuring tools from config: {e}")
|
||||
|
||||
def create_tool(self, name: str) -> BaseTool:
|
||||
"""
|
||||
Get a new instance of a tool by name.
|
||||
|
||||
:param name: The name of the tool to get.
|
||||
:return: A new instance of the tool or None if not found.
|
||||
"""
|
||||
tool_class = self.tool_classes.get(name)
|
||||
if tool_class:
|
||||
# Create a new instance
|
||||
tool_instance = tool_class()
|
||||
|
||||
# Apply configuration if available
|
||||
if hasattr(self, 'tool_configs') and name in self.tool_configs:
|
||||
tool_instance.config = self.tool_configs[name]
|
||||
|
||||
return tool_instance
|
||||
return None
|
||||
|
||||
def list_tools(self) -> dict:
|
||||
"""
|
||||
Get information about all loaded tools.
|
||||
|
||||
:return: A dictionary with tool information.
|
||||
"""
|
||||
result = {}
|
||||
for name, tool_class in self.tool_classes.items():
|
||||
# Create a temporary instance to get schema
|
||||
temp_instance = tool_class()
|
||||
result[name] = {
|
||||
"description": temp_instance.description,
|
||||
"parameters": temp_instance.get_json_schema()
|
||||
}
|
||||
return result
|
||||
40
agent/tools/utils/__init__.py
Normal file
40
agent/tools/utils/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from .truncate import (
|
||||
truncate_head,
|
||||
truncate_tail,
|
||||
truncate_line,
|
||||
format_size,
|
||||
TruncationResult,
|
||||
DEFAULT_MAX_LINES,
|
||||
DEFAULT_MAX_BYTES,
|
||||
GREP_MAX_LINE_LENGTH
|
||||
)
|
||||
|
||||
from .diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
normalize_to_lf,
|
||||
restore_line_endings,
|
||||
normalize_for_fuzzy_match,
|
||||
fuzzy_find_text,
|
||||
generate_diff_string,
|
||||
FuzzyMatchResult
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'truncate_head',
|
||||
'truncate_tail',
|
||||
'truncate_line',
|
||||
'format_size',
|
||||
'TruncationResult',
|
||||
'DEFAULT_MAX_LINES',
|
||||
'DEFAULT_MAX_BYTES',
|
||||
'GREP_MAX_LINE_LENGTH',
|
||||
'strip_bom',
|
||||
'detect_line_ending',
|
||||
'normalize_to_lf',
|
||||
'restore_line_endings',
|
||||
'normalize_for_fuzzy_match',
|
||||
'fuzzy_find_text',
|
||||
'generate_diff_string',
|
||||
'FuzzyMatchResult'
|
||||
]
|
||||
167
agent/tools/utils/diff.py
Normal file
167
agent/tools/utils/diff.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Diff tools for file editing
|
||||
Provides fuzzy matching and diff generation functionality
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def strip_bom(text: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Remove BOM (Byte Order Mark)
|
||||
|
||||
:param text: Original text
|
||||
:return: (BOM, text after removing BOM)
|
||||
"""
|
||||
if text.startswith('\ufeff'):
|
||||
return '\ufeff', text[1:]
|
||||
return '', text
|
||||
|
||||
|
||||
def detect_line_ending(text: str) -> str:
|
||||
"""
|
||||
Detect line ending type
|
||||
|
||||
:param text: Text content
|
||||
:return: Line ending type ('\r\n' or '\n')
|
||||
"""
|
||||
if '\r\n' in text:
|
||||
return '\r\n'
|
||||
return '\n'
|
||||
|
||||
|
||||
def normalize_to_lf(text: str) -> str:
|
||||
"""
|
||||
Normalize all line endings to LF (\n)
|
||||
|
||||
:param text: Original text
|
||||
:return: Normalized text
|
||||
"""
|
||||
return text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
|
||||
def restore_line_endings(text: str, original_ending: str) -> str:
|
||||
"""
|
||||
Restore original line endings
|
||||
|
||||
:param text: LF normalized text
|
||||
:param original_ending: Original line ending
|
||||
:return: Text with restored line endings
|
||||
"""
|
||||
if original_ending == '\r\n':
|
||||
return text.replace('\n', '\r\n')
|
||||
return text
|
||||
|
||||
|
||||
def normalize_for_fuzzy_match(text: str) -> str:
|
||||
"""
|
||||
Normalize text for fuzzy matching
|
||||
Remove excess whitespace but preserve basic structure
|
||||
|
||||
:param text: Original text
|
||||
:return: Normalized text
|
||||
"""
|
||||
# Compress multiple spaces to one
|
||||
text = re.sub(r'[ \t]+', ' ', text)
|
||||
# Remove trailing spaces
|
||||
text = re.sub(r' +\n', '\n', text)
|
||||
# Remove leading spaces (but preserve indentation structure, only remove excess)
|
||||
lines = text.split('\n')
|
||||
normalized_lines = []
|
||||
for line in lines:
|
||||
# Preserve indentation but normalize to multiples of single spaces
|
||||
stripped = line.lstrip()
|
||||
if stripped:
|
||||
indent_count = len(line) - len(stripped)
|
||||
# Normalize indentation (convert tabs to spaces)
|
||||
normalized_indent = ' ' * indent_count
|
||||
normalized_lines.append(normalized_indent + stripped)
|
||||
else:
|
||||
normalized_lines.append('')
|
||||
return '\n'.join(normalized_lines)
|
||||
|
||||
|
||||
class FuzzyMatchResult:
|
||||
"""Fuzzy match result"""
|
||||
|
||||
def __init__(self, found: bool, index: int = -1, match_length: int = 0, content_for_replacement: str = ""):
|
||||
self.found = found
|
||||
self.index = index
|
||||
self.match_length = match_length
|
||||
self.content_for_replacement = content_for_replacement
|
||||
|
||||
|
||||
def fuzzy_find_text(content: str, old_text: str) -> FuzzyMatchResult:
|
||||
"""
|
||||
Find text in content, try exact match first, then fuzzy match
|
||||
|
||||
:param content: Content to search in
|
||||
:param old_text: Text to find
|
||||
:return: Match result
|
||||
"""
|
||||
# First try exact match
|
||||
index = content.find(old_text)
|
||||
if index != -1:
|
||||
return FuzzyMatchResult(
|
||||
found=True,
|
||||
index=index,
|
||||
match_length=len(old_text),
|
||||
content_for_replacement=content
|
||||
)
|
||||
|
||||
# Try fuzzy match
|
||||
fuzzy_content = normalize_for_fuzzy_match(content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(old_text)
|
||||
|
||||
index = fuzzy_content.find(fuzzy_old_text)
|
||||
if index != -1:
|
||||
# Fuzzy match successful, use normalized content for replacement
|
||||
return FuzzyMatchResult(
|
||||
found=True,
|
||||
index=index,
|
||||
match_length=len(fuzzy_old_text),
|
||||
content_for_replacement=fuzzy_content
|
||||
)
|
||||
|
||||
# Not found
|
||||
return FuzzyMatchResult(found=False)
|
||||
|
||||
|
||||
def generate_diff_string(old_content: str, new_content: str) -> dict:
|
||||
"""
|
||||
Generate unified diff string
|
||||
|
||||
:param old_content: Old content
|
||||
:param new_content: New content
|
||||
:return: Dictionary containing diff and first changed line number
|
||||
"""
|
||||
old_lines = old_content.split('\n')
|
||||
new_lines = new_content.split('\n')
|
||||
|
||||
# Generate unified diff
|
||||
diff_lines = list(difflib.unified_diff(
|
||||
old_lines,
|
||||
new_lines,
|
||||
lineterm='',
|
||||
fromfile='original',
|
||||
tofile='modified'
|
||||
))
|
||||
|
||||
# Find first changed line number
|
||||
first_changed_line = None
|
||||
for line in diff_lines:
|
||||
if line.startswith('@@'):
|
||||
# Parse @@ -1,3 +1,3 @@ format
|
||||
match = re.search(r'@@ -\d+,?\d* \+(\d+)', line)
|
||||
if match:
|
||||
first_changed_line = int(match.group(1))
|
||||
break
|
||||
|
||||
diff_string = '\n'.join(diff_lines)
|
||||
|
||||
return {
|
||||
'diff': diff_string,
|
||||
'first_changed_line': first_changed_line
|
||||
}
|
||||
292
agent/tools/utils/truncate.py
Normal file
292
agent/tools/utils/truncate.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Shared truncation utilities for tool outputs.
|
||||
|
||||
Truncation is based on two independent limits - whichever is hit first wins:
|
||||
- Line limit (default: 2000 lines)
|
||||
- Byte limit (default: 50KB)
|
||||
|
||||
Never returns partial lines (except bash tail truncation edge case).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Literal
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 2000
|
||||
DEFAULT_MAX_BYTES = 50 * 1024 # 50KB
|
||||
GREP_MAX_LINE_LENGTH = 500 # Max chars per grep match line
|
||||
|
||||
|
||||
class TruncationResult:
|
||||
"""Truncation result"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
truncated: bool,
|
||||
truncated_by: Optional[Literal["lines", "bytes"]],
|
||||
total_lines: int,
|
||||
total_bytes: int,
|
||||
output_lines: int,
|
||||
output_bytes: int,
|
||||
last_line_partial: bool = False,
|
||||
first_line_exceeds_limit: bool = False,
|
||||
max_lines: int = DEFAULT_MAX_LINES,
|
||||
max_bytes: int = DEFAULT_MAX_BYTES
|
||||
):
|
||||
self.content = content
|
||||
self.truncated = truncated
|
||||
self.truncated_by = truncated_by
|
||||
self.total_lines = total_lines
|
||||
self.total_bytes = total_bytes
|
||||
self.output_lines = output_lines
|
||||
self.output_bytes = output_bytes
|
||||
self.last_line_partial = last_line_partial
|
||||
self.first_line_exceeds_limit = first_line_exceeds_limit
|
||||
self.max_lines = max_lines
|
||||
self.max_bytes = max_bytes
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"content": self.content,
|
||||
"truncated": self.truncated,
|
||||
"truncated_by": self.truncated_by,
|
||||
"total_lines": self.total_lines,
|
||||
"total_bytes": self.total_bytes,
|
||||
"output_lines": self.output_lines,
|
||||
"output_bytes": self.output_bytes,
|
||||
"last_line_partial": self.last_line_partial,
|
||||
"first_line_exceeds_limit": self.first_line_exceeds_limit,
|
||||
"max_lines": self.max_lines,
|
||||
"max_bytes": self.max_bytes
|
||||
}
|
||||
|
||||
|
||||
def format_size(bytes_count: int) -> str:
|
||||
"""Format bytes as human-readable size"""
|
||||
if bytes_count < 1024:
|
||||
return f"{bytes_count}B"
|
||||
elif bytes_count < 1024 * 1024:
|
||||
return f"{bytes_count / 1024:.1f}KB"
|
||||
else:
|
||||
return f"{bytes_count / (1024 * 1024):.1f}MB"
|
||||
|
||||
|
||||
def truncate_head(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||
"""
|
||||
Truncate content from the head (keep first N lines/bytes).
|
||||
Suitable for file reads where you want to see the beginning.
|
||||
|
||||
Never returns partial lines. If first line exceeds byte limit,
|
||||
returns empty content with first_line_exceeds_limit=True.
|
||||
|
||||
:param content: Content to truncate
|
||||
:param max_lines: Maximum number of lines (default: 2000)
|
||||
:param max_bytes: Maximum number of bytes (default: 50KB)
|
||||
:return: Truncation result
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_MAX_BYTES
|
||||
|
||||
total_bytes = len(content.encode('utf-8'))
|
||||
lines = content.split('\n')
|
||||
total_lines = len(lines)
|
||||
|
||||
# Check if no truncation is needed
|
||||
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||
return TruncationResult(
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncated_by=None,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=total_lines,
|
||||
output_bytes=total_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Check if first line alone exceeds byte limit
|
||||
first_line_bytes = len(lines[0].encode('utf-8'))
|
||||
if first_line_bytes > max_bytes:
|
||||
return TruncationResult(
|
||||
content="",
|
||||
truncated=True,
|
||||
truncated_by="bytes",
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=0,
|
||||
output_bytes=0,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=True,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Collect complete lines that fit
|
||||
output_lines_arr = []
|
||||
output_bytes_count = 0
|
||||
truncated_by = "lines"
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if i >= max_lines:
|
||||
break
|
||||
|
||||
# Calculate line bytes (add 1 for newline if not first line)
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if i > 0 else 0)
|
||||
|
||||
if output_bytes_count + line_bytes > max_bytes:
|
||||
truncated_by = "bytes"
|
||||
break
|
||||
|
||||
output_lines_arr.append(line)
|
||||
output_bytes_count += line_bytes
|
||||
|
||||
# If exited due to line limit
|
||||
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||
truncated_by = "lines"
|
||||
|
||||
output_content = '\n'.join(output_lines_arr)
|
||||
final_output_bytes = len(output_content.encode('utf-8'))
|
||||
|
||||
return TruncationResult(
|
||||
content=output_content,
|
||||
truncated=True,
|
||||
truncated_by=truncated_by,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=len(output_lines_arr),
|
||||
output_bytes=final_output_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
|
||||
def truncate_tail(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||
"""
|
||||
Truncate content from tail (keep last N lines/bytes).
|
||||
Suitable for bash output where you want to see the ending content (errors, final results).
|
||||
|
||||
If the last line of original content exceeds byte limit, may return partial first line.
|
||||
|
||||
:param content: Content to truncate
|
||||
:param max_lines: Maximum lines (default: 2000)
|
||||
:param max_bytes: Maximum bytes (default: 50KB)
|
||||
:return: Truncation result
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_MAX_BYTES
|
||||
|
||||
total_bytes = len(content.encode('utf-8'))
|
||||
lines = content.split('\n')
|
||||
total_lines = len(lines)
|
||||
|
||||
# Check if no truncation is needed
|
||||
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||
return TruncationResult(
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncated_by=None,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=total_lines,
|
||||
output_bytes=total_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Work backwards from the end
|
||||
output_lines_arr = []
|
||||
output_bytes_count = 0
|
||||
truncated_by = "lines"
|
||||
last_line_partial = False
|
||||
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
if len(output_lines_arr) >= max_lines:
|
||||
break
|
||||
|
||||
line = lines[i]
|
||||
# Calculate line bytes (add newline if not the first added line)
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if len(output_lines_arr) > 0 else 0)
|
||||
|
||||
if output_bytes_count + line_bytes > max_bytes:
|
||||
truncated_by = "bytes"
|
||||
# Edge case: if we haven't added any lines yet and this line exceeds maxBytes,
|
||||
# take the end portion of this line
|
||||
if len(output_lines_arr) == 0:
|
||||
truncated_line = _truncate_string_to_bytes_from_end(line, max_bytes)
|
||||
output_lines_arr.insert(0, truncated_line)
|
||||
output_bytes_count = len(truncated_line.encode('utf-8'))
|
||||
last_line_partial = True
|
||||
break
|
||||
|
||||
output_lines_arr.insert(0, line)
|
||||
output_bytes_count += line_bytes
|
||||
|
||||
# If exited due to line limit
|
||||
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||
truncated_by = "lines"
|
||||
|
||||
output_content = '\n'.join(output_lines_arr)
|
||||
final_output_bytes = len(output_content.encode('utf-8'))
|
||||
|
||||
return TruncationResult(
|
||||
content=output_content,
|
||||
truncated=True,
|
||||
truncated_by=truncated_by,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=len(output_lines_arr),
|
||||
output_bytes=final_output_bytes,
|
||||
last_line_partial=last_line_partial,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
|
||||
def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
|
||||
"""
|
||||
Truncate string to fit byte limit (from end).
|
||||
Properly handles multi-byte UTF-8 characters.
|
||||
|
||||
:param text: String to truncate
|
||||
:param max_bytes: Maximum bytes
|
||||
:return: Truncated string
|
||||
"""
|
||||
encoded = text.encode('utf-8')
|
||||
if len(encoded) <= max_bytes:
|
||||
return text
|
||||
|
||||
# Start from end, skip back maxBytes
|
||||
start = len(encoded) - max_bytes
|
||||
|
||||
# Find valid UTF-8 boundary (character start)
|
||||
while start < len(encoded) and (encoded[start] & 0xC0) == 0x80:
|
||||
start += 1
|
||||
|
||||
return encoded[start:].decode('utf-8', errors='ignore')
|
||||
|
||||
|
||||
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> tuple[str, bool]:
|
||||
"""
|
||||
Truncate single line to max characters, add [truncated] suffix.
|
||||
Used for grep match lines.
|
||||
|
||||
:param line: Line to truncate
|
||||
:param max_chars: Maximum characters
|
||||
:return: (truncated text, whether truncated)
|
||||
"""
|
||||
if len(line) <= max_chars:
|
||||
return line, False
|
||||
return f"{line[:max_chars]}... [truncated]", True
|
||||
3
agent/tools/write/__init__.py
Normal file
3
agent/tools/write/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .write import Write
|
||||
|
||||
__all__ = ['Write']
|
||||
91
agent/tools/write/write.py
Normal file
91
agent/tools/write/write.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Write tool - Write file content
|
||||
Creates or overwrites files, automatically creates parent directories
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
|
||||
|
||||
class Write(BaseTool):
|
||||
"""Tool for writing file content"""
|
||||
|
||||
name: str = "write"
|
||||
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to write (relative or absolute)"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file write operation
|
||||
|
||||
:param args: Contains file path and content
|
||||
:return: Operation result
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
content = args.get("content", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
try:
|
||||
# Create parent directory (if needed)
|
||||
parent_dir = os.path.dirname(absolute_path)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
# Write file
|
||||
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
# Get bytes written
|
||||
bytes_written = len(content.encode('utf-8'))
|
||||
|
||||
result = {
|
||||
"message": f"Successfully wrote {bytes_written} bytes to {path}",
|
||||
"path": path,
|
||||
"bytes_written": bytes_written
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied writing to {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error writing file: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
@@ -1,222 +0,0 @@
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from curl_cffi import requests
|
||||
from bot.bot import Bot
|
||||
from bot.claude.claude_ai_session import ClaudeAiSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class ClaudeAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.claude_api_cookie = conf().get("claude_api_cookie")
|
||||
self.proxy = conf().get("proxy")
|
||||
self.con_uuid_dic = {}
|
||||
if self.proxy:
|
||||
self.proxies = {
|
||||
"http": self.proxy,
|
||||
"https": self.proxy
|
||||
}
|
||||
else:
|
||||
self.proxies = None
|
||||
self.error = ""
|
||||
self.org_uuid = self.get_organization_id()
|
||||
|
||||
def generate_uuid(self):
|
||||
random_uuid = uuid.uuid4()
|
||||
random_uuid_str = str(random_uuid)
|
||||
formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
|
||||
return formatted_uuid
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def get_organization_id(self):
|
||||
url = "https://claude.ai/api/organizations"
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}'
|
||||
}
|
||||
try:
|
||||
response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
|
||||
res = json.loads(response.text)
|
||||
uuid = res[0]['uuid']
|
||||
except:
|
||||
if "App unavailable" in response.text:
|
||||
logger.error("IP error: The IP is not allowed to be used on Claude")
|
||||
self.error = "ip所在地区不被claude支持"
|
||||
elif "Invalid authorization" in response.text:
|
||||
logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
|
||||
self.error = "无法通过claude身份验证,请检查cookie"
|
||||
return None
|
||||
return uuid
|
||||
|
||||
def conversation_share_check(self,session_id):
|
||||
if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
|
||||
con_uuid = conf().get("claude_uuid")
|
||||
return con_uuid
|
||||
if session_id not in self.con_uuid_dic:
|
||||
self.con_uuid_dic[session_id] = self.generate_uuid()
|
||||
self.create_new_chat(self.con_uuid_dic[session_id])
|
||||
return self.con_uuid_dic[session_id]
|
||||
|
||||
def check_cookie(self):
|
||||
flag = self.get_organization_id()
|
||||
return flag
|
||||
|
||||
def create_new_chat(self, con_uuid):
|
||||
"""
|
||||
新建claude对话实体
|
||||
:param con_uuid: 对话id
|
||||
:return:
|
||||
"""
|
||||
url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
|
||||
payload = json.dumps({"uuid": con_uuid, "name": ""})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': self.claude_api_cookie,
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
|
||||
# Returns JSON of the newly created conversation information
|
||||
return response.json()
|
||||
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[CLAUDEAI] failed after maximum number of retry times")
|
||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
||||
|
||||
try:
|
||||
session_id = context["session_id"]
|
||||
if self.org_uuid is None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
con_uuid = self.conversation_share_check(session_id)
|
||||
|
||||
model = conf().get("model") or "gpt-3.5-turbo"
|
||||
# remove system message
|
||||
if session.messages[0].get("role") == "system":
|
||||
if model == "wenxin" or model == "claude":
|
||||
session.messages.pop(0)
|
||||
logger.info(f"[CLAUDEAI] query={query}")
|
||||
|
||||
# do http request
|
||||
base_url = "https://claude.ai"
|
||||
payload = json.dumps({
|
||||
"completion": {
|
||||
"prompt": f"{query}",
|
||||
"timezone": "Asia/Kolkata",
|
||||
"model": "claude-2"
|
||||
},
|
||||
"organization_uuid": f"{self.org_uuid}",
|
||||
"conversation_uuid": f"{con_uuid}",
|
||||
"text": f"{query}",
|
||||
"attachments": []
|
||||
})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept': 'text/event-stream, text/event-stream',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
|
||||
res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
|
||||
if res.status_code == 200 or "pemission" in res.text:
|
||||
# execute success
|
||||
decoded_data = res.content.decode("utf-8")
|
||||
decoded_data = re.sub('\n+', '\n', decoded_data).strip()
|
||||
data_strings = decoded_data.split('\n')
|
||||
completions = []
|
||||
for data_string in data_strings:
|
||||
json_str = data_string[6:].strip()
|
||||
data = json.loads(json_str)
|
||||
if 'completion' in data:
|
||||
completions.append(data['completion'])
|
||||
|
||||
reply_content = ''.join(completions)
|
||||
|
||||
if "rate limi" in reply_content:
|
||||
logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
|
||||
return Reply(ReplyType.ERROR, "对话达到系统速率限制,与cladue同步,请进入官网查看解除限制时间")
|
||||
logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
|
||||
self.sessions.session_reply(reply_content, session_id, 100)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
else:
|
||||
flag = self.check_cookie()
|
||||
if flag == None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
@@ -1,9 +0,0 @@
|
||||
from bot.session_manager import Session
|
||||
|
||||
|
||||
class ClaudeAiSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="claude"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# claude逆向不支持role prompt
|
||||
# self.reset()
|
||||
@@ -1,19 +1,18 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import anthropic
|
||||
import requests
|
||||
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
user_session = dict()
|
||||
@@ -23,13 +22,9 @@ user_session = dict()
|
||||
class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
proxy = conf().get("proxy", None)
|
||||
base_url = conf().get("open_ai_api_base", None) # 复用"open_ai_api_base"参数作为base_url
|
||||
self.claudeClient = anthropic.Anthropic(
|
||||
api_key=conf().get("claude_api_key"),
|
||||
proxies=proxy if proxy else None,
|
||||
base_url=base_url if base_url else None
|
||||
)
|
||||
self.api_key = conf().get("claude_api_key")
|
||||
self.api_base = conf().get("open_ai_api_base") or "https://api.anthropic.com/v1"
|
||||
self.proxy = conf().get("proxy", None)
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
@@ -73,39 +68,104 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0, tools=None):
|
||||
try:
|
||||
actual_model = self._model_mapping(conf().get("model"))
|
||||
response = self.claudeClient.messages.create(
|
||||
model=actual_model,
|
||||
max_tokens=4096,
|
||||
system=conf().get("character_desc", ""),
|
||||
messages=session.messages
|
||||
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
# Extract system prompt if present and prepare Claude-compatible messages
|
||||
system_prompt = conf().get("character_desc", "")
|
||||
claude_messages = []
|
||||
|
||||
for msg in session.messages:
|
||||
if msg.get("role") == "system":
|
||||
system_prompt = msg["content"]
|
||||
else:
|
||||
claude_messages.append(msg)
|
||||
|
||||
# Prepare request data
|
||||
data = {
|
||||
"model": actual_model,
|
||||
"messages": claude_messages,
|
||||
"max_tokens": self._get_max_tokens(actual_model)
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
data["system"] = system_prompt
|
||||
|
||||
if tools:
|
||||
data["tools"] = tools
|
||||
|
||||
# Make HTTP request
|
||||
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
||||
response = requests.post(
|
||||
f"{self.api_base}/messages",
|
||||
headers=headers,
|
||||
json=data,
|
||||
proxies=proxies
|
||||
)
|
||||
# response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response.usage.input_tokens+response.usage.output_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"API request failed: {response.status_code} - {response.text}")
|
||||
|
||||
claude_response = response.json()
|
||||
# Handle response content and tool calls
|
||||
res_content = ""
|
||||
tool_calls = []
|
||||
|
||||
content_blocks = claude_response.get("content", [])
|
||||
for block in content_blocks:
|
||||
if block.get("type") == "text":
|
||||
res_content += block.get("text", "")
|
||||
elif block.get("type") == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": block.get("id", ""),
|
||||
"name": block.get("name", ""),
|
||||
"arguments": block.get("input", {})
|
||||
})
|
||||
|
||||
res_content = res_content.strip().replace("<|endoftext|>", "")
|
||||
usage = claude_response.get("usage", {})
|
||||
total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
|
||||
completion_tokens = usage.get("output_tokens", 0)
|
||||
|
||||
logger.info("[CLAUDE_API] reply={}".format(res_content))
|
||||
return {
|
||||
if tool_calls:
|
||||
logger.info("[CLAUDE_API] tool_calls={}".format(tool_calls))
|
||||
|
||||
result = {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
result["tool_calls"] = tool_calls
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"total_tokens": 0, "completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
|
||||
# Handle different types of errors
|
||||
error_str = str(e).lower()
|
||||
if "rate" in error_str or "limit" in error_str:
|
||||
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
elif "timeout" in error_str:
|
||||
logger.warn("[CLAUDE_API] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
elif "connection" in error_str or "network" in error_str:
|
||||
logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
@@ -116,7 +176,7 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
return self.reply_text(session, retry_count + 1, tools)
|
||||
else:
|
||||
return result
|
||||
|
||||
@@ -130,3 +190,288 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
elif model == "claude-3.5-sonnet":
|
||||
return const.CLAUDE_35_SONNET
|
||||
return model
|
||||
|
||||
def _get_max_tokens(self, model: str) -> int:
|
||||
"""
|
||||
Get max_tokens for the model.
|
||||
Reference from pi-mono:
|
||||
- Claude 3.5/3.7: 8192
|
||||
- Claude 3 Opus: 4096
|
||||
- Default: 8192
|
||||
"""
|
||||
if model and (model.startswith("claude-3-5") or model.startswith("claude-3-7")):
|
||||
return 8192
|
||||
elif model and model.startswith("claude-3") and "opus" in model:
|
||||
return 4096
|
||||
elif model and (model.startswith("claude-sonnet-4") or model.startswith("claude-opus-4")):
|
||||
return 64000
|
||||
return 8192
|
||||
|
||||
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||
"""
|
||||
Call Claude API with tool support for agent integration
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
tools: List of tool definitions
|
||||
stream: Whether to use streaming
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Formatted response compatible with OpenAI format or generator for streaming
|
||||
"""
|
||||
actual_model = self._model_mapping(conf().get("model"))
|
||||
|
||||
# Extract system prompt from messages if present
|
||||
system_prompt = kwargs.get("system", conf().get("character_desc", ""))
|
||||
claude_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
system_prompt = msg["content"]
|
||||
else:
|
||||
claude_messages.append(msg)
|
||||
|
||||
request_params = {
|
||||
"model": actual_model,
|
||||
"max_tokens": kwargs.get("max_tokens", self._get_max_tokens(actual_model)),
|
||||
"messages": claude_messages,
|
||||
"stream": stream
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
request_params["system"] = system_prompt
|
||||
|
||||
if tools:
|
||||
request_params["tools"] = tools
|
||||
|
||||
try:
|
||||
if stream:
|
||||
return self._handle_stream_response(request_params)
|
||||
else:
|
||||
return self._handle_sync_response(request_params)
|
||||
except Exception as e:
|
||||
logger.error(f"Claude API call error: {e}")
|
||||
if stream:
|
||||
# Return error generator for stream
|
||||
def error_generator():
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
return error_generator()
|
||||
else:
|
||||
# Return error response for sync
|
||||
return {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
def _handle_sync_response(self, request_params):
|
||||
"""Handle synchronous Claude API response"""
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
# Make HTTP request
|
||||
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
||||
response = requests.post(
|
||||
f"{self.api_base}/messages",
|
||||
headers=headers,
|
||||
json=request_params,
|
||||
proxies=proxies
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"API request failed: {response.status_code} - {response.text}")
|
||||
|
||||
claude_response = response.json()
|
||||
|
||||
# Extract content blocks
|
||||
text_content = ""
|
||||
tool_calls = []
|
||||
|
||||
content_blocks = claude_response.get("content", [])
|
||||
for block in content_blocks:
|
||||
if block.get("type") == "text":
|
||||
text_content += block.get("text", "")
|
||||
elif block.get("type") == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": block.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name", ""),
|
||||
"arguments": json.dumps(block.get("input", {}))
|
||||
}
|
||||
})
|
||||
|
||||
# Build message in OpenAI format
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": text_content
|
||||
}
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
# Format response to match OpenAI structure
|
||||
usage = claude_response.get("usage", {})
|
||||
formatted_response = {
|
||||
"id": claude_response.get("id", ""),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": claude_response.get("model", request_params["model"]),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": message,
|
||||
"finish_reason": claude_response.get("stop_reason", "stop")
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
|
||||
}
|
||||
}
|
||||
|
||||
return formatted_response
|
||||
|
||||
def _handle_stream_response(self, request_params):
|
||||
"""Handle streaming Claude API response using HTTP requests"""
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
# Add stream parameter
|
||||
request_params["stream"] = True
|
||||
|
||||
# Track tool use state
|
||||
tool_uses_map = {} # {index: {id, name, input}}
|
||||
current_tool_use_index = -1
|
||||
|
||||
try:
|
||||
# Make streaming HTTP request
|
||||
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
||||
response = requests.post(
|
||||
f"{self.api_base}/messages",
|
||||
headers=headers,
|
||||
json=request_params,
|
||||
proxies=proxies,
|
||||
stream=True
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
try:
|
||||
error_data = json.loads(error_text)
|
||||
error_msg = error_data.get("error", {}).get("message", error_text)
|
||||
except:
|
||||
error_msg = error_text or "Unknown error"
|
||||
|
||||
yield {
|
||||
"error": True,
|
||||
"status_code": response.status_code,
|
||||
"message": error_msg
|
||||
}
|
||||
return
|
||||
|
||||
# Process streaming response
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode('utf-8')
|
||||
if line.startswith('data: '):
|
||||
line = line[6:] # Remove 'data: ' prefix
|
||||
if line == '[DONE]':
|
||||
break
|
||||
try:
|
||||
event = json.loads(line)
|
||||
event_type = event.get("type")
|
||||
|
||||
if event_type == "content_block_start":
|
||||
# New content block
|
||||
block = event.get("content_block", {})
|
||||
if block.get("type") == "tool_use":
|
||||
current_tool_use_index = event.get("index", 0)
|
||||
tool_uses_map[current_tool_use_index] = {
|
||||
"id": block.get("id", ""),
|
||||
"name": block.get("name", ""),
|
||||
"input": ""
|
||||
}
|
||||
|
||||
elif event_type == "content_block_delta":
|
||||
delta = event.get("delta", {})
|
||||
delta_type = delta.get("type")
|
||||
|
||||
if delta_type == "text_delta":
|
||||
# Text content
|
||||
content = delta.get("text", "")
|
||||
yield {
|
||||
"id": event.get("id", ""),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request_params["model"],
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": content},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
|
||||
elif delta_type == "input_json_delta":
|
||||
# Tool input accumulation
|
||||
if current_tool_use_index >= 0:
|
||||
tool_uses_map[current_tool_use_index]["input"] += delta.get("partial_json", "")
|
||||
|
||||
elif event_type == "message_delta":
|
||||
# Message complete - yield tool calls if any
|
||||
if tool_uses_map:
|
||||
for idx in sorted(tool_uses_map.keys()):
|
||||
tool_data = tool_uses_map[idx]
|
||||
yield {
|
||||
"id": event.get("id", ""),
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": request_params["model"],
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [{
|
||||
"index": idx,
|
||||
"id": tool_data["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_data["name"],
|
||||
"arguments": tool_data["input"]
|
||||
}
|
||||
}]
|
||||
},
|
||||
"finish_reason": None
|
||||
}]
|
||||
}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Claude streaming request error: {e}")
|
||||
yield {
|
||||
"error": True,
|
||||
"message": f"Connection error: {str(e)}",
|
||||
"status_code": 0
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Claude streaming error: {e}")
|
||||
yield {
|
||||
"error": True,
|
||||
"message": str(e),
|
||||
"status_code": 500
|
||||
}
|
||||
|
||||
288
bridge/agent_bridge.py
Normal file
288
bridge/agent_bridge.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Agent Bridge - Integrates Agent system with existing COW bridge
|
||||
"""
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from agent.protocol import Agent, LLMModel, LLMRequest
|
||||
from agent.tools import Calculator, CurrentTime, Read, Write, Edit, Bash, Grep, Find, Ls
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common import const
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class AgentLLMModel(LLMModel):
|
||||
"""
|
||||
LLM Model adapter that uses COW's existing bot infrastructure
|
||||
"""
|
||||
|
||||
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
|
||||
# Get model name directly from config
|
||||
from config import conf
|
||||
model_name = conf().get("model", const.GPT_41)
|
||||
super().__init__(model=model_name)
|
||||
self.bridge = bridge
|
||||
self.bot_type = bot_type
|
||||
self._bot = None
|
||||
|
||||
@property
|
||||
def bot(self):
|
||||
"""Lazy load the bot"""
|
||||
if self._bot is None:
|
||||
self._bot = self.bridge.get_bot(self.bot_type)
|
||||
return self._bot
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model using COW's bot infrastructure
|
||||
"""
|
||||
try:
|
||||
# For non-streaming calls, we'll use the existing reply method
|
||||
# This is a simplified implementation
|
||||
if hasattr(self.bot, 'call_with_tools'):
|
||||
# Use tool-enabled call if available
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
'tools': getattr(request, 'tools', None),
|
||||
'stream': False
|
||||
}
|
||||
# Only pass max_tokens if it's explicitly set
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
response = self.bot.call_with_tools(**kwargs)
|
||||
return self._format_response(response)
|
||||
else:
|
||||
# Fallback to regular call
|
||||
# This would need to be implemented based on your specific needs
|
||||
raise NotImplementedError("Regular call not implemented yet")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentLLMModel call error: {e}")
|
||||
raise
|
||||
|
||||
def call_stream(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with streaming using COW's bot infrastructure
|
||||
"""
|
||||
try:
|
||||
if hasattr(self.bot, 'call_with_tools'):
|
||||
# Use tool-enabled streaming call if available
|
||||
# Ensure max_tokens is an integer, use default if None
|
||||
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
|
||||
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
|
||||
# Build kwargs for call_with_tools
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
'tools': getattr(request, 'tools', None),
|
||||
'stream': True,
|
||||
'max_tokens': max_tokens
|
||||
}
|
||||
|
||||
# Add system prompt if present
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
stream = self.bot.call_with_tools(**kwargs)
|
||||
|
||||
# Convert Claude stream format to our expected format
|
||||
for chunk in stream:
|
||||
yield self._format_stream_chunk(chunk)
|
||||
else:
|
||||
raise NotImplementedError("Streaming call not implemented yet")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentLLMModel call_stream error: {e}")
|
||||
raise
|
||||
|
||||
def _format_response(self, response):
|
||||
"""Format Claude response to our expected format"""
|
||||
# This would need to be implemented based on Claude's response format
|
||||
return response
|
||||
|
||||
def _format_stream_chunk(self, chunk):
|
||||
"""Format Claude stream chunk to our expected format"""
|
||||
# This would need to be implemented based on Claude's stream format
|
||||
return chunk
|
||||
|
||||
|
||||
class AgentBridge:
|
||||
"""
|
||||
Bridge class that integrates single super Agent with COW
|
||||
"""
|
||||
|
||||
def __init__(self, bridge: Bridge):
|
||||
self.bridge = bridge
|
||||
self.agent: Optional[Agent] = None
|
||||
|
||||
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
|
||||
"""
|
||||
Create the super agent with COW integration
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt
|
||||
tools: List of tools (optional)
|
||||
**kwargs: Additional agent parameters
|
||||
|
||||
Returns:
|
||||
Agent instance
|
||||
"""
|
||||
# Create LLM model that uses COW's bot infrastructure
|
||||
model = AgentLLMModel(self.bridge)
|
||||
|
||||
# Default tools if none provided
|
||||
if tools is None:
|
||||
tools = [
|
||||
Calculator(),
|
||||
CurrentTime(),
|
||||
Read(),
|
||||
Write(),
|
||||
Edit(),
|
||||
Bash(),
|
||||
Grep(),
|
||||
Find(),
|
||||
Ls()
|
||||
]
|
||||
|
||||
# Create the single super agent
|
||||
self.agent = Agent(
|
||||
system_prompt=system_prompt,
|
||||
description=kwargs.get("description", "AI Super Agent"),
|
||||
model=model,
|
||||
tools=tools,
|
||||
max_steps=kwargs.get("max_steps", 15),
|
||||
output_mode=kwargs.get("output_mode", "logger")
|
||||
)
|
||||
|
||||
return self.agent
|
||||
|
||||
def get_agent(self) -> Optional[Agent]:
|
||||
"""Get the super agent, create if not exists"""
|
||||
if self.agent is None:
|
||||
self._init_default_agent()
|
||||
return self.agent
|
||||
|
||||
def _init_default_agent(self):
|
||||
"""Initialize default super agent with config and memory"""
|
||||
from config import conf
|
||||
import os
|
||||
|
||||
# Get base system prompt from config
|
||||
base_prompt = conf().get("character_desc", "你是一个AI助手")
|
||||
|
||||
# Setup memory if enabled
|
||||
memory_manager = None
|
||||
memory_tools = []
|
||||
|
||||
try:
|
||||
# Try to initialize memory system
|
||||
from agent.memory import MemoryManager, MemoryConfig
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
|
||||
# Create memory config directly with sensible defaults
|
||||
workspace_root = os.path.expanduser("~/cow")
|
||||
memory_config = MemoryConfig(
|
||||
workspace_root=workspace_root,
|
||||
embedding_provider="local", # Use local embedding (no API key needed)
|
||||
embedding_model="all-MiniLM-L6-v2"
|
||||
)
|
||||
|
||||
# Create memory manager with the config
|
||||
memory_manager = MemoryManager(memory_config)
|
||||
|
||||
# Create memory tools
|
||||
memory_tools = [
|
||||
MemorySearchTool(memory_manager),
|
||||
MemoryGetTool(memory_manager)
|
||||
]
|
||||
|
||||
# Build memory guidance and add to system prompt
|
||||
memory_guidance = memory_manager.build_memory_guidance(
|
||||
lang="zh",
|
||||
include_context=True
|
||||
)
|
||||
system_prompt = base_prompt + "\n\n" + memory_guidance
|
||||
|
||||
logger.info(f"[AgentBridge] Memory system initialized")
|
||||
logger.info(f"[AgentBridge] Workspace: {memory_config.get_workspace()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Memory system not available: {e}")
|
||||
logger.info("[AgentBridge] Continuing without memory features")
|
||||
system_prompt = base_prompt
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
logger.info("[AgentBridge] Initializing super agent")
|
||||
|
||||
# Configure file tools to work in the correct workspace
|
||||
file_config = {"cwd": workspace_root} if memory_manager else {}
|
||||
|
||||
# Create default tools with workspace config
|
||||
from agent.tools import Calculator, CurrentTime, Read, Write, Edit, Bash, Grep, Find, Ls
|
||||
tools = [
|
||||
Calculator(),
|
||||
CurrentTime(),
|
||||
Read(config=file_config),
|
||||
Write(config=file_config),
|
||||
Edit(config=file_config),
|
||||
Bash(config=file_config),
|
||||
Grep(config=file_config),
|
||||
Find(config=file_config),
|
||||
Ls(config=file_config)
|
||||
]
|
||||
|
||||
# Create agent with configured tools
|
||||
agent = self.create_agent(
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
max_steps=50,
|
||||
output_mode="logger"
|
||||
)
|
||||
|
||||
# Attach memory manager to agent if available
|
||||
if memory_manager:
|
||||
agent.memory_manager = memory_manager
|
||||
|
||||
# Add memory tools if available
|
||||
if memory_tools:
|
||||
for tool in memory_tools:
|
||||
agent.add_tool(tool)
|
||||
logger.info(f"[AgentBridge] Added {len(memory_tools)} memory tools")
|
||||
|
||||
def agent_reply(self, query: str, context: Context = None,
|
||||
on_event=None, clear_history: bool = False) -> Reply:
|
||||
"""
|
||||
Use super agent to reply to a query
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
context: COW context (optional)
|
||||
on_event: Event callback (optional)
|
||||
clear_history: Whether to clear conversation history
|
||||
|
||||
Returns:
|
||||
Reply object
|
||||
"""
|
||||
try:
|
||||
# Get agent (will auto-initialize if needed)
|
||||
agent = self.get_agent()
|
||||
if not agent:
|
||||
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
|
||||
|
||||
# Use agent's run_stream method
|
||||
response = agent.run_stream(
|
||||
user_message=query,
|
||||
on_event=on_event,
|
||||
clear_history=clear_history
|
||||
)
|
||||
|
||||
return Reply(ReplyType.TEXT, response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent reply error: {e}")
|
||||
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
|
||||
@@ -23,7 +23,7 @@ class Bridge(object):
|
||||
if bot_type:
|
||||
self.btype["chat"] = bot_type
|
||||
else:
|
||||
model_type = conf().get("model") or const.GPT35
|
||||
model_type = conf().get("model") or const.GPT_41_MINI
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
@@ -64,6 +64,7 @@ class Bridge(object):
|
||||
|
||||
self.bots = {}
|
||||
self.chat_bots = {}
|
||||
self._agent_bridge = None
|
||||
|
||||
# 模型对应的接口
|
||||
def get_bot(self, typename):
|
||||
@@ -104,3 +105,29 @@ class Bridge(object):
|
||||
重置bot路由
|
||||
"""
|
||||
self.__init__()
|
||||
|
||||
def get_agent_bridge(self):
|
||||
"""
|
||||
Get agent bridge for agent-based conversations
|
||||
"""
|
||||
if self._agent_bridge is None:
|
||||
from bridge.agent_bridge import AgentBridge
|
||||
self._agent_bridge = AgentBridge(self)
|
||||
return self._agent_bridge
|
||||
|
||||
def fetch_agent_reply(self, query: str, context: Context = None,
|
||||
on_event=None, clear_history: bool = False) -> Reply:
|
||||
"""
|
||||
Use super agent to handle the query
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
context: Context object
|
||||
on_event: Event callback for streaming
|
||||
clear_history: Whether to clear conversation history
|
||||
|
||||
Returns:
|
||||
Reply object
|
||||
"""
|
||||
agent_bridge = self.get_agent_bridge()
|
||||
return agent_bridge.agent_reply(query, context, on_event, clear_history)
|
||||
|
||||
@@ -5,6 +5,8 @@ Message sending channel abstract class
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import *
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class Channel(object):
|
||||
@@ -35,6 +37,29 @@ class Channel(object):
|
||||
raise NotImplementedError
|
||||
|
||||
def build_reply_content(self, query, context: Context = None) -> Reply:
|
||||
"""
|
||||
Build reply content, using agent if enabled in config
|
||||
"""
|
||||
# Check if agent mode is enabled
|
||||
use_agent = conf().get("agent", False)
|
||||
|
||||
if use_agent:
|
||||
try:
|
||||
logger.info("[Channel] Using agent mode")
|
||||
|
||||
# Use agent bridge to handle the query
|
||||
return Bridge().fetch_agent_reply(
|
||||
query=query,
|
||||
context=context,
|
||||
on_event=None,
|
||||
clear_history=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Channel] Agent mode failed, fallback to normal mode: {e}")
|
||||
# Fallback to normal mode if agent fails
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
else:
|
||||
# Normal mode
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
|
||||
def build_voice_to_text(self, voice_file) -> Reply:
|
||||
|
||||
@@ -150,9 +150,6 @@ class WebChannel(ChatChannel):
|
||||
Poll for responses using the session_id.
|
||||
"""
|
||||
try:
|
||||
# 不记录轮询请求的日志
|
||||
web.ctx.log_request = False
|
||||
|
||||
data = web.data()
|
||||
json_data = json.loads(data)
|
||||
session_id = json_data.get('session_id')
|
||||
@@ -215,18 +212,19 @@ class WebChannel(ChatChannel):
|
||||
)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
|
||||
# 禁用web.py的默认日志输出
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
# 完全禁用web.py的HTTP日志输出
|
||||
# 创建一个空的日志处理函数
|
||||
def null_log_function(status, environ):
|
||||
pass
|
||||
|
||||
# 配置web.py的日志级别为ERROR,只显示错误
|
||||
# 替换web.py的日志函数
|
||||
web.httpserver.LogMiddleware.log = lambda self, status, environ: None
|
||||
|
||||
# 配置web.py的日志级别为ERROR
|
||||
logging.getLogger("web").setLevel(logging.ERROR)
|
||||
|
||||
# 禁用web.httpserver的日志
|
||||
logging.getLogger("web.httpserver").setLevel(logging.ERROR)
|
||||
|
||||
# 临时重定向标准输出,捕获web.py的启动消息
|
||||
with redirect_stdout(io.StringIO()):
|
||||
# 启动服务器
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"model": "",
|
||||
"open_ai_api_key": "YOUR API KEY",
|
||||
"claude_api_key": "YOUR API KEY",
|
||||
"claude_api_base": "https://api.anthropic.com",
|
||||
"text_to_image": "dall-e-2",
|
||||
"voice_to_text": "openai",
|
||||
"text_to_voice": "openai",
|
||||
@@ -30,8 +31,9 @@
|
||||
"expires_in_seconds": 3600,
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"temperature": 0.7,
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": ""
|
||||
"linkai_app_code": "",
|
||||
"agent": true
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import copy
|
||||
|
||||
from common.log import logger
|
||||
|
||||
@@ -183,6 +183,7 @@ available_setting = {
|
||||
"Minimax_group_id": "",
|
||||
"Minimax_base_url": "",
|
||||
"web_port": 9899,
|
||||
"agent": False # 是否开启Agent模式
|
||||
}
|
||||
|
||||
|
||||
|
||||
5
memory/2026-01-29.md
Normal file
5
memory/2026-01-29.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# 2026-01-29 记录
|
||||
|
||||
## 老王的重要决定
|
||||
- 今天老王告诉我他决定要学AI了,这是一个重要的决策
|
||||
- 这可能会是他学习和职业发展的一个转折点
|
||||
21
memory/MEMORY.md
Normal file
21
memory/MEMORY.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# Memory
|
||||
|
||||
Long-term curated memories and preferences.
|
||||
|
||||
## 用户信息
|
||||
- 用户名:老王
|
||||
|
||||
## 用户信息
|
||||
- 用户名:老王
|
||||
|
||||
## 用户偏好
|
||||
- 喜欢吃红烧肉
|
||||
- 爱打篮球
|
||||
|
||||
## 重要决策
|
||||
- 决定要学习AI(2026-01-29)
|
||||
|
||||
## Notes
|
||||
|
||||
- Important decisions and facts go here
|
||||
- This is your long-term knowledge base
|
||||
Reference in New Issue
Block a user