diff --git a/agent/tools/__init__.py b/agent/tools/__init__.py index 8bfc6de9..78e7b979 100644 --- a/agent/tools/__init__.py +++ b/agent/tools/__init__.py @@ -107,6 +107,22 @@ def _import_browser_tool(): BrowserTool = _import_browser_tool() +# MCP Tools (no extra dependencies, loaded on demand) +def _import_mcp_tools(): + """导入 MCP 工具模块(无额外依赖,按需加载)""" + from common.log import logger + try: + from agent.tools.mcp.mcp_tool import McpTool + from agent.tools.mcp.mcp_client import McpClientRegistry + return {'McpTool': McpTool, 'McpClientRegistry': McpClientRegistry} + except Exception as e: + logger.warning(f"[Tools] MCP tools not loaded: {e}") + return {} + +_mcp_tools = _import_mcp_tools() +McpTool = _mcp_tools.get('McpTool') +McpClientRegistry = _mcp_tools.get('McpClientRegistry') + # Export all tools (including optional ones that might be None) __all__ = [ 'BaseTool', @@ -125,6 +141,7 @@ __all__ = [ 'WebFetch', 'Vision', 'BrowserTool', + 'McpTool', ] """ diff --git a/agent/tools/mcp/__init__.py b/agent/tools/mcp/__init__.py new file mode 100644 index 00000000..cd4065ad --- /dev/null +++ b/agent/tools/mcp/__init__.py @@ -0,0 +1,4 @@ +from agent.tools.mcp.mcp_client import McpClient, McpClientRegistry +from agent.tools.mcp.mcp_tool import McpTool + +__all__ = ["McpClient", "McpClientRegistry", "McpTool"] diff --git a/agent/tools/mcp/mcp_client.py b/agent/tools/mcp/mcp_client.py new file mode 100644 index 00000000..3148a74b --- /dev/null +++ b/agent/tools/mcp/mcp_client.py @@ -0,0 +1,344 @@ +""" +MCP (Model Context Protocol) client module. + +Implements JSON-RPC 2.0 over stdio and SSE transports without any external +MCP SDK dependency. +""" + +import json +import subprocess +import threading +import urllib.request +import urllib.error +from typing import Optional + +from common.log import logger + + +class McpClient: + """Single MCP Server client supporting stdio and SSE transports.""" + + def __init__(self, config: dict): + """ + config examples: + stdio: {"name": "filesystem", "type": "stdio", "command": "npx", "args": [...]} + SSE: {"name": "my-api", "type": "sse", "url": "http://localhost:8000/sse"} + """ + self.config = config + self.name: str = config.get("name", "unknown") + self.transport: str = config.get("type", "stdio") + + # stdio state + self._proc: Optional[subprocess.Popen] = None + + # SSE state + self._sse_url: Optional[str] = None + self._post_url: Optional[str] = None # endpoint for sending messages (resolved from SSE) + + # Shared state + self._next_id = 1 + self._id_lock = threading.Lock() + self._initialized = False + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + def initialize(self) -> bool: + """Connect and perform the MCP handshake. Returns True on success.""" + try: + if self.transport == "stdio": + return self._init_stdio() + elif self.transport == "sse": + return self._init_sse() + else: + logger.warning(f"[MCP:{self.name}] Unknown transport type: {self.transport!r}") + return False + except Exception as e: + logger.warning(f"[MCP:{self.name}] Initialization failed: {e}") + return False + + def list_tools(self) -> list: + """Return the tool list from this server. + + Each item is a dict: {"name": str, "description": str, "inputSchema": dict} + """ + try: + resp = self._send_request("tools/list", {}) + tools = resp.get("result", {}).get("tools", []) + return [ + { + "name": t.get("name", ""), + "description": t.get("description", ""), + "inputSchema": t.get("inputSchema", {}), + } + for t in tools + ] + except Exception as e: + logger.warning(f"[MCP:{self.name}] list_tools failed: {e}") + return [] + + def call_tool(self, name: str, arguments: dict) -> str: + """Call a tool and return the result as a string.""" + try: + resp = self._send_request("tools/call", {"name": name, "arguments": arguments}) + content = resp.get("result", {}).get("content", []) + parts = [item.get("text", "") for item in content if item.get("type") == "text"] + return "\n".join(parts) + except Exception as e: + logger.warning(f"[MCP:{self.name}] call_tool({name}) failed: {e}") + return f"Error: {e}" + + def shutdown(self): + """Close the connection / terminate the child process.""" + if self._proc is not None: + try: + self._proc.stdin.close() + except Exception: + pass + try: + self._proc.terminate() + self._proc.wait(timeout=5) + except Exception: + try: + self._proc.kill() + except Exception: + pass + self._proc = None + logger.debug(f"[MCP:{self.name}] stdio process terminated") + self._initialized = False + + # ------------------------------------------------------------------ + # stdio transport + # ------------------------------------------------------------------ + + def _init_stdio(self) -> bool: + command = self.config.get("command") + if not command: + logger.warning(f"[MCP:{self.name}] stdio config missing 'command'") + return False + + args = self.config.get("args", []) + env = self.config.get("env", None) + + self._proc = subprocess.Popen( + [command] + list(args), + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + encoding="utf-8", + env=env, + ) + logger.debug(f"[MCP:{self.name}] stdio process started (pid={self._proc.pid})") + return self._handshake() + + def _stdio_send(self, message: dict) -> dict: + """Send a JSON-RPC message over stdio and read the response.""" + raw = json.dumps(message) + "\n" + self._proc.stdin.write(raw) + self._proc.stdin.flush() + + while True: + line = self._proc.stdout.readline() + if not line: + raise IOError(f"[MCP:{self.name}] stdio process closed unexpectedly") + line = line.strip() + if not line: + continue + return json.loads(line) + + # ------------------------------------------------------------------ + # SSE transport + # ------------------------------------------------------------------ + + def _init_sse(self) -> bool: + url = self.config.get("url") + if not url: + logger.warning(f"[MCP:{self.name}] SSE config missing 'url'") + return False + + self._sse_url = url + + # Read the first SSE event to discover the POST endpoint + try: + self._post_url = self._sse_discover_endpoint() + except Exception as e: + logger.warning(f"[MCP:{self.name}] SSE endpoint discovery failed: {e}") + return False + + return self._handshake() + + def _sse_discover_endpoint(self) -> str: + """Open SSE stream and read the 'endpoint' event to learn the POST URL.""" + req = urllib.request.Request( + self._sse_url, + headers={"Accept": "text/event-stream"}, + ) + with urllib.request.urlopen(req, timeout=10) as resp: + for raw_line in resp: + line = raw_line.decode("utf-8").rstrip("\n\r") + if line.startswith("data:"): + data = line[len("data:"):].strip() + # Some servers send JSON with a "uri" or plain path + if data.startswith("{"): + parsed = json.loads(data) + return parsed.get("uri") or parsed.get("url") or parsed.get("endpoint") + # Plain relative or absolute URL + if data.startswith("http"): + return data + # Relative path: resolve against SSE base + from urllib.parse import urljoin + return urljoin(self._sse_url, data) + raise ValueError(f"[MCP:{self.name}] No endpoint event received from SSE stream") + + def _sse_send(self, message: dict) -> dict: + """POST a JSON-RPC message to the server and return the response.""" + body = json.dumps(message).encode("utf-8") + req = urllib.request.Request( + self._post_url, + data=body, + method="POST", + headers={"Content-Type": "application/json"}, + ) + with urllib.request.urlopen(req, timeout=30) as resp: + raw = resp.read().decode("utf-8") + return json.loads(raw) + + # ------------------------------------------------------------------ + # Common JSON-RPC helpers + # ------------------------------------------------------------------ + + def _next_request_id(self) -> int: + with self._id_lock: + rid = self._next_id + self._next_id += 1 + return rid + + def _build_request(self, method: str, params: dict) -> dict: + return { + "jsonrpc": "2.0", + "id": self._next_request_id(), + "method": method, + "params": params, + } + + def _build_notification(self, method: str, params: dict) -> dict: + return {"jsonrpc": "2.0", "method": method, "params": params} + + def _send_request(self, method: str, params: dict) -> dict: + """Send a request and return the full response dict.""" + if not self._initialized and method != "initialize": + raise RuntimeError(f"[MCP:{self.name}] Client not initialized") + + message = self._build_request(method, params) + + if self.transport == "stdio": + return self._stdio_send(message) + elif self.transport == "sse": + return self._sse_send(message) + else: + raise ValueError(f"[MCP:{self.name}] Unsupported transport: {self.transport}") + + def _send_notification(self, method: str, params: dict): + """Fire-and-forget notification (no response expected).""" + notification = self._build_notification(method, params) + raw = json.dumps(notification) + "\n" + + if self.transport == "stdio": + self._proc.stdin.write(raw) + self._proc.stdin.flush() + elif self.transport == "sse": + body = raw.encode("utf-8") + req = urllib.request.Request( + self._post_url, + data=body, + method="POST", + headers={"Content-Type": "application/json"}, + ) + try: + with urllib.request.urlopen(req, timeout=10): + pass + except Exception: + pass # notifications are fire-and-forget + + def _handshake(self) -> bool: + """Perform the MCP initialize / notifications/initialized handshake.""" + init_params = { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "CowAgent", "version": "1.0"}, + } + # Temporarily mark as initialized so _send_request doesn't block + self._initialized = True + try: + resp = self._send_request("initialize", init_params) + except Exception as e: + self._initialized = False + logger.warning(f"[MCP:{self.name}] Handshake initialize failed: {e}") + return False + + if "error" in resp: + self._initialized = False + logger.warning(f"[MCP:{self.name}] Handshake error: {resp['error']}") + return False + + self._send_notification("notifications/initialized", {}) + logger.debug(f"[MCP:{self.name}] Handshake complete") + return True + + +class McpClientRegistry: + """Global singleton managing the lifecycle of all MCP Server clients.""" + + _instance = None + _instance_lock = threading.Lock() + + def __new__(cls): + with cls._instance_lock: + if cls._instance is None: + obj = super().__new__(cls) + obj._clients: dict[str, McpClient] = {} + obj._registry_lock = threading.Lock() + cls._instance = obj + return cls._instance + + def start_all(self, configs: list) -> None: + """Initialize McpClient for each config entry; skip failures with a warning.""" + if not configs: + return + + for cfg in configs: + name = cfg.get("name", "") + client = McpClient(cfg) + ok = client.initialize() + if ok: + with self._registry_lock: + self._clients[name] = client + logger.info(f"[MCP] Server '{name}' initialized successfully") + else: + logger.warning(f"[MCP] Server '{name}' failed to initialize — skipping") + + def get(self, server_name: str) -> Optional[McpClient]: + """Return the initialized client for server_name, or None.""" + with self._registry_lock: + return self._clients.get(server_name) + + def all_clients(self) -> dict: + """Return a copy of the {name: McpClient} mapping.""" + with self._registry_lock: + return dict(self._clients) + + def shutdown_all(self) -> None: + """Shut down all managed clients.""" + with self._registry_lock: + clients = list(self._clients.values()) + self._clients.clear() + + for client in clients: + try: + client.shutdown() + except Exception as e: + logger.warning(f"[MCP] Error shutting down '{client.name}': {e}") + + logger.info("[MCP] All servers shut down") diff --git a/agent/tools/mcp/mcp_tool.py b/agent/tools/mcp/mcp_tool.py new file mode 100644 index 00000000..ef1e814d --- /dev/null +++ b/agent/tools/mcp/mcp_tool.py @@ -0,0 +1,31 @@ +from agent.tools.base_tool import BaseTool, ToolResult +from common.log import logger + + +class McpTool(BaseTool): + """ + 将单个 MCP 工具包装为 BaseTool。 + 一个 MCP Server 可以提供多个工具,每个工具对应一个 McpTool 实例。 + """ + + def __init__(self, client, tool_schema: dict, server_name: str): + """ + :param client: 该工具所属的 McpClient 实例 + :param tool_schema: MCP 返回的工具描述,格式: + {"name": str, "description": str, "inputSchema": dict} + :param server_name: Server 名称,用于日志 + """ + self.client = client + self.server_name = server_name + self.name = tool_schema["name"] + self.description = tool_schema.get("description", "") + self.params = tool_schema.get("inputSchema", {}) + + def execute(self, params: dict) -> ToolResult: + logger.info(f"[McpTool] server={self.server_name} tool={self.name} params={params}") + try: + result = self.client.call_tool(self.name, params) + return ToolResult.success(result) + except Exception as e: + logger.error(f"[McpTool] server={self.server_name} tool={self.name} error: {e}") + return ToolResult.fail(str(e)) diff --git a/agent/tools/tool_manager.py b/agent/tools/tool_manager.py index 929d60a1..4a40474e 100644 --- a/agent/tools/tool_manager.py +++ b/agent/tools/tool_manager.py @@ -25,6 +25,10 @@ class ToolManager: # Initialize only once if not hasattr(self, 'tool_classes'): self.tool_classes = {} # Dictionary to store tool classes + if not hasattr(self, '_mcp_registry'): + self._mcp_registry = None # 懒初始化,有配置时才创建 + if not hasattr(self, '_mcp_tool_instances'): + self._mcp_tool_instances: dict = {} # tool_name -> McpTool instance def load_tools(self, tools_dir: str = "", config_dict=None): """ @@ -39,6 +43,8 @@ class ToolManager: self._load_tools_from_init() self._configure_tools_from_config(config_dict) + self._load_mcp_tools() + def _load_tools_from_init(self) -> bool: """ Load tool classes from tools.__init__.__all__ @@ -70,10 +76,14 @@ class ToolManager: and cls != BaseTool ): try: - # Skip memory tools (they need special initialization with memory_manager) + # Skip tools that need special initialization if class_name in ["MemorySearchTool", "MemoryGetTool"]: logger.debug(f"Skipped tool {class_name} (requires memory_manager)") continue + # McpTool instances are registered dynamically via _load_mcp_tools() + if class_name == "McpTool": + logger.debug(f"Skipped tool {class_name} (registered dynamically via mcp_servers config)") + continue # Create a temporary instance to get the name temp_instance = cls() @@ -212,6 +222,36 @@ class ToolManager: except Exception as e: logger.error(f"Error configuring tools from config: {e}") + def _load_mcp_tools(self): + """Load MCP tools from mcp_servers config. Failures are non-fatal.""" + try: + mcp_servers_config = conf().get("mcp_servers", []) + if not mcp_servers_config: + return + + from agent.tools.mcp.mcp_client import McpClientRegistry + from agent.tools.mcp.mcp_tool import McpTool + + self._mcp_registry = McpClientRegistry() + self._mcp_registry.start_all(mcp_servers_config) + + for server_name, client in self._mcp_registry.all_clients().items(): + try: + tool_schemas = client.list_tools() + for schema in tool_schemas: + tool_name = schema.get("name", "") + if not tool_name: + continue + mcp_tool = McpTool(client, schema, server_name) + self._mcp_tool_instances[tool_name] = mcp_tool + logger.debug(f"[ToolManager] Loaded MCP tool: {tool_name} from server '{server_name}'") + except Exception as e: + logger.warning(f"[ToolManager] Failed to list tools from MCP server '{server_name}': {e}") + + logger.info(f"[ToolManager] Loaded {len(self._mcp_tool_instances)} MCP tool(s) in total") + except Exception as e: + logger.warning(f"[ToolManager] MCP tool loading failed, skipping: {e}") + def create_tool(self, name: str) -> BaseTool: """ Get a new instance of a tool by name. @@ -229,6 +269,12 @@ class ToolManager: tool_instance.config = self.tool_configs[name] return tool_instance + + # Fall back to MCP tool instances + mcp_tool = self._mcp_tool_instances.get(name) + if mcp_tool: + return mcp_tool + return None def list_tools(self) -> dict: @@ -245,4 +291,17 @@ class ToolManager: "description": temp_instance.description, "parameters": temp_instance.get_json_schema() } + + # Include MCP tool instances + for name, mcp_tool in self._mcp_tool_instances.items(): + result[name] = { + "description": mcp_tool.description, + "parameters": mcp_tool.params, + } + return result + + def shutdown_mcp(self): + """Shut down all MCP server clients.""" + if self._mcp_registry: + self._mcp_registry.shutdown_all() diff --git a/config.py b/config.py index 156c26c1..76e19c50 100644 --- a/config.py +++ b/config.py @@ -219,6 +219,19 @@ available_setting = { # using the rule: skill[][] -> SKILL__ # (e.g. skill["image-generation"].model -> SKILL_IMAGE_GENERATION_MODEL). "skill": {}, + # MCP (Model Context Protocol) server list. + # Each entry describes one MCP server to connect at startup. + # Supported types: + # stdio — launch a local process and communicate over stdin/stdout + # sse — connect to a remote server via HTTP + Server-Sent Events + # + # Example: + # "mcp_servers": [ + # {"name": "filesystem", "type": "stdio", "command": "npx", + # "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]}, + # {"name": "my-api", "type": "sse", "url": "http://localhost:8000/sse"} + # ] + "mcp_servers": [], }