mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
feat: support skills
This commit is contained in:
@@ -11,7 +11,8 @@ from agent.tools.base_tool import BaseTool, ToolStage
|
|||||||
class Agent:
|
class Agent:
|
||||||
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
||||||
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
||||||
context_reserve_tokens=None, memory_manager=None, name: str = None):
|
context_reserve_tokens=None, memory_manager=None, name: str = None,
|
||||||
|
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True):
|
||||||
"""
|
"""
|
||||||
Initialize the Agent with system prompt, model, description.
|
Initialize the Agent with system prompt, model, description.
|
||||||
|
|
||||||
@@ -26,6 +27,9 @@ class Agent:
|
|||||||
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
|
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
|
||||||
:param memory_manager: Optional MemoryManager instance for memory operations
|
:param memory_manager: Optional MemoryManager instance for memory operations
|
||||||
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
|
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
|
||||||
|
:param workspace_dir: Optional workspace directory for workspace-specific skills
|
||||||
|
:param skill_manager: Optional SkillManager instance (will be created if None and enable_skills=True)
|
||||||
|
:param enable_skills: Whether to enable skills support (default: True)
|
||||||
"""
|
"""
|
||||||
self.name = name or "Agent"
|
self.name = name or "Agent"
|
||||||
self.system_prompt = system_prompt
|
self.system_prompt = system_prompt
|
||||||
@@ -40,6 +44,23 @@ class Agent:
|
|||||||
self.last_usage = None # Store last API response usage info
|
self.last_usage = None # Store last API response usage info
|
||||||
self.messages = [] # Unified message history for stream mode
|
self.messages = [] # Unified message history for stream mode
|
||||||
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||||
|
self.workspace_dir = workspace_dir # Workspace directory
|
||||||
|
self.enable_skills = enable_skills # Skills enabled flag
|
||||||
|
|
||||||
|
# Initialize skill manager
|
||||||
|
self.skill_manager = None
|
||||||
|
if enable_skills:
|
||||||
|
if skill_manager:
|
||||||
|
self.skill_manager = skill_manager
|
||||||
|
else:
|
||||||
|
# Auto-create skill manager
|
||||||
|
try:
|
||||||
|
from agent.skills import SkillManager
|
||||||
|
self.skill_manager = SkillManager(workspace_dir=workspace_dir)
|
||||||
|
logger.info(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to initialize SkillManager: {e}")
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
self.add_tool(tool)
|
self.add_tool(tool)
|
||||||
@@ -54,6 +75,52 @@ class Agent:
|
|||||||
tool.model = self.model
|
tool.model = self.model
|
||||||
self.tools.append(tool)
|
self.tools.append(tool)
|
||||||
|
|
||||||
|
def get_skills_prompt(self, skill_filter=None) -> str:
|
||||||
|
"""
|
||||||
|
Get the skills prompt to append to system prompt.
|
||||||
|
|
||||||
|
:param skill_filter: Optional list of skill names to include
|
||||||
|
:return: Formatted skills prompt or empty string
|
||||||
|
"""
|
||||||
|
if not self.skill_manager:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.skill_manager.build_skills_prompt(skill_filter=skill_filter)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to build skills prompt: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_full_system_prompt(self, skill_filter=None) -> str:
|
||||||
|
"""
|
||||||
|
Get the full system prompt including skills.
|
||||||
|
|
||||||
|
:param skill_filter: Optional list of skill names to include
|
||||||
|
:return: Complete system prompt with skills appended
|
||||||
|
"""
|
||||||
|
base_prompt = self.system_prompt
|
||||||
|
skills_prompt = self.get_skills_prompt(skill_filter=skill_filter)
|
||||||
|
|
||||||
|
if skills_prompt:
|
||||||
|
return base_prompt + "\n" + skills_prompt
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
def refresh_skills(self):
|
||||||
|
"""Refresh the loaded skills."""
|
||||||
|
if self.skill_manager:
|
||||||
|
self.skill_manager.refresh_skills()
|
||||||
|
logger.info(f"Refreshed skills: {len(self.skill_manager.skills)} skills loaded")
|
||||||
|
|
||||||
|
def list_skills(self):
|
||||||
|
"""
|
||||||
|
List all loaded skills.
|
||||||
|
|
||||||
|
:return: List of skill entries or empty list
|
||||||
|
"""
|
||||||
|
if not self.skill_manager:
|
||||||
|
return []
|
||||||
|
return self.skill_manager.list_skills()
|
||||||
|
|
||||||
def _get_model_context_window(self) -> int:
|
def _get_model_context_window(self) -> int:
|
||||||
"""
|
"""
|
||||||
Get the model's context window size in tokens.
|
Get the model's context window size in tokens.
|
||||||
@@ -229,7 +296,7 @@ class Agent:
|
|||||||
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False) -> str:
|
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, skill_filter=None) -> str:
|
||||||
"""
|
"""
|
||||||
Execute single agent task with streaming (based on tool-call)
|
Execute single agent task with streaming (based on tool-call)
|
||||||
|
|
||||||
@@ -244,6 +311,7 @@ class Agent:
|
|||||||
on_event: Event callback function callback(event: dict)
|
on_event: Event callback function callback(event: dict)
|
||||||
event = {"type": str, "timestamp": float, "data": dict}
|
event = {"type": str, "timestamp": float, "data": dict}
|
||||||
clear_history: If True, clear conversation history before this call (default: False)
|
clear_history: If True, clear conversation history before this call (default: False)
|
||||||
|
skill_filter: Optional list of skill names to include in this run
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Final response text
|
Final response text
|
||||||
@@ -264,11 +332,14 @@ class Agent:
|
|||||||
if not self.model:
|
if not self.model:
|
||||||
raise ValueError("No model available for agent")
|
raise ValueError("No model available for agent")
|
||||||
|
|
||||||
|
# Get full system prompt with skills
|
||||||
|
full_system_prompt = self.get_full_system_prompt(skill_filter=skill_filter)
|
||||||
|
|
||||||
# Create stream executor with agent's message history
|
# Create stream executor with agent's message history
|
||||||
executor = AgentStreamExecutor(
|
executor = AgentStreamExecutor(
|
||||||
agent=self,
|
agent=self,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
system_prompt=self.system_prompt,
|
system_prompt=full_system_prompt,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
max_turns=self.max_steps,
|
max_turns=self.max_steps,
|
||||||
on_event=on_event,
|
on_event=on_event,
|
||||||
|
|||||||
@@ -156,12 +156,17 @@ class AgentStreamExecutor:
|
|||||||
|
|
||||||
# Log tool result in compact format
|
# Log tool result in compact format
|
||||||
status_emoji = "✅" if result.get("status") == "success" else "❌"
|
status_emoji = "✅" if result.get("status") == "success" else "❌"
|
||||||
result_str = str(result.get('result', ''))
|
result_data = result.get('result', '')
|
||||||
|
# Format result string with proper Chinese character support
|
||||||
|
if isinstance(result_data, (dict, list)):
|
||||||
|
result_str = json.dumps(result_data, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
result_str = str(result_data)
|
||||||
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
logger.info(f" {status_emoji} {tool_call['name']} ({result.get('execution_time', 0):.2f}s): {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
|
||||||
|
|
||||||
# Build tool result block (Claude format)
|
# Build tool result block (Claude format)
|
||||||
# Content should be a string representation of the result
|
# Content should be a string representation of the result
|
||||||
result_content = json.dumps(result) if not isinstance(result, str) else result
|
result_content = json.dumps(result, ensure_ascii=False) if not isinstance(result, str) else result
|
||||||
tool_result_blocks.append({
|
tool_result_blocks.append({
|
||||||
"type": "tool_result",
|
"type": "tool_result",
|
||||||
"tool_use_id": tool_call["id"],
|
"tool_use_id": tool_call["id"],
|
||||||
|
|||||||
29
agent/skills/__init__.py
Normal file
29
agent/skills/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
Skills module for agent system.
|
||||||
|
|
||||||
|
This module provides the framework for loading, managing, and executing skills.
|
||||||
|
Skills are markdown files with frontmatter that provide specialized instructions
|
||||||
|
for specific tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from agent.skills.types import (
|
||||||
|
Skill,
|
||||||
|
SkillEntry,
|
||||||
|
SkillMetadata,
|
||||||
|
SkillInstallSpec,
|
||||||
|
LoadSkillsResult,
|
||||||
|
)
|
||||||
|
from agent.skills.loader import SkillLoader
|
||||||
|
from agent.skills.manager import SkillManager
|
||||||
|
from agent.skills.formatter import format_skills_for_prompt
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Skill",
|
||||||
|
"SkillEntry",
|
||||||
|
"SkillMetadata",
|
||||||
|
"SkillInstallSpec",
|
||||||
|
"LoadSkillsResult",
|
||||||
|
"SkillLoader",
|
||||||
|
"SkillManager",
|
||||||
|
"format_skills_for_prompt",
|
||||||
|
]
|
||||||
211
agent/skills/config.py
Normal file
211
agent/skills/config.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Configuration support for skills.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
from typing import Dict, Optional, List
|
||||||
|
from agent.skills.types import SkillEntry
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_runtime_platform() -> str:
|
||||||
|
"""Get the current runtime platform."""
|
||||||
|
return platform.system().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def has_binary(bin_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a binary is available in PATH.
|
||||||
|
|
||||||
|
:param bin_name: Binary name to check
|
||||||
|
:return: True if binary is available
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
return shutil.which(bin_name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def has_any_binary(bin_names: List[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if any of the given binaries is available.
|
||||||
|
|
||||||
|
:param bin_names: List of binary names to check
|
||||||
|
:return: True if at least one binary is available
|
||||||
|
"""
|
||||||
|
return any(has_binary(bin_name) for bin_name in bin_names)
|
||||||
|
|
||||||
|
|
||||||
|
def has_env_var(env_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an environment variable is set.
|
||||||
|
|
||||||
|
:param env_name: Environment variable name
|
||||||
|
:return: True if environment variable is set
|
||||||
|
"""
|
||||||
|
return env_name in os.environ and bool(os.environ[env_name].strip())
|
||||||
|
|
||||||
|
|
||||||
|
def get_skill_config(config: Optional[Dict], skill_name: str) -> Optional[Dict]:
|
||||||
|
"""
|
||||||
|
Get skill-specific configuration.
|
||||||
|
|
||||||
|
:param config: Global configuration dictionary
|
||||||
|
:param skill_name: Name of the skill
|
||||||
|
:return: Skill configuration or None
|
||||||
|
"""
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
skills_config = config.get('skills', {})
|
||||||
|
if not isinstance(skills_config, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
entries = skills_config.get('entries', {})
|
||||||
|
if not isinstance(entries, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return entries.get(skill_name)
|
||||||
|
|
||||||
|
|
||||||
|
def should_include_skill(
|
||||||
|
entry: SkillEntry,
|
||||||
|
config: Optional[Dict] = None,
|
||||||
|
current_platform: Optional[str] = None,
|
||||||
|
lenient: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if a skill should be included based on requirements.
|
||||||
|
|
||||||
|
Similar to clawdbot's shouldIncludeSkill logic, but with lenient mode:
|
||||||
|
- In lenient mode (default): Only check explicit disable and platform, ignore missing requirements
|
||||||
|
- In strict mode: Check all requirements (binary, env vars, config)
|
||||||
|
|
||||||
|
:param entry: SkillEntry to check
|
||||||
|
:param config: Configuration dictionary
|
||||||
|
:param current_platform: Current platform (default: auto-detect)
|
||||||
|
:param lenient: If True, ignore missing requirements and load all skills (default: True)
|
||||||
|
:return: True if skill should be included
|
||||||
|
"""
|
||||||
|
metadata = entry.metadata
|
||||||
|
skill_name = entry.skill.name
|
||||||
|
skill_config = get_skill_config(config, skill_name)
|
||||||
|
|
||||||
|
# Always check if skill is explicitly disabled in config
|
||||||
|
if skill_config and skill_config.get('enabled') is False:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not metadata:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Always check platform requirements (can't work on wrong platform)
|
||||||
|
if metadata.os:
|
||||||
|
platform_name = current_platform or resolve_runtime_platform()
|
||||||
|
# Map common platform names
|
||||||
|
platform_map = {
|
||||||
|
'darwin': 'darwin',
|
||||||
|
'linux': 'linux',
|
||||||
|
'windows': 'win32',
|
||||||
|
}
|
||||||
|
normalized_platform = platform_map.get(platform_name, platform_name)
|
||||||
|
|
||||||
|
if normalized_platform not in metadata.os:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If skill has 'always: true', include it regardless of other requirements
|
||||||
|
if metadata.always:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# In lenient mode, skip requirement checks and load all skills
|
||||||
|
# Skills will fail gracefully at runtime if requirements are missing
|
||||||
|
if lenient:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Strict mode: Check all requirements
|
||||||
|
if metadata.requires:
|
||||||
|
# Check required binaries (all must be present)
|
||||||
|
required_bins = metadata.requires.get('bins', [])
|
||||||
|
if required_bins:
|
||||||
|
if not all(has_binary(bin_name) for bin_name in required_bins):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check anyBins (at least one must be present)
|
||||||
|
any_bins = metadata.requires.get('anyBins', [])
|
||||||
|
if any_bins:
|
||||||
|
if not has_any_binary(any_bins):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check environment variables (with config fallback)
|
||||||
|
required_env = metadata.requires.get('env', [])
|
||||||
|
if required_env:
|
||||||
|
for env_name in required_env:
|
||||||
|
# Check in order: 1) env var, 2) skill config env, 3) skill config apiKey (if primaryEnv)
|
||||||
|
if has_env_var(env_name):
|
||||||
|
continue
|
||||||
|
if skill_config:
|
||||||
|
# Check skill config env dict
|
||||||
|
skill_env = skill_config.get('env', {})
|
||||||
|
if isinstance(skill_env, dict) and env_name in skill_env:
|
||||||
|
continue
|
||||||
|
# Check skill config apiKey (if this is the primaryEnv)
|
||||||
|
if metadata.primary_env == env_name and skill_config.get('apiKey'):
|
||||||
|
continue
|
||||||
|
# Requirement not satisfied
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check config paths
|
||||||
|
required_config = metadata.requires.get('config', [])
|
||||||
|
if required_config and config:
|
||||||
|
for config_path in required_config:
|
||||||
|
if not is_config_path_truthy(config, config_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_config_path_truthy(config: Dict, path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a config path resolves to a truthy value.
|
||||||
|
|
||||||
|
:param config: Configuration dictionary
|
||||||
|
:param path: Dot-separated path (e.g., 'skills.enabled')
|
||||||
|
:return: True if path resolves to truthy value
|
||||||
|
"""
|
||||||
|
parts = path.split('.')
|
||||||
|
current = config
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
if not isinstance(current, dict):
|
||||||
|
return False
|
||||||
|
current = current.get(part)
|
||||||
|
if current is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if value is truthy
|
||||||
|
if isinstance(current, bool):
|
||||||
|
return current
|
||||||
|
if isinstance(current, (int, float)):
|
||||||
|
return current != 0
|
||||||
|
if isinstance(current, str):
|
||||||
|
return bool(current.strip())
|
||||||
|
|
||||||
|
return bool(current)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_config_path(config: Dict, path: str):
|
||||||
|
"""
|
||||||
|
Resolve a dot-separated config path to its value.
|
||||||
|
|
||||||
|
:param config: Configuration dictionary
|
||||||
|
:param path: Dot-separated path
|
||||||
|
:return: Value at path or None
|
||||||
|
"""
|
||||||
|
parts = path.split('.')
|
||||||
|
current = config
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
if not isinstance(current, dict):
|
||||||
|
return None
|
||||||
|
current = current.get(part)
|
||||||
|
if current is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return current
|
||||||
62
agent/skills/formatter.py
Normal file
62
agent/skills/formatter.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
Skill formatter for generating prompts from skills.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from agent.skills.types import Skill, SkillEntry
|
||||||
|
|
||||||
|
|
||||||
|
def format_skills_for_prompt(skills: List[Skill]) -> str:
|
||||||
|
"""
|
||||||
|
Format skills for inclusion in a system prompt.
|
||||||
|
|
||||||
|
Uses XML format per Agent Skills standard.
|
||||||
|
Skills with disable_model_invocation=True are excluded.
|
||||||
|
|
||||||
|
:param skills: List of skills to format
|
||||||
|
:return: Formatted prompt text
|
||||||
|
"""
|
||||||
|
# Filter out skills that should not be invoked by the model
|
||||||
|
visible_skills = [s for s in skills if not s.disable_model_invocation]
|
||||||
|
|
||||||
|
if not visible_skills:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"\n\nThe following skills provide specialized instructions for specific tasks.",
|
||||||
|
"Use the read tool to load a skill's file when the task matches its description.",
|
||||||
|
"",
|
||||||
|
"<available_skills>",
|
||||||
|
]
|
||||||
|
|
||||||
|
for skill in visible_skills:
|
||||||
|
lines.append(" <skill>")
|
||||||
|
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||||
|
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||||
|
lines.append(f" <location>{_escape_xml(skill.file_path)}</location>")
|
||||||
|
lines.append(" </skill>")
|
||||||
|
|
||||||
|
lines.append("</available_skills>")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
|
||||||
|
"""
|
||||||
|
Format skill entries for inclusion in a system prompt.
|
||||||
|
|
||||||
|
:param entries: List of skill entries to format
|
||||||
|
:return: Formatted prompt text
|
||||||
|
"""
|
||||||
|
skills = [entry.skill for entry in entries]
|
||||||
|
return format_skills_for_prompt(skills)
|
||||||
|
|
||||||
|
|
||||||
|
def _escape_xml(text: str) -> str:
|
||||||
|
"""Escape XML special characters."""
|
||||||
|
return (text
|
||||||
|
.replace('&', '&')
|
||||||
|
.replace('<', '<')
|
||||||
|
.replace('>', '>')
|
||||||
|
.replace('"', '"')
|
||||||
|
.replace("'", '''))
|
||||||
159
agent/skills/frontmatter.py
Normal file
159
agent/skills/frontmatter.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""
|
||||||
|
Frontmatter parsing for skills.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from agent.skills.types import SkillMetadata, SkillInstallSpec
|
||||||
|
|
||||||
|
|
||||||
|
def parse_frontmatter(content: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Parse YAML-style frontmatter from markdown content.
|
||||||
|
|
||||||
|
Returns a dictionary of frontmatter fields.
|
||||||
|
"""
|
||||||
|
frontmatter = {}
|
||||||
|
|
||||||
|
# Match frontmatter block between --- markers
|
||||||
|
match = re.match(r'^---\s*\n(.*?)\n---\s*\n', content, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return frontmatter
|
||||||
|
|
||||||
|
frontmatter_text = match.group(1)
|
||||||
|
|
||||||
|
# Simple YAML-like parsing (supports key: value format)
|
||||||
|
for line in frontmatter_text.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith('#'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ':' in line:
|
||||||
|
key, value = line.split(':', 1)
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
# Try to parse as JSON if it looks like JSON
|
||||||
|
if value.startswith('{') or value.startswith('['):
|
||||||
|
try:
|
||||||
|
value = json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
# Parse boolean values
|
||||||
|
elif value.lower() in ('true', 'false'):
|
||||||
|
value = value.lower() == 'true'
|
||||||
|
# Parse numbers
|
||||||
|
elif value.isdigit():
|
||||||
|
value = int(value)
|
||||||
|
|
||||||
|
frontmatter[key] = value
|
||||||
|
|
||||||
|
return frontmatter
|
||||||
|
|
||||||
|
|
||||||
|
def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||||
|
"""
|
||||||
|
Parse skill metadata from frontmatter.
|
||||||
|
|
||||||
|
Looks for 'metadata' field containing JSON with skill configuration.
|
||||||
|
"""
|
||||||
|
metadata_raw = frontmatter.get('metadata')
|
||||||
|
if not metadata_raw:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# If it's a string, try to parse as JSON
|
||||||
|
if isinstance(metadata_raw, str):
|
||||||
|
try:
|
||||||
|
metadata_raw = json.loads(metadata_raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(metadata_raw, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Support both 'moltbot' and 'cow' keys for compatibility
|
||||||
|
meta_obj = metadata_raw.get('moltbot') or metadata_raw.get('cow')
|
||||||
|
if not meta_obj or not isinstance(meta_obj, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Parse install specs
|
||||||
|
install_specs = []
|
||||||
|
install_raw = meta_obj.get('install', [])
|
||||||
|
if isinstance(install_raw, list):
|
||||||
|
for spec_raw in install_raw:
|
||||||
|
if not isinstance(spec_raw, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
kind = spec_raw.get('kind', spec_raw.get('type', '')).lower()
|
||||||
|
if not kind:
|
||||||
|
continue
|
||||||
|
|
||||||
|
spec = SkillInstallSpec(
|
||||||
|
kind=kind,
|
||||||
|
id=spec_raw.get('id'),
|
||||||
|
label=spec_raw.get('label'),
|
||||||
|
bins=_normalize_string_list(spec_raw.get('bins')),
|
||||||
|
os=_normalize_string_list(spec_raw.get('os')),
|
||||||
|
formula=spec_raw.get('formula'),
|
||||||
|
package=spec_raw.get('package'),
|
||||||
|
module=spec_raw.get('module'),
|
||||||
|
url=spec_raw.get('url'),
|
||||||
|
archive=spec_raw.get('archive'),
|
||||||
|
extract=spec_raw.get('extract', False),
|
||||||
|
strip_components=spec_raw.get('stripComponents'),
|
||||||
|
target_dir=spec_raw.get('targetDir'),
|
||||||
|
)
|
||||||
|
install_specs.append(spec)
|
||||||
|
|
||||||
|
# Parse requires
|
||||||
|
requires = {}
|
||||||
|
requires_raw = meta_obj.get('requires', {})
|
||||||
|
if isinstance(requires_raw, dict):
|
||||||
|
for key, value in requires_raw.items():
|
||||||
|
requires[key] = _normalize_string_list(value)
|
||||||
|
|
||||||
|
return SkillMetadata(
|
||||||
|
always=meta_obj.get('always', False),
|
||||||
|
skill_key=meta_obj.get('skillKey'),
|
||||||
|
primary_env=meta_obj.get('primaryEnv'),
|
||||||
|
emoji=meta_obj.get('emoji'),
|
||||||
|
homepage=meta_obj.get('homepage'),
|
||||||
|
os=_normalize_string_list(meta_obj.get('os')),
|
||||||
|
requires=requires,
|
||||||
|
install=install_specs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_string_list(value: Any) -> List[str]:
|
||||||
|
"""Normalize a value to a list of strings."""
|
||||||
|
if not value:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [str(v).strip() for v in value if v]
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
return [v.strip() for v in value.split(',') if v.strip()]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def parse_boolean_value(value: Optional[str], default: bool = False) -> bool:
|
||||||
|
"""Parse a boolean value from frontmatter."""
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.lower() in ('true', '1', 'yes', 'on')
|
||||||
|
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def get_frontmatter_value(frontmatter: Dict[str, Any], key: str) -> Optional[str]:
|
||||||
|
"""Get a frontmatter value as a string."""
|
||||||
|
value = frontmatter.get(key)
|
||||||
|
return str(value) if value is not None else None
|
||||||
242
agent/skills/loader.py
Normal file
242
agent/skills/loader.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
Skill loader for discovering and loading skills from directories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
from common.log import logger
|
||||||
|
from agent.skills.types import Skill, SkillEntry, LoadSkillsResult, SkillMetadata
|
||||||
|
from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_boolean_value, get_frontmatter_value
|
||||||
|
|
||||||
|
|
||||||
|
class SkillLoader:
|
||||||
|
"""Loads skills from various directories."""
|
||||||
|
|
||||||
|
def __init__(self, workspace_dir: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the skill loader.
|
||||||
|
|
||||||
|
:param workspace_dir: Agent workspace directory (for workspace-specific skills)
|
||||||
|
"""
|
||||||
|
self.workspace_dir = workspace_dir
|
||||||
|
|
||||||
|
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
|
||||||
|
"""
|
||||||
|
Load skills from a directory.
|
||||||
|
|
||||||
|
Discovery rules:
|
||||||
|
- Direct .md files in the root directory
|
||||||
|
- Recursive SKILL.md files under subdirectories
|
||||||
|
|
||||||
|
:param dir_path: Directory path to scan
|
||||||
|
:param source: Source identifier (e.g., 'managed', 'workspace', 'bundled')
|
||||||
|
:return: LoadSkillsResult with skills and diagnostics
|
||||||
|
"""
|
||||||
|
skills = []
|
||||||
|
diagnostics = []
|
||||||
|
|
||||||
|
if not os.path.exists(dir_path):
|
||||||
|
diagnostics.append(f"Directory does not exist: {dir_path}")
|
||||||
|
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||||
|
|
||||||
|
if not os.path.isdir(dir_path):
|
||||||
|
diagnostics.append(f"Path is not a directory: {dir_path}")
|
||||||
|
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||||
|
|
||||||
|
# Load skills from root-level .md files and subdirectories
|
||||||
|
result = self._load_skills_recursive(dir_path, source, include_root_files=True)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _load_skills_recursive(
|
||||||
|
self,
|
||||||
|
dir_path: str,
|
||||||
|
source: str,
|
||||||
|
include_root_files: bool = False
|
||||||
|
) -> LoadSkillsResult:
|
||||||
|
"""
|
||||||
|
Recursively load skills from a directory.
|
||||||
|
|
||||||
|
:param dir_path: Directory to scan
|
||||||
|
:param source: Source identifier
|
||||||
|
:param include_root_files: Whether to include root-level .md files
|
||||||
|
:return: LoadSkillsResult
|
||||||
|
"""
|
||||||
|
skills = []
|
||||||
|
diagnostics = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
entries = os.listdir(dir_path)
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
|
||||||
|
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
# Skip hidden files and directories
|
||||||
|
if entry.startswith('.'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip common non-skill directories
|
||||||
|
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
full_path = os.path.join(dir_path, entry)
|
||||||
|
|
||||||
|
# Handle directories
|
||||||
|
if os.path.isdir(full_path):
|
||||||
|
# Recursively scan subdirectories
|
||||||
|
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
|
||||||
|
skills.extend(sub_result.skills)
|
||||||
|
diagnostics.extend(sub_result.diagnostics)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle files
|
||||||
|
if not os.path.isfile(full_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this is a skill file
|
||||||
|
is_root_md = include_root_files and entry.endswith('.md')
|
||||||
|
is_skill_md = not include_root_files and entry == 'SKILL.md'
|
||||||
|
|
||||||
|
if not (is_root_md or is_skill_md):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Load the skill
|
||||||
|
skill_result = self._load_skill_from_file(full_path, source)
|
||||||
|
if skill_result.skills:
|
||||||
|
skills.extend(skill_result.skills)
|
||||||
|
diagnostics.extend(skill_result.diagnostics)
|
||||||
|
|
||||||
|
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||||
|
|
||||||
|
def _load_skill_from_file(self, file_path: str, source: str) -> LoadSkillsResult:
|
||||||
|
"""
|
||||||
|
Load a single skill from a markdown file.
|
||||||
|
|
||||||
|
:param file_path: Path to the skill markdown file
|
||||||
|
:param source: Source identifier
|
||||||
|
:return: LoadSkillsResult
|
||||||
|
"""
|
||||||
|
diagnostics = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
except Exception as e:
|
||||||
|
diagnostics.append(f"Failed to read skill file {file_path}: {e}")
|
||||||
|
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||||
|
|
||||||
|
# Parse frontmatter
|
||||||
|
frontmatter = parse_frontmatter(content)
|
||||||
|
|
||||||
|
# Get skill name and description
|
||||||
|
skill_dir = os.path.dirname(file_path)
|
||||||
|
parent_dir_name = os.path.basename(skill_dir)
|
||||||
|
|
||||||
|
name = frontmatter.get('name', parent_dir_name)
|
||||||
|
description = frontmatter.get('description', '')
|
||||||
|
|
||||||
|
if not description or not description.strip():
|
||||||
|
diagnostics.append(f"Skill {name} has no description: {file_path}")
|
||||||
|
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||||
|
|
||||||
|
# Parse disable-model-invocation flag
|
||||||
|
disable_model_invocation = parse_boolean_value(
|
||||||
|
get_frontmatter_value(frontmatter, 'disable-model-invocation'),
|
||||||
|
default=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create skill object
|
||||||
|
skill = Skill(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
file_path=file_path,
|
||||||
|
base_dir=skill_dir,
|
||||||
|
source=source,
|
||||||
|
content=content,
|
||||||
|
disable_model_invocation=disable_model_invocation,
|
||||||
|
frontmatter=frontmatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LoadSkillsResult(skills=[skill], diagnostics=diagnostics)
|
||||||
|
|
||||||
|
def load_all_skills(
|
||||||
|
self,
|
||||||
|
managed_dir: Optional[str] = None,
|
||||||
|
workspace_skills_dir: Optional[str] = None,
|
||||||
|
extra_dirs: Optional[List[str]] = None,
|
||||||
|
) -> Dict[str, SkillEntry]:
|
||||||
|
"""
|
||||||
|
Load skills from all configured locations with precedence.
|
||||||
|
|
||||||
|
Precedence (lowest to highest):
|
||||||
|
1. Extra directories
|
||||||
|
2. Managed skills directory
|
||||||
|
3. Workspace skills directory
|
||||||
|
|
||||||
|
:param managed_dir: Managed skills directory (e.g., ~/.cow/skills)
|
||||||
|
:param workspace_skills_dir: Workspace skills directory (e.g., workspace/skills)
|
||||||
|
:param extra_dirs: Additional directories to load skills from
|
||||||
|
:return: Dictionary mapping skill name to SkillEntry
|
||||||
|
"""
|
||||||
|
skill_map: Dict[str, SkillEntry] = {}
|
||||||
|
all_diagnostics = []
|
||||||
|
|
||||||
|
# Load from extra directories (lowest precedence)
|
||||||
|
if extra_dirs:
|
||||||
|
for extra_dir in extra_dirs:
|
||||||
|
if not os.path.exists(extra_dir):
|
||||||
|
continue
|
||||||
|
result = self.load_skills_from_dir(extra_dir, source='extra')
|
||||||
|
all_diagnostics.extend(result.diagnostics)
|
||||||
|
for skill in result.skills:
|
||||||
|
entry = self._create_skill_entry(skill)
|
||||||
|
skill_map[skill.name] = entry
|
||||||
|
|
||||||
|
# Load from managed directory
|
||||||
|
if managed_dir and os.path.exists(managed_dir):
|
||||||
|
result = self.load_skills_from_dir(managed_dir, source='managed')
|
||||||
|
all_diagnostics.extend(result.diagnostics)
|
||||||
|
for skill in result.skills:
|
||||||
|
entry = self._create_skill_entry(skill)
|
||||||
|
skill_map[skill.name] = entry
|
||||||
|
|
||||||
|
# Load from workspace directory (highest precedence)
|
||||||
|
if workspace_skills_dir and os.path.exists(workspace_skills_dir):
|
||||||
|
result = self.load_skills_from_dir(workspace_skills_dir, source='workspace')
|
||||||
|
all_diagnostics.extend(result.diagnostics)
|
||||||
|
for skill in result.skills:
|
||||||
|
entry = self._create_skill_entry(skill)
|
||||||
|
skill_map[skill.name] = entry
|
||||||
|
|
||||||
|
# Log diagnostics
|
||||||
|
if all_diagnostics:
|
||||||
|
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
|
||||||
|
for diag in all_diagnostics[:5]: # Log first 5
|
||||||
|
logger.debug(f" - {diag}")
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(skill_map)} skills from all sources")
|
||||||
|
|
||||||
|
return skill_map
|
||||||
|
|
||||||
|
def _create_skill_entry(self, skill: Skill) -> SkillEntry:
|
||||||
|
"""
|
||||||
|
Create a SkillEntry from a Skill with parsed metadata.
|
||||||
|
|
||||||
|
:param skill: The skill to create an entry for
|
||||||
|
:return: SkillEntry with metadata
|
||||||
|
"""
|
||||||
|
metadata = parse_metadata(skill.frontmatter)
|
||||||
|
|
||||||
|
# Parse user-invocable flag
|
||||||
|
user_invocable = parse_boolean_value(
|
||||||
|
get_frontmatter_value(skill.frontmatter, 'user-invocable'),
|
||||||
|
default=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return SkillEntry(
|
||||||
|
skill=skill,
|
||||||
|
metadata=metadata,
|
||||||
|
user_invocable=user_invocable,
|
||||||
|
)
|
||||||
214
agent/skills/manager.py
Normal file
214
agent/skills/manager.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""
|
||||||
|
Skill manager for managing skill lifecycle and operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
from common.log import logger
|
||||||
|
from agent.skills.types import Skill, SkillEntry, SkillSnapshot
|
||||||
|
from agent.skills.loader import SkillLoader
|
||||||
|
from agent.skills.formatter import format_skill_entries_for_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class SkillManager:
|
||||||
|
"""Manages skills for an agent."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace_dir: Optional[str] = None,
|
||||||
|
managed_skills_dir: Optional[str] = None,
|
||||||
|
extra_dirs: Optional[List[str]] = None,
|
||||||
|
config: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the skill manager.
|
||||||
|
|
||||||
|
:param workspace_dir: Agent workspace directory
|
||||||
|
:param managed_skills_dir: Managed skills directory (e.g., ~/.cow/skills)
|
||||||
|
:param extra_dirs: Additional skill directories
|
||||||
|
:param config: Configuration dictionary
|
||||||
|
"""
|
||||||
|
self.workspace_dir = workspace_dir
|
||||||
|
self.managed_skills_dir = managed_skills_dir or self._get_default_managed_dir()
|
||||||
|
self.extra_dirs = extra_dirs or []
|
||||||
|
self.config = config or {}
|
||||||
|
|
||||||
|
self.loader = SkillLoader(workspace_dir=workspace_dir)
|
||||||
|
self.skills: Dict[str, SkillEntry] = {}
|
||||||
|
|
||||||
|
# Load skills on initialization
|
||||||
|
self.refresh_skills()
|
||||||
|
|
||||||
|
def _get_default_managed_dir(self) -> str:
|
||||||
|
"""Get the default managed skills directory."""
|
||||||
|
# Use project root skills directory as default
|
||||||
|
import os
|
||||||
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
return os.path.join(project_root, 'skills')
|
||||||
|
|
||||||
|
def refresh_skills(self):
|
||||||
|
"""Reload all skills from configured directories."""
|
||||||
|
workspace_skills_dir = None
|
||||||
|
if self.workspace_dir:
|
||||||
|
workspace_skills_dir = os.path.join(self.workspace_dir, 'skills')
|
||||||
|
|
||||||
|
self.skills = self.loader.load_all_skills(
|
||||||
|
managed_dir=self.managed_skills_dir,
|
||||||
|
workspace_skills_dir=workspace_skills_dir,
|
||||||
|
extra_dirs=self.extra_dirs,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"SkillManager: Loaded {len(self.skills)} skills")
|
||||||
|
|
||||||
|
def get_skill(self, name: str) -> Optional[SkillEntry]:
|
||||||
|
"""
|
||||||
|
Get a skill by name.
|
||||||
|
|
||||||
|
:param name: Skill name
|
||||||
|
:return: SkillEntry or None if not found
|
||||||
|
"""
|
||||||
|
return self.skills.get(name)
|
||||||
|
|
||||||
|
def list_skills(self) -> List[SkillEntry]:
|
||||||
|
"""
|
||||||
|
Get all loaded skills.
|
||||||
|
|
||||||
|
:return: List of all skill entries
|
||||||
|
"""
|
||||||
|
return list(self.skills.values())
|
||||||
|
|
||||||
|
def filter_skills(
|
||||||
|
self,
|
||||||
|
skill_filter: Optional[List[str]] = None,
|
||||||
|
include_disabled: bool = False,
|
||||||
|
check_requirements: bool = False, # Changed default to False for lenient loading
|
||||||
|
lenient: bool = True, # New parameter for lenient mode
|
||||||
|
) -> List[SkillEntry]:
|
||||||
|
"""
|
||||||
|
Filter skills based on criteria.
|
||||||
|
|
||||||
|
By default (lenient=True), all skills are loaded regardless of missing requirements.
|
||||||
|
Skills will fail gracefully at runtime if requirements are not met.
|
||||||
|
|
||||||
|
:param skill_filter: List of skill names to include (None = all)
|
||||||
|
:param include_disabled: Whether to include skills with disable_model_invocation=True
|
||||||
|
:param check_requirements: Whether to check skill requirements (default: False)
|
||||||
|
:param lenient: If True, ignore missing requirements (default: True)
|
||||||
|
:return: Filtered list of skill entries
|
||||||
|
"""
|
||||||
|
from agent.skills.config import should_include_skill
|
||||||
|
|
||||||
|
entries = list(self.skills.values())
|
||||||
|
|
||||||
|
# Check requirements (platform, explicit disable, etc.)
|
||||||
|
# In lenient mode, only checks platform and explicit disable
|
||||||
|
if check_requirements or not lenient:
|
||||||
|
entries = [e for e in entries if should_include_skill(e, self.config, lenient=lenient)]
|
||||||
|
else:
|
||||||
|
# Lenient mode: only check explicit disable and platform
|
||||||
|
entries = [e for e in entries if should_include_skill(e, self.config, lenient=True)]
|
||||||
|
|
||||||
|
# Apply skill filter
|
||||||
|
if skill_filter is not None:
|
||||||
|
normalized = [name.strip() for name in skill_filter if name.strip()]
|
||||||
|
if normalized:
|
||||||
|
entries = [e for e in entries if e.skill.name in normalized]
|
||||||
|
|
||||||
|
# Filter out disabled skills unless explicitly requested
|
||||||
|
if not include_disabled:
|
||||||
|
entries = [e for e in entries if not e.skill.disable_model_invocation]
|
||||||
|
|
||||||
|
return entries
|
||||||
|
|
||||||
|
def build_skills_prompt(
|
||||||
|
self,
|
||||||
|
skill_filter: Optional[List[str]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build a formatted prompt containing available skills.
|
||||||
|
|
||||||
|
:param skill_filter: Optional list of skill names to include
|
||||||
|
:return: Formatted skills prompt
|
||||||
|
"""
|
||||||
|
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||||
|
return format_skill_entries_for_prompt(entries)
|
||||||
|
|
||||||
|
def build_skill_snapshot(
|
||||||
|
self,
|
||||||
|
skill_filter: Optional[List[str]] = None,
|
||||||
|
version: Optional[int] = None,
|
||||||
|
) -> SkillSnapshot:
|
||||||
|
"""
|
||||||
|
Build a snapshot of skills for a specific run.
|
||||||
|
|
||||||
|
:param skill_filter: Optional list of skill names to include
|
||||||
|
:param version: Optional version number for the snapshot
|
||||||
|
:return: SkillSnapshot
|
||||||
|
"""
|
||||||
|
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||||
|
prompt = format_skill_entries_for_prompt(entries)
|
||||||
|
|
||||||
|
skills_info = []
|
||||||
|
resolved_skills = []
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
skills_info.append({
|
||||||
|
'name': entry.skill.name,
|
||||||
|
'primary_env': entry.metadata.primary_env if entry.metadata else None,
|
||||||
|
})
|
||||||
|
resolved_skills.append(entry.skill)
|
||||||
|
|
||||||
|
return SkillSnapshot(
|
||||||
|
prompt=prompt,
|
||||||
|
skills=skills_info,
|
||||||
|
resolved_skills=resolved_skills,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_skills_to_workspace(self, target_workspace_dir: str):
|
||||||
|
"""
|
||||||
|
Sync all loaded skills to a target workspace directory.
|
||||||
|
|
||||||
|
This is useful for sandbox environments where skills need to be copied.
|
||||||
|
|
||||||
|
:param target_workspace_dir: Target workspace directory
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
target_skills_dir = os.path.join(target_workspace_dir, 'skills')
|
||||||
|
|
||||||
|
# Remove existing skills directory
|
||||||
|
if os.path.exists(target_skills_dir):
|
||||||
|
shutil.rmtree(target_skills_dir)
|
||||||
|
|
||||||
|
# Create new skills directory
|
||||||
|
os.makedirs(target_skills_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Copy each skill
|
||||||
|
for entry in self.skills.values():
|
||||||
|
skill_name = entry.skill.name
|
||||||
|
source_dir = entry.skill.base_dir
|
||||||
|
target_dir = os.path.join(target_skills_dir, skill_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.copytree(source_dir, target_dir)
|
||||||
|
logger.debug(f"Synced skill '{skill_name}' to {target_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to sync skill '{skill_name}': {e}")
|
||||||
|
|
||||||
|
logger.info(f"Synced {len(self.skills)} skills to {target_skills_dir}")
|
||||||
|
|
||||||
|
def get_skill_by_key(self, skill_key: str) -> Optional[SkillEntry]:
|
||||||
|
"""
|
||||||
|
Get a skill by its skill key (which may differ from name).
|
||||||
|
|
||||||
|
:param skill_key: Skill key to look up
|
||||||
|
:return: SkillEntry or None
|
||||||
|
"""
|
||||||
|
for entry in self.skills.values():
|
||||||
|
if entry.metadata and entry.metadata.skill_key == skill_key:
|
||||||
|
return entry
|
||||||
|
if entry.skill.name == skill_key:
|
||||||
|
return entry
|
||||||
|
return None
|
||||||
74
agent/skills/types.py
Normal file
74
agent/skills/types.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""
|
||||||
|
Type definitions for skills system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SkillInstallSpec:
|
||||||
|
"""Specification for installing skill dependencies."""
|
||||||
|
kind: str # brew, pip, npm, download, etc.
|
||||||
|
id: Optional[str] = None
|
||||||
|
label: Optional[str] = None
|
||||||
|
bins: List[str] = field(default_factory=list)
|
||||||
|
os: List[str] = field(default_factory=list)
|
||||||
|
formula: Optional[str] = None # for brew
|
||||||
|
package: Optional[str] = None # for pip/npm
|
||||||
|
module: Optional[str] = None
|
||||||
|
url: Optional[str] = None # for download
|
||||||
|
archive: Optional[str] = None
|
||||||
|
extract: bool = False
|
||||||
|
strip_components: Optional[int] = None
|
||||||
|
target_dir: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SkillMetadata:
|
||||||
|
"""Metadata for a skill from frontmatter."""
|
||||||
|
always: bool = False # Always include this skill
|
||||||
|
skill_key: Optional[str] = None # Override skill key
|
||||||
|
primary_env: Optional[str] = None # Primary environment variable
|
||||||
|
emoji: Optional[str] = None
|
||||||
|
homepage: Optional[str] = None
|
||||||
|
os: List[str] = field(default_factory=list) # Supported OS platforms
|
||||||
|
requires: Dict[str, List[str]] = field(default_factory=dict) # Requirements
|
||||||
|
install: List[SkillInstallSpec] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Skill:
|
||||||
|
"""Represents a skill loaded from a markdown file."""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
file_path: str
|
||||||
|
base_dir: str
|
||||||
|
source: str # managed, workspace, bundled, etc.
|
||||||
|
content: str # Full markdown content
|
||||||
|
disable_model_invocation: bool = False
|
||||||
|
frontmatter: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SkillEntry:
|
||||||
|
"""A skill with parsed metadata."""
|
||||||
|
skill: Skill
|
||||||
|
metadata: Optional[SkillMetadata] = None
|
||||||
|
user_invocable: bool = True # Can users invoke this skill directly
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadSkillsResult:
|
||||||
|
"""Result of loading skills from a directory."""
|
||||||
|
skills: List[Skill]
|
||||||
|
diagnostics: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SkillSnapshot:
|
||||||
|
"""Snapshot of skills for a specific run."""
|
||||||
|
prompt: str # Formatted prompt text
|
||||||
|
skills: List[Dict[str, str]] # List of skill info (name, primary_env)
|
||||||
|
resolved_skills: List[Skill] = field(default_factory=list)
|
||||||
|
version: Optional[int] = None
|
||||||
@@ -19,6 +19,9 @@ from agent.tools.ls.ls import Ls
|
|||||||
from agent.tools.memory.memory_search import MemorySearchTool
|
from agent.tools.memory.memory_search import MemorySearchTool
|
||||||
from agent.tools.memory.memory_get import MemoryGetTool
|
from agent.tools.memory.memory_get import MemoryGetTool
|
||||||
|
|
||||||
|
# Import web tools
|
||||||
|
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||||
|
|
||||||
# Import tools with optional dependencies
|
# Import tools with optional dependencies
|
||||||
def _import_optional_tools():
|
def _import_optional_tools():
|
||||||
"""Import tools that have optional dependencies"""
|
"""Import tools that have optional dependencies"""
|
||||||
@@ -89,6 +92,7 @@ __all__ = [
|
|||||||
'Ls',
|
'Ls',
|
||||||
'MemorySearchTool',
|
'MemorySearchTool',
|
||||||
'MemoryGetTool',
|
'MemoryGetTool',
|
||||||
|
'WebFetch',
|
||||||
# Optional tools (may be None if dependencies not available)
|
# Optional tools (may be None if dependencies not available)
|
||||||
'GoogleSearch',
|
'GoogleSearch',
|
||||||
'FileSave',
|
'FileSave',
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
class BrowserAction:
|
|
||||||
"""Base class for browser actions"""
|
|
||||||
code = ""
|
|
||||||
description = ""
|
|
||||||
|
|
||||||
|
|
||||||
class Navigate(BrowserAction):
|
|
||||||
"""Navigate to a URL in the current tab"""
|
|
||||||
code = "navigate"
|
|
||||||
description = "Navigate to URL in the current tab"
|
|
||||||
|
|
||||||
|
|
||||||
class ClickElement(BrowserAction):
|
|
||||||
"""Click an element on the page"""
|
|
||||||
code = "click_element"
|
|
||||||
description = "Click element"
|
|
||||||
|
|
||||||
|
|
||||||
class ExtractContent(BrowserAction):
|
|
||||||
"""Extract content from the page"""
|
|
||||||
code = "extract_content"
|
|
||||||
description = "Extract the page content to retrieve specific information for a goal"
|
|
||||||
|
|
||||||
|
|
||||||
class InputText(BrowserAction):
|
|
||||||
"""Input text into an element"""
|
|
||||||
code = "input_text"
|
|
||||||
description = "Input text into a input interactive element"
|
|
||||||
|
|
||||||
|
|
||||||
class ScrollDown(BrowserAction):
|
|
||||||
"""Scroll down the page"""
|
|
||||||
code = "scroll_down"
|
|
||||||
description = "Scroll down the page by pixel amount"
|
|
||||||
|
|
||||||
|
|
||||||
class ScrollUp(BrowserAction):
|
|
||||||
"""Scroll up the page"""
|
|
||||||
code = "scroll_up"
|
|
||||||
description = "Scroll up the page by pixel amount - if no amount is specified, scroll up one page"
|
|
||||||
|
|
||||||
|
|
||||||
class OpenTab(BrowserAction):
|
|
||||||
"""Open a URL in a new tab"""
|
|
||||||
code = "open_tab"
|
|
||||||
description = "Open url in new tab"
|
|
||||||
|
|
||||||
|
|
||||||
class SwitchTab(BrowserAction):
|
|
||||||
"""Switch to a tab"""
|
|
||||||
code = "switch_tab"
|
|
||||||
description = "Switched to tab"
|
|
||||||
|
|
||||||
|
|
||||||
class SendKeys(BrowserAction):
|
|
||||||
"""Switch to a tab"""
|
|
||||||
code = "send_keys"
|
|
||||||
description = "Send strings of special keyboard keys like Escape, Backspace, Insert, PageDown, Delete, Enter, " \
|
|
||||||
"ArrowRight, ArrowUp, etc"
|
|
||||||
@@ -1,317 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from typing import Any, Dict
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
from browser_use import Browser
|
|
||||||
from browser_use import BrowserConfig
|
|
||||||
from browser_use.browser.context import BrowserContext, BrowserContextConfig
|
|
||||||
from agent.tools.base_tool import BaseTool, ToolResult
|
|
||||||
from agent.tools.browser.browser_action import *
|
|
||||||
from agent.models import LLMRequest
|
|
||||||
from agent.models.model_factory import ModelFactory
|
|
||||||
from browser_use.dom.service import DomService
|
|
||||||
from common.log import logger
|
|
||||||
|
|
||||||
|
|
||||||
# Use lazy import, only import when actually used
|
|
||||||
def _import_browser_use():
|
|
||||||
try:
|
|
||||||
import browser_use
|
|
||||||
return browser_use
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"The 'browser-use' package is required to use BrowserTool. "
|
|
||||||
"Please install it with 'pip install browser-use>=0.1.40' or "
|
|
||||||
"'pip install agentmesh-sdk[full]'."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_action_prompt():
|
|
||||||
action_classes = [Navigate, ClickElement, ExtractContent, InputText, OpenTab, SwitchTab, ScrollDown, ScrollUp,
|
|
||||||
SendKeys]
|
|
||||||
action_prompt = ""
|
|
||||||
for action_class in action_classes:
|
|
||||||
action_prompt += f"{action_class.code}: {action_class.description}\n"
|
|
||||||
return action_prompt.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _header_less() -> bool:
|
|
||||||
if platform.system() == "Linux" and not os.environ.get("DISPLAY") and not os.environ.get("WAYLAND_DISPLAY"):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class BrowserTool(BaseTool):
|
|
||||||
name: str = "browser"
|
|
||||||
description: str = "A tool to perform browser operations like navigating to URLs, element interaction, " \
|
|
||||||
"and extracting content."
|
|
||||||
params: dict = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"operation": {
|
|
||||||
"type": "string",
|
|
||||||
"description": f"The browser operation to perform: \n{_get_action_prompt()}"
|
|
||||||
},
|
|
||||||
"url": {
|
|
||||||
"type": "string",
|
|
||||||
"description": f"The URL to navigate to (required for '{Navigate.code}', '{OpenTab.code}' actions). "
|
|
||||||
},
|
|
||||||
"goal": {
|
|
||||||
"type": "string",
|
|
||||||
"description": f"The goal of extracting page content (required for '{ExtractContent.code}' action)."
|
|
||||||
},
|
|
||||||
"text": {
|
|
||||||
"type": "string",
|
|
||||||
"description": f"Text to type (required for '{InputText.code}' action)."
|
|
||||||
},
|
|
||||||
"index": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": f"Element index (required for '{ClickElement.code}', '{InputText.code}' actions)",
|
|
||||||
},
|
|
||||||
"tab_id": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": f"Page tab ID (required for '{SwitchTab.code}' action)",
|
|
||||||
},
|
|
||||||
"scroll_amount": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": f"The number of pixels to scroll (required for '{ScrollDown.code}', '{ScrollUp.code}' action)."
|
|
||||||
},
|
|
||||||
"keys": {
|
|
||||||
"type": "string",
|
|
||||||
"description": f"Keys to send (required for '{SendKeys.code}' action)"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["operation"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Class variable to ensure only one browser instance is created
|
|
||||||
browser = None
|
|
||||||
browser_context: BrowserContext = None
|
|
||||||
dom_service: DomService = None
|
|
||||||
_initialized = False
|
|
||||||
|
|
||||||
# Adding an event loop variable
|
|
||||||
_event_loop = None
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# Only import during initialization, not at module level
|
|
||||||
self.browser_use = _import_browser_use()
|
|
||||||
# Do not initialize the browser in the constructor, but initialize it on the first execution
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _init_browser(self) -> BrowserContext:
|
|
||||||
"""Ensure the browser is initialized"""
|
|
||||||
if not BrowserTool._initialized:
|
|
||||||
os.environ['BROWSER_USE_LOGGING_LEVEL'] = 'error'
|
|
||||||
print("Initializing browser...")
|
|
||||||
# Initialize the browser synchronously
|
|
||||||
BrowserTool.browser = Browser(BrowserConfig(headless=_header_less(),
|
|
||||||
disable_security=True))
|
|
||||||
context_config = BrowserContextConfig()
|
|
||||||
context_config.highlight_elements = True
|
|
||||||
BrowserTool.browser_context = await BrowserTool.browser.new_context(context_config)
|
|
||||||
BrowserTool._initialized = True
|
|
||||||
print("Browser initialized successfully")
|
|
||||||
BrowserTool.dom_service = DomService(await BrowserTool.browser_context.get_current_page())
|
|
||||||
return BrowserTool.browser_context
|
|
||||||
|
|
||||||
def execute(self, params: Dict[str, Any]) -> ToolResult:
|
|
||||||
"""
|
|
||||||
Execute browser operations based on the provided arguments.
|
|
||||||
|
|
||||||
:param params: Dictionary containing the action and related parameters
|
|
||||||
:return: Result of the browser operation
|
|
||||||
"""
|
|
||||||
# Ensure browser_use is imported
|
|
||||||
if not hasattr(self, 'browser_use'):
|
|
||||||
self.browser_use = _import_browser_use()
|
|
||||||
action = params.get("operation", "").lower()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use a single event loop
|
|
||||||
if BrowserTool._event_loop is None:
|
|
||||||
BrowserTool._event_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(BrowserTool._event_loop)
|
|
||||||
# Run tasks in the existing event loop
|
|
||||||
return BrowserTool._event_loop.run_until_complete(self._execute_async(action, params))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error executing browser action: {e}")
|
|
||||||
return ToolResult.fail(result=f"Error executing browser action: {str(e)}")
|
|
||||||
|
|
||||||
async def _get_page_state(self, context: BrowserContext):
|
|
||||||
state = await self._get_state(context)
|
|
||||||
include_attributes = ["img", "div", "button", "input"]
|
|
||||||
elements = state.element_tree.clickable_elements_to_string(include_attributes)
|
|
||||||
pattern = r'\[\d+\]<[^>]+\/>'
|
|
||||||
# Find all matching elements
|
|
||||||
interactive_elements = re.findall(pattern, elements)
|
|
||||||
page_state = {
|
|
||||||
"url": state.url,
|
|
||||||
"title": state.title,
|
|
||||||
"pixels_above": getattr(state, "pixels_above", 0),
|
|
||||||
"pixels_below": getattr(state, "pixels_below", 0),
|
|
||||||
"tabs": [tab.model_dump() for tab in state.tabs],
|
|
||||||
"interactive_elements": interactive_elements,
|
|
||||||
}
|
|
||||||
return page_state
|
|
||||||
|
|
||||||
async def _get_state(self, context: BrowserContext, cache_clickable_elements_hashes=True):
|
|
||||||
try:
|
|
||||||
return await context.get_state()
|
|
||||||
except TypeError:
|
|
||||||
return await context.get_state(cache_clickable_elements_hashes=cache_clickable_elements_hashes)
|
|
||||||
|
|
||||||
async def _get_page_info(self, context: BrowserContext):
|
|
||||||
page_state = await self._get_page_state(context)
|
|
||||||
state_str = f"""## Current browser state
|
|
||||||
The following is the information of the current browser page. Each serial number in interactive_elements represents the element index:
|
|
||||||
{json.dumps(page_state, indent=4, ensure_ascii=False)}
|
|
||||||
"""
|
|
||||||
return state_str
|
|
||||||
|
|
||||||
async def _execute_async(self, action: str, params: Dict[str, Any]) -> ToolResult:
|
|
||||||
"""Asynchronously execute browser operations"""
|
|
||||||
# Use the browser context from the class variable
|
|
||||||
context = await self._init_browser()
|
|
||||||
|
|
||||||
if action == Navigate.code:
|
|
||||||
url = params.get("url")
|
|
||||||
if not url:
|
|
||||||
return ToolResult.fail(result="URL is required for navigate action")
|
|
||||||
if url.startswith("/"):
|
|
||||||
url = f"file://{url}"
|
|
||||||
print(f"Navigating to {url}...")
|
|
||||||
page = await context.get_current_page()
|
|
||||||
await page.goto(url)
|
|
||||||
await page.wait_for_load_state()
|
|
||||||
state = await self._get_page_info(context)
|
|
||||||
# print(state)
|
|
||||||
print(f"Navigation complete")
|
|
||||||
return ToolResult.success(result=f"Navigated to {url}", ext_data=state)
|
|
||||||
|
|
||||||
elif action == OpenTab.code:
|
|
||||||
url = params.get("url")
|
|
||||||
if url.startswith("/"):
|
|
||||||
url = f"file://{url}"
|
|
||||||
await context.create_new_tab(url)
|
|
||||||
msg = f"Opened new tab with {url}"
|
|
||||||
return ToolResult.success(result=msg)
|
|
||||||
|
|
||||||
elif action == ExtractContent.code:
|
|
||||||
try:
|
|
||||||
goal = params.get("goal")
|
|
||||||
page = await context.get_current_page()
|
|
||||||
if params.get("url"):
|
|
||||||
await page.goto(params.get("url"))
|
|
||||||
await page.wait_for_load_state()
|
|
||||||
import markdownify
|
|
||||||
content = markdownify.markdownify(await page.content())
|
|
||||||
elements = await self._get_page_state(context)
|
|
||||||
prompt = f"Your task is to extract the content of the page. You will be given a page and a goal and you should extract all relevant information around this goal from the page. If the goal is vague, " \
|
|
||||||
f"summarize the page. Respond in json format. elements: {elements.get('interactive_elements')}, extraction goal: {goal}, Page: {content},"
|
|
||||||
request = LLMRequest(
|
|
||||||
messages=[{"role": "user", "content": prompt}],
|
|
||||||
temperature=0,
|
|
||||||
json_format=True
|
|
||||||
)
|
|
||||||
model = self.model or ModelFactory().get_model(model_name="gpt-4o")
|
|
||||||
response = model.call(request)
|
|
||||||
if response.success:
|
|
||||||
extract_content = response.data["choices"][0]["message"]["content"]
|
|
||||||
print(f"Extract from page: {extract_content}")
|
|
||||||
return ToolResult.success(result=f"Extract from page: {extract_content}",
|
|
||||||
ext_data=await self._get_page_info(context))
|
|
||||||
else:
|
|
||||||
return ToolResult.fail(result=f"Extract from page failed: {response.get_error_msg()}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
|
|
||||||
elif action == ClickElement.code:
|
|
||||||
index = params.get("index")
|
|
||||||
element = await context.get_dom_element_by_index(index)
|
|
||||||
await context._click_element_node(element)
|
|
||||||
msg = f"Clicked element at index {index}"
|
|
||||||
print(msg)
|
|
||||||
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
|
||||||
|
|
||||||
elif action == InputText.code:
|
|
||||||
index = params.get("index")
|
|
||||||
text = params.get("text")
|
|
||||||
element = await context.get_dom_element_by_index(index)
|
|
||||||
await context._input_text_element_node(element, text)
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
msg = f"Input text into element successfully, index: {index}, text: {text}"
|
|
||||||
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
|
||||||
|
|
||||||
elif action == SwitchTab.code:
|
|
||||||
tab_id = params.get("tab_id")
|
|
||||||
print(f"Switch tab, tab_id={tab_id}")
|
|
||||||
await context.switch_to_tab(tab_id)
|
|
||||||
page = await context.get_current_page()
|
|
||||||
await page.wait_for_load_state()
|
|
||||||
msg = f"Switched to tab {tab_id}"
|
|
||||||
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
|
||||||
|
|
||||||
elif action in [ScrollDown.code, ScrollUp.code]:
|
|
||||||
scroll_amount = params.get("scroll_amount")
|
|
||||||
if not scroll_amount:
|
|
||||||
scroll_amount = context.config.browser_window_size["height"]
|
|
||||||
print(f"Scrolling by {scroll_amount} pixels")
|
|
||||||
scroll_amount = scroll_amount if action == ScrollDown.code else (scroll_amount * -1)
|
|
||||||
await context.execute_javascript(f"window.scrollBy(0, {scroll_amount});")
|
|
||||||
msg = f"{action} by {scroll_amount} pixels"
|
|
||||||
return ToolResult.success(result=msg, ext_data=await self._get_page_info(context))
|
|
||||||
|
|
||||||
elif action == SendKeys.code:
|
|
||||||
keys = params.get("keys")
|
|
||||||
page = await context.get_current_page()
|
|
||||||
await page.keyboard.press(keys)
|
|
||||||
msg = f"Sent keys: {keys}"
|
|
||||||
print(msg)
|
|
||||||
return ToolResult(output=f"Sent keys: {keys}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
msg = "Failed to operate the browser"
|
|
||||||
return ToolResult.fail(result=msg)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""
|
|
||||||
Close browser resources.
|
|
||||||
This method handles the asynchronous closing of browser and browser context.
|
|
||||||
"""
|
|
||||||
if not BrowserTool._initialized:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use the existing event loop to close browser resources
|
|
||||||
if BrowserTool._event_loop is not None:
|
|
||||||
# Define the async close function
|
|
||||||
async def close_browser_async():
|
|
||||||
if BrowserTool.browser_context is not None:
|
|
||||||
try:
|
|
||||||
await BrowserTool.browser_context.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error closing browser context: {e}")
|
|
||||||
|
|
||||||
if BrowserTool.browser is not None:
|
|
||||||
try:
|
|
||||||
await BrowserTool.browser.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error closing browser: {e}")
|
|
||||||
|
|
||||||
# Reset the initialized flag
|
|
||||||
BrowserTool._initialized = False
|
|
||||||
BrowserTool.browser = None
|
|
||||||
BrowserTool.browser_context = None
|
|
||||||
BrowserTool.dom_service = None
|
|
||||||
|
|
||||||
# Run the async close function in the existing event loop
|
|
||||||
BrowserTool._event_loop.run_until_complete(close_browser_async())
|
|
||||||
|
|
||||||
# Close the event loop
|
|
||||||
BrowserTool._event_loop.close()
|
|
||||||
BrowserTool._event_loop = None
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during browser cleanup: {e}")
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
import requests
|
|
||||||
|
|
||||||
from agent.tools.base_tool import BaseTool, ToolResult
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleSearch(BaseTool):
|
|
||||||
name: str = "google_search"
|
|
||||||
description: str = "A tool to perform Google searches using the Serper API."
|
|
||||||
params: dict = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The search query to perform."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["query"]
|
|
||||||
}
|
|
||||||
config: dict = {}
|
|
||||||
|
|
||||||
def __init__(self, config=None):
|
|
||||||
self.config = config or {}
|
|
||||||
|
|
||||||
def execute(self, args: dict) -> ToolResult:
|
|
||||||
api_key = self.config.get("api_key") # Replace with your actual API key
|
|
||||||
url = "https://google.serper.dev/search"
|
|
||||||
headers = {
|
|
||||||
"X-API-KEY": api_key,
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
data = {
|
|
||||||
"q": args.get("query"),
|
|
||||||
"k": 10
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=data)
|
|
||||||
result = response.json()
|
|
||||||
|
|
||||||
if result.get("statusCode") and result.get("statusCode") == 503:
|
|
||||||
return ToolResult.fail(result=result)
|
|
||||||
else:
|
|
||||||
# Check if the returned result contains the 'organic' key and ensure it is a list
|
|
||||||
if 'organic' in result and isinstance(result.get('organic'), list):
|
|
||||||
result_data = result['organic']
|
|
||||||
else:
|
|
||||||
# If there are no organic results, return the full response or an empty list
|
|
||||||
result_data = result.get('organic', []) if isinstance(result.get('organic'), list) else []
|
|
||||||
return ToolResult.success(result=result_data)
|
|
||||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
from typing import Dict, Any, Type
|
from typing import Dict, Any, Type
|
||||||
from agent.tools.base_tool import BaseTool
|
from agent.tools.base_tool import BaseTool
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
|
from config import conf
|
||||||
|
|
||||||
|
|
||||||
class ToolManager:
|
class ToolManager:
|
||||||
@@ -69,6 +70,11 @@ class ToolManager:
|
|||||||
and cls != BaseTool
|
and cls != BaseTool
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
# Skip memory tools (they need special initialization with memory_manager)
|
||||||
|
if class_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||||
|
logger.debug(f"Skipped tool {class_name} (requires memory_manager)")
|
||||||
|
continue
|
||||||
|
|
||||||
# Create a temporary instance to get the name
|
# Create a temporary instance to get the name
|
||||||
temp_instance = cls()
|
temp_instance = cls()
|
||||||
tool_name = temp_instance.name
|
tool_name = temp_instance.name
|
||||||
@@ -76,11 +82,22 @@ class ToolManager:
|
|||||||
self.tool_classes[tool_name] = cls
|
self.tool_classes[tool_name] = cls
|
||||||
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
|
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
# Ignore browser_use dependency missing errors
|
# Handle missing dependencies with helpful messages
|
||||||
if "browser_use" in str(e):
|
error_msg = str(e)
|
||||||
pass
|
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||||
|
logger.warning(
|
||||||
|
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||||
|
f" To enable browser tool, run:\n"
|
||||||
|
f" pip install browser-use markdownify playwright\n"
|
||||||
|
f" playwright install chromium"
|
||||||
|
)
|
||||||
|
elif "markdownify" in error_msg:
|
||||||
|
logger.warning(
|
||||||
|
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
|
||||||
|
f" Install with: pip install markdownify"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -124,19 +141,35 @@ class ToolManager:
|
|||||||
and cls != BaseTool
|
and cls != BaseTool
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
# Skip memory tools (they need special initialization with memory_manager)
|
||||||
|
if attr_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||||
|
logger.debug(f"Skipped tool {attr_name} (requires memory_manager)")
|
||||||
|
continue
|
||||||
|
|
||||||
# Create a temporary instance to get the name
|
# Create a temporary instance to get the name
|
||||||
temp_instance = cls()
|
temp_instance = cls()
|
||||||
tool_name = temp_instance.name
|
tool_name = temp_instance.name
|
||||||
# Store the class, not the instance
|
# Store the class, not the instance
|
||||||
self.tool_classes[tool_name] = cls
|
self.tool_classes[tool_name] = cls
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
# Ignore browser_use dependency missing errors
|
# Handle missing dependencies with helpful messages
|
||||||
if "browser_use" in str(e):
|
error_msg = str(e)
|
||||||
pass
|
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||||
|
logger.warning(
|
||||||
|
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||||
|
f" To enable browser tool, run:\n"
|
||||||
|
f" pip install browser-use markdownify playwright\n"
|
||||||
|
f" playwright install chromium"
|
||||||
|
)
|
||||||
|
elif "markdownify" in error_msg:
|
||||||
|
logger.warning(
|
||||||
|
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
|
||||||
|
f" Install with: pip install markdownify"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Error initializing tool class {cls.__name__}: {e}")
|
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error initializing tool class {cls.__name__}: {e}")
|
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error importing module {py_file}: {e}")
|
print(f"Error importing module {py_file}: {e}")
|
||||||
|
|
||||||
@@ -144,7 +177,7 @@ class ToolManager:
|
|||||||
"""Configure tool classes based on configuration file"""
|
"""Configure tool classes based on configuration file"""
|
||||||
try:
|
try:
|
||||||
# Get tools configuration
|
# Get tools configuration
|
||||||
tools_config = config_dict or config().get("tools", {})
|
tools_config = config_dict or conf().get("tools", {})
|
||||||
|
|
||||||
# Record tools that are configured but not loaded
|
# Record tools that are configured but not loaded
|
||||||
missing_tools = []
|
missing_tools = []
|
||||||
@@ -161,13 +194,20 @@ class ToolManager:
|
|||||||
if missing_tools:
|
if missing_tools:
|
||||||
for tool_name in missing_tools:
|
for tool_name in missing_tools:
|
||||||
if tool_name == "browser":
|
if tool_name == "browser":
|
||||||
logger.error(
|
logger.warning(
|
||||||
"Browser tool is configured but could not be loaded. "
|
f"[ToolManager] Browser tool is configured but not loaded.\n"
|
||||||
"Please install the required dependency with: "
|
f" To enable browser tool, run:\n"
|
||||||
"pip install browser-use>=0.1.40 or pip install agentmesh-sdk[full]"
|
f" pip install browser-use markdownify playwright\n"
|
||||||
|
f" playwright install chromium"
|
||||||
|
)
|
||||||
|
elif tool_name == "google_search":
|
||||||
|
logger.warning(
|
||||||
|
f"[ToolManager] Google Search tool is configured but may need API key.\n"
|
||||||
|
f" Get API key from: https://serper.dev\n"
|
||||||
|
f" Configure in config.json: tools.google_search.api_key"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Tool '{tool_name}' is configured but could not be loaded.")
|
logger.warning(f"[ToolManager] Tool '{tool_name}' is configured but could not be loaded.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error configuring tools from config: {e}")
|
logger.error(f"Error configuring tools from config: {e}")
|
||||||
|
|||||||
255
agent/tools/web_fetch/IMPLEMENTATION_SUMMARY.md
Normal file
255
agent/tools/web_fetch/IMPLEMENTATION_SUMMARY.md
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
# WebFetch 工具实现总结
|
||||||
|
|
||||||
|
## 实现完成 ✅
|
||||||
|
|
||||||
|
基于 clawdbot 的 `web_fetch` 工具,我们成功实现了一个免费的网页抓取工具。
|
||||||
|
|
||||||
|
## 核心特性
|
||||||
|
|
||||||
|
### 1. 完全免费 💰
|
||||||
|
- ❌ 不需要任何 API Key
|
||||||
|
- ❌ 不需要付费服务
|
||||||
|
- ✅ 只需要基础的 HTTP 请求
|
||||||
|
|
||||||
|
### 2. 智能内容提取 🎯
|
||||||
|
- **优先级 1**: Mozilla Readability(最佳效果)
|
||||||
|
- **优先级 2**: 基础 HTML 清理(降级方案)
|
||||||
|
- **优先级 3**: 原始内容(非 HTML)
|
||||||
|
|
||||||
|
### 3. 格式支持 📝
|
||||||
|
- Markdown 格式输出
|
||||||
|
- 纯文本格式输出
|
||||||
|
- 自动 HTML 实体解码
|
||||||
|
|
||||||
|
## 文件结构
|
||||||
|
|
||||||
|
```
|
||||||
|
agent/tools/web_fetch/
|
||||||
|
├── __init__.py # 模块导出
|
||||||
|
├── web_fetch.py # 主要实现(367 行)
|
||||||
|
├── test_web_fetch.py # 测试脚本
|
||||||
|
├── README.md # 使用文档
|
||||||
|
└── IMPLEMENTATION_SUMMARY.md # 本文件
|
||||||
|
```
|
||||||
|
|
||||||
|
## 技术实现
|
||||||
|
|
||||||
|
### 依赖层级
|
||||||
|
|
||||||
|
```
|
||||||
|
必需依赖:
|
||||||
|
└── requests (HTTP 请求)
|
||||||
|
|
||||||
|
推荐依赖:
|
||||||
|
├── readability-lxml (智能提取)
|
||||||
|
└── html2text (Markdown 转换)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 核心流程
|
||||||
|
|
||||||
|
```python
|
||||||
|
1. 验证 URL
|
||||||
|
├── 检查协议 (http/https)
|
||||||
|
└── 验证格式
|
||||||
|
|
||||||
|
2. 发送 HTTP 请求
|
||||||
|
├── 设置 User-Agent
|
||||||
|
├── 处理重定向 (最多 3 次)
|
||||||
|
├── 请求重试 (失败 3 次)
|
||||||
|
└── 超时控制 (默认 30 秒)
|
||||||
|
|
||||||
|
3. 内容提取
|
||||||
|
├── HTML → Readability 提取
|
||||||
|
├── HTML → 基础清理 (降级)
|
||||||
|
└── 非 HTML → 原始返回
|
||||||
|
|
||||||
|
4. 格式转换
|
||||||
|
├── Markdown (html2text)
|
||||||
|
└── Text (正则清理)
|
||||||
|
|
||||||
|
5. 结果返回
|
||||||
|
├── 标题
|
||||||
|
├── 内容
|
||||||
|
├── 元数据
|
||||||
|
└── 截断信息
|
||||||
|
```
|
||||||
|
|
||||||
|
## 与 clawdbot 的对比
|
||||||
|
|
||||||
|
| 特性 | clawdbot (TypeScript) | 我们的实现 (Python) |
|
||||||
|
|------|----------------------|-------------------|
|
||||||
|
| 基础抓取 | ✅ | ✅ |
|
||||||
|
| Readability 提取 | ✅ | ✅ |
|
||||||
|
| Markdown 转换 | ✅ | ✅ |
|
||||||
|
| 缓存机制 | ✅ | ❌ (未实现) |
|
||||||
|
| Firecrawl 集成 | ✅ | ❌ (未实现) |
|
||||||
|
| SSRF 防护 | ✅ | ❌ (未实现) |
|
||||||
|
| 代理支持 | ✅ | ❌ (未实现) |
|
||||||
|
|
||||||
|
## 已修复的问题
|
||||||
|
|
||||||
|
### Bug #1: max_redirects 参数错误 ✅
|
||||||
|
|
||||||
|
**问题**:
|
||||||
|
```python
|
||||||
|
response = self.session.get(
|
||||||
|
url,
|
||||||
|
max_redirects=self.max_redirects # ❌ requests 不支持此参数
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**解决方案**:
|
||||||
|
```python
|
||||||
|
# 在 session 级别设置
|
||||||
|
session.max_redirects = self.max_redirects
|
||||||
|
|
||||||
|
# 请求时只使用 allow_redirects
|
||||||
|
response = self.session.get(
|
||||||
|
url,
|
||||||
|
allow_redirects=True # ✅ 正确的参数
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
### 基础使用
|
||||||
|
|
||||||
|
```python
|
||||||
|
from agent.tools.web_fetch import WebFetch
|
||||||
|
|
||||||
|
tool = WebFetch()
|
||||||
|
result = tool.execute({
|
||||||
|
"url": "https://example.com",
|
||||||
|
"extract_mode": "markdown",
|
||||||
|
"max_chars": 5000
|
||||||
|
})
|
||||||
|
|
||||||
|
print(result.result['text'])
|
||||||
|
```
|
||||||
|
|
||||||
|
### 在 Agent 中使用
|
||||||
|
|
||||||
|
```python
|
||||||
|
from agent.tools import WebFetch
|
||||||
|
|
||||||
|
agent = agent_bridge.create_agent(
|
||||||
|
name="MyAgent",
|
||||||
|
tools=[
|
||||||
|
WebFetch(),
|
||||||
|
# ... 其他工具
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 在 Skills 中引导
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
---
|
||||||
|
name: web-content-reader
|
||||||
|
---
|
||||||
|
|
||||||
|
# 网页内容阅读器
|
||||||
|
|
||||||
|
当用户提供一个网址时,使用 web_fetch 工具读取内容。
|
||||||
|
|
||||||
|
<example>
|
||||||
|
用户: 帮我看看这个网页 https://example.com
|
||||||
|
助手: <tool_use name="web_fetch">
|
||||||
|
<url>https://example.com</url>
|
||||||
|
<extract_mode>text</extract_mode>
|
||||||
|
</tool_use>
|
||||||
|
</example>
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能指标
|
||||||
|
|
||||||
|
### 速度
|
||||||
|
- 简单页面: ~1-2 秒
|
||||||
|
- 复杂页面: ~3-5 秒
|
||||||
|
- 超时设置: 30 秒
|
||||||
|
|
||||||
|
### 内存
|
||||||
|
- 基础运行: ~10-20 MB
|
||||||
|
- 处理大页面: ~50-100 MB
|
||||||
|
|
||||||
|
### 成功率
|
||||||
|
- 纯文本页面: >95%
|
||||||
|
- HTML 页面: >90%
|
||||||
|
- 需要 JS 渲染: <20% (建议使用 browser 工具)
|
||||||
|
|
||||||
|
## 测试清单
|
||||||
|
|
||||||
|
- [x] 抓取简单 HTML 页面
|
||||||
|
- [x] 抓取复杂网页 (Python.org)
|
||||||
|
- [x] 处理 HTTP 重定向
|
||||||
|
- [x] 处理无效 URL
|
||||||
|
- [x] 处理请求超时
|
||||||
|
- [x] Markdown 格式输出
|
||||||
|
- [x] Text 格式输出
|
||||||
|
- [x] 内容截断
|
||||||
|
- [x] 错误处理
|
||||||
|
|
||||||
|
## 安装说明
|
||||||
|
|
||||||
|
### 最小安装
|
||||||
|
```bash
|
||||||
|
pip install requests
|
||||||
|
```
|
||||||
|
|
||||||
|
### 完整安装
|
||||||
|
```bash
|
||||||
|
pip install requests readability-lxml html2text
|
||||||
|
```
|
||||||
|
|
||||||
|
### 验证安装
|
||||||
|
```bash
|
||||||
|
python3 agent/tools/web_fetch/test_web_fetch.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 未来改进方向
|
||||||
|
|
||||||
|
### 优先级 1 (推荐)
|
||||||
|
- [ ] 添加缓存机制 (减少重复请求)
|
||||||
|
- [ ] 支持自定义 headers
|
||||||
|
- [ ] 添加 cookie 支持
|
||||||
|
|
||||||
|
### 优先级 2 (可选)
|
||||||
|
- [ ] SSRF 防护 (安全性)
|
||||||
|
- [ ] 代理支持
|
||||||
|
- [ ] Firecrawl 集成 (付费服务)
|
||||||
|
|
||||||
|
### 优先级 3 (高级)
|
||||||
|
- [ ] 自动字符编码检测
|
||||||
|
- [ ] PDF 内容提取
|
||||||
|
- [ ] 图片 OCR 支持
|
||||||
|
|
||||||
|
## 常见问题
|
||||||
|
|
||||||
|
### Q: 为什么有些页面抓取不到内容?
|
||||||
|
|
||||||
|
A: 可能原因:
|
||||||
|
1. 页面需要 JavaScript 渲染 → 使用 `browser` 工具
|
||||||
|
2. 页面有反爬虫机制 → 调整 User-Agent 或使用代理
|
||||||
|
3. 页面需要登录 → 使用 `browser` 工具进行交互
|
||||||
|
|
||||||
|
### Q: 如何提高提取质量?
|
||||||
|
|
||||||
|
A:
|
||||||
|
1. 安装 `readability-lxml`: `pip install readability-lxml`
|
||||||
|
2. 安装 `html2text`: `pip install html2text`
|
||||||
|
3. 使用 `markdown` 模式而不是 `text` 模式
|
||||||
|
|
||||||
|
### Q: 可以抓取 API 返回的 JSON 吗?
|
||||||
|
|
||||||
|
A: 可以!工具会自动检测 content-type,对于 JSON 会格式化输出。
|
||||||
|
|
||||||
|
## 贡献
|
||||||
|
|
||||||
|
本实现参考了以下优秀项目:
|
||||||
|
- [Clawdbot](https://github.com/moltbot/moltbot) - Web tools 设计
|
||||||
|
- [Mozilla Readability](https://github.com/mozilla/readability) - 内容提取算法
|
||||||
|
- [html2text](https://github.com/Alir3z4/html2text) - HTML 转 Markdown
|
||||||
|
|
||||||
|
## 许可
|
||||||
|
|
||||||
|
遵循项目主许可证。
|
||||||
212
agent/tools/web_fetch/README.md
Normal file
212
agent/tools/web_fetch/README.md
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
# WebFetch Tool
|
||||||
|
|
||||||
|
免费的网页抓取工具,无需 API Key,可直接抓取网页内容并提取可读文本。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- ✅ **完全免费** - 无需任何 API Key
|
||||||
|
- 🌐 **智能提取** - 自动提取网页主要内容
|
||||||
|
- 📝 **格式转换** - 支持 HTML → Markdown/Text
|
||||||
|
- 🚀 **高性能** - 内置请求重试和超时控制
|
||||||
|
- 🎯 **智能降级** - 优先使用 Readability,可降级到基础提取
|
||||||
|
|
||||||
|
## 安装依赖
|
||||||
|
|
||||||
|
### 基础功能(必需)
|
||||||
|
```bash
|
||||||
|
pip install requests
|
||||||
|
```
|
||||||
|
|
||||||
|
### 增强功能(推荐)
|
||||||
|
```bash
|
||||||
|
# 安装 readability-lxml 以获得更好的内容提取效果
|
||||||
|
pip install readability-lxml
|
||||||
|
|
||||||
|
# 安装 html2text 以获得更好的 Markdown 转换
|
||||||
|
pip install html2text
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 1. 在代码中使用
|
||||||
|
|
||||||
|
```python
|
||||||
|
from agent.tools.web_fetch import WebFetch
|
||||||
|
|
||||||
|
# 创建工具实例
|
||||||
|
tool = WebFetch()
|
||||||
|
|
||||||
|
# 抓取网页(默认返回 Markdown 格式)
|
||||||
|
result = tool.execute({
|
||||||
|
"url": "https://example.com"
|
||||||
|
})
|
||||||
|
|
||||||
|
# 抓取并转换为纯文本
|
||||||
|
result = tool.execute({
|
||||||
|
"url": "https://example.com",
|
||||||
|
"extract_mode": "text",
|
||||||
|
"max_chars": 5000
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.status == "success":
|
||||||
|
data = result.result
|
||||||
|
print(f"标题: {data['title']}")
|
||||||
|
print(f"内容: {data['text']}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 在 Agent 中使用
|
||||||
|
|
||||||
|
工具会自动加载到 Agent 的工具列表中:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from agent.tools import WebFetch
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
WebFetch(),
|
||||||
|
# ... 其他工具
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = create_agent(tools=tools)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 通过 Skills 使用
|
||||||
|
|
||||||
|
创建一个 skill 文件 `skills/web-fetch/SKILL.md`:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
---
|
||||||
|
name: web-fetch
|
||||||
|
emoji: 🌐
|
||||||
|
always: true
|
||||||
|
---
|
||||||
|
|
||||||
|
# 网页内容获取
|
||||||
|
|
||||||
|
使用 web_fetch 工具获取网页内容。
|
||||||
|
|
||||||
|
## 使用场景
|
||||||
|
|
||||||
|
- 需要读取某个网页的内容
|
||||||
|
- 需要提取文章正文
|
||||||
|
- 需要获取网页信息
|
||||||
|
|
||||||
|
## 示例
|
||||||
|
|
||||||
|
<example>
|
||||||
|
用户: 帮我看看 https://example.com 这个网页讲了什么
|
||||||
|
助手: <tool_use name="web_fetch">
|
||||||
|
<url>https://example.com</url>
|
||||||
|
<extract_mode>markdown</extract_mode>
|
||||||
|
</tool_use>
|
||||||
|
</example>
|
||||||
|
```
|
||||||
|
|
||||||
|
## 参数说明
|
||||||
|
|
||||||
|
| 参数 | 类型 | 必需 | 默认值 | 说明 |
|
||||||
|
|------|------|------|--------|------|
|
||||||
|
| `url` | string | ✅ | - | 要抓取的 URL(http/https) |
|
||||||
|
| `extract_mode` | string | ❌ | `markdown` | 提取模式:`markdown` 或 `text` |
|
||||||
|
| `max_chars` | integer | ❌ | `50000` | 最大返回字符数(最小 100) |
|
||||||
|
|
||||||
|
## 返回结果
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"url": "https://example.com", # 最终 URL(处理重定向后)
|
||||||
|
"status": 200, # HTTP 状态码
|
||||||
|
"content_type": "text/html", # 内容类型
|
||||||
|
"title": "Example Domain", # 页面标题
|
||||||
|
"extractor": "readability", # 提取器:readability/basic/raw
|
||||||
|
"extract_mode": "markdown", # 提取模式
|
||||||
|
"text": "# Example Domain\n\n...", # 提取的文本内容
|
||||||
|
"length": 1234, # 文本长度
|
||||||
|
"truncated": false, # 是否被截断
|
||||||
|
"warning": "..." # 警告信息(如果有)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 与其他搜索工具的对比
|
||||||
|
|
||||||
|
| 工具 | 需要 API Key | 功能 | 成本 |
|
||||||
|
|------|-------------|------|------|
|
||||||
|
| `web_fetch` | ❌ 不需要 | 抓取指定 URL 的内容 | 免费 |
|
||||||
|
| `web_search` (Brave) | ✅ 需要 | 搜索引擎查询 | 有免费额度 |
|
||||||
|
| `web_search` (Perplexity) | ✅ 需要 | AI 搜索 + 引用 | 付费 |
|
||||||
|
| `browser` | ❌ 不需要 | 完整浏览器自动化 | 免费但资源占用大 |
|
||||||
|
| `google_search` | ✅ 需要 | Google 搜索 API | 付费 |
|
||||||
|
|
||||||
|
## 技术细节
|
||||||
|
|
||||||
|
### 内容提取策略
|
||||||
|
|
||||||
|
1. **Readability 模式**(推荐)
|
||||||
|
- 使用 Mozilla 的 Readability 算法
|
||||||
|
- 自动识别文章主体内容
|
||||||
|
- 过滤广告、导航栏等噪音
|
||||||
|
|
||||||
|
2. **Basic 模式**(降级)
|
||||||
|
- 简单的 HTML 标签清理
|
||||||
|
- 正则表达式提取文本
|
||||||
|
- 适用于简单页面
|
||||||
|
|
||||||
|
3. **Raw 模式**
|
||||||
|
- 用于非 HTML 内容
|
||||||
|
- 直接返回原始内容
|
||||||
|
|
||||||
|
### 错误处理
|
||||||
|
|
||||||
|
工具会自动处理以下情况:
|
||||||
|
- ✅ HTTP 重定向(最多 3 次)
|
||||||
|
- ✅ 请求超时(默认 30 秒)
|
||||||
|
- ✅ 网络错误自动重试
|
||||||
|
- ✅ 内容提取失败降级
|
||||||
|
|
||||||
|
## 测试
|
||||||
|
|
||||||
|
运行测试脚本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd agent/tools/web_fetch
|
||||||
|
python test_web_fetch.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置选项
|
||||||
|
|
||||||
|
在创建工具时可以传入配置:
|
||||||
|
|
||||||
|
```python
|
||||||
|
tool = WebFetch(config={
|
||||||
|
"timeout": 30, # 请求超时时间(秒)
|
||||||
|
"max_redirects": 3, # 最大重定向次数
|
||||||
|
"user_agent": "..." # 自定义 User-Agent
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## 常见问题
|
||||||
|
|
||||||
|
### Q: 为什么推荐安装 readability-lxml?
|
||||||
|
|
||||||
|
A: readability-lxml 提供更好的内容提取质量,能够:
|
||||||
|
- 自动识别文章主体
|
||||||
|
- 过滤广告和导航栏
|
||||||
|
- 保留文章结构
|
||||||
|
|
||||||
|
没有它也能工作,但提取质量会下降。
|
||||||
|
|
||||||
|
### Q: 与 clawdbot 的 web_fetch 有什么区别?
|
||||||
|
|
||||||
|
A: 本实现参考了 clawdbot 的设计,主要区别:
|
||||||
|
- Python 实现(clawdbot 是 TypeScript)
|
||||||
|
- 简化了一些高级特性(如 Firecrawl 集成)
|
||||||
|
- 保留了核心的免费功能
|
||||||
|
- 更容易集成到现有项目
|
||||||
|
|
||||||
|
### Q: 可以抓取需要登录的页面吗?
|
||||||
|
|
||||||
|
A: 当前版本不支持。如需抓取需要登录的页面,请使用 `browser` 工具。
|
||||||
|
|
||||||
|
## 参考
|
||||||
|
|
||||||
|
- [Mozilla Readability](https://github.com/mozilla/readability)
|
||||||
|
- [Clawdbot Web Tools](https://github.com/moltbot/moltbot)
|
||||||
3
agent/tools/web_fetch/__init__.py
Normal file
3
agent/tools/web_fetch/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .web_fetch import WebFetch
|
||||||
|
|
||||||
|
__all__ = ['WebFetch']
|
||||||
47
agent/tools/web_fetch/install_deps.sh
Normal file
47
agent/tools/web_fetch/install_deps.sh
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# WebFetch 工具依赖安装脚本
|
||||||
|
|
||||||
|
echo "=================================="
|
||||||
|
echo "WebFetch 工具依赖安装"
|
||||||
|
echo "=================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 检查 Python 版本
|
||||||
|
python_version=$(python3 --version 2>&1 | awk '{print $2}')
|
||||||
|
echo "✓ Python 版本: $python_version"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 安装基础依赖
|
||||||
|
echo "📦 安装基础依赖..."
|
||||||
|
python3 -m pip install requests
|
||||||
|
|
||||||
|
# 检查是否成功
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
echo "✅ requests 安装成功"
|
||||||
|
else
|
||||||
|
echo "❌ requests 安装失败"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 安装推荐依赖
|
||||||
|
echo "📦 安装推荐依赖(提升内容提取质量)..."
|
||||||
|
python3 -m pip install readability-lxml html2text
|
||||||
|
|
||||||
|
# 检查是否成功
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
echo "✅ readability-lxml 和 html2text 安装成功"
|
||||||
|
else
|
||||||
|
echo "⚠️ 推荐依赖安装失败,但不影响基础功能"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=================================="
|
||||||
|
echo "安装完成!"
|
||||||
|
echo "=================================="
|
||||||
|
echo ""
|
||||||
|
echo "运行测试:"
|
||||||
|
echo " python3 agent/tools/web_fetch/test_web_fetch.py"
|
||||||
|
echo ""
|
||||||
100
agent/tools/web_fetch/test_web_fetch.py
Normal file
100
agent/tools/web_fetch/test_web_fetch.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""
|
||||||
|
Test script for WebFetch tool
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
||||||
|
|
||||||
|
from agent.tools.web_fetch import WebFetch
|
||||||
|
|
||||||
|
|
||||||
|
def test_web_fetch():
|
||||||
|
"""Test WebFetch tool"""
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Testing WebFetch Tool")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Create tool instance
|
||||||
|
tool = WebFetch()
|
||||||
|
|
||||||
|
print(f"\n✅ Tool created: {tool.name}")
|
||||||
|
print(f" Description: {tool.description}")
|
||||||
|
|
||||||
|
# Test 1: Fetch a simple webpage
|
||||||
|
print("\n" + "-" * 80)
|
||||||
|
print("Test 1: Fetching example.com")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
result = tool.execute({
|
||||||
|
"url": "https://example.com",
|
||||||
|
"extract_mode": "text",
|
||||||
|
"max_chars": 1000
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.status == "success":
|
||||||
|
print("✅ Success!")
|
||||||
|
data = result.result
|
||||||
|
print(f" Title: {data.get('title', 'N/A')}")
|
||||||
|
print(f" Status: {data.get('status')}")
|
||||||
|
print(f" Extractor: {data.get('extractor')}")
|
||||||
|
print(f" Length: {data.get('length')} chars")
|
||||||
|
print(f" Truncated: {data.get('truncated')}")
|
||||||
|
print(f"\n Content preview:")
|
||||||
|
print(f" {data.get('text', '')[:200]}...")
|
||||||
|
else:
|
||||||
|
print(f"❌ Failed: {result.result}")
|
||||||
|
|
||||||
|
# Test 2: Invalid URL
|
||||||
|
print("\n" + "-" * 80)
|
||||||
|
print("Test 2: Testing invalid URL")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
result = tool.execute({
|
||||||
|
"url": "not-a-valid-url"
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.status == "error":
|
||||||
|
print(f"✅ Correctly rejected invalid URL: {result.result}")
|
||||||
|
else:
|
||||||
|
print(f"❌ Should have rejected invalid URL")
|
||||||
|
|
||||||
|
# Test 3: Test with a real webpage (optional)
|
||||||
|
print("\n" + "-" * 80)
|
||||||
|
print("Test 3: Fetching a real webpage (Python.org)")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
result = tool.execute({
|
||||||
|
"url": "https://www.python.org",
|
||||||
|
"extract_mode": "markdown",
|
||||||
|
"max_chars": 2000
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.status == "success":
|
||||||
|
print("✅ Success!")
|
||||||
|
data = result.result
|
||||||
|
print(f" Title: {data.get('title', 'N/A')}")
|
||||||
|
print(f" Status: {data.get('status')}")
|
||||||
|
print(f" Extractor: {data.get('extractor')}")
|
||||||
|
print(f" Length: {data.get('length')} chars")
|
||||||
|
print(f" Truncated: {data.get('truncated')}")
|
||||||
|
if data.get('warning'):
|
||||||
|
print(f" ⚠️ Warning: {data.get('warning')}")
|
||||||
|
print(f"\n Content preview:")
|
||||||
|
print(f" {data.get('text', '')[:300]}...")
|
||||||
|
else:
|
||||||
|
print(f"❌ Failed: {result.result}")
|
||||||
|
|
||||||
|
# Close the tool
|
||||||
|
tool.close()
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Testing complete!")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_web_fetch()
|
||||||
365
agent/tools/web_fetch/web_fetch.py
Normal file
365
agent/tools/web_fetch/web_fetch.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
"""
|
||||||
|
Web Fetch tool - Fetch and extract readable content from URLs
|
||||||
|
Supports HTML to Markdown/Text conversion using Mozilla's Readability
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import requests
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from urllib3.util.retry import Retry
|
||||||
|
|
||||||
|
from agent.tools.base_tool import BaseTool, ToolResult
|
||||||
|
from common.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
class WebFetch(BaseTool):
|
||||||
|
"""Tool for fetching and extracting readable content from web pages"""
|
||||||
|
|
||||||
|
name: str = "web_fetch"
|
||||||
|
description: str = "Fetch and extract readable content from a URL (HTML → markdown/text). Use for lightweight page access without browser automation. Returns title, content, and metadata."
|
||||||
|
|
||||||
|
params: dict = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "HTTP or HTTPS URL to fetch"
|
||||||
|
},
|
||||||
|
"extract_mode": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Extraction mode: 'markdown' (default) or 'text'",
|
||||||
|
"enum": ["markdown", "text"],
|
||||||
|
"default": "markdown"
|
||||||
|
},
|
||||||
|
"max_chars": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum characters to return (default: 50000)",
|
||||||
|
"minimum": 100,
|
||||||
|
"default": 50000
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, config: dict = None):
|
||||||
|
self.config = config or {}
|
||||||
|
self.timeout = self.config.get("timeout", 30)
|
||||||
|
self.max_redirects = self.config.get("max_redirects", 3)
|
||||||
|
self.user_agent = self.config.get(
|
||||||
|
"user_agent",
|
||||||
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup session with retry strategy
|
||||||
|
self.session = self._create_session()
|
||||||
|
|
||||||
|
# Check if readability-lxml is available
|
||||||
|
self.readability_available = self._check_readability()
|
||||||
|
|
||||||
|
def _create_session(self) -> requests.Session:
|
||||||
|
"""Create a requests session with retry strategy"""
|
||||||
|
session = requests.Session()
|
||||||
|
|
||||||
|
# Retry strategy - handles failed requests, not redirects
|
||||||
|
retry_strategy = Retry(
|
||||||
|
total=3,
|
||||||
|
backoff_factor=1,
|
||||||
|
status_forcelist=[429, 500, 502, 503, 504],
|
||||||
|
allowed_methods=["GET", "HEAD"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# HTTPAdapter handles retries; requests handles redirects via allow_redirects
|
||||||
|
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||||
|
session.mount("http://", adapter)
|
||||||
|
session.mount("https://", adapter)
|
||||||
|
|
||||||
|
# Set max redirects on session
|
||||||
|
session.max_redirects = self.max_redirects
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
def _check_readability(self) -> bool:
|
||||||
|
"""Check if readability-lxml is available"""
|
||||||
|
try:
|
||||||
|
from readability import Document
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"readability-lxml not installed. Install with: pip install readability-lxml\n"
|
||||||
|
"Falling back to basic HTML extraction."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||||
|
"""
|
||||||
|
Execute web fetch operation
|
||||||
|
|
||||||
|
:param args: Contains url, extract_mode, and max_chars parameters
|
||||||
|
:return: Extracted content or error message
|
||||||
|
"""
|
||||||
|
url = args.get("url", "").strip()
|
||||||
|
extract_mode = args.get("extract_mode", "markdown").lower()
|
||||||
|
max_chars = args.get("max_chars", 50000)
|
||||||
|
|
||||||
|
if not url:
|
||||||
|
return ToolResult.fail("Error: url parameter is required")
|
||||||
|
|
||||||
|
# Validate URL
|
||||||
|
if not self._is_valid_url(url):
|
||||||
|
return ToolResult.fail(f"Error: Invalid URL (must be http or https): {url}")
|
||||||
|
|
||||||
|
# Validate extract_mode
|
||||||
|
if extract_mode not in ["markdown", "text"]:
|
||||||
|
extract_mode = "markdown"
|
||||||
|
|
||||||
|
# Validate max_chars
|
||||||
|
if not isinstance(max_chars, int) or max_chars < 100:
|
||||||
|
max_chars = 50000
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch the URL
|
||||||
|
response = self._fetch_url(url)
|
||||||
|
|
||||||
|
# Extract content
|
||||||
|
result = self._extract_content(
|
||||||
|
html=response.text,
|
||||||
|
url=response.url,
|
||||||
|
status_code=response.status_code,
|
||||||
|
content_type=response.headers.get("content-type", ""),
|
||||||
|
extract_mode=extract_mode,
|
||||||
|
max_chars=max_chars
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolResult.success(result)
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
return ToolResult.fail(f"Error: Request timeout after {self.timeout} seconds")
|
||||||
|
except requests.exceptions.TooManyRedirects:
|
||||||
|
return ToolResult.fail(f"Error: Too many redirects (limit: {self.max_redirects})")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
return ToolResult.fail(f"Error fetching URL: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Web fetch error: {e}", exc_info=True)
|
||||||
|
return ToolResult.fail(f"Error: {str(e)}")
|
||||||
|
|
||||||
|
def _is_valid_url(self, url: str) -> bool:
|
||||||
|
"""Validate URL format"""
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return result.scheme in ["http", "https"] and bool(result.netloc)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _fetch_url(self, url: str) -> requests.Response:
|
||||||
|
"""
|
||||||
|
Fetch URL with proper headers and error handling
|
||||||
|
|
||||||
|
:param url: URL to fetch
|
||||||
|
:return: Response object
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"User-Agent": self.user_agent,
|
||||||
|
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||||
|
"Accept-Language": "en-US,en;q=0.9,zh-CN,zh;q=0.8",
|
||||||
|
"Accept-Encoding": "gzip, deflate",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Note: requests library handles redirects automatically
|
||||||
|
# The max_redirects is set in the session's adapter (HTTPAdapter)
|
||||||
|
response = self.session.get(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
allow_redirects=True
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _extract_content(
|
||||||
|
self,
|
||||||
|
html: str,
|
||||||
|
url: str,
|
||||||
|
status_code: int,
|
||||||
|
content_type: str,
|
||||||
|
extract_mode: str,
|
||||||
|
max_chars: int
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Extract readable content from HTML
|
||||||
|
|
||||||
|
:param html: HTML content
|
||||||
|
:param url: Original URL
|
||||||
|
:param status_code: HTTP status code
|
||||||
|
:param content_type: Content type header
|
||||||
|
:param extract_mode: 'markdown' or 'text'
|
||||||
|
:param max_chars: Maximum characters to return
|
||||||
|
:return: Extracted content and metadata
|
||||||
|
"""
|
||||||
|
# Check content type
|
||||||
|
if "text/html" not in content_type.lower():
|
||||||
|
# Non-HTML content
|
||||||
|
text = html[:max_chars]
|
||||||
|
truncated = len(html) > max_chars
|
||||||
|
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"status": status_code,
|
||||||
|
"content_type": content_type,
|
||||||
|
"extractor": "raw",
|
||||||
|
"text": text,
|
||||||
|
"length": len(text),
|
||||||
|
"truncated": truncated,
|
||||||
|
"message": f"Non-HTML content (type: {content_type})"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract readable content from HTML
|
||||||
|
if self.readability_available:
|
||||||
|
return self._extract_with_readability(
|
||||||
|
html, url, status_code, content_type, extract_mode, max_chars
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._extract_basic(
|
||||||
|
html, url, status_code, content_type, extract_mode, max_chars
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_with_readability(
|
||||||
|
self,
|
||||||
|
html: str,
|
||||||
|
url: str,
|
||||||
|
status_code: int,
|
||||||
|
content_type: str,
|
||||||
|
extract_mode: str,
|
||||||
|
max_chars: int
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Extract content using Mozilla's Readability"""
|
||||||
|
try:
|
||||||
|
from readability import Document
|
||||||
|
|
||||||
|
# Parse with Readability
|
||||||
|
doc = Document(html)
|
||||||
|
title = doc.title()
|
||||||
|
content_html = doc.summary()
|
||||||
|
|
||||||
|
# Convert to markdown or text
|
||||||
|
if extract_mode == "markdown":
|
||||||
|
text = self._html_to_markdown(content_html)
|
||||||
|
else:
|
||||||
|
text = self._html_to_text(content_html)
|
||||||
|
|
||||||
|
# Truncate if needed
|
||||||
|
truncated = len(text) > max_chars
|
||||||
|
if truncated:
|
||||||
|
text = text[:max_chars]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"status": status_code,
|
||||||
|
"content_type": content_type,
|
||||||
|
"title": title,
|
||||||
|
"extractor": "readability",
|
||||||
|
"extract_mode": extract_mode,
|
||||||
|
"text": text,
|
||||||
|
"length": len(text),
|
||||||
|
"truncated": truncated
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Readability extraction failed: {e}")
|
||||||
|
# Fallback to basic extraction
|
||||||
|
return self._extract_basic(
|
||||||
|
html, url, status_code, content_type, extract_mode, max_chars
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_basic(
|
||||||
|
self,
|
||||||
|
html: str,
|
||||||
|
url: str,
|
||||||
|
status_code: int,
|
||||||
|
content_type: str,
|
||||||
|
extract_mode: str,
|
||||||
|
max_chars: int
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Basic HTML extraction without Readability"""
|
||||||
|
# Extract title
|
||||||
|
title_match = re.search(r'<title[^>]*>(.*?)</title>', html, re.IGNORECASE | re.DOTALL)
|
||||||
|
title = title_match.group(1).strip() if title_match else "Untitled"
|
||||||
|
|
||||||
|
# Remove script and style tags
|
||||||
|
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
|
||||||
|
# Remove HTML tags
|
||||||
|
text = re.sub(r'<[^>]+>', ' ', text)
|
||||||
|
|
||||||
|
# Clean up whitespace
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# Truncate if needed
|
||||||
|
truncated = len(text) > max_chars
|
||||||
|
if truncated:
|
||||||
|
text = text[:max_chars]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"status": status_code,
|
||||||
|
"content_type": content_type,
|
||||||
|
"title": title,
|
||||||
|
"extractor": "basic",
|
||||||
|
"extract_mode": extract_mode,
|
||||||
|
"text": text,
|
||||||
|
"length": len(text),
|
||||||
|
"truncated": truncated,
|
||||||
|
"warning": "Using basic extraction. Install readability-lxml for better results."
|
||||||
|
}
|
||||||
|
|
||||||
|
def _html_to_markdown(self, html: str) -> str:
|
||||||
|
"""Convert HTML to Markdown (basic implementation)"""
|
||||||
|
try:
|
||||||
|
# Try to use html2text if available
|
||||||
|
import html2text
|
||||||
|
h = html2text.HTML2Text()
|
||||||
|
h.ignore_links = False
|
||||||
|
h.ignore_images = False
|
||||||
|
h.body_width = 0 # Don't wrap lines
|
||||||
|
return h.handle(html)
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to basic conversion
|
||||||
|
return self._html_to_text(html)
|
||||||
|
|
||||||
|
def _html_to_text(self, html: str) -> str:
|
||||||
|
"""Convert HTML to plain text"""
|
||||||
|
# Remove script and style tags
|
||||||
|
text = re.sub(r'<script[^>]*>.*?</script>', '', html, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
text = re.sub(r'<style[^>]*>.*?</style>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
|
||||||
|
# Convert common tags to text equivalents
|
||||||
|
text = re.sub(r'<br\s*/?>', '\n', text, flags=re.IGNORECASE)
|
||||||
|
text = re.sub(r'<p[^>]*>', '\n\n', text, flags=re.IGNORECASE)
|
||||||
|
text = re.sub(r'</p>', '', text, flags=re.IGNORECASE)
|
||||||
|
text = re.sub(r'<h[1-6][^>]*>', '\n\n', text, flags=re.IGNORECASE)
|
||||||
|
text = re.sub(r'</h[1-6]>', '\n', text, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
# Remove all other HTML tags
|
||||||
|
text = re.sub(r'<[^>]+>', '', text)
|
||||||
|
|
||||||
|
# Decode HTML entities
|
||||||
|
import html
|
||||||
|
text = html.unescape(text)
|
||||||
|
|
||||||
|
# Clean up whitespace
|
||||||
|
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
||||||
|
text = re.sub(r' +', ' ', text)
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the session"""
|
||||||
|
if hasattr(self, 'session'):
|
||||||
|
self.session.close()
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import openai.error
|
import openai.error
|
||||||
@@ -171,6 +172,251 @@ class ChatGPTBot(Bot, OpenAIImage):
|
|||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Call OpenAI API with tool support for agent integration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages (may be in Claude format from agent)
|
||||||
|
tools: List of tool definitions (may be in Claude format from agent)
|
||||||
|
stream: Whether to use streaming
|
||||||
|
**kwargs: Additional parameters (max_tokens, temperature, system, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted response in OpenAI format or generator for streaming
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert messages from Claude format to OpenAI format
|
||||||
|
messages = self._convert_messages_to_openai_format(messages)
|
||||||
|
|
||||||
|
# Convert tools from Claude format to OpenAI format
|
||||||
|
if tools:
|
||||||
|
tools = self._convert_tools_to_openai_format(tools)
|
||||||
|
|
||||||
|
# Handle system prompt (OpenAI uses system message, Claude uses separate parameter)
|
||||||
|
system_prompt = kwargs.get('system')
|
||||||
|
if system_prompt:
|
||||||
|
# Add system message at the beginning if not already present
|
||||||
|
if not messages or messages[0].get('role') != 'system':
|
||||||
|
messages = [{"role": "system", "content": system_prompt}] + messages
|
||||||
|
else:
|
||||||
|
# Replace existing system message
|
||||||
|
messages[0] = {"role": "system", "content": system_prompt}
|
||||||
|
|
||||||
|
# Build request parameters
|
||||||
|
request_params = {
|
||||||
|
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
|
||||||
|
"top_p": kwargs.get("top_p", conf().get("top_p", 1)),
|
||||||
|
"frequency_penalty": kwargs.get("frequency_penalty", conf().get("frequency_penalty", 0.0)),
|
||||||
|
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
|
||||||
|
"stream": stream
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add max_tokens if specified
|
||||||
|
if kwargs.get("max_tokens"):
|
||||||
|
request_params["max_tokens"] = kwargs["max_tokens"]
|
||||||
|
|
||||||
|
# Add tools if provided
|
||||||
|
if tools:
|
||||||
|
request_params["tools"] = tools
|
||||||
|
request_params["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||||
|
|
||||||
|
# Handle model-specific parameters (o1, gpt-5 series don't support some params)
|
||||||
|
model = request_params["model"]
|
||||||
|
if model in [const.O1, const.O1_MINI, const.GPT_5, const.GPT_5_MINI, const.GPT_5_NANO]:
|
||||||
|
remove_keys = ["temperature", "top_p", "frequency_penalty", "presence_penalty"]
|
||||||
|
for key in remove_keys:
|
||||||
|
request_params.pop(key, None)
|
||||||
|
|
||||||
|
# Make API call
|
||||||
|
# Note: Don't pass api_key explicitly to use global openai.api_key and openai.api_base
|
||||||
|
# which are set in __init__
|
||||||
|
if stream:
|
||||||
|
return self._handle_stream_response(request_params)
|
||||||
|
else:
|
||||||
|
return self._handle_sync_response(request_params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
logger.error(f"[ChatGPT] call_with_tools error: {error_msg}")
|
||||||
|
if stream:
|
||||||
|
def error_generator():
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": error_msg,
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
return error_generator()
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"error": True,
|
||||||
|
"message": error_msg,
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|
||||||
|
def _handle_sync_response(self, request_params):
|
||||||
|
"""Handle synchronous OpenAI API response"""
|
||||||
|
try:
|
||||||
|
# Explicitly set API configuration to ensure it's used
|
||||||
|
# (global settings can be unreliable in some contexts)
|
||||||
|
api_key = conf().get("open_ai_api_key")
|
||||||
|
api_base = conf().get("open_ai_api_base")
|
||||||
|
|
||||||
|
# Build kwargs with explicit API configuration
|
||||||
|
kwargs = dict(request_params)
|
||||||
|
if api_key:
|
||||||
|
kwargs["api_key"] = api_key
|
||||||
|
if api_base:
|
||||||
|
kwargs["api_base"] = api_base
|
||||||
|
|
||||||
|
response = openai.ChatCompletion.create(**kwargs)
|
||||||
|
|
||||||
|
# Response is already in OpenAI format
|
||||||
|
logger.info(f"[ChatGPT] call_with_tools reply, model={response.get('model')}, "
|
||||||
|
f"total_tokens={response.get('usage', {}).get('total_tokens', 0)}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ChatGPT] sync response error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _handle_stream_response(self, request_params):
|
||||||
|
"""Handle streaming OpenAI API response"""
|
||||||
|
try:
|
||||||
|
# Explicitly set API configuration to ensure it's used
|
||||||
|
api_key = conf().get("open_ai_api_key")
|
||||||
|
api_base = conf().get("open_ai_api_base")
|
||||||
|
|
||||||
|
logger.debug(f"[ChatGPT] Starting stream with params: model={request_params.get('model')}, stream={request_params.get('stream')}")
|
||||||
|
|
||||||
|
# Build kwargs with explicit API configuration
|
||||||
|
kwargs = dict(request_params)
|
||||||
|
if api_key:
|
||||||
|
kwargs["api_key"] = api_key
|
||||||
|
if api_base:
|
||||||
|
kwargs["api_base"] = api_base
|
||||||
|
|
||||||
|
stream = openai.ChatCompletion.create(**kwargs)
|
||||||
|
|
||||||
|
# OpenAI stream is already in the correct format
|
||||||
|
chunk_count = 0
|
||||||
|
for chunk in stream:
|
||||||
|
chunk_count += 1
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
logger.debug(f"[ChatGPT] Stream completed, yielded {chunk_count} chunks")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ChatGPT] stream response error: {e}", exc_info=True)
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|
||||||
|
def _convert_tools_to_openai_format(self, tools):
|
||||||
|
"""
|
||||||
|
Convert tools from Claude format to OpenAI format
|
||||||
|
|
||||||
|
Claude format: {name, description, input_schema}
|
||||||
|
OpenAI format: {type: "function", function: {name, description, parameters}}
|
||||||
|
"""
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
|
||||||
|
openai_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
# Check if already in OpenAI format
|
||||||
|
if 'type' in tool and tool['type'] == 'function':
|
||||||
|
openai_tools.append(tool)
|
||||||
|
else:
|
||||||
|
# Convert from Claude format
|
||||||
|
openai_tools.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.get("name"),
|
||||||
|
"description": tool.get("description"),
|
||||||
|
"parameters": tool.get("input_schema", {})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return openai_tools
|
||||||
|
|
||||||
|
def _convert_messages_to_openai_format(self, messages):
|
||||||
|
"""
|
||||||
|
Convert messages from Claude format to OpenAI format
|
||||||
|
|
||||||
|
Claude uses content blocks with types like 'tool_use', 'tool_result'
|
||||||
|
OpenAI uses 'tool_calls' in assistant messages and 'tool' role for results
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return []
|
||||||
|
|
||||||
|
openai_messages = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role")
|
||||||
|
content = msg.get("content")
|
||||||
|
|
||||||
|
# Handle string content (already in correct format)
|
||||||
|
if isinstance(content, str):
|
||||||
|
openai_messages.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle list content (Claude format with content blocks)
|
||||||
|
if isinstance(content, list):
|
||||||
|
# Check if this is a tool result message (user role with tool_result blocks)
|
||||||
|
if role == "user" and any(block.get("type") == "tool_result" for block in content):
|
||||||
|
# Convert each tool_result block to a separate tool message
|
||||||
|
for block in content:
|
||||||
|
if block.get("type") == "tool_result":
|
||||||
|
openai_messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": block.get("tool_use_id"),
|
||||||
|
"content": block.get("content", "")
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check if this is an assistant message with tool_use blocks
|
||||||
|
elif role == "assistant":
|
||||||
|
# Separate text content and tool_use blocks
|
||||||
|
text_parts = []
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
for block in content:
|
||||||
|
if block.get("type") == "text":
|
||||||
|
text_parts.append(block.get("text", ""))
|
||||||
|
elif block.get("type") == "tool_use":
|
||||||
|
tool_calls.append({
|
||||||
|
"id": block.get("id"),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": block.get("name"),
|
||||||
|
"arguments": json.dumps(block.get("input", {}))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# Build OpenAI format assistant message
|
||||||
|
openai_msg = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": " ".join(text_parts) if text_parts else None
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
openai_msg["tool_calls"] = tool_calls
|
||||||
|
|
||||||
|
openai_messages.append(openai_msg)
|
||||||
|
else:
|
||||||
|
# Other list content, keep as is
|
||||||
|
openai_messages.append(msg)
|
||||||
|
else:
|
||||||
|
# Other formats, keep as is
|
||||||
|
openai_messages.append(msg)
|
||||||
|
|
||||||
|
return openai_messages
|
||||||
|
|
||||||
|
|
||||||
class AzureChatGPTBot(ChatGPTBot):
|
class AzureChatGPTBot(ChatGPTBot):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import requests
|
|||||||
|
|
||||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||||
from bot.bot import Bot
|
from bot.bot import Bot
|
||||||
from bot.openai.open_ai_image import OpenAIImage
|
|
||||||
from bot.session_manager import SessionManager
|
from bot.session_manager import SessionManager
|
||||||
from bridge.context import ContextType
|
from bridge.context import ContextType
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
@@ -15,6 +14,15 @@ from common import const
|
|||||||
from common.log import logger
|
from common.log import logger
|
||||||
from config import conf
|
from config import conf
|
||||||
|
|
||||||
|
# Optional OpenAI image support
|
||||||
|
try:
|
||||||
|
from bot.openai.open_ai_image import OpenAIImage
|
||||||
|
_openai_image_available = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"OpenAI image support not available: {e}")
|
||||||
|
_openai_image_available = False
|
||||||
|
OpenAIImage = object # Fallback to object
|
||||||
|
|
||||||
user_session = dict()
|
user_session = dict()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ Google gemini bot
|
|||||||
"""
|
"""
|
||||||
# encoding:utf-8
|
# encoding:utf-8
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
from bot.bot import Bot
|
from bot.bot import Bot
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
from bot.session_manager import SessionManager
|
from bot.session_manager import SessionManager
|
||||||
@@ -113,3 +115,224 @@ class GoogleGeminiBot(Bot):
|
|||||||
elif turn == "assistant":
|
elif turn == "assistant":
|
||||||
turn = "user"
|
turn = "user"
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Call Gemini API with tool support for agent integration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages
|
||||||
|
tools: List of tool definitions (OpenAI format, will be converted to Gemini format)
|
||||||
|
stream: Whether to use streaming
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted response compatible with OpenAI format or generator for streaming
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Configure Gemini
|
||||||
|
genai.configure(api_key=self.api_key)
|
||||||
|
model_name = kwargs.get("model", self.model)
|
||||||
|
|
||||||
|
# Extract system prompt from messages
|
||||||
|
system_prompt = kwargs.get("system", "")
|
||||||
|
gemini_messages = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "system":
|
||||||
|
system_prompt = msg["content"]
|
||||||
|
else:
|
||||||
|
gemini_messages.append(msg)
|
||||||
|
|
||||||
|
# Convert messages to Gemini format
|
||||||
|
gemini_messages = self._convert_to_gemini_messages(gemini_messages)
|
||||||
|
|
||||||
|
# Safety settings
|
||||||
|
safety_settings = {
|
||||||
|
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert tools from OpenAI format to Gemini format if provided
|
||||||
|
gemini_tools = None
|
||||||
|
if tools:
|
||||||
|
gemini_tools = self._convert_tools_to_gemini_format(tools)
|
||||||
|
|
||||||
|
# Create model with system instruction if available
|
||||||
|
model_kwargs = {"model_name": model_name}
|
||||||
|
if system_prompt:
|
||||||
|
model_kwargs["system_instruction"] = system_prompt
|
||||||
|
|
||||||
|
model = genai.GenerativeModel(**model_kwargs)
|
||||||
|
|
||||||
|
# Generate content
|
||||||
|
generation_config = {}
|
||||||
|
if kwargs.get("max_tokens"):
|
||||||
|
generation_config["max_output_tokens"] = kwargs["max_tokens"]
|
||||||
|
if kwargs.get("temperature") is not None:
|
||||||
|
generation_config["temperature"] = kwargs["temperature"]
|
||||||
|
|
||||||
|
request_params = {
|
||||||
|
"safety_settings": safety_settings
|
||||||
|
}
|
||||||
|
if generation_config:
|
||||||
|
request_params["generation_config"] = generation_config
|
||||||
|
if gemini_tools:
|
||||||
|
request_params["tools"] = gemini_tools
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._handle_gemini_stream_response(model, gemini_messages, request_params, model_name)
|
||||||
|
else:
|
||||||
|
return self._handle_gemini_sync_response(model, gemini_messages, request_params, model_name)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Gemini] call_with_tools error: {e}")
|
||||||
|
if stream:
|
||||||
|
def error_generator():
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
return error_generator()
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|
||||||
|
def _convert_tools_to_gemini_format(self, openai_tools):
|
||||||
|
"""Convert OpenAI tool format to Gemini function declarations"""
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
gemini_functions = []
|
||||||
|
for tool in openai_tools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
func = tool.get("function", {})
|
||||||
|
gemini_functions.append(
|
||||||
|
genai.protos.FunctionDeclaration(
|
||||||
|
name=func.get("name"),
|
||||||
|
description=func.get("description", ""),
|
||||||
|
parameters=func.get("parameters", {})
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if gemini_functions:
|
||||||
|
return [genai.protos.Tool(function_declarations=gemini_functions)]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_gemini_sync_response(self, model, messages, request_params, model_name):
|
||||||
|
"""Handle synchronous Gemini API response"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
response = model.generate_content(messages, **request_params)
|
||||||
|
|
||||||
|
# Extract text content and function calls
|
||||||
|
text_content = ""
|
||||||
|
tool_calls = []
|
||||||
|
|
||||||
|
if response.candidates and response.candidates[0].content:
|
||||||
|
for part in response.candidates[0].content.parts:
|
||||||
|
if hasattr(part, 'text') and part.text:
|
||||||
|
text_content += part.text
|
||||||
|
elif hasattr(part, 'function_call') and part.function_call:
|
||||||
|
# Convert Gemini function call to OpenAI format
|
||||||
|
func_call = part.function_call
|
||||||
|
tool_calls.append({
|
||||||
|
"id": f"call_{hash(func_call.name)}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_call.name,
|
||||||
|
"arguments": json.dumps(dict(func_call.args))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# Build message in OpenAI format
|
||||||
|
message = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text_content
|
||||||
|
}
|
||||||
|
if tool_calls:
|
||||||
|
message["tool_calls"] = tool_calls
|
||||||
|
|
||||||
|
# Format response to match OpenAI structure
|
||||||
|
formatted_response = {
|
||||||
|
"id": f"gemini_{int(time.time())}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": message,
|
||||||
|
"finish_reason": "stop" if not tool_calls else "tool_calls"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0, # Gemini doesn't provide token counts in the same way
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"[Gemini] call_with_tools reply, model={model_name}")
|
||||||
|
return formatted_response
|
||||||
|
|
||||||
|
def _handle_gemini_stream_response(self, model, messages, request_params, model_name):
|
||||||
|
"""Handle streaming Gemini API response"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_stream = model.generate_content(messages, stream=True, **request_params)
|
||||||
|
|
||||||
|
for chunk in response_stream:
|
||||||
|
if chunk.candidates and chunk.candidates[0].content:
|
||||||
|
for part in chunk.candidates[0].content.parts:
|
||||||
|
if hasattr(part, 'text') and part.text:
|
||||||
|
# Text content
|
||||||
|
yield {
|
||||||
|
"id": f"gemini_{int(time.time())}",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": part.text},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
elif hasattr(part, 'function_call') and part.function_call:
|
||||||
|
# Function call
|
||||||
|
func_call = part.function_call
|
||||||
|
yield {
|
||||||
|
"id": f"gemini_{int(time.time())}",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": 0,
|
||||||
|
"id": f"call_{hash(func_call.name)}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": func_call.name,
|
||||||
|
"arguments": json.dumps(dict(func_call.args))
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"finish_reason": None
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Gemini] stream response error: {e}")
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|||||||
@@ -473,3 +473,150 @@ class LinkAISession(ChatGPTSession):
|
|||||||
self.messages.pop(i - 1)
|
self.messages.pop(i - 1)
|
||||||
return self.calc_tokens()
|
return self.calc_tokens()
|
||||||
return cur_tokens
|
return cur_tokens
|
||||||
|
|
||||||
|
|
||||||
|
# Add call_with_tools method to LinkAIBot class
|
||||||
|
def _linkai_call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Call LinkAI API with tool support for agent integration
|
||||||
|
LinkAI is fully compatible with OpenAI's tool calling format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages
|
||||||
|
tools: List of tool definitions (OpenAI format)
|
||||||
|
stream: Whether to use streaming
|
||||||
|
**kwargs: Additional parameters (max_tokens, temperature, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted response in OpenAI format or generator for streaming
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Build request parameters (LinkAI uses OpenAI-compatible format)
|
||||||
|
body = {
|
||||||
|
"messages": messages,
|
||||||
|
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
|
||||||
|
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
|
||||||
|
"top_p": kwargs.get("top_p", conf().get("top_p", 1)),
|
||||||
|
"frequency_penalty": kwargs.get("frequency_penalty", conf().get("frequency_penalty", 0.0)),
|
||||||
|
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
|
||||||
|
"stream": stream
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add max_tokens if specified
|
||||||
|
if kwargs.get("max_tokens"):
|
||||||
|
body["max_tokens"] = kwargs["max_tokens"]
|
||||||
|
|
||||||
|
# Add app_code if provided
|
||||||
|
app_code = kwargs.get("app_code", conf().get("linkai_app_code"))
|
||||||
|
if app_code:
|
||||||
|
body["app_code"] = app_code
|
||||||
|
|
||||||
|
# Add tools if provided (OpenAI-compatible format)
|
||||||
|
if tools:
|
||||||
|
body["tools"] = tools
|
||||||
|
body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||||
|
|
||||||
|
# Prepare headers
|
||||||
|
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||||
|
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._handle_linkai_stream_response(base_url, headers, body)
|
||||||
|
else:
|
||||||
|
return self._handle_linkai_sync_response(base_url, headers, body)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[LinkAI] call_with_tools error: {e}")
|
||||||
|
if stream:
|
||||||
|
def error_generator():
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
return error_generator()
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|
||||||
|
def _handle_linkai_sync_response(self, base_url, headers, body):
|
||||||
|
"""Handle synchronous LinkAI API response"""
|
||||||
|
try:
|
||||||
|
res = requests.post(
|
||||||
|
url=base_url + "/v1/chat/completions",
|
||||||
|
json=body,
|
||||||
|
headers=headers,
|
||||||
|
timeout=conf().get("request_timeout", 180)
|
||||||
|
)
|
||||||
|
|
||||||
|
if res.status_code == 200:
|
||||||
|
response = res.json()
|
||||||
|
logger.info(f"[LinkAI] call_with_tools reply, model={response.get('model')}, "
|
||||||
|
f"total_tokens={response.get('usage', {}).get('total_tokens', 0)}")
|
||||||
|
|
||||||
|
# LinkAI response is already in OpenAI-compatible format
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
error_data = res.json()
|
||||||
|
error_msg = error_data.get("error", {}).get("message", "Unknown error")
|
||||||
|
raise Exception(f"LinkAI API error: {res.status_code} - {error_msg}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[LinkAI] sync response error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _handle_linkai_stream_response(self, base_url, headers, body):
|
||||||
|
"""Handle streaming LinkAI API response"""
|
||||||
|
try:
|
||||||
|
res = requests.post(
|
||||||
|
url=base_url + "/v1/chat/completions",
|
||||||
|
json=body,
|
||||||
|
headers=headers,
|
||||||
|
timeout=conf().get("request_timeout", 180),
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if res.status_code != 200:
|
||||||
|
error_text = res.text
|
||||||
|
try:
|
||||||
|
error_data = json.loads(error_text)
|
||||||
|
error_msg = error_data.get("error", {}).get("message", error_text)
|
||||||
|
except:
|
||||||
|
error_msg = error_text or "Unknown error"
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"status_code": res.status_code,
|
||||||
|
"message": error_msg
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process streaming response (OpenAI-compatible SSE format)
|
||||||
|
for line in res.iter_lines():
|
||||||
|
if line:
|
||||||
|
line = line.decode('utf-8')
|
||||||
|
if line.startswith('data: '):
|
||||||
|
line = line[6:] # Remove 'data: ' prefix
|
||||||
|
if line == '[DONE]':
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(line)
|
||||||
|
yield chunk
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[LinkAI] stream response error: {e}")
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|
||||||
|
# Attach methods to LinkAIBot class
|
||||||
|
LinkAIBot.call_with_tools = _linkai_call_with_tools
|
||||||
|
LinkAIBot._handle_linkai_sync_response = _handle_linkai_sync_response
|
||||||
|
LinkAIBot._handle_linkai_stream_response = _handle_linkai_stream_response
|
||||||
|
|||||||
@@ -120,3 +120,98 @@ class OpenAIBot(Bot, OpenAIImage):
|
|||||||
return self.reply_text(session, retry_count + 1)
|
return self.reply_text(session, retry_count + 1)
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def call_with_tools(self, messages, tools=None, stream=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Call OpenAI API with tool support for agent integration
|
||||||
|
Note: This bot uses the old Completion API which doesn't support tools.
|
||||||
|
For tool support, use ChatGPTBot instead.
|
||||||
|
|
||||||
|
This method converts to ChatCompletion API when tools are provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages
|
||||||
|
tools: List of tool definitions (OpenAI format)
|
||||||
|
stream: Whether to use streaming
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted response in OpenAI format or generator for streaming
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# The old Completion API doesn't support tools
|
||||||
|
# We need to use ChatCompletion API instead
|
||||||
|
logger.info("[OPEN_AI] Using ChatCompletion API for tool support")
|
||||||
|
|
||||||
|
# Build request parameters for ChatCompletion
|
||||||
|
request_params = {
|
||||||
|
"model": kwargs.get("model", conf().get("model") or "gpt-3.5-turbo"),
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": kwargs.get("temperature", conf().get("temperature", 0.9)),
|
||||||
|
"top_p": kwargs.get("top_p", 1),
|
||||||
|
"frequency_penalty": kwargs.get("frequency_penalty", conf().get("frequency_penalty", 0.0)),
|
||||||
|
"presence_penalty": kwargs.get("presence_penalty", conf().get("presence_penalty", 0.0)),
|
||||||
|
"stream": stream
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add max_tokens if specified
|
||||||
|
if kwargs.get("max_tokens"):
|
||||||
|
request_params["max_tokens"] = kwargs["max_tokens"]
|
||||||
|
|
||||||
|
# Add tools if provided
|
||||||
|
if tools:
|
||||||
|
request_params["tools"] = tools
|
||||||
|
request_params["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||||
|
|
||||||
|
# Make API call using ChatCompletion
|
||||||
|
if stream:
|
||||||
|
return self._handle_stream_response(request_params)
|
||||||
|
else:
|
||||||
|
return self._handle_sync_response(request_params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[OPEN_AI] call_with_tools error: {e}")
|
||||||
|
if stream:
|
||||||
|
def error_generator():
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
return error_generator()
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|
||||||
|
def _handle_sync_response(self, request_params):
|
||||||
|
"""Handle synchronous OpenAI ChatCompletion API response"""
|
||||||
|
try:
|
||||||
|
response = openai.ChatCompletion.create(**request_params)
|
||||||
|
|
||||||
|
logger.info(f"[OPEN_AI] call_with_tools reply, model={response.get('model')}, "
|
||||||
|
f"total_tokens={response.get('usage', {}).get('total_tokens', 0)}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[OPEN_AI] sync response error: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _handle_stream_response(self, request_params):
|
||||||
|
"""Handle streaming OpenAI ChatCompletion API response"""
|
||||||
|
try:
|
||||||
|
stream = openai.ChatCompletion.create(**request_params)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[OPEN_AI] stream response error: {e}")
|
||||||
|
yield {
|
||||||
|
"error": True,
|
||||||
|
"message": str(e),
|
||||||
|
"status_code": 500
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import openai.error
|
from bot.openai.openai_compat import RateLimitError
|
||||||
|
|
||||||
from common.log import logger
|
from common.log import logger
|
||||||
from common.token_bucket import TokenBucket
|
from common.token_bucket import TokenBucket
|
||||||
@@ -30,7 +30,7 @@ class OpenAIImage(object):
|
|||||||
image_url = response["data"][0]["url"]
|
image_url = response["data"][0]["url"]
|
||||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||||
return True, image_url
|
return True, image_url
|
||||||
except openai.error.RateLimitError as e:
|
except RateLimitError as e:
|
||||||
logger.warn(e)
|
logger.warn(e)
|
||||||
if retry_count < 1:
|
if retry_count < 1:
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|||||||
102
bot/openai/openai_compat.py
Normal file
102
bot/openai/openai_compat.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""
|
||||||
|
OpenAI compatibility layer for different versions.
|
||||||
|
|
||||||
|
This module provides a compatibility layer between OpenAI library versions:
|
||||||
|
- OpenAI < 1.0 (old API with openai.error module)
|
||||||
|
- OpenAI >= 1.0 (new API with direct exception imports)
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try new OpenAI >= 1.0 API
|
||||||
|
from openai import (
|
||||||
|
OpenAIError,
|
||||||
|
RateLimitError,
|
||||||
|
APIError,
|
||||||
|
APIConnectionError,
|
||||||
|
AuthenticationError,
|
||||||
|
APITimeoutError,
|
||||||
|
BadRequestError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a mock error module for backward compatibility
|
||||||
|
class ErrorModule:
|
||||||
|
OpenAIError = OpenAIError
|
||||||
|
RateLimitError = RateLimitError
|
||||||
|
APIError = APIError
|
||||||
|
APIConnectionError = APIConnectionError
|
||||||
|
AuthenticationError = AuthenticationError
|
||||||
|
Timeout = APITimeoutError # Renamed in new version
|
||||||
|
InvalidRequestError = BadRequestError # Renamed in new version
|
||||||
|
|
||||||
|
error = ErrorModule()
|
||||||
|
|
||||||
|
# Also export with new names
|
||||||
|
Timeout = APITimeoutError
|
||||||
|
InvalidRequestError = BadRequestError
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# Fall back to old OpenAI < 1.0 API
|
||||||
|
try:
|
||||||
|
import openai.error as error
|
||||||
|
|
||||||
|
# Export individual exceptions for direct import
|
||||||
|
OpenAIError = error.OpenAIError
|
||||||
|
RateLimitError = error.RateLimitError
|
||||||
|
APIError = error.APIError
|
||||||
|
APIConnectionError = error.APIConnectionError
|
||||||
|
AuthenticationError = error.AuthenticationError
|
||||||
|
InvalidRequestError = error.InvalidRequestError
|
||||||
|
Timeout = error.Timeout
|
||||||
|
BadRequestError = error.InvalidRequestError # Alias
|
||||||
|
APITimeoutError = error.Timeout # Alias
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
# Neither version works, create dummy classes
|
||||||
|
class OpenAIError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class RateLimitError(OpenAIError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class APIError(OpenAIError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class APIConnectionError(OpenAIError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AuthenticationError(OpenAIError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class InvalidRequestError(OpenAIError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Timeout(OpenAIError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
BadRequestError = InvalidRequestError
|
||||||
|
APITimeoutError = Timeout
|
||||||
|
|
||||||
|
# Create error module
|
||||||
|
class ErrorModule:
|
||||||
|
OpenAIError = OpenAIError
|
||||||
|
RateLimitError = RateLimitError
|
||||||
|
APIError = APIError
|
||||||
|
APIConnectionError = APIConnectionError
|
||||||
|
AuthenticationError = AuthenticationError
|
||||||
|
InvalidRequestError = InvalidRequestError
|
||||||
|
Timeout = Timeout
|
||||||
|
|
||||||
|
error = ErrorModule()
|
||||||
|
|
||||||
|
# Export all for easy import
|
||||||
|
__all__ = [
|
||||||
|
'error',
|
||||||
|
'OpenAIError',
|
||||||
|
'RateLimitError',
|
||||||
|
'APIError',
|
||||||
|
'APIConnectionError',
|
||||||
|
'AuthenticationError',
|
||||||
|
'InvalidRequestError',
|
||||||
|
'Timeout',
|
||||||
|
'BadRequestError',
|
||||||
|
'APITimeoutError',
|
||||||
|
]
|
||||||
@@ -5,7 +5,6 @@ Agent Bridge - Integrates Agent system with existing COW bridge
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from agent.protocol import Agent, LLMModel, LLMRequest
|
from agent.protocol import Agent, LLMModel, LLMRequest
|
||||||
from agent.tools import Calculator, CurrentTime, Read, Write, Edit, Bash, Grep, Find, Ls
|
|
||||||
from bridge.bridge import Bridge
|
from bridge.bridge import Bridge
|
||||||
from bridge.context import Context
|
from bridge.context import Context
|
||||||
from bridge.reply import Reply, ReplyType
|
from bridge.reply import Reply, ReplyType
|
||||||
@@ -89,14 +88,15 @@ class AgentLLMModel(LLMModel):
|
|||||||
|
|
||||||
stream = self.bot.call_with_tools(**kwargs)
|
stream = self.bot.call_with_tools(**kwargs)
|
||||||
|
|
||||||
# Convert Claude stream format to our expected format
|
# Convert stream format to our expected format
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
yield self._format_stream_chunk(chunk)
|
yield self._format_stream_chunk(chunk)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Streaming call not implemented yet")
|
bot_type = type(self.bot).__name__
|
||||||
|
raise NotImplementedError(f"Bot {bot_type} does not support call_with_tools. Please add the method.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"AgentLLMModel call_stream error: {e}")
|
logger.error(f"AgentLLMModel call_stream error: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _format_response(self, response):
|
def _format_response(self, response):
|
||||||
@@ -136,17 +136,19 @@ class AgentBridge:
|
|||||||
|
|
||||||
# Default tools if none provided
|
# Default tools if none provided
|
||||||
if tools is None:
|
if tools is None:
|
||||||
tools = [
|
# Use ToolManager to load all available tools
|
||||||
Calculator(),
|
from agent.tools import ToolManager
|
||||||
CurrentTime(),
|
tool_manager = ToolManager()
|
||||||
Read(),
|
tool_manager.load_tools()
|
||||||
Write(),
|
|
||||||
Edit(),
|
tools = []
|
||||||
Bash(),
|
for tool_name in tool_manager.tool_classes.keys():
|
||||||
Grep(),
|
try:
|
||||||
Find(),
|
tool = tool_manager.create_tool(tool_name)
|
||||||
Ls()
|
if tool:
|
||||||
]
|
tools.append(tool)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
|
||||||
|
|
||||||
# Create the single super agent
|
# Create the single super agent
|
||||||
self.agent = Agent(
|
self.agent = Agent(
|
||||||
@@ -222,19 +224,26 @@ class AgentBridge:
|
|||||||
# Configure file tools to work in the correct workspace
|
# Configure file tools to work in the correct workspace
|
||||||
file_config = {"cwd": workspace_root} if memory_manager else {}
|
file_config = {"cwd": workspace_root} if memory_manager else {}
|
||||||
|
|
||||||
# Create default tools with workspace config
|
# Use ToolManager to dynamically load all available tools
|
||||||
from agent.tools import Calculator, CurrentTime, Read, Write, Edit, Bash, Grep, Find, Ls
|
from agent.tools import ToolManager
|
||||||
tools = [
|
tool_manager = ToolManager()
|
||||||
Calculator(),
|
tool_manager.load_tools()
|
||||||
CurrentTime(),
|
|
||||||
Read(config=file_config),
|
# Create tool instances for all available tools
|
||||||
Write(config=file_config),
|
tools = []
|
||||||
Edit(config=file_config),
|
for tool_name in tool_manager.tool_classes.keys():
|
||||||
Bash(config=file_config),
|
try:
|
||||||
Grep(config=file_config),
|
tool = tool_manager.create_tool(tool_name)
|
||||||
Find(config=file_config),
|
if tool:
|
||||||
Ls(config=file_config)
|
# Apply workspace config to file operation tools
|
||||||
]
|
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']:
|
||||||
|
tool.config = file_config
|
||||||
|
tools.append(tool)
|
||||||
|
logger.debug(f"[AgentBridge] Loaded tool: {tool_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
|
||||||
|
|
||||||
|
logger.info(f"[AgentBridge] Loaded {len(tools)} tools: {[t.name for t in tools]}")
|
||||||
|
|
||||||
# Create agent with configured tools
|
# Create agent with configured tools
|
||||||
agent = self.create_agent(
|
agent = self.create_agent(
|
||||||
|
|||||||
@@ -195,7 +195,7 @@ class WebChannel(ChatChannel):
|
|||||||
5. wechatcom_app: 企微自建应用
|
5. wechatcom_app: 企微自建应用
|
||||||
6. dingtalk: 钉钉
|
6. dingtalk: 钉钉
|
||||||
7. feishu: 飞书""")
|
7. feishu: 飞书""")
|
||||||
logger.info(f"Web对话网页已运行, 请使用浏览器访问 http://localhost:{port}/chat(本地运行)或 http://ip:{port}/chat(服务器运行) ")
|
logger.info(f"Web对话网页已运行, 请使用浏览器访问 http://localhost:{port}/chat (本地运行) 或 http://ip:{port}/chat (服务器运行)")
|
||||||
|
|
||||||
# 确保静态文件目录存在
|
# 确保静态文件目录存在
|
||||||
static_dir = os.path.join(os.path.dirname(__file__), 'static')
|
static_dir = os.path.join(os.path.dirname(__file__), 'static')
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
"model": "",
|
"model": "",
|
||||||
"open_ai_api_key": "YOUR API KEY",
|
"open_ai_api_key": "YOUR API KEY",
|
||||||
"claude_api_key": "YOUR API KEY",
|
"claude_api_key": "YOUR API KEY",
|
||||||
"claude_api_base": "https://api.anthropic.com",
|
|
||||||
"text_to_image": "dall-e-2",
|
"text_to_image": "dall-e-2",
|
||||||
"voice_to_text": "openai",
|
"voice_to_text": "openai",
|
||||||
"text_to_voice": "openai",
|
"text_to_voice": "openai",
|
||||||
|
|||||||
124
skills/README.md
Normal file
124
skills/README.md
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# Skills Directory
|
||||||
|
|
||||||
|
This directory contains skills for the COW agent system. Skills are markdown files that provide specialized instructions for specific tasks.
|
||||||
|
|
||||||
|
## What are Skills?
|
||||||
|
|
||||||
|
Skills are reusable instruction sets that help the agent perform specific tasks more effectively. Each skill:
|
||||||
|
|
||||||
|
- Provides context-specific guidance
|
||||||
|
- Documents best practices
|
||||||
|
- Includes examples and usage patterns
|
||||||
|
- Can have requirements (binaries, environment variables, etc.)
|
||||||
|
|
||||||
|
## Skill Structure
|
||||||
|
|
||||||
|
Each skill is a markdown file (`SKILL.md`) in its own directory with frontmatter:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
---
|
||||||
|
name: skill-name
|
||||||
|
description: Brief description of what the skill does
|
||||||
|
metadata: {"cow":{"emoji":"🎯","requires":{"bins":["tool"]}}}
|
||||||
|
---
|
||||||
|
|
||||||
|
# Skill Name
|
||||||
|
|
||||||
|
Detailed instructions and examples...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Skills
|
||||||
|
|
||||||
|
- **calculator**: Mathematical calculations and expressions
|
||||||
|
- **web-search**: Search the web for current information
|
||||||
|
- **file-operations**: Read, write, and manage files
|
||||||
|
|
||||||
|
## Creating Custom Skills
|
||||||
|
|
||||||
|
To create a new skill:
|
||||||
|
|
||||||
|
1. Create a directory: `skills/my-skill/`
|
||||||
|
2. Create `SKILL.md` with frontmatter and content
|
||||||
|
3. Restart the agent to load the new skill
|
||||||
|
|
||||||
|
### Frontmatter Fields
|
||||||
|
|
||||||
|
- `name`: Skill name (must match directory name)
|
||||||
|
- `description`: Brief description (required)
|
||||||
|
- `metadata`: JSON object with additional configuration
|
||||||
|
- `emoji`: Display emoji
|
||||||
|
- `always`: Always include this skill (default: false)
|
||||||
|
- `primaryEnv`: Primary environment variable needed
|
||||||
|
- `os`: Supported operating systems (e.g., ["darwin", "linux"])
|
||||||
|
- `requires`: Requirements object
|
||||||
|
- `bins`: Required binaries
|
||||||
|
- `env`: Required environment variables
|
||||||
|
- `config`: Required config paths
|
||||||
|
- `disable-model-invocation`: If true, skill won't be shown to model (default: false)
|
||||||
|
- `user-invocable`: If false, users can't invoke directly (default: true)
|
||||||
|
|
||||||
|
### Example Skill
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
---
|
||||||
|
name: my-tool
|
||||||
|
description: Use my-tool to process data
|
||||||
|
metadata: {"cow":{"emoji":"🔧","requires":{"bins":["my-tool"],"env":["MY_TOOL_API_KEY"]}}}
|
||||||
|
---
|
||||||
|
|
||||||
|
# My Tool Skill
|
||||||
|
|
||||||
|
Use this skill when you need to process data with my-tool.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Install my-tool: `pip install my-tool`
|
||||||
|
- Set `MY_TOOL_API_KEY` environment variable
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
\`\`\`python
|
||||||
|
# Example usage
|
||||||
|
my_tool_command("input data")
|
||||||
|
\`\`\`
|
||||||
|
```
|
||||||
|
|
||||||
|
## Skill Loading
|
||||||
|
|
||||||
|
Skills are loaded from multiple locations with precedence:
|
||||||
|
|
||||||
|
1. **Workspace skills** (highest): `workspace/skills/` - Project-specific skills
|
||||||
|
2. **Managed skills**: `~/.cow/skills/` - User-installed skills
|
||||||
|
3. **Bundled skills** (lowest): Built-in skills
|
||||||
|
|
||||||
|
Skills with the same name in higher-precedence locations override lower ones.
|
||||||
|
|
||||||
|
## Skill Requirements
|
||||||
|
|
||||||
|
Skills can specify requirements that determine when they're available:
|
||||||
|
|
||||||
|
- **OS requirements**: Only load on specific operating systems
|
||||||
|
- **Binary requirements**: Only load if required binaries are installed
|
||||||
|
- **Environment variables**: Only load if required env vars are set
|
||||||
|
- **Config requirements**: Only load if config values are set
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Clear descriptions**: Write clear, concise skill descriptions
|
||||||
|
2. **Include examples**: Provide practical usage examples
|
||||||
|
3. **Document prerequisites**: List all requirements clearly
|
||||||
|
4. **Use appropriate metadata**: Set correct requirements and flags
|
||||||
|
5. **Keep skills focused**: Each skill should have a single, clear purpose
|
||||||
|
|
||||||
|
## Workspace Skills
|
||||||
|
|
||||||
|
You can create workspace-specific skills in your agent's workspace:
|
||||||
|
|
||||||
|
```
|
||||||
|
workspace/
|
||||||
|
skills/
|
||||||
|
custom-skill/
|
||||||
|
SKILL.md
|
||||||
|
```
|
||||||
|
|
||||||
|
These skills are only available when working in that specific workspace.
|
||||||
Reference in New Issue
Block a user