mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: personal ai agent framework
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -35,3 +35,6 @@ plugins/banwords/lib/__pycache__
|
|||||||
!plugins/linkai
|
!plugins/linkai
|
||||||
!plugins/agent
|
!plugins/agent
|
||||||
client_config.json
|
client_config.json
|
||||||
|
ref/
|
||||||
|
.cursor/
|
||||||
|
local/
|
||||||
|
|||||||
10
agent/memory/__init__.py
Normal file
10
agent/memory/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
Memory module for AgentMesh
|
||||||
|
|
||||||
|
Provides long-term memory capabilities with hybrid search (vector + keyword)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from agent.memory.manager import MemoryManager
|
||||||
|
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||||
|
|
||||||
|
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config']
|
||||||
139
agent/memory/chunker.py
Normal file
139
agent/memory/chunker.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""
|
||||||
|
Text chunking utilities for memory
|
||||||
|
|
||||||
|
Splits text into chunks with token limits and overlap
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextChunk:
|
||||||
|
"""Represents a text chunk with line numbers"""
|
||||||
|
text: str
|
||||||
|
start_line: int
|
||||||
|
end_line: int
|
||||||
|
|
||||||
|
|
||||||
|
class TextChunker:
|
||||||
|
"""Chunks text by line count with token estimation"""
|
||||||
|
|
||||||
|
def __init__(self, max_tokens: int = 500, overlap_tokens: int = 50):
|
||||||
|
"""
|
||||||
|
Initialize chunker
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_tokens: Maximum tokens per chunk
|
||||||
|
overlap_tokens: Overlap tokens between chunks
|
||||||
|
"""
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.overlap_tokens = overlap_tokens
|
||||||
|
# Rough estimation: ~4 chars per token for English/Chinese mixed
|
||||||
|
self.chars_per_token = 4
|
||||||
|
|
||||||
|
def chunk_text(self, text: str) -> List[TextChunk]:
|
||||||
|
"""
|
||||||
|
Chunk text into overlapping segments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TextChunk objects
|
||||||
|
"""
|
||||||
|
if not text.strip():
|
||||||
|
return []
|
||||||
|
|
||||||
|
lines = text.split('\n')
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
max_chars = self.max_tokens * self.chars_per_token
|
||||||
|
overlap_chars = self.overlap_tokens * self.chars_per_token
|
||||||
|
|
||||||
|
current_chunk = []
|
||||||
|
current_chars = 0
|
||||||
|
start_line = 1
|
||||||
|
|
||||||
|
for i, line in enumerate(lines, start=1):
|
||||||
|
line_chars = len(line)
|
||||||
|
|
||||||
|
# If single line exceeds max, split it
|
||||||
|
if line_chars > max_chars:
|
||||||
|
# Save current chunk if exists
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(TextChunk(
|
||||||
|
text='\n'.join(current_chunk),
|
||||||
|
start_line=start_line,
|
||||||
|
end_line=i - 1
|
||||||
|
))
|
||||||
|
current_chunk = []
|
||||||
|
current_chars = 0
|
||||||
|
|
||||||
|
# Split long line into multiple chunks
|
||||||
|
for sub_chunk in self._split_long_line(line, max_chars):
|
||||||
|
chunks.append(TextChunk(
|
||||||
|
text=sub_chunk,
|
||||||
|
start_line=i,
|
||||||
|
end_line=i
|
||||||
|
))
|
||||||
|
|
||||||
|
start_line = i + 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if adding this line would exceed limit
|
||||||
|
if current_chars + line_chars > max_chars and current_chunk:
|
||||||
|
# Save current chunk
|
||||||
|
chunks.append(TextChunk(
|
||||||
|
text='\n'.join(current_chunk),
|
||||||
|
start_line=start_line,
|
||||||
|
end_line=i - 1
|
||||||
|
))
|
||||||
|
|
||||||
|
# Start new chunk with overlap
|
||||||
|
overlap_lines = self._get_overlap_lines(current_chunk, overlap_chars)
|
||||||
|
current_chunk = overlap_lines + [line]
|
||||||
|
current_chars = sum(len(l) for l in current_chunk)
|
||||||
|
start_line = i - len(overlap_lines)
|
||||||
|
else:
|
||||||
|
# Add line to current chunk
|
||||||
|
current_chunk.append(line)
|
||||||
|
current_chars += line_chars
|
||||||
|
|
||||||
|
# Save last chunk
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(TextChunk(
|
||||||
|
text='\n'.join(current_chunk),
|
||||||
|
start_line=start_line,
|
||||||
|
end_line=len(lines)
|
||||||
|
))
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _split_long_line(self, line: str, max_chars: int) -> List[str]:
|
||||||
|
"""Split a single long line into multiple chunks"""
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(line), max_chars):
|
||||||
|
chunks.append(line[i:i + max_chars])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _get_overlap_lines(self, lines: List[str], target_chars: int) -> List[str]:
|
||||||
|
"""Get last few lines that fit within target_chars for overlap"""
|
||||||
|
overlap = []
|
||||||
|
chars = 0
|
||||||
|
|
||||||
|
for line in reversed(lines):
|
||||||
|
line_chars = len(line)
|
||||||
|
if chars + line_chars > target_chars:
|
||||||
|
break
|
||||||
|
overlap.insert(0, line)
|
||||||
|
chars += line_chars
|
||||||
|
|
||||||
|
return overlap
|
||||||
|
|
||||||
|
def chunk_markdown(self, text: str) -> List[TextChunk]:
|
||||||
|
"""
|
||||||
|
Chunk markdown text while respecting structure
|
||||||
|
(For future enhancement: respect markdown sections)
|
||||||
|
"""
|
||||||
|
return self.chunk_text(text)
|
||||||
114
agent/memory/config.py
Normal file
114
agent/memory/config.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""
|
||||||
|
Memory configuration module
|
||||||
|
|
||||||
|
Provides global memory configuration with simplified workspace structure
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, List
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryConfig:
|
||||||
|
"""Configuration for memory storage and search"""
|
||||||
|
|
||||||
|
# Storage paths (default: ~/cow)
|
||||||
|
workspace_root: str = field(default_factory=lambda: os.path.expanduser("~/cow"))
|
||||||
|
|
||||||
|
# Embedding config
|
||||||
|
embedding_provider: str = "openai" # "openai" | "local"
|
||||||
|
embedding_model: str = "text-embedding-3-small"
|
||||||
|
embedding_dim: int = 1536
|
||||||
|
|
||||||
|
# Chunking config
|
||||||
|
chunk_max_tokens: int = 500
|
||||||
|
chunk_overlap_tokens: int = 50
|
||||||
|
|
||||||
|
# Search config
|
||||||
|
max_results: int = 10
|
||||||
|
min_score: float = 0.3
|
||||||
|
|
||||||
|
# Hybrid search weights
|
||||||
|
vector_weight: float = 0.7
|
||||||
|
keyword_weight: float = 0.3
|
||||||
|
|
||||||
|
# Memory sources
|
||||||
|
sources: List[str] = field(default_factory=lambda: ["memory", "session"])
|
||||||
|
|
||||||
|
# Sync config
|
||||||
|
enable_auto_sync: bool = True
|
||||||
|
sync_on_search: bool = True
|
||||||
|
|
||||||
|
def get_workspace(self) -> Path:
|
||||||
|
"""Get workspace root directory"""
|
||||||
|
return Path(self.workspace_root)
|
||||||
|
|
||||||
|
def get_memory_dir(self) -> Path:
|
||||||
|
"""Get memory files directory"""
|
||||||
|
return self.get_workspace() / "memory"
|
||||||
|
|
||||||
|
def get_db_path(self) -> Path:
|
||||||
|
"""Get SQLite database path for long-term memory index"""
|
||||||
|
index_dir = self.get_memory_dir() / "long-term"
|
||||||
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
return index_dir / "index.db"
|
||||||
|
|
||||||
|
def get_skills_dir(self) -> Path:
|
||||||
|
"""Get skills directory"""
|
||||||
|
return self.get_workspace() / "skills"
|
||||||
|
|
||||||
|
def get_agent_workspace(self, agent_name: Optional[str] = None) -> Path:
|
||||||
|
"""
|
||||||
|
Get workspace directory for an agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: Optional agent name (not used in current implementation)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to workspace directory
|
||||||
|
"""
|
||||||
|
workspace = self.get_workspace()
|
||||||
|
# Ensure workspace directory exists
|
||||||
|
workspace.mkdir(parents=True, exist_ok=True)
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
# Global memory configuration
|
||||||
|
_global_memory_config: Optional[MemoryConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_memory_config() -> MemoryConfig:
|
||||||
|
"""
|
||||||
|
Get the global memory configuration.
|
||||||
|
If not set, returns a default configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryConfig instance
|
||||||
|
"""
|
||||||
|
global _global_memory_config
|
||||||
|
if _global_memory_config is None:
|
||||||
|
_global_memory_config = MemoryConfig()
|
||||||
|
return _global_memory_config
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_memory_config(config: MemoryConfig):
|
||||||
|
"""
|
||||||
|
Set the global memory configuration.
|
||||||
|
This should be called before creating any MemoryManager instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: MemoryConfig instance to use globally
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from agent.memory import MemoryConfig, set_global_memory_config
|
||||||
|
>>> config = MemoryConfig(
|
||||||
|
... workspace_root="~/my_agents",
|
||||||
|
... embedding_provider="openai",
|
||||||
|
... vector_weight=0.8
|
||||||
|
... )
|
||||||
|
>>> set_global_memory_config(config)
|
||||||
|
"""
|
||||||
|
global _global_memory_config
|
||||||
|
_global_memory_config = config
|
||||||
175
agent/memory/embedding.py
Normal file
175
agent/memory/embedding.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""
|
||||||
|
Embedding providers for memory
|
||||||
|
|
||||||
|
Supports OpenAI and local embedding models
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProvider(ABC):
|
||||||
|
"""Base class for embedding providers"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed(self, text: str) -> List[float]:
|
||||||
|
"""Generate embedding for text"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Generate embeddings for multiple texts"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def dimensions(self) -> int:
|
||||||
|
"""Get embedding dimensions"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||||
|
"""OpenAI embedding provider"""
|
||||||
|
|
||||||
|
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize OpenAI embedding provider
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
||||||
|
api_key: OpenAI API key
|
||||||
|
api_base: Optional API base URL
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.api_key = api_key
|
||||||
|
self.api_base = api_base or "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
# Lazy import to avoid dependency issues
|
||||||
|
try:
|
||||||
|
from openai import OpenAI
|
||||||
|
self.client = OpenAI(api_key=api_key, base_url=api_base)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("OpenAI package not installed. Install with: pip install openai")
|
||||||
|
|
||||||
|
# Set dimensions based on model
|
||||||
|
self._dimensions = 1536 if "small" in model else 3072
|
||||||
|
|
||||||
|
def embed(self, text: str) -> List[float]:
|
||||||
|
"""Generate embedding for text"""
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
input=text,
|
||||||
|
model=self.model
|
||||||
|
)
|
||||||
|
return response.data[0].embedding
|
||||||
|
|
||||||
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Generate embeddings for multiple texts"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
input=texts,
|
||||||
|
model=self.model
|
||||||
|
)
|
||||||
|
return [item.embedding for item in response.data]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimensions(self) -> int:
|
||||||
|
return self._dimensions
|
||||||
|
|
||||||
|
|
||||||
|
class LocalEmbeddingProvider(EmbeddingProvider):
|
||||||
|
"""Local embedding provider using sentence-transformers"""
|
||||||
|
|
||||||
|
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
||||||
|
"""
|
||||||
|
Initialize local embedding provider
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name from sentence-transformers
|
||||||
|
"""
|
||||||
|
self.model_name = model
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
self.model = SentenceTransformer(model)
|
||||||
|
self._dimensions = self.model.get_sentence_embedding_dimension()
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"sentence-transformers not installed. "
|
||||||
|
"Install with: pip install sentence-transformers"
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed(self, text: str) -> List[float]:
|
||||||
|
"""Generate embedding for text"""
|
||||||
|
embedding = self.model.encode(text, convert_to_numpy=True)
|
||||||
|
return embedding.tolist()
|
||||||
|
|
||||||
|
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Generate embeddings for multiple texts"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
embeddings = self.model.encode(texts, convert_to_numpy=True)
|
||||||
|
return embeddings.tolist()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimensions(self) -> int:
|
||||||
|
return self._dimensions
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingCache:
|
||||||
|
"""Cache for embeddings to avoid recomputation"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||||
|
"""Get cached embedding"""
|
||||||
|
key = self._compute_key(text, provider, model)
|
||||||
|
return self.cache.get(key)
|
||||||
|
|
||||||
|
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||||
|
"""Cache embedding"""
|
||||||
|
key = self._compute_key(text, provider, model)
|
||||||
|
self.cache[key] = embedding
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||||
|
"""Compute cache key"""
|
||||||
|
content = f"{provider}:{model}:{text}"
|
||||||
|
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear cache"""
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_provider(
|
||||||
|
provider: str = "openai",
|
||||||
|
model: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None
|
||||||
|
) -> EmbeddingProvider:
|
||||||
|
"""
|
||||||
|
Factory function to create embedding provider
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name ("openai" or "local")
|
||||||
|
model: Model name (provider-specific)
|
||||||
|
api_key: API key for remote providers
|
||||||
|
api_base: API base URL for remote providers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingProvider instance
|
||||||
|
"""
|
||||||
|
if provider == "openai":
|
||||||
|
model = model or "text-embedding-3-small"
|
||||||
|
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||||
|
elif provider == "local":
|
||||||
|
model = model or "all-MiniLM-L6-v2"
|
||||||
|
return LocalEmbeddingProvider(model=model)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown embedding provider: {provider}")
|
||||||
623
agent/memory/manager.py
Normal file
623
agent/memory/manager.py
Normal file
@@ -0,0 +1,623 @@
|
|||||||
|
"""
|
||||||
|
Memory manager for AgentMesh
|
||||||
|
|
||||||
|
Provides high-level interface for memory operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from pathlib import Path
|
||||||
|
import hashlib
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||||
|
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||||
|
from agent.memory.chunker import TextChunker
|
||||||
|
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
|
||||||
|
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryManager:
|
||||||
|
"""
|
||||||
|
Memory manager with hybrid search capabilities
|
||||||
|
|
||||||
|
Provides long-term memory for agents with vector and keyword search
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Optional[MemoryConfig] = None,
|
||||||
|
embedding_provider: Optional[EmbeddingProvider] = None,
|
||||||
|
llm_model: Optional[Any] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize memory manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Memory configuration (uses global config if not provided)
|
||||||
|
embedding_provider: Custom embedding provider (optional)
|
||||||
|
llm_model: LLM model for summarization (optional)
|
||||||
|
"""
|
||||||
|
self.config = config or get_default_memory_config()
|
||||||
|
|
||||||
|
# Initialize storage
|
||||||
|
db_path = self.config.get_db_path()
|
||||||
|
self.storage = MemoryStorage(db_path)
|
||||||
|
|
||||||
|
# Initialize chunker
|
||||||
|
self.chunker = TextChunker(
|
||||||
|
max_tokens=self.config.chunk_max_tokens,
|
||||||
|
overlap_tokens=self.config.chunk_overlap_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize embedding provider (optional)
|
||||||
|
self.embedding_provider = None
|
||||||
|
if embedding_provider:
|
||||||
|
self.embedding_provider = embedding_provider
|
||||||
|
else:
|
||||||
|
# Try to create embedding provider, but allow failure
|
||||||
|
try:
|
||||||
|
# Get API key from environment or config
|
||||||
|
api_key = os.environ.get('OPENAI_API_KEY')
|
||||||
|
api_base = os.environ.get('OPENAI_API_BASE')
|
||||||
|
|
||||||
|
self.embedding_provider = create_embedding_provider(
|
||||||
|
provider=self.config.embedding_provider,
|
||||||
|
model=self.config.embedding_model,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Embedding provider failed, but that's OK
|
||||||
|
# We can still use keyword search and file operations
|
||||||
|
print(f"⚠️ Warning: Embedding provider initialization failed: {e}")
|
||||||
|
print(f"ℹ️ Memory will work with keyword search only (no semantic search)")
|
||||||
|
|
||||||
|
# Initialize memory flush manager
|
||||||
|
workspace_dir = self.config.get_workspace()
|
||||||
|
self.flush_manager = MemoryFlushManager(
|
||||||
|
workspace_dir=workspace_dir,
|
||||||
|
llm_model=llm_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure workspace directories exist
|
||||||
|
self._init_workspace()
|
||||||
|
|
||||||
|
self._dirty = False
|
||||||
|
|
||||||
|
def _init_workspace(self):
|
||||||
|
"""Initialize workspace directories"""
|
||||||
|
memory_dir = self.config.get_memory_dir()
|
||||||
|
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create default memory files
|
||||||
|
workspace_dir = self.config.get_workspace()
|
||||||
|
create_memory_files_if_needed(workspace_dir)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
max_results: Optional[int] = None,
|
||||||
|
min_score: Optional[float] = None,
|
||||||
|
include_shared: bool = True
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""
|
||||||
|
Search memory with hybrid search (vector + keyword)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query
|
||||||
|
user_id: User ID for scoped search
|
||||||
|
max_results: Maximum results to return
|
||||||
|
min_score: Minimum score threshold
|
||||||
|
include_shared: Include shared memories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results sorted by relevance
|
||||||
|
"""
|
||||||
|
max_results = max_results or self.config.max_results
|
||||||
|
min_score = min_score or self.config.min_score
|
||||||
|
|
||||||
|
# Determine scopes
|
||||||
|
scopes = []
|
||||||
|
if include_shared:
|
||||||
|
scopes.append("shared")
|
||||||
|
if user_id:
|
||||||
|
scopes.append("user")
|
||||||
|
|
||||||
|
if not scopes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sync if needed
|
||||||
|
if self.config.sync_on_search and self._dirty:
|
||||||
|
await self.sync()
|
||||||
|
|
||||||
|
# Perform vector search (if embedding provider available)
|
||||||
|
vector_results = []
|
||||||
|
if self.embedding_provider:
|
||||||
|
query_embedding = self.embedding_provider.embed(query)
|
||||||
|
vector_results = self.storage.search_vector(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
user_id=user_id,
|
||||||
|
scopes=scopes,
|
||||||
|
limit=max_results * 2 # Get more candidates for merging
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform keyword search
|
||||||
|
keyword_results = self.storage.search_keyword(
|
||||||
|
query=query,
|
||||||
|
user_id=user_id,
|
||||||
|
scopes=scopes,
|
||||||
|
limit=max_results * 2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge results
|
||||||
|
merged = self._merge_results(
|
||||||
|
vector_results,
|
||||||
|
keyword_results,
|
||||||
|
self.config.vector_weight,
|
||||||
|
self.config.keyword_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter by min score and limit
|
||||||
|
filtered = [r for r in merged if r.score >= min_score]
|
||||||
|
return filtered[:max_results]
|
||||||
|
|
||||||
|
async def add_memory(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
scope: str = "shared",
|
||||||
|
source: str = "memory",
|
||||||
|
path: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add new memory content
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Memory content
|
||||||
|
user_id: User ID for user-scoped memory
|
||||||
|
scope: Memory scope ("shared", "user", "session")
|
||||||
|
source: Memory source ("memory" or "session")
|
||||||
|
path: File path (auto-generated if not provided)
|
||||||
|
metadata: Additional metadata
|
||||||
|
"""
|
||||||
|
if not content.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Generate path if not provided
|
||||||
|
if not path:
|
||||||
|
content_hash = hashlib.md5(content.encode('utf-8')).hexdigest()[:8]
|
||||||
|
if user_id and scope == "user":
|
||||||
|
path = f"memory/users/{user_id}/memory_{content_hash}.md"
|
||||||
|
else:
|
||||||
|
path = f"memory/shared/memory_{content_hash}.md"
|
||||||
|
|
||||||
|
# Chunk content
|
||||||
|
chunks = self.chunker.chunk_text(content)
|
||||||
|
|
||||||
|
# Generate embeddings (if provider available)
|
||||||
|
texts = [chunk.text for chunk in chunks]
|
||||||
|
if self.embedding_provider:
|
||||||
|
embeddings = self.embedding_provider.embed_batch(texts)
|
||||||
|
else:
|
||||||
|
# No embeddings, just use None
|
||||||
|
embeddings = [None] * len(texts)
|
||||||
|
|
||||||
|
# Create memory chunks
|
||||||
|
memory_chunks = []
|
||||||
|
for chunk, embedding in zip(chunks, embeddings):
|
||||||
|
chunk_id = self._generate_chunk_id(path, chunk.start_line, chunk.end_line)
|
||||||
|
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||||
|
|
||||||
|
memory_chunks.append(MemoryChunk(
|
||||||
|
id=chunk_id,
|
||||||
|
agent_id="default",
|
||||||
|
user_id=user_id,
|
||||||
|
scope=scope,
|
||||||
|
source=source,
|
||||||
|
path=path,
|
||||||
|
start_line=chunk.start_line,
|
||||||
|
end_line=chunk.end_line,
|
||||||
|
text=chunk.text,
|
||||||
|
embedding=embedding,
|
||||||
|
hash=chunk_hash,
|
||||||
|
metadata=metadata
|
||||||
|
))
|
||||||
|
|
||||||
|
# Save to storage
|
||||||
|
self.storage.save_chunks_batch(memory_chunks)
|
||||||
|
|
||||||
|
# Update file metadata
|
||||||
|
file_hash = MemoryStorage.compute_hash(content)
|
||||||
|
self.storage.update_file_metadata(
|
||||||
|
path=path,
|
||||||
|
source=source,
|
||||||
|
file_hash=file_hash,
|
||||||
|
mtime=int(os.path.getmtime(__file__)), # Use current time
|
||||||
|
size=len(content)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def sync(self, force: bool = False):
|
||||||
|
"""
|
||||||
|
Synchronize memory from files
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: Force full reindex
|
||||||
|
"""
|
||||||
|
memory_dir = self.config.get_memory_dir()
|
||||||
|
workspace_dir = self.config.get_workspace()
|
||||||
|
|
||||||
|
# Scan memory/MEMORY.md
|
||||||
|
memory_file = memory_dir / "MEMORY.md"
|
||||||
|
if memory_file.exists():
|
||||||
|
await self._sync_file(memory_file, "memory", "shared", None)
|
||||||
|
|
||||||
|
# Scan memory directory (including daily summaries)
|
||||||
|
if memory_dir.exists():
|
||||||
|
for file_path in memory_dir.rglob("*.md"):
|
||||||
|
# Determine scope and user_id from path
|
||||||
|
rel_path = file_path.relative_to(workspace_dir)
|
||||||
|
parts = rel_path.parts
|
||||||
|
|
||||||
|
# Check if it's in daily summary directory
|
||||||
|
if "daily" in parts:
|
||||||
|
# Daily summary files
|
||||||
|
if "users" in parts or len(parts) > 3:
|
||||||
|
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
|
||||||
|
user_idx = parts.index("daily") + 1
|
||||||
|
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||||
|
scope = "user"
|
||||||
|
else:
|
||||||
|
# Shared daily summary: memory/daily/2024-01-29.md
|
||||||
|
user_id = None
|
||||||
|
scope = "shared"
|
||||||
|
elif "users" in parts:
|
||||||
|
# User-scoped memory
|
||||||
|
user_idx = parts.index("users") + 1
|
||||||
|
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||||
|
scope = "user"
|
||||||
|
else:
|
||||||
|
# Shared memory
|
||||||
|
user_id = None
|
||||||
|
scope = "shared"
|
||||||
|
|
||||||
|
await self._sync_file(file_path, "memory", scope, user_id)
|
||||||
|
|
||||||
|
self._dirty = False
|
||||||
|
|
||||||
|
async def _sync_file(
|
||||||
|
self,
|
||||||
|
file_path: Path,
|
||||||
|
source: str,
|
||||||
|
scope: str,
|
||||||
|
user_id: Optional[str]
|
||||||
|
):
|
||||||
|
"""Sync a single file"""
|
||||||
|
# Compute file hash
|
||||||
|
content = file_path.read_text()
|
||||||
|
file_hash = MemoryStorage.compute_hash(content)
|
||||||
|
|
||||||
|
# Get relative path
|
||||||
|
workspace_dir = self.config.get_workspace()
|
||||||
|
rel_path = str(file_path.relative_to(workspace_dir))
|
||||||
|
|
||||||
|
# Check if file changed
|
||||||
|
stored_hash = self.storage.get_file_hash(rel_path)
|
||||||
|
if stored_hash == file_hash:
|
||||||
|
return # No changes
|
||||||
|
|
||||||
|
# Delete old chunks
|
||||||
|
self.storage.delete_by_path(rel_path)
|
||||||
|
|
||||||
|
# Chunk and embed
|
||||||
|
chunks = self.chunker.chunk_text(content)
|
||||||
|
if not chunks:
|
||||||
|
return
|
||||||
|
|
||||||
|
texts = [chunk.text for chunk in chunks]
|
||||||
|
if self.embedding_provider:
|
||||||
|
embeddings = self.embedding_provider.embed_batch(texts)
|
||||||
|
else:
|
||||||
|
embeddings = [None] * len(texts)
|
||||||
|
|
||||||
|
# Create memory chunks
|
||||||
|
memory_chunks = []
|
||||||
|
for chunk, embedding in zip(chunks, embeddings):
|
||||||
|
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||||
|
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||||
|
|
||||||
|
memory_chunks.append(MemoryChunk(
|
||||||
|
id=chunk_id,
|
||||||
|
agent_id="default",
|
||||||
|
user_id=user_id,
|
||||||
|
scope=scope,
|
||||||
|
source=source,
|
||||||
|
path=rel_path,
|
||||||
|
start_line=chunk.start_line,
|
||||||
|
end_line=chunk.end_line,
|
||||||
|
text=chunk.text,
|
||||||
|
embedding=embedding,
|
||||||
|
hash=chunk_hash,
|
||||||
|
metadata=None
|
||||||
|
))
|
||||||
|
|
||||||
|
# Save
|
||||||
|
self.storage.save_chunks_batch(memory_chunks)
|
||||||
|
|
||||||
|
# Update file metadata
|
||||||
|
stat = file_path.stat()
|
||||||
|
self.storage.update_file_metadata(
|
||||||
|
path=rel_path,
|
||||||
|
source=source,
|
||||||
|
file_hash=file_hash,
|
||||||
|
mtime=int(stat.st_mtime),
|
||||||
|
size=stat.st_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_flush_memory(
|
||||||
|
self,
|
||||||
|
current_tokens: int,
|
||||||
|
context_window: int = 128000,
|
||||||
|
reserve_tokens: int = 20000,
|
||||||
|
soft_threshold: int = 4000
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if memory flush should be triggered
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_tokens: Current session token count
|
||||||
|
context_window: Model's context window size (default: 128K)
|
||||||
|
reserve_tokens: Reserve tokens for compaction overhead (default: 20K)
|
||||||
|
soft_threshold: Trigger N tokens before threshold (default: 4K)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if memory flush should run
|
||||||
|
"""
|
||||||
|
return self.flush_manager.should_flush(
|
||||||
|
current_tokens=current_tokens,
|
||||||
|
context_window=context_window,
|
||||||
|
reserve_tokens=reserve_tokens,
|
||||||
|
soft_threshold=soft_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_memory_flush(
|
||||||
|
self,
|
||||||
|
agent_executor,
|
||||||
|
current_tokens: int,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
**executor_kwargs
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Execute memory flush before compaction
|
||||||
|
|
||||||
|
This runs a silent agent turn to write durable memories to disk.
|
||||||
|
Similar to clawdbot's pre-compaction memory flush.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_executor: Async function to execute agent with prompt
|
||||||
|
current_tokens: Current session token count
|
||||||
|
user_id: Optional user ID
|
||||||
|
**executor_kwargs: Additional kwargs for agent executor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if flush completed successfully
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> async def run_agent(prompt, system_prompt, silent=False):
|
||||||
|
... # Your agent execution logic
|
||||||
|
... pass
|
||||||
|
>>>
|
||||||
|
>>> if manager.should_flush_memory(current_tokens=100000):
|
||||||
|
... await manager.execute_memory_flush(
|
||||||
|
... agent_executor=run_agent,
|
||||||
|
... current_tokens=100000
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
success = await self.flush_manager.execute_flush(
|
||||||
|
agent_executor=agent_executor,
|
||||||
|
current_tokens=current_tokens,
|
||||||
|
user_id=user_id,
|
||||||
|
**executor_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
# Mark dirty so next search will sync the new memories
|
||||||
|
self._dirty = True
|
||||||
|
|
||||||
|
return success
|
||||||
|
|
||||||
|
def build_memory_guidance(self, lang: str = "en", include_context: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
Build natural memory guidance for agent system prompt
|
||||||
|
|
||||||
|
Following clawdbot's approach:
|
||||||
|
1. Load MEMORY.md as bootstrap context (blends into background)
|
||||||
|
2. Load daily files on-demand via memory_search tool
|
||||||
|
3. Agent should NOT proactively mention memories unless user asks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lang: Language for guidance ("en" or "zh")
|
||||||
|
include_context: Whether to include bootstrap memory context (default: True)
|
||||||
|
MEMORY.md is loaded as background context (like clawdbot)
|
||||||
|
Daily files are accessed via memory_search tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Memory guidance text (and optionally context) for system prompt
|
||||||
|
"""
|
||||||
|
today_file = self.flush_manager.get_today_memory_file().name
|
||||||
|
|
||||||
|
if lang == "zh":
|
||||||
|
guidance = f"""## 记忆召回
|
||||||
|
回答关于过去工作、决策、日期、人物、偏好或待办事项的问题前:先用 memory_search 搜索 MEMORY.md + memory/*.md;然后用 memory_get 只读取需要的行。如果搜索后仍不确定,说明你已检查过。
|
||||||
|
|
||||||
|
## 记忆存储
|
||||||
|
当用户分享持久偏好、决策或重要事实时(无论是否明确要求"记住"),主动存储:
|
||||||
|
- 持久信息(偏好、决策、人物信息)→ memory/MEMORY.md
|
||||||
|
- 当天的笔记和上下文 → memory/{today_file}
|
||||||
|
- 静默存储,仅在用户明确要求时确认
|
||||||
|
|
||||||
|
## 记忆使用原则
|
||||||
|
- 不要主动提起或列举记忆内容
|
||||||
|
- 只在用户明确询问相关信息时才使用记忆
|
||||||
|
- 记忆是背景知识,不是要展示的内容
|
||||||
|
- 自然使用记忆,就像你本来就知道这些信息"""
|
||||||
|
else:
|
||||||
|
guidance = f"""## Memory Recall
|
||||||
|
Before answering anything about prior work, decisions, dates, people, preferences, or todos: run memory_search on MEMORY.md + memory/*.md; then use memory_get to pull only the needed lines. If low confidence after search, say you checked.
|
||||||
|
|
||||||
|
## Memory Storage
|
||||||
|
When user shares durable preferences, decisions, or important facts (whether or not they explicitly say "remember"), proactively store:
|
||||||
|
- Durable info (preferences, decisions, people) → memory/MEMORY.md
|
||||||
|
- Daily notes and context → memory/{today_file}
|
||||||
|
- Store silently; only confirm when explicitly requested
|
||||||
|
|
||||||
|
## Memory Usage Principles
|
||||||
|
- Don't proactively mention or list memory contents
|
||||||
|
- Only use memories when user explicitly asks about them
|
||||||
|
- Memories are background knowledge, not content to showcase
|
||||||
|
- Use memories naturally as if you inherently knew this information"""
|
||||||
|
|
||||||
|
if include_context:
|
||||||
|
# Load bootstrap context (MEMORY.md only, like clawdbot)
|
||||||
|
bootstrap_context = self.load_bootstrap_memories()
|
||||||
|
if bootstrap_context:
|
||||||
|
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
|
||||||
|
|
||||||
|
return guidance
|
||||||
|
|
||||||
|
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Load bootstrap memory files for session start
|
||||||
|
|
||||||
|
Following clawdbot's design:
|
||||||
|
- Only loads memory/MEMORY.md (long-term curated memory)
|
||||||
|
- Daily files (YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
|
||||||
|
- User-specific MEMORY.md is also loaded if user_id provided
|
||||||
|
|
||||||
|
Returns memory content WITHOUT obvious headers so it blends naturally
|
||||||
|
into the context as background knowledge.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Optional user ID for user-specific memories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Memory content to inject into system prompt (blends naturally as background context)
|
||||||
|
"""
|
||||||
|
workspace_dir = self.config.get_workspace()
|
||||||
|
memory_dir = self.config.get_memory_dir()
|
||||||
|
|
||||||
|
sections = []
|
||||||
|
|
||||||
|
# 1. Load memory/MEMORY.md ONLY (long-term curated memory)
|
||||||
|
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
|
||||||
|
memory_file = memory_dir / "MEMORY.md"
|
||||||
|
if memory_file.exists():
|
||||||
|
try:
|
||||||
|
content = memory_file.read_text(encoding='utf-8').strip()
|
||||||
|
if content:
|
||||||
|
sections.append(content)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to read memory/MEMORY.md: {e}")
|
||||||
|
|
||||||
|
# 2. Load user-specific MEMORY.md if user_id provided
|
||||||
|
if user_id:
|
||||||
|
user_memory_dir = memory_dir / "users" / user_id
|
||||||
|
user_memory_file = user_memory_dir / "MEMORY.md"
|
||||||
|
if user_memory_file.exists():
|
||||||
|
try:
|
||||||
|
content = user_memory_file.read_text(encoding='utf-8').strip()
|
||||||
|
if content:
|
||||||
|
sections.append(content)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Failed to read user memory: {e}")
|
||||||
|
|
||||||
|
if not sections:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Join sections without obvious headers - let memories blend naturally
|
||||||
|
# This makes the agent feel like it "just knows" rather than "checking memory files"
|
||||||
|
return "\n\n".join(sections)
|
||||||
|
|
||||||
|
def get_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get memory status"""
|
||||||
|
stats = self.storage.get_stats()
|
||||||
|
return {
|
||||||
|
'chunks': stats['chunks'],
|
||||||
|
'files': stats['files'],
|
||||||
|
'workspace': str(self.config.get_workspace()),
|
||||||
|
'dirty': self._dirty,
|
||||||
|
'embedding_enabled': self.embedding_provider is not None,
|
||||||
|
'embedding_provider': self.config.embedding_provider if self.embedding_provider else 'disabled',
|
||||||
|
'embedding_model': self.config.embedding_model if self.embedding_provider else 'N/A',
|
||||||
|
'search_mode': 'hybrid (vector + keyword)' if self.embedding_provider else 'keyword only (FTS5)'
|
||||||
|
}
|
||||||
|
|
||||||
|
def mark_dirty(self):
|
||||||
|
"""Mark memory as dirty (needs sync)"""
|
||||||
|
self._dirty = True
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close memory manager and release resources"""
|
||||||
|
self.storage.close()
|
||||||
|
|
||||||
|
# Helper methods
|
||||||
|
|
||||||
|
def _generate_chunk_id(self, path: str, start_line: int, end_line: int) -> str:
|
||||||
|
"""Generate unique chunk ID"""
|
||||||
|
content = f"{path}:{start_line}:{end_line}"
|
||||||
|
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||||
|
|
||||||
|
def _merge_results(
|
||||||
|
self,
|
||||||
|
vector_results: List[SearchResult],
|
||||||
|
keyword_results: List[SearchResult],
|
||||||
|
vector_weight: float,
|
||||||
|
keyword_weight: float
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Merge vector and keyword search results"""
|
||||||
|
# Create a map by (path, start_line, end_line)
|
||||||
|
merged_map = {}
|
||||||
|
|
||||||
|
for result in vector_results:
|
||||||
|
key = (result.path, result.start_line, result.end_line)
|
||||||
|
merged_map[key] = {
|
||||||
|
'result': result,
|
||||||
|
'vector_score': result.score,
|
||||||
|
'keyword_score': 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
for result in keyword_results:
|
||||||
|
key = (result.path, result.start_line, result.end_line)
|
||||||
|
if key in merged_map:
|
||||||
|
merged_map[key]['keyword_score'] = result.score
|
||||||
|
else:
|
||||||
|
merged_map[key] = {
|
||||||
|
'result': result,
|
||||||
|
'vector_score': 0.0,
|
||||||
|
'keyword_score': result.score
|
||||||
|
}
|
||||||
|
|
||||||
|
# Calculate combined scores
|
||||||
|
merged_results = []
|
||||||
|
for entry in merged_map.values():
|
||||||
|
combined_score = (
|
||||||
|
vector_weight * entry['vector_score'] +
|
||||||
|
keyword_weight * entry['keyword_score']
|
||||||
|
)
|
||||||
|
|
||||||
|
result = entry['result']
|
||||||
|
merged_results.append(SearchResult(
|
||||||
|
path=result.path,
|
||||||
|
start_line=result.start_line,
|
||||||
|
end_line=result.end_line,
|
||||||
|
score=combined_score,
|
||||||
|
snippet=result.snippet,
|
||||||
|
source=result.source,
|
||||||
|
user_id=result.user_id
|
||||||
|
))
|
||||||
|
|
||||||
|
# Sort by score
|
||||||
|
merged_results.sort(key=lambda r: r.score, reverse=True)
|
||||||
|
return merged_results
|
||||||
418
agent/memory/storage.py
Normal file
418
agent/memory/storage.py
Normal file
@@ -0,0 +1,418 @@
|
|||||||
|
"""
|
||||||
|
Storage layer for memory using SQLite + FTS5
|
||||||
|
|
||||||
|
Provides vector and keyword search capabilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryChunk:
|
||||||
|
"""Represents a memory chunk with text and embedding"""
|
||||||
|
id: str
|
||||||
|
user_id: Optional[str]
|
||||||
|
scope: str # "shared" | "user" | "session"
|
||||||
|
source: str # "memory" | "session"
|
||||||
|
path: str
|
||||||
|
start_line: int
|
||||||
|
end_line: int
|
||||||
|
text: str
|
||||||
|
embedding: Optional[List[float]]
|
||||||
|
hash: str
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchResult:
|
||||||
|
"""Search result with score and snippet"""
|
||||||
|
path: str
|
||||||
|
start_line: int
|
||||||
|
end_line: int
|
||||||
|
score: float
|
||||||
|
snippet: str
|
||||||
|
source: str
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryStorage:
|
||||||
|
"""SQLite-based storage with FTS5 for keyword search"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path):
|
||||||
|
self.db_path = db_path
|
||||||
|
self.conn: Optional[sqlite3.Connection] = None
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
"""Initialize database with schema"""
|
||||||
|
self.conn = sqlite3.connect(str(self.db_path))
|
||||||
|
self.conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
# Enable JSON support
|
||||||
|
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
|
||||||
|
# Create chunks table with embeddings
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS chunks (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
user_id TEXT,
|
||||||
|
scope TEXT NOT NULL DEFAULT 'shared',
|
||||||
|
source TEXT NOT NULL DEFAULT 'memory',
|
||||||
|
path TEXT NOT NULL,
|
||||||
|
start_line INTEGER NOT NULL,
|
||||||
|
end_line INTEGER NOT NULL,
|
||||||
|
text TEXT NOT NULL,
|
||||||
|
embedding TEXT,
|
||||||
|
hash TEXT NOT NULL,
|
||||||
|
metadata TEXT,
|
||||||
|
created_at INTEGER DEFAULT (strftime('%s', 'now')),
|
||||||
|
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_user
|
||||||
|
ON chunks(user_id)
|
||||||
|
""")
|
||||||
|
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_scope
|
||||||
|
ON chunks(scope)
|
||||||
|
""")
|
||||||
|
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_chunks_hash
|
||||||
|
ON chunks(path, hash)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create FTS5 virtual table for keyword search
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
||||||
|
text,
|
||||||
|
id UNINDEXED,
|
||||||
|
user_id UNINDEXED,
|
||||||
|
path UNINDEXED,
|
||||||
|
source UNINDEXED,
|
||||||
|
scope UNINDEXED,
|
||||||
|
content='chunks',
|
||||||
|
content_rowid='rowid'
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create triggers to keep FTS in sync
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
||||||
|
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
||||||
|
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
||||||
|
END
|
||||||
|
""")
|
||||||
|
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
||||||
|
DELETE FROM chunks_fts WHERE rowid = old.rowid;
|
||||||
|
END
|
||||||
|
""")
|
||||||
|
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
||||||
|
UPDATE chunks_fts SET text = new.text, id = new.id,
|
||||||
|
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
|
||||||
|
WHERE rowid = new.rowid;
|
||||||
|
END
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Create files metadata table
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS files (
|
||||||
|
path TEXT PRIMARY KEY,
|
||||||
|
source TEXT NOT NULL DEFAULT 'memory',
|
||||||
|
hash TEXT NOT NULL,
|
||||||
|
mtime INTEGER NOT NULL,
|
||||||
|
size INTEGER NOT NULL,
|
||||||
|
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def save_chunk(self, chunk: MemoryChunk):
|
||||||
|
"""Save a memory chunk"""
|
||||||
|
self.conn.execute("""
|
||||||
|
INSERT OR REPLACE INTO chunks
|
||||||
|
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||||
|
""", (
|
||||||
|
chunk.id,
|
||||||
|
chunk.user_id,
|
||||||
|
chunk.scope,
|
||||||
|
chunk.source,
|
||||||
|
chunk.path,
|
||||||
|
chunk.start_line,
|
||||||
|
chunk.end_line,
|
||||||
|
chunk.text,
|
||||||
|
json.dumps(chunk.embedding) if chunk.embedding else None,
|
||||||
|
chunk.hash,
|
||||||
|
json.dumps(chunk.metadata) if chunk.metadata else None
|
||||||
|
))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def save_chunks_batch(self, chunks: List[MemoryChunk]):
|
||||||
|
"""Save multiple chunks in a batch"""
|
||||||
|
self.conn.executemany("""
|
||||||
|
INSERT OR REPLACE INTO chunks
|
||||||
|
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||||
|
""", [
|
||||||
|
(
|
||||||
|
c.id, c.user_id, c.scope, c.source, c.path,
|
||||||
|
c.start_line, c.end_line, c.text,
|
||||||
|
json.dumps(c.embedding) if c.embedding else None,
|
||||||
|
c.hash,
|
||||||
|
json.dumps(c.metadata) if c.metadata else None
|
||||||
|
)
|
||||||
|
for c in chunks
|
||||||
|
])
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
|
||||||
|
"""Get a chunk by ID"""
|
||||||
|
row = self.conn.execute("""
|
||||||
|
SELECT * FROM chunks WHERE id = ?
|
||||||
|
""", (chunk_id,)).fetchone()
|
||||||
|
|
||||||
|
if not row:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._row_to_chunk(row)
|
||||||
|
|
||||||
|
def search_vector(
|
||||||
|
self,
|
||||||
|
query_embedding: List[float],
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
scopes: List[str] = None,
|
||||||
|
limit: int = 10
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""
|
||||||
|
Vector similarity search using in-memory cosine similarity
|
||||||
|
(sqlite-vec can be added later for better performance)
|
||||||
|
"""
|
||||||
|
if scopes is None:
|
||||||
|
scopes = ["shared"]
|
||||||
|
if user_id:
|
||||||
|
scopes.append("user")
|
||||||
|
|
||||||
|
# Build query
|
||||||
|
scope_placeholders = ','.join('?' * len(scopes))
|
||||||
|
params = scopes
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
query = f"""
|
||||||
|
SELECT * FROM chunks
|
||||||
|
WHERE scope IN ({scope_placeholders})
|
||||||
|
AND (scope = 'shared' OR user_id = ?)
|
||||||
|
AND embedding IS NOT NULL
|
||||||
|
"""
|
||||||
|
params.append(user_id)
|
||||||
|
else:
|
||||||
|
query = f"""
|
||||||
|
SELECT * FROM chunks
|
||||||
|
WHERE scope IN ({scope_placeholders})
|
||||||
|
AND embedding IS NOT NULL
|
||||||
|
"""
|
||||||
|
|
||||||
|
rows = self.conn.execute(query, params).fetchall()
|
||||||
|
|
||||||
|
# Calculate cosine similarity
|
||||||
|
results = []
|
||||||
|
for row in rows:
|
||||||
|
embedding = json.loads(row['embedding'])
|
||||||
|
similarity = self._cosine_similarity(query_embedding, embedding)
|
||||||
|
|
||||||
|
if similarity > 0:
|
||||||
|
results.append((similarity, row))
|
||||||
|
|
||||||
|
# Sort by similarity and limit
|
||||||
|
results.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
results = results[:limit]
|
||||||
|
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
path=row['path'],
|
||||||
|
start_line=row['start_line'],
|
||||||
|
end_line=row['end_line'],
|
||||||
|
score=score,
|
||||||
|
snippet=self._truncate_text(row['text'], 500),
|
||||||
|
source=row['source'],
|
||||||
|
user_id=row['user_id']
|
||||||
|
)
|
||||||
|
for score, row in results
|
||||||
|
]
|
||||||
|
|
||||||
|
def search_keyword(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
scopes: List[str] = None,
|
||||||
|
limit: int = 10
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""Keyword search using FTS5"""
|
||||||
|
if scopes is None:
|
||||||
|
scopes = ["shared"]
|
||||||
|
if user_id:
|
||||||
|
scopes.append("user")
|
||||||
|
|
||||||
|
# Build FTS query
|
||||||
|
fts_query = self._build_fts_query(query)
|
||||||
|
if not fts_query:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scope_placeholders = ','.join('?' * len(scopes))
|
||||||
|
params = [fts_query] + scopes
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
sql_query = f"""
|
||||||
|
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||||
|
FROM chunks_fts
|
||||||
|
JOIN chunks ON chunks.id = chunks_fts.id
|
||||||
|
WHERE chunks_fts MATCH ?
|
||||||
|
AND chunks.scope IN ({scope_placeholders})
|
||||||
|
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
|
||||||
|
ORDER BY rank
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
params.extend([user_id, limit])
|
||||||
|
else:
|
||||||
|
sql_query = f"""
|
||||||
|
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||||
|
FROM chunks_fts
|
||||||
|
JOIN chunks ON chunks.id = chunks_fts.id
|
||||||
|
WHERE chunks_fts MATCH ?
|
||||||
|
AND chunks.scope IN ({scope_placeholders})
|
||||||
|
ORDER BY rank
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
params.append(limit)
|
||||||
|
|
||||||
|
rows = self.conn.execute(sql_query, params).fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
path=row['path'],
|
||||||
|
start_line=row['start_line'],
|
||||||
|
end_line=row['end_line'],
|
||||||
|
score=self._bm25_rank_to_score(row['rank']),
|
||||||
|
snippet=self._truncate_text(row['text'], 500),
|
||||||
|
source=row['source'],
|
||||||
|
user_id=row['user_id']
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
def delete_by_path(self, path: str):
|
||||||
|
"""Delete all chunks from a file"""
|
||||||
|
self.conn.execute("""
|
||||||
|
DELETE FROM chunks WHERE path = ?
|
||||||
|
""", (path,))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def get_file_hash(self, path: str) -> Optional[str]:
|
||||||
|
"""Get stored file hash"""
|
||||||
|
row = self.conn.execute("""
|
||||||
|
SELECT hash FROM files WHERE path = ?
|
||||||
|
""", (path,)).fetchone()
|
||||||
|
return row['hash'] if row else None
|
||||||
|
|
||||||
|
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
|
||||||
|
"""Update file metadata"""
|
||||||
|
self.conn.execute("""
|
||||||
|
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||||
|
""", (path, source, file_hash, mtime, size))
|
||||||
|
self.conn.commit()
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, int]:
|
||||||
|
"""Get storage statistics"""
|
||||||
|
chunks_count = self.conn.execute("""
|
||||||
|
SELECT COUNT(*) as cnt FROM chunks
|
||||||
|
""").fetchone()['cnt']
|
||||||
|
|
||||||
|
files_count = self.conn.execute("""
|
||||||
|
SELECT COUNT(*) as cnt FROM files
|
||||||
|
""").fetchone()['cnt']
|
||||||
|
|
||||||
|
return {
|
||||||
|
'chunks': chunks_count,
|
||||||
|
'files': files_count
|
||||||
|
}
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close database connection"""
|
||||||
|
if self.conn:
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
|
# Helper methods
|
||||||
|
|
||||||
|
def _row_to_chunk(self, row) -> MemoryChunk:
|
||||||
|
"""Convert database row to MemoryChunk"""
|
||||||
|
return MemoryChunk(
|
||||||
|
id=row['id'],
|
||||||
|
user_id=row['user_id'],
|
||||||
|
scope=row['scope'],
|
||||||
|
source=row['source'],
|
||||||
|
path=row['path'],
|
||||||
|
start_line=row['start_line'],
|
||||||
|
end_line=row['end_line'],
|
||||||
|
text=row['text'],
|
||||||
|
embedding=json.loads(row['embedding']) if row['embedding'] else None,
|
||||||
|
hash=row['hash'],
|
||||||
|
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||||
|
"""Calculate cosine similarity between two vectors"""
|
||||||
|
if len(vec1) != len(vec2):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
||||||
|
norm1 = sum(a * a for a in vec1) ** 0.5
|
||||||
|
norm2 = sum(b * b for b in vec2) ** 0.5
|
||||||
|
|
||||||
|
if norm1 == 0 or norm2 == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return dot_product / (norm1 * norm2)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_fts_query(raw_query: str) -> Optional[str]:
|
||||||
|
"""Build FTS5 query from raw text"""
|
||||||
|
import re
|
||||||
|
tokens = re.findall(r'[A-Za-z0-9_\u4e00-\u9fff]+', raw_query)
|
||||||
|
if not tokens:
|
||||||
|
return None
|
||||||
|
quoted = [f'"{t}"' for t in tokens]
|
||||||
|
return ' AND '.join(quoted)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _bm25_rank_to_score(rank: float) -> float:
|
||||||
|
"""Convert BM25 rank to 0-1 score"""
|
||||||
|
normalized = max(0, rank) if rank is not None else 999
|
||||||
|
return 1 / (1 + normalized)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _truncate_text(text: str, max_chars: int) -> str:
|
||||||
|
"""Truncate text to max characters"""
|
||||||
|
if len(text) <= max_chars:
|
||||||
|
return text
|
||||||
|
return text[:max_chars] + "..."
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_hash(content: str) -> str:
|
||||||
|
"""Compute SHA256 hash of content"""
|
||||||
|
return hashlib.sha256(content.encode('utf-8')).hexdigest()
|
||||||
235
agent/memory/summarizer.py
Normal file
235
agent/memory/summarizer.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""
|
||||||
|
Memory flush manager
|
||||||
|
|
||||||
|
Triggers memory flush before context compaction (similar to clawdbot)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Callable, Any
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryFlushManager:
|
||||||
|
"""
|
||||||
|
Manages memory flush operations before context compaction
|
||||||
|
|
||||||
|
Similar to clawdbot's memory flush mechanism:
|
||||||
|
- Triggers when context approaches token limit
|
||||||
|
- Runs a silent agent turn to write memories to disk
|
||||||
|
- Uses memory/YYYY-MM-DD.md for daily notes
|
||||||
|
- Uses MEMORY.md for long-term curated memories
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace_dir: Path,
|
||||||
|
llm_model: Optional[Any] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize memory flush manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_dir: Workspace directory
|
||||||
|
llm_model: LLM model for agent execution (optional)
|
||||||
|
"""
|
||||||
|
self.workspace_dir = workspace_dir
|
||||||
|
self.llm_model = llm_model
|
||||||
|
|
||||||
|
self.memory_dir = workspace_dir / "memory"
|
||||||
|
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Tracking
|
||||||
|
self.last_flush_token_count: Optional[int] = None
|
||||||
|
self.last_flush_timestamp: Optional[datetime] = None
|
||||||
|
|
||||||
|
def should_flush(
|
||||||
|
self,
|
||||||
|
current_tokens: int,
|
||||||
|
context_window: int,
|
||||||
|
reserve_tokens: int = 20000,
|
||||||
|
soft_threshold: int = 4000
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if memory flush should be triggered
|
||||||
|
|
||||||
|
Similar to clawdbot's shouldRunMemoryFlush logic:
|
||||||
|
threshold = contextWindow - reserveTokens - softThreshold
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_tokens: Current session token count
|
||||||
|
context_window: Model's context window size
|
||||||
|
reserve_tokens: Reserve tokens for compaction overhead
|
||||||
|
soft_threshold: Trigger flush N tokens before threshold
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if flush should run
|
||||||
|
"""
|
||||||
|
if current_tokens <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
threshold = max(0, context_window - reserve_tokens - soft_threshold)
|
||||||
|
if threshold <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if we've crossed the threshold
|
||||||
|
if current_tokens < threshold:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Avoid duplicate flush in same compaction cycle
|
||||||
|
if self.last_flush_token_count is not None:
|
||||||
|
if current_tokens <= self.last_flush_token_count + soft_threshold:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||||
|
"""
|
||||||
|
Get today's memory file path: memory/YYYY-MM-DD.md
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Optional user ID for user-specific memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to today's memory file
|
||||||
|
"""
|
||||||
|
today = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
user_dir = self.memory_dir / "users" / user_id
|
||||||
|
user_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
return user_dir / f"{today}.md"
|
||||||
|
else:
|
||||||
|
return self.memory_dir / f"{today}.md"
|
||||||
|
|
||||||
|
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||||
|
"""
|
||||||
|
Get main memory file path: memory/MEMORY.md
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Optional user ID for user-specific memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to main memory file
|
||||||
|
"""
|
||||||
|
if user_id:
|
||||||
|
user_dir = self.memory_dir / "users" / user_id
|
||||||
|
user_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
return user_dir / "MEMORY.md"
|
||||||
|
else:
|
||||||
|
return self.memory_dir / "MEMORY.md"
|
||||||
|
|
||||||
|
def create_flush_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Create prompt for memory flush turn
|
||||||
|
|
||||||
|
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_PROMPT
|
||||||
|
"""
|
||||||
|
today = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
return (
|
||||||
|
f"Pre-compaction memory flush. "
|
||||||
|
f"Store durable memories now (use memory/{today}.md for daily notes; "
|
||||||
|
f"create memory/ if needed). "
|
||||||
|
f"If nothing to store, reply with NO_REPLY."
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_flush_system_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Create system prompt for memory flush turn
|
||||||
|
|
||||||
|
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_SYSTEM_PROMPT
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
"Pre-compaction memory flush turn. "
|
||||||
|
"The session is near auto-compaction; capture durable memories to disk. "
|
||||||
|
"You may reply, but usually NO_REPLY is correct."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_flush(
|
||||||
|
self,
|
||||||
|
agent_executor: Callable,
|
||||||
|
current_tokens: int,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
**executor_kwargs
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Execute memory flush by running a silent agent turn
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_executor: Function to execute agent with prompt
|
||||||
|
current_tokens: Current token count
|
||||||
|
user_id: Optional user ID
|
||||||
|
**executor_kwargs: Additional kwargs for agent executor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if flush completed successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Create flush prompts
|
||||||
|
prompt = self.create_flush_prompt()
|
||||||
|
system_prompt = self.create_flush_system_prompt()
|
||||||
|
|
||||||
|
# Execute agent turn (silent, no user-visible reply expected)
|
||||||
|
await agent_executor(
|
||||||
|
prompt=prompt,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
silent=True, # NO_REPLY expected
|
||||||
|
**executor_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track flush
|
||||||
|
self.last_flush_token_count = current_tokens
|
||||||
|
self.last_flush_timestamp = datetime.now()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Memory flush failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_status(self) -> dict:
|
||||||
|
"""Get memory flush status"""
|
||||||
|
return {
|
||||||
|
'last_flush_tokens': self.last_flush_token_count,
|
||||||
|
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
|
||||||
|
'today_file': str(self.get_today_memory_file()),
|
||||||
|
'main_file': str(self.get_main_memory_file())
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Create default memory files if they don't exist
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_dir: Workspace directory
|
||||||
|
user_id: Optional user ID for user-specific files
|
||||||
|
"""
|
||||||
|
memory_dir = workspace_dir / "memory"
|
||||||
|
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create main MEMORY.md in memory directory
|
||||||
|
if user_id:
|
||||||
|
user_dir = memory_dir / "users" / user_id
|
||||||
|
user_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
main_memory = user_dir / "MEMORY.md"
|
||||||
|
else:
|
||||||
|
main_memory = memory_dir / "MEMORY.md"
|
||||||
|
|
||||||
|
if not main_memory.exists():
|
||||||
|
# Create empty file or with minimal structure (no obvious "Memory" header)
|
||||||
|
# Following clawdbot's approach: memories should blend naturally into context
|
||||||
|
main_memory.write_text("")
|
||||||
|
|
||||||
|
# Create today's memory file
|
||||||
|
today = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
if user_id:
|
||||||
|
user_dir = memory_dir / "users" / user_id
|
||||||
|
today_memory = user_dir / f"{today}.md"
|
||||||
|
else:
|
||||||
|
today_memory = memory_dir / f"{today}.md"
|
||||||
|
|
||||||
|
if not today_memory.exists():
|
||||||
|
today_memory.write_text(
|
||||||
|
f"# Daily Memory: {today}\n\n"
|
||||||
|
f"Day-to-day notes and running context.\n\n"
|
||||||
|
)
|
||||||
10
agent/memory/tools/__init__.py
Normal file
10
agent/memory/tools/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
Memory tools for AgentMesh
|
||||||
|
|
||||||
|
Provides memory_search and memory_get tools for agents
|
||||||
|
"""
|
||||||
|
|
||||||
|
from agent.memory.tools.memory_search import MemorySearchTool
|
||||||
|
from agent.memory.tools.memory_get import MemoryGetTool
|
||||||
|
|
||||||
|
__all__ = ['MemorySearchTool', 'MemoryGetTool']
|
||||||
118
agent/memory/tools/memory_get.py
Normal file
118
agent/memory/tools/memory_get.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""
|
||||||
|
Memory get tool
|
||||||
|
|
||||||
|
Allows agents to read specific sections from memory files
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
from agent.tools.base_tool import BaseTool
|
||||||
|
from agent.memory.manager import MemoryManager
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryGetTool(BaseTool):
|
||||||
|
"""Tool for reading memory file contents"""
|
||||||
|
|
||||||
|
def __init__(self, memory_manager: MemoryManager):
|
||||||
|
"""
|
||||||
|
Initialize memory get tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_manager: MemoryManager instance
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.memory_manager = memory_manager
|
||||||
|
self._name = "memory_get"
|
||||||
|
self._description = (
|
||||||
|
"Read specific memory file content by path and line range. "
|
||||||
|
"Use after memory_search to get full context from historical memory files."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return self._description
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')"
|
||||||
|
},
|
||||||
|
"start_line": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Starting line number (optional, default: 1)",
|
||||||
|
"default": 1
|
||||||
|
},
|
||||||
|
"num_lines": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of lines to read (optional, reads all if not specified)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Execute memory file read
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path
|
||||||
|
start_line: Start line
|
||||||
|
num_lines: Number of lines
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content
|
||||||
|
"""
|
||||||
|
path = kwargs.get("path")
|
||||||
|
start_line = kwargs.get("start_line", 1)
|
||||||
|
num_lines = kwargs.get("num_lines")
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return "Error: path parameter is required"
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace_dir = self.memory_manager.config.get_workspace()
|
||||||
|
file_path = workspace_dir / path
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
return f"Error: File not found: {path}"
|
||||||
|
|
||||||
|
content = file_path.read_text()
|
||||||
|
lines = content.split('\n')
|
||||||
|
|
||||||
|
# Handle line range
|
||||||
|
if start_line < 1:
|
||||||
|
start_line = 1
|
||||||
|
|
||||||
|
start_idx = start_line - 1
|
||||||
|
|
||||||
|
if num_lines:
|
||||||
|
end_idx = start_idx + num_lines
|
||||||
|
selected_lines = lines[start_idx:end_idx]
|
||||||
|
else:
|
||||||
|
selected_lines = lines[start_idx:]
|
||||||
|
|
||||||
|
result = '\n'.join(selected_lines)
|
||||||
|
|
||||||
|
# Add metadata
|
||||||
|
total_lines = len(lines)
|
||||||
|
shown_lines = len(selected_lines)
|
||||||
|
|
||||||
|
output = [
|
||||||
|
f"File: {path}",
|
||||||
|
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
|
||||||
|
"",
|
||||||
|
result
|
||||||
|
]
|
||||||
|
|
||||||
|
return '\n'.join(output)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error reading memory file: {str(e)}"
|
||||||
106
agent/memory/tools/memory_search.py
Normal file
106
agent/memory/tools/memory_search.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
Memory search tool
|
||||||
|
|
||||||
|
Allows agents to search their memory using semantic and keyword search
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from agent.tools.base_tool import BaseTool
|
||||||
|
from agent.memory.manager import MemoryManager
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySearchTool(BaseTool):
|
||||||
|
"""Tool for searching agent memory"""
|
||||||
|
|
||||||
|
def __init__(self, memory_manager: MemoryManager, user_id: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize memory search tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_manager: MemoryManager instance
|
||||||
|
user_id: Optional user ID for scoped search
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.memory_manager = memory_manager
|
||||||
|
self.user_id = user_id
|
||||||
|
self._name = "memory_search"
|
||||||
|
self._description = (
|
||||||
|
"Search historical memory files (beyond today/yesterday) using semantic and keyword search. "
|
||||||
|
"Recent context (MEMORY.md + today + yesterday) is already loaded. "
|
||||||
|
"Use this ONLY for older dates, specific past events, or when current context lacks needed info."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return self._description
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query (can be natural language question or keywords)"
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of results to return (default: 10)",
|
||||||
|
"default": 10
|
||||||
|
},
|
||||||
|
"min_score": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Minimum relevance score (0-1, default: 0.3)",
|
||||||
|
"default": 0.3
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Execute memory search
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query
|
||||||
|
max_results: Maximum results
|
||||||
|
min_score: Minimum score
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted search results
|
||||||
|
"""
|
||||||
|
query = kwargs.get("query")
|
||||||
|
max_results = kwargs.get("max_results", 10)
|
||||||
|
min_score = kwargs.get("min_score", 0.3)
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return "Error: query parameter is required"
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = await self.memory_manager.search(
|
||||||
|
query=query,
|
||||||
|
user_id=self.user_id,
|
||||||
|
max_results=max_results,
|
||||||
|
min_score=min_score,
|
||||||
|
include_shared=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return f"No relevant memories found for query: {query}"
|
||||||
|
|
||||||
|
# Format results
|
||||||
|
output = [f"Found {len(results)} relevant memories:\n"]
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
|
||||||
|
output.append(f" Score: {result.score:.3f}")
|
||||||
|
output.append(f" Snippet: {result.snippet}")
|
||||||
|
|
||||||
|
return "\n".join(output)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error searching memory: {str(e)}"
|
||||||
20
agent/protocol/__init__.py
Normal file
20
agent/protocol/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from .agent import Agent
|
||||||
|
from .agent_stream import AgentStreamExecutor
|
||||||
|
from .task import Task, TaskType, TaskStatus
|
||||||
|
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
|
||||||
|
from .models import LLMModel, LLMRequest, ModelFactory
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Agent',
|
||||||
|
'AgentStreamExecutor',
|
||||||
|
'Task',
|
||||||
|
'TaskType',
|
||||||
|
'TaskStatus',
|
||||||
|
'AgentResult',
|
||||||
|
'AgentAction',
|
||||||
|
'AgentActionType',
|
||||||
|
'ToolResult',
|
||||||
|
'LLMModel',
|
||||||
|
'LLMRequest',
|
||||||
|
'ModelFactory'
|
||||||
|
]
|
||||||
292
agent/protocol/agent.py
Normal file
292
agent/protocol/agent.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
from common.log import logger
|
||||||
|
from agent.protocol.models import LLMRequest, LLMModel
|
||||||
|
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||||
|
from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolStage
|
||||||
|
|
||||||
|
|
||||||
|
class Agent:
|
||||||
|
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
||||||
|
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
||||||
|
context_reserve_tokens=None, memory_manager=None, name: str = None):
|
||||||
|
"""
|
||||||
|
Initialize the Agent with system prompt, model, description.
|
||||||
|
|
||||||
|
:param system_prompt: The system prompt for the agent.
|
||||||
|
:param description: A description of the agent.
|
||||||
|
:param model: An instance of LLMModel to be used by the agent.
|
||||||
|
:param tools: Optional list of tools for the agent to use.
|
||||||
|
:param output_mode: Control how execution progress is displayed:
|
||||||
|
"print" for console output or "logger" for using logger
|
||||||
|
:param max_steps: Maximum number of steps the agent can take (default: 100)
|
||||||
|
:param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model)
|
||||||
|
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
|
||||||
|
:param memory_manager: Optional MemoryManager instance for memory operations
|
||||||
|
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
|
||||||
|
"""
|
||||||
|
self.name = name or "Agent"
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.model: LLMModel = model # Instance of LLMModel
|
||||||
|
self.description = description
|
||||||
|
self.tools: list = []
|
||||||
|
self.max_steps = max_steps # max tool-call steps, default 100
|
||||||
|
self.max_context_tokens = max_context_tokens # max tokens in context
|
||||||
|
self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests
|
||||||
|
self.captured_actions = [] # Initialize captured actions list
|
||||||
|
self.output_mode = output_mode
|
||||||
|
self.last_usage = None # Store last API response usage info
|
||||||
|
self.messages = [] # Unified message history for stream mode
|
||||||
|
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
self.add_tool(tool)
|
||||||
|
|
||||||
|
def add_tool(self, tool: BaseTool):
|
||||||
|
"""
|
||||||
|
Add a tool to the agent.
|
||||||
|
|
||||||
|
:param tool: The tool to add (either a tool instance or a tool name)
|
||||||
|
"""
|
||||||
|
# If tool is already an instance, use it directly
|
||||||
|
tool.model = self.model
|
||||||
|
self.tools.append(tool)
|
||||||
|
|
||||||
|
def _get_model_context_window(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the model's context window size in tokens.
|
||||||
|
Auto-detect based on model name.
|
||||||
|
|
||||||
|
Model context windows:
|
||||||
|
- Claude 3.5/3.7 Sonnet: 200K tokens
|
||||||
|
- Claude 3 Opus: 200K tokens
|
||||||
|
- GPT-4 Turbo/128K: 128K tokens
|
||||||
|
- GPT-4: 8K-32K tokens
|
||||||
|
- GPT-3.5: 16K tokens
|
||||||
|
- DeepSeek: 64K tokens
|
||||||
|
|
||||||
|
:return: Context window size in tokens
|
||||||
|
"""
|
||||||
|
if self.model and hasattr(self.model, 'model'):
|
||||||
|
model_name = self.model.model.lower()
|
||||||
|
|
||||||
|
# Claude models - 200K context
|
||||||
|
if 'claude-3' in model_name or 'claude-sonnet' in model_name:
|
||||||
|
return 200000
|
||||||
|
|
||||||
|
# GPT-4 models
|
||||||
|
elif 'gpt-4' in model_name:
|
||||||
|
if 'turbo' in model_name or '128k' in model_name:
|
||||||
|
return 128000
|
||||||
|
elif '32k' in model_name:
|
||||||
|
return 32000
|
||||||
|
else:
|
||||||
|
return 8000
|
||||||
|
|
||||||
|
# GPT-3.5
|
||||||
|
elif 'gpt-3.5' in model_name:
|
||||||
|
if '16k' in model_name:
|
||||||
|
return 16000
|
||||||
|
else:
|
||||||
|
return 4000
|
||||||
|
|
||||||
|
# DeepSeek
|
||||||
|
elif 'deepseek' in model_name:
|
||||||
|
return 64000
|
||||||
|
|
||||||
|
# Default conservative value
|
||||||
|
return 10000
|
||||||
|
|
||||||
|
def _get_context_reserve_tokens(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the number of tokens to reserve for new requests.
|
||||||
|
This prevents context overflow by keeping a buffer.
|
||||||
|
|
||||||
|
:return: Number of tokens to reserve
|
||||||
|
"""
|
||||||
|
if self.context_reserve_tokens is not None:
|
||||||
|
return self.context_reserve_tokens
|
||||||
|
|
||||||
|
# Reserve ~20% of context window for new requests
|
||||||
|
context_window = self._get_model_context_window()
|
||||||
|
return max(4000, int(context_window * 0.2))
|
||||||
|
|
||||||
|
def _estimate_message_tokens(self, message: dict) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count for a message using chars/4 heuristic.
|
||||||
|
This is a conservative estimate (tends to overestimate).
|
||||||
|
|
||||||
|
:param message: Message dict with 'role' and 'content'
|
||||||
|
:return: Estimated token count
|
||||||
|
"""
|
||||||
|
content = message.get('content', '')
|
||||||
|
if isinstance(content, str):
|
||||||
|
return max(1, len(content) // 4)
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Handle multi-part content (text + images)
|
||||||
|
total_chars = 0
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get('type') == 'text':
|
||||||
|
total_chars += len(part.get('text', ''))
|
||||||
|
elif isinstance(part, dict) and part.get('type') == 'image':
|
||||||
|
# Estimate images as ~1200 tokens
|
||||||
|
total_chars += 4800
|
||||||
|
return max(1, total_chars // 4)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def _find_tool(self, tool_name: str):
|
||||||
|
"""Find and return a tool with the specified name"""
|
||||||
|
for tool in self.tools:
|
||||||
|
if tool.name == tool_name:
|
||||||
|
# Only pre-process stage tools can be actively called
|
||||||
|
if tool.stage == ToolStage.PRE_PROCESS:
|
||||||
|
tool.model = self.model
|
||||||
|
tool.context = self # Set tool context
|
||||||
|
return tool
|
||||||
|
else:
|
||||||
|
# If it's a post-process tool, return None to prevent direct calling
|
||||||
|
logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.")
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
# output function based on mode
|
||||||
|
def output(self, message="", end="\n"):
|
||||||
|
if self.output_mode == "print":
|
||||||
|
print(message, end=end)
|
||||||
|
elif message:
|
||||||
|
logger.info(message)
|
||||||
|
|
||||||
|
def _execute_post_process_tools(self):
|
||||||
|
"""Execute all post-process stage tools"""
|
||||||
|
# Get all post-process stage tools
|
||||||
|
post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS]
|
||||||
|
|
||||||
|
# Execute each tool
|
||||||
|
for tool in post_process_tools:
|
||||||
|
# Set tool context
|
||||||
|
tool.context = self
|
||||||
|
|
||||||
|
# Record start time for execution timing
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Execute tool (with empty parameters, tool will extract needed info from context)
|
||||||
|
result = tool.execute({})
|
||||||
|
|
||||||
|
# Calculate execution time
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Capture tool use for tracking
|
||||||
|
self.capture_tool_use(
|
||||||
|
tool_name=tool.name,
|
||||||
|
input_params={}, # Post-process tools typically don't take parameters
|
||||||
|
output=result.result,
|
||||||
|
status=result.status,
|
||||||
|
error_message=str(result.result) if result.status == "error" else None,
|
||||||
|
execution_time=execution_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log result
|
||||||
|
if result.status == "success":
|
||||||
|
# Print tool execution result in the desired format
|
||||||
|
self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}")
|
||||||
|
else:
|
||||||
|
# Print failure in print mode
|
||||||
|
self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}")
|
||||||
|
|
||||||
|
def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None,
|
||||||
|
execution_time=0.0):
|
||||||
|
"""
|
||||||
|
Capture a tool use action.
|
||||||
|
|
||||||
|
:param thought: thought content
|
||||||
|
:param tool_name: Name of the tool used
|
||||||
|
:param input_params: Parameters passed to the tool
|
||||||
|
:param output: Output from the tool
|
||||||
|
:param status: Status of the tool execution
|
||||||
|
:param error_message: Error message if the tool execution failed
|
||||||
|
:param execution_time: Time taken to execute the tool
|
||||||
|
"""
|
||||||
|
tool_result = ToolResult(
|
||||||
|
tool_name=tool_name,
|
||||||
|
input_params=input_params,
|
||||||
|
output=output,
|
||||||
|
status=status,
|
||||||
|
error_message=error_message,
|
||||||
|
execution_time=execution_time
|
||||||
|
)
|
||||||
|
|
||||||
|
action = AgentAction(
|
||||||
|
agent_id=self.id if hasattr(self, 'id') else str(id(self)),
|
||||||
|
agent_name=self.name,
|
||||||
|
action_type=AgentActionType.TOOL_USE,
|
||||||
|
tool_result=tool_result,
|
||||||
|
thought=thought
|
||||||
|
)
|
||||||
|
|
||||||
|
self.captured_actions.append(action)
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Execute single agent task with streaming (based on tool-call)
|
||||||
|
|
||||||
|
This method supports:
|
||||||
|
- Streaming output
|
||||||
|
- Multi-turn reasoning based on tool-call
|
||||||
|
- Event callbacks
|
||||||
|
- Persistent conversation history across calls
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_message: User message
|
||||||
|
on_event: Event callback function callback(event: dict)
|
||||||
|
event = {"type": str, "timestamp": float, "data": dict}
|
||||||
|
clear_history: If True, clear conversation history before this call (default: False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final response text
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Multi-turn conversation with memory
|
||||||
|
response1 = agent.run_stream("My name is Alice")
|
||||||
|
response2 = agent.run_stream("What's my name?") # Will remember Alice
|
||||||
|
|
||||||
|
# Single-turn without memory
|
||||||
|
response = agent.run_stream("Hello", clear_history=True)
|
||||||
|
"""
|
||||||
|
# Clear history if requested
|
||||||
|
if clear_history:
|
||||||
|
self.messages = []
|
||||||
|
|
||||||
|
# Get model to use
|
||||||
|
if not self.model:
|
||||||
|
raise ValueError("No model available for agent")
|
||||||
|
|
||||||
|
# Create stream executor with agent's message history
|
||||||
|
executor = AgentStreamExecutor(
|
||||||
|
agent=self,
|
||||||
|
model=self.model,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
tools=self.tools,
|
||||||
|
max_turns=self.max_steps,
|
||||||
|
on_event=on_event,
|
||||||
|
messages=self.messages # Pass agent's message history
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
response = executor.run_stream(user_message)
|
||||||
|
|
||||||
|
# Update agent's message history from executor
|
||||||
|
self.messages = executor.messages
|
||||||
|
|
||||||
|
# Execute all post-process tools
|
||||||
|
self._execute_post_process_tools()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def clear_history(self):
|
||||||
|
"""Clear conversation history and captured actions"""
|
||||||
|
self.messages = []
|
||||||
|
self.captured_actions = []
|
||||||
461
agent/protocol/agent_stream.py
Normal file
461
agent/protocol/agent_stream.py
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
"""
|
||||||
|
Agent Stream Execution Module - Multi-turn reasoning based on tool-call
|
||||||
|
|
||||||
|
Provides streaming output, event system, and complete tool-call loop
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Any, Optional, Callable
|
||||||
|
|
||||||
|
from common.log import logger
|
||||||
|
from agent.protocol.models import LLMRequest, LLMModel
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStreamExecutor:
|
||||||
|
"""
|
||||||
|
Agent Stream Executor
|
||||||
|
|
||||||
|
Handles multi-turn reasoning loop based on tool-call:
|
||||||
|
1. LLM generates response (may include tool calls)
|
||||||
|
2. Execute tools
|
||||||
|
3. Return results to LLM
|
||||||
|
4. Repeat until no more tool calls
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent, # Agent instance
|
||||||
|
model: LLMModel,
|
||||||
|
system_prompt: str,
|
||||||
|
tools: List[BaseTool],
|
||||||
|
max_turns: int = 50,
|
||||||
|
on_event: Optional[Callable] = None,
|
||||||
|
messages: Optional[List[Dict]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize stream executor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: Agent instance (for accessing context)
|
||||||
|
model: LLM model
|
||||||
|
system_prompt: System prompt
|
||||||
|
tools: List of available tools
|
||||||
|
max_turns: Maximum number of turns
|
||||||
|
on_event: Event callback function
|
||||||
|
messages: Optional existing message history (for persistent conversations)
|
||||||
|
"""
|
||||||
|
self.agent = agent
|
||||||
|
self.model = model
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
# Convert tools list to dict
|
||||||
|
self.tools = {tool.name: tool for tool in tools} if isinstance(tools, list) else tools
|
||||||
|
self.max_turns = max_turns
|
||||||
|
self.on_event = on_event
|
||||||
|
|
||||||
|
# Message history - use provided messages or create new list
|
||||||
|
self.messages = messages if messages is not None else []
|
||||||
|
|
||||||
|
def _emit_event(self, event_type: str, data: dict = None):
|
||||||
|
"""Emit event"""
|
||||||
|
if self.on_event:
|
||||||
|
try:
|
||||||
|
self.on_event({
|
||||||
|
"type": event_type,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"data": data or {}
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Event callback error: {e}")
|
||||||
|
|
||||||
|
def run_stream(self, user_message: str) -> str:
|
||||||
|
"""
|
||||||
|
Execute streaming reasoning loop
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_message: User message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final response text
|
||||||
|
"""
|
||||||
|
# Log user message
|
||||||
|
logger.info(f"\n{'='*50}")
|
||||||
|
logger.info(f"👤 用户: {user_message}")
|
||||||
|
logger.info(f"{'='*50}")
|
||||||
|
|
||||||
|
# Add user message (Claude format - use content blocks for consistency)
|
||||||
|
self.messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": user_message
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
self._emit_event("agent_start")
|
||||||
|
|
||||||
|
final_response = ""
|
||||||
|
turn = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while turn < self.max_turns:
|
||||||
|
turn += 1
|
||||||
|
logger.info(f"\n{'='*50} 第 {turn} 轮 {'='*50}")
|
||||||
|
self._emit_event("turn_start", {"turn": turn})
|
||||||
|
|
||||||
|
# Check if memory flush is needed (before calling LLM)
|
||||||
|
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
||||||
|
usage = self.agent.last_usage
|
||||||
|
if usage and 'input_tokens' in usage:
|
||||||
|
current_tokens = usage.get('input_tokens', 0)
|
||||||
|
context_window = self.agent._get_model_context_window()
|
||||||
|
reserve_tokens = self.agent.context_reserve_tokens or 20000
|
||||||
|
|
||||||
|
if self.agent.memory_manager.should_flush_memory(
|
||||||
|
current_tokens=current_tokens,
|
||||||
|
context_window=context_window,
|
||||||
|
reserve_tokens=reserve_tokens
|
||||||
|
):
|
||||||
|
self._emit_event("memory_flush_start", {
|
||||||
|
"current_tokens": current_tokens,
|
||||||
|
"threshold": context_window - reserve_tokens - 4000
|
||||||
|
})
|
||||||
|
|
||||||
|
# TODO: Execute memory flush in background
|
||||||
|
# This would require async support
|
||||||
|
logger.info(f"Memory flush recommended at {current_tokens} tokens")
|
||||||
|
|
||||||
|
# Call LLM
|
||||||
|
assistant_msg, tool_calls = self._call_llm_stream()
|
||||||
|
final_response = assistant_msg
|
||||||
|
|
||||||
|
# No tool calls, end loop
|
||||||
|
if not tool_calls:
|
||||||
|
if assistant_msg:
|
||||||
|
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
|
||||||
|
logger.info(f"✅ 完成 (无工具调用)")
|
||||||
|
self._emit_event("turn_end", {
|
||||||
|
"turn": turn,
|
||||||
|
"has_tool_calls": False
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
|
# Log tool calls in compact format
|
||||||
|
tool_names = [tc['name'] for tc in tool_calls]
|
||||||
|
logger.info(f"🔧 调用工具: {', '.join(tool_names)}")
|
||||||
|
|
||||||
|
# Execute tools
|
||||||
|
tool_results = []
|
||||||
|
tool_result_blocks = []
|
||||||
|
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
result = self._execute_tool(tool_call)
|
||||||
|
tool_results.append(result)
|
||||||
|
|
||||||
|
# Log tool result in compact format
|
||||||
|
status_emoji = "✅" if result.get("status") == "success" else "❌"
|
||||||
|
result_str = str(result.get('result', ''))
|
||||||
|
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
||||||
|
|
||||||
|
# Build tool result block (Claude format)
|
||||||
|
# Content should be a string representation of the result
|
||||||
|
result_content = json.dumps(result) if not isinstance(result, str) else result
|
||||||
|
tool_result_blocks.append({
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": tool_call["id"],
|
||||||
|
"content": result_content
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add tool results to message history as user message (Claude format)
|
||||||
|
self.messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": tool_result_blocks
|
||||||
|
})
|
||||||
|
|
||||||
|
self._emit_event("turn_end", {
|
||||||
|
"turn": turn,
|
||||||
|
"has_tool_calls": True,
|
||||||
|
"tool_count": len(tool_calls)
|
||||||
|
})
|
||||||
|
|
||||||
|
if turn >= self.max_turns:
|
||||||
|
logger.warning(f"⚠️ 已达到最大轮数限制: {self.max_turns}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Agent执行错误: {e}")
|
||||||
|
self._emit_event("error", {"error": str(e)})
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
logger.info(f"{'='*50} 完成({turn}轮) {'='*50}\n")
|
||||||
|
self._emit_event("agent_end", {"final_response": final_response})
|
||||||
|
|
||||||
|
return final_response
|
||||||
|
|
||||||
|
def _call_llm_stream(self) -> tuple[str, List[Dict]]:
|
||||||
|
"""
|
||||||
|
Call LLM with streaming
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(response_text, tool_calls)
|
||||||
|
"""
|
||||||
|
# Trim messages if needed (using agent's context management)
|
||||||
|
self._trim_messages()
|
||||||
|
|
||||||
|
# Prepare messages
|
||||||
|
messages = self._prepare_messages()
|
||||||
|
|
||||||
|
# Debug: log message structure
|
||||||
|
logger.debug(f"Sending {len(messages)} messages to LLM")
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
role = msg.get("role", "unknown")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if isinstance(content, list):
|
||||||
|
content_types = [c.get("type") for c in content if isinstance(c, dict)]
|
||||||
|
logger.debug(f" Message {i}: role={role}, content_blocks={content_types}")
|
||||||
|
else:
|
||||||
|
logger.debug(f" Message {i}: role={role}, content_length={len(str(content))}")
|
||||||
|
|
||||||
|
# Prepare tool definitions (OpenAI/Claude format)
|
||||||
|
tools_schema = None
|
||||||
|
if self.tools:
|
||||||
|
tools_schema = []
|
||||||
|
for tool in self.tools.values():
|
||||||
|
tools_schema.append({
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"input_schema": tool.params # Claude uses input_schema
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create request
|
||||||
|
request = LLMRequest(
|
||||||
|
messages=messages,
|
||||||
|
temperature=0,
|
||||||
|
stream=True,
|
||||||
|
tools=tools_schema,
|
||||||
|
system=self.system_prompt # Pass system prompt separately for Claude API
|
||||||
|
)
|
||||||
|
|
||||||
|
self._emit_event("message_start", {"role": "assistant"})
|
||||||
|
|
||||||
|
# Streaming response
|
||||||
|
full_content = ""
|
||||||
|
tool_calls_buffer = {} # {index: {id, name, arguments}}
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = self.model.call_stream(request)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
# Check for errors
|
||||||
|
if isinstance(chunk, dict) and chunk.get("error"):
|
||||||
|
error_msg = chunk.get("message", "Unknown error")
|
||||||
|
status_code = chunk.get("status_code", "N/A")
|
||||||
|
logger.error(f"API Error: {error_msg} (Status: {status_code})")
|
||||||
|
logger.error(f"Full error chunk: {chunk}")
|
||||||
|
raise Exception(f"{error_msg} (Status: {status_code})")
|
||||||
|
|
||||||
|
# Parse chunk
|
||||||
|
if isinstance(chunk, dict) and "choices" in chunk:
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
delta = choice.get("delta", {})
|
||||||
|
|
||||||
|
# Handle text content
|
||||||
|
if "content" in delta and delta["content"]:
|
||||||
|
content_delta = delta["content"]
|
||||||
|
full_content += content_delta
|
||||||
|
self._emit_event("message_update", {"delta": content_delta})
|
||||||
|
|
||||||
|
# Handle tool calls
|
||||||
|
if "tool_calls" in delta:
|
||||||
|
for tc_delta in delta["tool_calls"]:
|
||||||
|
index = tc_delta.get("index", 0)
|
||||||
|
|
||||||
|
if index not in tool_calls_buffer:
|
||||||
|
tool_calls_buffer[index] = {
|
||||||
|
"id": "",
|
||||||
|
"name": "",
|
||||||
|
"arguments": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if "id" in tc_delta:
|
||||||
|
tool_calls_buffer[index]["id"] = tc_delta["id"]
|
||||||
|
|
||||||
|
if "function" in tc_delta:
|
||||||
|
func = tc_delta["function"]
|
||||||
|
if "name" in func:
|
||||||
|
tool_calls_buffer[index]["name"] = func["name"]
|
||||||
|
if "arguments" in func:
|
||||||
|
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM call error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Parse tool calls
|
||||||
|
tool_calls = []
|
||||||
|
for idx in sorted(tool_calls_buffer.keys()):
|
||||||
|
tc = tool_calls_buffer[idx]
|
||||||
|
try:
|
||||||
|
arguments = json.loads(tc["arguments"]) if tc["arguments"] else {}
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Failed to parse tool arguments: {tc['arguments']}")
|
||||||
|
arguments = {}
|
||||||
|
|
||||||
|
tool_calls.append({
|
||||||
|
"id": tc["id"],
|
||||||
|
"name": tc["name"],
|
||||||
|
"arguments": arguments
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add assistant message to history (Claude format uses content blocks)
|
||||||
|
assistant_msg = {"role": "assistant", "content": []}
|
||||||
|
|
||||||
|
# Add text content block if present
|
||||||
|
if full_content:
|
||||||
|
assistant_msg["content"].append({
|
||||||
|
"type": "text",
|
||||||
|
"text": full_content
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add tool_use blocks if present
|
||||||
|
if tool_calls:
|
||||||
|
for tc in tool_calls:
|
||||||
|
assistant_msg["content"].append({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc["id"],
|
||||||
|
"name": tc["name"],
|
||||||
|
"input": tc["arguments"]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Only append if content is not empty
|
||||||
|
if assistant_msg["content"]:
|
||||||
|
self.messages.append(assistant_msg)
|
||||||
|
|
||||||
|
self._emit_event("message_end", {
|
||||||
|
"content": full_content,
|
||||||
|
"tool_calls": tool_calls
|
||||||
|
})
|
||||||
|
|
||||||
|
return full_content, tool_calls
|
||||||
|
|
||||||
|
def _execute_tool(self, tool_call: Dict) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call: {"id": str, "name": str, "arguments": dict}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
"""
|
||||||
|
tool_name = tool_call["name"]
|
||||||
|
tool_id = tool_call["id"]
|
||||||
|
arguments = tool_call["arguments"]
|
||||||
|
|
||||||
|
self._emit_event("tool_execution_start", {
|
||||||
|
"tool_call_id": tool_id,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"arguments": arguments
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool = self.tools.get(tool_name)
|
||||||
|
if not tool:
|
||||||
|
raise ValueError(f"Tool '{tool_name}' not found")
|
||||||
|
|
||||||
|
# Set tool context
|
||||||
|
tool.model = self.model
|
||||||
|
tool.context = self.agent
|
||||||
|
|
||||||
|
# Execute tool
|
||||||
|
start_time = time.time()
|
||||||
|
result: ToolResult = tool.execute_tool(arguments)
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
|
||||||
|
result_dict = {
|
||||||
|
"status": result.status,
|
||||||
|
"result": result.result,
|
||||||
|
"execution_time": execution_time
|
||||||
|
}
|
||||||
|
|
||||||
|
self._emit_event("tool_execution_end", {
|
||||||
|
"tool_call_id": tool_id,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
**result_dict
|
||||||
|
})
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool execution error: {e}")
|
||||||
|
error_result = {
|
||||||
|
"status": "error",
|
||||||
|
"result": str(e),
|
||||||
|
"execution_time": 0
|
||||||
|
}
|
||||||
|
self._emit_event("tool_execution_end", {
|
||||||
|
"tool_call_id": tool_id,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
**error_result
|
||||||
|
})
|
||||||
|
return error_result
|
||||||
|
|
||||||
|
def _trim_messages(self):
|
||||||
|
"""
|
||||||
|
Trim message history to stay within context limits.
|
||||||
|
Uses agent's context management configuration.
|
||||||
|
"""
|
||||||
|
if not self.messages or not self.agent:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get context window and reserve tokens from agent
|
||||||
|
context_window = self.agent._get_model_context_window()
|
||||||
|
reserve_tokens = self.agent._get_context_reserve_tokens()
|
||||||
|
max_tokens = context_window - reserve_tokens
|
||||||
|
|
||||||
|
# Estimate current tokens
|
||||||
|
current_tokens = sum(self.agent._estimate_message_tokens(msg) for msg in self.messages)
|
||||||
|
|
||||||
|
# Add system prompt tokens
|
||||||
|
system_tokens = self.agent._estimate_message_tokens({"role": "system", "content": self.system_prompt})
|
||||||
|
current_tokens += system_tokens
|
||||||
|
|
||||||
|
# If under limit, no need to trim
|
||||||
|
if current_tokens <= max_tokens:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Keep messages from newest, accumulating tokens
|
||||||
|
available_tokens = max_tokens - system_tokens
|
||||||
|
kept_messages = []
|
||||||
|
accumulated_tokens = 0
|
||||||
|
|
||||||
|
for msg in reversed(self.messages):
|
||||||
|
msg_tokens = self.agent._estimate_message_tokens(msg)
|
||||||
|
if accumulated_tokens + msg_tokens <= available_tokens:
|
||||||
|
kept_messages.insert(0, msg)
|
||||||
|
accumulated_tokens += msg_tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
old_count = len(self.messages)
|
||||||
|
self.messages = kept_messages
|
||||||
|
new_count = len(self.messages)
|
||||||
|
|
||||||
|
if old_count > new_count:
|
||||||
|
logger.info(
|
||||||
|
f"Context trimmed: {old_count} -> {new_count} messages "
|
||||||
|
f"(~{current_tokens} -> ~{system_tokens + accumulated_tokens} tokens, "
|
||||||
|
f"limit: {max_tokens})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_messages(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Prepare messages to send to LLM
|
||||||
|
|
||||||
|
Note: For Claude API, system prompt should be passed separately via system parameter,
|
||||||
|
not as a message. The AgentLLMModel will handle this.
|
||||||
|
"""
|
||||||
|
# Don't add system message here - it will be handled separately by the LLM adapter
|
||||||
|
return self.messages
|
||||||
27
agent/protocol/context.py
Normal file
27
agent/protocol/context.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
class TeamContext:
|
||||||
|
def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100):
|
||||||
|
"""
|
||||||
|
Initialize the TeamContext with a name, description, rules, a list of agents, and a user question.
|
||||||
|
:param name: The name of the group context.
|
||||||
|
:param description: A description of the group context.
|
||||||
|
:param rule: The rules governing the group context.
|
||||||
|
:param agents: A list of agents in the context.
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.rule = rule
|
||||||
|
self.agents = agents
|
||||||
|
self.user_task = "" # For backward compatibility
|
||||||
|
self.task = None # Will be a Task instance
|
||||||
|
self.model = None # Will be an instance of LLMModel
|
||||||
|
self.task_short_name = None # Store the task directory name
|
||||||
|
# List of agents that have been executed
|
||||||
|
self.agent_outputs: list = []
|
||||||
|
self.current_steps = 0
|
||||||
|
self.max_steps = max_steps
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOutput:
|
||||||
|
def __init__(self, agent_name: str, output: str):
|
||||||
|
self.agent_name = agent_name
|
||||||
|
self.output = output
|
||||||
57
agent/protocol/models.py
Normal file
57
agent/protocol/models.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""
|
||||||
|
Models module for agent system.
|
||||||
|
Provides basic model classes needed by tools and bridge integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRequest:
|
||||||
|
"""Request model for LLM operations"""
|
||||||
|
|
||||||
|
def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None,
|
||||||
|
temperature: float = 0.7, max_tokens: Optional[int] = None,
|
||||||
|
stream: bool = False, tools: Optional[List] = None, **kwargs):
|
||||||
|
self.messages = messages or []
|
||||||
|
self.model = model
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.stream = stream
|
||||||
|
self.tools = tools
|
||||||
|
# Allow extra attributes
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModel:
|
||||||
|
"""Base class for LLM models"""
|
||||||
|
|
||||||
|
def __init__(self, model: str = None, **kwargs):
|
||||||
|
self.model = model
|
||||||
|
self.config = kwargs
|
||||||
|
|
||||||
|
def call(self, request: LLMRequest):
|
||||||
|
"""
|
||||||
|
Call the model with a request.
|
||||||
|
This is a placeholder implementation.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("LLMModel.call not implemented in this context")
|
||||||
|
|
||||||
|
def call_stream(self, request: LLMRequest):
|
||||||
|
"""
|
||||||
|
Call the model with streaming.
|
||||||
|
This is a placeholder implementation.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("LLMModel.call_stream not implemented in this context")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFactory:
|
||||||
|
"""Factory for creating model instances"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_model(model_type: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a model instance based on type.
|
||||||
|
This is a placeholder implementation.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("ModelFactory.create_model not implemented in this context")
|
||||||
96
agent/protocol/result.py
Normal file
96
agent/protocol/result.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from agent.protocol.task import Task, TaskStatus
|
||||||
|
|
||||||
|
|
||||||
|
class AgentActionType(Enum):
|
||||||
|
"""Enum representing different types of agent actions."""
|
||||||
|
TOOL_USE = "tool_use"
|
||||||
|
THINKING = "thinking"
|
||||||
|
FINAL_ANSWER = "final_answer"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResult:
|
||||||
|
"""
|
||||||
|
Represents the result of a tool use.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tool_name: Name of the tool used
|
||||||
|
input_params: Parameters passed to the tool
|
||||||
|
output: Output from the tool
|
||||||
|
status: Status of the tool execution (success/error)
|
||||||
|
error_message: Error message if the tool execution failed
|
||||||
|
execution_time: Time taken to execute the tool
|
||||||
|
"""
|
||||||
|
tool_name: str
|
||||||
|
input_params: Dict[str, Any]
|
||||||
|
output: Any
|
||||||
|
status: str
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
execution_time: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentAction:
|
||||||
|
"""
|
||||||
|
Represents an action taken by an agent.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique identifier for the action
|
||||||
|
agent_id: ID of the agent that performed the action
|
||||||
|
agent_name: Name of the agent that performed the action
|
||||||
|
action_type: Type of action (tool use, thinking, final answer)
|
||||||
|
content: Content of the action (thought content, final answer content)
|
||||||
|
tool_result: Tool use details if action_type is TOOL_USE
|
||||||
|
timestamp: When the action was performed
|
||||||
|
"""
|
||||||
|
agent_id: str
|
||||||
|
agent_name: str
|
||||||
|
action_type: AgentActionType
|
||||||
|
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
content: str = ""
|
||||||
|
tool_result: Optional[ToolResult] = None
|
||||||
|
thought: Optional[str] = None
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentResult:
|
||||||
|
"""
|
||||||
|
Represents the result of an agent's execution.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
final_answer: The final answer provided by the agent
|
||||||
|
step_count: Number of steps taken by the agent
|
||||||
|
status: Status of the execution (success/error)
|
||||||
|
error_message: Error message if execution failed
|
||||||
|
"""
|
||||||
|
final_answer: str
|
||||||
|
step_count: int
|
||||||
|
status: str = "success"
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def success(cls, final_answer: str, step_count: int) -> "AgentResult":
|
||||||
|
"""Create a successful result"""
|
||||||
|
return cls(final_answer=final_answer, step_count=step_count)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error(cls, error_message: str, step_count: int = 0) -> "AgentResult":
|
||||||
|
"""Create an error result"""
|
||||||
|
return cls(
|
||||||
|
final_answer=f"Error: {error_message}",
|
||||||
|
step_count=step_count,
|
||||||
|
status="error",
|
||||||
|
error_message=error_message
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_error(self) -> bool:
|
||||||
|
"""Check if the result represents an error"""
|
||||||
|
return self.status == "error"
|
||||||
95
agent/protocol/task.py
Normal file
95
agent/protocol/task.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
|
||||||
|
class TaskType(Enum):
|
||||||
|
"""Enum representing different types of tasks."""
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
VIDEO = "video"
|
||||||
|
AUDIO = "audio"
|
||||||
|
FILE = "file"
|
||||||
|
MIXED = "mixed"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
"""Enum representing the status of a task."""
|
||||||
|
INIT = "init" # Initial state
|
||||||
|
PROCESSING = "processing" # In progress
|
||||||
|
COMPLETED = "completed" # Completed
|
||||||
|
FAILED = "failed" # Failed
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Task:
|
||||||
|
"""
|
||||||
|
Represents a task to be processed by an agent.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique identifier for the task
|
||||||
|
content: The primary text content of the task
|
||||||
|
type: Type of the task
|
||||||
|
status: Current status of the task
|
||||||
|
created_at: Timestamp when the task was created
|
||||||
|
updated_at: Timestamp when the task was last updated
|
||||||
|
metadata: Additional metadata for the task
|
||||||
|
images: List of image URLs or base64 encoded images
|
||||||
|
videos: List of video URLs
|
||||||
|
audios: List of audio URLs or base64 encoded audios
|
||||||
|
files: List of file URLs or paths
|
||||||
|
"""
|
||||||
|
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
content: str = ""
|
||||||
|
type: TaskType = TaskType.TEXT
|
||||||
|
status: TaskStatus = TaskStatus.INIT
|
||||||
|
created_at: float = field(default_factory=time.time)
|
||||||
|
updated_at: float = field(default_factory=time.time)
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Media content
|
||||||
|
images: List[str] = field(default_factory=list)
|
||||||
|
videos: List[str] = field(default_factory=list)
|
||||||
|
audios: List[str] = field(default_factory=list)
|
||||||
|
files: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def __init__(self, content: str = "", **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize a Task with content and optional keyword arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The text content of the task
|
||||||
|
**kwargs: Additional attributes to set
|
||||||
|
"""
|
||||||
|
self.id = kwargs.get('id', str(uuid.uuid4()))
|
||||||
|
self.content = content
|
||||||
|
self.type = kwargs.get('type', TaskType.TEXT)
|
||||||
|
self.status = kwargs.get('status', TaskStatus.INIT)
|
||||||
|
self.created_at = kwargs.get('created_at', time.time())
|
||||||
|
self.updated_at = kwargs.get('updated_at', time.time())
|
||||||
|
self.metadata = kwargs.get('metadata', {})
|
||||||
|
self.images = kwargs.get('images', [])
|
||||||
|
self.videos = kwargs.get('videos', [])
|
||||||
|
self.audios = kwargs.get('audios', [])
|
||||||
|
self.files = kwargs.get('files', [])
|
||||||
|
|
||||||
|
def get_text(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the text content of the task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The text content
|
||||||
|
"""
|
||||||
|
return self.content
|
||||||
|
|
||||||
|
def update_status(self, status: TaskStatus) -> None:
|
||||||
|
"""
|
||||||
|
Update the status of the task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: The new status
|
||||||
|
"""
|
||||||
|
self.status = status
|
||||||
|
self.updated_at = time.time()
|
||||||
101
agent/tools/__init__.py
Normal file
101
agent/tools/__init__.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
# Import base tool
|
||||||
|
from agent.tools.base_tool import BaseTool
|
||||||
|
from agent.tools.tool_manager import ToolManager
|
||||||
|
|
||||||
|
# Import basic tools (no external dependencies)
|
||||||
|
from agent.tools.calculator.calculator import Calculator
|
||||||
|
from agent.tools.current_time.current_time import CurrentTime
|
||||||
|
|
||||||
|
# Import file operation tools
|
||||||
|
from agent.tools.read.read import Read
|
||||||
|
from agent.tools.write.write import Write
|
||||||
|
from agent.tools.edit.edit import Edit
|
||||||
|
from agent.tools.bash.bash import Bash
|
||||||
|
from agent.tools.grep.grep import Grep
|
||||||
|
from agent.tools.find.find import Find
|
||||||
|
from agent.tools.ls.ls import Ls
|
||||||
|
|
||||||
|
# Import memory tools
|
||||||
|
from agent.tools.memory.memory_search import MemorySearchTool
|
||||||
|
from agent.tools.memory.memory_get import MemoryGetTool
|
||||||
|
|
||||||
|
# Import tools with optional dependencies
|
||||||
|
def _import_optional_tools():
|
||||||
|
"""Import tools that have optional dependencies"""
|
||||||
|
tools = {}
|
||||||
|
|
||||||
|
# Google Search (requires requests)
|
||||||
|
try:
|
||||||
|
from agent.tools.google_search.google_search import GoogleSearch
|
||||||
|
tools['GoogleSearch'] = GoogleSearch
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# File Save (may have dependencies)
|
||||||
|
try:
|
||||||
|
from agent.tools.file_save.file_save import FileSave
|
||||||
|
tools['FileSave'] = FileSave
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Terminal (basic, should work)
|
||||||
|
try:
|
||||||
|
from agent.tools.terminal.terminal import Terminal
|
||||||
|
tools['Terminal'] = Terminal
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return tools
|
||||||
|
|
||||||
|
# Load optional tools
|
||||||
|
_optional_tools = _import_optional_tools()
|
||||||
|
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||||
|
FileSave = _optional_tools.get('FileSave')
|
||||||
|
Terminal = _optional_tools.get('Terminal')
|
||||||
|
|
||||||
|
|
||||||
|
# Delayed import for BrowserTool
|
||||||
|
def _import_browser_tool():
|
||||||
|
try:
|
||||||
|
from agent.tools.browser.browser_tool import BrowserTool
|
||||||
|
return BrowserTool
|
||||||
|
except ImportError:
|
||||||
|
# Return a placeholder class that will prompt the user to install dependencies when instantiated
|
||||||
|
class BrowserToolPlaceholder:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
raise ImportError(
|
||||||
|
"The 'browser-use' package is required to use BrowserTool. "
|
||||||
|
"Please install it with 'pip install browser-use>=0.1.40'."
|
||||||
|
)
|
||||||
|
|
||||||
|
return BrowserToolPlaceholder
|
||||||
|
|
||||||
|
|
||||||
|
# Dynamically set BrowserTool
|
||||||
|
BrowserTool = _import_browser_tool()
|
||||||
|
|
||||||
|
# Export all tools (including optional ones that might be None)
|
||||||
|
__all__ = [
|
||||||
|
'BaseTool',
|
||||||
|
'ToolManager',
|
||||||
|
'Calculator',
|
||||||
|
'CurrentTime',
|
||||||
|
'Read',
|
||||||
|
'Write',
|
||||||
|
'Edit',
|
||||||
|
'Bash',
|
||||||
|
'Grep',
|
||||||
|
'Find',
|
||||||
|
'Ls',
|
||||||
|
'MemorySearchTool',
|
||||||
|
'MemoryGetTool',
|
||||||
|
# Optional tools (may be None if dependencies not available)
|
||||||
|
'GoogleSearch',
|
||||||
|
'FileSave',
|
||||||
|
'Terminal',
|
||||||
|
'BrowserTool'
|
||||||
|
]
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tools module for Agent.
|
||||||
|
"""
|
||||||
99
agent/tools/base_tool.py
Normal file
99
agent/tools/base_tool.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
from common.log import logger
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class ToolStage(Enum):
|
||||||
|
"""Enum representing tool decision stages"""
|
||||||
|
PRE_PROCESS = "pre_process" # Tools that need to be actively selected by the agent
|
||||||
|
POST_PROCESS = "post_process" # Tools that automatically execute after final_answer
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResult:
|
||||||
|
"""Tool execution result"""
|
||||||
|
|
||||||
|
def __init__(self, status: str = None, result: Any = None, ext_data: Any = None):
|
||||||
|
self.status = status
|
||||||
|
self.result = result
|
||||||
|
self.ext_data = ext_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def success(result, ext_data: Any = None):
|
||||||
|
return ToolResult(status="success", result=result, ext_data=ext_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fail(result, ext_data: Any = None):
|
||||||
|
return ToolResult(status="error", result=result, ext_data=ext_data)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool:
|
||||||
|
"""Base class for all tools."""
|
||||||
|
|
||||||
|
# Default decision stage is pre-process
|
||||||
|
stage = ToolStage.PRE_PROCESS
|
||||||
|
|
||||||
|
# Class attributes must be inherited
|
||||||
|
name: str = "base_tool"
|
||||||
|
description: str = "Base tool"
|
||||||
|
params: dict = {} # Store JSON Schema
|
||||||
|
model: Optional[Any] = None # LLM model instance, type depends on bot implementation
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_json_schema(cls) -> dict:
|
||||||
|
"""Get the standard description of the tool"""
|
||||||
|
return {
|
||||||
|
"name": cls.name,
|
||||||
|
"description": cls.description,
|
||||||
|
"parameters": cls.params
|
||||||
|
}
|
||||||
|
|
||||||
|
def execute_tool(self, params: dict) -> ToolResult:
|
||||||
|
try:
|
||||||
|
return self.execute(params)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
def execute(self, params: dict) -> ToolResult:
|
||||||
|
"""Specific logic to be implemented by subclasses"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_schema(cls) -> dict:
|
||||||
|
"""Convert JSON Schema to Pydantic fields"""
|
||||||
|
fields = {}
|
||||||
|
for name, prop in cls.params["properties"].items():
|
||||||
|
# Convert JSON Schema types to Python types
|
||||||
|
type_map = {
|
||||||
|
"string": str,
|
||||||
|
"number": float,
|
||||||
|
"integer": int,
|
||||||
|
"boolean": bool,
|
||||||
|
"array": list,
|
||||||
|
"object": dict
|
||||||
|
}
|
||||||
|
fields[name] = (
|
||||||
|
type_map[prop["type"]],
|
||||||
|
prop.get("default", ...)
|
||||||
|
)
|
||||||
|
return fields
|
||||||
|
|
||||||
|
def should_auto_execute(self, context) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if this tool should be automatically executed based on context.
|
||||||
|
|
||||||
|
:param context: The agent context
|
||||||
|
:return: True if the tool should be executed, False otherwise
|
||||||
|
"""
|
||||||
|
# Only tools in post-process stage will be automatically executed
|
||||||
|
return self.stage == ToolStage.POST_PROCESS
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Close any resources used by the tool.
|
||||||
|
This method should be overridden by tools that need to clean up resources
|
||||||
|
such as browser connections, file handles, etc.
|
||||||
|
|
||||||
|
By default, this method does nothing.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
3
agent/tools/bash/__init__.py
Normal file
3
agent/tools/bash/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .bash import Bash
|
||||||
|
|
||||||
|
__all__ = ['Bash']
|
||||||
187
agent/tools/bash/bash.py
Normal file
187
agent/tools/bash/bash.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""
|
||||||
|
Bash tool - Execute bash commands
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||||
|
|
||||||
|
|
||||||
|
class Bash(BaseTool):
|
||||||
|
"""Tool for executing bash commands"""
|
||||||
|
|
||||||
|
name: str = "bash"
|
||||||
|
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
|
||||||
|
|
||||||
|
IMPORTANT SAFETY GUIDELINES:
|
||||||
|
- You can freely create, modify, and delete files within the current workspace
|
||||||
|
- For operations outside the workspace or potentially destructive commands (rm -rf, system commands, etc.), always explain what you're about to do and ask for user confirmation first
|
||||||
|
- Be especially careful with: file deletions, system modifications, network operations, or commands that might affect system stability
|
||||||
|
- When in doubt, describe the command's purpose and ask for permission before executing"""
|
||||||
|
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Bash command to execute"
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Timeout in seconds (optional, default: 30)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["command"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: dict = None):
|
||||||
|
self.config = config or {}
|
||||||
|
self.cwd = self.config.get("cwd", os.getcwd())
|
||||||
|
# Ensure working directory exists
|
||||||
|
if not os.path.exists(self.cwd):
|
||||||
|
os.makedirs(self.cwd, exist_ok=True)
|
||||||
|
self.default_timeout = self.config.get("timeout", 30)
|
||||||
|
# Enable safety mode by default (can be disabled in config)
|
||||||
|
self.safety_mode = self.config.get("safety_mode", True)
|
||||||
|
|
||||||
|
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Execute a bash command
|
||||||
|
|
||||||
|
:param args: Dictionary containing the command and optional timeout
|
||||||
|
:return: Command output or error
|
||||||
|
"""
|
||||||
|
command = args.get("command", "").strip()
|
||||||
|
timeout = args.get("timeout", self.default_timeout)
|
||||||
|
|
||||||
|
if not command:
|
||||||
|
return ToolResult.fail("Error: command parameter is required")
|
||||||
|
|
||||||
|
# Optional safety check - only warn about extremely dangerous commands
|
||||||
|
if self.safety_mode:
|
||||||
|
warning = self._get_safety_warning(command)
|
||||||
|
if warning:
|
||||||
|
return ToolResult.fail(
|
||||||
|
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute command
|
||||||
|
result = subprocess.run(
|
||||||
|
command,
|
||||||
|
shell=True,
|
||||||
|
cwd=self.cwd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine stdout and stderr
|
||||||
|
output = result.stdout
|
||||||
|
if result.stderr:
|
||||||
|
output += "\n" + result.stderr
|
||||||
|
|
||||||
|
# Check if we need to save full output to temp file
|
||||||
|
temp_file_path = None
|
||||||
|
total_bytes = len(output.encode('utf-8'))
|
||||||
|
|
||||||
|
if total_bytes > DEFAULT_MAX_BYTES:
|
||||||
|
# Save full output to temp file
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='bash-') as f:
|
||||||
|
f.write(output)
|
||||||
|
temp_file_path = f.name
|
||||||
|
|
||||||
|
# Apply tail truncation
|
||||||
|
truncation = truncate_tail(output)
|
||||||
|
output_text = truncation.content or "(no output)"
|
||||||
|
|
||||||
|
# Build result
|
||||||
|
details = {}
|
||||||
|
|
||||||
|
if truncation.truncated:
|
||||||
|
details["truncation"] = truncation.to_dict()
|
||||||
|
if temp_file_path:
|
||||||
|
details["full_output_path"] = temp_file_path
|
||||||
|
|
||||||
|
# Build notice
|
||||||
|
start_line = truncation.total_lines - truncation.output_lines + 1
|
||||||
|
end_line = truncation.total_lines
|
||||||
|
|
||||||
|
if truncation.last_line_partial:
|
||||||
|
# Edge case: last line alone > 30KB
|
||||||
|
last_line = output.split('\n')[-1] if output else ""
|
||||||
|
last_line_size = format_size(len(last_line.encode('utf-8')))
|
||||||
|
output_text += f"\n\n[Showing last {format_size(truncation.output_bytes)} of line {end_line} (line is {last_line_size}). Full output: {temp_file_path}]"
|
||||||
|
elif truncation.truncated_by == "lines":
|
||||||
|
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines}. Full output: {temp_file_path}]"
|
||||||
|
else:
|
||||||
|
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Full output: {temp_file_path}]"
|
||||||
|
|
||||||
|
# Check exit code
|
||||||
|
if result.returncode != 0:
|
||||||
|
output_text += f"\n\nCommand exited with code {result.returncode}"
|
||||||
|
return ToolResult.fail({
|
||||||
|
"output": output_text,
|
||||||
|
"exit_code": result.returncode,
|
||||||
|
"details": details if details else None
|
||||||
|
})
|
||||||
|
|
||||||
|
return ToolResult.success({
|
||||||
|
"output": output_text,
|
||||||
|
"exit_code": result.returncode,
|
||||||
|
"details": details if details else None
|
||||||
|
})
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return ToolResult.fail(f"Error: Command timed out after {timeout} seconds")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error executing command: {str(e)}")
|
||||||
|
|
||||||
|
def _get_safety_warning(self, command: str) -> str:
|
||||||
|
"""
|
||||||
|
Get safety warning for potentially dangerous commands
|
||||||
|
Only warns about extremely dangerous system-level operations
|
||||||
|
|
||||||
|
:param command: Command to check
|
||||||
|
:return: Warning message if dangerous, empty string if safe
|
||||||
|
"""
|
||||||
|
cmd_lower = command.lower().strip()
|
||||||
|
|
||||||
|
# Only block extremely dangerous system operations
|
||||||
|
dangerous_patterns = [
|
||||||
|
# System shutdown/reboot
|
||||||
|
("shutdown", "This command will shut down the system"),
|
||||||
|
("reboot", "This command will reboot the system"),
|
||||||
|
("halt", "This command will halt the system"),
|
||||||
|
("poweroff", "This command will power off the system"),
|
||||||
|
|
||||||
|
# Critical system modifications
|
||||||
|
("rm -rf /", "This command will delete the entire filesystem"),
|
||||||
|
("rm -rf /*", "This command will delete the entire filesystem"),
|
||||||
|
("dd if=/dev/zero", "This command can destroy disk data"),
|
||||||
|
("mkfs", "This command will format a filesystem, destroying all data"),
|
||||||
|
("fdisk", "This command modifies disk partitions"),
|
||||||
|
|
||||||
|
# User/system management (only if targeting system users)
|
||||||
|
("userdel root", "This command will delete the root user"),
|
||||||
|
("passwd root", "This command will change the root password"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, warning in dangerous_patterns:
|
||||||
|
if pattern in cmd_lower:
|
||||||
|
return warning
|
||||||
|
|
||||||
|
# Check for recursive deletion outside workspace
|
||||||
|
if "rm" in cmd_lower and "-rf" in cmd_lower:
|
||||||
|
# Allow deletion within current workspace
|
||||||
|
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
|
||||||
|
# Check if targeting system directories
|
||||||
|
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
|
||||||
|
if any(sysdir in cmd_lower for sysdir in system_dirs):
|
||||||
|
return "This command will recursively delete system directories"
|
||||||
|
|
||||||
|
return "" # No warning needed
|
||||||
59
agent/tools/browser/browser_action.py
Normal file
59
agent/tools/browser/browser_action.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
class BrowserAction:
|
||||||
|
"""Base class for browser actions"""
|
||||||
|
code = ""
|
||||||
|
description = ""
|
||||||
|
|
||||||
|
|
||||||
|
class Navigate(BrowserAction):
|
||||||
|
"""Navigate to a URL in the current tab"""
|
||||||
|
code = "navigate"
|
||||||
|
description = "Navigate to URL in the current tab"
|
||||||
|
|
||||||
|
|
||||||
|
class ClickElement(BrowserAction):
|
||||||
|
"""Click an element on the page"""
|
||||||
|
code = "click_element"
|
||||||
|
description = "Click element"
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractContent(BrowserAction):
|
||||||
|
"""Extract content from the page"""
|
||||||
|
code = "extract_content"
|
||||||
|
description = "Extract the page content to retrieve specific information for a goal"
|
||||||
|
|
||||||
|
|
||||||
|
class InputText(BrowserAction):
|
||||||
|
"""Input text into an element"""
|
||||||
|
code = "input_text"
|
||||||
|
description = "Input text into a input interactive element"
|
||||||
|
|
||||||
|
|
||||||
|
class ScrollDown(BrowserAction):
|
||||||
|
"""Scroll down the page"""
|
||||||
|
code = "scroll_down"
|
||||||
|
description = "Scroll down the page by pixel amount"
|
||||||
|
|
||||||
|
|
||||||
|
class ScrollUp(BrowserAction):
|
||||||
|
"""Scroll up the page"""
|
||||||
|
code = "scroll_up"
|
||||||
|
description = "Scroll up the page by pixel amount - if no amount is specified, scroll up one page"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenTab(BrowserAction):
|
||||||
|
"""Open a URL in a new tab"""
|
||||||
|
code = "open_tab"
|
||||||
|
description = "Open url in new tab"
|
||||||
|
|
||||||
|
|
||||||
|
class SwitchTab(BrowserAction):
|
||||||
|
"""Switch to a tab"""
|
||||||
|
code = "switch_tab"
|
||||||
|
description = "Switched to tab"
|
||||||
|
|
||||||
|
|
||||||
|
class SendKeys(BrowserAction):
|
||||||
|
"""Switch to a tab"""
|
||||||
|
code = "send_keys"
|
||||||
|
description = "Send strings of special keyboard keys like Escape, Backspace, Insert, PageDown, Delete, Enter, " \
|
||||||
|
"ArrowRight, ArrowUp, etc"
|
||||||
317
agent/tools/browser/browser_tool.py
Normal file
317
agent/tools/browser/browser_tool.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
from browser_use import Browser
|
||||||
|
from browser_use import BrowserConfig
|
||||||
|
from browser_use.browser.context import BrowserContext, BrowserContextConfig
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
from agent.tools.browser.browser_action import *
|
||||||
|
from agent.models import LLMRequest
|
||||||
|
from agent.models.model_factory import ModelFactory
|
||||||
|
from browser_use.dom.service import DomService
|
||||||
|
from common.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
# Use lazy import, only import when actually used
|
||||||
|
def _import_browser_use():
|
||||||
|
try:
|
||||||
|
import browser_use
|
||||||
|
return browser_use
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'browser-use' package is required to use BrowserTool. "
|
||||||
|
"Please install it with 'pip install browser-use>=0.1.40' or "
|
||||||
|
"'pip install agentmesh-sdk[full]'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_action_prompt():
|
||||||
|
action_classes = [Navigate, ClickElement, ExtractContent, InputText, OpenTab, SwitchTab, ScrollDown, ScrollUp,
|
||||||
|
SendKeys]
|
||||||
|
action_prompt = ""
|
||||||
|
for action_class in action_classes:
|
||||||
|
action_prompt += f"{action_class.code}: {action_class.description}\n"
|
||||||
|
return action_prompt.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _header_less() -> bool:
|
||||||
|
if platform.system() == "Linux" and not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserTool(BaseTool):
|
||||||
|
name: str = "browser"
|
||||||
|
description: str = "A tool to perform browser operations like navigating to URLs, element interaction, " \
|
||||||
|
"and extracting content."
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"The browser operation to perform: \n{_get_action_prompt()}"
|
||||||
|
},
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"The URL to navigate to (required for '{Navigate.code}', '{OpenTab.code}' actions). "
|
||||||
|
},
|
||||||
|
"goal": {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"The goal of extracting page content (required for '{ExtractContent.code}' action)."
|
||||||
|
},
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"Text to type (required for '{InputText.code}' action)."
|
||||||
|
},
|
||||||
|
"index": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": f"Element index (required for '{ClickElement.code}', '{InputText.code}' actions)",
|
||||||
|
},
|
||||||
|
"tab_id": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": f"Page tab ID (required for '{SwitchTab.code}' action)",
|
||||||
|
},
|
||||||
|
"scroll_amount": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": f"The number of pixels to scroll (required for '{ScrollDown.code}', '{ScrollUp.code}' action)."
|
||||||
|
},
|
||||||
|
"keys": {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"Keys to send (required for '{SendKeys.code}' action)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["operation"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Class variable to ensure only one browser instance is created
|
||||||
|
browser = None
|
||||||
|
browser_context: BrowserContext = None
|
||||||
|
dom_service: DomService = None
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
# Adding an event loop variable
|
||||||
|
_event_loop = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Only import during initialization, not at module level
|
||||||
|
self.browser_use = _import_browser_use()
|
||||||
|
# Do not initialize the browser in the constructor, but initialize it on the first execution
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _init_browser(self) -> BrowserContext:
|
||||||
|
"""Ensure the browser is initialized"""
|
||||||
|
if not BrowserTool._initialized:
|
||||||
|
os.environ['BROWSER_USE_LOGGING_LEVEL'] = 'error'
|
||||||
|
print("Initializing browser...")
|
||||||
|
# Initialize the browser synchronously
|
||||||
|
BrowserTool.browser = Browser(BrowserConfig(headless=_header_less(),
|
||||||
|
disable_security=True))
|
||||||
|
context_config = BrowserContextConfig()
|
||||||
|
context_config.highlight_elements = True
|
||||||
|
BrowserTool.browser_context = await BrowserTool.browser.new_context(context_config)
|
||||||
|
BrowserTool._initialized = True
|
||||||
|
print("Browser initialized successfully")
|
||||||
|
BrowserTool.dom_service = DomService(await BrowserTool.browser_context.get_current_page())
|
||||||
|
return BrowserTool.browser_context
|
||||||
|
|
||||||
|
def execute(self, params: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Execute browser operations based on the provided arguments.
|
||||||
|
|
||||||
|
:param params: Dictionary containing the action and related parameters
|
||||||
|
:return: Result of the browser operation
|
||||||
|
"""
|
||||||
|
# Ensure browser_use is imported
|
||||||
|
if not hasattr(self, 'browser_use'):
|
||||||
|
self.browser_use = _import_browser_use()
|
||||||
|
action = params.get("operation", "").lower()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use a single event loop
|
||||||
|
if BrowserTool._event_loop is None:
|
||||||
|
BrowserTool._event_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(BrowserTool._event_loop)
|
||||||
|
# Run tasks in the existing event loop
|
||||||
|
return BrowserTool._event_loop.run_until_complete(self._execute_async(action, params))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error executing browser action: {e}")
|
||||||
|
return ToolResult.fail(result=f"Error executing browser action: {str(e)}")
|
||||||
|
|
||||||
|
async def _get_page_state(self, context: BrowserContext):
|
||||||
|
state = await self._get_state(context)
|
||||||
|
include_attributes = ["img", "div", "button", "input"]
|
||||||
|
elements = state.element_tree.clickable_elements_to_string(include_attributes)
|
||||||
|
pattern = r'\[\d+\]<[^>]+\/>'
|
||||||
|
# Find all matching elements
|
||||||
|
interactive_elements = re.findall(pattern, elements)
|
||||||
|
page_state = {
|
||||||
|
"url": state.url,
|
||||||
|
"title": state.title,
|
||||||
|
"pixels_above": getattr(state, "pixels_above", 0),
|
||||||
|
"pixels_below": getattr(state, "pixels_below", 0),
|
||||||
|
"tabs": [tab.model_dump() for tab in state.tabs],
|
||||||
|
"interactive_elements": interactive_elements,
|
||||||
|
}
|
||||||
|
return page_state
|
||||||
|
|
||||||
|
async def _get_state(self, context: BrowserContext, cache_clickable_elements_hashes=True):
|
||||||
|
try:
|
||||||
|
return await context.get_state()
|
||||||
|
except TypeError:
|
||||||
|
return await context.get_state(cache_clickable_elements_hashes=cache_clickable_elements_hashes)
|
||||||
|
|
||||||
|
async def _get_page_info(self, context: BrowserContext):
|
||||||
|
page_state = await self._get_page_state(context)
|
||||||
|
state_str = f"""## Current browser state
|
||||||
|
The following is the information of the current browser page. Each serial number in interactive_elements represents the element index:
|
||||||
|
{json.dumps(page_state, indent=4, ensure_ascii=False)}
|
||||||
|
"""
|
||||||
|
return state_str
|
||||||
|
|
||||||
|
async def _execute_async(self, action: str, params: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""Asynchronously execute browser operations"""
|
||||||
|
# Use the browser context from the class variable
|
||||||
|
context = await self._init_browser()
|
||||||
|
|
||||||
|
if action == Navigate.code:
|
||||||
|
url = params.get("url")
|
||||||
|
if not url:
|
||||||
|
return ToolResult.fail(result="URL is required for navigate action")
|
||||||
|
if url.startswith("/"):
|
||||||
|
url = f"file://{url}"
|
||||||
|
print(f"Navigating to {url}...")
|
||||||
|
page = await context.get_current_page()
|
||||||
|
await page.goto(url)
|
||||||
|
await page.wait_for_load_state()
|
||||||
|
state = await self._get_page_info(context)
|
||||||
|
# print(state)
|
||||||
|
print(f"Navigation complete")
|
||||||
|
return ToolResult.success(result=f"Navigated to {url}", ext_data=state)
|
||||||
|
|
||||||
|
elif action == OpenTab.code:
|
||||||
|
url = params.get("url")
|
||||||
|
if url.startswith("/"):
|
||||||
|
url = f"file://{url}"
|
||||||
|
await context.create_new_tab(url)
|
||||||
|
msg = f"Opened new tab with {url}"
|
||||||
|
return ToolResult.success(result=msg)
|
||||||
|
|
||||||
|
elif action == ExtractContent.code:
|
||||||
|
try:
|
||||||
|
goal = params.get("goal")
|
||||||
|
page = await context.get_current_page()
|
||||||
|
if params.get("url"):
|
||||||
|
await page.goto(params.get("url"))
|
||||||
|
await page.wait_for_load_state()
|
||||||
|
import markdownify
|
||||||
|
content = markdownify.markdownify(await page.content())
|
||||||
|
elements = await self._get_page_state(context)
|
||||||
|
prompt = f"Your task is to extract the content of the page. You will be given a page and a goal and you should extract all relevant information around this goal from the page. If the goal is vague, " \
|
||||||
|
f"summarize the page. Respond in json format. elements: {elements.get('interactive_elements')}, extraction goal: {goal}, Page: {content},"
|
||||||
|
request = LLMRequest(
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0,
|
||||||
|
json_format=True
|
||||||
|
)
|
||||||
|
model = self.model or ModelFactory().get_model(model_name="gpt-4o")
|
||||||
|
response = model.call(request)
|
||||||
|
if response.success:
|
||||||
|
extract_content = response.data["choices"][0]["message"]["content"]
|
||||||
|
print(f"Extract from page: {extract_content}")
|
||||||
|
return ToolResult.success(result=f"Extract from page: {extract_content}",
|
||||||
|
ext_data=await self._get_page_info(context))
|
||||||
|
else:
|
||||||
|
return ToolResult.fail(result=f"Extract from page failed: {response.get_error_msg()}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
|
||||||
|
elif action == ClickElement.code:
|
||||||
|
index = params.get("index")
|
||||||
|
element = await context.get_dom_element_by_index(index)
|
||||||
|
await context._click_element_node(element)
|
||||||
|
msg = f"Clicked element at index {index}"
|
||||||
|
print(msg)
|
||||||
|
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
||||||
|
|
||||||
|
elif action == InputText.code:
|
||||||
|
index = params.get("index")
|
||||||
|
text = params.get("text")
|
||||||
|
element = await context.get_dom_element_by_index(index)
|
||||||
|
await context._input_text_element_node(element, text)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
msg = f"Input text into element successfully, index: {index}, text: {text}"
|
||||||
|
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
||||||
|
|
||||||
|
elif action == SwitchTab.code:
|
||||||
|
tab_id = params.get("tab_id")
|
||||||
|
print(f"Switch tab, tab_id={tab_id}")
|
||||||
|
await context.switch_to_tab(tab_id)
|
||||||
|
page = await context.get_current_page()
|
||||||
|
await page.wait_for_load_state()
|
||||||
|
msg = f"Switched to tab {tab_id}"
|
||||||
|
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
||||||
|
|
||||||
|
elif action in [ScrollDown.code, ScrollUp.code]:
|
||||||
|
scroll_amount = params.get("scroll_amount")
|
||||||
|
if not scroll_amount:
|
||||||
|
scroll_amount = context.config.browser_window_size["height"]
|
||||||
|
print(f"Scrolling by {scroll_amount} pixels")
|
||||||
|
scroll_amount = scroll_amount if action == ScrollDown.code else (scroll_amount * -1)
|
||||||
|
await context.execute_javascript(f"window.scrollBy(0, {scroll_amount});")
|
||||||
|
msg = f"{action} by {scroll_amount} pixels"
|
||||||
|
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
||||||
|
|
||||||
|
elif action == SendKeys.code:
|
||||||
|
keys = params.get("keys")
|
||||||
|
page = await context.get_current_page()
|
||||||
|
await page.keyboard.press(keys)
|
||||||
|
msg = f"Sent keys: {keys}"
|
||||||
|
print(msg)
|
||||||
|
return ToolResult(output=f"Sent keys: {keys}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
msg = "Failed to operate the browser"
|
||||||
|
return ToolResult.fail(result=msg)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Close browser resources.
|
||||||
|
This method handles the asynchronous closing of browser and browser context.
|
||||||
|
"""
|
||||||
|
if not BrowserTool._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the existing event loop to close browser resources
|
||||||
|
if BrowserTool._event_loop is not None:
|
||||||
|
# Define the async close function
|
||||||
|
async def close_browser_async():
|
||||||
|
if BrowserTool.browser_context is not None:
|
||||||
|
try:
|
||||||
|
await BrowserTool.browser_context.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing browser context: {e}")
|
||||||
|
|
||||||
|
if BrowserTool.browser is not None:
|
||||||
|
try:
|
||||||
|
await BrowserTool.browser.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing browser: {e}")
|
||||||
|
|
||||||
|
# Reset the initialized flag
|
||||||
|
BrowserTool._initialized = False
|
||||||
|
BrowserTool.browser = None
|
||||||
|
BrowserTool.browser_context = None
|
||||||
|
BrowserTool.dom_service = None
|
||||||
|
|
||||||
|
# Run the async close function in the existing event loop
|
||||||
|
BrowserTool._event_loop.run_until_complete(close_browser_async())
|
||||||
|
|
||||||
|
# Close the event loop
|
||||||
|
BrowserTool._event_loop.close()
|
||||||
|
BrowserTool._event_loop = None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during browser cleanup: {e}")
|
||||||
18
agent/tools/browser_tool.py
Normal file
18
agent/tools/browser_tool.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
def copy(self):
|
||||||
|
"""
|
||||||
|
Special copy method for browser tool to avoid recreating browser instance.
|
||||||
|
|
||||||
|
:return: A new instance with shared browser reference but unique model
|
||||||
|
"""
|
||||||
|
new_tool = self.__class__()
|
||||||
|
|
||||||
|
# Copy essential attributes
|
||||||
|
new_tool.model = self.model
|
||||||
|
new_tool.context = getattr(self, 'context', None)
|
||||||
|
new_tool.config = getattr(self, 'config', None)
|
||||||
|
|
||||||
|
# Share the browser instance instead of creating a new one
|
||||||
|
if hasattr(self, 'browser'):
|
||||||
|
new_tool.browser = self.browser
|
||||||
|
|
||||||
|
return new_tool
|
||||||
58
agent/tools/calculator/calculator.py
Normal file
58
agent/tools/calculator/calculator.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class Calculator(BaseTool):
|
||||||
|
name: str = "calculator"
|
||||||
|
description: str = "A tool to perform basic mathematical calculations."
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The mathematical expression to evaluate (e.g., '2 + 2', '5 * 3', 'sqrt(16)'). "
|
||||||
|
"Ensure your input is a valid Python expression, it will be evaluated directly."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["expression"]
|
||||||
|
}
|
||||||
|
config: dict = {}
|
||||||
|
|
||||||
|
def execute(self, args: dict) -> ToolResult:
|
||||||
|
try:
|
||||||
|
# Get the expression
|
||||||
|
expression = args["expression"]
|
||||||
|
|
||||||
|
# Create a safe local environment containing only basic math functions
|
||||||
|
safe_locals = {
|
||||||
|
"abs": abs,
|
||||||
|
"round": round,
|
||||||
|
"max": max,
|
||||||
|
"min": min,
|
||||||
|
"pow": pow,
|
||||||
|
"sqrt": math.sqrt,
|
||||||
|
"sin": math.sin,
|
||||||
|
"cos": math.cos,
|
||||||
|
"tan": math.tan,
|
||||||
|
"pi": math.pi,
|
||||||
|
"e": math.e,
|
||||||
|
"log": math.log,
|
||||||
|
"log10": math.log10,
|
||||||
|
"exp": math.exp,
|
||||||
|
"floor": math.floor,
|
||||||
|
"ceil": math.ceil
|
||||||
|
}
|
||||||
|
|
||||||
|
# Safely evaluate the expression
|
||||||
|
result = eval(expression, {"__builtins__": {}}, safe_locals)
|
||||||
|
|
||||||
|
return ToolResult.success({
|
||||||
|
"result": result,
|
||||||
|
"expression": expression
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.success({
|
||||||
|
"error": str(e),
|
||||||
|
"expression": args.get("expression", "")
|
||||||
|
})
|
||||||
75
agent/tools/current_time/current_time.py
Normal file
75
agent/tools/current_time/current_time.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class CurrentTime(BaseTool):
|
||||||
|
name: str = "time"
|
||||||
|
description: str = "A tool to get current date and time information."
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"format": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional format for the time (e.g., 'iso', 'unix', 'human'). Default is 'human'."
|
||||||
|
},
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional timezone specification (e.g., 'UTC', 'local'). Default is 'local'."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
config: dict = {}
|
||||||
|
|
||||||
|
def execute(self, args: dict) -> ToolResult:
|
||||||
|
try:
|
||||||
|
# Get the format and timezone parameters, with defaults
|
||||||
|
time_format = args.get("format", "human").lower()
|
||||||
|
timezone = args.get("timezone", "local").lower()
|
||||||
|
|
||||||
|
# Get current time
|
||||||
|
current_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
# Handle timezone if specified
|
||||||
|
if timezone == "utc":
|
||||||
|
current_time = datetime.datetime.utcnow()
|
||||||
|
|
||||||
|
# Format the time according to the specified format
|
||||||
|
if time_format == "iso":
|
||||||
|
# ISO 8601 format
|
||||||
|
formatted_time = current_time.isoformat()
|
||||||
|
elif time_format == "unix":
|
||||||
|
# Unix timestamp (seconds since epoch)
|
||||||
|
formatted_time = time.time()
|
||||||
|
else:
|
||||||
|
# Human-readable format
|
||||||
|
formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
# Prepare additional time components for the response
|
||||||
|
year = current_time.year
|
||||||
|
month = current_time.month
|
||||||
|
day = current_time.day
|
||||||
|
hour = current_time.hour
|
||||||
|
minute = current_time.minute
|
||||||
|
second = current_time.second
|
||||||
|
weekday = current_time.strftime("%A") # Full weekday name
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"current_time": formatted_time,
|
||||||
|
"components": {
|
||||||
|
"year": year,
|
||||||
|
"month": month,
|
||||||
|
"day": day,
|
||||||
|
"hour": hour,
|
||||||
|
"minute": minute,
|
||||||
|
"second": second,
|
||||||
|
"weekday": weekday
|
||||||
|
},
|
||||||
|
"format": time_format,
|
||||||
|
"timezone": timezone
|
||||||
|
}
|
||||||
|
return ToolResult.success(result=result)
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(result=str(e))
|
||||||
3
agent/tools/edit/__init__.py
Normal file
3
agent/tools/edit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .edit import Edit
|
||||||
|
|
||||||
|
__all__ = ['Edit']
|
||||||
164
agent/tools/edit/edit.py
Normal file
164
agent/tools/edit/edit.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""
|
||||||
|
Edit tool - Precise file editing
|
||||||
|
Edit files through exact text replacement
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
from agent.tools.utils.diff import (
|
||||||
|
strip_bom,
|
||||||
|
detect_line_ending,
|
||||||
|
normalize_to_lf,
|
||||||
|
restore_line_endings,
|
||||||
|
normalize_for_fuzzy_match,
|
||||||
|
fuzzy_find_text,
|
||||||
|
generate_diff_string
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Edit(BaseTool):
|
||||||
|
"""Tool for precise file editing"""
|
||||||
|
|
||||||
|
name: str = "edit"
|
||||||
|
description: str = "Edit a file by replacing exact text. The oldText must match exactly (including whitespace). Use this for precise, surgical edits."
|
||||||
|
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to the file to edit (relative or absolute)"
|
||||||
|
},
|
||||||
|
"oldText": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Exact text to find and replace (must match exactly)"
|
||||||
|
},
|
||||||
|
"newText": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "New text to replace the old text with"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "oldText", "newText"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: dict = None):
|
||||||
|
self.config = config or {}
|
||||||
|
self.cwd = self.config.get("cwd", os.getcwd())
|
||||||
|
|
||||||
|
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Execute file edit operation
|
||||||
|
|
||||||
|
:param args: Contains file path, old text and new text
|
||||||
|
:return: Operation result
|
||||||
|
"""
|
||||||
|
path = args.get("path", "").strip()
|
||||||
|
old_text = args.get("oldText", "")
|
||||||
|
new_text = args.get("newText", "")
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return ToolResult.fail("Error: path parameter is required")
|
||||||
|
|
||||||
|
# Resolve path
|
||||||
|
absolute_path = self._resolve_path(path)
|
||||||
|
|
||||||
|
# Check if file exists
|
||||||
|
if not os.path.exists(absolute_path):
|
||||||
|
return ToolResult.fail(f"Error: File not found: {path}")
|
||||||
|
|
||||||
|
# Check if readable/writable
|
||||||
|
if not os.access(absolute_path, os.R_OK | os.W_OK):
|
||||||
|
return ToolResult.fail(f"Error: File is not readable/writable: {path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read file
|
||||||
|
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||||
|
raw_content = f.read()
|
||||||
|
|
||||||
|
# Remove BOM (LLM won't include invisible BOM in oldText)
|
||||||
|
bom, content = strip_bom(raw_content)
|
||||||
|
|
||||||
|
# Detect original line ending
|
||||||
|
original_ending = detect_line_ending(content)
|
||||||
|
|
||||||
|
# Normalize to LF
|
||||||
|
normalized_content = normalize_to_lf(content)
|
||||||
|
normalized_old_text = normalize_to_lf(old_text)
|
||||||
|
normalized_new_text = normalize_to_lf(new_text)
|
||||||
|
|
||||||
|
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
|
||||||
|
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
|
||||||
|
|
||||||
|
if not match_result.found:
|
||||||
|
return ToolResult.fail(
|
||||||
|
f"Error: Could not find the exact text in {path}. "
|
||||||
|
"The old text must match exactly including all whitespace and newlines."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate occurrence count (use fuzzy normalized content for consistency)
|
||||||
|
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
|
||||||
|
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
|
||||||
|
occurrences = fuzzy_content.count(fuzzy_old_text)
|
||||||
|
|
||||||
|
if occurrences > 1:
|
||||||
|
return ToolResult.fail(
|
||||||
|
f"Error: Found {occurrences} occurrences of the text in {path}. "
|
||||||
|
"The text must be unique. Please provide more context to make it unique."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute replacement (use matched text position)
|
||||||
|
base_content = match_result.content_for_replacement
|
||||||
|
new_content = (
|
||||||
|
base_content[:match_result.index] +
|
||||||
|
normalized_new_text +
|
||||||
|
base_content[match_result.index + match_result.match_length:]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify replacement actually changed content
|
||||||
|
if base_content == new_content:
|
||||||
|
return ToolResult.fail(
|
||||||
|
f"Error: No changes made to {path}. "
|
||||||
|
"The replacement produced identical content. "
|
||||||
|
"This might indicate an issue with special characters or the text not existing as expected."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restore original line endings
|
||||||
|
final_content = bom + restore_line_endings(new_content, original_ending)
|
||||||
|
|
||||||
|
# Write file
|
||||||
|
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(final_content)
|
||||||
|
|
||||||
|
# Generate diff
|
||||||
|
diff_result = generate_diff_string(base_content, new_content)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"message": f"Successfully replaced text in {path}",
|
||||||
|
"path": path,
|
||||||
|
"diff": diff_result['diff'],
|
||||||
|
"first_changed_line": diff_result['first_changed_line']
|
||||||
|
}
|
||||||
|
|
||||||
|
return ToolResult.success(result)
|
||||||
|
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {path}")
|
||||||
|
except PermissionError:
|
||||||
|
return ToolResult.fail(f"Error: Permission denied accessing {path}")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error editing file: {str(e)}")
|
||||||
|
|
||||||
|
def _resolve_path(self, path: str) -> str:
|
||||||
|
"""
|
||||||
|
Resolve path to absolute path
|
||||||
|
|
||||||
|
:param path: Relative or absolute path
|
||||||
|
:return: Absolute path
|
||||||
|
"""
|
||||||
|
# Expand ~ to user home directory
|
||||||
|
path = os.path.expanduser(path)
|
||||||
|
if os.path.isabs(path):
|
||||||
|
return path
|
||||||
|
return os.path.abspath(os.path.join(self.cwd, path))
|
||||||
3
agent/tools/file_save/__init__.py
Normal file
3
agent/tools/file_save/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .file_save import FileSave
|
||||||
|
|
||||||
|
__all__ = ['FileSave']
|
||||||
770
agent/tools/file_save/file_save.py
Normal file
770
agent/tools/file_save/file_save.py
Normal file
@@ -0,0 +1,770 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult, ToolStage
|
||||||
|
from agent.models import LLMRequest
|
||||||
|
from common.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
class FileSave(BaseTool):
|
||||||
|
"""Tool for saving content to files in the workspace directory."""
|
||||||
|
|
||||||
|
name = "file_save"
|
||||||
|
description = "Save the agent's output to a file in the workspace directory. Content is automatically extracted from the agent's previous outputs."
|
||||||
|
|
||||||
|
# Set as post-process stage tool
|
||||||
|
stage = ToolStage.POST_PROCESS
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"file_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional. The name of the file to save. If not provided, a name will be generated based on the content."
|
||||||
|
},
|
||||||
|
"file_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional. The type/extension of the file (e.g., 'txt', 'md', 'py', 'java'). If not provided, it will be inferred from the content."
|
||||||
|
},
|
||||||
|
"extract_code": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Optional. If true, will attempt to extract code blocks from the content. Default is false."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [] # No required fields, as everything can be extracted from context
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.context = None
|
||||||
|
self.config = {}
|
||||||
|
self.workspace_dir = Path("workspace")
|
||||||
|
|
||||||
|
def execute(self, params: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Save content to a file in the workspace directory.
|
||||||
|
|
||||||
|
:param params: The parameters for the file output operation.
|
||||||
|
:return: Result of the operation.
|
||||||
|
"""
|
||||||
|
# Extract content from context
|
||||||
|
if not hasattr(self, 'context') or not self.context:
|
||||||
|
return ToolResult.fail("Error: No context available to extract content from.")
|
||||||
|
|
||||||
|
content = self._extract_content_from_context()
|
||||||
|
|
||||||
|
# If no content could be extracted, return error
|
||||||
|
if not content:
|
||||||
|
return ToolResult.fail("Error: Couldn't extract content from context.")
|
||||||
|
|
||||||
|
# Use model to determine file parameters
|
||||||
|
try:
|
||||||
|
task_dir = self._get_task_dir_from_context()
|
||||||
|
file_name, file_type, extract_code = self._get_file_params_from_model(content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error determining file parameters: {str(e)}")
|
||||||
|
# Fall back to manual parameter extraction
|
||||||
|
task_dir = params.get("task_dir") or self._get_task_id_from_context() or f"task_{int(time.time())}"
|
||||||
|
file_name = params.get("file_name") or self._infer_file_name(content)
|
||||||
|
file_type = params.get("file_type") or self._infer_file_type(content)
|
||||||
|
extract_code = params.get("extract_code", False)
|
||||||
|
|
||||||
|
# Get team_name from context
|
||||||
|
team_name = self._get_team_name_from_context() or "default_team"
|
||||||
|
|
||||||
|
# Create directory structure
|
||||||
|
task_dir_path = self.workspace_dir / team_name / task_dir
|
||||||
|
task_dir_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if extract_code:
|
||||||
|
# Save the complete content as markdown
|
||||||
|
md_file_name = f"{file_name}.md"
|
||||||
|
md_file_path = task_dir_path / md_file_name
|
||||||
|
|
||||||
|
# Write content to file
|
||||||
|
with open(md_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
return self._handle_multiple_code_blocks(content)
|
||||||
|
|
||||||
|
# Ensure file_name has the correct extension
|
||||||
|
if file_type and not file_name.endswith(f".{file_type}"):
|
||||||
|
file_name = f"{file_name}.{file_type}"
|
||||||
|
|
||||||
|
# Create the full file path
|
||||||
|
file_path = task_dir_path / file_name
|
||||||
|
|
||||||
|
# Get absolute path for storage in team_context
|
||||||
|
abs_file_path = file_path.absolute()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Write content to file
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# Update the current agent's final_answer to include file information
|
||||||
|
if hasattr(self.context, 'team_context'):
|
||||||
|
# Store with absolute path in team_context
|
||||||
|
self.context.team_context.agent_outputs[-1].output += f"\n\nSaved file: {abs_file_path}"
|
||||||
|
|
||||||
|
return ToolResult.success({
|
||||||
|
"status": "success",
|
||||||
|
"file_path": str(file_path) # Return relative path in result
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error saving file: {str(e)}")
|
||||||
|
|
||||||
|
def _handle_multiple_code_blocks(self, content: str) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Handle content with multiple code blocks, extracting and saving each as a separate file.
|
||||||
|
|
||||||
|
:param content: The content containing multiple code blocks
|
||||||
|
:return: Result of the operation
|
||||||
|
"""
|
||||||
|
# Extract code blocks with context (including potential file name information)
|
||||||
|
code_blocks_with_context = self._extract_code_blocks_with_context(content)
|
||||||
|
|
||||||
|
if not code_blocks_with_context:
|
||||||
|
return ToolResult.fail("No code blocks found in the content.")
|
||||||
|
|
||||||
|
# Get task directory and team name
|
||||||
|
task_dir = self._get_task_dir_from_context() or f"task_{int(time.time())}"
|
||||||
|
team_name = self._get_team_name_from_context() or "default_team"
|
||||||
|
|
||||||
|
# Create directory structure
|
||||||
|
task_dir_path = self.workspace_dir / team_name / task_dir
|
||||||
|
task_dir_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
saved_files = []
|
||||||
|
|
||||||
|
for block_with_context in code_blocks_with_context:
|
||||||
|
try:
|
||||||
|
# Use model to determine file name for this code block
|
||||||
|
block_file_name, block_file_type = self._get_filename_for_code_block(block_with_context)
|
||||||
|
|
||||||
|
# Clean the code block (remove md code markers)
|
||||||
|
clean_code = self._clean_code_block(block_with_context)
|
||||||
|
|
||||||
|
# Ensure file_name has the correct extension
|
||||||
|
if block_file_type and not block_file_name.endswith(f".{block_file_type}"):
|
||||||
|
block_file_name = f"{block_file_name}.{block_file_type}"
|
||||||
|
|
||||||
|
# Create the full file path (no subdirectories)
|
||||||
|
file_path = task_dir_path / block_file_name
|
||||||
|
|
||||||
|
# Get absolute path for storage in team_context
|
||||||
|
abs_file_path = file_path.absolute()
|
||||||
|
|
||||||
|
# Write content to file
|
||||||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(clean_code)
|
||||||
|
|
||||||
|
saved_files.append({
|
||||||
|
"file_path": str(file_path),
|
||||||
|
"abs_file_path": str(abs_file_path), # Store absolute path for internal use
|
||||||
|
"file_name": block_file_name,
|
||||||
|
"size": len(clean_code),
|
||||||
|
"status": "success",
|
||||||
|
"type": "code"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving code block: {str(e)}")
|
||||||
|
# Continue with the next block even if this one fails
|
||||||
|
|
||||||
|
if not saved_files:
|
||||||
|
return ToolResult.fail("Failed to save any code blocks.")
|
||||||
|
|
||||||
|
# Update the current agent's final_answer to include files information
|
||||||
|
if hasattr(self, 'context') and self.context:
|
||||||
|
# If the agent has a final_answer attribute, append the files info to it
|
||||||
|
if hasattr(self.context, 'team_context'):
|
||||||
|
# Use relative paths for display
|
||||||
|
display_info = f"\n\nSaved files to {task_dir_path}:\n" + "\n".join(
|
||||||
|
[f"- {f['file_path']}" for f in saved_files])
|
||||||
|
|
||||||
|
# Check if we need to append the info
|
||||||
|
if not self.context.team_context.agent_outputs[-1].output.endswith(display_info):
|
||||||
|
# Store with absolute paths in team_context
|
||||||
|
abs_info = f"\n\nSaved files to {task_dir_path.absolute()}:\n" + "\n".join(
|
||||||
|
[f"- {f['abs_file_path']}" for f in saved_files])
|
||||||
|
self.context.team_context.agent_outputs[-1].output += abs_info
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"status": "success",
|
||||||
|
"files": [{"file_path": f["file_path"]} for f in saved_files]
|
||||||
|
}
|
||||||
|
|
||||||
|
return ToolResult.success(result)
|
||||||
|
|
||||||
|
def _extract_code_blocks_with_context(self, content: str) -> list:
|
||||||
|
"""
|
||||||
|
Extract code blocks from content, including context lines before the block.
|
||||||
|
|
||||||
|
:param content: The content to extract code blocks from
|
||||||
|
:return: List of code blocks with context
|
||||||
|
"""
|
||||||
|
# Check if content starts with <!DOCTYPE or <html - likely a full HTML file
|
||||||
|
if content.strip().startswith(("<!DOCTYPE", "<html", "<?xml")):
|
||||||
|
return [content] # Return the entire content as a single block
|
||||||
|
|
||||||
|
# Split content into lines
|
||||||
|
lines = content.split('\n')
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
in_code_block = False
|
||||||
|
current_block = []
|
||||||
|
context_lines = []
|
||||||
|
|
||||||
|
# Check if there are any code block markers in the content
|
||||||
|
if not re.search(r'```\w+', content):
|
||||||
|
# If no code block markers and content looks like code, return the entire content
|
||||||
|
if self._is_likely_code(content):
|
||||||
|
return [content]
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.strip().startswith('```'):
|
||||||
|
if in_code_block:
|
||||||
|
# End of code block
|
||||||
|
current_block.append(line)
|
||||||
|
# Only add blocks that have a language specified
|
||||||
|
block_content = '\n'.join(current_block)
|
||||||
|
if re.search(r'```\w+', current_block[0]):
|
||||||
|
# Combine context with code block
|
||||||
|
blocks.append('\n'.join(context_lines + current_block))
|
||||||
|
current_block = []
|
||||||
|
context_lines = []
|
||||||
|
in_code_block = False
|
||||||
|
else:
|
||||||
|
# Start of code block - check if it has a language specified
|
||||||
|
if re.search(r'```\w+', line) and not re.search(r'```language=\s*$', line):
|
||||||
|
# Start of code block with language
|
||||||
|
in_code_block = True
|
||||||
|
current_block = [line]
|
||||||
|
# Keep only the last few context lines
|
||||||
|
context_lines = context_lines[-5:] if context_lines else []
|
||||||
|
|
||||||
|
elif in_code_block:
|
||||||
|
current_block.append(line)
|
||||||
|
else:
|
||||||
|
# Store context lines when not in a code block
|
||||||
|
context_lines.append(line)
|
||||||
|
|
||||||
|
return blocks
|
||||||
|
|
||||||
|
def _get_filename_for_code_block(self, block_with_context: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Determine the file name for a code block.
|
||||||
|
|
||||||
|
:param block_with_context: The code block with context lines
|
||||||
|
:return: Tuple of (file_name, file_type)
|
||||||
|
"""
|
||||||
|
# Define common code file extensions
|
||||||
|
COMMON_CODE_EXTENSIONS = {
|
||||||
|
'py', 'js', 'java', 'c', 'cpp', 'h', 'hpp', 'cs', 'go', 'rb', 'php',
|
||||||
|
'html', 'css', 'ts', 'jsx', 'tsx', 'vue', 'sh', 'sql', 'json', 'xml',
|
||||||
|
'yaml', 'yml', 'md', 'rs', 'swift', 'kt', 'scala', 'pl', 'r', 'lua'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Split the block into lines to examine only the context around code block markers
|
||||||
|
lines = block_with_context.split('\n')
|
||||||
|
|
||||||
|
# Find the code block start marker line index
|
||||||
|
start_marker_idx = -1
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if line.strip().startswith('```') and not line.strip() == '```':
|
||||||
|
start_marker_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if start_marker_idx == -1:
|
||||||
|
# No code block marker found
|
||||||
|
return "", ""
|
||||||
|
|
||||||
|
# Extract the language from the code block marker
|
||||||
|
code_marker = lines[start_marker_idx].strip()
|
||||||
|
language = ""
|
||||||
|
if len(code_marker) > 3:
|
||||||
|
language = code_marker[3:].strip().split('=')[0].strip()
|
||||||
|
|
||||||
|
# Define the context range (5 lines before and 2 after the marker)
|
||||||
|
context_start = max(0, start_marker_idx - 5)
|
||||||
|
context_end = min(len(lines), start_marker_idx + 3)
|
||||||
|
|
||||||
|
# Extract only the relevant context lines
|
||||||
|
context_lines = lines[context_start:context_end]
|
||||||
|
|
||||||
|
# First, check for explicit file headers like "## filename.ext"
|
||||||
|
for line in context_lines:
|
||||||
|
# Match patterns like "## filename.ext" or "# filename.ext"
|
||||||
|
header_match = re.search(r'^\s*#{1,6}\s+([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)\s*$', line)
|
||||||
|
if header_match:
|
||||||
|
file_name = header_match.group(1)
|
||||||
|
file_type = os.path.splitext(file_name)[1].lstrip('.')
|
||||||
|
if file_type in COMMON_CODE_EXTENSIONS:
|
||||||
|
return os.path.splitext(file_name)[0], file_type
|
||||||
|
|
||||||
|
# Simple patterns to match explicit file names in the context
|
||||||
|
file_patterns = [
|
||||||
|
# Match explicit file names in headers or text
|
||||||
|
r'(?:file|filename)[:=\s]+[\'"]?([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)[\'"]?',
|
||||||
|
# Match language=filename.ext in code markers
|
||||||
|
r'language=([a-zA-Z0-9_-]+\.[a-zA-Z0-9]+)',
|
||||||
|
# Match standalone filenames with extensions
|
||||||
|
r'\b([a-zA-Z0-9_-]+\.(py|js|java|c|cpp|h|hpp|cs|go|rb|php|html|css|ts|jsx|tsx|vue|sh|sql|json|xml|yaml|yml|md|rs|swift|kt|scala|pl|r|lua))\b',
|
||||||
|
# Match file paths in comments
|
||||||
|
r'#\s*([a-zA-Z0-9_/-]+\.[a-zA-Z0-9]+)'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check each context line for file name patterns
|
||||||
|
for line in context_lines:
|
||||||
|
line = line.strip()
|
||||||
|
for pattern in file_patterns:
|
||||||
|
matches = re.findall(pattern, line)
|
||||||
|
if matches:
|
||||||
|
for match in matches:
|
||||||
|
if isinstance(match, tuple):
|
||||||
|
# If the match is a tuple (filename, extension)
|
||||||
|
file_name = match[0]
|
||||||
|
file_type = match[1]
|
||||||
|
# Verify it's not a code reference like Direction.DOWN
|
||||||
|
if not any(keyword in file_name for keyword in ['class.', 'enum.', 'import.']):
|
||||||
|
return os.path.splitext(file_name)[0], file_type
|
||||||
|
else:
|
||||||
|
# If the match is a string (full filename)
|
||||||
|
file_name = match
|
||||||
|
file_type = os.path.splitext(file_name)[1].lstrip('.')
|
||||||
|
# Verify it's not a code reference
|
||||||
|
if file_type in COMMON_CODE_EXTENSIONS and not any(
|
||||||
|
keyword in file_name for keyword in ['class.', 'enum.', 'import.']):
|
||||||
|
return os.path.splitext(file_name)[0], file_type
|
||||||
|
|
||||||
|
# If no explicit file name found, use LLM to infer from code content
|
||||||
|
# Extract the code content
|
||||||
|
code_content = block_with_context
|
||||||
|
|
||||||
|
# Get the first 20 lines of code for LLM analysis
|
||||||
|
code_lines = code_content.split('\n')
|
||||||
|
code_preview = '\n'.join(code_lines[:20])
|
||||||
|
|
||||||
|
# Get the model to use
|
||||||
|
model_to_use = None
|
||||||
|
if hasattr(self, 'context') and self.context:
|
||||||
|
if hasattr(self.context, 'model') and self.context.model:
|
||||||
|
model_to_use = self.context.model
|
||||||
|
elif hasattr(self.context, 'team_context') and self.context.team_context:
|
||||||
|
if hasattr(self.context.team_context, 'model') and self.context.team_context.model:
|
||||||
|
model_to_use = self.context.team_context.model
|
||||||
|
|
||||||
|
# If no model is available in context, use the tool's model
|
||||||
|
if not model_to_use and hasattr(self, 'model') and self.model:
|
||||||
|
model_to_use = self.model
|
||||||
|
|
||||||
|
if model_to_use:
|
||||||
|
# Prepare a prompt for the model
|
||||||
|
prompt = f"""Analyze the following code and determine the most appropriate file name and file type/extension.
|
||||||
|
The file name should be descriptive but concise, using snake_case (lowercase with underscores).
|
||||||
|
The file type should be a standard file extension (e.g., py, js, html, css, java).
|
||||||
|
|
||||||
|
Code preview (first 20 lines):
|
||||||
|
{code_preview}
|
||||||
|
|
||||||
|
Return your answer in JSON format with these fields:
|
||||||
|
- file_name: The suggested file name (without extension)
|
||||||
|
- file_type: The suggested file extension
|
||||||
|
|
||||||
|
JSON response:"""
|
||||||
|
|
||||||
|
# Create a request to the model
|
||||||
|
request = LLMRequest(
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0,
|
||||||
|
json_format=True
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = model_to_use.call(request)
|
||||||
|
|
||||||
|
if not response.is_error:
|
||||||
|
# Clean the JSON response
|
||||||
|
json_content = self._clean_json_response(response.data["choices"][0]["message"]["content"])
|
||||||
|
result = json.loads(json_content)
|
||||||
|
|
||||||
|
file_name = result.get("file_name", "")
|
||||||
|
file_type = result.get("file_type", "")
|
||||||
|
|
||||||
|
if file_name and file_type:
|
||||||
|
return file_name, file_type
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error using model to determine file name: {str(e)}")
|
||||||
|
|
||||||
|
# If we still don't have a file name, use the language as file type
|
||||||
|
if language and language in COMMON_CODE_EXTENSIONS:
|
||||||
|
timestamp = int(time.time())
|
||||||
|
return f"code_{timestamp}", language
|
||||||
|
|
||||||
|
# If all else fails, return empty strings
|
||||||
|
return "", ""
|
||||||
|
|
||||||
|
def _clean_json_response(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean JSON response from LLM by removing markdown code block markers.
|
||||||
|
|
||||||
|
:param text: The text containing JSON possibly wrapped in markdown code blocks
|
||||||
|
:return: Clean JSON string
|
||||||
|
"""
|
||||||
|
# Remove markdown code block markers if present
|
||||||
|
if text.startswith("```json"):
|
||||||
|
text = text[7:]
|
||||||
|
elif text.startswith("```"):
|
||||||
|
# Find the first newline to skip the language identifier line
|
||||||
|
first_newline = text.find('\n')
|
||||||
|
if first_newline != -1:
|
||||||
|
text = text[first_newline + 1:]
|
||||||
|
|
||||||
|
if text.endswith("```"):
|
||||||
|
text = text[:-3]
|
||||||
|
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
def _clean_code_block(self, block_with_context: str) -> str:
|
||||||
|
"""
|
||||||
|
Clean a code block by removing markdown code markers and context lines.
|
||||||
|
|
||||||
|
:param block_with_context: Code block with context lines
|
||||||
|
:return: Clean code ready for execution
|
||||||
|
"""
|
||||||
|
# Check if this is a full HTML or XML document
|
||||||
|
if block_with_context.strip().startswith(("<!DOCTYPE", "<html", "<?xml")):
|
||||||
|
return block_with_context
|
||||||
|
|
||||||
|
# Find the code block
|
||||||
|
code_block_match = re.search(r'```(?:\w+)?(?:[:=][^\n]+)?\n([\s\S]*?)\n```', block_with_context)
|
||||||
|
|
||||||
|
if code_block_match:
|
||||||
|
return code_block_match.group(1)
|
||||||
|
|
||||||
|
# If no match found, try to extract anything between ``` markers
|
||||||
|
lines = block_with_context.split('\n')
|
||||||
|
start_idx = None
|
||||||
|
end_idx = None
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if line.strip().startswith('```'):
|
||||||
|
if start_idx is None:
|
||||||
|
start_idx = i
|
||||||
|
else:
|
||||||
|
end_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if start_idx is not None and end_idx is not None:
|
||||||
|
# Extract the code between the markers, excluding the markers themselves
|
||||||
|
code_lines = lines[start_idx + 1:end_idx]
|
||||||
|
return '\n'.join(code_lines)
|
||||||
|
|
||||||
|
# If all else fails, return the original content
|
||||||
|
return block_with_context
|
||||||
|
|
||||||
|
def _get_file_params_from_model(self, content, model=None):
|
||||||
|
"""
|
||||||
|
Use LLM to determine if the content is code and suggest appropriate file parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The content to analyze
|
||||||
|
model: Optional model to use for the analysis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (file_name, file_type, extract_code) for backward compatibility
|
||||||
|
"""
|
||||||
|
if model is None:
|
||||||
|
model = self.model
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
# Default fallback if no model is available
|
||||||
|
return "output", "txt", False
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
Analyze the following content and determine:
|
||||||
|
1. Is this primarily code implementation (where most of the content consists of code blocks)?
|
||||||
|
2. What would be an appropriate filename and file extension?
|
||||||
|
|
||||||
|
Content to analyze: ```
|
||||||
|
{content[:500]} # Only show first 500 chars to avoid token limits ```
|
||||||
|
|
||||||
|
{"..." if len(content) > 500 else ""}
|
||||||
|
|
||||||
|
Respond in JSON format only with the following structure:
|
||||||
|
{{
|
||||||
|
"is_code": true/false, # Whether this is primarily code implementation
|
||||||
|
"filename": "suggested_filename", # Don't include extension, english words
|
||||||
|
"extension": "appropriate_extension" # Don't include the dot, e.g., "md", "py", "js"
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a request to the model
|
||||||
|
request = LLMRequest(
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0.1,
|
||||||
|
json_format=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the model using the standard interface
|
||||||
|
response = model.call(request)
|
||||||
|
|
||||||
|
if response.is_error:
|
||||||
|
logger.warning(f"Error from model: {response.error_message}")
|
||||||
|
raise Exception(f"Model error: {response.error_message}")
|
||||||
|
|
||||||
|
# Extract JSON from response
|
||||||
|
result = response.data["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
# Clean the JSON response
|
||||||
|
result = self._clean_json_response(result)
|
||||||
|
|
||||||
|
# Parse the JSON
|
||||||
|
params = json.loads(result)
|
||||||
|
|
||||||
|
# For backward compatibility, return tuple format
|
||||||
|
file_name = params.get("filename", "output")
|
||||||
|
# Remove dot from extension if present
|
||||||
|
file_type = params.get("extension", "md").lstrip(".")
|
||||||
|
extract_code = params.get("is_code", False)
|
||||||
|
|
||||||
|
return file_name, file_type, extract_code
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error getting file parameters from model: {e}")
|
||||||
|
# Default fallback
|
||||||
|
return "output", "md", False
|
||||||
|
|
||||||
|
def _get_team_name_from_context(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get team name from the agent's context.
|
||||||
|
|
||||||
|
:return: Team name or None if not found
|
||||||
|
"""
|
||||||
|
if hasattr(self, 'context') and self.context:
|
||||||
|
# Try to get team name from team_context
|
||||||
|
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||||
|
return self.context.team_context.name
|
||||||
|
|
||||||
|
# Try direct team_name attribute
|
||||||
|
if hasattr(self.context, 'name'):
|
||||||
|
return self.context.name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_task_id_from_context(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get task ID from the agent's context.
|
||||||
|
|
||||||
|
:return: Task ID or None if not found
|
||||||
|
"""
|
||||||
|
if hasattr(self, 'context') and self.context:
|
||||||
|
# Try to get task ID from task object
|
||||||
|
if hasattr(self.context, 'task') and self.context.task:
|
||||||
|
return self.context.task.id
|
||||||
|
|
||||||
|
# Try team_context's task
|
||||||
|
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||||
|
if hasattr(self.context.team_context, 'task') and self.context.team_context.task:
|
||||||
|
return self.context.team_context.task.id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_task_dir_from_context(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get task directory name from the team context.
|
||||||
|
|
||||||
|
:return: Task directory name or None if not found
|
||||||
|
"""
|
||||||
|
if hasattr(self, 'context') and self.context:
|
||||||
|
# Try to get from team_context
|
||||||
|
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||||
|
if hasattr(self.context.team_context, 'task_short_name') and self.context.team_context.task_short_name:
|
||||||
|
return self.context.team_context.task_short_name
|
||||||
|
|
||||||
|
# Fall back to task ID if available
|
||||||
|
return self._get_task_id_from_context()
|
||||||
|
|
||||||
|
def _extract_content_from_context(self) -> str:
|
||||||
|
"""
|
||||||
|
Extract content from the agent's context.
|
||||||
|
|
||||||
|
:return: Extracted content
|
||||||
|
"""
|
||||||
|
# Check if we have access to the agent's context
|
||||||
|
if not hasattr(self, 'context') or not self.context:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Try to get the most recent final answer from the agent
|
||||||
|
if hasattr(self.context, 'final_answer') and self.context.final_answer:
|
||||||
|
return self.context.final_answer
|
||||||
|
|
||||||
|
# Try to get the most recent final answer from team context
|
||||||
|
if hasattr(self.context, 'team_context') and self.context.team_context:
|
||||||
|
if hasattr(self.context.team_context, 'agent_outputs') and self.context.team_context.agent_outputs:
|
||||||
|
latest_output = self.context.team_context.agent_outputs[-1].output
|
||||||
|
return latest_output
|
||||||
|
|
||||||
|
# If we have action history, try to get the most recent final answer
|
||||||
|
if hasattr(self.context, 'action_history') and self.context.action_history:
|
||||||
|
for action in reversed(self.context.action_history):
|
||||||
|
if "final_answer" in action and action["final_answer"]:
|
||||||
|
return action["final_answer"]
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _extract_code_blocks(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract code blocks from markdown content.
|
||||||
|
|
||||||
|
:param content: The content to extract code blocks from
|
||||||
|
:return: Extracted code blocks
|
||||||
|
"""
|
||||||
|
# Pattern to match markdown code blocks
|
||||||
|
code_block_pattern = r'```(?:\w+)?\n([\s\S]*?)\n```'
|
||||||
|
|
||||||
|
# Find all code blocks
|
||||||
|
code_blocks = re.findall(code_block_pattern, content)
|
||||||
|
|
||||||
|
if code_blocks:
|
||||||
|
# Join all code blocks with newlines
|
||||||
|
return '\n\n'.join(code_blocks)
|
||||||
|
|
||||||
|
return content # Return original content if no code blocks found
|
||||||
|
|
||||||
|
def _infer_file_name(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Infer a file name from the content.
|
||||||
|
|
||||||
|
:param content: The content to analyze.
|
||||||
|
:return: A suggested file name.
|
||||||
|
"""
|
||||||
|
# Check for title patterns in markdown
|
||||||
|
title_match = re.search(r'^#\s+(.+)$', content, re.MULTILINE)
|
||||||
|
if title_match:
|
||||||
|
# Convert title to a valid filename
|
||||||
|
title = title_match.group(1).strip()
|
||||||
|
return self._sanitize_filename(title)
|
||||||
|
|
||||||
|
# Check for class/function definitions in code
|
||||||
|
code_match = re.search(r'(class|def|function)\s+(\w+)', content)
|
||||||
|
if code_match:
|
||||||
|
return self._sanitize_filename(code_match.group(2))
|
||||||
|
|
||||||
|
# Default name based on content type
|
||||||
|
if self._is_likely_code(content):
|
||||||
|
return "code"
|
||||||
|
elif self._is_likely_markdown(content):
|
||||||
|
return "document"
|
||||||
|
elif self._is_likely_json(content):
|
||||||
|
return "data"
|
||||||
|
else:
|
||||||
|
return "output"
|
||||||
|
|
||||||
|
def _infer_file_type(self, content: str) -> str:
|
||||||
|
"""
|
||||||
|
Infer the file type/extension from the content.
|
||||||
|
|
||||||
|
:param content: The content to analyze.
|
||||||
|
:return: A suggested file extension.
|
||||||
|
"""
|
||||||
|
# Check for common programming language patterns
|
||||||
|
if re.search(r'(import\s+[a-zA-Z0-9_]+|from\s+[a-zA-Z0-9_\.]+\s+import)', content):
|
||||||
|
return "py" # Python
|
||||||
|
elif re.search(r'(public\s+class|private\s+class|protected\s+class)', content):
|
||||||
|
return "java" # Java
|
||||||
|
elif re.search(r'(function\s+\w+\s*\(|const\s+\w+\s*=|let\s+\w+\s*=|var\s+\w+\s*=)', content):
|
||||||
|
return "js" # JavaScript
|
||||||
|
elif re.search(r'(<html|<body|<div|<p>)', content):
|
||||||
|
return "html" # HTML
|
||||||
|
elif re.search(r'(#include\s+<\w+\.h>|int\s+main\s*\()', content):
|
||||||
|
return "cpp" # C/C++
|
||||||
|
|
||||||
|
# Check for markdown
|
||||||
|
if self._is_likely_markdown(content):
|
||||||
|
return "md"
|
||||||
|
|
||||||
|
# Check for JSON
|
||||||
|
if self._is_likely_json(content):
|
||||||
|
return "json"
|
||||||
|
|
||||||
|
# Default to text
|
||||||
|
return "txt"
|
||||||
|
|
||||||
|
def _is_likely_code(self, content: str) -> bool:
|
||||||
|
"""Check if the content is likely code."""
|
||||||
|
# First check for common HTML/XML patterns
|
||||||
|
if content.strip().startswith(("<!DOCTYPE", "<html", "<?xml", "<head", "<body")):
|
||||||
|
return True
|
||||||
|
|
||||||
|
code_patterns = [
|
||||||
|
r'(class|def|function|import|from|public|private|protected|#include)',
|
||||||
|
r'(\{\s*\n|\}\s*\n|\[\s*\n|\]\s*\n)',
|
||||||
|
r'(if\s*\(|for\s*\(|while\s*\()',
|
||||||
|
r'(<\w+>.*?</\w+>)', # HTML/XML tags
|
||||||
|
r'(var|let|const)\s+\w+\s*=', # JavaScript variable declarations
|
||||||
|
r'#\s*\w+', # CSS ID selectors or Python comments
|
||||||
|
r'\.\w+\s*\{', # CSS class selectors
|
||||||
|
r'@media|@import|@font-face' # CSS at-rules
|
||||||
|
]
|
||||||
|
return any(re.search(pattern, content) for pattern in code_patterns)
|
||||||
|
|
||||||
|
def _is_likely_markdown(self, content: str) -> bool:
|
||||||
|
"""Check if the content is likely markdown."""
|
||||||
|
md_patterns = [
|
||||||
|
r'^#\s+.+$', # Headers
|
||||||
|
r'^\*\s+.+$', # Unordered lists
|
||||||
|
r'^\d+\.\s+.+$', # Ordered lists
|
||||||
|
r'\[.+\]\(.+\)', # Links
|
||||||
|
r'!\[.+\]\(.+\)' # Images
|
||||||
|
]
|
||||||
|
return any(re.search(pattern, content, re.MULTILINE) for pattern in md_patterns)
|
||||||
|
|
||||||
|
def _is_likely_json(self, content: str) -> bool:
|
||||||
|
"""Check if the content is likely JSON."""
|
||||||
|
try:
|
||||||
|
content = content.strip()
|
||||||
|
if (content.startswith('{') and content.endswith('}')) or (
|
||||||
|
content.startswith('[') and content.endswith(']')):
|
||||||
|
json.loads(content)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _sanitize_filename(self, name: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize a string to be used as a filename.
|
||||||
|
|
||||||
|
:param name: The string to sanitize.
|
||||||
|
:return: A sanitized filename.
|
||||||
|
"""
|
||||||
|
# Replace spaces with underscores
|
||||||
|
name = name.replace(' ', '_')
|
||||||
|
|
||||||
|
# Remove invalid characters
|
||||||
|
name = re.sub(r'[^\w\-\.]', '', name)
|
||||||
|
|
||||||
|
# Limit length
|
||||||
|
if len(name) > 50:
|
||||||
|
name = name[:50]
|
||||||
|
|
||||||
|
return name.lower()
|
||||||
|
|
||||||
|
def _process_file_path(self, file_path: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Process a file path to extract the file name and type, and create directories if needed.
|
||||||
|
|
||||||
|
:param file_path: The file path to process
|
||||||
|
:return: Tuple of (file_name, file_type)
|
||||||
|
"""
|
||||||
|
# Get the file name and extension
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
file_type = os.path.splitext(file_name)[1].lstrip('.')
|
||||||
|
|
||||||
|
return os.path.splitext(file_name)[0], file_type
|
||||||
3
agent/tools/find/__init__.py
Normal file
3
agent/tools/find/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .find import Find
|
||||||
|
|
||||||
|
__all__ = ['Find']
|
||||||
177
agent/tools/find/find.py
Normal file
177
agent/tools/find/find.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""
|
||||||
|
Find tool - Search for files by glob pattern
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import glob as glob_module
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_LIMIT = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class Find(BaseTool):
|
||||||
|
"""Tool for finding files by pattern"""
|
||||||
|
|
||||||
|
name: str = "find"
|
||||||
|
description: str = f"Search for files by glob pattern. Returns matching file paths relative to the search directory. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} results or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||||
|
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"pattern": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Glob pattern to match files, e.g. '*.ts', '**/*.json', or 'src/**/*.spec.ts'"
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Directory to search in (default: current directory)"
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": f"Maximum number of results (default: {DEFAULT_LIMIT})"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["pattern"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: dict = None):
|
||||||
|
self.config = config or {}
|
||||||
|
self.cwd = self.config.get("cwd", os.getcwd())
|
||||||
|
|
||||||
|
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Execute file search
|
||||||
|
|
||||||
|
:param args: Search parameters
|
||||||
|
:return: Search results or error
|
||||||
|
"""
|
||||||
|
pattern = args.get("pattern", "").strip()
|
||||||
|
search_path = args.get("path", ".").strip()
|
||||||
|
limit = args.get("limit", DEFAULT_LIMIT)
|
||||||
|
|
||||||
|
if not pattern:
|
||||||
|
return ToolResult.fail("Error: pattern parameter is required")
|
||||||
|
|
||||||
|
# Resolve search path
|
||||||
|
absolute_path = self._resolve_path(search_path)
|
||||||
|
|
||||||
|
if not os.path.exists(absolute_path):
|
||||||
|
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||||
|
|
||||||
|
if not os.path.isdir(absolute_path):
|
||||||
|
return ToolResult.fail(f"Error: Not a directory: {search_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load .gitignore patterns
|
||||||
|
ignore_patterns = self._load_gitignore(absolute_path)
|
||||||
|
|
||||||
|
# Search for files
|
||||||
|
results = []
|
||||||
|
search_pattern = os.path.join(absolute_path, pattern)
|
||||||
|
|
||||||
|
# Use glob with recursive support
|
||||||
|
for file_path in glob_module.glob(search_pattern, recursive=True):
|
||||||
|
# Skip if matches ignore patterns
|
||||||
|
if self._should_ignore(file_path, absolute_path, ignore_patterns):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get relative path
|
||||||
|
relative_path = os.path.relpath(file_path, absolute_path)
|
||||||
|
|
||||||
|
# Add trailing slash for directories
|
||||||
|
if os.path.isdir(file_path):
|
||||||
|
relative_path += '/'
|
||||||
|
|
||||||
|
results.append(relative_path)
|
||||||
|
|
||||||
|
if len(results) >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return ToolResult.success({"message": "No files found matching pattern", "files": []})
|
||||||
|
|
||||||
|
# Sort results
|
||||||
|
results.sort()
|
||||||
|
|
||||||
|
# Format output
|
||||||
|
raw_output = '\n'.join(results)
|
||||||
|
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||||
|
|
||||||
|
output = truncation.content
|
||||||
|
details = {}
|
||||||
|
notices = []
|
||||||
|
|
||||||
|
result_limit_reached = len(results) >= limit
|
||||||
|
if result_limit_reached:
|
||||||
|
notices.append(f"{limit} results limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||||
|
details["result_limit_reached"] = limit
|
||||||
|
|
||||||
|
if truncation.truncated:
|
||||||
|
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||||
|
details["truncation"] = truncation.to_dict()
|
||||||
|
|
||||||
|
if notices:
|
||||||
|
output += f"\n\n[{'. '.join(notices)}]"
|
||||||
|
|
||||||
|
return ToolResult.success({
|
||||||
|
"output": output,
|
||||||
|
"file_count": len(results),
|
||||||
|
"details": details if details else None
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error executing find: {str(e)}")
|
||||||
|
|
||||||
|
def _resolve_path(self, path: str) -> str:
|
||||||
|
"""Resolve path to absolute path"""
|
||||||
|
# Expand ~ to user home directory
|
||||||
|
path = os.path.expanduser(path)
|
||||||
|
if os.path.isabs(path):
|
||||||
|
return path
|
||||||
|
return os.path.abspath(os.path.join(self.cwd, path))
|
||||||
|
|
||||||
|
def _load_gitignore(self, directory: str) -> List[str]:
|
||||||
|
"""Load .gitignore patterns from directory"""
|
||||||
|
patterns = []
|
||||||
|
gitignore_path = os.path.join(directory, '.gitignore')
|
||||||
|
|
||||||
|
if os.path.exists(gitignore_path):
|
||||||
|
try:
|
||||||
|
with open(gitignore_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line and not line.startswith('#'):
|
||||||
|
patterns.append(line)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Add common ignore patterns
|
||||||
|
patterns.extend([
|
||||||
|
'.git',
|
||||||
|
'__pycache__',
|
||||||
|
'*.pyc',
|
||||||
|
'node_modules',
|
||||||
|
'.DS_Store'
|
||||||
|
])
|
||||||
|
|
||||||
|
return patterns
|
||||||
|
|
||||||
|
def _should_ignore(self, file_path: str, base_path: str, patterns: List[str]) -> bool:
|
||||||
|
"""Check if file should be ignored based on patterns"""
|
||||||
|
relative_path = os.path.relpath(file_path, base_path)
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
# Simple pattern matching
|
||||||
|
if pattern in relative_path:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if it's a directory pattern
|
||||||
|
if pattern.endswith('/'):
|
||||||
|
if relative_path.startswith(pattern.rstrip('/')):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
48
agent/tools/google_search/google_search.py
Normal file
48
agent/tools/google_search/google_search.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleSearch(BaseTool):
|
||||||
|
name: str = "google_search"
|
||||||
|
description: str = "A tool to perform Google searches using the Serper API."
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The search query to perform."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
config: dict = {}
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
self.config = config or {}
|
||||||
|
|
||||||
|
def execute(self, args: dict) -> ToolResult:
|
||||||
|
api_key = self.config.get("api_key") # Replace with your actual API key
|
||||||
|
url = "https://google.serper.dev/search"
|
||||||
|
headers = {
|
||||||
|
"X-API-KEY": api_key,
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"q": args.get("query"),
|
||||||
|
"k": 10
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=data)
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if result.get("statusCode") and result.get("statusCode") == 503:
|
||||||
|
return ToolResult.fail(result=result)
|
||||||
|
else:
|
||||||
|
# Check if the returned result contains the 'organic' key and ensure it is a list
|
||||||
|
if 'organic' in result and isinstance(result.get('organic'), list):
|
||||||
|
result_data = result['organic']
|
||||||
|
else:
|
||||||
|
# If there are no organic results, return the full response or an empty list
|
||||||
|
result_data = result.get('organic', []) if isinstance(result.get('organic'), list) else []
|
||||||
|
return ToolResult.success(result=result_data)
|
||||||
3
agent/tools/grep/__init__.py
Normal file
3
agent/tools/grep/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .grep import Grep
|
||||||
|
|
||||||
|
__all__ = ['Grep']
|
||||||
248
agent/tools/grep/grep.py
Normal file
248
agent/tools/grep/grep.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
"""
|
||||||
|
Grep tool - Search file contents for patterns
|
||||||
|
Uses ripgrep (rg) for fast searching
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
from agent.tools.utils.truncate import (
|
||||||
|
truncate_head, truncate_line, format_size,
|
||||||
|
DEFAULT_MAX_BYTES, GREP_MAX_LINE_LENGTH
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_LIMIT = 100
|
||||||
|
|
||||||
|
|
||||||
|
class Grep(BaseTool):
|
||||||
|
"""Tool for searching file contents"""
|
||||||
|
|
||||||
|
name: str = "grep"
|
||||||
|
description: str = f"Search file contents for a pattern. Returns matching lines with file paths and line numbers. Respects .gitignore. Output is truncated to {DEFAULT_LIMIT} matches or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Long lines are truncated to {GREP_MAX_LINE_LENGTH} chars."
|
||||||
|
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"pattern": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search pattern (regex or literal string)"
|
||||||
|
},
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Directory or file to search (default: current directory)"
|
||||||
|
},
|
||||||
|
"glob": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Filter files by glob pattern, e.g. '*.ts' or '**/*.spec.ts'"
|
||||||
|
},
|
||||||
|
"ignoreCase": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Case-insensitive search (default: false)"
|
||||||
|
},
|
||||||
|
"literal": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Treat pattern as literal string instead of regex (default: false)"
|
||||||
|
},
|
||||||
|
"context": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of lines to show before and after each match (default: 0)"
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": f"Maximum number of matches to return (default: {DEFAULT_LIMIT})"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["pattern"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: dict = None):
|
||||||
|
self.config = config or {}
|
||||||
|
self.cwd = self.config.get("cwd", os.getcwd())
|
||||||
|
self.rg_path = self._find_ripgrep()
|
||||||
|
|
||||||
|
def _find_ripgrep(self) -> Optional[str]:
|
||||||
|
"""Find ripgrep executable"""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(['which', 'rg'], capture_output=True, text=True)
|
||||||
|
if result.returncode == 0:
|
||||||
|
return result.stdout.strip()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Execute grep search
|
||||||
|
|
||||||
|
:param args: Search parameters
|
||||||
|
:return: Search results or error
|
||||||
|
"""
|
||||||
|
if not self.rg_path:
|
||||||
|
return ToolResult.fail("Error: ripgrep (rg) is not installed. Please install it first.")
|
||||||
|
|
||||||
|
pattern = args.get("pattern", "").strip()
|
||||||
|
search_path = args.get("path", ".").strip()
|
||||||
|
glob = args.get("glob")
|
||||||
|
ignore_case = args.get("ignoreCase", False)
|
||||||
|
literal = args.get("literal", False)
|
||||||
|
context = args.get("context", 0)
|
||||||
|
limit = args.get("limit", DEFAULT_LIMIT)
|
||||||
|
|
||||||
|
if not pattern:
|
||||||
|
return ToolResult.fail("Error: pattern parameter is required")
|
||||||
|
|
||||||
|
# Resolve search path
|
||||||
|
absolute_path = self._resolve_path(search_path)
|
||||||
|
|
||||||
|
if not os.path.exists(absolute_path):
|
||||||
|
return ToolResult.fail(f"Error: Path not found: {search_path}")
|
||||||
|
|
||||||
|
# Build ripgrep command
|
||||||
|
cmd = [
|
||||||
|
self.rg_path,
|
||||||
|
'--json',
|
||||||
|
'--line-number',
|
||||||
|
'--color=never',
|
||||||
|
'--hidden'
|
||||||
|
]
|
||||||
|
|
||||||
|
if ignore_case:
|
||||||
|
cmd.append('--ignore-case')
|
||||||
|
|
||||||
|
if literal:
|
||||||
|
cmd.append('--fixed-strings')
|
||||||
|
|
||||||
|
if glob:
|
||||||
|
cmd.extend(['--glob', glob])
|
||||||
|
|
||||||
|
cmd.extend([pattern, absolute_path])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Execute ripgrep
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=self.cwd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse JSON output
|
||||||
|
matches = []
|
||||||
|
match_count = 0
|
||||||
|
|
||||||
|
for line in result.stdout.splitlines():
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
event = json.loads(line)
|
||||||
|
if event.get('type') == 'match':
|
||||||
|
data = event.get('data', {})
|
||||||
|
file_path = data.get('path', {}).get('text')
|
||||||
|
line_number = data.get('line_number')
|
||||||
|
|
||||||
|
if file_path and line_number:
|
||||||
|
matches.append({
|
||||||
|
'file': file_path,
|
||||||
|
'line': line_number
|
||||||
|
})
|
||||||
|
match_count += 1
|
||||||
|
|
||||||
|
if match_count >= limit:
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if match_count == 0:
|
||||||
|
return ToolResult.success({"message": "No matches found", "matches": []})
|
||||||
|
|
||||||
|
# Format output with context
|
||||||
|
output_lines = []
|
||||||
|
lines_truncated = False
|
||||||
|
is_directory = os.path.isdir(absolute_path)
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
file_path = match['file']
|
||||||
|
line_number = match['line']
|
||||||
|
|
||||||
|
# Format file path
|
||||||
|
if is_directory:
|
||||||
|
relative_path = os.path.relpath(file_path, absolute_path)
|
||||||
|
else:
|
||||||
|
relative_path = os.path.basename(file_path)
|
||||||
|
|
||||||
|
# Read file and get context
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
file_lines = f.read().split('\n')
|
||||||
|
|
||||||
|
# Calculate context range
|
||||||
|
start = max(0, line_number - 1 - context) if context > 0 else line_number - 1
|
||||||
|
end = min(len(file_lines), line_number + context) if context > 0 else line_number
|
||||||
|
|
||||||
|
# Format lines with context
|
||||||
|
for i in range(start, end):
|
||||||
|
line_text = file_lines[i].replace('\r', '')
|
||||||
|
|
||||||
|
# Truncate long lines
|
||||||
|
truncated_text, was_truncated = truncate_line(line_text)
|
||||||
|
if was_truncated:
|
||||||
|
lines_truncated = True
|
||||||
|
|
||||||
|
# Format output
|
||||||
|
current_line = i + 1
|
||||||
|
if current_line == line_number:
|
||||||
|
output_lines.append(f"{relative_path}:{current_line}: {truncated_text}")
|
||||||
|
else:
|
||||||
|
output_lines.append(f"{relative_path}-{current_line}- {truncated_text}")
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
output_lines.append(f"{relative_path}:{line_number}: (unable to read file)")
|
||||||
|
|
||||||
|
# Apply byte truncation
|
||||||
|
raw_output = '\n'.join(output_lines)
|
||||||
|
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||||
|
|
||||||
|
output = truncation.content
|
||||||
|
details = {}
|
||||||
|
notices = []
|
||||||
|
|
||||||
|
if match_count >= limit:
|
||||||
|
notices.append(f"{limit} matches limit reached. Use limit={limit * 2} for more, or refine pattern")
|
||||||
|
details["match_limit_reached"] = limit
|
||||||
|
|
||||||
|
if truncation.truncated:
|
||||||
|
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||||
|
details["truncation"] = truncation.to_dict()
|
||||||
|
|
||||||
|
if lines_truncated:
|
||||||
|
notices.append(f"Some lines truncated to {GREP_MAX_LINE_LENGTH} chars. Use read tool to see full lines")
|
||||||
|
details["lines_truncated"] = True
|
||||||
|
|
||||||
|
if notices:
|
||||||
|
output += f"\n\n[{'. '.join(notices)}]"
|
||||||
|
|
||||||
|
return ToolResult.success({
|
||||||
|
"output": output,
|
||||||
|
"match_count": match_count,
|
||||||
|
"details": details if details else None
|
||||||
|
})
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return ToolResult.fail("Error: Search timed out after 30 seconds")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error executing grep: {str(e)}")
|
||||||
|
|
||||||
|
def _resolve_path(self, path: str) -> str:
|
||||||
|
"""Resolve path to absolute path"""
|
||||||
|
# Expand ~ to user home directory
|
||||||
|
path = os.path.expanduser(path)
|
||||||
|
if os.path.isabs(path):
|
||||||
|
return path
|
||||||
|
return os.path.abspath(os.path.join(self.cwd, path))
|
||||||
3
agent/tools/ls/__init__.py
Normal file
3
agent/tools/ls/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .ls import Ls
|
||||||
|
|
||||||
|
__all__ = ['Ls']
|
||||||
125
agent/tools/ls/ls.py
Normal file
125
agent/tools/ls/ls.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""
|
||||||
|
Ls tool - List directory contents
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_LIMIT = 500
|
||||||
|
|
||||||
|
|
||||||
|
class Ls(BaseTool):
|
||||||
|
"""Tool for listing directory contents"""
|
||||||
|
|
||||||
|
name: str = "ls"
|
||||||
|
description: str = f"List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to {DEFAULT_LIMIT} entries or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||||
|
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Directory to list (default: current directory)"
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": f"Maximum number of entries to return (default: {DEFAULT_LIMIT})"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": []
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: dict = None):
|
||||||
|
self.config = config or {}
|
||||||
|
self.cwd = self.config.get("cwd", os.getcwd())
|
||||||
|
|
||||||
|
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Execute directory listing
|
||||||
|
|
||||||
|
:param args: Listing parameters
|
||||||
|
:return: Directory contents or error
|
||||||
|
"""
|
||||||
|
path = args.get("path", ".").strip()
|
||||||
|
limit = args.get("limit", DEFAULT_LIMIT)
|
||||||
|
|
||||||
|
# Resolve path
|
||||||
|
absolute_path = self._resolve_path(path)
|
||||||
|
|
||||||
|
if not os.path.exists(absolute_path):
|
||||||
|
return ToolResult.fail(f"Error: Path not found: {path}")
|
||||||
|
|
||||||
|
if not os.path.isdir(absolute_path):
|
||||||
|
return ToolResult.fail(f"Error: Not a directory: {path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read directory entries
|
||||||
|
entries = os.listdir(absolute_path)
|
||||||
|
|
||||||
|
# Sort alphabetically (case-insensitive)
|
||||||
|
entries.sort(key=lambda x: x.lower())
|
||||||
|
|
||||||
|
# Format entries with directory indicators
|
||||||
|
results = []
|
||||||
|
entry_limit_reached = False
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
if len(results) >= limit:
|
||||||
|
entry_limit_reached = True
|
||||||
|
break
|
||||||
|
|
||||||
|
full_path = os.path.join(absolute_path, entry)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.path.isdir(full_path):
|
||||||
|
results.append(entry + '/')
|
||||||
|
else:
|
||||||
|
results.append(entry)
|
||||||
|
except:
|
||||||
|
# Skip entries we can't stat
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return ToolResult.success({"message": "(empty directory)", "entries": []})
|
||||||
|
|
||||||
|
# Format output
|
||||||
|
raw_output = '\n'.join(results)
|
||||||
|
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||||
|
|
||||||
|
output = truncation.content
|
||||||
|
details = {}
|
||||||
|
notices = []
|
||||||
|
|
||||||
|
if entry_limit_reached:
|
||||||
|
notices.append(f"{limit} entries limit reached. Use limit={limit * 2} for more")
|
||||||
|
details["entry_limit_reached"] = limit
|
||||||
|
|
||||||
|
if truncation.truncated:
|
||||||
|
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||||
|
details["truncation"] = truncation.to_dict()
|
||||||
|
|
||||||
|
if notices:
|
||||||
|
output += f"\n\n[{'. '.join(notices)}]"
|
||||||
|
|
||||||
|
return ToolResult.success({
|
||||||
|
"output": output,
|
||||||
|
"entry_count": len(results),
|
||||||
|
"details": details if details else None
|
||||||
|
})
|
||||||
|
|
||||||
|
except PermissionError:
|
||||||
|
return ToolResult.fail(f"Error: Permission denied reading directory: {path}")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error listing directory: {str(e)}")
|
||||||
|
|
||||||
|
def _resolve_path(self, path: str) -> str:
|
||||||
|
"""Resolve path to absolute path"""
|
||||||
|
# Expand ~ to user home directory
|
||||||
|
path = os.path.expanduser(path)
|
||||||
|
if os.path.isabs(path):
|
||||||
|
return path
|
||||||
|
return os.path.abspath(os.path.join(self.cwd, path))
|
||||||
10
agent/tools/memory/__init__.py
Normal file
10
agent/tools/memory/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
Memory tools for Agent
|
||||||
|
|
||||||
|
Provides memory_search and memory_get tools
|
||||||
|
"""
|
||||||
|
|
||||||
|
from agent.tools.memory.memory_search import MemorySearchTool
|
||||||
|
from agent.tools.memory.memory_get import MemoryGetTool
|
||||||
|
|
||||||
|
__all__ = ['MemorySearchTool', 'MemoryGetTool']
|
||||||
107
agent/tools/memory/memory_get.py
Normal file
107
agent/tools/memory/memory_get.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""
|
||||||
|
Memory get tool
|
||||||
|
|
||||||
|
Allows agents to read specific sections from memory files
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any
|
||||||
|
from pathlib import Path
|
||||||
|
from agent.tools.base_tool import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryGetTool(BaseTool):
|
||||||
|
"""Tool for reading memory file contents"""
|
||||||
|
|
||||||
|
name: str = "memory_get"
|
||||||
|
description: str = (
|
||||||
|
"Read specific content from memory files. "
|
||||||
|
"Use this to get full context from a memory file or specific line range."
|
||||||
|
)
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Relative path to the memory file (e.g., 'MEMORY.md', 'memory/2024-01-29.md')"
|
||||||
|
},
|
||||||
|
"start_line": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Starting line number (optional, default: 1)",
|
||||||
|
"default": 1
|
||||||
|
},
|
||||||
|
"num_lines": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of lines to read (optional, reads all if not specified)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, memory_manager):
|
||||||
|
"""
|
||||||
|
Initialize memory get tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_manager: MemoryManager instance
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.memory_manager = memory_manager
|
||||||
|
|
||||||
|
def execute(self, args: dict):
|
||||||
|
"""
|
||||||
|
Execute memory file read
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Dictionary with path, start_line, num_lines
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolResult with file content
|
||||||
|
"""
|
||||||
|
from agent.tools.base_tool import ToolResult
|
||||||
|
|
||||||
|
path = args.get("path")
|
||||||
|
start_line = args.get("start_line", 1)
|
||||||
|
num_lines = args.get("num_lines")
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return ToolResult.fail("Error: path parameter is required")
|
||||||
|
|
||||||
|
try:
|
||||||
|
workspace_dir = self.memory_manager.config.get_workspace()
|
||||||
|
file_path = workspace_dir / path
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
return ToolResult.fail(f"Error: File not found: {path}")
|
||||||
|
|
||||||
|
content = file_path.read_text()
|
||||||
|
lines = content.split('\n')
|
||||||
|
|
||||||
|
# Handle line range
|
||||||
|
if start_line < 1:
|
||||||
|
start_line = 1
|
||||||
|
|
||||||
|
start_idx = start_line - 1
|
||||||
|
|
||||||
|
if num_lines:
|
||||||
|
end_idx = start_idx + num_lines
|
||||||
|
selected_lines = lines[start_idx:end_idx]
|
||||||
|
else:
|
||||||
|
selected_lines = lines[start_idx:]
|
||||||
|
|
||||||
|
result = '\n'.join(selected_lines)
|
||||||
|
|
||||||
|
# Add metadata
|
||||||
|
total_lines = len(lines)
|
||||||
|
shown_lines = len(selected_lines)
|
||||||
|
|
||||||
|
output = [
|
||||||
|
f"File: {path}",
|
||||||
|
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
|
||||||
|
"",
|
||||||
|
result
|
||||||
|
]
|
||||||
|
|
||||||
|
return ToolResult.success('\n'.join(output))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error reading memory file: {str(e)}")
|
||||||
96
agent/tools/memory/memory_search.py
Normal file
96
agent/tools/memory/memory_search.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
Memory search tool
|
||||||
|
|
||||||
|
Allows agents to search their memory using semantic and keyword search
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from agent.tools.base_tool import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySearchTool(BaseTool):
|
||||||
|
"""Tool for searching agent memory"""
|
||||||
|
|
||||||
|
name: str = "memory_search"
|
||||||
|
description: str = (
|
||||||
|
"Search agent's long-term memory using semantic and keyword search. "
|
||||||
|
"Use this to recall past conversations, preferences, and knowledge."
|
||||||
|
)
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query (can be natural language question or keywords)"
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of results to return (default: 10)",
|
||||||
|
"default": 10
|
||||||
|
},
|
||||||
|
"min_score": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Minimum relevance score (0-1, default: 0.3)",
|
||||||
|
"default": 0.3
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, memory_manager, user_id: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize memory search tool
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_manager: MemoryManager instance
|
||||||
|
user_id: Optional user ID for scoped search
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.memory_manager = memory_manager
|
||||||
|
self.user_id = user_id
|
||||||
|
|
||||||
|
def execute(self, args: dict):
|
||||||
|
"""
|
||||||
|
Execute memory search
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Dictionary with query, max_results, min_score
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolResult with formatted search results
|
||||||
|
"""
|
||||||
|
from agent.tools.base_tool import ToolResult
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
query = args.get("query")
|
||||||
|
max_results = args.get("max_results", 10)
|
||||||
|
min_score = args.get("min_score", 0.3)
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return ToolResult.fail("Error: query parameter is required")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run async search in sync context
|
||||||
|
results = asyncio.run(self.memory_manager.search(
|
||||||
|
query=query,
|
||||||
|
user_id=self.user_id,
|
||||||
|
max_results=max_results,
|
||||||
|
min_score=min_score,
|
||||||
|
include_shared=True
|
||||||
|
))
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return ToolResult.success(f"No relevant memories found for query: {query}")
|
||||||
|
|
||||||
|
# Format results
|
||||||
|
output = [f"Found {len(results)} relevant memories:\n"]
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
|
||||||
|
output.append(f" Score: {result.score:.3f}")
|
||||||
|
output.append(f" Snippet: {result.snippet}")
|
||||||
|
|
||||||
|
return ToolResult.success("\n".join(output))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error searching memory: {str(e)}")
|
||||||
3
agent/tools/read/__init__.py
Normal file
3
agent/tools/read/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .read import Read
|
||||||
|
|
||||||
|
__all__ = ['Read']
|
||||||
336
agent/tools/read/read.py
Normal file
336
agent/tools/read/read.py
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
"""
|
||||||
|
Read tool - Read file contents
|
||||||
|
Supports text files, images (jpg, png, gif, webp), and PDF files
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||||
|
|
||||||
|
|
||||||
|
class Read(BaseTool):
|
||||||
|
"""Tool for reading file contents"""
|
||||||
|
|
||||||
|
name: str = "read"
|
||||||
|
description: str = f"Read the contents of a file. Supports text files, PDF files, and images (jpg, png, gif, webp). For text files, output is truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). Use offset/limit for large files."
|
||||||
|
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to the file to read (relative or absolute)"
|
||||||
|
},
|
||||||
|
"offset": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Line number to start reading from (1-indexed, optional)"
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of lines to read (optional)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: dict = None):
|
||||||
|
self.config = config or {}
|
||||||
|
self.cwd = self.config.get("cwd", os.getcwd())
|
||||||
|
# Supported image formats
|
||||||
|
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp'}
|
||||||
|
# Supported PDF format
|
||||||
|
self.pdf_extensions = {'.pdf'}
|
||||||
|
|
||||||
|
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Execute file read operation
|
||||||
|
|
||||||
|
:param args: Contains file path and optional offset/limit parameters
|
||||||
|
:return: File content or error message
|
||||||
|
"""
|
||||||
|
path = args.get("path", "").strip()
|
||||||
|
offset = args.get("offset")
|
||||||
|
limit = args.get("limit")
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return ToolResult.fail("Error: path parameter is required")
|
||||||
|
|
||||||
|
# Resolve path
|
||||||
|
absolute_path = self._resolve_path(path)
|
||||||
|
|
||||||
|
# Check if file exists
|
||||||
|
if not os.path.exists(absolute_path):
|
||||||
|
return ToolResult.fail(f"Error: File not found: {path}")
|
||||||
|
|
||||||
|
# Check if readable
|
||||||
|
if not os.access(absolute_path, os.R_OK):
|
||||||
|
return ToolResult.fail(f"Error: File is not readable: {path}")
|
||||||
|
|
||||||
|
# Check file type
|
||||||
|
file_ext = Path(absolute_path).suffix.lower()
|
||||||
|
|
||||||
|
# Check if image
|
||||||
|
if file_ext in self.image_extensions:
|
||||||
|
return self._read_image(absolute_path, file_ext)
|
||||||
|
|
||||||
|
# Check if PDF
|
||||||
|
if file_ext in self.pdf_extensions:
|
||||||
|
return self._read_pdf(absolute_path, path, offset, limit)
|
||||||
|
|
||||||
|
# Read text file
|
||||||
|
return self._read_text(absolute_path, path, offset, limit)
|
||||||
|
|
||||||
|
def _resolve_path(self, path: str) -> str:
|
||||||
|
"""
|
||||||
|
Resolve path to absolute path
|
||||||
|
|
||||||
|
:param path: Relative or absolute path
|
||||||
|
:return: Absolute path
|
||||||
|
"""
|
||||||
|
# Expand ~ to user home directory
|
||||||
|
path = os.path.expanduser(path)
|
||||||
|
if os.path.isabs(path):
|
||||||
|
return path
|
||||||
|
return os.path.abspath(os.path.join(self.cwd, path))
|
||||||
|
|
||||||
|
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Read image file
|
||||||
|
|
||||||
|
:param absolute_path: Absolute path to the image file
|
||||||
|
:param file_ext: File extension
|
||||||
|
:return: Result containing image information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Read image file
|
||||||
|
with open(absolute_path, 'rb') as f:
|
||||||
|
image_data = f.read()
|
||||||
|
|
||||||
|
# Get file size
|
||||||
|
file_size = len(image_data)
|
||||||
|
|
||||||
|
# Return image information (actual image data can be base64 encoded when needed)
|
||||||
|
import base64
|
||||||
|
base64_data = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
|
||||||
|
# Determine MIME type
|
||||||
|
mime_type_map = {
|
||||||
|
'.jpg': 'image/jpeg',
|
||||||
|
'.jpeg': 'image/jpeg',
|
||||||
|
'.png': 'image/png',
|
||||||
|
'.gif': 'image/gif',
|
||||||
|
'.webp': 'image/webp'
|
||||||
|
}
|
||||||
|
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"type": "image",
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"size": file_size,
|
||||||
|
"size_formatted": format_size(file_size),
|
||||||
|
"data": base64_data # Base64 encoded image data
|
||||||
|
}
|
||||||
|
|
||||||
|
return ToolResult.success(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error reading image file: {str(e)}")
|
||||||
|
|
||||||
|
def _read_text(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Read text file
|
||||||
|
|
||||||
|
:param absolute_path: Absolute path to the file
|
||||||
|
:param display_path: Path to display
|
||||||
|
:param offset: Starting line number (1-indexed)
|
||||||
|
:param limit: Maximum number of lines to read
|
||||||
|
:return: File content or error message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Read file
|
||||||
|
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
all_lines = content.split('\n')
|
||||||
|
total_file_lines = len(all_lines)
|
||||||
|
|
||||||
|
# Apply offset (if specified)
|
||||||
|
start_line = 0
|
||||||
|
if offset is not None:
|
||||||
|
start_line = max(0, offset - 1) # Convert to 0-indexed
|
||||||
|
if start_line >= total_file_lines:
|
||||||
|
return ToolResult.fail(
|
||||||
|
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
|
||||||
|
)
|
||||||
|
|
||||||
|
start_line_display = start_line + 1 # For display (1-indexed)
|
||||||
|
|
||||||
|
# If user specified limit, use it
|
||||||
|
selected_content = content
|
||||||
|
user_limited_lines = None
|
||||||
|
if limit is not None:
|
||||||
|
end_line = min(start_line + limit, total_file_lines)
|
||||||
|
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||||
|
user_limited_lines = end_line - start_line
|
||||||
|
elif offset is not None:
|
||||||
|
selected_content = '\n'.join(all_lines[start_line:])
|
||||||
|
|
||||||
|
# Apply truncation (considering line count and byte limits)
|
||||||
|
truncation = truncate_head(selected_content)
|
||||||
|
|
||||||
|
output_text = ""
|
||||||
|
details = {}
|
||||||
|
|
||||||
|
if truncation.first_line_exceeds_limit:
|
||||||
|
# First line exceeds 30KB limit
|
||||||
|
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
|
||||||
|
output_text = f"[Line {start_line_display} is {first_line_size}, exceeds {format_size(DEFAULT_MAX_BYTES)} limit. Use bash tool to read: head -c {DEFAULT_MAX_BYTES} {display_path} | tail -n +{start_line_display}]"
|
||||||
|
details["truncation"] = truncation.to_dict()
|
||||||
|
elif truncation.truncated:
|
||||||
|
# Truncation occurred
|
||||||
|
end_line_display = start_line_display + truncation.output_lines - 1
|
||||||
|
next_offset = end_line_display + 1
|
||||||
|
|
||||||
|
output_text = truncation.content
|
||||||
|
|
||||||
|
if truncation.truncated_by == "lines":
|
||||||
|
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines}. Use offset={next_offset} to continue.]"
|
||||||
|
else:
|
||||||
|
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||||
|
|
||||||
|
details["truncation"] = truncation.to_dict()
|
||||||
|
elif user_limited_lines is not None and start_line + user_limited_lines < total_file_lines:
|
||||||
|
# User specified limit, more content available, but no truncation
|
||||||
|
remaining = total_file_lines - (start_line + user_limited_lines)
|
||||||
|
next_offset = start_line + user_limited_lines + 1
|
||||||
|
|
||||||
|
output_text = truncation.content
|
||||||
|
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||||
|
else:
|
||||||
|
# No truncation, no exceeding user limit
|
||||||
|
output_text = truncation.content
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"content": output_text,
|
||||||
|
"total_lines": total_file_lines,
|
||||||
|
"start_line": start_line_display,
|
||||||
|
"output_lines": truncation.output_lines
|
||||||
|
}
|
||||||
|
|
||||||
|
if details:
|
||||||
|
result["details"] = details
|
||||||
|
|
||||||
|
return ToolResult.success(result)
|
||||||
|
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {display_path}")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error reading file: {str(e)}")
|
||||||
|
|
||||||
|
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Read PDF file content
|
||||||
|
|
||||||
|
:param absolute_path: Absolute path to the file
|
||||||
|
:param display_path: Path to display
|
||||||
|
:param offset: Starting line number (1-indexed)
|
||||||
|
:param limit: Maximum number of lines to read
|
||||||
|
:return: PDF text content or error message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to import pypdf
|
||||||
|
try:
|
||||||
|
from pypdf import PdfReader
|
||||||
|
except ImportError:
|
||||||
|
return ToolResult.fail(
|
||||||
|
"Error: pypdf library not installed. Install with: pip install pypdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read PDF
|
||||||
|
reader = PdfReader(absolute_path)
|
||||||
|
total_pages = len(reader.pages)
|
||||||
|
|
||||||
|
# Extract text from all pages
|
||||||
|
text_parts = []
|
||||||
|
for page_num, page in enumerate(reader.pages, 1):
|
||||||
|
page_text = page.extract_text()
|
||||||
|
if page_text.strip():
|
||||||
|
text_parts.append(f"--- Page {page_num} ---\n{page_text}")
|
||||||
|
|
||||||
|
if not text_parts:
|
||||||
|
return ToolResult.success({
|
||||||
|
"content": f"[PDF file with {total_pages} pages, but no text content could be extracted]",
|
||||||
|
"total_pages": total_pages,
|
||||||
|
"message": "PDF may contain only images or be encrypted"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Merge all text
|
||||||
|
full_content = "\n\n".join(text_parts)
|
||||||
|
all_lines = full_content.split('\n')
|
||||||
|
total_lines = len(all_lines)
|
||||||
|
|
||||||
|
# Apply offset and limit (same logic as text files)
|
||||||
|
start_line = 0
|
||||||
|
if offset is not None:
|
||||||
|
start_line = max(0, offset - 1)
|
||||||
|
if start_line >= total_lines:
|
||||||
|
return ToolResult.fail(
|
||||||
|
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||||
|
)
|
||||||
|
|
||||||
|
start_line_display = start_line + 1
|
||||||
|
|
||||||
|
selected_content = full_content
|
||||||
|
user_limited_lines = None
|
||||||
|
if limit is not None:
|
||||||
|
end_line = min(start_line + limit, total_lines)
|
||||||
|
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||||
|
user_limited_lines = end_line - start_line
|
||||||
|
elif offset is not None:
|
||||||
|
selected_content = '\n'.join(all_lines[start_line:])
|
||||||
|
|
||||||
|
# Apply truncation
|
||||||
|
truncation = truncate_head(selected_content)
|
||||||
|
|
||||||
|
output_text = ""
|
||||||
|
details = {}
|
||||||
|
|
||||||
|
if truncation.truncated:
|
||||||
|
end_line_display = start_line_display + truncation.output_lines - 1
|
||||||
|
next_offset = end_line_display + 1
|
||||||
|
|
||||||
|
output_text = truncation.content
|
||||||
|
|
||||||
|
if truncation.truncated_by == "lines":
|
||||||
|
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||||
|
else:
|
||||||
|
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||||
|
|
||||||
|
details["truncation"] = truncation.to_dict()
|
||||||
|
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||||
|
remaining = total_lines - (start_line + user_limited_lines)
|
||||||
|
next_offset = start_line + user_limited_lines + 1
|
||||||
|
|
||||||
|
output_text = truncation.content
|
||||||
|
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||||
|
else:
|
||||||
|
output_text = truncation.content
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"content": output_text,
|
||||||
|
"total_pages": total_pages,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"start_line": start_line_display,
|
||||||
|
"output_lines": truncation.output_lines
|
||||||
|
}
|
||||||
|
|
||||||
|
if details:
|
||||||
|
result["details"] = details
|
||||||
|
|
||||||
|
return ToolResult.success(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error reading PDF file: {str(e)}")
|
||||||
3
agent/tools/terminal/__init__.py
Normal file
3
agent/tools/terminal/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .terminal import Terminal
|
||||||
|
|
||||||
|
__all__ = ['Terminal']
|
||||||
100
agent/tools/terminal/terminal.py
Normal file
100
agent/tools/terminal/terminal.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class Terminal(BaseTool):
|
||||||
|
name: str = "terminal"
|
||||||
|
description: str = "A tool to run terminal commands on the local system"
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": f"The terminal command to execute which should be valid in {platform.system()} platform"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["command"]
|
||||||
|
}
|
||||||
|
config: dict = {}
|
||||||
|
|
||||||
|
def __init__(self, config=None):
|
||||||
|
self.config = config or {}
|
||||||
|
# Set of dangerous commands that should be blocked
|
||||||
|
self.command_ban_set = {"halt", "poweroff", "shutdown", "reboot", "rm", "kill",
|
||||||
|
"exit", "sudo", "su", "userdel", "groupdel", "logout", "alias"}
|
||||||
|
|
||||||
|
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Execute a terminal command safely.
|
||||||
|
|
||||||
|
:param args: Dictionary containing the command to execute
|
||||||
|
:return: Result of the command execution
|
||||||
|
"""
|
||||||
|
command = args.get("command", "").strip()
|
||||||
|
|
||||||
|
# Check if the command is safe to execute
|
||||||
|
if not self._is_safe_command(command):
|
||||||
|
return ToolResult.fail(result=f"Command '{command}' is not allowed for security reasons.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
command,
|
||||||
|
shell=True,
|
||||||
|
check=True, # Raise exception on non-zero return code
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
timeout=self.config.get("timeout", 30)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolResult.success({
|
||||||
|
"stdout": result.stdout,
|
||||||
|
"stderr": result.stderr,
|
||||||
|
"return_code": result.returncode,
|
||||||
|
"command": command
|
||||||
|
})
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
# Preserve the original error handling for CalledProcessError
|
||||||
|
return ToolResult.fail({
|
||||||
|
"stdout": e.stdout,
|
||||||
|
"stderr": e.stderr,
|
||||||
|
"return_code": e.returncode,
|
||||||
|
"command": command
|
||||||
|
})
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return ToolResult.fail(result=f"Command timed out after {self.config.get('timeout', 20)} seconds.")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(result=f"Error executing command: {str(e)}")
|
||||||
|
|
||||||
|
def _is_safe_command(self, command: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a command is safe to execute.
|
||||||
|
|
||||||
|
:param command: The command to check
|
||||||
|
:return: True if the command is safe, False otherwise
|
||||||
|
"""
|
||||||
|
# Split the command to get the base command
|
||||||
|
cmd_parts = command.split()
|
||||||
|
if not cmd_parts:
|
||||||
|
return False
|
||||||
|
|
||||||
|
base_cmd = cmd_parts[0].lower()
|
||||||
|
|
||||||
|
# Check if the base command is in the ban list
|
||||||
|
if base_cmd in self.command_ban_set:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for sudo/su commands
|
||||||
|
if any(banned in command.lower() for banned in ["sudo ", "su -"]):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for rm -rf or similar dangerous patterns
|
||||||
|
if "rm" in base_cmd and ("-rf" in command or "-r" in command or "-f" in command):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Additional security checks can be added here
|
||||||
|
|
||||||
|
return True
|
||||||
208
agent/tools/tool_manager.py
Normal file
208
agent/tools/tool_manager.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Type
|
||||||
|
from agent.tools.base_tool import BaseTool
|
||||||
|
from common.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
class ToolManager:
|
||||||
|
"""
|
||||||
|
Tool manager for managing tools.
|
||||||
|
"""
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
"""Singleton pattern to ensure only one instance of ToolManager exists."""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(ToolManager, cls).__new__(cls)
|
||||||
|
cls._instance.tool_classes = {} # Store tool classes instead of instances
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Initialize only once
|
||||||
|
if not hasattr(self, 'tool_classes'):
|
||||||
|
self.tool_classes = {} # Dictionary to store tool classes
|
||||||
|
|
||||||
|
def load_tools(self, tools_dir: str = "", config_dict=None):
|
||||||
|
"""
|
||||||
|
Load tools from both directory and configuration.
|
||||||
|
|
||||||
|
:param tools_dir: Directory to scan for tool modules
|
||||||
|
"""
|
||||||
|
if tools_dir:
|
||||||
|
self._load_tools_from_directory(tools_dir)
|
||||||
|
self._configure_tools_from_config()
|
||||||
|
else:
|
||||||
|
self._load_tools_from_init()
|
||||||
|
self._configure_tools_from_config(config_dict)
|
||||||
|
|
||||||
|
def _load_tools_from_init(self) -> bool:
|
||||||
|
"""
|
||||||
|
Load tool classes from tools.__init__.__all__
|
||||||
|
|
||||||
|
:return: True if tools were loaded, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to import the tools package
|
||||||
|
tools_package = importlib.import_module("agent.tools")
|
||||||
|
|
||||||
|
# Check if __all__ is defined
|
||||||
|
if hasattr(tools_package, "__all__"):
|
||||||
|
tool_classes = tools_package.__all__
|
||||||
|
|
||||||
|
# Import each tool class directly from the tools package
|
||||||
|
for class_name in tool_classes:
|
||||||
|
try:
|
||||||
|
# Skip base classes
|
||||||
|
if class_name in ["BaseTool", "ToolManager"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the class directly from the tools package
|
||||||
|
if hasattr(tools_package, class_name):
|
||||||
|
cls = getattr(tools_package, class_name)
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(cls, type)
|
||||||
|
and issubclass(cls, BaseTool)
|
||||||
|
and cls != BaseTool
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# Create a temporary instance to get the name
|
||||||
|
temp_instance = cls()
|
||||||
|
tool_name = temp_instance.name
|
||||||
|
# Store the class, not the instance
|
||||||
|
self.tool_classes[tool_name] = cls
|
||||||
|
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
|
||||||
|
except ImportError as e:
|
||||||
|
# Ignore browser_use dependency missing errors
|
||||||
|
if "browser_use" in str(e):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error importing class {class_name}: {e}")
|
||||||
|
|
||||||
|
return len(self.tool_classes) > 0
|
||||||
|
return False
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("Could not import agent.tools package")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading tools from __init__.__all__: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _load_tools_from_directory(self, tools_dir: str):
|
||||||
|
"""Dynamically load tool classes from directory"""
|
||||||
|
tools_path = Path(tools_dir)
|
||||||
|
|
||||||
|
# Traverse all .py files
|
||||||
|
for py_file in tools_path.rglob("*.py"):
|
||||||
|
# Skip initialization files and base tool files
|
||||||
|
if py_file.name in ["__init__.py", "base_tool.py", "tool_manager.py"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get module name
|
||||||
|
module_name = py_file.stem
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load module directly from file
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, py_file)
|
||||||
|
if spec and spec.loader:
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# Find tool classes in the module
|
||||||
|
for attr_name in dir(module):
|
||||||
|
cls = getattr(module, attr_name)
|
||||||
|
if (
|
||||||
|
isinstance(cls, type)
|
||||||
|
and issubclass(cls, BaseTool)
|
||||||
|
and cls != BaseTool
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# Create a temporary instance to get the name
|
||||||
|
temp_instance = cls()
|
||||||
|
tool_name = temp_instance.name
|
||||||
|
# Store the class, not the instance
|
||||||
|
self.tool_classes[tool_name] = cls
|
||||||
|
except ImportError as e:
|
||||||
|
# Ignore browser_use dependency missing errors
|
||||||
|
if "browser_use" in str(e):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print(f"Error initializing tool class {cls.__name__}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error initializing tool class {cls.__name__}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error importing module {py_file}: {e}")
|
||||||
|
|
||||||
|
def _configure_tools_from_config(self, config_dict=None):
|
||||||
|
"""Configure tool classes based on configuration file"""
|
||||||
|
try:
|
||||||
|
# Get tools configuration
|
||||||
|
tools_config = config_dict or config().get("tools", {})
|
||||||
|
|
||||||
|
# Record tools that are configured but not loaded
|
||||||
|
missing_tools = []
|
||||||
|
|
||||||
|
# Store configurations for later use when instantiating
|
||||||
|
self.tool_configs = tools_config
|
||||||
|
|
||||||
|
# Check which configured tools are missing
|
||||||
|
for tool_name in tools_config:
|
||||||
|
if tool_name not in self.tool_classes:
|
||||||
|
missing_tools.append(tool_name)
|
||||||
|
|
||||||
|
# If there are missing tools, record warnings
|
||||||
|
if missing_tools:
|
||||||
|
for tool_name in missing_tools:
|
||||||
|
if tool_name == "browser":
|
||||||
|
logger.error(
|
||||||
|
"Browser tool is configured but could not be loaded. "
|
||||||
|
"Please install the required dependency with: "
|
||||||
|
"pip install browser-use>=0.1.40 or pip install agentmesh-sdk[full]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Tool '{tool_name}' is configured but could not be loaded.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error configuring tools from config: {e}")
|
||||||
|
|
||||||
|
def create_tool(self, name: str) -> BaseTool:
|
||||||
|
"""
|
||||||
|
Get a new instance of a tool by name.
|
||||||
|
|
||||||
|
:param name: The name of the tool to get.
|
||||||
|
:return: A new instance of the tool or None if not found.
|
||||||
|
"""
|
||||||
|
tool_class = self.tool_classes.get(name)
|
||||||
|
if tool_class:
|
||||||
|
# Create a new instance
|
||||||
|
tool_instance = tool_class()
|
||||||
|
|
||||||
|
# Apply configuration if available
|
||||||
|
if hasattr(self, 'tool_configs') and name in self.tool_configs:
|
||||||
|
tool_instance.config = self.tool_configs[name]
|
||||||
|
|
||||||
|
return tool_instance
|
||||||
|
return None
|
||||||
|
|
||||||
|
def list_tools(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get information about all loaded tools.
|
||||||
|
|
||||||
|
:return: A dictionary with tool information.
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for name, tool_class in self.tool_classes.items():
|
||||||
|
# Create a temporary instance to get schema
|
||||||
|
temp_instance = tool_class()
|
||||||
|
result[name] = {
|
||||||
|
"description": temp_instance.description,
|
||||||
|
"parameters": temp_instance.get_json_schema()
|
||||||
|
}
|
||||||
|
return result
|
||||||
40
agent/tools/utils/__init__.py
Normal file
40
agent/tools/utils/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from .truncate import (
|
||||||
|
truncate_head,
|
||||||
|
truncate_tail,
|
||||||
|
truncate_line,
|
||||||
|
format_size,
|
||||||
|
TruncationResult,
|
||||||
|
DEFAULT_MAX_LINES,
|
||||||
|
DEFAULT_MAX_BYTES,
|
||||||
|
GREP_MAX_LINE_LENGTH
|
||||||
|
)
|
||||||
|
|
||||||
|
from .diff import (
|
||||||
|
strip_bom,
|
||||||
|
detect_line_ending,
|
||||||
|
normalize_to_lf,
|
||||||
|
restore_line_endings,
|
||||||
|
normalize_for_fuzzy_match,
|
||||||
|
fuzzy_find_text,
|
||||||
|
generate_diff_string,
|
||||||
|
FuzzyMatchResult
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'truncate_head',
|
||||||
|
'truncate_tail',
|
||||||
|
'truncate_line',
|
||||||
|
'format_size',
|
||||||
|
'TruncationResult',
|
||||||
|
'DEFAULT_MAX_LINES',
|
||||||
|
'DEFAULT_MAX_BYTES',
|
||||||
|
'GREP_MAX_LINE_LENGTH',
|
||||||
|
'strip_bom',
|
||||||
|
'detect_line_ending',
|
||||||
|
'normalize_to_lf',
|
||||||
|
'restore_line_endings',
|
||||||
|
'normalize_for_fuzzy_match',
|
||||||
|
'fuzzy_find_text',
|
||||||
|
'generate_diff_string',
|
||||||
|
'FuzzyMatchResult'
|
||||||
|
]
|
||||||
167
agent/tools/utils/diff.py
Normal file
167
agent/tools/utils/diff.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""
|
||||||
|
Diff tools for file editing
|
||||||
|
Provides fuzzy matching and diff generation functionality
|
||||||
|
"""
|
||||||
|
|
||||||
|
import difflib
|
||||||
|
import re
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def strip_bom(text: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Remove BOM (Byte Order Mark)
|
||||||
|
|
||||||
|
:param text: Original text
|
||||||
|
:return: (BOM, text after removing BOM)
|
||||||
|
"""
|
||||||
|
if text.startswith('\ufeff'):
|
||||||
|
return '\ufeff', text[1:]
|
||||||
|
return '', text
|
||||||
|
|
||||||
|
|
||||||
|
def detect_line_ending(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Detect line ending type
|
||||||
|
|
||||||
|
:param text: Text content
|
||||||
|
:return: Line ending type ('\r\n' or '\n')
|
||||||
|
"""
|
||||||
|
if '\r\n' in text:
|
||||||
|
return '\r\n'
|
||||||
|
return '\n'
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_to_lf(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Normalize all line endings to LF (\n)
|
||||||
|
|
||||||
|
:param text: Original text
|
||||||
|
:return: Normalized text
|
||||||
|
"""
|
||||||
|
return text.replace('\r\n', '\n').replace('\r', '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def restore_line_endings(text: str, original_ending: str) -> str:
|
||||||
|
"""
|
||||||
|
Restore original line endings
|
||||||
|
|
||||||
|
:param text: LF normalized text
|
||||||
|
:param original_ending: Original line ending
|
||||||
|
:return: Text with restored line endings
|
||||||
|
"""
|
||||||
|
if original_ending == '\r\n':
|
||||||
|
return text.replace('\n', '\r\n')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_for_fuzzy_match(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Normalize text for fuzzy matching
|
||||||
|
Remove excess whitespace but preserve basic structure
|
||||||
|
|
||||||
|
:param text: Original text
|
||||||
|
:return: Normalized text
|
||||||
|
"""
|
||||||
|
# Compress multiple spaces to one
|
||||||
|
text = re.sub(r'[ \t]+', ' ', text)
|
||||||
|
# Remove trailing spaces
|
||||||
|
text = re.sub(r' +\n', '\n', text)
|
||||||
|
# Remove leading spaces (but preserve indentation structure, only remove excess)
|
||||||
|
lines = text.split('\n')
|
||||||
|
normalized_lines = []
|
||||||
|
for line in lines:
|
||||||
|
# Preserve indentation but normalize to multiples of single spaces
|
||||||
|
stripped = line.lstrip()
|
||||||
|
if stripped:
|
||||||
|
indent_count = len(line) - len(stripped)
|
||||||
|
# Normalize indentation (convert tabs to spaces)
|
||||||
|
normalized_indent = ' ' * indent_count
|
||||||
|
normalized_lines.append(normalized_indent + stripped)
|
||||||
|
else:
|
||||||
|
normalized_lines.append('')
|
||||||
|
return '\n'.join(normalized_lines)
|
||||||
|
|
||||||
|
|
||||||
|
class FuzzyMatchResult:
|
||||||
|
"""Fuzzy match result"""
|
||||||
|
|
||||||
|
def __init__(self, found: bool, index: int = -1, match_length: int = 0, content_for_replacement: str = ""):
|
||||||
|
self.found = found
|
||||||
|
self.index = index
|
||||||
|
self.match_length = match_length
|
||||||
|
self.content_for_replacement = content_for_replacement
|
||||||
|
|
||||||
|
|
||||||
|
def fuzzy_find_text(content: str, old_text: str) -> FuzzyMatchResult:
|
||||||
|
"""
|
||||||
|
Find text in content, try exact match first, then fuzzy match
|
||||||
|
|
||||||
|
:param content: Content to search in
|
||||||
|
:param old_text: Text to find
|
||||||
|
:return: Match result
|
||||||
|
"""
|
||||||
|
# First try exact match
|
||||||
|
index = content.find(old_text)
|
||||||
|
if index != -1:
|
||||||
|
return FuzzyMatchResult(
|
||||||
|
found=True,
|
||||||
|
index=index,
|
||||||
|
match_length=len(old_text),
|
||||||
|
content_for_replacement=content
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try fuzzy match
|
||||||
|
fuzzy_content = normalize_for_fuzzy_match(content)
|
||||||
|
fuzzy_old_text = normalize_for_fuzzy_match(old_text)
|
||||||
|
|
||||||
|
index = fuzzy_content.find(fuzzy_old_text)
|
||||||
|
if index != -1:
|
||||||
|
# Fuzzy match successful, use normalized content for replacement
|
||||||
|
return FuzzyMatchResult(
|
||||||
|
found=True,
|
||||||
|
index=index,
|
||||||
|
match_length=len(fuzzy_old_text),
|
||||||
|
content_for_replacement=fuzzy_content
|
||||||
|
)
|
||||||
|
|
||||||
|
# Not found
|
||||||
|
return FuzzyMatchResult(found=False)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_diff_string(old_content: str, new_content: str) -> dict:
|
||||||
|
"""
|
||||||
|
Generate unified diff string
|
||||||
|
|
||||||
|
:param old_content: Old content
|
||||||
|
:param new_content: New content
|
||||||
|
:return: Dictionary containing diff and first changed line number
|
||||||
|
"""
|
||||||
|
old_lines = old_content.split('\n')
|
||||||
|
new_lines = new_content.split('\n')
|
||||||
|
|
||||||
|
# Generate unified diff
|
||||||
|
diff_lines = list(difflib.unified_diff(
|
||||||
|
old_lines,
|
||||||
|
new_lines,
|
||||||
|
lineterm='',
|
||||||
|
fromfile='original',
|
||||||
|
tofile='modified'
|
||||||
|
))
|
||||||
|
|
||||||
|
# Find first changed line number
|
||||||
|
first_changed_line = None
|
||||||
|
for line in diff_lines:
|
||||||
|
if line.startswith('@@'):
|
||||||
|
# Parse @@ -1,3 +1,3 @@ format
|
||||||
|
match = re.search(r'@@ -\d+,?\d* \+(\d+)', line)
|
||||||
|
if match:
|
||||||
|
first_changed_line = int(match.group(1))
|
||||||
|
break
|
||||||
|
|
||||||
|
diff_string = '\n'.join(diff_lines)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'diff': diff_string,
|
||||||
|
'first_changed_line': first_changed_line
|
||||||
|
}
|
||||||
292
agent/tools/utils/truncate.py
Normal file
292
agent/tools/utils/truncate.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
"""
|
||||||
|
Shared truncation utilities for tool outputs.
|
||||||
|
|
||||||
|
Truncation is based on two independent limits - whichever is hit first wins:
|
||||||
|
- Line limit (default: 2000 lines)
|
||||||
|
- Byte limit (default: 50KB)
|
||||||
|
|
||||||
|
Never returns partial lines (except bash tail truncation edge case).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, Optional, Literal
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MAX_LINES = 2000
|
||||||
|
DEFAULT_MAX_BYTES = 50 * 1024 # 50KB
|
||||||
|
GREP_MAX_LINE_LENGTH = 500 # Max chars per grep match line
|
||||||
|
|
||||||
|
|
||||||
|
class TruncationResult:
|
||||||
|
"""Truncation result"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
truncated: bool,
|
||||||
|
truncated_by: Optional[Literal["lines", "bytes"]],
|
||||||
|
total_lines: int,
|
||||||
|
total_bytes: int,
|
||||||
|
output_lines: int,
|
||||||
|
output_bytes: int,
|
||||||
|
last_line_partial: bool = False,
|
||||||
|
first_line_exceeds_limit: bool = False,
|
||||||
|
max_lines: int = DEFAULT_MAX_LINES,
|
||||||
|
max_bytes: int = DEFAULT_MAX_BYTES
|
||||||
|
):
|
||||||
|
self.content = content
|
||||||
|
self.truncated = truncated
|
||||||
|
self.truncated_by = truncated_by
|
||||||
|
self.total_lines = total_lines
|
||||||
|
self.total_bytes = total_bytes
|
||||||
|
self.output_lines = output_lines
|
||||||
|
self.output_bytes = output_bytes
|
||||||
|
self.last_line_partial = last_line_partial
|
||||||
|
self.first_line_exceeds_limit = first_line_exceeds_limit
|
||||||
|
self.max_lines = max_lines
|
||||||
|
self.max_bytes = max_bytes
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary"""
|
||||||
|
return {
|
||||||
|
"content": self.content,
|
||||||
|
"truncated": self.truncated,
|
||||||
|
"truncated_by": self.truncated_by,
|
||||||
|
"total_lines": self.total_lines,
|
||||||
|
"total_bytes": self.total_bytes,
|
||||||
|
"output_lines": self.output_lines,
|
||||||
|
"output_bytes": self.output_bytes,
|
||||||
|
"last_line_partial": self.last_line_partial,
|
||||||
|
"first_line_exceeds_limit": self.first_line_exceeds_limit,
|
||||||
|
"max_lines": self.max_lines,
|
||||||
|
"max_bytes": self.max_bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def format_size(bytes_count: int) -> str:
|
||||||
|
"""Format bytes as human-readable size"""
|
||||||
|
if bytes_count < 1024:
|
||||||
|
return f"{bytes_count}B"
|
||||||
|
elif bytes_count < 1024 * 1024:
|
||||||
|
return f"{bytes_count / 1024:.1f}KB"
|
||||||
|
else:
|
||||||
|
return f"{bytes_count / (1024 * 1024):.1f}MB"
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_head(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||||
|
"""
|
||||||
|
Truncate content from the head (keep first N lines/bytes).
|
||||||
|
Suitable for file reads where you want to see the beginning.
|
||||||
|
|
||||||
|
Never returns partial lines. If first line exceeds byte limit,
|
||||||
|
returns empty content with first_line_exceeds_limit=True.
|
||||||
|
|
||||||
|
:param content: Content to truncate
|
||||||
|
:param max_lines: Maximum number of lines (default: 2000)
|
||||||
|
:param max_bytes: Maximum number of bytes (default: 50KB)
|
||||||
|
:return: Truncation result
|
||||||
|
"""
|
||||||
|
if max_lines is None:
|
||||||
|
max_lines = DEFAULT_MAX_LINES
|
||||||
|
if max_bytes is None:
|
||||||
|
max_bytes = DEFAULT_MAX_BYTES
|
||||||
|
|
||||||
|
total_bytes = len(content.encode('utf-8'))
|
||||||
|
lines = content.split('\n')
|
||||||
|
total_lines = len(lines)
|
||||||
|
|
||||||
|
# Check if no truncation is needed
|
||||||
|
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||||
|
return TruncationResult(
|
||||||
|
content=content,
|
||||||
|
truncated=False,
|
||||||
|
truncated_by=None,
|
||||||
|
total_lines=total_lines,
|
||||||
|
total_bytes=total_bytes,
|
||||||
|
output_lines=total_lines,
|
||||||
|
output_bytes=total_bytes,
|
||||||
|
last_line_partial=False,
|
||||||
|
first_line_exceeds_limit=False,
|
||||||
|
max_lines=max_lines,
|
||||||
|
max_bytes=max_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if first line alone exceeds byte limit
|
||||||
|
first_line_bytes = len(lines[0].encode('utf-8'))
|
||||||
|
if first_line_bytes > max_bytes:
|
||||||
|
return TruncationResult(
|
||||||
|
content="",
|
||||||
|
truncated=True,
|
||||||
|
truncated_by="bytes",
|
||||||
|
total_lines=total_lines,
|
||||||
|
total_bytes=total_bytes,
|
||||||
|
output_lines=0,
|
||||||
|
output_bytes=0,
|
||||||
|
last_line_partial=False,
|
||||||
|
first_line_exceeds_limit=True,
|
||||||
|
max_lines=max_lines,
|
||||||
|
max_bytes=max_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect complete lines that fit
|
||||||
|
output_lines_arr = []
|
||||||
|
output_bytes_count = 0
|
||||||
|
truncated_by = "lines"
|
||||||
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if i >= max_lines:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Calculate line bytes (add 1 for newline if not first line)
|
||||||
|
line_bytes = len(line.encode('utf-8')) + (1 if i > 0 else 0)
|
||||||
|
|
||||||
|
if output_bytes_count + line_bytes > max_bytes:
|
||||||
|
truncated_by = "bytes"
|
||||||
|
break
|
||||||
|
|
||||||
|
output_lines_arr.append(line)
|
||||||
|
output_bytes_count += line_bytes
|
||||||
|
|
||||||
|
# If exited due to line limit
|
||||||
|
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||||
|
truncated_by = "lines"
|
||||||
|
|
||||||
|
output_content = '\n'.join(output_lines_arr)
|
||||||
|
final_output_bytes = len(output_content.encode('utf-8'))
|
||||||
|
|
||||||
|
return TruncationResult(
|
||||||
|
content=output_content,
|
||||||
|
truncated=True,
|
||||||
|
truncated_by=truncated_by,
|
||||||
|
total_lines=total_lines,
|
||||||
|
total_bytes=total_bytes,
|
||||||
|
output_lines=len(output_lines_arr),
|
||||||
|
output_bytes=final_output_bytes,
|
||||||
|
last_line_partial=False,
|
||||||
|
first_line_exceeds_limit=False,
|
||||||
|
max_lines=max_lines,
|
||||||
|
max_bytes=max_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_tail(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||||
|
"""
|
||||||
|
Truncate content from tail (keep last N lines/bytes).
|
||||||
|
Suitable for bash output where you want to see the ending content (errors, final results).
|
||||||
|
|
||||||
|
If the last line of original content exceeds byte limit, may return partial first line.
|
||||||
|
|
||||||
|
:param content: Content to truncate
|
||||||
|
:param max_lines: Maximum lines (default: 2000)
|
||||||
|
:param max_bytes: Maximum bytes (default: 50KB)
|
||||||
|
:return: Truncation result
|
||||||
|
"""
|
||||||
|
if max_lines is None:
|
||||||
|
max_lines = DEFAULT_MAX_LINES
|
||||||
|
if max_bytes is None:
|
||||||
|
max_bytes = DEFAULT_MAX_BYTES
|
||||||
|
|
||||||
|
total_bytes = len(content.encode('utf-8'))
|
||||||
|
lines = content.split('\n')
|
||||||
|
total_lines = len(lines)
|
||||||
|
|
||||||
|
# Check if no truncation is needed
|
||||||
|
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||||
|
return TruncationResult(
|
||||||
|
content=content,
|
||||||
|
truncated=False,
|
||||||
|
truncated_by=None,
|
||||||
|
total_lines=total_lines,
|
||||||
|
total_bytes=total_bytes,
|
||||||
|
output_lines=total_lines,
|
||||||
|
output_bytes=total_bytes,
|
||||||
|
last_line_partial=False,
|
||||||
|
first_line_exceeds_limit=False,
|
||||||
|
max_lines=max_lines,
|
||||||
|
max_bytes=max_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Work backwards from the end
|
||||||
|
output_lines_arr = []
|
||||||
|
output_bytes_count = 0
|
||||||
|
truncated_by = "lines"
|
||||||
|
last_line_partial = False
|
||||||
|
|
||||||
|
for i in range(len(lines) - 1, -1, -1):
|
||||||
|
if len(output_lines_arr) >= max_lines:
|
||||||
|
break
|
||||||
|
|
||||||
|
line = lines[i]
|
||||||
|
# Calculate line bytes (add newline if not the first added line)
|
||||||
|
line_bytes = len(line.encode('utf-8')) + (1 if len(output_lines_arr) > 0 else 0)
|
||||||
|
|
||||||
|
if output_bytes_count + line_bytes > max_bytes:
|
||||||
|
truncated_by = "bytes"
|
||||||
|
# Edge case: if we haven't added any lines yet and this line exceeds maxBytes,
|
||||||
|
# take the end portion of this line
|
||||||
|
if len(output_lines_arr) == 0:
|
||||||
|
truncated_line = _truncate_string_to_bytes_from_end(line, max_bytes)
|
||||||
|
output_lines_arr.insert(0, truncated_line)
|
||||||
|
output_bytes_count = len(truncated_line.encode('utf-8'))
|
||||||
|
last_line_partial = True
|
||||||
|
break
|
||||||
|
|
||||||
|
output_lines_arr.insert(0, line)
|
||||||
|
output_bytes_count += line_bytes
|
||||||
|
|
||||||
|
# If exited due to line limit
|
||||||
|
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||||
|
truncated_by = "lines"
|
||||||
|
|
||||||
|
output_content = '\n'.join(output_lines_arr)
|
||||||
|
final_output_bytes = len(output_content.encode('utf-8'))
|
||||||
|
|
||||||
|
return TruncationResult(
|
||||||
|
content=output_content,
|
||||||
|
truncated=True,
|
||||||
|
truncated_by=truncated_by,
|
||||||
|
total_lines=total_lines,
|
||||||
|
total_bytes=total_bytes,
|
||||||
|
output_lines=len(output_lines_arr),
|
||||||
|
output_bytes=final_output_bytes,
|
||||||
|
last_line_partial=last_line_partial,
|
||||||
|
first_line_exceeds_limit=False,
|
||||||
|
max_lines=max_lines,
|
||||||
|
max_bytes=max_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
|
||||||
|
"""
|
||||||
|
Truncate string to fit byte limit (from end).
|
||||||
|
Properly handles multi-byte UTF-8 characters.
|
||||||
|
|
||||||
|
:param text: String to truncate
|
||||||
|
:param max_bytes: Maximum bytes
|
||||||
|
:return: Truncated string
|
||||||
|
"""
|
||||||
|
encoded = text.encode('utf-8')
|
||||||
|
if len(encoded) <= max_bytes:
|
||||||
|
return text
|
||||||
|
|
||||||
|
# Start from end, skip back maxBytes
|
||||||
|
start = len(encoded) - max_bytes
|
||||||
|
|
||||||
|
# Find valid UTF-8 boundary (character start)
|
||||||
|
while start < len(encoded) and (encoded[start] & 0xC0) == 0x80:
|
||||||
|
start += 1
|
||||||
|
|
||||||
|
return encoded[start:].decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> tuple[str, bool]:
|
||||||
|
"""
|
||||||
|
Truncate single line to max characters, add [truncated] suffix.
|
||||||
|
Used for grep match lines.
|
||||||
|
|
||||||
|
:param line: Line to truncate
|
||||||
|
:param max_chars: Maximum characters
|
||||||
|
:return: (truncated text, whether truncated)
|
||||||
|
"""
|
||||||
|
if len(line) <= max_chars:
|
||||||
|
return line, False
|
||||||
|
return f"{line[:max_chars]}... [truncated]", True
|
||||||
3
agent/tools/write/__init__.py
Normal file
3
agent/tools/write/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .write import Write
|
||||||
|
|
||||||
|
__all__ = ['Write']
|
||||||
91
agent/tools/write/write.py
Normal file
91
agent/tools/write/write.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
Write tool - Write file content
|
||||||
|
Creates or overwrites files, automatically creates parent directories
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class Write(BaseTool):
|
||||||
|
"""Tool for writing file content"""
|
||||||
|
|
||||||
|
name: str = "write"
|
||||||
|
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories."
|
||||||
|
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to the file to write (relative or absolute)"
|
||||||
|
},
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Content to write to the file"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["path", "content"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: dict = None):
|
||||||
|
self.config = config or {}
|
||||||
|
self.cwd = self.config.get("cwd", os.getcwd())
|
||||||
|
|
||||||
|
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Execute file write operation
|
||||||
|
|
||||||
|
:param args: Contains file path and content
|
||||||
|
:return: Operation result
|
||||||
|
"""
|
||||||
|
path = args.get("path", "").strip()
|
||||||
|
content = args.get("content", "")
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
return ToolResult.fail("Error: path parameter is required")
|
||||||
|
|
||||||
|
# Resolve path
|
||||||
|
absolute_path = self._resolve_path(path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create parent directory (if needed)
|
||||||
|
parent_dir = os.path.dirname(absolute_path)
|
||||||
|
if parent_dir:
|
||||||
|
os.makedirs(parent_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Write file
|
||||||
|
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# Get bytes written
|
||||||
|
bytes_written = len(content.encode('utf-8'))
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"message": f"Successfully wrote {bytes_written} bytes to {path}",
|
||||||
|
"path": path,
|
||||||
|
"bytes_written": bytes_written
|
||||||
|
}
|
||||||
|
|
||||||
|
return ToolResult.success(result)
|
||||||
|
|
||||||
|
except PermissionError:
|
||||||
|
return ToolResult.fail(f"Error: Permission denied writing to {path}")
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(f"Error writing file: {str(e)}")
|
||||||
|
|
||||||
|
def _resolve_path(self, path: str) -> str:
|
||||||
|
"""
|
||||||
|
Resolve path to absolute path
|
||||||
|
|
||||||
|
:param path: Relative or absolute path
|
||||||
|
:return: Absolute path
|
||||||
|
"""
|
||||||
|
# Expand ~ to user home directory
|
||||||
|
path = os.path.expanduser(path)
|
||||||
|
if os.path.isabs(path):
|
||||||
|
return path
|
||||||
|
return os.path.abspath(os.path.join(self.cwd, path))
|
||||||
@@ -1,222 +0,0 @@
|
|||||||
import re
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from curl_cffi import requests
|
|
||||||
from bot.bot import Bot
|
|
||||||
from bot.claude.claude_ai_session import ClaudeAiSession
|
|
||||||
from bot.openai.open_ai_image import OpenAIImage
|
|
||||||
from bot.session_manager import SessionManager
|
|
||||||
from bridge.context import Context, ContextType
|
|
||||||
from bridge.reply import Reply, ReplyType
|
|
||||||
from common.log import logger
|
|
||||||
from config import conf
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeAIBot(Bot, OpenAIImage):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
|
|
||||||
self.claude_api_cookie = conf().get("claude_api_cookie")
|
|
||||||
self.proxy = conf().get("proxy")
|
|
||||||
self.con_uuid_dic = {}
|
|
||||||
if self.proxy:
|
|
||||||
self.proxies = {
|
|
||||||
"http": self.proxy,
|
|
||||||
"https": self.proxy
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
self.proxies = None
|
|
||||||
self.error = ""
|
|
||||||
self.org_uuid = self.get_organization_id()
|
|
||||||
|
|
||||||
def generate_uuid(self):
|
|
||||||
random_uuid = uuid.uuid4()
|
|
||||||
random_uuid_str = str(random_uuid)
|
|
||||||
formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
|
|
||||||
return formatted_uuid
|
|
||||||
|
|
||||||
def reply(self, query, context: Context = None) -> Reply:
|
|
||||||
if context.type == ContextType.TEXT:
|
|
||||||
return self._chat(query, context)
|
|
||||||
elif context.type == ContextType.IMAGE_CREATE:
|
|
||||||
ok, res = self.create_img(query, 0)
|
|
||||||
if ok:
|
|
||||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
|
||||||
else:
|
|
||||||
reply = Reply(ReplyType.ERROR, res)
|
|
||||||
return reply
|
|
||||||
else:
|
|
||||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
|
||||||
return reply
|
|
||||||
|
|
||||||
def get_organization_id(self):
|
|
||||||
url = "https://claude.ai/api/organizations"
|
|
||||||
headers = {
|
|
||||||
'User-Agent':
|
|
||||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
|
||||||
'Accept-Language': 'en-US,en;q=0.5',
|
|
||||||
'Referer': 'https://claude.ai/chats',
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Sec-Fetch-Dest': 'empty',
|
|
||||||
'Sec-Fetch-Mode': 'cors',
|
|
||||||
'Sec-Fetch-Site': 'same-origin',
|
|
||||||
'Connection': 'keep-alive',
|
|
||||||
'Cookie': f'{self.claude_api_cookie}'
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
|
|
||||||
res = json.loads(response.text)
|
|
||||||
uuid = res[0]['uuid']
|
|
||||||
except:
|
|
||||||
if "App unavailable" in response.text:
|
|
||||||
logger.error("IP error: The IP is not allowed to be used on Claude")
|
|
||||||
self.error = "ip所在地区不被claude支持"
|
|
||||||
elif "Invalid authorization" in response.text:
|
|
||||||
logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
|
|
||||||
self.error = "无法通过claude身份验证,请检查cookie"
|
|
||||||
return None
|
|
||||||
return uuid
|
|
||||||
|
|
||||||
def conversation_share_check(self,session_id):
|
|
||||||
if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
|
|
||||||
con_uuid = conf().get("claude_uuid")
|
|
||||||
return con_uuid
|
|
||||||
if session_id not in self.con_uuid_dic:
|
|
||||||
self.con_uuid_dic[session_id] = self.generate_uuid()
|
|
||||||
self.create_new_chat(self.con_uuid_dic[session_id])
|
|
||||||
return self.con_uuid_dic[session_id]
|
|
||||||
|
|
||||||
def check_cookie(self):
|
|
||||||
flag = self.get_organization_id()
|
|
||||||
return flag
|
|
||||||
|
|
||||||
def create_new_chat(self, con_uuid):
|
|
||||||
"""
|
|
||||||
新建claude对话实体
|
|
||||||
:param con_uuid: 对话id
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
|
|
||||||
payload = json.dumps({"uuid": con_uuid, "name": ""})
|
|
||||||
headers = {
|
|
||||||
'User-Agent':
|
|
||||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
|
||||||
'Accept-Language': 'en-US,en;q=0.5',
|
|
||||||
'Referer': 'https://claude.ai/chats',
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Origin': 'https://claude.ai',
|
|
||||||
'DNT': '1',
|
|
||||||
'Connection': 'keep-alive',
|
|
||||||
'Cookie': self.claude_api_cookie,
|
|
||||||
'Sec-Fetch-Dest': 'empty',
|
|
||||||
'Sec-Fetch-Mode': 'cors',
|
|
||||||
'Sec-Fetch-Site': 'same-origin',
|
|
||||||
'TE': 'trailers'
|
|
||||||
}
|
|
||||||
response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
|
|
||||||
# Returns JSON of the newly created conversation information
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
|
||||||
"""
|
|
||||||
发起对话请求
|
|
||||||
:param query: 请求提示词
|
|
||||||
:param context: 对话上下文
|
|
||||||
:param retry_count: 当前递归重试次数
|
|
||||||
:return: 回复
|
|
||||||
"""
|
|
||||||
if retry_count >= 2:
|
|
||||||
# exit from retry 2 times
|
|
||||||
logger.warn("[CLAUDEAI] failed after maximum number of retry times")
|
|
||||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
|
||||||
|
|
||||||
try:
|
|
||||||
session_id = context["session_id"]
|
|
||||||
if self.org_uuid is None:
|
|
||||||
return Reply(ReplyType.ERROR, self.error)
|
|
||||||
|
|
||||||
session = self.sessions.session_query(query, session_id)
|
|
||||||
con_uuid = self.conversation_share_check(session_id)
|
|
||||||
|
|
||||||
model = conf().get("model") or "gpt-3.5-turbo"
|
|
||||||
# remove system message
|
|
||||||
if session.messages[0].get("role") == "system":
|
|
||||||
if model == "wenxin" or model == "claude":
|
|
||||||
session.messages.pop(0)
|
|
||||||
logger.info(f"[CLAUDEAI] query={query}")
|
|
||||||
|
|
||||||
# do http request
|
|
||||||
base_url = "https://claude.ai"
|
|
||||||
payload = json.dumps({
|
|
||||||
"completion": {
|
|
||||||
"prompt": f"{query}",
|
|
||||||
"timezone": "Asia/Kolkata",
|
|
||||||
"model": "claude-2"
|
|
||||||
},
|
|
||||||
"organization_uuid": f"{self.org_uuid}",
|
|
||||||
"conversation_uuid": f"{con_uuid}",
|
|
||||||
"text": f"{query}",
|
|
||||||
"attachments": []
|
|
||||||
})
|
|
||||||
headers = {
|
|
||||||
'User-Agent':
|
|
||||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
|
||||||
'Accept': 'text/event-stream, text/event-stream',
|
|
||||||
'Accept-Language': 'en-US,en;q=0.5',
|
|
||||||
'Referer': 'https://claude.ai/chats',
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Origin': 'https://claude.ai',
|
|
||||||
'DNT': '1',
|
|
||||||
'Connection': 'keep-alive',
|
|
||||||
'Cookie': f'{self.claude_api_cookie}',
|
|
||||||
'Sec-Fetch-Dest': 'empty',
|
|
||||||
'Sec-Fetch-Mode': 'cors',
|
|
||||||
'Sec-Fetch-Site': 'same-origin',
|
|
||||||
'TE': 'trailers'
|
|
||||||
}
|
|
||||||
|
|
||||||
res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
|
|
||||||
if res.status_code == 200 or "pemission" in res.text:
|
|
||||||
# execute success
|
|
||||||
decoded_data = res.content.decode("utf-8")
|
|
||||||
decoded_data = re.sub('\n+', '\n', decoded_data).strip()
|
|
||||||
data_strings = decoded_data.split('\n')
|
|
||||||
completions = []
|
|
||||||
for data_string in data_strings:
|
|
||||||
json_str = data_string[6:].strip()
|
|
||||||
data = json.loads(json_str)
|
|
||||||
if 'completion' in data:
|
|
||||||
completions.append(data['completion'])
|
|
||||||
|
|
||||||
reply_content = ''.join(completions)
|
|
||||||
|
|
||||||
if "rate limi" in reply_content:
|
|
||||||
logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
|
|
||||||
return Reply(ReplyType.ERROR, "对话达到系统速率限制,与cladue同步,请进入官网查看解除限制时间")
|
|
||||||
logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
|
|
||||||
self.sessions.session_reply(reply_content, session_id, 100)
|
|
||||||
return Reply(ReplyType.TEXT, reply_content)
|
|
||||||
else:
|
|
||||||
flag = self.check_cookie()
|
|
||||||
if flag == None:
|
|
||||||
return Reply(ReplyType.ERROR, self.error)
|
|
||||||
|
|
||||||
response = res.json()
|
|
||||||
error = response.get("error")
|
|
||||||
logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
|
|
||||||
f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
|
|
||||||
|
|
||||||
if res.status_code >= 500:
|
|
||||||
# server error, need retry
|
|
||||||
time.sleep(2)
|
|
||||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
|
||||||
return self._chat(query, context, retry_count + 1)
|
|
||||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(e)
|
|
||||||
# retry
|
|
||||||
time.sleep(2)
|
|
||||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
|
||||||
return self._chat(query, context, retry_count + 1)
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
from bot.session_manager import Session
|
|
||||||
|
|
||||||
|
|
||||||
class ClaudeAiSession(Session):
|
|
||||||
def __init__(self, session_id, system_prompt=None, model="claude"):
|
|
||||||
super().__init__(session_id, system_prompt)
|
|
||||||
self.model = model
|
|
||||||
# claude逆向不支持role prompt
|
|
||||||
# self.reset()
|
|
||||||
@@ -1,19 +1,18 @@
|
|||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import openai
|
import requests
|
||||||
import openai.error
|
|
||||||
import anthropic
|
|
||||||
|
|
||||||
|
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||||
from bot.bot import Bot
|
from bot.bot import Bot
|
||||||
from bot.openai.open_ai_image import OpenAIImage
|
from bot.openai.open_ai_image import OpenAIImage
|
||||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
|
||||||
from bot.session_manager import SessionManager
|
from bot.session_manager import SessionManager
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
from common.log import logger
|
|
||||||
from common import const
|
from common import const
|
||||||
|
from common.log import logger
|
||||||
from config import conf
|
from config import conf
|
||||||
|
|
||||||
user_session = dict()
|
user_session = dict()
|
||||||
@@ -23,13 +22,9 @@ user_session = dict()
|
|||||||
class ClaudeAPIBot(Bot, OpenAIImage):
|
class ClaudeAPIBot(Bot, OpenAIImage):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
proxy = conf().get("proxy", None)
|
self.api_key = conf().get("claude_api_key")
|
||||||
base_url = conf().get("open_ai_api_base", None) # 复用"open_ai_api_base"参数作为base_url
|
self.api_base = conf().get("open_ai_api_base") or "https://api.anthropic.com/v1"
|
||||||
self.claudeClient = anthropic.Anthropic(
|
self.proxy = conf().get("proxy", None)
|
||||||
api_key=conf().get("claude_api_key"),
|
|
||||||
proxies=proxy if proxy else None,
|
|
||||||
base_url=base_url if base_url else None
|
|
||||||
)
|
|
||||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003")
|
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003")
|
||||||
|
|
||||||
def reply(self, query, context=None):
|
def reply(self, query, context=None):
|
||||||
@@ -73,39 +68,104 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
|||||||
reply = Reply(ReplyType.ERROR, retstring)
|
reply = Reply(ReplyType.ERROR, retstring)
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
def reply_text(self, session: BaiduWenxinSession, retry_count=0, tools=None):
|
||||||
try:
|
try:
|
||||||
actual_model = self._model_mapping(conf().get("model"))
|
actual_model = self._model_mapping(conf().get("model"))
|
||||||
response = self.claudeClient.messages.create(
|
|
||||||
model=actual_model,
|
# Prepare headers
|
||||||
max_tokens=4096,
|
headers = {
|
||||||
system=conf().get("character_desc", ""),
|
"x-api-key": self.api_key,
|
||||||
messages=session.messages
|
"anthropic-version": "2023-06-01",
|
||||||
|
"content-type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract system prompt if present and prepare Claude-compatible messages
|
||||||
|
system_prompt = conf().get("character_desc", "")
|
||||||
|
claude_messages = []
|
||||||
|
|
||||||
|
for msg in session.messages:
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
system_prompt = msg["content"]
|
||||||
|
else:
|
||||||
|
claude_messages.append(msg)
|
||||||
|
|
||||||
|
# Prepare request data
|
||||||
|
data = {
|
||||||
|
"model": actual_model,
|
||||||
|
"messages": claude_messages,
|
||||||
|
"max_tokens": self._get_max_tokens(actual_model)
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
data["system"] = system_prompt
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
data["tools"] = tools
|
||||||
|
|
||||||
|
# Make HTTP request
|
||||||
|
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.api_base}/messages",
|
||||||
|
headers=headers,
|
||||||
|
json=data,
|
||||||
|
proxies=proxies
|
||||||
)
|
)
|
||||||
# response = openai.Completion.create(prompt=str(session), **self.args)
|
|
||||||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
if response.status_code != 200:
|
||||||
total_tokens = response.usage.input_tokens+response.usage.output_tokens
|
raise Exception(f"API request failed: {response.status_code} - {response.text}")
|
||||||
completion_tokens = response.usage.output_tokens
|
|
||||||
|
claude_response = response.json()
|
||||||
|
# Handle response content and tool calls
|
||||||
|
res_content = ""
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
content_blocks = claude_response.get("content", [])
|
||||||
|
for block in content_blocks:
|
||||||
|
if block.get("type") == "text":
|
||||||
|
res_content += block.get("text", "")
|
||||||
|
elif block.get("type") == "tool_use":
|
||||||
|
tool_calls.append({
|
||||||
|
"id": block.get("id", ""),
|
||||||
|
"name": block.get("name", ""),
|
||||||
|
"arguments": block.get("input", {})
|
||||||
|
})
|
||||||
|
|
||||||
|
res_content = res_content.strip().replace("<|endoftext|>", "")
|
||||||
|
usage = claude_response.get("usage", {})
|
||||||
|
total_tokens = usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
|
||||||
|
completion_tokens = usage.get("output_tokens", 0)
|
||||||
|
|
||||||
logger.info("[CLAUDE_API] reply={}".format(res_content))
|
logger.info("[CLAUDE_API] reply={}".format(res_content))
|
||||||
return {
|
if tool_calls:
|
||||||
|
logger.info("[CLAUDE_API] tool_calls={}".format(tool_calls))
|
||||||
|
|
||||||
|
result = {
|
||||||
"total_tokens": total_tokens,
|
"total_tokens": total_tokens,
|
||||||
"completion_tokens": completion_tokens,
|
"completion_tokens": completion_tokens,
|
||||||
"content": res_content,
|
"content": res_content,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
result["tool_calls"] = tool_calls
|
||||||
|
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
need_retry = retry_count < 2
|
need_retry = retry_count < 2
|
||||||
result = {"total_tokens": 0, "completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
result = {"total_tokens": 0, "completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||||
if isinstance(e, openai.error.RateLimitError):
|
|
||||||
|
# Handle different types of errors
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if "rate" in error_str or "limit" in error_str:
|
||||||
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
|
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
|
||||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||||
if need_retry:
|
if need_retry:
|
||||||
time.sleep(20)
|
time.sleep(20)
|
||||||
elif isinstance(e, openai.error.Timeout):
|
elif "timeout" in error_str:
|
||||||
logger.warn("[CLAUDE_API] Timeout: {}".format(e))
|
logger.warn("[CLAUDE_API] Timeout: {}".format(e))
|
||||||
result["content"] = "我没有收到你的消息"
|
result["content"] = "我没有收到你的消息"
|
||||||
if need_retry:
|
if need_retry:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
elif isinstance(e, openai.error.APIConnectionError):
|
elif "connection" in error_str or "network" in error_str:
|
||||||
logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e))
|
logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e))
|
||||||
need_retry = False
|
need_retry = False
|
||||||
result["content"] = "我连接不到你的网络"
|
result["content"] = "我连接不到你的网络"
|
||||||
@@ -116,7 +176,7 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
|||||||
|
|
||||||
if need_retry:
|
if need_retry:
|
||||||
logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1))
|
logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1))
|
||||||
return self.reply_text(session, retry_count + 1)
|
return self.reply_text(session, retry_count + 1, tools)
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -130,3 +190,288 @@ class ClaudeAPIBot(Bot, OpenAIImage):
|
|||||||
elif model == "claude-3.5-sonnet":
|
elif model == "claude-3.5-sonnet":
|
||||||
return const.CLAUDE_35_SONNET
|
return const.CLAUDE_35_SONNET
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _get_max_tokens(self, model: str) -> int:
|
||||||
|
"""
|
||||||
|
Get max_tokens for the model.
|
||||||
|
Reference from pi-mono:
|
||||||
|
- Claude 3.5/3.7: 8192
|
||||||
|
- Claude 3 Opus: 4096
|
||||||
|
- Default: 8192
|
||||||
|
"""
|
||||||
|
if model and (model.startswith("claude-3-5") or model.startswith("claude-3-7")):
|
||||||
|
return 8192
|
||||||
|
elif model and model.startswith("claude-3") and "opus" in model:
|
||||||
|
return 4096
|
||||||
|
elif model and (model.startswith("claude-sonnet-4") or model.startswith("claude-opus-4")):
|
||||||
|
return 64000
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Call Claude API with tool support for agent integration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages
|
||||||
|
tools: List of tool definitions
|
||||||
|
stream: Whether to use streaming
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted response compatible with OpenAI format or generator for streaming
|
||||||
|
"""
|
||||||
|
actual_model = self._model_mapping(conf().get("model"))
|
||||||
|
|
||||||
|
# Extract system prompt from messages if present
|
||||||
|
system_prompt = kwargs.get("system", conf().get("character_desc", ""))
|
||||||
|
claude_messages = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
system_prompt = msg["content"]
|
||||||
|
else:
|
||||||
|
claude_messages.append(msg)
|
||||||
|
|
||||||
|
request_params = {
|
||||||
|
"model": actual_model,
|
||||||
|
"max_tokens": kwargs.get("max_tokens", self._get_max_tokens(actual_model)),
|
||||||
|
"messages": claude_messages,
|
||||||
|
"stream": stream
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_prompt:
|
||||||
|
request_params["system"] = system_prompt
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
request_params["tools"] = tools
|
||||||
|
|
||||||
|
try:
|
||||||
|
if stream:
|
||||||
|
return self._handle_stream_response(request_params)
|
||||||
|
else:
|
||||||
|
return self._handle_sync_response(request_params)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Claude API call error: {e}")
|
||||||
|
if stream:
|
||||||
|
# Return error generator for stream
|
||||||
|
def error_generator():
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|
||||||
|
return error_generator()
|
||||||
|
else:
|
||||||
|
# Return error response for sync
|
||||||
|
return {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|
||||||
|
def _handle_sync_response(self, request_params):
|
||||||
|
"""Handle synchronous Claude API response"""
|
||||||
|
# Prepare headers
|
||||||
|
headers = {
|
||||||
|
"x-api-key": self.api_key,
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"content-type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make HTTP request
|
||||||
|
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.api_base}/messages",
|
||||||
|
headers=headers,
|
||||||
|
json=request_params,
|
||||||
|
proxies=proxies
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"API request failed: {response.status_code} - {response.text}")
|
||||||
|
|
||||||
|
claude_response = response.json()
|
||||||
|
|
||||||
|
# Extract content blocks
|
||||||
|
text_content = ""
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
content_blocks = claude_response.get("content", [])
|
||||||
|
for block in content_blocks:
|
||||||
|
if block.get("type") == "text":
|
||||||
|
text_content += block.get("text", "")
|
||||||
|
elif block.get("type") == "tool_use":
|
||||||
|
tool_calls.append({
|
||||||
|
"id": block.get("id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": block.get("name", ""),
|
||||||
|
"arguments": json.dumps(block.get("input", {}))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# Build message in OpenAI format
|
||||||
|
message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text_content
|
||||||
|
}
|
||||||
|
if tool_calls:
|
||||||
|
message["tool_calls"] = tool_calls
|
||||||
|
|
||||||
|
# Format response to match OpenAI structure
|
||||||
|
usage = claude_response.get("usage", {})
|
||||||
|
formatted_response = {
|
||||||
|
"id": claude_response.get("id", ""),
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": claude_response.get("model", request_params["model"]),
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": message,
|
||||||
|
"finish_reason": claude_response.get("stop_reason", "stop")
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": usage.get("input_tokens", 0),
|
||||||
|
"completion_tokens": usage.get("output_tokens", 0),
|
||||||
|
"total_tokens": usage.get("input_tokens", 0) + usage.get("output_tokens", 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return formatted_response
|
||||||
|
|
||||||
|
def _handle_stream_response(self, request_params):
|
||||||
|
"""Handle streaming Claude API response using HTTP requests"""
|
||||||
|
# Prepare headers
|
||||||
|
headers = {
|
||||||
|
"x-api-key": self.api_key,
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"content-type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add stream parameter
|
||||||
|
request_params["stream"] = True
|
||||||
|
|
||||||
|
# Track tool use state
|
||||||
|
tool_uses_map = {} # {index: {id, name, input}}
|
||||||
|
current_tool_use_index = -1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make streaming HTTP request
|
||||||
|
proxies = {"http": self.proxy, "https": self.proxy} if self.proxy else None
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.api_base}/messages",
|
||||||
|
headers=headers,
|
||||||
|
json=request_params,
|
||||||
|
proxies=proxies,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_text = response.text
|
||||||
|
try:
|
||||||
|
error_data = json.loads(error_text)
|
||||||
|
error_msg = error_data.get("error", {}).get("message", error_text)
|
||||||
|
except:
|
||||||
|
error_msg = error_text or "Unknown error"
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"message": error_msg
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process streaming response
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
line = line.decode('utf-8')
|
||||||
|
if line.startswith('data: '):
|
||||||
|
line = line[6:] # Remove 'data: ' prefix
|
||||||
|
if line == '[DONE]':
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
event = json.loads(line)
|
||||||
|
event_type = event.get("type")
|
||||||
|
|
||||||
|
if event_type == "content_block_start":
|
||||||
|
# New content block
|
||||||
|
block = event.get("content_block", {})
|
||||||
|
if block.get("type") == "tool_use":
|
||||||
|
current_tool_use_index = event.get("index", 0)
|
||||||
|
tool_uses_map[current_tool_use_index] = {
|
||||||
|
"id": block.get("id", ""),
|
||||||
|
"name": block.get("name", ""),
|
||||||
|
"input": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
elif event_type == "content_block_delta":
|
||||||
|
delta = event.get("delta", {})
|
||||||
|
delta_type = delta.get("type")
|
||||||
|
|
||||||
|
if delta_type == "text_delta":
|
||||||
|
# Text content
|
||||||
|
content = delta.get("text", "")
|
||||||
|
yield {
|
||||||
|
"id": event.get("id", ""),
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": request_params["model"],
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": content},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
elif delta_type == "input_json_delta":
|
||||||
|
# Tool input accumulation
|
||||||
|
if current_tool_use_index >= 0:
|
||||||
|
tool_uses_map[current_tool_use_index]["input"] += delta.get("partial_json", "")
|
||||||
|
|
||||||
|
elif event_type == "message_delta":
|
||||||
|
# Message complete - yield tool calls if any
|
||||||
|
if tool_uses_map:
|
||||||
|
for idx in sorted(tool_uses_map.keys()):
|
||||||
|
tool_data = tool_uses_map[idx]
|
||||||
|
yield {
|
||||||
|
"id": event.get("id", ""),
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": request_params["model"],
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": idx,
|
||||||
|
"id": tool_data["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_data["name"],
|
||||||
|
"arguments": tool_data["input"]
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.error(f"Claude streaming request error: {e}")
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": f"Connection error: {str(e)}",
|
||||||
|
"status_code": 0
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Claude streaming error: {e}")
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|||||||
288
bridge/agent_bridge.py
Normal file
288
bridge/agent_bridge.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""
|
||||||
|
Agent Bridge - Integrates Agent system with existing COW bridge
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from agent.protocol import Agent, LLMModel, LLMRequest
|
||||||
|
from agent.tools import Calculator, CurrentTime, Read, Write, Edit, Bash, Grep, Find, Ls
|
||||||
|
from bridge.bridge import Bridge
|
||||||
|
from bridge.context import Context
|
||||||
|
from bridge.reply import Reply, ReplyType
|
||||||
|
from common import const
|
||||||
|
from common.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
class AgentLLMModel(LLMModel):
|
||||||
|
"""
|
||||||
|
LLM Model adapter that uses COW's existing bot infrastructure
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
|
||||||
|
# Get model name directly from config
|
||||||
|
from config import conf
|
||||||
|
model_name = conf().get("model", const.GPT_41)
|
||||||
|
super().__init__(model=model_name)
|
||||||
|
self.bridge = bridge
|
||||||
|
self.bot_type = bot_type
|
||||||
|
self._bot = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bot(self):
|
||||||
|
"""Lazy load the bot"""
|
||||||
|
if self._bot is None:
|
||||||
|
self._bot = self.bridge.get_bot(self.bot_type)
|
||||||
|
return self._bot
|
||||||
|
|
||||||
|
def call(self, request: LLMRequest):
|
||||||
|
"""
|
||||||
|
Call the model using COW's bot infrastructure
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# For non-streaming calls, we'll use the existing reply method
|
||||||
|
# This is a simplified implementation
|
||||||
|
if hasattr(self.bot, 'call_with_tools'):
|
||||||
|
# Use tool-enabled call if available
|
||||||
|
kwargs = {
|
||||||
|
'messages': request.messages,
|
||||||
|
'tools': getattr(request, 'tools', None),
|
||||||
|
'stream': False
|
||||||
|
}
|
||||||
|
# Only pass max_tokens if it's explicitly set
|
||||||
|
if request.max_tokens is not None:
|
||||||
|
kwargs['max_tokens'] = request.max_tokens
|
||||||
|
response = self.bot.call_with_tools(**kwargs)
|
||||||
|
return self._format_response(response)
|
||||||
|
else:
|
||||||
|
# Fallback to regular call
|
||||||
|
# This would need to be implemented based on your specific needs
|
||||||
|
raise NotImplementedError("Regular call not implemented yet")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AgentLLMModel call error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def call_stream(self, request: LLMRequest):
|
||||||
|
"""
|
||||||
|
Call the model with streaming using COW's bot infrastructure
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if hasattr(self.bot, 'call_with_tools'):
|
||||||
|
# Use tool-enabled streaming call if available
|
||||||
|
# Ensure max_tokens is an integer, use default if None
|
||||||
|
max_tokens = request.max_tokens if request.max_tokens is not None else 4096
|
||||||
|
|
||||||
|
# Extract system prompt if present
|
||||||
|
system_prompt = getattr(request, 'system', None)
|
||||||
|
|
||||||
|
# Build kwargs for call_with_tools
|
||||||
|
kwargs = {
|
||||||
|
'messages': request.messages,
|
||||||
|
'tools': getattr(request, 'tools', None),
|
||||||
|
'stream': True,
|
||||||
|
'max_tokens': max_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add system prompt if present
|
||||||
|
if system_prompt:
|
||||||
|
kwargs['system'] = system_prompt
|
||||||
|
|
||||||
|
stream = self.bot.call_with_tools(**kwargs)
|
||||||
|
|
||||||
|
# Convert Claude stream format to our expected format
|
||||||
|
for chunk in stream:
|
||||||
|
yield self._format_stream_chunk(chunk)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Streaming call not implemented yet")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AgentLLMModel call_stream error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _format_response(self, response):
|
||||||
|
"""Format Claude response to our expected format"""
|
||||||
|
# This would need to be implemented based on Claude's response format
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _format_stream_chunk(self, chunk):
|
||||||
|
"""Format Claude stream chunk to our expected format"""
|
||||||
|
# This would need to be implemented based on Claude's stream format
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
|
class AgentBridge:
|
||||||
|
"""
|
||||||
|
Bridge class that integrates single super Agent with COW
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bridge: Bridge):
|
||||||
|
self.bridge = bridge
|
||||||
|
self.agent: Optional[Agent] = None
|
||||||
|
|
||||||
|
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
|
||||||
|
"""
|
||||||
|
Create the super agent with COW integration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: System prompt
|
||||||
|
tools: List of tools (optional)
|
||||||
|
**kwargs: Additional agent parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent instance
|
||||||
|
"""
|
||||||
|
# Create LLM model that uses COW's bot infrastructure
|
||||||
|
model = AgentLLMModel(self.bridge)
|
||||||
|
|
||||||
|
# Default tools if none provided
|
||||||
|
if tools is None:
|
||||||
|
tools = [
|
||||||
|
Calculator(),
|
||||||
|
CurrentTime(),
|
||||||
|
Read(),
|
||||||
|
Write(),
|
||||||
|
Edit(),
|
||||||
|
Bash(),
|
||||||
|
Grep(),
|
||||||
|
Find(),
|
||||||
|
Ls()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create the single super agent
|
||||||
|
self.agent = Agent(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
description=kwargs.get("description", "AI Super Agent"),
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
max_steps=kwargs.get("max_steps", 15),
|
||||||
|
output_mode=kwargs.get("output_mode", "logger")
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.agent
|
||||||
|
|
||||||
|
def get_agent(self) -> Optional[Agent]:
|
||||||
|
"""Get the super agent, create if not exists"""
|
||||||
|
if self.agent is None:
|
||||||
|
self._init_default_agent()
|
||||||
|
return self.agent
|
||||||
|
|
||||||
|
def _init_default_agent(self):
|
||||||
|
"""Initialize default super agent with config and memory"""
|
||||||
|
from config import conf
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Get base system prompt from config
|
||||||
|
base_prompt = conf().get("character_desc", "你是一个AI助手")
|
||||||
|
|
||||||
|
# Setup memory if enabled
|
||||||
|
memory_manager = None
|
||||||
|
memory_tools = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to initialize memory system
|
||||||
|
from agent.memory import MemoryManager, MemoryConfig
|
||||||
|
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||||
|
|
||||||
|
# Create memory config directly with sensible defaults
|
||||||
|
workspace_root = os.path.expanduser("~/cow")
|
||||||
|
memory_config = MemoryConfig(
|
||||||
|
workspace_root=workspace_root,
|
||||||
|
embedding_provider="local", # Use local embedding (no API key needed)
|
||||||
|
embedding_model="all-MiniLM-L6-v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create memory manager with the config
|
||||||
|
memory_manager = MemoryManager(memory_config)
|
||||||
|
|
||||||
|
# Create memory tools
|
||||||
|
memory_tools = [
|
||||||
|
MemorySearchTool(memory_manager),
|
||||||
|
MemoryGetTool(memory_manager)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build memory guidance and add to system prompt
|
||||||
|
memory_guidance = memory_manager.build_memory_guidance(
|
||||||
|
lang="zh",
|
||||||
|
include_context=True
|
||||||
|
)
|
||||||
|
system_prompt = base_prompt + "\n\n" + memory_guidance
|
||||||
|
|
||||||
|
logger.info(f"[AgentBridge] Memory system initialized")
|
||||||
|
logger.info(f"[AgentBridge] Workspace: {memory_config.get_workspace()}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AgentBridge] Memory system not available: {e}")
|
||||||
|
logger.info("[AgentBridge] Continuing without memory features")
|
||||||
|
system_prompt = base_prompt
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
logger.info("[AgentBridge] Initializing super agent")
|
||||||
|
|
||||||
|
# Configure file tools to work in the correct workspace
|
||||||
|
file_config = {"cwd": workspace_root} if memory_manager else {}
|
||||||
|
|
||||||
|
# Create default tools with workspace config
|
||||||
|
from agent.tools import Calculator, CurrentTime, Read, Write, Edit, Bash, Grep, Find, Ls
|
||||||
|
tools = [
|
||||||
|
Calculator(),
|
||||||
|
CurrentTime(),
|
||||||
|
Read(config=file_config),
|
||||||
|
Write(config=file_config),
|
||||||
|
Edit(config=file_config),
|
||||||
|
Bash(config=file_config),
|
||||||
|
Grep(config=file_config),
|
||||||
|
Find(config=file_config),
|
||||||
|
Ls(config=file_config)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create agent with configured tools
|
||||||
|
agent = self.create_agent(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
max_steps=50,
|
||||||
|
output_mode="logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attach memory manager to agent if available
|
||||||
|
if memory_manager:
|
||||||
|
agent.memory_manager = memory_manager
|
||||||
|
|
||||||
|
# Add memory tools if available
|
||||||
|
if memory_tools:
|
||||||
|
for tool in memory_tools:
|
||||||
|
agent.add_tool(tool)
|
||||||
|
logger.info(f"[AgentBridge] Added {len(memory_tools)} memory tools")
|
||||||
|
|
||||||
|
def agent_reply(self, query: str, context: Context = None,
|
||||||
|
on_event=None, clear_history: bool = False) -> Reply:
|
||||||
|
"""
|
||||||
|
Use super agent to reply to a query
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: User query
|
||||||
|
context: COW context (optional)
|
||||||
|
on_event: Event callback (optional)
|
||||||
|
clear_history: Whether to clear conversation history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reply object
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get agent (will auto-initialize if needed)
|
||||||
|
agent = self.get_agent()
|
||||||
|
if not agent:
|
||||||
|
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
|
||||||
|
|
||||||
|
# Use agent's run_stream method
|
||||||
|
response = agent.run_stream(
|
||||||
|
user_message=query,
|
||||||
|
on_event=on_event,
|
||||||
|
clear_history=clear_history
|
||||||
|
)
|
||||||
|
|
||||||
|
return Reply(ReplyType.TEXT, response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Agent reply error: {e}")
|
||||||
|
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
|
||||||
@@ -23,7 +23,7 @@ class Bridge(object):
|
|||||||
if bot_type:
|
if bot_type:
|
||||||
self.btype["chat"] = bot_type
|
self.btype["chat"] = bot_type
|
||||||
else:
|
else:
|
||||||
model_type = conf().get("model") or const.GPT35
|
model_type = conf().get("model") or const.GPT_41_MINI
|
||||||
if model_type in ["text-davinci-003"]:
|
if model_type in ["text-davinci-003"]:
|
||||||
self.btype["chat"] = const.OPEN_AI
|
self.btype["chat"] = const.OPEN_AI
|
||||||
if conf().get("use_azure_chatgpt", False):
|
if conf().get("use_azure_chatgpt", False):
|
||||||
@@ -64,6 +64,7 @@ class Bridge(object):
|
|||||||
|
|
||||||
self.bots = {}
|
self.bots = {}
|
||||||
self.chat_bots = {}
|
self.chat_bots = {}
|
||||||
|
self._agent_bridge = None
|
||||||
|
|
||||||
# 模型对应的接口
|
# 模型对应的接口
|
||||||
def get_bot(self, typename):
|
def get_bot(self, typename):
|
||||||
@@ -104,3 +105,29 @@ class Bridge(object):
|
|||||||
重置bot路由
|
重置bot路由
|
||||||
"""
|
"""
|
||||||
self.__init__()
|
self.__init__()
|
||||||
|
|
||||||
|
def get_agent_bridge(self):
|
||||||
|
"""
|
||||||
|
Get agent bridge for agent-based conversations
|
||||||
|
"""
|
||||||
|
if self._agent_bridge is None:
|
||||||
|
from bridge.agent_bridge import AgentBridge
|
||||||
|
self._agent_bridge = AgentBridge(self)
|
||||||
|
return self._agent_bridge
|
||||||
|
|
||||||
|
def fetch_agent_reply(self, query: str, context: Context = None,
|
||||||
|
on_event=None, clear_history: bool = False) -> Reply:
|
||||||
|
"""
|
||||||
|
Use super agent to handle the query
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: User query
|
||||||
|
context: Context object
|
||||||
|
on_event: Event callback for streaming
|
||||||
|
clear_history: Whether to clear conversation history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reply object
|
||||||
|
"""
|
||||||
|
agent_bridge = self.get_agent_bridge()
|
||||||
|
return agent_bridge.agent_reply(query, context, on_event, clear_history)
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ Message sending channel abstract class
|
|||||||
from bridge.bridge import Bridge
|
from bridge.bridge import Bridge
|
||||||
from bridge.context import Context
|
from bridge.context import Context
|
||||||
from bridge.reply import *
|
from bridge.reply import *
|
||||||
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
|
||||||
class Channel(object):
|
class Channel(object):
|
||||||
@@ -35,6 +37,29 @@ class Channel(object):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def build_reply_content(self, query, context: Context = None) -> Reply:
|
def build_reply_content(self, query, context: Context = None) -> Reply:
|
||||||
|
"""
|
||||||
|
Build reply content, using agent if enabled in config
|
||||||
|
"""
|
||||||
|
# Check if agent mode is enabled
|
||||||
|
use_agent = conf().get("agent", False)
|
||||||
|
|
||||||
|
if use_agent:
|
||||||
|
try:
|
||||||
|
logger.info("[Channel] Using agent mode")
|
||||||
|
|
||||||
|
# Use agent bridge to handle the query
|
||||||
|
return Bridge().fetch_agent_reply(
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
on_event=None,
|
||||||
|
clear_history=False
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Channel] Agent mode failed, fallback to normal mode: {e}")
|
||||||
|
# Fallback to normal mode if agent fails
|
||||||
|
return Bridge().fetch_reply_content(query, context)
|
||||||
|
else:
|
||||||
|
# Normal mode
|
||||||
return Bridge().fetch_reply_content(query, context)
|
return Bridge().fetch_reply_content(query, context)
|
||||||
|
|
||||||
def build_voice_to_text(self, voice_file) -> Reply:
|
def build_voice_to_text(self, voice_file) -> Reply:
|
||||||
|
|||||||
@@ -150,9 +150,6 @@ class WebChannel(ChatChannel):
|
|||||||
Poll for responses using the session_id.
|
Poll for responses using the session_id.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 不记录轮询请求的日志
|
|
||||||
web.ctx.log_request = False
|
|
||||||
|
|
||||||
data = web.data()
|
data = web.data()
|
||||||
json_data = json.loads(data)
|
json_data = json.loads(data)
|
||||||
session_id = json_data.get('session_id')
|
session_id = json_data.get('session_id')
|
||||||
@@ -215,18 +212,19 @@ class WebChannel(ChatChannel):
|
|||||||
)
|
)
|
||||||
app = web.application(urls, globals(), autoreload=False)
|
app = web.application(urls, globals(), autoreload=False)
|
||||||
|
|
||||||
# 禁用web.py的默认日志输出
|
# 完全禁用web.py的HTTP日志输出
|
||||||
import io
|
# 创建一个空的日志处理函数
|
||||||
from contextlib import redirect_stdout
|
def null_log_function(status, environ):
|
||||||
|
pass
|
||||||
|
|
||||||
# 配置web.py的日志级别为ERROR,只显示错误
|
# 替换web.py的日志函数
|
||||||
|
web.httpserver.LogMiddleware.log = lambda self, status, environ: None
|
||||||
|
|
||||||
|
# 配置web.py的日志级别为ERROR
|
||||||
logging.getLogger("web").setLevel(logging.ERROR)
|
logging.getLogger("web").setLevel(logging.ERROR)
|
||||||
|
|
||||||
# 禁用web.httpserver的日志
|
|
||||||
logging.getLogger("web.httpserver").setLevel(logging.ERROR)
|
logging.getLogger("web.httpserver").setLevel(logging.ERROR)
|
||||||
|
|
||||||
# 临时重定向标准输出,捕获web.py的启动消息
|
# 启动服务器
|
||||||
with redirect_stdout(io.StringIO()):
|
|
||||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"model": "",
|
"model": "",
|
||||||
"open_ai_api_key": "YOUR API KEY",
|
"open_ai_api_key": "YOUR API KEY",
|
||||||
"claude_api_key": "YOUR API KEY",
|
"claude_api_key": "YOUR API KEY",
|
||||||
|
"claude_api_base": "https://api.anthropic.com",
|
||||||
"text_to_image": "dall-e-2",
|
"text_to_image": "dall-e-2",
|
||||||
"voice_to_text": "openai",
|
"voice_to_text": "openai",
|
||||||
"text_to_voice": "openai",
|
"text_to_voice": "openai",
|
||||||
@@ -30,8 +31,9 @@
|
|||||||
"expires_in_seconds": 3600,
|
"expires_in_seconds": 3600,
|
||||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||||
"use_linkai": false,
|
"use_linkai": false,
|
||||||
"linkai_api_key": "",
|
"linkai_api_key": "",
|
||||||
"linkai_app_code": ""
|
"linkai_app_code": "",
|
||||||
|
"agent": true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import copy
|
|
||||||
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
|
||||||
@@ -183,6 +183,7 @@ available_setting = {
|
|||||||
"Minimax_group_id": "",
|
"Minimax_group_id": "",
|
||||||
"Minimax_base_url": "",
|
"Minimax_base_url": "",
|
||||||
"web_port": 9899,
|
"web_port": 9899,
|
||||||
|
"agent": False # 是否开启Agent模式
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
5
memory/2026-01-29.md
Normal file
5
memory/2026-01-29.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# 2026-01-29 记录
|
||||||
|
|
||||||
|
## 老王的重要决定
|
||||||
|
- 今天老王告诉我他决定要学AI了,这是一个重要的决策
|
||||||
|
- 这可能会是他学习和职业发展的一个转折点
|
||||||
21
memory/MEMORY.md
Normal file
21
memory/MEMORY.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Memory
|
||||||
|
|
||||||
|
Long-term curated memories and preferences.
|
||||||
|
|
||||||
|
## 用户信息
|
||||||
|
- 用户名:老王
|
||||||
|
|
||||||
|
## 用户信息
|
||||||
|
- 用户名:老王
|
||||||
|
|
||||||
|
## 用户偏好
|
||||||
|
- 喜欢吃红烧肉
|
||||||
|
- 爱打篮球
|
||||||
|
|
||||||
|
## 重要决策
|
||||||
|
- 决定要学习AI(2026-01-29)
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- Important decisions and facts go here
|
||||||
|
- This is your long-term knowledge base
|
||||||
Reference in New Issue
Block a user