mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 00:57:41 +08:00
Allows CowAgent to dynamically load tools from any MCP server at startup,
extending the agent from a fixed toolset to an open, extensible tool ecosystem.
## What's added
- `agent/tools/mcp/mcp_client.py`: lightweight JSON-RPC client supporting both
stdio (subprocess) and SSE (HTTP) transports — zero extra dependencies
- `agent/tools/mcp/mcp_tool.py`: `McpTool` wraps a single MCP tool as a
`BaseTool`, with dynamic name/description/params set at instance level
- `agent/tools/tool_manager.py`: new `_load_mcp_tools()` loads MCP servers at
startup via `McpClientRegistry`; falls back gracefully on any error; no-op
when `mcp_servers` is not configured
- `config.py`: registers `mcp_servers` in `available_setting` with inline docs
## Design
- No new dependencies — JSON-RPC implemented from scratch using stdlib only
- MCP clients are long-lived (initialized once, shared across tool calls)
- `McpClientRegistry` holds all subprocess handles and shuts them down cleanly
- Server init failures are non-fatal: logged as warnings, agent continues normally
- Zero overhead when `mcp_servers` is absent from config
## Config example
```json
"mcp_servers": [
{
"name": "filesystem",
"type": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
}
]
```
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
345 lines
12 KiB
Python
345 lines
12 KiB
Python
"""
|
|
MCP (Model Context Protocol) client module.
|
|
|
|
Implements JSON-RPC 2.0 over stdio and SSE transports without any external
|
|
MCP SDK dependency.
|
|
"""
|
|
|
|
import json
|
|
import subprocess
|
|
import threading
|
|
import urllib.request
|
|
import urllib.error
|
|
from typing import Optional
|
|
|
|
from common.log import logger
|
|
|
|
|
|
class McpClient:
|
|
"""Single MCP Server client supporting stdio and SSE transports."""
|
|
|
|
def __init__(self, config: dict):
|
|
"""
|
|
config examples:
|
|
stdio: {"name": "filesystem", "type": "stdio", "command": "npx", "args": [...]}
|
|
SSE: {"name": "my-api", "type": "sse", "url": "http://localhost:8000/sse"}
|
|
"""
|
|
self.config = config
|
|
self.name: str = config.get("name", "unknown")
|
|
self.transport: str = config.get("type", "stdio")
|
|
|
|
# stdio state
|
|
self._proc: Optional[subprocess.Popen] = None
|
|
|
|
# SSE state
|
|
self._sse_url: Optional[str] = None
|
|
self._post_url: Optional[str] = None # endpoint for sending messages (resolved from SSE)
|
|
|
|
# Shared state
|
|
self._next_id = 1
|
|
self._id_lock = threading.Lock()
|
|
self._initialized = False
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public interface
|
|
# ------------------------------------------------------------------
|
|
|
|
def initialize(self) -> bool:
|
|
"""Connect and perform the MCP handshake. Returns True on success."""
|
|
try:
|
|
if self.transport == "stdio":
|
|
return self._init_stdio()
|
|
elif self.transport == "sse":
|
|
return self._init_sse()
|
|
else:
|
|
logger.warning(f"[MCP:{self.name}] Unknown transport type: {self.transport!r}")
|
|
return False
|
|
except Exception as e:
|
|
logger.warning(f"[MCP:{self.name}] Initialization failed: {e}")
|
|
return False
|
|
|
|
def list_tools(self) -> list:
|
|
"""Return the tool list from this server.
|
|
|
|
Each item is a dict: {"name": str, "description": str, "inputSchema": dict}
|
|
"""
|
|
try:
|
|
resp = self._send_request("tools/list", {})
|
|
tools = resp.get("result", {}).get("tools", [])
|
|
return [
|
|
{
|
|
"name": t.get("name", ""),
|
|
"description": t.get("description", ""),
|
|
"inputSchema": t.get("inputSchema", {}),
|
|
}
|
|
for t in tools
|
|
]
|
|
except Exception as e:
|
|
logger.warning(f"[MCP:{self.name}] list_tools failed: {e}")
|
|
return []
|
|
|
|
def call_tool(self, name: str, arguments: dict) -> str:
|
|
"""Call a tool and return the result as a string."""
|
|
try:
|
|
resp = self._send_request("tools/call", {"name": name, "arguments": arguments})
|
|
content = resp.get("result", {}).get("content", [])
|
|
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
|
return "\n".join(parts)
|
|
except Exception as e:
|
|
logger.warning(f"[MCP:{self.name}] call_tool({name}) failed: {e}")
|
|
return f"Error: {e}"
|
|
|
|
def shutdown(self):
|
|
"""Close the connection / terminate the child process."""
|
|
if self._proc is not None:
|
|
try:
|
|
self._proc.stdin.close()
|
|
except Exception:
|
|
pass
|
|
try:
|
|
self._proc.terminate()
|
|
self._proc.wait(timeout=5)
|
|
except Exception:
|
|
try:
|
|
self._proc.kill()
|
|
except Exception:
|
|
pass
|
|
self._proc = None
|
|
logger.debug(f"[MCP:{self.name}] stdio process terminated")
|
|
self._initialized = False
|
|
|
|
# ------------------------------------------------------------------
|
|
# stdio transport
|
|
# ------------------------------------------------------------------
|
|
|
|
def _init_stdio(self) -> bool:
|
|
command = self.config.get("command")
|
|
if not command:
|
|
logger.warning(f"[MCP:{self.name}] stdio config missing 'command'")
|
|
return False
|
|
|
|
args = self.config.get("args", [])
|
|
env = self.config.get("env", None)
|
|
|
|
self._proc = subprocess.Popen(
|
|
[command] + list(args),
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
text=True,
|
|
encoding="utf-8",
|
|
env=env,
|
|
)
|
|
logger.debug(f"[MCP:{self.name}] stdio process started (pid={self._proc.pid})")
|
|
return self._handshake()
|
|
|
|
def _stdio_send(self, message: dict) -> dict:
|
|
"""Send a JSON-RPC message over stdio and read the response."""
|
|
raw = json.dumps(message) + "\n"
|
|
self._proc.stdin.write(raw)
|
|
self._proc.stdin.flush()
|
|
|
|
while True:
|
|
line = self._proc.stdout.readline()
|
|
if not line:
|
|
raise IOError(f"[MCP:{self.name}] stdio process closed unexpectedly")
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
return json.loads(line)
|
|
|
|
# ------------------------------------------------------------------
|
|
# SSE transport
|
|
# ------------------------------------------------------------------
|
|
|
|
def _init_sse(self) -> bool:
|
|
url = self.config.get("url")
|
|
if not url:
|
|
logger.warning(f"[MCP:{self.name}] SSE config missing 'url'")
|
|
return False
|
|
|
|
self._sse_url = url
|
|
|
|
# Read the first SSE event to discover the POST endpoint
|
|
try:
|
|
self._post_url = self._sse_discover_endpoint()
|
|
except Exception as e:
|
|
logger.warning(f"[MCP:{self.name}] SSE endpoint discovery failed: {e}")
|
|
return False
|
|
|
|
return self._handshake()
|
|
|
|
def _sse_discover_endpoint(self) -> str:
|
|
"""Open SSE stream and read the 'endpoint' event to learn the POST URL."""
|
|
req = urllib.request.Request(
|
|
self._sse_url,
|
|
headers={"Accept": "text/event-stream"},
|
|
)
|
|
with urllib.request.urlopen(req, timeout=10) as resp:
|
|
for raw_line in resp:
|
|
line = raw_line.decode("utf-8").rstrip("\n\r")
|
|
if line.startswith("data:"):
|
|
data = line[len("data:"):].strip()
|
|
# Some servers send JSON with a "uri" or plain path
|
|
if data.startswith("{"):
|
|
parsed = json.loads(data)
|
|
return parsed.get("uri") or parsed.get("url") or parsed.get("endpoint")
|
|
# Plain relative or absolute URL
|
|
if data.startswith("http"):
|
|
return data
|
|
# Relative path: resolve against SSE base
|
|
from urllib.parse import urljoin
|
|
return urljoin(self._sse_url, data)
|
|
raise ValueError(f"[MCP:{self.name}] No endpoint event received from SSE stream")
|
|
|
|
def _sse_send(self, message: dict) -> dict:
|
|
"""POST a JSON-RPC message to the server and return the response."""
|
|
body = json.dumps(message).encode("utf-8")
|
|
req = urllib.request.Request(
|
|
self._post_url,
|
|
data=body,
|
|
method="POST",
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
with urllib.request.urlopen(req, timeout=30) as resp:
|
|
raw = resp.read().decode("utf-8")
|
|
return json.loads(raw)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Common JSON-RPC helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _next_request_id(self) -> int:
|
|
with self._id_lock:
|
|
rid = self._next_id
|
|
self._next_id += 1
|
|
return rid
|
|
|
|
def _build_request(self, method: str, params: dict) -> dict:
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": self._next_request_id(),
|
|
"method": method,
|
|
"params": params,
|
|
}
|
|
|
|
def _build_notification(self, method: str, params: dict) -> dict:
|
|
return {"jsonrpc": "2.0", "method": method, "params": params}
|
|
|
|
def _send_request(self, method: str, params: dict) -> dict:
|
|
"""Send a request and return the full response dict."""
|
|
if not self._initialized and method != "initialize":
|
|
raise RuntimeError(f"[MCP:{self.name}] Client not initialized")
|
|
|
|
message = self._build_request(method, params)
|
|
|
|
if self.transport == "stdio":
|
|
return self._stdio_send(message)
|
|
elif self.transport == "sse":
|
|
return self._sse_send(message)
|
|
else:
|
|
raise ValueError(f"[MCP:{self.name}] Unsupported transport: {self.transport}")
|
|
|
|
def _send_notification(self, method: str, params: dict):
|
|
"""Fire-and-forget notification (no response expected)."""
|
|
notification = self._build_notification(method, params)
|
|
raw = json.dumps(notification) + "\n"
|
|
|
|
if self.transport == "stdio":
|
|
self._proc.stdin.write(raw)
|
|
self._proc.stdin.flush()
|
|
elif self.transport == "sse":
|
|
body = raw.encode("utf-8")
|
|
req = urllib.request.Request(
|
|
self._post_url,
|
|
data=body,
|
|
method="POST",
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=10):
|
|
pass
|
|
except Exception:
|
|
pass # notifications are fire-and-forget
|
|
|
|
def _handshake(self) -> bool:
|
|
"""Perform the MCP initialize / notifications/initialized handshake."""
|
|
init_params = {
|
|
"protocolVersion": "2024-11-05",
|
|
"capabilities": {},
|
|
"clientInfo": {"name": "CowAgent", "version": "1.0"},
|
|
}
|
|
# Temporarily mark as initialized so _send_request doesn't block
|
|
self._initialized = True
|
|
try:
|
|
resp = self._send_request("initialize", init_params)
|
|
except Exception as e:
|
|
self._initialized = False
|
|
logger.warning(f"[MCP:{self.name}] Handshake initialize failed: {e}")
|
|
return False
|
|
|
|
if "error" in resp:
|
|
self._initialized = False
|
|
logger.warning(f"[MCP:{self.name}] Handshake error: {resp['error']}")
|
|
return False
|
|
|
|
self._send_notification("notifications/initialized", {})
|
|
logger.debug(f"[MCP:{self.name}] Handshake complete")
|
|
return True
|
|
|
|
|
|
class McpClientRegistry:
|
|
"""Global singleton managing the lifecycle of all MCP Server clients."""
|
|
|
|
_instance = None
|
|
_instance_lock = threading.Lock()
|
|
|
|
def __new__(cls):
|
|
with cls._instance_lock:
|
|
if cls._instance is None:
|
|
obj = super().__new__(cls)
|
|
obj._clients: dict[str, McpClient] = {}
|
|
obj._registry_lock = threading.Lock()
|
|
cls._instance = obj
|
|
return cls._instance
|
|
|
|
def start_all(self, configs: list) -> None:
|
|
"""Initialize McpClient for each config entry; skip failures with a warning."""
|
|
if not configs:
|
|
return
|
|
|
|
for cfg in configs:
|
|
name = cfg.get("name", "<unnamed>")
|
|
client = McpClient(cfg)
|
|
ok = client.initialize()
|
|
if ok:
|
|
with self._registry_lock:
|
|
self._clients[name] = client
|
|
logger.info(f"[MCP] Server '{name}' initialized successfully")
|
|
else:
|
|
logger.warning(f"[MCP] Server '{name}' failed to initialize — skipping")
|
|
|
|
def get(self, server_name: str) -> Optional[McpClient]:
|
|
"""Return the initialized client for server_name, or None."""
|
|
with self._registry_lock:
|
|
return self._clients.get(server_name)
|
|
|
|
def all_clients(self) -> dict:
|
|
"""Return a copy of the {name: McpClient} mapping."""
|
|
with self._registry_lock:
|
|
return dict(self._clients)
|
|
|
|
def shutdown_all(self) -> None:
|
|
"""Shut down all managed clients."""
|
|
with self._registry_lock:
|
|
clients = list(self._clients.values())
|
|
self._clients.clear()
|
|
|
|
for client in clients:
|
|
try:
|
|
client.shutdown()
|
|
except Exception as e:
|
|
logger.warning(f"[MCP] Error shutting down '{client.name}': {e}")
|
|
|
|
logger.info("[MCP] All servers shut down")
|