diff --git a/.gitignore b/.gitignore index 0612e1e3..de10c0b7 100644 --- a/.gitignore +++ b/.gitignore @@ -32,7 +32,6 @@ plugins/banwords/lib/__pycache__ !plugins/role !plugins/keyword !plugins/linkai -!plugins/agent !plugins/cow_cli client_config.json ref/ diff --git a/README.md b/README.md index 951e5e37..1edf1d3f 100644 --- a/README.md +++ b/README.md @@ -1,918 +1,257 @@ -

CowAgent

+

CowAgent

Latest release License: MIT Stars
- [中文] | [English] | [日本語] + [English] | [中文] | [日本語]

-**CowAgent** 是基于大模型的超级 AI 助理,能够主动思考和任务规划、操作计算机和外部资源、创造和执行 Skills、拥有长期记忆和知识库并不断成长,比 OpenClaw 更轻量和便捷。CowAgent 支持灵活切换多种模型,能处理文本、语音、图片、文件等多模态消息,可接入微信、飞书、钉钉、企微智能机器人、QQ、企微自建应用、微信公众号、网页中使用,7*24小时运行于你的个人电脑或服务器中。 +**CowAgent** is an open-source super AI assistant that proactively plans tasks, controls your computer and external services, creates and runs Skills, and grows alongside you through a personal knowledge base and long-term memory — a reference implementation of Agent Harness engineering. + +CowAgent is lightweight, easy to deploy, and built to extend. Plug in any major LLM provider and run it 24/7 on a personal computer or server, across the web and all major IM platforms.

- 🌐 官网  ·  - 📖 文档中心  ·  - 🚀 快速开始  ·  - 🧩 技能广场  ·  - ☁️ 在线体验 + 🌐 Website  ·  + 📖 Docs  ·  + 🚀 Quick Start  ·  + 🧩 Skill Hub  ·  + ☁️ Try Online

+
-# 简介 +## 🌟 Highlights -> 该项目既是一个可以开箱即用的超级 AI 助理,也是一个支持高扩展的 Agent 框架,可以通过为项目扩展大模型接口、接入渠道、内置工具、Skills 系统来灵活实现各种定制需求。核心能力如下: - -- ✅ **自主任务规划**:能够理解复杂任务并自主规划执行,持续思考和调用工具直到完成目标 -- ✅ **长期记忆:** 自动将对话记忆持久化至本地文件和数据库中,包括核心记忆、日级记忆和梦境蒸馏,支持关键词及向量检索 -- ✅ **个人知识库:** 自动整理结构化知识,通过交叉引用构建知识图谱,支持通过对话管理和可视化浏览知识库 -- ✅ **技能系统:** Skills 安装和运行的引擎,支持从 [Skill Hub](https://skills.cowagent.ai/)、GitHub 等一键安装技能,或通过对话创造 Skills -- ✅ **工具系统:** 内置文件读写、终端执行、浏览器操作、定时任务等工具,支持 MCP 协议,通过 Agent 自主调用完成复杂任务 -- ✅ **CLI系统:** 提供终端命令和对话命令,支持进程管理、技能安装、配置修改等操作 -- ✅ **多模态消息:** 支持对文本、图片、语音、文件等多类型消息进行解析、处理、生成、发送等操作 -- ✅ **多模型支持:** 支持 DeepSeek、MiniMax、Claude、Gemini、OpenAI、GLM、Qwen、Doubao、Kimi 等国内外主流模型厂商 -- ✅ **多通道接入:** 支持运行在本地计算机或服务器,可集成到微信、飞书、钉钉、企业微信、QQ、微信公众号、网页中使用 - -## 声明 - -1. 本项目遵循 [MIT 开源协议](/LICENSE),主要用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为。任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任。 -2. 成本与安全:Agent 模式下 Token 使用量高于普通对话模式,请根据效果及成本综合选择模型。Agent 具有访问所在操作系统的能力,请谨慎选择项目部署环境。同时项目也会持续升级安全机制、并降低模型消耗成本。 -3. CowAgent 项目专注于开源技术开发,不会参与、授权或发行任何加密货币。 - -## 演示 - -- 使用说明( Agent 模式):[CowAgent 介绍](https://docs.cowagent.ai/intro/features) - -- 免部署在线体验:[CowAgent](https://link-ai.tech/cowagent/create) - -- DEMO 视频(对话模式):https://cdn.link-ai.tech/doc/cow_demo.mp4 - -## 社区 - -添加小助手微信加入开源项目交流群: - - +| Capability | Description | +| :--- | :--- | +| [Planning](https://docs.cowagent.ai/en/intro/architecture) | Decomposes complex tasks and executes them step by step, looping over tools until the goal is reached | +| [Memory](https://docs.cowagent.ai/en/memory/index) | Three-tier architecture (context → daily → core), automatic Deep Dream distillation, hybrid keyword + vector retrieval | +| [Knowledge](https://docs.cowagent.ai/en/knowledge/index) | Auto-curates structured knowledge into a Markdown wiki, builds an evolving knowledge graph with visual browsing | +| [Skills](https://docs.cowagent.ai/en/skills/index) | One-click install from [Skill Hub](https://skills.cowagent.ai/), GitHub, ClawHub; or create custom skills via natural-language conversation | +| [Tools](https://docs.cowagent.ai/en/tools/index) | Built-in file I/O, terminal, browser, scheduler, memory retrieval, web search, and 10+ more tools — with native MCP integration | +| [Channels](https://docs.cowagent.ai/en/channels/index) | Integrates with Web, WeChat, Feishu, DingTalk, WeCom, QQ, Official Accounts, Telegram, and Slack | +| Multimodal | First-class support for text, images, voice, and files — recognition, generation, and delivery | +| [Models](https://docs.cowagent.ai/en/models/index) | Claude, GPT, Gemini, DeepSeek, Qwen, GLM, Kimi, MiniMax, Doubao, and more — swap providers from the Web console with one click | +| [Deploy](https://docs.cowagent.ai/en/guide/quick-start) | One-line installer, unified Web console, multiple deployment modes (local, Docker, server) |
-# 企业服务 +## 🏗️ Architecture - +CowAgent Architecture -> [LinkAI](https://link-ai.tech/) 是面向企业和个人的一站式 AI 智能体平台,聚合多模态大模型、知识库、技能、工作流等能力,支持一键接入主流平台并管理,支持 SaaS、私有化部署等多种模式,可免部署在线运行[CowAgent 助理](https://link-ai.tech/cowagent/create)。 -> -> LinkAI 目前已在智能客服、私域运营、企业效率助手等场景积累了丰富的 AI 解决方案,在消费、健康、文教、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。 +CowAgent is a complete **Agent Harness**: messages flow in through **Channels**; the **Agent Core** plans and reasons over memory, knowledge, and the available tools and skills; **Models** generate the response, which is sent back through the originating channel. Every layer is decoupled and independently extensible. -**产品咨询和企业服务** 可联系产品客服: - - +Read more in [Architecture](https://docs.cowagent.ai/en/intro/architecture).
-# 🏷 更新日志 +## 🚀 Quick Start ->**2026.05.06:** [2.0.8版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.8),飞书渠道全面升级(语音、流式输出和Markdown、一键扫码接入)、新模型支持(DeepSeek V4、百度千帆)、定时任务工具增强等 +A one-line installer takes care of dependencies, configuration, and startup: ->**2026.04.22:** [2.0.7版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.7),图像生成内置技能(GPT Image 2、Nano Banana 等)、新模型支持(Kimi K2.6、Claude Opus 4.7、GLM 5.1)、知识库和记忆增强、Web 控制台优化 +**Linux / macOS:** ->**2026.04.14:** [2.0.6版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.6),知识库系统、梦境记忆模块、上下文智能压缩、Web 控制台多会话及多项优化。 - ->**2026.04.01:** [2.0.5版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.5),Cow CLI 命令系统、Skill Hub 开源、浏览器工具、企微扫码创建、多项优化和修复。 - ->**2026.03.22:** [2.0.4版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.4),新增个人微信通道(微信扫码即用)、新增 MiniMax-M2.7 和 GLM-5-Turbo 模型、run.sh 脚本重构、日文文档及多项修复。 - ->**2026.03.18:** [2.0.3版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.3),新增企微智能机器人和 QQ 通道、支持 Coding Plan、新增多个模型、Web 端文件处理、记忆系统升级。 - ->**2026.02.27:** [2.0.2版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.2),Web 控制台全面升级(流式对话、模型/技能/记忆/通道/定时任务/日志管理)、支持多通道同时运行、会话持久化存储、新增多个模型。 - ->**2026.02.13:** [2.0.1版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.1),内置 Web Search 工具、智能上下文裁剪策略、运行时信息动态更新、Windows 兼容性适配,修复定时任务记忆丢失、飞书连接等多项问题。 - ->**2026.02.03:** [2.0.0版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.0),正式升级为超级 Agent 助理,支持多轮任务决策、具备长期记忆、实现多种系统工具、支持 Skills 框架,新增多种模型并优化了接入渠道。 - -更多更新历史请查看: [更新日志](https://docs.cowagent.ai/releases) - -
- -# 🚀 快速开始 - -项目提供了一键安装、配置、启动、管理程序的脚本,推荐使用脚本快速运行,也可以根据下文中的详细指引一步步安装运行。 - -在终端执行以下命令: - -**Linux / macOS:** ```bash bash <(curl -fsSL https://cdn.link-ai.tech/code/cow/run.sh) ``` -**Windows(PowerShell):** +**Windows (PowerShell):** + ```powershell irm https://cdn.link-ai.tech/code/cow/run.ps1 | iex ``` -脚本使用说明:[一键运行脚本](https://docs.cowagent.ai/guide/quick-start)。安装后可使用 `cow start`、`cow stop` 等 [CLI 命令](https://docs.cowagent.ai/cli/index) 管理服务。 - - -## 一、准备 - -### 1. 模型API - -项目支持国内外主流厂商的模型接口,可选模型及配置说明参考:[模型说明](#模型说明)。 - -> 注:Agent 模式下推荐使用以下模型,可根据效果及成本综合选择:deepseek-v4-flash、MiniMax-M2.7、glm-5.1、kimi-k2.6、qwen3.5-plus、claude-sonnet-4-6、gemini-3.1-pro-preview、gpt-5.4、gpt-5.4-mini、ernie-5.1 - -同时支持使用 **LinkAI 平台** 接口,支持上述全部模型,并支持知识库、工作流、插件等 Agent 技能,参考 [接口文档](https://docs.link-ai.tech/platform/api)。 - -### 2.环境安装 - -支持 Linux、MacOS、Windows 操作系统,可在个人计算机及服务器上运行,需安装 `Python`,Python 版本需在 3.7 ~ 3.13 之间。 - -> 注意:Agent 模式推荐使用源码运行,若选择 Docker 部署则无需安装 python 环境和下载源码,可直接快进到下一节。 - -**(1) 克隆项目代码:** - -```bash -git clone https://github.com/zhayujie/CowAgent -cd CowAgent/ -``` - -若遇到网络问题可使用国内仓库地址:https://gitee.com/zhayujie/CowAgent - -**(2) 安装核心依赖 (必选):** - -```bash -pip3 install -r requirements.txt -``` - -**(3) 拓展依赖 (可选,建议安装):** - -```bash -pip3 install -r requirements-optional.txt -``` - -> 国内网络可使用镜像源加速:`pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple` - -如果某项依赖安装失败可注释掉对应的行后重试。 - -**(4) 安装 Cow CLI (推荐):** - -```bash -pip3 install -e . -``` - -安装后可使用 `cow` 命令管理服务(启动、停止、更新等)和技能,详见 [命令文档](https://docs.cowagent.ai/cli/index)。 - -**(5) 安装浏览器工具 (可选):** - -如果需要 Agent 操作浏览器(如访问网页、填写表单等),需要额外安装浏览器依赖: - -```bash -cow install-browser -``` - -该命令会自动安装 `playwright` 和 Chromium 浏览器,国内网络自动使用镜像加速。详见 [浏览器工具文档](https://docs.cowagent.ai/tools/browser)。 - -## 二、配置 - -配置文件的模板在根目录的 `config-template.json` 中,需复制该模板创建最终生效的 `config.json` 文件: - -```bash - cp config-template.json config.json -``` - -然后在 `config.json` 中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证 JSON 格式的规范): - -```bash -# config.json 文件内容示例 -{ - "channel_type": "weixin", # 接入渠道类型,默认为 weixin, 支持修改为 feishu,dingtalk,wecom_bot,qq,wechatcom_app,wechatmp_service,wechatmp,terminal - "model": "deepseek-v4-flash", # 模型名称 - "deepseek_api_key": "", # DeepSeek API Key - "deepseek_api_base": "https://api.deepseek.com/v1", # DeepSeek API 地址 - "minimax_api_key": "", # MiniMax API Key - "zhipu_ai_api_key": "", # 智谱 GLM API Key - "moonshot_api_key": "", # Kimi/Moonshot API Key - "ark_api_key": "", # 豆包(火山方舟) API Key - "dashscope_api_key": "", # 百炼(通义千问) API Key - "claude_api_key": "", # Claude API Key - "claude_api_base": "https://api.anthropic.com/v1", # Claude API 地址,修改可接入三方代理平台 - "gemini_api_key": "", # Gemini API Key - "gemini_api_base": "https://generativelanguage.googleapis.com", # Gemini API 地址 - "open_ai_api_key": "", # OpenAI API Key - "open_ai_api_base": "https://api.openai.com/v1", # OpenAI API 地址 - "linkai_api_key": "", # LinkAI API Key - "proxy": "", # 代理客户端的 ip 和端口,国内环境需要开启代理的可填写该项,如 "127.0.0.1:7890" - "speech_recognition": false, # 是否开启语音识别 - "group_speech_recognition": false, # 是否开启群组语音识别 - "voice_reply_voice": false, # 是否使用语音回复语音 - "use_linkai": false, # 是否使用 LinkAI 接口,默认关闭,设置为 true 后可对接 LinkAI 平台模型 - "web_password": "", # Web 控制台访问密码,留空则不启用密码保护(监听 0.0.0.0 时务必设置) - "agent": true, # 是否启用 Agent 模式,启用后拥有多轮工具决策、长期记忆、Skills 能力等 - "agent_workspace": "~/cow", # Agent 的工作空间路径,用于存储 memory、skills、系统设定等 - "agent_max_context_tokens": 50000, # Agent 模式下最大上下文 tokens,超出将自动智能压缩处理 - "agent_max_context_turns": 20, # Agent 模式下最大上下文记忆轮次,一问一答为一轮,超出后智能压缩处理 - "agent_max_steps": 20, # Agent 模式下单次任务的最大决策步数,超出后将停止继续调用工具 - "enable_thinking": false # 是否启用深度思考模式 -} -``` - -**配置补充说明:** - -
-1. 语音配置 - -+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用 openai 的 whisper 模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图); -+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用 openai 的 whisper 模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配 group_chat_prefix 和 group_chat_keyword, 支持语音触发画图); -+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊) -+ 使用 MiniMax TTS:设置 `"text_to_voice": "minimax"`,并配置 `minimax_api_key`;可通过 `"tts_voice_id"` 指定发音人(如 `English_Graceful_Lady`),`"text_to_voice_model"` 指定模型(如 `speech-2.8-hd`、`speech-2.8-turbo`) -
- -
-2. 其他配置 - -+ `model`: 模型名称,Agent 模式下推荐使用 `deepseek-v4-flash`、`MiniMax-M2.7`、`glm-5.1`、`kimi-k2.6`、`qwen3.6-plus`、`claude-sonnet-4-6`、`gemini-3.1-pro-preview`,全部模型名称参考[common/const.py](https://github.com/zhayujie/CowAgent/blob/master/common/const.py)文件 -+ `character_desc`:普通对话模式下的机器人系统提示词。在 Agent 模式下该配置不生效,由工作空间中的文件内容构成。 -+ `subscribe_msg`:订阅消息,公众号和企业微信 channel 中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成 bot 的触发词。 -
- -
-3. LinkAI 配置 - -+ `use_linkai`: 是否使用 LinkAI 接口,默认关闭,设置为 true 后可对接 LinkAI 平台,使用模型、知识库、工作流、插件等技能, 参考[接口文档](https://docs.link-ai.tech/platform/api/chat) -+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建 -
- -注:全部配置项说明可在 [`config.py`](https://github.com/zhayujie/CowAgent/blob/master/config.py) 文件中查看。 - -## 三、运行 - -### 1.本地运行 - -如果是个人计算机 **本地运行**,直接在项目根目录下执行: - -```bash -cow start # 推荐,需先安装 Cow CLI -python3 app.py # 或直接运行,windows 环境下该命令通常为 python app.py -``` - -运行后默认会启动 web 服务,可通过访问 `http://localhost:9899/chat` 在网页端对话。 - -如果需要接入其他应用通道只需修改 `config.json` 配置文件中的 `channel_type` 参数,详情参考:[通道说明](#通道说明)。 - - -### 2.服务器部署 - -推荐使用 `cow` 命令管理服务: - -```bash -cow start # 后台启动 -cow stop # 停止服务 -cow restart # 重启服务 -cow status # 查看运行状态 -cow logs # 查看日志 -cow update # 拉取最新代码并重启 -``` - -也可以使用传统方式后台运行: - -```bash -nohup python3 app.py & tail -f nohup.out -``` - -此外,项目根目录下的 `run.sh` 脚本也支持一键管理服务,包括 `./run.sh start`、`./run.sh stop`、`./run.sh restart` 等命令,执行 `./run.sh help` 可查看全部用法。 - -> 如果需要通过浏览器访问 Web 控制台,请确保服务器的 `9899` 端口已在防火墙或安全组中放行,建议仅对指定 IP 开放以保证安全。 - -### 3.Docker部署 - -使用 docker 部署无需下载源码和安装依赖,只需要获取 `docker-compose.yml` 配置文件并启动容器即可。Agent 模式下更推荐使用源码进行部署,以获得更多系统访问能力。 - -> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功后执行 `docker -v` 和 `docker-compose version` (或 `docker compose version`) 可查看到版本号。安装地址为 [docker官网](https://docs.docker.com/engine/install/) 。 - -**(1) 下载 docker-compose.yml 文件** +**Docker:** ```bash curl -O https://cdn.link-ai.tech/code/cow/docker-compose.yml +docker compose up -d ``` -下载完成后打开 `docker-compose.yml` 填写所需配置,例如 `CHANNEL_TYPE`、`OPEN_AI_API_KEY` 和等配置。 +Once started, open `http://localhost:9899` to access the **Web console** — your one-stop hub to chat with the Agent, configure models, connect channels, and install skills. -**(2) 启动容器** +> Deploying on a server? Set `web_host` to `0.0.0.0` in `config.json` to make the console reachable from outside, and set `web_password` to protect it. Don't forget to open port `9899` in your firewall or security group. -在 `docker-compose.yml` 所在目录下执行以下命令启动容器: +> 📖 Detailed guides: [Quick Start](https://docs.cowagent.ai/en/guide/quick-start) · [Install from Source](https://docs.cowagent.ai/en/guide/manual-install) · [Upgrade](https://docs.cowagent.ai/en/guide/upgrade) + +After installation, manage the service with the [cow CLI](https://docs.cowagent.ai/en/cli/index): ```bash -sudo docker compose up -d # 若docker-compose为 1.X 版本,则执行 `sudo docker-compose up -d` +cow start | stop | restart # service control +cow status | logs # status and logs +cow update # pull latest code and restart +cow skill install # install a skill +cow install-browser # install browser automation ``` -运行命令后,会自动取 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取最新 release 版本的镜像。当执行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。最后执行以下命令可查看容器的运行日志: - -```bash -sudo docker logs -f chatgpt-on-wechat -``` - -> 如果需要通过浏览器访问 Web 控制台,请确保服务器的 `9899` 端口已在防火墙或安全组中放行,建议仅对指定 IP 开放以保证安全。 - -## 模型说明 - -推荐通过 Web 控制台在线管理模型配置,无需手动编辑文件,详见 [模型文档](https://docs.cowagent.ai/models)。以下是手动修改 `config.json` 配置模型的说明: - -
-DeepSeek - -1. API Key 创建:在 [DeepSeek 平台](https://platform.deepseek.com/api_keys) 创建 API Key - -2. 填写配置 - -方式一:官方接入(推荐): - -```json -{ - "model": "deepseek-v4-flash", - "deepseek_api_key": "sk-xxxxxxxxxxx" -} -``` - - - `model`: 推荐填写 `deepseek-v4-flash`、`deepseek-v4-pro` - - `deepseek_api_key`: DeepSeek 平台的 API Key - - `deepseek_api_base`: 可选,默认为 `https://api.deepseek.com/v1`,可修改为第三方代理地址 - -方式二:OpenAI 兼容方式接入: - -```json -{ - "model": "deepseek-v4-flash", - "bot_type": "openai", - "open_ai_api_key": "sk-xxxxxxxxxxx", - "open_ai_api_base": "https://api.deepseek.com/v1" -} -``` - -
- -
-MiniMax - -方式一:官方接入,配置如下(推荐): - -```json -{ - "model": "MiniMax-M2.7", - "minimax_api_key": "" -} -``` - - `model`: 可填写 `MiniMax-M2.7、MiniMax-M2.7-highspeed、MiniMax-M2.5、MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2、abab6.5-chat` 等 - - `minimax_api_key`:MiniMax 平台的 API-KEY,在 [控制台](https://platform.minimaxi.com/user-center/basic-information/interface-key) 创建 - -方式二:OpenAI 兼容方式接入,配置如下: -```json -{ - "bot_type": "openai", - "model": "MiniMax-M2.7", - "open_ai_api_base": "https://api.minimaxi.com/v1", - "open_ai_api_key": "" -} -``` -- `bot_type`: OpenAI 兼容方式 -- `model`: 可填 `MiniMax-M2.7、MiniMax-M2.7-highspeed、MiniMax-M2.5、MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2`,参考[API文档](https://platform.minimaxi.com/document/%E5%AF%B9%E8%AF%9D?key=66701d281d57f38758d581d0#QklxsNSbaf6kM4j6wjO5eEek) -- `open_ai_api_base`: MiniMax 平台 API 的 BASE URL -- `open_ai_api_key`: MiniMax 平台的 API-KEY -
- -
-Claude - -1. API Key 创建:在 [Claude控制台](https://console.anthropic.com/settings/keys) 创建 API Key - -2. 填写配置 - -```json -{ - "model": "claude-sonnet-4-6", - "claude_api_key": "YOUR_API_KEY" -} -``` - - `model`: 参考 [官方模型ID](https://docs.anthropic.com/en/docs/about-claude/models/overview#model-aliases) ,支持 `claude-sonnet-4-6、claude-opus-4-7、claude-opus-4-6、claude-sonnet-4-5、claude-sonnet-4-0、claude-opus-4-0、claude-3-5-sonnet-latest` 等 -
- -
-Gemini - -API Key 创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn) 创建 API Key ,配置如下 -```json -{ - "model": "gemini-3.1-flash-lite-preview", - "gemini_api_key": "" -} -``` - - `model`: 参考[官方文档-模型列表](https://ai.google.dev/gemini-api/docs/models?hl=zh-cn),支持 `gemini-3.1-flash-lite-preview、gemini-3.1-pro-preview、gemini-3-flash-preview、gemini-3-pro-preview` 等 -
- -
-OpenAI - -1. API Key 创建:在 [OpenAI平台](https://platform.openai.com/api-keys) 创建 API Key - -2. 填写配置 - -```json -{ - "model": "gpt-5.4", - "open_ai_api_key": "YOUR_API_KEY", - "open_ai_api_base": "https://api.openai.com/v1", - "bot_type": "openai" -} -``` - - - `model`: 与 OpenAI 接口的 [model参数](https://platform.openai.com/docs/models) 一致,支持包括 gpt-5.4、gpt-5.4-mini、gpt-5.4-nano、o 系列、gpt-4.1 等模型,Agent 模式推荐使用 `gpt-5.4`、`gpt-5.4-mini` - - `open_ai_api_base`: 如果需要接入第三方代理接口,可通过修改该参数进行接入 - - `bot_type`: 使用 OpenAI 相关模型时无需填写。当使用第三方代理接口接入 Claude 等非 OpenAI 官方模型时,该参数设为 `openai` -
- -
-智谱AI (GLM) - -方式一:官方接入,配置如下(推荐): - -```json -{ - "model": "glm-5.1", - "zhipu_ai_api_key": "" -} -``` - - `model`: 可填 `glm-5.1、glm-5-turbo、glm-5、glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等, 参考 [glm 系列模型编码](https://bigmodel.cn/dev/api/normal-model/glm-4) - - `zhipu_ai_api_key`: 智谱AI 平台的 API KEY,在 [控制台](https://www.bigmodel.cn/usercenter/proj-mgmt/apikeys) 创建 - -方式二:OpenAI 兼容方式接入,配置如下: -```json -{ - "bot_type": "openai", - "model": "glm-5.1", - "open_ai_api_base": "https://open.bigmodel.cn/api/paas/v4", - "open_ai_api_key": "" -} -``` -- `bot_type`: OpenAI 兼容方式 -- `model`: 可填 `glm-5.1、glm-5-turbo、glm-5、glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等 -- `open_ai_api_base`: 智谱AI 平台的 BASE URL -- `open_ai_api_key`: 智谱AI 平台的 API KEY -
- -
-通义千问 (Qwen) - -方式一:官方 SDK 接入,配置如下(推荐): - -```json -{ - "model": "qwen3.6-plus", - "dashscope_api_key": "sk-qVxxxxG" -} -``` - - `model`: 可填写 `qwen3.6-plus、qwen3.5-plus、qwen3-max、qwen-max、qwen-plus、qwen-turbo、qwen-long、qwq-plus` 等 - - `dashscope_api_key`: 通义千问的 API-KEY,参考 [官方文档](https://bailian.console.aliyun.com/?tab=api#/api) ,在 [百炼控制台](https://bailian.console.aliyun.com/?tab=model#/api-key) 创建 - -方式二:OpenAI 兼容方式接入,配置如下: -```json -{ - "bot_type": "openai", - "model": "qwen3.6-plus", - "open_ai_api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1", - "open_ai_api_key": "sk-qVxxxxG" -} -``` -- `bot_type`: OpenAI 兼容方式 -- `model`: 支持官方所有模型,参考[模型列表](https://help.aliyun.com/zh/model-studio/models?spm=a2c4g.11186623.0.0.78d84823Kth5on#9f8890ce29g5u) -- `open_ai_api_base`: 通义千问 API 的 BASE URL -- `open_ai_api_key`: 通义千问的 API-KEY -
- -
-豆包 (Doubao) - -1. API Key 创建:在 [火山方舟控制台](https://console.volcengine.com/ark/region:ark+cn-beijing/apikey) 创建API Key - -2. 填写配置 - -```json -{ - "model": "doubao-seed-2-0-code-preview-260215", - "ark_api_key": "YOUR_API_KEY" -} -``` - - `model`: 可填写 `doubao-seed-2-0-code-preview-260215、doubao-seed-2-0-pro-260215、doubao-seed-2-0-lite-260215、doubao-seed-2-0-mini-260215` 等 - - `ark_api_key`: 火山方舟平台的 API Key,在 [控制台](https://console.volcengine.com/ark/region:ark+cn-beijing/apikey) 创建 - - `ark_base_url`: 可选,默认为 `https://ark.cn-beijing.volces.com/api/v3` -
- -
-Kimi (Moonshot) - -方式一:官方接入,配置如下: - -```json -{ - "model": "kimi-k2.6", - "moonshot_api_key": "" -} -``` - - `model`: 可填写 `kimi-k2.6、kimi-k2.5、kimi-k2、moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k` - - `moonshot_api_key`: Moonshot 的 API-KEY,在 [控制台](https://platform.moonshot.cn/console/api-keys) 创建 - -方式二:OpenAI 兼容方式接入,配置如下: -```json -{ - "bot_type": "openai", - "model": "kimi-k2.6", - "open_ai_api_base": "https://api.moonshot.cn/v1", - "open_ai_api_key": "" -} -``` -- `bot_type`: OpenAI 兼容方式 -- `model`: 可填写 `kimi-k2.6、kimi-k2.5、kimi-k2、moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k` -- `open_ai_api_base`: Moonshot 的 BASE URL -- `open_ai_api_key`: Moonshot 的 API-KEY -
- -
-ModelScope - -```json -{ - "bot_type": "modelscope", - "model": "Qwen/QwQ-32B", - "modelscope_api_key": "your_api_key", - "modelscope_base_url": "https://api-inference.modelscope.cn/v1/chat/completions", - "text_to_image": "MusePublic/489_ckpt_FLUX_1" -} -``` - -- `bot_type`: modelscope 接口格式 -- `model`: 参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1) -- `modelscope_api_key`: 参考 [官方文档-访问令牌](https://modelscope.cn/docs/accounts/token) ,在 [控制台](https://modelscope.cn/my/myaccesstoken) -- `modelscope_base_url`: modelscope 平台的 BASE URL -- `text_to_image`: 图像生成模型,参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1) -
- -
-LinkAI - -1. API Key 创建:在 [LinkAI平台](https://link-ai.tech/console/interface) 创建 API Key - -2. 填写配置 - -```json -{ - "model": "gpt-5.4-mini", - "use_linkai": true, - "linkai_api_key": "YOUR API KEY" -} -``` - -+ `use_linkai`: 是否使用 LinkAI 接口,默认关闭,设置为 true 后可对接 LinkAI 平台的模型,并使用知识库、工作流、数据库、插件等丰富的 Agent 技能 -+ `linkai_api_key`: LinkAI 平台的 API Key,可在 [控制台](https://link-ai.tech/console/interface) 中创建 -+ `model`: [模型列表](https://link-ai.tech/console/models)中的全部模型均可使用 -
- -
-Azure - -1. API Key 创建:在 [Azure平台](https://oai.azure.com/) 创建 API Key - -2. 填写配置 - -```json -{ - "model": "", - "use_azure_chatgpt": true, - "open_ai_api_key": "", - "open_ai_api_base": "", - "azure_deployment_id": "", - "azure_api_version": "2025-01-01-preview" -} -``` - - - `model`: 留空即可 - - `use_azure_chatgpt`: 设为 true - - `open_ai_api_key`: Azure 平台的密钥 - - `open_ai_api_base`: Azure 平台的 BASE URL - - `azure_deployment_id`: Azure 平台部署的模型名称 - - `azure_api_version`: api 版本以及以上参数可以在部署的 [模型配置](https://oai.azure.com/resource/deployments) 界面查看 -
- -
-百度千帆 / ERNIE - -方式一:官方接入(推荐),配置如下: - -```json -{ - "model": "ernie-5.1", - "qianfan_api_key": "", - "qianfan_api_base": "https://qianfan.baidubce.com/v2" -} -``` - - - `model`: 默认推荐填写 `ernie-5.1`(多模态,可直接识图),也可填写 `ernie-5.0`、`ernie-x1.1`、`ernie-4.5-turbo-128k`、`ernie-4.5-turbo-32k`;当主模型为纯文本 ERNIE 时,Vision 工具会自动 fallback 到 `ernie-4.5-turbo-vl` - - `qianfan_api_key`: 百度千帆 API Key,通常以 `bce-v3/` 开头,可在百度智能云控制台创建 - - `qianfan_api_base`: 可选,默认为 `https://qianfan.baidubce.com/v2` - -方式二:OpenAI 兼容方式接入,配置如下: -```json -{ - "bot_type": "openai", - "model": "ernie-5.1", - "open_ai_api_base": "https://qianfan.baidubce.com/v2", - "open_ai_api_key": "" -} -``` -- `bot_type`: OpenAI 兼容方式 -- `model`: 支持千帆平台上的 ERNIE 模型 -- `open_ai_api_base`: 百度千帆 OpenAI 兼容 API 的 BASE URL -- `open_ai_api_key`: 百度千帆 API Key - -
- -
-讯飞星火 - -方式一:官方接入,配置如下: -参考 [官方文档-快速指引](https://www.xfyun.cn/doc/platform/quickguide.html#%E7%AC%AC%E4%BA%8C%E6%AD%A5-%E5%88%9B%E5%BB%BA%E6%82%A8%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8-%E5%BC%80%E5%A7%8B%E4%BD%BF%E7%94%A8%E6%9C%8D%E5%8A%A1) 获取 `APPID、 APISecret、 APIKey` 三个参数 - -```json -{ - "model": "xunfei", - "xunfei_app_id": "", - "xunfei_api_key": "", - "xunfei_api_secret": "", - "xunfei_domain": "4.0Ultra", - "xunfei_spark_url": "wss://spark-api.xf-yun.com/v4.0/chat" -} -``` - - `model`: 填 `xunfei` - - `xunfei_domain`: 可填写 `4.0Ultra、generalv3.5、max-32k、generalv3、pro-128k、lite` - - `xunfei_spark_url`: 填写参考 [官方文档-请求地址](https://www.xfyun.cn/doc/spark/Web.html#_1-1-%E8%AF%B7%E6%B1%82%E5%9C%B0%E5%9D%80) 的说明 - -方式二:OpenAI 兼容方式接入,配置如下: -```json -{ - "bot_type": "openai", - "model": "4.0Ultra", - "open_ai_api_base": "https://spark-api-open.xf-yun.com/v1", - "open_ai_api_key": "" -} -``` -- `bot_type`: OpenAI 兼容方式 -- `model`: 可填写 `4.0Ultra、generalv3.5、max-32k、generalv3、pro-128k、lite` -- `open_ai_api_base`: 讯飞星火平台的 BASE URL -- `open_ai_api_key`: 讯飞星火平台的[APIPassword](https://console.xfyun.cn/services/bm3) ,因模型而已 -
- -
-Coding Plan - -Coding Plan 是各厂商推出的编程包月套餐,所有厂商均可通过 OpenAI 兼容方式接入: - -```json -{ - "bot_type": "openai", - "model": "模型名称", - "open_ai_api_base": "厂商 Coding Plan API Base", - "open_ai_api_key": "YOUR_API_KEY" -} -``` - -目前支持阿里云、MiniMax、智谱 GLM、Kimi、火山引擎等厂商,各厂商详细配置请参考 [Coding Plan 文档](https://docs.cowagent.ai/models/coding-plan)。 -
- - -## 通道说明 - -推荐通过 Web 控制台在线管理通道配置,无需手动编辑文件,详见 [通道文档](https://docs.cowagent.ai/channels/weixin)。以下为手动修改 `config.json` 配置通道的说明: - -支持同时可接入多个通道,配置时可通过逗号进行分割,例如 `"channel_type": "feishu,dingtalk"`。 - -
-1. Weixin - 微信 - -接入个人微信,扫码登录即可使用,支持文本、图片、语音、文件等消息收发。 - -```json -{ - "channel_type": "weixin" -} -``` - -启动后终端会显示二维码,使用微信扫码授权即可,也可以在 Web 控制台的「通道」页面中扫码接入。登录凭证会自动保存至 `~/.weixin_cow_credentials.json`,下次启动无需重新扫码,如需重新登录删除该文件后重启即可。 - -详细步骤和参数说明参考 [微信接入](https://docs.cowagent.ai/channels/weixin) - -
- -
-2. Web - -项目启动后会默认运行 Web 控制台,配置如下: - -```json -{ - "channel_type": "web", - "web_host": "0.0.0.0", - "web_password": "YOUR PASSWORD", - "web_port": 9899 -} -``` - -- `web_host`: 监听地址,默认 `127.0.0.1`(仅本机),如需公网访问请改为 `0.0.0.0` 并设置密码 -- `web_port`: 默认为 9899,可按需更改,需要服务器防火墙和安全组放行该端口 -- `web_password`: 访问密码,留空则不启用密码保护。部署在公网环境时请务必设置 -- 如本地运行,启动后请访问 `http://localhost:9899` ;如服务器运行,请访问 `http://YOUR_IP:9899` -> 注:请将上述 url 中的 ip 或者 port 替换为实际的值 -
- -
-3. Feishu - 飞书 - -飞书使用 WebSocket 长连接模式,无需公网 IP。详细步骤参考 [飞书接入](https://docs.cowagent.ai/channels/feishu)。 - -**方式一:扫码一键创建(推荐)** - -启动 Cow 后打开 Web 控制台,**通道** → **接入通道** → 选择 **飞书** → 扫码创建。也支持 CLI 启动时在终端打印二维码。 - -**方式二:手动配置** - -在飞书开放平台创建自建应用并配置权限后,将凭据填入 `config.json`: - -```json -{ - "channel_type": "feishu", - "feishu_app_id": "APP_ID", - "feishu_app_secret": "APP_SECRET", - "feishu_stream_reply": true -} -``` - -- `feishu_stream_reply`:是否开启流式打字机回复,默认开启(需 `cardkit:card:write` 权限 + 飞书客户端 ≥ 7.20) - -
- -
-4. DingTalk - 钉钉 - -钉钉需要在开放平台创建智能机器人应用,将以下配置填入 `config.json`: - -```json -{ - "channel_type": "dingtalk", - "dingtalk_client_id": "CLIENT_ID", - "dingtalk_client_secret": "CLIENT_SECRET" -} -``` -详细步骤和参数说明参考 [钉钉接入](https://docs.cowagent.ai/channels/dingtalk) -
- -
-5. WeCom Bot - 企微智能机器人 - -企微智能机器人使用 WebSocket 长连接模式,无需公网 IP 和域名。详细步骤参考 [企微智能机器人接入](https://docs.cowagent.ai/channels/wecom-bot)。 - -**方式一:扫码一键创建(推荐)** - -启动 Cow 后打开 Web 控制台,**通道** → **接入通道** → 选择 **企微智能机器人** → 使用企业微信扫码创建。 - -**方式二:手动配置** - -在企业微信中创建智能机器人并选择**长连接模式**,记录 Bot ID 和 Secret 后填入 `config.json`: - -```json -{ - "channel_type": "wecom_bot", - "wecom_bot_id": "YOUR_BOT_ID", - "wecom_bot_secret": "YOUR_SECRET" -} -``` - -
- -
-6. QQ - QQ 机器人 - -QQ 机器人使用 WebSocket 长连接模式,无需公网 IP 和域名,支持 QQ 单聊、群聊和频道消息: - -```json -{ - "channel_type": "qq", - "qq_app_id": "YOUR_APP_ID", - "qq_app_secret": "YOUR_APP_SECRET" -} -``` -详细步骤和参数说明参考 [QQ 机器人接入](https://docs.cowagent.ai/channels/qq) - -
- -
-7. WeCom App - 企业微信应用 - -企业微信自建应用接入需在后台创建应用并启用消息回调,配置示例: - -```json -{ - "channel_type": "wechatcom_app", - "wechatcom_corp_id": "CORPID", - "wechatcomapp_token": "TOKEN", - "wechatcomapp_port": 9898, - "wechatcomapp_secret": "SECRET", - "wechatcomapp_agent_id": "AGENTID", - "wechatcomapp_aes_key": "AESKEY" -} -``` -详细步骤和参数说明参考 [企微自建应用接入](https://docs.cowagent.ai/channels/wecom) - -
- -
-8. WeChat MP - 微信公众号 - -本项目支持订阅号和服务号两种公众号,通过服务号(`wechatmp_service`)体验更佳。 - -**个人订阅号(wechatmp)** - -```json -{ - "channel_type": "wechatmp", - "wechatmp_token": "TOKEN", - "wechatmp_port": 80, - "wechatmp_app_id": "APPID", - "wechatmp_app_secret": "APPSECRET", - "wechatmp_aes_key": "" -} -``` - -**企业服务号(wechatmp_service)** - -```json -{ - "channel_type": "wechatmp_service", - "wechatmp_token": "TOKEN", - "wechatmp_port": 80, - "wechatmp_app_id": "APPID", - "wechatmp_app_secret": "APPSECRET", - "wechatmp_aes_key": "" -} -``` - -详细步骤和参数说明参考 [微信公众号接入](https://docs.cowagent.ai/channels/wechatmp) - -
- -
-9. Terminal - 终端 - -修改 `config.json` 中的 `channel_type` 字段: - -```json -{ - "channel_type": "terminal" -} -``` - -运行后可在终端与机器人进行对话。 - -
-
-# 🔗 相关项目 +## 🤖 Models -- [Cow Skill Hub](https://github.com/zhayujie/cow-skill-hub):开源的 AI Agent 技能广场,浏览、搜索、安装和发布技能,支持 CowAgent、OpenClaw、Claude Code 等多种 Agent。 -- [bot-on-anything](https://github.com/zhayujie/bot-on-anything):轻量和高可扩展的大模型应用框架,支持接入 Slack, Telegram, Discord, Gmail 等海外平台,可作为本项目的补充使用。 -- [AgentMesh](https://github.com/MinimalFuture/AgentMesh):开源的多智能体( Multi-Agent )框架,可以通过多智能体团队的协同来解决复杂问题。 +CowAgent supports all mainstream LLM providers. **Chat, vision, image generation, ASR/TTS, and embeddings** can each be routed to a different vendor. Providers are configured directly in the Web console — no manual file editing required. +| Provider | Featured Models | Chat | Vision | Image Gen | ASR | TTS | Embedding | +| --- | --- | :-: | :-: | :-: | :-: | :-: | :-: | +| [Claude](https://docs.cowagent.ai/en/models/claude) | claude-opus-4-8 | ✅ | ✅ | | | | | +| [OpenAI](https://docs.cowagent.ai/en/models/openai) | gpt-5.5, o-series | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Gemini](https://docs.cowagent.ai/en/models/gemini) | gemini-3.5-flash | ✅ | ✅ | ✅ | | | | +| [DeepSeek](https://docs.cowagent.ai/en/models/deepseek) | deepseek-v4-flash / pro | ✅ | | | | | | +| [Qwen](https://docs.cowagent.ai/en/models/qwen) | qwen3.7-max | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [GLM](https://docs.cowagent.ai/en/models/glm) | glm-5.1, glm-5v-turbo | ✅ | ✅ | | ✅ | | ✅ | +| [Doubao](https://docs.cowagent.ai/en/models/doubao) | doubao-seed-2.0 series | ✅ | ✅ | ✅ | | | ✅ | +| [Kimi](https://docs.cowagent.ai/en/models/kimi) | kimi-k2.6 | ✅ | ✅ | | | | | +| [MiniMax](https://docs.cowagent.ai/en/models/minimax) | MiniMax-M2.7 | ✅ | ✅ | ✅ | | ✅ | | +| [ERNIE](https://docs.cowagent.ai/en/models/qianfan) | ernie-5.1 | ✅ | ✅ | | | | | +| [MiMo](https://docs.cowagent.ai/en/models/mimo) | mimo-v2.5 / pro | ✅ | ✅ | | | ✅ | | +| [LinkAI](https://docs.cowagent.ai/en/models/linkai) | One key for 100+ models | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Custom](https://docs.cowagent.ai/en/models/custom) | Local models / third-party proxy | ✅ | | | | | | +> For details on each provider, see the [Models overview](https://docs.cowagent.ai/en/models/index). +
-# 🔎 常见问题 +## 💬 Channels -FAQs: +A single Agent instance can serve multiple channels in parallel. Most channels can be onboarded right from the Web console. -或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (知识库持续完善中,回复供参考) +| Channel | Text | Image | File | Voice | Group | +| --- | :-: | :-: | :-: | :-: | :-: | +| [Web Console](https://docs.cowagent.ai/en/channels/web) (default) | ✅ | ✅ | ✅ | ✅ | | +| [WeChat](https://docs.cowagent.ai/en/channels/weixin) | ✅ | ✅ | ✅ | ✅ | | +| [Feishu / Lark](https://docs.cowagent.ai/en/channels/feishu) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [DingTalk](https://docs.cowagent.ai/en/channels/dingtalk) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [WeCom Bot](https://docs.cowagent.ai/en/channels/wecom-bot) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [QQ](https://docs.cowagent.ai/en/channels/qq) | ✅ | ✅ | ✅ | | ✅ | +| [WeCom App](https://docs.cowagent.ai/en/channels/wecom) | ✅ | ✅ | ✅ | ✅ | | +| [WeChat Official Account](https://docs.cowagent.ai/en/channels/wechatmp) | ✅ | ✅ | | ✅ | | +| [Telegram](https://docs.cowagent.ai/en/channels/telegram) | ✅ | ✅ | ✅ | ✅ | ✅ | +| [Slack](https://docs.cowagent.ai/en/channels/slack) | ✅ | ✅ | ✅ | | ✅ | -# 🛠️ 开发 +> See the [Channels overview](https://docs.cowagent.ai/en/channels/index) for setup details. -欢迎接入更多应用通道,参考 [飞书通道](https://github.com/zhayujie/CowAgent/blob/master/channel/feishu/feishu_channel.py) 新增自定义通道,实现接收和发送消息逻辑即可完成接入。同时欢迎贡献新的 Skills,向 [Skill Hub](https://skills.cowagent.ai/submit) 提交技能。 +CowAgent Web Console -# ✉ 联系 +*The Web console is the default channel and the unified entry point to configure models, channels, skills, memory, and more.* -欢迎提交PR、Issues进行反馈,以及通过 🌟Star 支持并关注项目更新。项目运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/CowAgent/wiki/FAQs) ,以及前往 [Issues](https://github.com/zhayujie/CowAgent/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品客服](https://cdn.link-ai.tech/portal/linkai-customer-service.png)咨询。 +
-# 🌟 贡献者 +## 🧠 Memory & Knowledge Base + +**Long-term memory** uses a three-tier architecture: conversation context (short-term) → daily memory (mid-term) → MEMORY.md (long-term). A nightly **Deep Dream** pass distills scattered memories into refined long-term entries and a narrative journal. See [Long-term Memory](https://docs.cowagent.ai/en/memory/index) · [Deep Dream](https://docs.cowagent.ai/en/memory/deep-dream). + +**Personal knowledge base** complements the time-ordered memory by organizing structured knowledge **by topic**. The Agent automatically curates valuable information from conversations, maintains cross-references and indexes, and the Web console offers an interactive knowledge-graph view. See [Personal Knowledge Base](https://docs.cowagent.ai/en/knowledge/index). + + + + + + +
+ Long-term Memory +

Long-term Memory · Three-tier architecture + Deep Dream

+
+ Personal Knowledge Base +

Knowledge Base · Auto-curated Markdown wiki

+
+ +
+ +## 🔧 Tools & Skills + +**Tools** are atomic capabilities the Agent uses to interact with system resources. **Skills** are higher-level workflows defined by a manifest file that compose multiple tools to accomplish complex tasks. + +### Tool System + +**Built-in tools** cover file I/O (`read` / `write` / `edit` / `ls`), terminal (`bash`), file sending (`send`), memory retrieval (`memory`), environment variables (`env_config`), web fetching (`web_fetch`), scheduling (`scheduler`), web search (`web_search`), vision (`vision`), and browser automation (`browser`). + +**MCP protocol** integrates the open ecosystem of [Model Context Protocol](https://modelcontextprotocol.io) servers. A single `mcp.json` is enough — supports stdio / SSE transports, hot reload, and zero-code integration. + +Learn more: [Tools overview](https://docs.cowagent.ai/en/tools/index) · [MCP integration](https://docs.cowagent.ai/en/tools/mcp). + +### Skills System + +- **[Skill Hub](https://skills.cowagent.ai/)** — open skill marketplace: browse, search, install in one click +- **GitHub / ClawHub / URL and more** — install skills from any source +- **Conversational authoring** — generate custom skills through dialogue with `skill-creator`; turn any workflow or third-party API into a reusable skill + +```bash +/skill list # list installed skills +/skill search # search the marketplace +/skill install # one-click install +``` + +Learn more: [Skills overview](https://docs.cowagent.ai/en/skills/index) · [Creating Skills](https://docs.cowagent.ai/en/skills/create). + +
+ +## 🏷 Changelog + +> **2026.05.22:** [v2.0.9](https://github.com/zhayujie/CowAgent/releases/tag/2.0.9) — Model management, MCP protocol support, persistent browser sessions, new models (gpt-5.5, gemini-3.5-flash, qwen3.7-max), deployment hardening. + +> **2026.05.06:** [v2.0.8](https://github.com/zhayujie/CowAgent/releases/tag/2.0.8) — Feishu channel overhaul (voice, streaming, QR onboarding), DeepSeek V4 and Baidu Qianfan support, scheduler tool upgrades. + +> **2026.04.22:** [v2.0.7](https://github.com/zhayujie/CowAgent/releases/tag/2.0.7) — Built-in image generation (GPT Image 2, Nano Banana), new models (Kimi K2.6, Claude Opus 4.7, GLM 5.1), memory and knowledge enhancements. + +> **2026.04.14:** [v2.0.6](https://github.com/zhayujie/CowAgent/releases/tag/2.0.6) — Knowledge base, Deep Dream memory distillation, smart context compression, multi-session Web console. + +> **2026.04.01:** [v2.0.5](https://github.com/zhayujie/CowAgent/releases/tag/2.0.5) — Cow CLI, Skill Hub open source, browser tool, WeCom Bot QR onboarding. + +> **2026.02.03:** [v2.0.0](https://github.com/zhayujie/CowAgent/releases/tag/2.0.0) — Major upgrade to a super Agent assistant with multi-step task planning, long-term memory, and the Skills framework. + +Full history: [Release Notes](https://docs.cowagent.ai/en/releases/overview) + +
+ +## 🤝 Community & Support + +[File an issue](https://github.com/zhayujie/CowAgent/issues) on GitHub, or scan the QR code below to join our WeChat community: + + + +
+ +## 🔗 Related Projects + +- **[Cow Skill Hub](https://github.com/zhayujie/cow-skill-hub)** — open skill marketplace for AI Agents; works with CowAgent, OpenClaw, Claude Code, and more +- **[bot-on-anything](https://github.com/zhayujie/bot-on-anything)** — lightweight LLM application framework with integrations for Slack, Telegram, Discord, Gmail, and more +- **[AgentMesh](https://github.com/MinimalFuture/AgentMesh)** — open-source multi-agent framework for solving complex problems through team collaboration + +
+ +## 🏢 Enterprise Services + +[**LinkAI**](https://link-ai.tech/) is an all-in-one AI Agent platform for enterprises and developers, offering managed hosting and enterprise-grade support for CowAgent: + +- **🚀 Zero-deployment hosted runtime** — spin up a [CowAgent online assistant](https://link-ai.tech/cowagent/create) in under a minute, no server required +- **🧠 Agent infrastructure** — unified access to LLMs, knowledge bases, databases, skills, and workflows; plug-and-play building blocks that extend what CowAgent can do +- **🏢 Team & enterprise features** — workspaces, role-based access, audit logs, and private deployment for production use cases + +For enterprise inquiries: sales@simple-future.tech or [scan the QR code](https://cdn.link-ai.tech/consultant.jpg) to reach our team on WeChat. + +
+ +## 🛠️ Development & Contributing + +Contributions are welcome — add a new channel by following the [Feishu channel reference](https://github.com/zhayujie/CowAgent/blob/master/channel/feishu/feishu_channel.py), or contribute new skills to [Skill Hub](https://skills.cowagent.ai/submit). + +⭐ Star the project to follow updates, and feel free to open PRs and Issues. + +## 🌟 Contributors ![cow contributors](https://contrib.rocks/image?repo=zhayujie/CowAgent&max=1000) -# 📌 项目更名说明 +
-本项目原名 `chatgpt-on-wechat`(GitHub 原地址:https://github.com/zhayujie/chatgpt-on-wechat ), -于 2026.04.13 正式更名为 **CowAgent**。GitHub 已自动设置重定向,原有链接仍可正常访问。 +## ⚠️ Disclaimer -如需更新本地仓库的远程地址(可选): -```bash -git remote set-url origin https://github.com/zhayujie/CowAgent.git -``` +1. This project is licensed under the [MIT License](/LICENSE) and is intended for technical research and learning. You are responsible for complying with applicable laws and regulations in your jurisdiction; the maintainers assume no liability for any consequences arising from use of this project. +2. **Cost & safety:** Agent mode consumes substantially more tokens than regular chat — pick models that balance quality and cost. The Agent has access to your local operating system, so only deploy it in trusted environments. +3. CowAgent is a pure open-source project and does not participate in, authorize, or issue any cryptocurrency. + +
+ +## 📌 Project Renaming Notice + +This project was previously named `chatgpt-on-wechat` and is now officially **CowAgent**. The old GitHub URL redirects automatically; existing users may optionally run `git remote set-url origin https://github.com/zhayujie/CowAgent.git` to update the local remote. diff --git a/agent/memory/conversation_store.py b/agent/memory/conversation_store.py index c5d215bf..48148f61 100644 --- a/agent/memory/conversation_store.py +++ b/agent/memory/conversation_store.py @@ -44,6 +44,7 @@ CREATE TABLE IF NOT EXISTS messages ( role TEXT NOT NULL, content TEXT NOT NULL, created_at INTEGER NOT NULL, + extras TEXT NOT NULL DEFAULT '', UNIQUE (session_id, seq) ); @@ -67,6 +68,12 @@ _MIGRATION_ADD_CONTEXT_START_SEQ = """ ALTER TABLE sessions ADD COLUMN context_start_seq INTEGER NOT NULL DEFAULT 0; """ +# Generic JSON sidecar for per-message attachments (TTS audio URL, future use). +# Always optional — readers must tolerate missing column / empty / invalid JSON. +_MIGRATION_ADD_MSG_EXTRAS = """ +ALTER TABLE messages ADD COLUMN extras TEXT NOT NULL DEFAULT ''; +""" + DEFAULT_MAX_AGE_DAYS: int = 30 @@ -169,20 +176,26 @@ def _group_into_display_turns( cur_rest: List[tuple] = [] started = False - for role, raw_content, created_at in rows: + for role, raw_content, created_at, raw_extras in rows: try: content = json.loads(raw_content) except Exception: content = raw_content + try: + extras = json.loads(raw_extras) if raw_extras else {} + if not isinstance(extras, dict): + extras = {} + except Exception: + extras = {} if role == "user" and _is_visible_user_message(content): if started: groups.append((cur_user, cur_rest)) - cur_user = (content, created_at) + cur_user = (content, created_at, extras) cur_rest = [] started = True else: - cur_rest.append((role, content, created_at)) + cur_rest.append((role, content, created_at, extras)) if started: groups.append((cur_user, cur_rest)) @@ -195,7 +208,7 @@ def _group_into_display_turns( for user_row, rest in groups: # User turn if user_row: - content, created_at = user_row + content, created_at, _u_extras = user_row text = _extract_display_text(content) if text: turns.append({"role": "user", "content": text, "created_at": created_at}) @@ -206,8 +219,11 @@ def _group_into_display_turns( tool_results: Dict[str, str] = {} final_text = "" final_ts: Optional[int] = None + merged_extras: Dict[str, Any] = {} - for role, content, created_at in rest: + for role, content, created_at, extras in rest: + if role == "assistant" and isinstance(extras, dict): + merged_extras.update(extras) if role == "user": tool_results.update(_extract_tool_results(content)) elif role == "assistant": @@ -256,6 +272,8 @@ def _group_into_display_turns( "steps": steps, "created_at": final_ts or (user_row[1] if user_row else 0), } + if merged_extras: + turn["extras"] = merged_extras turns.append(turn) return turns @@ -411,13 +429,15 @@ class ConversationStore: content = json.dumps( msg.get("content", ""), ensure_ascii=False ) + extras_obj = msg.get("extras") or {} + extras = json.dumps(extras_obj, ensure_ascii=False) if extras_obj else "" conn.execute( """ INSERT OR IGNORE INTO messages - (session_id, seq, role, content, created_at) - VALUES (?, ?, ?, ?, ?) + (session_id, seq, role, content, created_at, extras) + VALUES (?, ?, ?, ?, ?, ?) """, - (session_id, next_seq, role, content, now), + (session_id, next_seq, role, content, now, extras), ) next_seq += 1 @@ -651,6 +671,55 @@ class ConversationStore: logger.info(f"[ConversationStore] Pruned {deleted} expired sessions") return deleted + def attach_extras_to_last_assistant( + self, + session_id: str, + extras: Dict[str, Any], + ) -> Optional[int]: + """ + Merge ``extras`` into the latest assistant message of a session. + + Used by post-processing (e.g. TTS) that needs to annotate an already + persisted bot reply with attachments such as audio URLs. + + Returns the message seq that was updated, or ``None`` if no assistant + message exists or the update could not be applied. + """ + if not extras: + return None + with self._lock: + conn = self._connect() + try: + row = conn.execute( + """ + SELECT seq, extras FROM messages + WHERE session_id = ? AND role = 'assistant' + ORDER BY seq DESC LIMIT 1 + """, + (session_id,), + ).fetchone() + if not row: + return None + seq, raw = row + try: + cur = json.loads(raw) if raw else {} + if not isinstance(cur, dict): + cur = {} + except Exception: + cur = {} + cur.update(extras) + conn.execute( + "UPDATE messages SET extras = ? WHERE session_id = ? AND seq = ?", + (json.dumps(cur, ensure_ascii=False), session_id, seq), + ) + conn.commit() + return seq + except Exception as e: + logger.warning(f"[ConversationStore] attach_extras failed: {e}") + return None + finally: + conn.close() + def load_history_page( self, session_id: str, @@ -698,15 +767,31 @@ class ConversationStore: ).fetchone() ctx_start = ctx_row[0] if ctx_row else 0 - rows = conn.execute( - """ - SELECT seq, role, content, created_at - FROM messages - WHERE session_id = ? - ORDER BY seq ASC - """, - (session_id,), - ).fetchall() + # extras column is added by migration; tolerate older DBs that + # might miss it by falling back to a NULL literal. + try: + rows = conn.execute( + """ + SELECT seq, role, content, created_at, extras + FROM messages + WHERE session_id = ? + ORDER BY seq ASC + """, + (session_id,), + ).fetchall() + except sqlite3.OperationalError: + rows = [ + (seq, role, content, created_at, "") + for (seq, role, content, created_at) in conn.execute( + """ + SELECT seq, role, content, created_at + FROM messages + WHERE session_id = ? + ORDER BY seq ASC + """, + (session_id,), + ).fetchall() + ] finally: conn.close() @@ -719,13 +804,16 @@ class ConversationStore: include_thinking = False # Strip seq for display grouping, but record max seq per visible user group - plain_rows = [(role, content, created_at) for _seq, role, content, created_at in rows] + plain_rows = [ + (role, content, created_at, extras_raw) + for _seq, role, content, created_at, extras_raw in rows + ] visible = _group_into_display_turns(plain_rows, include_thinking=include_thinking) # Build a mapping: find the seq of each visible user message to annotate context boundary. # Walk through rows to find visible user message seqs in order. visible_user_seqs: List[int] = [] - for seq, role, raw_content, _ts in rows: + for seq, role, raw_content, _ts, _extras in rows: if role != "user": continue try: @@ -911,6 +999,18 @@ class ConversationStore: except Exception as e: logger.warning(f"[ConversationStore] Migration (context_start_seq) failed: {e}") + msg_cols = { + row[1] + for row in conn.execute("PRAGMA table_info(messages)").fetchall() + } + if "extras" not in msg_cols: + try: + conn.execute(_MIGRATION_ADD_MSG_EXTRAS) + conn.commit() + logger.info("[ConversationStore] Migrated: added messages.extras column") + except Exception as e: + logger.warning(f"[ConversationStore] Migration (extras) failed: {e}") + def _connect(self) -> sqlite3.Connection: conn = sqlite3.connect(str(self._db_path), timeout=10) conn.execute("PRAGMA journal_mode=WAL") diff --git a/agent/memory/embedding/state.py b/agent/memory/embedding/state.py index 3fb60b23..5efffef2 100644 --- a/agent/memory/embedding/state.py +++ b/agent/memory/embedding/state.py @@ -31,9 +31,13 @@ def detect_index_dim(storage) -> Optional[int]: if not row or not row["embedding"]: return None try: - emb = json.loads(row["embedding"]) + raw = row["embedding"] + if isinstance(raw, (bytes, bytearray)): + # New BLOB format: 4 bytes per float32 + return len(raw) // 4 + emb = json.loads(raw) return len(emb) if isinstance(emb, list) else None - except (json.JSONDecodeError, TypeError): + except (json.JSONDecodeError, TypeError, Exception): return None diff --git a/agent/memory/manager.py b/agent/memory/manager.py index 6aaac767..5ec2ade7 100644 --- a/agent/memory/manager.py +++ b/agent/memory/manager.py @@ -13,7 +13,7 @@ from datetime import datetime, timedelta from agent.memory.config import MemoryConfig, get_default_memory_config from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult from agent.memory.chunker import TextChunker -from agent.memory.embedding import EmbeddingProvider +from agent.memory.embedding import EmbeddingProvider, EmbeddingCache from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed @@ -61,7 +61,11 @@ class MemoryManager: logger.info( "[MemoryManager] No embedding provider; memory will use keyword search only" ) - + + # Cache for query embeddings (avoids redundant API calls within a session) + self._embedding_cache = EmbeddingCache() + + # Initialize memory flush manager workspace_dir = self.config.get_workspace() self.flush_manager = MemoryFlushManager( @@ -128,7 +132,14 @@ class MemoryManager: vector_results = [] if self.embedding_provider: try: - query_embedding = self.embedding_provider.embed_query(query) + provider_name = type(self.embedding_provider).__name__ + model_name = getattr(self.embedding_provider, 'model', '') + cached = self._embedding_cache.get(query, provider_name, model_name) + if cached is not None: + query_embedding = cached + else: + query_embedding = self.embedding_provider.embed_query(query) + self._embedding_cache.put(query, provider_name, model_name, query_embedding) vector_results = self.storage.search_vector( query_embedding=query_embedding, user_id=user_id, diff --git a/agent/memory/storage.py b/agent/memory/storage.py index 0a4e6edb..683b083f 100644 --- a/agent/memory/storage.py +++ b/agent/memory/storage.py @@ -5,12 +5,42 @@ Provides vector and keyword search capabilities """ from __future__ import annotations +import re import sqlite3 import json import hashlib +import threading from typing import List, Dict, Optional, Any from pathlib import Path from dataclasses import dataclass +try: + import numpy as np + _HAS_NUMPY = True +except ImportError: + _HAS_NUMPY = False + np = None # type: ignore[assignment] + +# UPSERT (INSERT … ON CONFLICT DO UPDATE) requires SQLite ≥ 3.24.0 (2018). +# Older systems (e.g. CentOS 7 ships SQLite 3.7) fall back to INSERT OR REPLACE, +# which risks FTS5 rowid drift on chunk updates (see save_chunk docstring). +_HAS_UPSERT = sqlite3.sqlite_version_info >= (3, 24, 0) + +# --------------------------------------------------------------------------- +# CJK character ranges, compiled once at module load. +# Covers: CJK Symbols/Punctuation, Japanese kana (hiragana + katakana), +# CJK Unified Ideographs + Extension A, Korean syllables (Hangul), +# CJK Compatibility Ideographs, and CJK Extension B–F. +# --------------------------------------------------------------------------- +_CJK_RANGES = ( + r'\u3000-\u30ff' # CJK Symbols/Punctuation + Japanese kana + r'\u3400-\u9fff' # CJK Unified Ideographs (incl. Extension A) + r'\uac00-\ud7af' # Korean syllables (Hangul) + r'\uf900-\ufaff' # CJK Compatibility Ideographs + r'\U00020000-\U0002fa1f' # CJK Extension B–F +) +_RE_CONTAINS_CJK = re.compile(f'[{_CJK_RANGES}]') +_RE_CJK_WORDS = re.compile(f'[{_CJK_RANGES}]+') +_RE_TRIGRAM_TOKENS = re.compile(f'[{_CJK_RANGES}]+|[A-Za-z0-9_]+') @dataclass @@ -48,6 +78,10 @@ class MemoryStorage: self.db_path = db_path self.conn: Optional[sqlite3.Connection] = None self.fts5_available = False # Track FTS5 availability + # RLock protects concurrent writes from the same process. + # SQLite WAL mode handles read/write concurrency at the file level, + # but same-process concurrent writes still need a Python-level lock. + self._lock = threading.RLock() self._init_db() def _check_fts5_support(self) -> bool: @@ -69,6 +103,14 @@ class MemoryStorage: # Check FTS5 support self.fts5_available = self._check_fts5_support() + if not _HAS_UPSERT: + from common.log import logger + logger.warning( + "[MemoryStorage] SQLite %s < 3.24 — UPSERT unavailable. " + "Falling back to INSERT OR REPLACE; FTS5 rowid may drift on " + "chunk updates (rebuild index periodically to recover).", + sqlite3.sqlite_version, + ) if not self.fts5_available: from common.log import logger logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search") @@ -175,6 +217,75 @@ class MemoryStorage: ) self._rebuild_fts5_from_chunks() + # Internal key-value store for persistent flags (e.g. backfill tracking) + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS _meta ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) + """) + + # Create trigram FTS5 table for CJK / mixed-language search + self.trigram_fts5_available = False + if self.fts5_available: + try: + self.conn.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts_trigram USING fts5( + text, + id UNINDEXED, + user_id UNINDEXED, + path UNINDEXED, + source UNINDEXED, + scope UNINDEXED, + content='chunks', + content_rowid='rowid', + tokenize='trigram case_sensitive 0' + ) + """) + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_trigram_ai + AFTER INSERT ON chunks BEGIN + INSERT INTO chunks_fts_trigram(rowid, text, id, user_id, path, source, scope) + VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope); + END + """) + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_trigram_ad + AFTER DELETE ON chunks BEGIN + DELETE FROM chunks_fts_trigram WHERE rowid = old.rowid; + END + """) + self.conn.execute(""" + CREATE TRIGGER IF NOT EXISTS chunks_trigram_au + AFTER UPDATE ON chunks BEGIN + UPDATE chunks_fts_trigram + SET text=new.text, id=new.id, user_id=new.user_id, + path=new.path, source=new.source, scope=new.scope + WHERE rowid = new.rowid; + END + """) + # One-time backfill for existing rows. + # NOTE: COUNT(*) on an FTS5 content table always returns 0, so we + # use a persistent flag in _meta instead of counting trigram rows. + backfill_done = self.conn.execute( + "SELECT 1 FROM _meta WHERE key = 'trigram_backfill_done'" + ).fetchone() + chunks_count = self.conn.execute( + "SELECT COUNT(*) as c FROM chunks" + ).fetchone()['c'] + if chunks_count > 0 and not backfill_done: + self.conn.execute( + "INSERT INTO chunks_fts_trigram(chunks_fts_trigram) VALUES('rebuild')" + ) + self.conn.execute( + "INSERT OR REPLACE INTO _meta(key, value) VALUES('trigram_backfill_done', '1')" + ) + self.trigram_fts5_available = True + except Exception: + from common.log import logger + logger.warning("[MemoryStorage] trigram FTS5 unavailable, CJK search will use LIKE fallback", exc_info=True) + self.trigram_fts5_available = False + # Create files metadata table self.conn.execute(""" CREATE TABLE IF NOT EXISTS files ( @@ -186,7 +297,7 @@ class MemoryStorage: updated_at INTEGER DEFAULT (strftime('%s', 'now')) ) """) - + self.conn.commit() def _fts5_state_inconsistent(self) -> bool: @@ -299,43 +410,98 @@ class MemoryStorage: self.conn.commit() def save_chunk(self, chunk: MemoryChunk): - """Save a memory chunk""" - self.conn.execute(""" - INSERT OR REPLACE INTO chunks - (id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) - """, ( - chunk.id, - chunk.user_id, - chunk.scope, - chunk.source, - chunk.path, - chunk.start_line, - chunk.end_line, - chunk.text, - json.dumps(chunk.embedding) if chunk.embedding else None, + """Save a memory chunk (insert or update by id). + + Uses SQLite UPSERT (INSERT … ON CONFLICT DO UPDATE) instead of + INSERT OR REPLACE. INSERT OR REPLACE internally does DELETE+INSERT, + which changes the row's rowid. Because both FTS5 tables use + content_rowid='rowid', a new rowid would leave the old FTS index + entries pointing at a non-existent rowid and trigger + "fts5: missing row N from content table" errors. + ON CONFLICT DO UPDATE fires the AFTER UPDATE trigger (chunks_au / + chunks_trigram_au) and keeps the original rowid intact. + """ + if _HAS_UPSERT: + _SQL = """ + INSERT INTO chunks + (id, user_id, scope, source, path, start_line, end_line, + text, embedding, hash, metadata, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) + ON CONFLICT(id) DO UPDATE SET + user_id = excluded.user_id, + scope = excluded.scope, + source = excluded.source, + path = excluded.path, + start_line = excluded.start_line, + end_line = excluded.end_line, + text = excluded.text, + embedding = excluded.embedding, + hash = excluded.hash, + metadata = excluded.metadata, + updated_at = strftime('%s', 'now') + """ + else: + _SQL = """ + INSERT OR REPLACE INTO chunks + (id, user_id, scope, source, path, start_line, end_line, + text, embedding, hash, metadata, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) + """ + params = ( + chunk.id, chunk.user_id, chunk.scope, chunk.source, chunk.path, + chunk.start_line, chunk.end_line, chunk.text, + self._encode_embedding(chunk.embedding), chunk.hash, - json.dumps(chunk.metadata) if chunk.metadata else None - )) - self.conn.commit() - + json.dumps(chunk.metadata) if chunk.metadata else None, + ) + with self._lock: + self.conn.execute(_SQL, params) + self.conn.commit() + def save_chunks_batch(self, chunks: List[MemoryChunk]): - """Save multiple chunks in a batch""" - self.conn.executemany(""" - INSERT OR REPLACE INTO chunks - (id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) - """, [ + """Save multiple chunks in a batch (insert or update by id). + + See save_chunk for why UPSERT is used instead of INSERT OR REPLACE. + """ + if _HAS_UPSERT: + _SQL = """ + INSERT INTO chunks + (id, user_id, scope, source, path, start_line, end_line, + text, embedding, hash, metadata, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) + ON CONFLICT(id) DO UPDATE SET + user_id = excluded.user_id, + scope = excluded.scope, + source = excluded.source, + path = excluded.path, + start_line = excluded.start_line, + end_line = excluded.end_line, + text = excluded.text, + embedding = excluded.embedding, + hash = excluded.hash, + metadata = excluded.metadata, + updated_at = strftime('%s', 'now') + """ + else: + _SQL = """ + INSERT OR REPLACE INTO chunks + (id, user_id, scope, source, path, start_line, end_line, + text, embedding, hash, metadata, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now')) + """ + params_list = [ ( c.id, c.user_id, c.scope, c.source, c.path, c.start_line, c.end_line, c.text, - json.dumps(c.embedding) if c.embedding else None, + self._encode_embedding(c.embedding), c.hash, - json.dumps(c.metadata) if c.metadata else None + json.dumps(c.metadata) if c.metadata else None, ) for c in chunks - ]) - self.conn.commit() + ] + with self._lock: + self.conn.executemany(_SQL, params_list) + self.conn.commit() def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]: """Get a chunk by ID""" @@ -356,21 +522,21 @@ class MemoryStorage: limit: int = 10 ) -> List[SearchResult]: """ - Vector similarity search using in-memory cosine similarity - (sqlite-vec can be added later for better performance) + Vector similarity search using numpy-vectorized cosine similarity. + All embeddings are loaded then scored in a single BLAS matrix-vector + multiply, which is ~100x faster than the pure-Python per-row loop. """ if scopes is None: scopes = ["shared"] if user_id: scopes.append("user") - - # Build query + scope_placeholders = ','.join('?' * len(scopes)) - params = scopes - + params = list(scopes) + if user_id: query = f""" - SELECT * FROM chunks + SELECT * FROM chunks WHERE scope IN ({scope_placeholders}) AND (scope = 'shared' OR user_id = ?) AND embedding IS NOT NULL @@ -378,51 +544,95 @@ class MemoryStorage: params.append(user_id) else: query = f""" - SELECT * FROM chunks + SELECT * FROM chunks WHERE scope IN ({scope_placeholders}) AND embedding IS NOT NULL """ - + rows = self.conn.execute(query, params).fetchall() + if not rows: + return [] - # Calculate cosine similarity. We probe the first row's dim to fail - # loudly on a query/index dim mismatch — otherwise every doc would - # score 0 silently, leaving the user wondering why search broke. - results = [] - query_dim = len(query_embedding) - if rows: - first = json.loads(rows[0]['embedding']) - if isinstance(first, list) and len(first) != query_dim: - raise ValueError( - f"Embedding dim mismatch: query is {query_dim}-dim but " - f"index stores {len(first)}-dim vectors. The configured " - f"embedding model differs from the one that built the " - f"index — run /memory rebuild-index to re-embed." - ) - + # Parse embeddings and build a (N, D) matrix in one pass. + # New rows store BLOB bytes (np.frombuffer); legacy rows fall back to JSON. + # Filter out rows whose embedding dimension differs from the query — + # mixing dimensions would cause np.array() to produce an object array + # and matrix @ q_vec to raise ValueError. + expected_dim = len(query_embedding) + valid_rows = [] + vectors = [] for row in rows: - embedding = json.loads(row['embedding']) - similarity = self._cosine_similarity(query_embedding, embedding) + vec = self._decode_embedding(row['embedding']) + if not vec: + continue + if len(vec) != expected_dim: + from common.log import logger + logger.warning( + "[MemoryStorage] Skipping chunk %s: embedding dim %d != query dim %d", + row['id'], len(vec), expected_dim + ) + continue + valid_rows.append(row) + vectors.append(vec) - if similarity > 0: - results.append((similarity, row)) - - # Sort by similarity and limit - results.sort(key=lambda x: x[0], reverse=True) - results = results[:limit] - - return [ - SearchResult( - path=row['path'], - start_line=row['start_line'], - end_line=row['end_line'], - score=score, - snippet=self._truncate_text(row['text'], 500), - source=row['source'], - user_id=row['user_id'] - ) - for score, row in results - ] + if not vectors: + return [] + + if _HAS_NUMPY: + matrix = np.array(vectors, dtype=np.float32) # (N, D) + q_vec = np.array(query_embedding, dtype=np.float32) # (D,) + + # Vectorized cosine similarity: dot(matrix, q) / (||matrix|| * ||q||) + dots = matrix @ q_vec # (N,) + row_norms = np.linalg.norm(matrix, axis=1) # (N,) + q_norm = float(np.linalg.norm(q_vec)) + denominators = row_norms * q_norm + np.maximum(denominators, 1e-10, out=denominators) # avoid div-by-zero + sims = dots / denominators # (N,) + + # Select TopK using argpartition (O(N) average), then sort only those K + k = min(limit, len(valid_rows)) + top_idx = np.argpartition(sims, -k)[-k:] + top_idx = top_idx[np.argsort(sims[top_idx])[::-1]] + + return [ + SearchResult( + path=valid_rows[i]['path'], + start_line=valid_rows[i]['start_line'], + end_line=valid_rows[i]['end_line'], + score=float(sims[i]), + snippet=self._truncate_text(valid_rows[i]['text'], 500), + source=valid_rows[i]['source'], + user_id=valid_rows[i]['user_id'] + ) + for i in top_idx + if sims[i] > 0 + ] + else: + # Pure-Python cosine similarity fallback (numpy not installed) + import math + q = query_embedding + q_norm = math.sqrt(sum(x * x for x in q)) or 1e-10 + scored = [] + for i, vec in enumerate(vectors): + dot = sum(a * b for a, b in zip(vec, q)) + v_norm = math.sqrt(sum(x * x for x in vec)) or 1e-10 + sim = dot / (v_norm * q_norm) + if sim > 0: + scored.append((sim, valid_rows[i])) + scored.sort(key=lambda x: x[0], reverse=True) + return [ + SearchResult( + path=row['path'], + start_line=row['start_line'], + end_line=row['end_line'], + score=sim, + snippet=self._truncate_text(row['text'], 500), + source=row['source'], + user_id=row['user_id'] + ) + for sim, row in scored[:limit] + ] def search_keyword( self, @@ -445,12 +655,37 @@ class MemoryStorage: if user_id: scopes.append("user") - if self.fts5_available: + # Step 1: Standard FTS5 (unicode61) — pure ASCII queries only. + # Skipped when query contains any CJK characters: unicode61 tokenises CJK + # as individual characters without forming meaningful tokens, so it would + # match only the ASCII portion of a mixed query (e.g. "Python" from + # "Python教程") and silently discard the CJK part. Those queries go + # directly to Step 2 (trigram), which handles both ASCII and CJK together. + fts1_attempted = False + if (self.fts5_available + and not MemoryStorage._contains_cjk(query) + and MemoryStorage._build_fts_query(query)): + fts1_attempted = True fts_results = self._search_fts5(query, user_id, scopes, limit) if fts_results: return fts_results - return self._search_like(query, user_id, scopes, limit) + # Step 2: Trigram FTS5 — CJK/mixed queries, plus fallback when unicode61 + # returned nothing (trigram indexes all scripts with 3-char sliding windows, + # so it can catch terms that unicode61 tokenisation misses). + if self.trigram_fts5_available and ( + MemoryStorage._contains_cjk(query) or fts1_attempted + ): + trigram_results = self._search_fts5_trigram(query, user_id, scopes, limit) + if trigram_results: + return trigram_results + + # Step 3: LIKE fallback — last resort (FTS5 unavailable, or CJK tokens + # shorter than 3 characters that trigram cannot match, e.g. a single-char query). + if not self.fts5_available or MemoryStorage._contains_cjk(query): + return self._search_like(query, user_id, scopes, limit) + + return [] def _search_fts5( self, @@ -471,7 +706,7 @@ class MemoryStorage: sql_query = f""" SELECT chunks.*, bm25(chunks_fts) as rank FROM chunks_fts - JOIN chunks ON chunks.id = chunks_fts.id + JOIN chunks ON chunks.rowid = chunks_fts.rowid WHERE chunks_fts MATCH ? AND chunks.scope IN ({scope_placeholders}) AND (chunks.scope = 'shared' OR chunks.user_id = ?) @@ -483,7 +718,7 @@ class MemoryStorage: sql_query = f""" SELECT chunks.*, bm25(chunks_fts) as rank FROM chunks_fts - JOIN chunks ON chunks.id = chunks_fts.id + JOIN chunks ON chunks.rowid = chunks_fts.rowid WHERE chunks_fts MATCH ? AND chunks.scope IN ({scope_placeholders}) ORDER BY rank @@ -505,13 +740,11 @@ class MemoryStorage: ) for row in rows ] - except Exception as e: + except Exception: from common.log import logger - logger.error( - f"[MemoryStorage] FTS5 search failed (caller will fall back to LIKE): {e}" - ) + logger.warning("[MemoryStorage] _search_fts5 failed, returning empty", exc_info=True) return [] - + def _search_like( self, query: str, @@ -522,12 +755,11 @@ class MemoryStorage: """LIKE-based search. Used as the keyword-search fallback when FTS5 is unavailable, fails, - or returns empty. Supports both CJK runs and ASCII word tokens so it - can serve as a true safety net for any query. + or returns empty. Supports both CJK runs (1+ chars) and ASCII word + tokens (3+ chars) so it can serve as a true safety net for any query. """ - import re - # CJK runs (2+ chars) + ASCII word tokens (3+ chars to avoid noise) - cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query) + # CJK runs (1+ chars, wide Unicode range) + ASCII words (3+ chars to avoid noise) + cjk_words = _RE_CJK_WORDS.findall(query) ascii_words = [t for t in re.findall(r'[A-Za-z0-9_]+', query) if len(t) >= 3] words = cjk_words + ascii_words if not words: @@ -565,44 +797,54 @@ class MemoryStorage: try: rows = self.conn.execute(sql_query, params).fetchall() - return [ - SearchResult( + results = [] + for row in rows: + # Dynamic score: reward chunks that contain more of the query words. + # Use all tokens (CJK + ASCII) so pure-ASCII queries are not skipped. + # matched_count is always ≥1 because the WHERE clause uses OR, but + # guard defensively so unexpected zero-match rows are never surfaced. + text_lower = row['text'].lower() + matched_count = sum(1 for w in words if w.lower() in text_lower) + if matched_count == 0: + continue + score = min(0.85, 0.3 + 0.15 * matched_count) + results.append(SearchResult( path=row['path'], start_line=row['start_line'], end_line=row['end_line'], - score=0.5, # Fixed score for LIKE search + score=score, snippet=self._truncate_text(row['text'], 500), source=row['source'], user_id=row['user_id'] - ) - for row in rows - ] - except Exception as e: + )) + results.sort(key=lambda r: r.score, reverse=True) + return results + except Exception: from common.log import logger - logger.error(f"[MemoryStorage] LIKE search failed: {e}") + logger.warning("[MemoryStorage] _search_like failed, returning empty", exc_info=True) return [] - + def delete_by_path(self, path: str): """Delete all chunks from a file""" - self.conn.execute(""" - DELETE FROM chunks WHERE path = ? - """, (path,)) - self.conn.commit() - + with self._lock: + self.conn.execute("DELETE FROM chunks WHERE path = ?", (path,)) + self.conn.commit() + def get_file_hash(self, path: str) -> Optional[str]: """Get stored file hash""" row = self.conn.execute(""" SELECT hash FROM files WHERE path = ? """, (path,)).fetchone() return row['hash'] if row else None - + def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int): """Update file metadata""" - self.conn.execute(""" - INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at) - VALUES (?, ?, ?, ?, ?, strftime('%s', 'now')) - """, (path, source, file_hash, mtime, size)) - self.conn.commit() + with self._lock: + self.conn.execute(""" + INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at) + VALUES (?, ?, ?, ?, ?, strftime('%s', 'now')) + """, (path, source, file_hash, mtime, size)) + self.conn.commit() def get_stats(self) -> Dict[str, int]: """Get storage statistics""" @@ -632,7 +874,8 @@ class MemoryStorage: self.conn.close() self.conn = None # Mark as closed except Exception as e: - print(f"⚠️ Error closing database connection: {e}") + from common.log import logger + logger.warning("[MemoryStorage] Error closing database connection: %s", e) def __del__(self): """Destructor to ensure connection is closed""" @@ -642,7 +885,33 @@ class MemoryStorage: pass # Ignore errors during cleanup # Helper methods - + + @staticmethod + def _encode_embedding(embedding: Optional[List[float]]) -> Optional[bytes]: + """Encode embedding as float32 BLOB bytes (~6x smaller and faster than JSON). + Falls back to struct.pack when numpy is unavailable.""" + if embedding is None: + return None + if _HAS_NUMPY: + return np.array(embedding, dtype=np.float32).tobytes() + import struct + return struct.pack(f'{len(embedding)}f', *embedding) + + @staticmethod + def _decode_embedding(raw) -> Optional[List[float]]: + """Decode embedding from BLOB bytes or legacy JSON string. + Handles both numpy and numpy-free environments.""" + if raw is None: + return None + if isinstance(raw, (bytes, bytearray)): + if _HAS_NUMPY: + return np.frombuffer(raw, dtype=np.float32).tolist() + import struct + n = len(raw) // 4 + return list(struct.unpack(f'{n}f', raw)) + # Legacy JSON format written by older versions + return json.loads(raw) + def _row_to_chunk(self, row) -> MemoryChunk: """Convert database row to MemoryChunk""" return MemoryChunk( @@ -654,32 +923,89 @@ class MemoryStorage: start_line=row['start_line'], end_line=row['end_line'], text=row['text'], - embedding=json.loads(row['embedding']) if row['embedding'] else None, + embedding=self._decode_embedding(row['embedding']), hash=row['hash'], metadata=json.loads(row['metadata']) if row['metadata'] else None ) @staticmethod - def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float: - """Calculate cosine similarity between two vectors""" - if len(vec1) != len(vec2): - return 0.0 - - dot_product = sum(a * b for a, b in zip(vec1, vec2)) - norm1 = sum(a * a for a in vec1) ** 0.5 - norm2 = sum(b * b for b in vec2) ** 0.5 - - if norm1 == 0 or norm2 == 0: - return 0.0 - - return dot_product / (norm1 * norm2) + def _contains_cjk(text: str) -> bool: + """Check if text contains CJK or related characters (Chinese, Japanese, Korean).""" + return bool(_RE_CONTAINS_CJK.search(text)) @staticmethod - def _contains_cjk(text: str) -> bool: - """Check if text contains CJK (Chinese/Japanese/Korean) characters""" - import re - return bool(re.search(r'[\u4e00-\u9fff]', text)) - + def _build_trigram_query(raw_query: str) -> Optional[str]: + """ + Build FTS5 MATCH query for the trigram tokenizer. + Extracts CJK sequences (including single characters) and ASCII words, + joining them with AND so all terms must appear in the matched chunk. + """ + tokens = _RE_TRIGRAM_TOKENS.findall(raw_query) + tokens = [t for t in tokens if t] + if not tokens: + return None + # Escape embedded double-quotes (FTS5 uses "" inside quoted phrases) + quoted = [f'"{t.replace(chr(34), chr(34)*2)}"' for t in tokens] + return ' AND '.join(quoted) + + def _search_fts5_trigram( + self, + query: str, + user_id: Optional[str], + scopes: List[str], + limit: int + ) -> List[SearchResult]: + """Trigram FTS5 search — handles CJK and mixed queries with BM25 ranking.""" + trigram_query = self._build_trigram_query(query) + if not trigram_query: + return [] + + scope_placeholders = ','.join('?' * len(scopes)) + params = [trigram_query] + list(scopes) + + if user_id: + sql = f""" + SELECT chunks.*, bm25(chunks_fts_trigram) as rank + FROM chunks_fts_trigram + JOIN chunks ON chunks.rowid = chunks_fts_trigram.rowid + WHERE chunks_fts_trigram MATCH ? + AND chunks.scope IN ({scope_placeholders}) + AND (chunks.scope = 'shared' OR chunks.user_id = ?) + ORDER BY rank + LIMIT ? + """ + params.extend([user_id, limit]) + else: + sql = f""" + SELECT chunks.*, bm25(chunks_fts_trigram) as rank + FROM chunks_fts_trigram + JOIN chunks ON chunks.rowid = chunks_fts_trigram.rowid + WHERE chunks_fts_trigram MATCH ? + AND chunks.scope IN ({scope_placeholders}) + ORDER BY rank + LIMIT ? + """ + params.append(limit) + + try: + rows = self.conn.execute(sql, params).fetchall() + return [ + SearchResult( + path=row['path'], + start_line=row['start_line'], + end_line=row['end_line'], + score=self._bm25_rank_to_score(row['rank']), + snippet=self._truncate_text(row['text'], 500), + source=row['source'], + user_id=row['user_id'] + ) + for row in rows + ] + except Exception: + from common.log import logger + logger.warning("[MemoryStorage] _search_fts5_trigram failed, returning empty", exc_info=True) + return [] + @staticmethod def _build_fts_query(raw_query: str) -> Optional[str]: """ @@ -688,7 +1014,6 @@ class MemoryStorage: Works best for English and word-based languages. For CJK characters, LIKE search will be used as fallback. """ - import re # Extract words (primarily English words and numbers) tokens = re.findall(r'[A-Za-z0-9_]+', raw_query) if not tokens: @@ -701,9 +1026,22 @@ class MemoryStorage: @staticmethod def _bm25_rank_to_score(rank: float) -> float: - """Convert BM25 rank to 0-1 score""" - normalized = max(0, rank) if rank is not None else 999 - return 1 / (1 + normalized) + """Convert SQLite BM25 rank to a [0, 1) relevance score. + + SQLite's bm25() returns a non-positive float (0 or negative). + More negative = more relevant. max(0, rank) would clip every + negative value to 0, making every score 1/(1+0) = 1.0 and + destroying all ranking information. + + abs(rank) / (1 + abs(rank)) maps the absolute relevance magnitude + to [0, 1): larger |rank| (stronger match) → score closer to 1. + """ + if rank is None: + return 0.0 + # Add a floor of 0.3 so any FTS5 match always exceeds typical + # min_score thresholds (default 0.1). Small-corpus ranks close to + # 0 would otherwise produce score≈0 and be filtered out downstream. + return 0.3 + 0.69 * (abs(rank) / (1.0 + abs(rank))) @staticmethod def _truncate_text(text: str, max_chars: int) -> str: diff --git a/agent/protocol/__init__.py b/agent/protocol/__init__.py index a9fe5a3e..f0a7a4e2 100644 --- a/agent/protocol/__init__.py +++ b/agent/protocol/__init__.py @@ -3,6 +3,11 @@ from .agent_stream import AgentStreamExecutor from .task import Task, TaskType, TaskStatus from .result import AgentResult, AgentAction, AgentActionType, ToolResult from .models import LLMModel, LLMRequest, ModelFactory +from .cancel import ( + AgentCancelledError, + CancelTokenRegistry, + get_cancel_registry, +) __all__ = [ 'Agent', @@ -16,5 +21,8 @@ __all__ = [ 'ToolResult', 'LLMModel', 'LLMRequest', - 'ModelFactory' -] \ No newline at end of file + 'ModelFactory', + 'AgentCancelledError', + 'CancelTokenRegistry', + 'get_cancel_registry', +] diff --git a/agent/protocol/agent.py b/agent/protocol/agent.py index 285a9732..d944660b 100644 --- a/agent/protocol/agent.py +++ b/agent/protocol/agent.py @@ -365,7 +365,8 @@ class Agent: return action - def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, skill_filter=None) -> str: + def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, + skill_filter=None, cancel_event=None) -> str: """ Execute single agent task with streaming (based on tool-call) @@ -374,6 +375,7 @@ class Agent: - Multi-turn reasoning based on tool-call - Event callbacks - Persistent conversation history across calls + - User-initiated cancellation via ``cancel_event`` Args: user_message: User message @@ -381,6 +383,11 @@ class Agent: event = {"type": str, "timestamp": float, "data": dict} clear_history: If True, clear conversation history before this call (default: False) skill_filter: Optional list of skill names to include in this run + cancel_event: Optional threading.Event polled at agent checkpoints. + When set, the loop exits at the next safe point, injects a + "[Interrupted by user]" assistant note, and returns the + partial response. ``messages`` stays in a valid state + (tool_use/tool_result pairs preserved). Returns: Final response text @@ -424,7 +431,8 @@ class Agent: max_turns=self.max_steps, on_event=on_event, messages=messages_copy, # Pass copied message history - max_context_turns=max_context_turns + max_context_turns=max_context_turns, + cancel_event=cancel_event, ) # Execute diff --git a/agent/protocol/agent_stream.py b/agent/protocol/agent_stream.py index 75b4f4ff..e3be20b8 100644 --- a/agent/protocol/agent_stream.py +++ b/agent/protocol/agent_stream.py @@ -7,11 +7,19 @@ import json import time from typing import List, Dict, Any, Optional, Callable, Tuple +from agent.protocol.cancel import AgentCancelledError from agent.protocol.models import LLMRequest, LLMModel from agent.protocol.message_utils import sanitize_claude_messages, compress_turn_to_text_only from agent.tools.base_tool import BaseTool, ToolResult from common.log import logger +# Optional: repair malformed JSON args from non-strict providers (e.g. unescaped quotes in long content). +try: + from json_repair import repair_json as _repair_json + _HAS_JSON_REPAIR = True +except ImportError: + _HAS_JSON_REPAIR = False + # Maximum number of characters of model "reasoning / thinking" content to persist # in conversation history. The full reasoning is still streamed to the UI in real @@ -44,6 +52,30 @@ def _truncate_reasoning_for_storage(text: str) -> str: return head + _REASONING_TRUNCATE_MARKER.format(omitted=omitted) + tail +def _parse_tool_args(args_str: str, finish_reason: Optional[str]) -> Tuple[dict, Optional[str]]: + """Parse tool args JSON. Returns (args, error_msg); error_msg is None on success. + + On JSONDecodeError: detect truncation first (skip repair, surface max_tokens hint); + otherwise try json-repair for escape issues; finally fall back to the raw decoder error. + """ + if not args_str: + return {}, None + try: + return json.loads(args_str), None + except json.JSONDecodeError as e: + if finish_reason in ("length", "max_tokens") or not args_str.rstrip().endswith("}"): + return {}, "Output truncated (max_tokens reached). Split content into smaller chunks across multiple tool calls." + if _HAS_JSON_REPAIR: + try: + repaired = _repair_json(args_str, return_objects=True) + if isinstance(repaired, dict): + logger.warning(f"Tool args JSON repaired ({len(args_str)} chars)") + return repaired, None + except Exception: + pass + return {}, f"Invalid JSON in tool arguments: {e.msg}" + + class AgentStreamExecutor: """ Agent Stream Executor @@ -64,7 +96,8 @@ class AgentStreamExecutor: max_turns: int = 50, on_event: Optional[Callable] = None, messages: Optional[List[Dict]] = None, - max_context_turns: int = 30 + max_context_turns: int = 30, + cancel_event=None, ): """ Initialize stream executor @@ -78,6 +111,10 @@ class AgentStreamExecutor: on_event: Event callback function messages: Optional existing message history (for persistent conversations) max_context_turns: Maximum number of conversation turns to keep in context + cancel_event: Optional threading.Event used to signal user cancel. + Checked at every safe point (turn boundary, before tool execution, + during LLM streaming). When set, raises AgentCancelledError which + run_stream catches to gracefully wind down. """ self.agent = agent self.model = model @@ -87,6 +124,7 @@ class AgentStreamExecutor: self.max_turns = max_turns self.on_event = on_event self.max_context_turns = max_context_turns + self.cancel_event = cancel_event # Message history - use provided messages or create new list self.messages = messages if messages is not None else [] @@ -97,6 +135,73 @@ class AgentStreamExecutor: # Track files to send (populated by read tool) self.files_to_send = [] # List of file metadata dicts + def _check_cancelled(self) -> None: + """Raise AgentCancelledError if the user requested cancellation. + + Called at safe points (turn start, between tool calls, between LLM + chunks). Cheap to call: just an Event.is_set() probe. + """ + if self.cancel_event is not None and self.cancel_event.is_set(): + raise AgentCancelledError("agent cancelled by user") + + def _handle_cancelled(self, partial_response: str) -> None: + """Wind down ``self.messages`` after a user-initiated cancel. + + The messages list may be in any of these states when we get here: + (a) Last message is an assistant message containing tool_use + blocks but the matching tool_result has not been appended yet. + (b) Last message is an assistant text-only reply (cancel happened + right before the next turn started). + (c) Last message is a user tool_result message and we cancelled + between turns. + + For (a) we MUST synthesise tool_result blocks, otherwise the next + request will fail Claude/OpenAI's strict pairing validation. For + (b)/(c) the state is already valid and we just append a small + cancellation note so the user/LLM both see the boundary clearly. + """ + try: + # Step 1: close any orphaned tool_use in the trailing assistant + # message by injecting matching tool_result blocks. + if self.messages and isinstance(self.messages[-1], dict) \ + and self.messages[-1].get("role") == "assistant": + last = self.messages[-1] + content = last.get("content") + if isinstance(content, list): + pending_tool_use_ids = [ + block.get("id") + for block in content + if isinstance(block, dict) and block.get("type") == "tool_use" + ] + pending_tool_use_ids = [tid for tid in pending_tool_use_ids if tid] + if pending_tool_use_ids: + tool_result_blocks = [ + { + "type": "tool_result", + "tool_use_id": tid, + "content": "Cancelled by user before this tool finished.", + "is_error": True, + } + for tid in pending_tool_use_ids + ] + self.messages.append({ + "role": "user", + "content": tool_result_blocks, + }) + logger.info( + f"[Agent] Injected {len(tool_result_blocks)} cancellation " + f"tool_result blocks to keep message history valid" + ) + + # Step 2: append a stable "interrupted" marker so the LLM sees a + # clear stop boundary on the next turn. + self.messages.append({ + "role": "assistant", + "content": [{"type": "text", "text": "_(Cancelled by user)_"}], + }) + except Exception as e: + logger.warning(f"[Agent] _handle_cancelled cleanup failed: {e}") + def _emit_event(self, event_type: str, data: dict = None): """Emit event""" if self.on_event: @@ -270,8 +375,13 @@ class AgentStreamExecutor: final_response = "" turn = 0 + cancelled = False try: while turn < self.max_turns: + # Check at the very top of every turn so a cancel arriving + # between turns short-circuits cleanly. + self._check_cancelled() + turn += 1 logger.info(f"[Agent] 第 {turn} 轮") self._emit_event("turn_start", {"turn": turn}) @@ -375,6 +485,8 @@ class AgentStreamExecutor: try: for tool_call in tool_calls: + # Honour cancel between tool invocations within the same turn + self._check_cancelled() result = self._execute_tool(tool_call) tool_results.append(result) @@ -557,6 +669,15 @@ class AgentStreamExecutor: self.messages.pop(prompt_insert_idx) logger.debug("[Agent] Removed injected max-steps prompt from message history") + except AgentCancelledError: + # User-initiated stop: wind down message history cleanly so the + # next turn is unaffected; channels emit a "cancelled" UI event. + cancelled = True + logger.info(f"[Agent] 🛑 已被用户中止 (第 {turn} 轮)") + self._handle_cancelled(final_response) + if not final_response or not final_response.strip(): + final_response = "_(Cancelled)_" + except Exception as e: logger.error(f"❌ Agent执行错误: {e}") self._emit_event("error", {"error": str(e)}) @@ -564,8 +685,11 @@ class AgentStreamExecutor: finally: final_response = final_response.strip() if final_response else final_response - logger.info(f"[Agent] 🏁 完成 ({turn}轮)") - self._emit_event("agent_end", {"final_response": final_response}) + if cancelled: + # Emit before agent_end so channels can mark UI as cancelled + self._emit_event("agent_cancelled", {"final_response": final_response}) + logger.info(f"[Agent] 🏁 完成 ({turn}轮)" + (" [cancelled]" if cancelled else "")) + self._emit_event("agent_end", {"final_response": final_response, "cancelled": cancelled}) return final_response @@ -603,15 +727,24 @@ class AgentStreamExecutor: except Exception as e: logger.debug(f"[Agent] MCP sync skipped: {e}") - # Prepare tool definitions (OpenAI/Claude format) + # Prepare tool definitions. Prefer get_json_schema() when it yields + # real properties (lets tools augment schema at runtime), otherwise + # fall back to the static `tool.params` (MCP tools rely on this). tools_schema = None if self.tools: tools_schema = [] for tool in self.tools.values(): + input_schema = tool.params + try: + dynamic = (tool.get_json_schema() or {}).get("parameters") or {} + if dynamic.get("properties"): + input_schema = dynamic + except Exception: + pass tools_schema.append({ "name": tool.name, "description": tool.description, - "input_schema": tool.params # Claude uses input_schema + "input_schema": input_schema, }) # Create request @@ -635,7 +768,32 @@ class AgentStreamExecutor: try: stream = self.model.call_stream(request) + # Probe cancel every N chunks to bound reaction time without + # checking on every token. + _cancel_probe_counter = 0 + _CANCEL_PROBE_EVERY = 8 + for chunk in stream: + _cancel_probe_counter += 1 + if _cancel_probe_counter >= _CANCEL_PROBE_EVERY: + _cancel_probe_counter = 0 + if self.cancel_event is not None and self.cancel_event.is_set(): + # Persist partial text only; tool_use args may be + # truncated mid-stream and would fail validation. + logger.info("[Agent] cancel detected mid-stream, aborting LLM call") + if full_content: + partial_msg = { + "role": "assistant", + "content": [{"type": "text", "text": full_content}], + } + self.messages.append(partial_msg) + self._emit_event("message_end", { + "content": full_content, + "tool_calls": [], + "cancelled": True, + }) + raise AgentCancelledError("cancelled during LLM streaming") + # Check for errors if isinstance(chunk, dict) and chunk.get("error"): # Extract error message from nested structure @@ -729,6 +887,10 @@ class AgentStreamExecutor: elif isinstance(choice, dict) and choice.get("_gemini_raw_parts"): gemini_raw_parts = choice["_gemini_raw_parts"] + except AgentCancelledError: + # Must propagate untouched; never treat as a retryable error. + raise + except Exception as e: error_str = str(e) error_str_lower = error_str.lower() @@ -842,26 +1004,17 @@ class AgentStreamExecutor: import uuid tool_id = f"call_{uuid.uuid4().hex[:24]}" - try: - # Safely get arguments, handle None case - args_str = tc.get("arguments") or "" - arguments = json.loads(args_str) if args_str else {} - except json.JSONDecodeError as e: - # Handle None or invalid arguments safely - args_str = tc.get('arguments') or "" - args_preview = args_str[:200] if len(args_str) > 200 else args_str - logger.error(f"Failed to parse tool arguments for {tc['name']}") - logger.error(f"Arguments length: {len(args_str)} chars") - logger.error(f"Arguments preview: {args_preview}...") - logger.error(f"JSON decode error: {e}") - - # Return a clear error message to the LLM instead of empty dict - # This helps the LLM understand what went wrong + args_str = tc.get("arguments") or "" + arguments, parse_err = _parse_tool_args(args_str, stop_reason) + if parse_err: + logger.error( + f"Tool args parse failed for {tc['name']} ({len(args_str)} chars): {parse_err}" + ) tool_calls.append({ "id": tool_id, "name": tc["name"], "arguments": {}, - "_parse_error": f"Invalid JSON in tool arguments: {args_preview}... Error: {str(e)}. Tip: For large content, consider splitting into smaller chunks or using a different approach." + "_parse_error": parse_err, }) continue @@ -949,14 +1102,11 @@ class AgentStreamExecutor: tool_id = tool_call["id"] arguments = tool_call["arguments"] - # Check if there was a JSON parse error if "_parse_error" in tool_call: - parse_error = tool_call["_parse_error"] - logger.error(f"Skipping tool execution due to parse error: {parse_error}") result = { "status": "error", - "result": f"Failed to parse tool arguments. {parse_error}. Please ensure your tool call uses valid JSON format with all required parameters.", - "execution_time": 0 + "result": tool_call["_parse_error"], + "execution_time": 0, } self._record_tool_result(tool_name, arguments, False) return result diff --git a/agent/protocol/cancel.py b/agent/protocol/cancel.py new file mode 100644 index 00000000..6354cd38 --- /dev/null +++ b/agent/protocol/cancel.py @@ -0,0 +1,121 @@ +""" +Cancel token registry for aborting in-flight agent runs. + +A user cancel (web Cancel button, /cancel command) sets a threading.Event +that the agent loop polls at safe checkpoints. Tokens are keyed by +request_id (preferred) and tracked under session_id as a fallback. Entries +are released after the run completes to keep the registry bounded. + +No project deps — importable from any layer without circular imports. +""" + +from __future__ import annotations + +import threading +from typing import Dict, Optional + + +class AgentCancelledError(Exception): + """Raised inside the agent loop when a stop has been requested. + + The agent stream executor catches this, injects a "[Interrupted]" note + into the message history (preserving tool_use/tool_result integrity) + and returns a partial response to the caller. + """ + + +class _CancelEntry: + __slots__ = ("event", "session_id") + + def __init__(self, session_id: Optional[str]): + self.event = threading.Event() + self.session_id = session_id + + +class CancelTokenRegistry: + """In-process registry mapping request_id -> cancel Event. + + Thread-safe. Singleton via module-level ``_registry``. + """ + + def __init__(self): + self._lock = threading.Lock() + self._by_request: Dict[str, _CancelEntry] = {} + # session_id -> set of request_ids currently in flight (usually 1). + self._by_session: Dict[str, set] = {} + + def register(self, request_id: str, session_id: Optional[str] = None) -> threading.Event: + """Create (or return existing) cancel event for a request. + + Returns the threading.Event the caller should poll via ``is_set()``. + """ + if not request_id: + return threading.Event() + with self._lock: + entry = self._by_request.get(request_id) + if entry is None: + entry = _CancelEntry(session_id) + self._by_request[request_id] = entry + if session_id: + self._by_session.setdefault(session_id, set()).add(request_id) + return entry.event + + def get_event(self, request_id: str) -> Optional[threading.Event]: + if not request_id: + return None + with self._lock: + entry = self._by_request.get(request_id) + return entry.event if entry else None + + def cancel_request(self, request_id: str) -> bool: + """Trigger cancel for a specific request. Returns True when matched.""" + if not request_id: + return False + with self._lock: + entry = self._by_request.get(request_id) + if entry is None: + return False + entry.event.set() + return True + + def cancel_session(self, session_id: str) -> int: + """Trigger cancel for every in-flight request of a session. + + Returns the number of requests cancelled (0 when nothing was running). + """ + if not session_id: + return 0 + with self._lock: + request_ids = list(self._by_session.get(session_id, ())) + entries = [self._by_request[r] for r in request_ids if r in self._by_request] + for entry in entries: + entry.event.set() + return len(entries) + + def unregister(self, request_id: str) -> None: + """Remove an entry once the agent run is done. Safe to call twice.""" + if not request_id: + return + with self._lock: + entry = self._by_request.pop(request_id, None) + if entry and entry.session_id: + bucket = self._by_session.get(entry.session_id) + if bucket is not None: + bucket.discard(request_id) + if not bucket: + self._by_session.pop(entry.session_id, None) + + def has_active(self, session_id: str) -> bool: + if not session_id: + return False + with self._lock: + bucket = self._by_session.get(session_id) + return bool(bucket) + + +_registry = CancelTokenRegistry() + + +def get_cancel_registry() -> CancelTokenRegistry: + """Module-level accessor for the singleton registry.""" + return _registry diff --git a/agent/tools/browser/browser_service.py b/agent/tools/browser/browser_service.py index 69ec0e06..f499fb29 100644 --- a/agent/tools/browser/browser_service.py +++ b/agent/tools/browser/browser_service.py @@ -15,7 +15,7 @@ import threading from typing import Optional, Dict, Any, List, Callable from common.log import logger -from common.utils import expand_path +from common.utils import expand_path, is_cloud_deployment _DEFAULT_USER_DATA_DIR = "~/.cow/browser_profile" @@ -436,6 +436,20 @@ class BrowserService: if self._headless: launch_args.append("--no-sandbox") + if is_cloud_deployment(): + launch_args.extend([ + "--disable-gpu", + "--disable-software-rasterizer", + "--disable-extensions", + "--disable-background-networking", + "--disable-background-timer-throttling", + "--disable-renderer-backgrounding", + "--disable-features=site-per-process,TranslateUI,IsolateOrigins", + "--no-zygote", + "--js-flags=--max-old-space-size=384", + "--memory-pressure-off", + ]) + extra_args = self._config.get("launch_args", []) if extra_args: launch_args.extend(extra_args) diff --git a/agent/tools/browser/browser_tool.py b/agent/tools/browser/browser_tool.py index c5139812..c91be26c 100644 --- a/agent/tools/browser/browser_tool.py +++ b/agent/tools/browser/browser_tool.py @@ -145,7 +145,8 @@ class BrowserTool(BaseTool): url = args.get("url", "").strip() if not url: return ToolResult.fail("Error: 'url' is required for navigate action") - if not url.startswith(("http://", "https://")): + # Only auto-prepend https:// for bare hosts; preserve file://, about:, data:, etc. + if "://" not in url and not url.startswith(("about:", "data:")): url = "https://" + url timeout = args.get("timeout", 30000) service = self._get_service() diff --git a/agent/tools/mcp/mcp_client.py b/agent/tools/mcp/mcp_client.py index 694a0c46..be93c716 100644 --- a/agent/tools/mcp/mcp_client.py +++ b/agent/tools/mcp/mcp_client.py @@ -1,8 +1,8 @@ """ MCP (Model Context Protocol) client module. -Implements JSON-RPC 2.0 over stdio and SSE transports without any external -MCP SDK dependency. +Implements JSON-RPC 2.0 over stdio, SSE and Streamable HTTP transports +without any external MCP SDK dependency. """ import json @@ -17,18 +17,29 @@ from typing import Optional from common.log import logger +# Aliases accepted for the Streamable HTTP transport type +_STREAMABLE_HTTP_ALIASES = {"streamable-http", "streamable_http", "streamablehttp", "http"} + + class McpClient: - """Single MCP Server client supporting stdio and SSE transports.""" + """Single MCP Server client supporting stdio, SSE and Streamable HTTP transports.""" def __init__(self, config: dict): """ config examples: - stdio: {"name": "filesystem", "type": "stdio", "command": "npx", "args": [...]} - SSE: {"name": "my-api", "type": "sse", "url": "http://localhost:8000/sse"} + stdio: {"name": "filesystem", "type": "stdio", "command": "npx", "args": [...]} + SSE: {"name": "my-api", "type": "sse", "url": "http://localhost:8000/sse"} + streamable-http: {"name": "pubmed", "type": "streamable-http", "url": "https://x/mcp"} """ self.config = config self.name: str = config.get("name", "unknown") - self.transport: str = config.get("type", "stdio") + raw_transport: str = config.get("type", "stdio") + # Normalize streamable-http aliases to a single internal key + self.transport: str = ( + "streamable-http" + if raw_transport.lower() in _STREAMABLE_HTTP_ALIASES + else raw_transport + ) # stdio state self._proc: Optional[subprocess.Popen] = None @@ -37,6 +48,11 @@ class McpClient: self._sse_url: Optional[str] = None self._post_url: Optional[str] = None # endpoint for sending messages (resolved from SSE) + # Streamable HTTP state + self._http_url: Optional[str] = None + self._http_headers: dict = {} # extra headers from user config (e.g. Authorization) + self._http_session_id: Optional[str] = None # Mcp-Session-Id assigned by the server + # Shared state self._next_id = 1 self._id_lock = threading.Lock() @@ -54,6 +70,8 @@ class McpClient: return self._init_stdio() elif self.transport == "sse": return self._init_sse() + elif self.transport == "streamable-http": + return self._init_streamable_http() else: logger.warning(f"[MCP:{self.name}] Unknown transport type: {self.transport!r}") return False @@ -109,6 +127,21 @@ class McpClient: pass self._proc = None logger.debug(f"[MCP:{self.name}] stdio process terminated") + + # Best-effort streamable-http session termination + if self.transport == "streamable-http" and self._http_session_id and self._http_url: + try: + req = urllib.request.Request( + self._http_url, + method="DELETE", + headers={"Mcp-Session-Id": self._http_session_id, **self._http_headers}, + ) + with urllib.request.urlopen(req, timeout=5): + pass + except Exception: + pass + self._http_session_id = None + self._initialized = False # ------------------------------------------------------------------ @@ -234,6 +267,120 @@ class McpClient: raw = resp.read().decode("utf-8") return json.loads(raw) + # ------------------------------------------------------------------ + # Streamable HTTP transport (MCP spec 2025-03-26) + # ------------------------------------------------------------------ + + def _init_streamable_http(self) -> bool: + url = self.config.get("url") + if not url: + logger.warning(f"[MCP:{self.name}] streamable-http config missing 'url'") + return False + + self._http_url = url + # Allow user-provided headers (e.g. {"Authorization": "Bearer xxx"}) + extra_headers = self.config.get("headers") or {} + if isinstance(extra_headers, dict): + self._http_headers = {str(k): str(v) for k, v in extra_headers.items()} + + return self._handshake() + + def _streamable_http_send(self, message: dict) -> dict: + """POST a JSON-RPC request and return the response (JSON or SSE-wrapped).""" + return self._streamable_http_post(message, expect_response=True) + + def _streamable_http_post(self, message: dict, expect_response: bool) -> dict: + """ + POST a JSON-RPC message over Streamable HTTP. + + Per the spec, the response Content-Type can be either: + - application/json -> single JSON-RPC response in body + - text/event-stream -> SSE stream; we read until we get a matching response + """ + body = json.dumps(message).encode("utf-8") + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + if self._http_session_id: + headers["Mcp-Session-Id"] = self._http_session_id + headers.update(self._http_headers) + + req = urllib.request.Request( + self._http_url, + data=body, + method="POST", + headers=headers, + ) + + try: + resp = urllib.request.urlopen(req, timeout=30) + except urllib.error.HTTPError as e: + # Surface the server-provided error body for easier debugging + detail = "" + try: + detail = e.read().decode("utf-8", errors="ignore") + except Exception: + pass + raise IOError( + f"[MCP:{self.name}] streamable-http HTTP {e.code}: {detail[:200]}" + ) + + with resp: + # Capture session id assigned by the server (if any) + session_id = resp.headers.get("Mcp-Session-Id") + if session_id and not self._http_session_id: + self._http_session_id = session_id + + status = resp.status if hasattr(resp, "status") else resp.getcode() + + # Notifications: server may reply with 202 Accepted and no body + if not expect_response or status == 202: + try: + resp.read() + except Exception: + pass + return {} + + content_type = (resp.headers.get("Content-Type") or "").lower() + expected_id = message.get("id") + + if "text/event-stream" in content_type: + return self._read_sse_response(resp, expected_id) + + raw = resp.read().decode("utf-8") + if not raw: + return {} + return json.loads(raw) + + def _read_sse_response(self, resp, expected_id) -> dict: + """Read an SSE stream and return the first JSON-RPC response with matching id.""" + data_buf: list = [] + for raw_line in resp: + line = raw_line.decode("utf-8").rstrip("\n\r") + if line == "": + # End of an SSE event, attempt to parse accumulated data + if data_buf: + payload = "\n".join(data_buf) + data_buf = [] + try: + msg = json.loads(payload) + except json.JSONDecodeError: + continue + # Skip notifications / mismatched ids + if "id" not in msg: + continue + if expected_id is None or msg.get("id") == expected_id: + return msg + continue + if line.startswith(":"): + continue # SSE comment / keepalive + if line.startswith("data:"): + data_buf.append(line[len("data:"):].lstrip()) + # Ignore 'event:' / 'id:' lines; we only care about JSON-RPC payloads + + raise IOError(f"[MCP:{self.name}] streamable-http SSE stream closed before response") + # ------------------------------------------------------------------ # Common JSON-RPC helpers # ------------------------------------------------------------------ @@ -267,6 +414,8 @@ class McpClient: return self._stdio_send(message) elif self.transport == "sse": return self._sse_send(message) + elif self.transport == "streamable-http": + return self._streamable_http_send(message) else: raise ValueError(f"[MCP:{self.name}] Unsupported transport: {self.transport}") @@ -291,6 +440,11 @@ class McpClient: pass except Exception: pass # notifications are fire-and-forget + elif self.transport == "streamable-http": + try: + self._streamable_http_post(notification, expect_response=False) + except Exception: + pass # notifications are fire-and-forget def _handshake(self) -> bool: """Perform the MCP initialize / notifications/initialized handshake.""" diff --git a/agent/tools/scheduler/integration.py b/agent/tools/scheduler/integration.py index 9e559a43..7421a525 100644 --- a/agent/tools/scheduler/integration.py +++ b/agent/tools/scheduler/integration.py @@ -57,34 +57,44 @@ def init_scheduler(agent_bridge) -> bool: _task_store = TaskStore(store_path) logger.debug(f"[Scheduler] Task store initialized: {store_path}") - # Create execute callback + # Create execute callback. Returns True on success, False to ask + # the scheduler to retry on the next tick (e.g. channel not yet + # ready right after process start). def execute_task_callback(task: dict): - """Callback to execute a scheduled task""" try: action = task.get("action", {}) action_type = action.get("type") + channel_type = action.get("channel_type", "unknown") + receiver = action.get("receiver", "") + + if not _is_channel_ready(channel_type, receiver): + logger.warning( + f"[Scheduler] Task {task.get('id')}: channel " + f"'{channel_type}' not ready for receiver={receiver} " + f"(no inbound msg cached since restart?); deferring" + ) + return False if action_type == "agent_task": - _execute_agent_task(task, agent_bridge) + return _execute_agent_task(task, agent_bridge) elif action_type == "send_message": - # Legacy support for old tasks - _execute_send_message(task, agent_bridge) + return _execute_send_message(task, agent_bridge) elif action_type == "tool_call": - # Legacy support for old tasks - _execute_tool_call(task, agent_bridge) + return _execute_tool_call(task, agent_bridge) elif action_type == "skill_call": - # Legacy support for old tasks - _execute_skill_call(task, agent_bridge) + return _execute_skill_call(task, agent_bridge) else: logger.warning(f"[Scheduler] Unknown action type: {action_type}") + return True except Exception as e: logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}") + return False # Create scheduler service _scheduler_service = SchedulerService(_task_store, execute_task_callback) _scheduler_service.start() - logger.debug("[Scheduler] Scheduler service initialized and started") + logger.info("[Scheduler] Service initialized and started") return True except Exception as e: @@ -92,6 +102,40 @@ def init_scheduler(agent_bridge) -> bool: return False +def _is_channel_ready(channel_type: str, receiver: str) -> bool: + """Best-effort readiness probe for outbound channels. + + Returns False when we know the send will drop (e.g. weixin not yet + logged in, web session has no polling queue), so the scheduler can + defer instead of consuming the task. Unknown channels return True + to preserve previous behaviour. + """ + if not channel_type or channel_type == "unknown": + return True + try: + from channel.channel_factory import create_channel + channel = create_channel(channel_type) + if channel is None: + return False + + if channel_type == "weixin": + tokens = getattr(channel, "_context_tokens", None) + if not tokens or receiver not in tokens: + return False + return True + + if channel_type == "web": + queues = getattr(channel, "session_queues", None) + if not queues or receiver not in queues: + return False + return True + + return True + except Exception as e: + logger.warning(f"[Scheduler] Channel readiness check failed for {channel_type}: {e}") + return True + + def get_task_store(): """Get the global task store instance""" return _task_store @@ -145,13 +189,10 @@ def _remember_delivered_output( ) -def _execute_agent_task(task: dict, agent_bridge): +def _execute_agent_task(task: dict, agent_bridge) -> bool: """ - Execute an agent_task action - let Agent handle the task - - Args: - task: Task dictionary - agent_bridge: AgentBridge instance + Execute an agent_task action - let Agent handle the task. + Returns True on successful delivery, False to retry next tick. """ try: action = task.get("action", {}) @@ -162,11 +203,11 @@ def _execute_agent_task(task: dict, agent_bridge): if not task_description: logger.error(f"[Scheduler] Task {task['id']}: No task_description specified") - return + return True # malformed task, don't loop forever if not receiver: logger.error(f"[Scheduler] Task {task['id']}: No receiver specified") - return + return True # Check for unsupported channels if channel_type == "dingtalk": @@ -209,51 +250,47 @@ def _execute_agent_task(task: dict, agent_bridge): try: # Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=False) - - if reply and reply.content: - # Send the reply via channel - from channel.channel_factory import create_channel - - try: - channel = create_channel(channel_type) - if channel: - # For web channel, register request_id - if channel_type == "web" and hasattr(channel, 'request_to_session'): - request_id = context.get("request_id") - if request_id: - channel.request_to_session[request_id] = receiver - logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}") - - # Send the reply - channel.send(reply, context) - _remember_delivered_output(agent_bridge, task, channel_type, reply.content) - logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}") - else: - logger.error(f"[Scheduler] Failed to create channel: {channel_type}") - except Exception as e: - logger.error(f"[Scheduler] Failed to send result: {e}") - else: + + if not (reply and reply.content): logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution") - + return True # agent ran but produced nothing; don't loop + + from channel.channel_factory import create_channel + channel = create_channel(channel_type) + if not channel: + logger.error(f"[Scheduler] Failed to create channel: {channel_type}") + return False + + if channel_type == "web" and hasattr(channel, 'request_to_session'): + request_id = context.get("request_id") + if request_id: + channel.request_to_session[request_id] = receiver + + try: + channel.send(reply, context) + except Exception as e: + logger.error(f"[Scheduler] Failed to send result: {e}") + return False + + _remember_delivered_output(agent_bridge, task, channel_type, reply.content) + logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}") + return True + except Exception as e: logger.error(f"[Scheduler] Failed to execute task via Agent: {e}") import traceback logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") - + return False + except Exception as e: logger.error(f"[Scheduler] Error in _execute_agent_task: {e}") import traceback logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") + return False -def _execute_send_message(task: dict, agent_bridge): - """ - Execute a send_message action - - Args: - task: Task dictionary - agent_bridge: AgentBridge instance - """ +def _execute_send_message(task: dict, agent_bridge) -> bool: + """Execute a send_message action. Returns True/False for delivery.""" try: action = task.get("action", {}) content = action.get("content", "") @@ -263,7 +300,7 @@ def _execute_send_message(task: dict, agent_bridge): if not receiver: logger.error(f"[Scheduler] Task {task['id']}: No receiver specified") - return + return True # Create context for sending message context = Context(ContextType.TEXT, content) @@ -308,169 +345,135 @@ def _execute_send_message(task: dict, agent_bridge): # Get channel and send from channel.channel_factory import create_channel + channel = create_channel(channel_type) + if not channel: + logger.error(f"[Scheduler] Failed to create channel: {channel_type}") + return False + + if channel_type == "web" and hasattr(channel, 'request_to_session'): + channel.request_to_session[request_id] = receiver + try: - channel = create_channel(channel_type) - if channel: - # For web channel, register the request_id to session mapping - if channel_type == "web" and hasattr(channel, 'request_to_session'): - channel.request_to_session[request_id] = receiver - logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}") - - channel.send(reply, context) - _remember_delivered_output(agent_bridge, task, channel_type, content) - logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}") - else: - logger.error(f"[Scheduler] Failed to create channel: {channel_type}") + channel.send(reply, context) except Exception as e: logger.error(f"[Scheduler] Failed to send message: {e}") - import traceback - logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") - + return False + + _remember_delivered_output(agent_bridge, task, channel_type, content) + logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}") + return True + except Exception as e: logger.error(f"[Scheduler] Error in _execute_send_message: {e}") import traceback logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") + return False -def _execute_tool_call(task: dict, agent_bridge): - """ - Execute a tool_call action - - Args: - task: Task dictionary - agent_bridge: AgentBridge instance - """ +def _execute_tool_call(task: dict, agent_bridge) -> bool: + """Execute a tool_call action. Returns True/False for delivery.""" try: action = task.get("action", {}) - # Support both old and new field names tool_name = action.get("call_name") or action.get("tool_name") tool_params = action.get("call_params") or action.get("tool_params", {}) result_prefix = action.get("result_prefix", "") receiver = action.get("receiver") is_group = action.get("is_group", False) channel_type = action.get("channel_type", "unknown") - + if not tool_name: logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified") - return - + return True if not receiver: logger.error(f"[Scheduler] Task {task['id']}: No receiver specified") - return - - # Get tool manager and create tool instance + return True + from agent.tools.tool_manager import ToolManager - tool_manager = ToolManager() - tool = tool_manager.create_tool(tool_name) - + tool = ToolManager().create_tool(tool_name) if not tool: logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found") - return - - # Execute tool + return True + logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}") result = tool.execute(tool_params) - - # Get result content - if hasattr(result, 'result'): - content = result.result - else: - content = str(result) - - # Add prefix if specified + content = result.result if hasattr(result, 'result') else str(result) if result_prefix: content = f"{result_prefix}\n\n{content}" - - # Send result as message + context = Context(ContextType.TEXT, content) context["receiver"] = receiver context["isgroup"] = is_group context["session_id"] = receiver - - # Channel-specific context setup + + request_id = None if channel_type == "web": - # Web channel needs request_id import uuid request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}" context["request_id"] = request_id - logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}") elif channel_type == "feishu": context["receive_id_type"] = "chat_id" if is_group else "open_id" context["msg"] = None - logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}") elif channel_type == "wecom_bot": context["msg"] = None reply = Reply(ReplyType.TEXT, content) - # Get channel and send from channel.channel_factory import create_channel + channel = create_channel(channel_type) + if not channel: + logger.error(f"[Scheduler] Failed to create channel: {channel_type}") + return False + + if channel_type == "web" and request_id and hasattr(channel, 'request_to_session'): + channel.request_to_session[request_id] = receiver try: - channel = create_channel(channel_type) - if channel: - if channel_type == "web" and hasattr(channel, 'request_to_session'): - channel.request_to_session[request_id] = receiver - logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}") - - channel.send(reply, context) - _remember_delivered_output(agent_bridge, task, channel_type, content) - logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}") - else: - logger.error(f"[Scheduler] Failed to create channel: {channel_type}") + channel.send(reply, context) except Exception as e: logger.error(f"[Scheduler] Failed to send tool result: {e}") + return False + + _remember_delivered_output(agent_bridge, task, channel_type, content) + logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}") + return True except Exception as e: logger.error(f"[Scheduler] Error in _execute_tool_call: {e}") + return False -def _execute_skill_call(task: dict, agent_bridge): - """ - Execute a skill_call action by asking Agent to run the skill - - Args: - task: Task dictionary - agent_bridge: AgentBridge instance - """ +def _execute_skill_call(task: dict, agent_bridge) -> bool: + """Execute a skill_call action by asking Agent to run the skill. + Returns True/False for delivery.""" try: action = task.get("action", {}) - # Support both old and new field names skill_name = action.get("call_name") or action.get("skill_name") skill_params = action.get("call_params") or action.get("skill_params", {}) result_prefix = action.get("result_prefix", "") receiver = action.get("receiver") is_group = action.get("isgroup", False) channel_type = action.get("channel_type", "unknown") - + if not skill_name: logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified") - return - + return True if not receiver: logger.error(f"[Scheduler] Task {task['id']}: No receiver specified") - return - + return True + logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}") - - # Create a unique session_id for this scheduled task to avoid polluting user's conversation - # Format: scheduler__ to ensure isolation + scheduler_session_id = f"scheduler_{receiver}_{task['id']}" - - # Build a natural language query for the Agent to execute the skill - # Format: "Use skill-name to do something with params" param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()]) query = f"Use {skill_name} skill" if param_str: query += f" with {param_str}" - - # Create context for Agent + context = Context(ContextType.TEXT, query) context["receiver"] = receiver context["isgroup"] = is_group context["session_id"] = scheduler_session_id - - # Channel-specific setup + if channel_type == "web": import uuid request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}" @@ -481,49 +484,48 @@ def _execute_skill_call(task: dict, agent_bridge): elif channel_type == "wecom_bot": context["msg"] = None - # Use Agent to execute the skill try: - # Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=False) - - if reply and reply.content: - content = reply.content - - # Add prefix if specified - if result_prefix: - content = f"{result_prefix}\n\n{content}" - - # Send the result via channel - from channel.channel_factory import create_channel - - try: - channel = create_channel(channel_type) - if channel: - # For web channel, register request_id - if channel_type == "web" and hasattr(channel, 'request_to_session'): - req_id = context.get("request_id") - if req_id: - channel.request_to_session[req_id] = receiver - logger.debug(f"[Scheduler] Registered request_id {req_id} -> session {receiver}") - - channel.send(Reply(ReplyType.TEXT, content), context) - _remember_delivered_output(agent_bridge, task, channel_type, content) - except Exception as e: - logger.error(f"[Scheduler] Failed to send skill result: {e}") - - logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}") - else: - logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution") - except Exception as e: logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}") import traceback logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") - + return False + + if not (reply and reply.content): + logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution") + return True + + content = reply.content + if result_prefix: + content = f"{result_prefix}\n\n{content}" + + from channel.channel_factory import create_channel + channel = create_channel(channel_type) + if not channel: + logger.error(f"[Scheduler] Failed to create channel: {channel_type}") + return False + + if channel_type == "web" and hasattr(channel, 'request_to_session'): + req_id = context.get("request_id") + if req_id: + channel.request_to_session[req_id] = receiver + + try: + channel.send(Reply(ReplyType.TEXT, content), context) + except Exception as e: + logger.error(f"[Scheduler] Failed to send skill result: {e}") + return False + + _remember_delivered_output(agent_bridge, task, channel_type, content) + logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}") + return True + except Exception as e: logger.error(f"[Scheduler] Error in _execute_skill_call: {e}") import traceback logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}") + return False def attach_scheduler_to_tool(tool, context: Context = None): diff --git a/agent/tools/scheduler/scheduler_service.py b/agent/tools/scheduler/scheduler_service.py index dd5369cb..1f4bc6fb 100644 --- a/agent/tools/scheduler/scheduler_service.py +++ b/agent/tools/scheduler/scheduler_service.py @@ -52,7 +52,6 @@ class SchedulerService: self.running = True self.thread = threading.Thread(target=self._run_loop, daemon=True) self.thread.start() - logger.debug("[Scheduler] Service started") def stop(self): """Stop the scheduler service""" @@ -67,7 +66,7 @@ class SchedulerService: def _run_loop(self): """Main scheduler loop""" - logger.debug("[Scheduler] Scheduler loop started") + logger.info("[Scheduler] Scheduler loop started") while self.running: try: @@ -84,12 +83,18 @@ class SchedulerService: for task in tasks: try: - # Check if task is due if self._is_task_due(task, now): logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}") - self._execute_task(task) - - # Update next run time + ok = self._execute_task(task) + if not ok: + # Leave next_run_at as-is so the next loop retries. + # Cron tasks within the catch-up window will keep + # firing; beyond it _is_task_due will reschedule. + logger.warning( + f"[Scheduler] Task {task['id']} delivery failed, will retry next tick" + ) + continue + next_run = self._calculate_next_run(task, now) if next_run: self.task_store.update_task(task['id'], { @@ -97,7 +102,6 @@ class SchedulerService: "last_run_at": now.isoformat() }) else: - # One-time task completed, remove it self.task_store.delete_task(task['id']) logger.info(f"[Scheduler] One-time task completed and removed: {task['id']}") except Exception as e: @@ -128,30 +132,35 @@ class SchedulerService: try: next_run = _parse_naive_local(next_run_str) - # Check if task is overdue (e.g., service restart) if next_run < now: time_diff = (now - next_run).total_seconds() - - # If overdue by more than 5 minutes, skip this run and schedule next - if time_diff > 300: # 5 minutes - logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run") - - # For one-time tasks, remove them directly - schedule = task.get("schedule", {}) - if schedule.get("type") == "once": - self.task_store.delete_task(task['id']) - logger.info(f"[Scheduler] One-time task {task['id']} expired, removed") - return False - - # For recurring tasks, calculate next run from now - next_next_run = self._calculate_next_run(task, now) - if next_next_run: - self.task_store.update_task(task['id'], { - "next_run_at": next_next_run.isoformat() - }) - logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}") + schedule = task.get("schedule", {}) + schedule_type = schedule.get("type") + + # Catch-up window: fire if we're within 10 minutes of the + # scheduled tick. Beyond that we'd rather skip than push a + # stale daily report to the user. + if time_diff <= 600: + return True + + logger.warning( + f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, " + f"skipping and scheduling next run" + ) + + if schedule_type == "once": + self.task_store.delete_task(task['id']) + logger.info(f"[Scheduler] One-time task {task['id']} expired, removed") return False - + + next_next_run = self._calculate_next_run(task, now) + if next_next_run: + self.task_store.update_task(task['id'], { + "next_run_at": next_next_run.isoformat() + }) + logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}") + return False + return now >= next_run except Exception as e: logger.error( @@ -213,20 +222,22 @@ class SchedulerService: return None - def _execute_task(self, task: dict): + def _execute_task(self, task: dict) -> bool: """ - Execute a task - - Args: - task: Task dictionary + Execute a task. + + Returns True if delivery succeeded (caller should advance state), + False if it failed (caller should keep next_run_at so the next + loop iteration retries). Callback may return None for legacy + behaviour, treated as success. """ try: - # Call the execute callback - self.execute_callback(task) + result = self.execute_callback(task) + return False if result is False else True except Exception as e: logger.error(f"[Scheduler] Error executing task {task['id']}: {e}") - # Update task with error self.task_store.update_task(task['id'], { "last_error": str(e), "last_error_at": datetime.now().isoformat() }) + return False diff --git a/agent/tools/vision/vision.py b/agent/tools/vision/vision.py index a1c3265f..498f3cd8 100644 --- a/agent/tools/vision/vision.py +++ b/agent/tools/vision/vision.py @@ -3,7 +3,7 @@ Vision tool - Analyze images using Vision API. Supports local files (auto base64-encoded) and HTTP URLs. Provider resolution: - - tool.vision.model (if set) means "prefer this model first; fall back to + - tools.vision.model (if set) means "prefer this model first; fall back to other configured providers if it fails". The model name is mapped to its native provider (e.g. doubao-* → Doubao, kimi-* → Moonshot, gpt-* → OpenAI/LinkAI). That provider is tried first, then the standard auto @@ -53,14 +53,15 @@ _DISCOVERABLE_MODELS = [ ("ark_api_key", const.DOUBAO, const.DOUBAO_SEED_2_PRO, "Doubao"), ("dashscope_api_key", const.QWEN_DASHSCOPE, const.QWEN36_PLUS, "DashScope"), ("claude_api_key", const.CLAUDEAPI, const.CLAUDE_4_6_SONNET, "Claude"), - ("gemini_api_key", const.GEMINI, const.GEMINI_31_FLASH_LITE_PRE, "Gemini"), + ("gemini_api_key", const.GEMINI, const.GEMINI_35_FLASH, "Gemini"), ("qianfan_api_key", const.QIANFAN, const.ERNIE_45_TURBO_VL, "Qianfan"), ("zhipu_ai_api_key", const.ZHIPU_AI, const.GLM_4_7, "ZhipuAI"), ("minimax_api_key", const.MiniMax, const.MINIMAX_M2_7, "MiniMax"), + ("mimo_api_key", const.MIMO, const.MIMO_V2_5_PRO, "MiMo"), ] # Model name prefix → discoverable provider display_name. -# Used to auto-route tool.vision.model to its native provider. +# Used to auto-route tools.vision.model to its native provider. # Matched case-insensitively; longest prefix wins. _MODEL_PREFIX_TO_PROVIDER = [ ("doubao-", "Doubao"), @@ -73,11 +74,29 @@ _MODEL_PREFIX_TO_PROVIDER = [ ("glm-", "ZhipuAI"), ("minimax-", "MiniMax"), ("abab", "MiniMax"), + ("mimo-", "MiMo"), ] # Model prefixes that natively belong to OpenAI / LinkAI (raw HTTP providers). _OPENAI_MODEL_PREFIXES = ("gpt-", "o1-", "o3-", "o4-", "chatgpt-") +# Maps the UI provider id (persisted in tools.vision.provider) to the internal +# display name used in VisionProvider.name. Keep in sync with _DISCOVERABLE_MODELS +# and the openai/linkai branches in _route_by_model_name. +_PROVIDER_ID_TO_DISPLAY = { + "openai": "OpenAI", + "linkai": "LinkAI", + "moonshot": "Moonshot", + "doubao": "Doubao", + "dashscope": "DashScope", + "claudeAPI": "Claude", + "gemini": "Gemini", + "qianfan": "Qianfan", + "zhipu": "ZhipuAI", + "minimax": "MiniMax", + "mimo": "MiMo", +} + @dataclass class VisionProvider: @@ -154,7 +173,7 @@ class Vision(BaseTool): # Default model is only used as a last-resort placeholder for providers # whose VisionProvider.model_override is None (e.g. raw OpenAI provider - # when the user did not configure tool.vision.model). + # when the user did not configure tools.vision.model). return self._call_with_fallback(providers, DEFAULT_MODEL, question, image_content) def _call_with_fallback(self, providers: List[VisionProvider], model: str, @@ -193,12 +212,12 @@ class Vision(BaseTool): """ Build an ordered list of providers to try. - Semantics of `tool.vision.model`: + Semantics of `tools.vision.model`: "Prefer this model first; fall back to other configured providers if it fails." Order: - 1. The provider that natively serves `tool.vision.model` (if any + 1. The provider that natively serves `tools.vision.model` (if any and its API key is configured) — using the user-specified model name verbatim. 2. Auto-discovery chain as fallback: @@ -211,13 +230,19 @@ class Vision(BaseTool): are de-duplicated to avoid retrying the same endpoint twice. """ user_model = self._resolve_user_vision_model() + user_provider = self._resolve_user_vision_provider() providers: List[VisionProvider] = [] - # Step 1: preferred provider derived from tool.vision.model - if user_model: + # Step 1: preferred provider — explicit `tools.vision.provider` + # wins so custom model names can still be routed correctly. Falls + # through to model-name prefix inference when provider is unset. + preferred = None + if user_provider and user_model: + preferred = self._route_by_provider_id(user_provider, user_model) + if not preferred and user_model: preferred = self._route_by_model_name(user_model) - if preferred: - providers.extend(preferred) + if preferred: + providers.extend(preferred) # Step 2: auto-discovery chain as fallback existing = {p.name for p in providers} @@ -251,11 +276,11 @@ class Vision(BaseTool): @staticmethod def _resolve_user_vision_model() -> Optional[str]: - """Read tool.vision.model from config; return None if unset/blank.""" - tool_conf = conf().get("tool", {}) - if not isinstance(tool_conf, dict): + """Read tools.vision.model (singular ``tool`` kept as runtime fallback).""" + tools_conf = conf().get("tools") or conf().get("tool") or {} + if not isinstance(tools_conf, dict): return None - vision_conf = tool_conf.get("vision", {}) + vision_conf = tools_conf.get("vision", {}) if not isinstance(vision_conf, dict): return None m = vision_conf.get("model") @@ -263,6 +288,24 @@ class Vision(BaseTool): return m.strip() return None + @staticmethod + def _resolve_user_vision_provider() -> Optional[str]: + """Read tools.vision.provider — the UI-persisted vendor id. + + Lets users pin a vendor for custom model names that prefix-inference + can't recognize. Returns None when unset/blank. + """ + tools_conf = conf().get("tools") or conf().get("tool") or {} + if not isinstance(tools_conf, dict): + return None + vision_conf = tools_conf.get("vision", {}) + if not isinstance(vision_conf, dict): + return None + p = vision_conf.get("provider") + if isinstance(p, str) and p.strip(): + return p.strip() + return None + @staticmethod def _infer_provider_from_model(model_name: str) -> Optional[str]: """ @@ -279,6 +322,54 @@ class Vision(BaseTool): return display_name return None + def _route_by_provider_id(self, provider_id: str, user_model: str) -> Optional[List[VisionProvider]]: + """Route by the UI-persisted provider id. + + Returns: + - [provider] : provider id is known and its key is configured. + - None : unknown provider id, or the bot can't be created. + Caller falls through to model-name-based routing. + """ + display_name = _PROVIDER_ID_TO_DISPLAY.get(provider_id) + if not display_name: + return None + + # OpenAI / LinkAI use raw HTTP providers, not the discoverable bot path. + if provider_id == "openai": + p = self._build_openai_provider(user_model) + return [p] if p else None + if provider_id == "linkai": + p = self._build_linkai_provider(user_model) + return [p] if p else None + + # Discoverable bot-backed providers. + for config_key, bot_type, _default_model, name in _DISCOVERABLE_MODELS: + if name != display_name: + continue + api_key = conf().get(config_key, "") + if not api_key or not api_key.strip(): + logger.warning(f"[Vision] tools.vision.provider='{provider_id}' " + f"but '{config_key}' is not configured. Falling back.") + return None + try: + from models.bot_factory import create_bot + bot = create_bot(bot_type) + if not hasattr(bot, 'call_vision'): + logger.warning(f"[Vision] '{display_name}' bot does not implement call_vision.") + return None + except Exception as e: + logger.warning(f"[Vision] Failed to create '{display_name}' bot: {e}") + return None + return [VisionProvider( + name=display_name, + api_key="", + api_base="", + model_override=user_model, + use_bot=True, + fallback_bot=bot, + )] + return None + def _route_by_model_name(self, user_model: str) -> Optional[List[VisionProvider]]: """ Try to build a provider list using the user-specified model name. @@ -303,7 +394,7 @@ class Vision(BaseTool): self._append_provider(providers, lambda: self._build_linkai_provider(user_model)) if providers: return providers - logger.warning(f"[Vision] tool.vision.model='{user_model}' looks like an OpenAI " + logger.warning(f"[Vision] tools.vision.model='{user_model}' looks like an OpenAI " f"model but neither OPENAI_API_KEY nor LINKAI_API_KEY is configured.") return None # fall through to auto @@ -317,7 +408,7 @@ class Vision(BaseTool): continue api_key = conf().get(config_key, "") if not api_key or not api_key.strip(): - logger.warning(f"[Vision] tool.vision.model='{user_model}' routes to " + logger.warning(f"[Vision] tools.vision.model='{user_model}' routes to " f"'{display_name}' but '{config_key}' is not configured. " f"Falling back to auto-discovery.") return None # fall through to auto @@ -452,8 +543,8 @@ class Vision(BaseTool): if not self._main_bot_supports_vision(bot): return None - # Use the configured main model name; do NOT inject tool.vision.model - # here, because by the time we reach this branch the tool.vision.model + # Use the configured main model name; do NOT inject tools.vision.model + # here, because by the time we reach this branch the tools.vision.model # routing has already been attempted (and either matched the main bot # or failed to find a provider). main_model_name = conf().get("model") or None diff --git a/agent/tools/web_search/web_search.py b/agent/tools/web_search/web_search.py index 4c6d1e45..ca56567d 100644 --- a/agent/tools/web_search/web_search.py +++ b/agent/tools/web_search/web_search.py @@ -1,13 +1,27 @@ -""" -Web Search tool - Search the web using Bocha or LinkAI search API. -Supports two backends with unified response format: - 1. Bocha Search (primary, requires BOCHA_API_KEY) - 2. LinkAI Search (fallback, requires LINKAI_API_KEY) +"""Web Search tool. Supports four backends with a unified response format: + - bocha (https://open.bochaai.com) + - zhipu (https://docs.bigmodel.cn/cn/guide/tools/web-search) + - qianfan (https://cloud.baidu.com/doc/qianfan/s/2mh4su4uy) + - linkai (https://link-ai.tech, fallback) + +Provider selection + - strategy 'auto' (default): pick the first configured provider in the + canonical order [bocha, zhipu, qianfan, linkai]. When the caller passes + an explicit `provider` it overrides the pick; an invalid/unconfigured + one silently falls back to the auto order. + - strategy 'fixed': use the configured provider; if its credential is + missing at call time, silently fall back to auto order (no card hint). + +Credentials + - bocha : tools.web_search.bocha_api_key -> env BOCHA_API_KEY + - zhipu : conf.zhipu_ai_api_key -> env ZHIPUAI_API_KEY + - qianfan : conf.qianfan_api_key -> env QIANFAN_API_KEY + - linkai : conf.linkai_api_key -> env LINKAI_API_KEY """ -import os import json -from typing import Dict, Any, Optional +import os +from typing import Any, Dict, List, Optional import requests @@ -16,12 +30,63 @@ from common.log import logger from config import conf -# Default timeout for API requests (seconds) DEFAULT_TIMEOUT = 30 +# Canonical fallback order. Empirically ordered by Chinese real-time +# quality + relevance: bocha (best overall), qianfan (best for hot news), +# zhipu (strong on long-form articles), linkai (cloud aggregator, last +# resort). +PROVIDER_ORDER = ("bocha", "qianfan", "zhipu", "linkai") + +PROVIDER_LABELS = { + "bocha": "Bocha", + "zhipu": "Zhipu", + "qianfan": "Baidu Qianfan", + "linkai": "LinkAI", +} + + +def _tools_web_search_conf() -> dict: + """Return the tools.web_search config block (dict-like).""" + tools_cfg = conf().get("tools") or {} + if not isinstance(tools_cfg, dict): + return {} + block = tools_cfg.get("web_search") or {} + return block if isinstance(block, dict) else {} + + +def _get_api_key(provider: str) -> str: + """Resolve API key for a provider, with conf -> env fallback.""" + if provider == "bocha": + key = (_tools_web_search_conf().get("bocha_api_key") or "").strip() + return key or os.environ.get("BOCHA_API_KEY", "").strip() + if provider == "zhipu": + key = (conf().get("zhipu_ai_api_key") or "").strip() + return key or os.environ.get("ZHIPUAI_API_KEY", "").strip() + if provider == "qianfan": + key = (conf().get("qianfan_api_key") or "").strip() + return key or os.environ.get("QIANFAN_API_KEY", "").strip() + if provider == "linkai": + key = (conf().get("linkai_api_key") or "").strip() + return key or os.environ.get("LINKAI_API_KEY", "").strip() + return "" + + +def configured_providers() -> List[str]: + """Return configured providers in canonical order.""" + return [p for p in PROVIDER_ORDER if _get_api_key(p)] + + +def _configured_strategy() -> str: + return (_tools_web_search_conf().get("strategy") or "auto").strip().lower() + + +def _configured_provider() -> str: + return (_tools_web_search_conf().get("provider") or "").strip().lower() + class WebSearch(BaseTool): - """Tool for searching the web using Bocha or LinkAI search API""" + """Tool for searching the web across multiple providers.""" name: str = "web_search" description: str = "Search the web for real-time information. Returns titles, URLs, and snippets." @@ -55,264 +120,368 @@ class WebSearch(BaseTool): def __init__(self, config: dict = None): self.config = config or {} - self._backend = None # Will be resolved on first execute @staticmethod def is_available() -> bool: - """Check if web search is available (at least one API key is configured)""" - return bool(os.environ.get("BOCHA_API_KEY") or os.environ.get("LINKAI_API_KEY")) + """Tool is offered to the agent when at least one provider has a key.""" + return bool(configured_providers()) - def _resolve_backend(self) -> Optional[str]: - """ - Determine which search backend to use. - Priority: Bocha > LinkAI + @classmethod + def get_json_schema(cls) -> dict: + """Augment the static schema with a `provider` field — only when the + user has ≥2 providers configured AND strategy is 'auto'. Otherwise + the backend picks silently and exposing the field would only waste + the agent's tokens.""" + schema = { + "name": cls.name, + "description": cls.description, + "parameters": json.loads(json.dumps(cls.params)), # deep copy + } + if _configured_strategy() != "auto": + return schema + available = configured_providers() + if len(available) < 2: + return schema - :return: 'bocha', 'linkai', or None + schema["parameters"]["properties"]["provider"] = { + "type": "string", + "enum": available, + "description": "Optional. Specifies the search backend. You may switch between providers when the user wants results from a particular source or from multiple sources.", + } + return schema + + # ------------------------------------------------------------------ + # Provider resolution + # ------------------------------------------------------------------ + + def _resolve_provider(self, requested: Optional[str]) -> Optional[str]: + """Pick a provider for this call. + + Priority: caller-supplied (if configured) > fixed strategy (if + configured) > first configured in PROVIDER_ORDER. Silent fallback + when the desired one has no key. """ - if os.environ.get("BOCHA_API_KEY"): - return "bocha" - if os.environ.get("LINKAI_API_KEY"): - return "linkai" - return None + available = configured_providers() + if not available: + return None + + if requested: + req = requested.strip().lower() + if req in available: + return req + logger.warning(f"[WebSearch] requested provider '{requested}' unavailable, falling back") + + if _configured_strategy() == "fixed": + pinned = _configured_provider() + if pinned in available: + return pinned + if pinned: + logger.warning(f"[WebSearch] pinned provider '{pinned}' unavailable, falling back to auto") + + return available[0] + + @staticmethod + def _resolution_reason(requested: Optional[str], chosen: str) -> str: + """Human-readable explanation for why `chosen` won the resolver.""" + if requested and requested.strip().lower() == chosen: + return "caller-requested" + strategy = _configured_strategy() + if strategy == "fixed" and _configured_provider() == chosen: + return "fixed-strategy" + return "auto-fallback" + + # ------------------------------------------------------------------ + # Entry point + # ------------------------------------------------------------------ def execute(self, args: Dict[str, Any]) -> ToolResult: - """ - Execute web search - - :param args: Search parameters (query, count, freshness, summary) - :return: Search results - """ - query = args.get("query", "").strip() + query = (args.get("query") or "").strip() if not query: return ToolResult.fail("Error: 'query' parameter is required") count = args.get("count", 10) freshness = args.get("freshness", "noLimit") summary = args.get("summary", False) - - # Validate count if not isinstance(count, int) or count < 1 or count > 50: count = 10 - # Resolve backend - backend = self._resolve_backend() - if not backend: + requested = args.get("provider") + provider = self._resolve_provider(requested) + if not provider: return ToolResult.fail( - "Error: No search API key configured. " - "Please set BOCHA_API_KEY or LINKAI_API_KEY using env_config tool.\n" - " - Bocha Search: https://open.bocha.cn\n" - " - LinkAI Search: https://link-ai.tech" + "Error: No search provider configured. " + "Configure one of BOCHA_API_KEY / zhipu_ai_api_key / qianfan_api_key / linkai_api_key." ) + # Always log the routing decision so multi-provider deployments can + # tell at a glance which backend served any given query. + available = configured_providers() + reason = self._resolution_reason(requested, provider) + q_preview = query if len(query) <= 60 else (query[:57] + "...") + logger.info( + f"[WebSearch] provider={provider} reason={reason} " + f"available={list(available)} query={q_preview!r} count={count} freshness={freshness}" + ) + try: - if backend == "bocha": + if provider == "bocha": return self._search_bocha(query, count, freshness, summary) - else: + if provider == "zhipu": + return self._search_zhipu(query, count, freshness) + if provider == "qianfan": + return self._search_qianfan(query, count, freshness) + if provider == "linkai": return self._search_linkai(query, count, freshness) + return ToolResult.fail(f"Error: Unknown provider '{provider}'") except requests.Timeout: return ToolResult.fail(f"Error: Search request timed out after {DEFAULT_TIMEOUT}s") except requests.ConnectionError: return ToolResult.fail("Error: Failed to connect to search API") except Exception as e: - logger.error(f"[WebSearch] Unexpected error: {e}", exc_info=True) + logger.error(f"[WebSearch] Unexpected error ({provider}): {e}", exc_info=True) return ToolResult.fail(f"Error: Search failed - {str(e)}") + # ------------------------------------------------------------------ + # Bocha + # ------------------------------------------------------------------ + def _search_bocha(self, query: str, count: int, freshness: str, summary: bool) -> ToolResult: - """ - Search using Bocha API - - :param query: Search query - :param count: Number of results - :param freshness: Time range filter - :param summary: Whether to include summary - :return: Formatted search results - """ - api_key = os.environ.get("BOCHA_API_KEY", "") - url = "https://api.bocha.cn/v1/web-search" - + api_key = _get_api_key("bocha") + url = "https://api.bochaai.com/v1/web-search" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", - "Accept": "application/json" + "Accept": "application/json", } + payload = {"query": query, "count": count, "freshness": freshness, "summary": summary} - payload = { - "query": query, - "count": count, - "freshness": freshness, - "summary": summary - } + logger.debug(f"[WebSearch] bocha: query='{query}', count={count}") + resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) - logger.debug(f"[WebSearch] Bocha search: query='{query}', count={count}") + if resp.status_code == 401: + return ToolResult.fail("Error: Invalid bocha API key.") + if resp.status_code == 403: + return ToolResult.fail("Error: bocha API — insufficient balance. Top up at https://open.bochaai.com") + if resp.status_code == 429: + return ToolResult.fail("Error: bocha API rate limit reached.") + if resp.status_code != 200: + return ToolResult.fail(f"Error: bocha API returned HTTP {resp.status_code}") - response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) - - if response.status_code == 401: - return ToolResult.fail("Error: Invalid BOCHA_API_KEY. Please check your API key.") - if response.status_code == 403: - return ToolResult.fail("Error: Bocha API - insufficient balance. Please top up at https://open.bocha.cn") - if response.status_code == 429: - return ToolResult.fail("Error: Bocha API rate limit reached. Please try again later.") - if response.status_code != 200: - return ToolResult.fail(f"Error: Bocha API returned HTTP {response.status_code}") - - data = response.json() - - # Check API-level error code + data = resp.json() api_code = data.get("code") if api_code is not None and api_code != 200: msg = data.get("msg") or "Unknown error" - return ToolResult.fail(f"Error: Bocha API error (code={api_code}): {msg}") - - # Extract and format results - return self._format_bocha_results(data, query) - - def _format_bocha_results(self, data: dict, query: str) -> ToolResult: - """ - Format Bocha API response into unified result structure - - :param data: Raw API response - :param query: Original query - :return: Formatted ToolResult - """ - search_data = data.get("data", {}) - web_pages = search_data.get("webPages", {}) - pages = web_pages.get("value", []) - - if not pages: - return ToolResult.success({ - "query": query, - "backend": "bocha", - "total": 0, - "results": [], - "message": "No results found" - }) + return ToolResult.fail(f"Error: bocha API error (code={api_code}): {msg}") + pages = (data.get("data") or {}).get("webPages", {}).get("value", []) or [] results = [] - for page in pages: - result = { - "title": page.get("name", ""), - "url": page.get("url", ""), - "snippet": page.get("snippet", ""), - "siteName": page.get("siteName", ""), - "datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""), + for p in pages: + item = { + "title": p.get("name", ""), + "url": p.get("url", ""), + "snippet": p.get("snippet", ""), + "siteName": p.get("siteName", ""), + "datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""), } - # Include summary only if present - if page.get("summary"): - result["summary"] = page["summary"] - results.append(result) - - total = web_pages.get("totalEstimatedMatches", len(results)) - + if p.get("summary"): + item["summary"] = p["summary"] + results.append(item) + total = (data.get("data") or {}).get("webPages", {}).get("totalEstimatedMatches", len(results)) return ToolResult.success({ - "query": query, - "backend": "bocha", - "total": total, - "count": len(results), - "results": results + "query": query, "backend": "bocha", + "total": total, "count": len(results), "results": results, }) - def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult: - """ - Search using LinkAI plugin API + # ------------------------------------------------------------------ + # Zhipu + # ------------------------------------------------------------------ - :param query: Search query - :param count: Number of results - :param freshness: Time range filter - :return: Formatted search results - """ - api_key = os.environ.get("LINKAI_API_KEY", "") - api_base = conf().get("linkai_api_base", "https://api.link-ai.tech") - url = f"{api_base.rstrip('/')}/v1/plugin/execute" + def _search_zhipu(self, query: str, count: int, freshness: str) -> ToolResult: + api_key = _get_api_key("zhipu") + api_base = (conf().get("zhipu_ai_api_base") or "https://open.bigmodel.cn/api/paas/v4").rstrip("/") + url = f"{api_base}/web_search" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + # Zhipu Web Search expects `search_query` <= 70 chars; truncate + # gracefully so a long agent-supplied query doesn't get rejected. + trimmed_query = (query or "")[:70] + engine = (_tools_web_search_conf().get("zhipu_search_engine") or "search_pro").strip().lower() + if engine not in ("search_std", "search_pro", "search_pro_sogou", "search_pro_quark"): + engine = "search_pro" + + payload: Dict[str, Any] = { + "search_engine": engine, + "search_query": trimmed_query, + "search_intent": False, + "count": max(1, min(int(count or 10), 50)), + "search_recency_filter": freshness if freshness in ( + "oneDay", "oneWeek", "oneMonth", "oneYear", "noLimit" + ) else "noLimit", + } + content_size = (_tools_web_search_conf().get("zhipu_content_size") or "").strip().lower() + if content_size in ("medium", "high"): + payload["content_size"] = content_size + + logger.debug(f"[WebSearch] zhipu: query='{trimmed_query}', count={payload['count']}, engine={engine}") + resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) + + if resp.status_code == 401: + return ToolResult.fail("Error: Invalid Zhipu API key.") + if resp.status_code != 200: + return ToolResult.fail(f"Error: Zhipu API returned HTTP {resp.status_code}: {resp.text[:200]}") + + data = resp.json() + # Business-level errors (1701/1702/1703 etc.) come back as + # {"error": {"code","message"}} even on HTTP 200. + if isinstance(data, dict) and data.get("error"): + err = data["error"] or {} + return ToolResult.fail(f"Error: Zhipu returned {err.get('code')}: {err.get('message','')}") + + items = data.get("search_result") or (data.get("data") or {}).get("search_result") or [] + results = [] + for it in items: + results.append({ + "title": it.get("title", ""), + "url": it.get("link") or it.get("url", ""), + "snippet": it.get("content") or it.get("snippet", ""), + "siteName": it.get("media") or it.get("siteName", ""), + "datePublished": it.get("publish_date") or it.get("datePublished", ""), + }) + return ToolResult.success({ + "query": query, "backend": "zhipu", + "total": len(results), "count": len(results), "results": results, + }) + + # ------------------------------------------------------------------ + # Qianfan (Baidu) + # ------------------------------------------------------------------ + + def _search_qianfan(self, query: str, count: int, freshness: str) -> ToolResult: + api_key = _get_api_key("qianfan") + api_base = (conf().get("qianfan_api_base") or "https://qianfan.baidubce.com/v2").rstrip("/") + url = f"{api_base}/ai_search/web_search" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "X-Appbuilder-From": "cow", + } + + count = max(1, min(int(count or 10), 50)) + payload: Dict[str, Any] = { + "messages": [{"role": "user", "content": query}], + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "web", "top_k": count}], + } + + # Baidu AI Search expects freshness as a date-range filter, not a + # named recency token. Translate our shared vocabulary into the + # underlying page_time range expected by the API. + search_filter = self._qianfan_build_freshness_filter(freshness) + if search_filter: + payload["search_filter"] = search_filter + + logger.debug(f"[WebSearch] qianfan: query='{query}', count={count}, freshness={freshness!r}") + resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) + + if resp.status_code == 401: + return ToolResult.fail("Error: Invalid Qianfan API key.") + if resp.status_code != 200: + return ToolResult.fail(f"Error: Qianfan API returned HTTP {resp.status_code}: {resp.text[:200]}") + + data = resp.json() + # Even on HTTP 200 Baidu surfaces business errors as {"code","message"}. + if isinstance(data, dict) and data.get("code"): + return ToolResult.fail(f"Error: Qianfan returned {data.get('code')}: {data.get('message','')}") + + refs = data.get("references") or [] + results = [] + for d in refs: + results.append({ + "title": d.get("title", ""), + "url": d.get("url", ""), + "snippet": (d.get("content") or "")[:200], + "siteName": d.get("web_anchor") or d.get("website") or "", + "datePublished": d.get("date", ""), + }) + return ToolResult.success({ + "query": query, "backend": "qianfan", + "total": len(results), "count": len(results), "results": results, + }) + + @staticmethod + def _qianfan_build_freshness_filter(freshness: str) -> Optional[Dict[str, Any]]: + if not freshness or freshness == "noLimit": + return None + delta_days = {"oneDay": 1, "oneWeek": 7, "oneMonth": 30, "oneYear": 365}.get(freshness) + if not delta_days: + return None + from datetime import datetime, timedelta + now = datetime.now() + end_date = (now + timedelta(days=1)).strftime("%Y-%m-%d") + start_date = (now - timedelta(days=delta_days)).strftime("%Y-%m-%d") + return {"range": {"page_time": {"gte": start_date, "lt": end_date}}} + + # ------------------------------------------------------------------ + # LinkAI (plugin) + # ------------------------------------------------------------------ + + def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult: + api_key = _get_api_key("linkai") + api_base = (conf().get("linkai_api_base") or "https://api.link-ai.tech").rstrip("/") + url = f"{api_base}/v1/plugin/execute" from common.utils import get_cloud_headers headers = get_cloud_headers(api_key) - payload = { - "code": "web-search", - "args": { - "query": query, - "count": count, - "freshness": freshness - } - } + payload = {"code": "web-search", "args": {"query": query, "count": count, "freshness": freshness}} + logger.debug(f"[WebSearch] linkai: query='{query}', count={count}") + resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) - logger.debug(f"[WebSearch] LinkAI search: query='{query}', count={count}") - - response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT) - - if response.status_code == 401: - return ToolResult.fail("Error: Invalid LINKAI_API_KEY. Please check your API key.") - if response.status_code != 200: - return ToolResult.fail(f"Error: LinkAI API returned HTTP {response.status_code}") - - data = response.json() + if resp.status_code == 401: + return ToolResult.fail("Error: Invalid LinkAI API key.") + if resp.status_code != 200: + return ToolResult.fail(f"Error: LinkAI API returned HTTP {resp.status_code}") + data = resp.json() if not data.get("success"): msg = data.get("message") or "Unknown error" return ToolResult.fail(f"Error: LinkAI search failed: {msg}") - return self._format_linkai_results(data, query) - - def _format_linkai_results(self, data: dict, query: str) -> ToolResult: - """ - Format LinkAI API response into unified result structure. - LinkAI returns the search data in data.data field, which follows - the same Bing-compatible format as Bocha. - - :param data: Raw API response - :param query: Original query - :return: Formatted ToolResult - """ - raw_data = data.get("data", "") - - # LinkAI may return data as a JSON string - if isinstance(raw_data, str): + raw = data.get("data", "") + if isinstance(raw, str): try: - raw_data = json.loads(raw_data) + raw = json.loads(raw) except (json.JSONDecodeError, TypeError): - # If data is plain text, return it as a single result return ToolResult.success({ - "query": query, - "backend": "linkai", - "total": 1, - "count": 1, - "results": [{"content": raw_data}] + "query": query, "backend": "linkai", + "total": 1, "count": 1, "results": [{"content": raw}], }) - # If the response follows Bing-compatible structure - if isinstance(raw_data, dict): - web_pages = raw_data.get("webPages", {}) - pages = web_pages.get("value", []) - + if isinstance(raw, dict): + pages = (raw.get("webPages") or {}).get("value", []) or [] if pages: results = [] - for page in pages: - result = { - "title": page.get("name", ""), - "url": page.get("url", ""), - "snippet": page.get("snippet", ""), - "siteName": page.get("siteName", ""), - "datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""), + for p in pages: + item = { + "title": p.get("name", ""), + "url": p.get("url", ""), + "snippet": p.get("snippet", ""), + "siteName": p.get("siteName", ""), + "datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""), } - if page.get("summary"): - result["summary"] = page["summary"] - results.append(result) - - total = web_pages.get("totalEstimatedMatches", len(results)) + if p.get("summary"): + item["summary"] = p["summary"] + results.append(item) + total = (raw.get("webPages") or {}).get("totalEstimatedMatches", len(results)) return ToolResult.success({ - "query": query, - "backend": "linkai", - "total": total, - "count": len(results), - "results": results + "query": query, "backend": "linkai", + "total": total, "count": len(results), "results": results, }) - # Fallback: return raw data return ToolResult.success({ - "query": query, - "backend": "linkai", - "total": 1, - "count": 1, - "results": [{"content": str(raw_data)}] + "query": query, "backend": "linkai", + "total": 1, "count": 1, "results": [{"content": str(raw)}], }) diff --git a/app.py b/app.py index ba2ab265..dbb15209 100644 --- a/app.py +++ b/app.py @@ -289,6 +289,16 @@ def _warmup_mcp_tools(): logger.warning(f"[App] MCP warmup failed (non-fatal): {e}") +def _warmup_scheduler(): + """Eager-init AgentBridge so the scheduler thread starts at process + boot rather than waiting for the first user message.""" + try: + from bridge.bridge import Bridge + Bridge().get_agent_bridge() + except Exception as e: + logger.warning(f"[App] Scheduler warmup failed: {e}") + + def _sync_builtin_skills(): """Sync builtin skills from project skills/ to workspace skills/ on startup.""" import shutil @@ -354,6 +364,8 @@ def run(): # latency isn't dominated by npx package downloads. _warmup_mcp_tools() + _warmup_scheduler() + logger.info(f"[App] Starting channels: {channel_names}") _channel_mgr = ChannelManager() diff --git a/bridge/agent_bridge.py b/bridge/agent_bridge.py index e60ffd9d..a924dab2 100644 --- a/bridge/agent_bridge.py +++ b/bridge/agent_bridge.py @@ -5,7 +5,7 @@ Agent Bridge - Integrates Agent system with existing COW bridge import os from typing import Optional, List -from agent.protocol import Agent, LLMModel, LLMRequest +from agent.protocol import Agent, LLMModel, LLMRequest, get_cancel_registry from bridge.agent_event_handler import AgentEventHandler from bridge.agent_initializer import AgentInitializer from bridge.bridge import Bridge @@ -285,6 +285,15 @@ class AgentBridge: # Create helper instances self.initializer = AgentInitializer(bridge, self) + + # Eager-start the scheduler so cron tasks fire without waiting + # for the first user message. init_scheduler is idempotent. + try: + from agent.tools.scheduler.integration import init_scheduler + if init_scheduler(self): + self.scheduler_initialized = True + except Exception as e: + logger.warning(f"[AgentBridge] Eager scheduler init failed: {e}") def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent: """ Create the super agent with COW integration @@ -390,11 +399,22 @@ class AgentBridge: """ session_id = None agent = None + request_id = None + cancel_event = None try: # Extract session_id from context for user isolation if context: session_id = context.kwargs.get("session_id") or context.get("session_id") - + request_id = context.kwargs.get("request_id") or context.get("request_id") + + # Register a cancel token. Prefer per-turn request_id (web), + # fall back to session_id (IM channels). The Event is polled by + # AgentStreamExecutor at safe checkpoints. + registry = get_cancel_registry() + token_key = request_id or session_id + if token_key: + cancel_event = registry.register(token_key, session_id=session_id) + # Get agent for this session (will auto-initialize if needed) agent = self.get_agent(session_id=session_id) if not agent: @@ -449,7 +469,8 @@ class AgentBridge: response = agent.run_stream( user_message=query, on_event=event_handler.handle_event, - clear_history=clear_history + clear_history=clear_history, + cancel_event=cancel_event, ) finally: # Restore original tools @@ -459,6 +480,13 @@ class AgentBridge: # Log execution summary event_handler.log_summary() + # Release cancel token; keep registry bounded. + if token_key: + try: + registry.unregister(token_key) + except Exception: + pass + # Persist new messages generated during this run if session_id: channel_type = (context.get("channel_type") or "") if context else "" @@ -512,6 +540,12 @@ class AgentBridge: logger.info(f"[AgentBridge] Cleared DB for session after error: {session_id}") except Exception as db_err: logger.warning(f"[AgentBridge] Failed to clear DB after error: {db_err}") + # Release cancel token on error path too (idempotent). + if cancel_event is not None and (request_id or session_id): + try: + get_cancel_registry().unregister(request_id or session_id) + except Exception: + pass return Reply(ReplyType.ERROR, f"Agent error: {str(e)}") def _schedule_mcp_hot_reload(self, agent): diff --git a/bridge/agent_event_handler.py b/bridge/agent_event_handler.py index 50826235..35173730 100644 --- a/bridge/agent_event_handler.py +++ b/bridge/agent_event_handler.py @@ -2,44 +2,40 @@ Agent Event Handler - Handles agent events and thinking process output """ +from common import const from common.log import logger +# Cap intermediate thinking messages on weixin to stay within send quota. +WEIXIN_THINKING_INSTANT_MAX = 7 + class AgentEventHandler: """ Handles agent events and optionally sends intermediate messages to channel """ - + def __init__(self, context=None, original_callback=None): - """ - Initialize event handler - - Args: - context: COW context (for accessing channel) - original_callback: Original event callback to chain - """ self.context = context self.original_callback = original_callback - - # Get channel for sending intermediate messages + self.channel = None if context: self.channel = context.kwargs.get("channel") if hasattr(context, "kwargs") else None - + self.current_content = "" self.turn_number = 0 - + + channel_type = "" + if context and hasattr(context, "kwargs"): + channel_type = context.kwargs.get("channel_type", "") or "" + self._is_weixin = channel_type == const.WEIXIN + self._thinking_sent_count = 0 + self._merged_buf: list[str] = [] + def handle_event(self, event): - """ - Main event handler - - Args: - event: Event dict with type and data - """ event_type = event.get("type") data = event.get("data", {}) - - # Dispatch to specific handlers + if event_type == "turn_start": self._handle_turn_start(data) elif event_type == "message_update": @@ -52,25 +48,23 @@ class AgentEventHandler: self._handle_tool_execution_start(data) elif event_type == "tool_execution_end": self._handle_tool_execution_end(data) - - # Call original callback if provided + elif event_type == "agent_end": + self._handle_agent_end(data) + if self.original_callback: self.original_callback(event) - + def _handle_turn_start(self, data): - """Handle turn start event""" self.turn_number = data.get("turn", 0) self.current_content = "" - + def _handle_message_update(self, data): - """Handle message update event (streaming content text)""" delta = data.get("delta", "") self.current_content += delta - + def _handle_message_end(self, data): - """Handle message end event""" tool_calls = data.get("tool_calls", []) - + if tool_calls: if self.current_content.strip(): logger.info(f"💭 {self.current_content.strip()[:200]}{'...' if len(self.current_content) > 200 else ''}") @@ -78,35 +72,54 @@ class AgentEventHandler: else: if self.current_content.strip(): logger.debug(f"💬 {self.current_content.strip()[:200]}{'...' if len(self.current_content) > 200 else ''}") - + # Drain weixin buffer before final reply leaves chat_channel + self._flush_merged_now() + self.current_content = "" - + + def _handle_agent_end(self, data): + self._flush_merged_now() + def _handle_tool_execution_start(self, data): - """Handle tool execution start event - logged by agent_stream.py""" pass - + def _handle_tool_execution_end(self, data): - """Handle tool execution end event - logged by agent_stream.py""" pass - + def _send_to_channel(self, message): - """ - Try to send intermediate message to channel. - Skipped in SSE mode because thinking text is already streamed via on_event. - """ if self.context and self.context.get("on_event"): return + if not self.channel: + return + + if not self._is_weixin: + self._do_send(message) + return + + if self._thinking_sent_count < WEIXIN_THINKING_INSTANT_MAX: + self._do_send(message) + self._thinking_sent_count += 1 + return + + self._merged_buf.append(message) + + def _flush_merged_now(self): + if not self._merged_buf: + return + merged = "\n\n".join(self._merged_buf) + count = len(self._merged_buf) + self._merged_buf = [] + logger.debug(f"[AgentEventHandler] Flushing {count} merged thinking msgs, len={len(merged)}") + self._do_send(merged) + self._thinking_sent_count += 1 + + def _do_send(self, message): + try: + from bridge.reply import Reply, ReplyType + reply = Reply(ReplyType.TEXT, message) + self.channel._send(reply, self.context) + except Exception as e: + logger.debug(f"[AgentEventHandler] Failed to send to channel: {e}") - if self.channel: - try: - from bridge.reply import Reply, ReplyType - reply = Reply(ReplyType.TEXT, message) - self.channel._send(reply, self.context) - except Exception as e: - logger.debug(f"[AgentEventHandler] Failed to send to channel: {e}") - def log_summary(self): - """Log execution summary - simplified""" - # Summary removed as per user request - # Real-time logging during execution is sufficient pass diff --git a/bridge/agent_initializer.py b/bridge/agent_initializer.py index d17dcb0c..7d5afb4a 100644 --- a/bridge/agent_initializer.py +++ b/bridge/agent_initializer.py @@ -521,7 +521,7 @@ class AgentInitializer: if tool_name == "web_search": from agent.tools.web_search.web_search import WebSearch if not WebSearch.is_available(): - logger.debug("[AgentInitializer] WebSearch skipped - no BOCHA_API_KEY or LINKAI_API_KEY") + logger.debug("[AgentInitializer] WebSearch skipped - no search provider configured") continue # Special handling for EnvConfig tool diff --git a/bridge/bridge.py b/bridge/bridge.py index 753e394a..6eeb0887 100644 --- a/bridge/bridge.py +++ b/bridge/bridge.py @@ -14,7 +14,9 @@ class Bridge(object): def __init__(self): self.btype = { "chat": const.OPENAI, - "voice_to_text": conf().get("voice_to_text", "openai"), + # Empty `voice_to_text` (the default in new configs) triggers + # the auto-pick below — see _auto_pick_voice_to_text for order. + "voice_to_text": conf().get("voice_to_text") or self._auto_pick_voice_to_text(), "text_to_voice": conf().get("text_to_voice", "google"), "translate": conf().get("translate", "baidu"), } @@ -61,6 +63,10 @@ class Bridge(object): if model_type and model_type.startswith("deepseek"): self.btype["chat"] = const.DEEPSEEK + # 小米 MiMo 系列模型,全部以 mimo- 开头 + if model_type and model_type.startswith("mimo-"): + self.btype["chat"] = const.MIMO + if model_type and isinstance(model_type, str): lowered_model_type = model_type.lower() if lowered_model_type == const.QIANFAN or lowered_model_type.startswith("ernie"): @@ -84,6 +90,46 @@ class Bridge(object): self.chat_bots = {} self._agent_bridge = None + def refresh_voice(self): + """Re-read voice_to_text / text_to_voice from config and drop the + cached voice bots so the next call picks up the new provider. + Used by the web console after the user edits voice settings. + Does NOT touch the agent_bridge / agent state. + """ + new_v2t = conf().get("voice_to_text") or self._auto_pick_voice_to_text() + new_t2v = conf().get("text_to_voice", "google") + if conf().get("use_linkai") and conf().get("linkai_api_key"): + if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]: + new_v2t = const.LINKAI + if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]: + new_t2v = const.LINKAI + self.btype["voice_to_text"] = new_v2t + self.btype["text_to_voice"] = new_t2v + self.bots.pop("voice_to_text", None) + self.bots.pop("text_to_voice", None) + logger.info(f"[Bridge] voice refreshed: voice_to_text={new_v2t}, text_to_voice={new_t2v}") + + @staticmethod + def _auto_pick_voice_to_text() -> str: + """Pick an ASR provider by configured api keys when voice_to_text is + unset. Order matches the web console: openai → dashscope → zhipu → + linkai. Falls back to 'openai' when nothing is configured so the + original "missing key" error is preserved. + """ + def has(k: str) -> bool: + v = (conf().get(k) or "").strip() + return v != "" and v not in ("YOUR API KEY", "YOUR_API_KEY") + + for key, provider in ( + ("open_ai_api_key", "openai"), + ("dashscope_api_key", "dashscope"), + ("zhipu_ai_api_key", "zhipu"), + ("linkai_api_key", "linkai"), + ): + if has(key): + return provider + return "openai" + # 模型对应的接口 def get_bot(self, typename): if self.bots.get(typename) is None: diff --git a/channel/channel_factory.py b/channel/channel_factory.py index 10000226..2645945e 100644 --- a/channel/channel_factory.py +++ b/channel/channel_factory.py @@ -42,6 +42,12 @@ def create_channel(channel_type) -> Channel: elif channel_type == const.QQ: from channel.qq.qq_channel import QQChannel ch = QQChannel() + elif channel_type == const.TELEGRAM: + from channel.telegram.telegram_channel import TelegramChannel + ch = TelegramChannel() + elif channel_type == const.SLACK: + from channel.slack.slack_channel import SlackChannel + ch = SlackChannel() elif channel_type in (const.WEIXIN, "wx"): from channel.weixin.weixin_channel import WeixinChannel ch = WeixinChannel() diff --git a/channel/chat_channel.py b/channel/chat_channel.py index 3251c286..6a9a1952 100644 --- a/channel/chat_channel.py +++ b/channel/chat_channel.py @@ -171,7 +171,13 @@ class ChatChannel(Channel): if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: context["desire_rtype"] = ReplyType.VOICE elif context.type == ContextType.VOICE: - if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: + # Voice input replies with voice when either voice_reply_voice + # (mirror voice) or the global always_reply_voice toggle is on. + if ( + "desire_rtype" not in context + and (conf().get("voice_reply_voice") or conf().get("always_reply_voice")) + and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE + ): context["desire_rtype"] = ReplyType.VOICE return context @@ -264,6 +270,8 @@ class ChatChannel(Channel): if reply.type == ReplyType.TEXT: reply_text = reply.content if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE: + # Preserve original text for the "text-then-voice" pattern in _send_reply. + context["voice_reply_text"] = reply.content reply = super().build_text_to_voice(reply.content) return self._decorate_reply(context, reply) if context.get("isgroup", False): @@ -311,6 +319,15 @@ class ChatChannel(Channel): # 短暂延迟后发送图片 time.sleep(0.3) self._send(reply, context) + # Send text bubble before voice, unless channel already streamed + # the text (feishu) or natively renders STT under the voice (wechatcom). + elif reply.type == ReplyType.VOICE and context.get("voice_reply_text") \ + and not context.get("feishu_streamed") \ + and context.get("channel_type") not in ("wechatcom_app",): + text_reply = Reply(ReplyType.TEXT, context.get("voice_reply_text")) + self._send(text_reply, context) + time.sleep(0.3) + self._send(reply, context) else: self._send(reply, context) @@ -421,8 +438,21 @@ class ChatChannel(Channel): return func + # Chat commands that must bypass the per-session serial queue, + # otherwise /cancel would queue behind the task it tries to cancel. + # Use /cancel (not /stop) to avoid colliding with `cow stop` CLI. + _BYPASS_QUEUE_COMMANDS = ("/cancel",) + def produce(self, context: Context): session_id = context["session_id"] + + # Fast path: /cancel must not enter the queue. + if context.type == ContextType.TEXT and context.content: + stripped = context.content.strip().lower() + if stripped in self._BYPASS_QUEUE_COMMANDS: + self._handle_cancel_command(context, session_id) + return + with self.lock: if session_id not in self.sessions: self.sessions[session_id] = [ @@ -434,6 +464,29 @@ class ChatChannel(Channel): else: self.sessions[session_id][0].put(context) + def _handle_cancel_command(self, context: Context, session_id: str) -> None: + """Cancel any in-flight agent run for *session_id* and reply inline. + + Runs synchronously on the caller's thread. Reply is sent through + _send_reply so plugins (e.g. logging) still observe it. + """ + try: + from agent.protocol import get_cancel_registry + from bridge.reply import Reply, ReplyType + + cancelled = get_cancel_registry().cancel_session(session_id) + text = ( + "🛑 已中止" + if cancelled > 0 + else "当前没有可中止的任务。" + ) + logger.info( + f"[chat_channel] /cancel fast-path: session={session_id}, cancelled={cancelled}" + ) + self._send_reply(context, Reply(ReplyType.TEXT, text)) + except Exception as e: + logger.warning(f"[chat_channel] /cancel fast-path failed: {e}") + # 消费者函数,单独线程,用于从消息队列中取出消息并处理 def consume(self): while True: diff --git a/channel/dingtalk/dingtalk_channel.py b/channel/dingtalk/dingtalk_channel.py index d572e35d..b1ae86c2 100644 --- a/channel/dingtalk/dingtalk_channel.py +++ b/channel/dingtalk/dingtalk_channel.py @@ -86,6 +86,8 @@ def _check(func): @singleton class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler): + NOT_SUPPORT_REPLYTYPE = [] + dingtalk_client_id = conf().get('dingtalk_client_id') dingtalk_client_secret = conf().get('dingtalk_client_secret') @@ -870,6 +872,48 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler): self.reply_text("抱歉,文件上传失败", incoming_message) return + # Native sampleAudio. Upload only accepts ogg/amr, so convert TTS mp3/wav to amr. + elif reply.type == ReplyType.VOICE: + logger.info(f"[DingTalk] Sending voice: {reply.content}") + access_token = self.get_access_token() + if not access_token: + logger.error("[DingTalk] Cannot get access token for voice") + self.reply_text("抱歉,语音发送失败(无法获取token)", incoming_message) + return + + voice_path = reply.content + if voice_path.startswith("file://"): + voice_path = voice_path[7:] + + amr_path = voice_path + duration_ms = 0 + if not voice_path.lower().endswith((".amr", ".ogg")): + try: + from voice.audio_convert import any_to_amr + amr_path = os.path.splitext(voice_path)[0] + ".amr" + duration_ms = int(any_to_amr(voice_path, amr_path) or 0) + except Exception as e: + logger.error(f"[DingTalk] Failed to convert voice to amr: {e}") + self.reply_text("抱歉,语音转码失败", incoming_message) + return + + media_id = self.upload_media(amr_path, media_type="voice") + if not media_id: + logger.error("[DingTalk] Failed to upload voice media") + self.reply_text("抱歉,语音上传失败", incoming_message) + return + + msg_param = { + "mediaId": media_id, + "duration": str(duration_ms or 1000), + } + success = self._send_file_message( + access_token, incoming_message, "sampleAudio", msg_param, isgroup + ) + if not success: + self.reply_text("抱歉,语音发送失败", incoming_message) + return + # 处理文本消息 elif reply.type == ReplyType.TEXT: logger.info(f"[DingTalk] Sending text message, length={len(reply.content)}") diff --git a/channel/feishu/feishu_channel.py b/channel/feishu/feishu_channel.py index f479394a..9a9f3307 100644 --- a/channel/feishu/feishu_channel.py +++ b/channel/feishu/feishu_channel.py @@ -752,6 +752,9 @@ class FeiShuChanel(ChatChannel): init_in_flight = [False] # 一旦初始化失败就长期标记为 disabled,本次回复不再尝试任何流式调用 disabled = [False] + # True after agent_cancelled: agent_end stops rewriting the card + # with stale final_response and just finalizes current content. + cancelled = [False] lock = threading.Lock() # ---- 异步推送队列 ---------------------------------------------------- @@ -1076,18 +1079,42 @@ class FeiShuChanel(ChatChannel): message_id[0] = None sequence[0] = 0 + elif event_type == "agent_cancelled": + # Lock channel into "no-rewrite" mode: the subsequent + # agent_end's final_response is from the last *completed* + # turn (the user already saw it), so rewriting the card + # would duplicate it visually. + with lock: + cancelled[0] = True + elif event_type == "agent_end": # 最终回复:用 final_response 覆盖当前流式卡片,然后关闭流式模式。 final_response = data.get("final_response", "") - if not final_response: - return - final_text = str(final_response) # 标记 streamed 让 chat_channel 跳过 send() context["feishu_streamed"] = True with lock: + was_cancelled = cancelled[0] has_card = card_id[0] is not None init_busy = init_in_flight[0] + pending_text = current_text[0] + + if was_cancelled: + # Cancelled path: finalize the in-flight card with + # partial output (or a short marker if empty); drop + # stale final_response to avoid duplicating last turn. + if has_card: + _drain_push_queue() + partial = (pending_text or "").rstrip() + final_text = partial or "_(已中止)_" + _stream_update_text(final_text) + _close_streaming_mode(final_text) + push_queue.put(None) + return + + if not final_response: + return + final_text = str(final_response) # 罕见情况:agent_end 触发时还没创建过卡片(极快返回 / 没有 # message_update),主动创建一张承载 final_text。 @@ -1515,10 +1542,16 @@ class FeiShuChanel(ChatChannel): else: context.type = ContextType.TEXT context.content = content.strip() + # Text input opts into voice replies only when the always-on toggle is set. + if "desire_rtype" not in context and conf().get("always_reply_voice"): + context["desire_rtype"] = ReplyType.VOICE elif context.type == ContextType.VOICE: - # 2.语音请求 - if "desire_rtype" not in context and conf().get("voice_reply_voice"): + # 2.语音请求: voice input replies with voice if either + # voice_reply_voice (mirror reply) or always_reply_voice is on. + if "desire_rtype" not in context and ( + conf().get("voice_reply_voice") or conf().get("always_reply_voice") + ): context["desire_rtype"] = ReplyType.VOICE return context diff --git a/channel/slack/__init__.py b/channel/slack/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/channel/slack/__init__.py @@ -0,0 +1 @@ + diff --git a/channel/slack/slack_channel.py b/channel/slack/slack_channel.py new file mode 100644 index 00000000..8e82fcc5 --- /dev/null +++ b/channel/slack/slack_channel.py @@ -0,0 +1,506 @@ +""" +Slack channel via Bolt for Python (Socket Mode). + +Features: +- Direct message & channel chat (text / image / file) +- Channel trigger: @mention or reply in a thread the bot is in (configurable) +- /cancel fast-path matches Web channel behaviour +- Socket Mode: no public IP / callback URL required, works behind NAT + +Implementation note: + slack_bolt's SocketModeHandler is blocking and runs its own background + threads. We start it in a dedicated thread so the rest of cow (sync) stays + untouched. Inbound events are dispatched onto cow's existing sync + ChatChannel.produce() pipeline; outbound send() calls the Slack Web API + client directly (it is sync-safe). +""" + +import os +import re +import threading + +import requests + +from bridge.context import Context, ContextType +from bridge.reply import Reply, ReplyType +from channel.chat_channel import ChatChannel, check_prefix +from channel.slack.slack_message import SlackMessage +from common.expired_dict import ExpiredDict +from common.log import logger +from common.singleton import singleton +from config import conf + + +@singleton +class SlackChannel(ChatChannel): + NOT_SUPPORT_REPLYTYPE = [] + + def __init__(self): + super().__init__() + self.bot_token = "" + self.app_token = "" + self.bot_user_id = "" # used to strip @mention and ignore self messages + self._app = None + self._handler = None + self._client = None + self._loop_thread = None + # Idempotent dedup; Slack retries event delivery on slow ack + self._received_msgs = ExpiredDict(60 * 60 * 1) + + # Disable group whitelist / prefix checks (we handle triggering ourselves + # in _should_reply_in_channel), aligned with telegram / feishu channels. + conf()["group_name_white_list"] = ["ALL_GROUP"] + conf()["single_chat_prefix"] = [""] + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def startup(self): + self.bot_token = conf().get("slack_bot_token", "") + self.app_token = conf().get("slack_app_token", "") + if not self.bot_token or not self.app_token: + err = "[Slack] slack_bot_token and slack_app_token are both required" + logger.error(err) + self.report_startup_error(err) + return + + # Guard against the common mistake of swapping the two tokens: + # bot token must start with xoxb-, app-level token with xapp-. + if not self.bot_token.startswith("xoxb-") or not self.app_token.startswith("xapp-"): + err = ( + "[Slack] token type mismatch: slack_bot_token must start with 'xoxb-' " + "and slack_app_token must start with 'xapp-' (they look swapped)" + ) + logger.error(err) + self.report_startup_error(err) + return + + try: + from slack_bolt import App + from slack_bolt.adapter.socket_mode import SocketModeHandler + except ImportError: + err = ( + "[Slack] slack_bolt is not installed. " + "Run: pip install slack_bolt" + ) + logger.error(err) + self.report_startup_error(err) + return + + try: + self._app = App(token=self.bot_token) + self._client = self._app.client + + # Resolve our own bot user id (needed for @mention strip / self-ignore) + auth = self._client.auth_test() + self.bot_user_id = auth.get("user_id", "") + self.name = self.bot_user_id # ChatChannel uses self.name to strip @-mention + logger.info(f"[Slack] Bot logged in as user_id={self.bot_user_id}, team={auth.get('team')}") + except Exception as e: + err = f"[Slack] auth_test failed: {e}" + logger.error(err) + self.report_startup_error(err) + return + + self._register_handlers() + + self._handler = SocketModeHandler(self._app, self.app_token) + + def _run(): + try: + logger.info("[Slack] Starting Socket Mode connection...") + self.report_startup_success() + logger.info("[Slack] ✅ Slack bot ready, listening for events") + self._handler.start() + except Exception as e: + logger.error(f"[Slack] socket mode crashed: {e}", exc_info=True) + self.report_startup_error(str(e)) + finally: + logger.info("[Slack] socket mode exited") + + self._loop_thread = threading.Thread(target=_run, daemon=True, name="slack-socket") + self._loop_thread.start() + # Block startup() until the handler thread exits, matching other channels' + # behaviour (startup is a blocking call). + self._loop_thread.join() + + def _register_handlers(self): + app = self._app + + # app_mention: bot is @-mentioned in a channel + @app.event("app_mention") + def _on_app_mention(event, ack): + ack() + self._handle_event(event, is_group=True) + + # message: DMs and channel messages (including thread replies) + @app.event("message") + def _on_message(event, ack): + ack() + self._handle_message_event(event) + + def stop(self): + logger.info("[Slack] stop() called") + try: + if self._handler is not None: + self._handler.close() + except Exception as e: + logger.warning(f"[Slack] handler close error: {e}") + if self._loop_thread and self._loop_thread.is_alive(): + try: + self._loop_thread.join(timeout=10) + except Exception: + pass + logger.info("[Slack] stop() completed") + + # ------------------------------------------------------------------ + # Inbound: slack event -> ChatMessage -> ChatChannel.produce + # ------------------------------------------------------------------ + + def _handle_message_event(self, event: dict): + """Route a raw `message` event: skip bot/system noise, decide grouping.""" + try: + logger.debug( + f"[Slack] message event: channel_type={event.get('channel_type')}, " + f"subtype={event.get('subtype')}, user={event.get('user')}, " + f"ts={event.get('ts')}, thread_ts={event.get('thread_ts')}" + ) + # Ignore bot messages (including our own) and message edits/deletes + if event.get("bot_id") or event.get("subtype") in ("bot_message", "message_changed", "message_deleted"): + return + if event.get("user") == self.bot_user_id: + return + + channel_type = event.get("channel_type", "") + # DM (im) is single chat; channel/group is group chat. app_mention + # already covers channel @-mentions, so for plain channel messages we + # only react when configured / thread-following. + is_group = channel_type in ("channel", "group", "mpim") + if is_group: + # app_mention handler covers explicit @bot; here we only handle + # follow-up replies in threads the bot participates in. + if not self._should_reply_in_channel(event): + return + self._handle_event(event, is_group=is_group) + except Exception as e: + logger.error(f"[Slack] _handle_message_event error: {e}", exc_info=True) + + def _handle_event(self, event: dict, is_group: bool): + """Parse event -> build SlackMessage -> produce().""" + try: + channel_id = event.get("channel", "") + ts = event.get("ts", "") + if not channel_id: + return + + # Idempotent dedup + msg_uid = f"{channel_id}:{ts}" + if self._received_msgs.get(msg_uid): + return + self._received_msgs[msg_uid] = True + + # Parse type + download media if needed. + ctype, content, caption = self._parse_event(event) + if ctype is None: + logger.debug(f"[Slack] unsupported message type, skip. event={event}") + return + + # Strip <@bot_user_id> mention from channel text + if is_group and self.bot_user_id: + if ctype == ContextType.TEXT and content: + content = self._strip_at_mention(content) + if caption: + caption = self._strip_at_mention(caption) + + slack_msg = SlackMessage( + event, + is_group=is_group, + bot_user_id=self.bot_user_id, + ctype=ctype, + content=content, + ) + slack_msg.is_at = is_group # if we reached here in a channel, bot is mentioned/threaded + + from channel.file_cache import get_file_cache + file_cache = get_file_cache() + session_id = self._compute_session_id(event, is_group) + + # Media + caption together: treat as a complete query and bypass the cache + if ctype in (ContextType.IMAGE, ContextType.FILE) and caption: + tag = "image" if ctype == ContextType.IMAGE else "file" + merged_text = f"{caption}\n[{tag}: {content}]" + slack_msg.ctype = ContextType.TEXT + slack_msg.content = merged_text + ctype = ContextType.TEXT + logger.info(f"[Slack] Media+caption merged for session {session_id}") + # fallthrough to the TEXT branch below + + elif ctype == ContextType.IMAGE: + file_cache.add(session_id, content, file_type="image") + logger.info(f"[Slack] Image cached for session {session_id}, waiting for query...") + return + elif ctype == ContextType.FILE: + file_cache.add(session_id, content, file_type="file") + logger.info(f"[Slack] File cached for session {session_id}: {content}") + return + + if ctype == ContextType.TEXT: + # Fast-path: /cancel mirrors Web channel behaviour + if (content or "").strip().lower() in ("/cancel", "cancel"): + self._do_cancel(session_id, channel_id, event) + return + + cached_files = file_cache.get(session_id) + if cached_files: + refs = [] + for fi in cached_files: + ftype = fi["type"] + tag = ftype if ftype in ("image", "video") else "file" + refs.append(f"[{tag}: {fi['path']}]") + slack_msg.content = (slack_msg.content or "") + "\n" + "\n".join(refs) + file_cache.clear(session_id) + logger.info(f"[Slack] Attached {len(cached_files)} cached file(s) to query") + + # Reply in the originating thread when present, else start one on this msg + thread_ts = event.get("thread_ts") or ts + + context = self._compose_context( + slack_msg.ctype, + slack_msg.content, + isgroup=is_group, + msg=slack_msg, + # Replies go back into the thread, no manual @mention needed + no_need_at=True, + ) + if context: + context["session_id"] = session_id + context["receiver"] = channel_id + context["slack_channel"] = channel_id + context["slack_thread_ts"] = thread_ts if is_group else None + self.produce(context) + logger.debug(f"[Slack] received: type={ctype}, content={str(slack_msg.content)[:80]}") + except Exception as e: + logger.error(f"[Slack] _handle_event error: {e}", exc_info=True) + + def _do_cancel(self, session_id: str, channel_id: str, event: dict): + """Fast-path: /cancel calls cancel_session directly without going through agent.""" + try: + from agent.protocol import get_cancel_registry + cancelled = get_cancel_registry().cancel_session(session_id) + text = "Current task cancelled." if cancelled else "No running task to cancel." + thread_ts = event.get("thread_ts") or event.get("ts") + self._client.chat_postMessage(channel=channel_id, text=text, thread_ts=thread_ts) + logger.info(f"[Slack] /cancel session={session_id}, cancelled={cancelled}") + except Exception as e: + logger.error(f"[Slack] /cancel error: {e}", exc_info=True) + + def _parse_event(self, event: dict): + """Parse a slack event and return (ctype, content, caption). + + - content is text for ContextType.TEXT, otherwise the local file path + - caption is the optional text accompanying a file; empty for plain text + """ + text = (event.get("text") or "").strip() + files = event.get("files") or [] + + if files: + # Handle the first attachment; caption is the accompanying message text + f = files[0] + mimetype = (f.get("mimetype") or "").lower() + url = f.get("url_private_download") or f.get("url_private") + name = f.get("name") or f.get("id") or "file" + if not url: + return (None, None, "") + path = self._download_file(url, name) + if not path: + return (None, None, "") + if mimetype.startswith("image/"): + return (ContextType.IMAGE, path, text) + return (ContextType.FILE, path, text) + + if text: + return (ContextType.TEXT, text, "") + + return (None, None, "") + + def _download_file(self, url: str, name: str): + """Download a Slack private file (requires bot token auth) to local tmp dir.""" + try: + headers = {"Authorization": f"Bearer {self.bot_token}"} + resp = requests.get(url, headers=headers, timeout=60, stream=True) + resp.raise_for_status() + tmp_dir = SlackMessage.get_tmp_dir() + # Sanitize the name and keep it unique-ish via the url tail + safe_name = re.sub(r"[^\w.\-]", "_", name) + local_path = os.path.join(tmp_dir, safe_name) + with open(local_path, "wb") as fp: + for chunk in resp.iter_content(chunk_size=8192): + if chunk: + fp.write(chunk) + logger.debug(f"[Slack] downloaded {name} -> {local_path}") + return local_path + except Exception as e: + logger.error(f"[Slack] download_file failed ({name}): {e}") + return None + + # ------------------------------------------------------------------ + # Channel trigger logic + # ------------------------------------------------------------------ + + def _should_reply_in_channel(self, event: dict) -> bool: + """Decide whether to reply to a plain channel message (no @mention). + + app_mention already handles explicit @bot, so here we only deal with + follow-up messages. `all` replies to every message; `mention_or_reply` + replies inside threads the bot already participates in. + """ + mode = conf().get("slack_group_trigger", "mention_or_reply") + if mode == "all": + return True + if mode == "mention_only": + return False + # mention_or_reply: follow up only within an existing thread + return bool(event.get("thread_ts")) + + def _strip_at_mention(self, content: str) -> str: + """Strip <@BOT_USER_ID> from channel text.""" + if not content or not self.bot_user_id: + return content + pattern = re.compile(r"<@" + re.escape(self.bot_user_id) + r">", re.IGNORECASE) + return pattern.sub("", content).strip() + + @staticmethod + def _compute_session_id(event: dict, is_group: bool) -> str: + channel_id = event.get("channel", "") + user_id = event.get("user", "") + if is_group: + if conf().get("group_shared_session", True): + return f"slack_channel_{channel_id}" + return f"slack_channel_{channel_id}_{user_id}" + return f"slack_user_{user_id}" + + # ------------------------------------------------------------------ + # Override _compose_context: skip the parent's group whitelist/at checks + # (already handled via _should_reply_in_channel). Same idea as telegram. + # ------------------------------------------------------------------ + + def _compose_context(self, ctype: ContextType, content, **kwargs): + context = Context(ctype, content) + context.kwargs = kwargs + if "channel_type" not in context: + context["channel_type"] = self.channel_type + if "origin_ctype" not in context: + context["origin_ctype"] = ctype + + cmsg = context["msg"] + if cmsg.is_group: + if conf().get("group_shared_session", True): + context["session_id"] = cmsg.other_user_id + else: + context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}" + else: + context["session_id"] = cmsg.from_user_id + context["receiver"] = cmsg.other_user_id + + if ctype == ContextType.TEXT: + img_match_prefix = check_prefix(content, conf().get("image_create_prefix")) + if img_match_prefix: + content = content.replace(img_match_prefix, "", 1) + context.type = ContextType.IMAGE_CREATE + else: + context.type = ContextType.TEXT + context.content = (content or "").strip() + if "desire_rtype" not in context and conf().get("always_reply_voice"): + context["desire_rtype"] = ReplyType.VOICE + elif ctype == ContextType.VOICE: + if "desire_rtype" not in context and ( + conf().get("voice_reply_voice") or conf().get("always_reply_voice") + ): + context["desire_rtype"] = ReplyType.VOICE + + return context + + # ------------------------------------------------------------------ + # Outbound: ChatChannel.send -> Slack Web API + # ------------------------------------------------------------------ + + def send(self, reply: Reply, context: Context): + """Called from cow's sync main thread; Slack Web client is sync-safe.""" + if self._client is None: + logger.warning("[Slack] client not ready, drop reply") + return + + channel_id = context.get("slack_channel") + thread_ts = context.get("slack_thread_ts") + if not channel_id: + logger.warning("[Slack] no slack_channel in context, drop reply") + return + + try: + self._do_send(reply, channel_id, thread_ts) + logger.info(f"[Slack] sent reply (type={reply.type}, channel={channel_id})") + except Exception as e: + logger.error(f"[Slack] send failed: {e}", exc_info=True) + + def _do_send(self, reply: Reply, channel_id: str, thread_ts): + rtype = reply.type + content = reply.content + + if rtype in (ReplyType.TEXT, ReplyType.INFO, ReplyType.ERROR): + text = str(content) if content is not None else "" + if not text: + return + # Slack caps a message around 40k chars; split conservatively + for chunk in _split_text(text, 3500): + self._client.chat_postMessage(channel=channel_id, text=chunk, thread_ts=thread_ts) + + elif rtype == ReplyType.IMAGE: + # Already a local BytesIO; upload it directly + content.seek(0) + self._client.files_upload_v2( + channel=channel_id, file=content, filename="image.png", thread_ts=thread_ts, + ) + + elif rtype == ReplyType.IMAGE_URL: + url = str(content) + if url.startswith("file://"): + local = url[7:] + self._client.files_upload_v2( + channel=channel_id, file=local, thread_ts=thread_ts, + ) + else: + # Post the URL as text; Slack will unfurl it as an image preview + self._client.chat_postMessage(channel=channel_id, text=url, thread_ts=thread_ts) + + elif rtype in (ReplyType.VOICE, ReplyType.FILE): + local = content[7:] if isinstance(content, str) and content.startswith("file://") else content + caption = getattr(reply, "text_content", None) or None + self._client.files_upload_v2( + channel=channel_id, file=local, initial_comment=caption, thread_ts=thread_ts, + ) + + else: + # Fallback: send as plain text + self._client.chat_postMessage(channel=channel_id, text=str(content), thread_ts=thread_ts) + + +def _split_text(text: str, limit: int): + """Split long text preferring line breaks to keep markdown structure intact.""" + if len(text) <= limit: + yield text + return + buf = [] + size = 0 + for line in text.splitlines(keepends=True): + if size + len(line) > limit and buf: + yield "".join(buf) + buf, size = [], 0 + # Hard-split single lines that exceed the limit + while len(line) > limit: + yield line[:limit] + line = line[limit:] + buf.append(line) + size += len(line) + if buf: + yield "".join(buf) diff --git a/channel/slack/slack_message.py b/channel/slack/slack_message.py new file mode 100644 index 00000000..39f215bd --- /dev/null +++ b/channel/slack/slack_message.py @@ -0,0 +1,60 @@ +""" +Slack message adapter. + +Convert a Slack event payload into cow's unified ChatMessage. +File downloads are NOT performed here; the channel layer downloads files +on demand because it needs the bot token for authenticated download URLs. +""" +import os + +from bridge.context import ContextType +from channel.chat_message import ChatMessage +from common.utils import expand_path +from config import conf + + +class SlackMessage(ChatMessage): + """Wrap a Slack event into the unified ChatMessage.""" + + def __init__(self, event: dict, is_group: bool = False, bot_user_id: str = "", + ctype: ContextType = ContextType.TEXT, content: str = ""): + super().__init__(event) + # Basic fields + self.msg_id = event.get("client_msg_id") or event.get("ts") or "" + try: + self.create_time = int(float(event.get("ts", 0))) + except (TypeError, ValueError): + self.create_time = 0 + self.ctype = ctype + self.content = content + + # Sender / chat info + from_user_id = event.get("user", "unknown") + channel_id = event.get("channel", "") + self.from_user_id = from_user_id + self.from_user_nickname = from_user_id + self.to_user_id = bot_user_id or "slack_bot" + self.to_user_nickname = bot_user_id or "slack_bot" + + self.is_group = is_group + if is_group: + # Channel chat: other_user_id = channel_id, actual_user_id = sender id + self.other_user_id = channel_id + self.other_user_nickname = channel_id + self.actual_user_id = from_user_id + self.actual_user_nickname = from_user_id + else: + # DM: use channel_id so replies go back to the same DM channel + self.other_user_id = channel_id or from_user_id + self.other_user_nickname = from_user_id + + # Whether the bot was triggered by @-mention (set by channel layer) + self.is_at = False + + @staticmethod + def get_tmp_dir() -> str: + """Local download directory, aligned with other channels (agent_workspace/tmp).""" + workspace_root = expand_path(conf().get("agent_workspace", "~/cow")) + tmp_dir = os.path.join(workspace_root, "tmp") + os.makedirs(tmp_dir, exist_ok=True) + return tmp_dir diff --git a/channel/telegram/__init__.py b/channel/telegram/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/channel/telegram/telegram_channel.py b/channel/telegram/telegram_channel.py new file mode 100644 index 00000000..9e40c59f --- /dev/null +++ b/channel/telegram/telegram_channel.py @@ -0,0 +1,719 @@ +""" +Telegram channel via Bot API (long polling mode). + +Features: +- Single chat & group chat (text / photo / voice / video / document) +- Group trigger: @mention or reply-to-bot (configurable) +- /cancel fast-path matches Web channel behaviour +- Auto-register bot commands menu on startup (mirrors Web slash menu) +- Optional HTTP/SOCKS5 proxy support for restricted networks + +Implementation note: + python-telegram-bot is async-first. We run the bot inside a dedicated + thread with its own asyncio loop so the rest of cow (which is sync) + stays untouched. Inbound updates are dispatched onto cow's existing + sync ChatChannel.produce() pipeline; outbound send() schedules + coroutines back onto that loop via asyncio.run_coroutine_threadsafe. +""" + +import asyncio +import os +import re +import threading + +from bridge.context import Context, ContextType +from bridge.reply import Reply, ReplyType +from channel.chat_channel import ChatChannel, check_prefix +from channel.telegram.telegram_message import TelegramMessage +from common.expired_dict import ExpiredDict +from common.log import logger +from common.singleton import singleton +from config import conf + +# Bot command menu, aligned with Web slash commands. +# Top-level commands only; sub-commands are entered with a space (e.g. "/skill list"). +TELEGRAM_BOT_COMMANDS = [ + ("help", "Show command help"), + ("status", "Show running status"), + ("context", "View/clear conversation context (sub: clear)"), + ("skill", "Manage skills (list/search/install/...)"), + ("memory", "Manage memory (sub: dream)"), + ("knowledge", "Manage knowledge base (list/on/off)"), + ("config", "Show current config"), + ("cancel", "Cancel running agent task"), + ("logs", "Show recent logs"), + ("version", "Show version"), +] + + +@singleton +class TelegramChannel(ChatChannel): + NOT_SUPPORT_REPLYTYPE = [] + + def __init__(self): + super().__init__() + self.bot_token = "" + self.bot_username = "" # used for @-mention matching + self._bot = None + self._application = None + self._loop = None + self._loop_thread = None + self._stop_event = threading.Event() + # Idempotent dedup; TG occasionally redelivers the same update on flaky networks + self._received_msgs = ExpiredDict(60 * 60 * 1) + + # Disable group whitelist / prefix checks (we handle triggering ourselves + # in _should_reply_in_group), aligned with feishu / wecom_bot channels. + conf()["group_name_white_list"] = ["ALL_GROUP"] + conf()["single_chat_prefix"] = [""] + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def startup(self): + self.bot_token = conf().get("telegram_token", "") + if not self.bot_token: + err = "[Telegram] telegram_token is required" + logger.error(err) + self.report_startup_error(err) + return + + try: + from telegram.ext import ( + Application, + MessageHandler, + CommandHandler, + filters, + ) + except ImportError: + err = ( + "[Telegram] python-telegram-bot is not installed. " + "Run: pip install python-telegram-bot" + ) + logger.error(err) + self.report_startup_error(err) + return + + # Run the asyncio event loop in a dedicated thread so the sync cow body + # is untouched. + self._loop = asyncio.new_event_loop() + + def _run_loop(): + asyncio.set_event_loop(self._loop) + try: + self._loop.run_until_complete(self._async_main(Application, MessageHandler, CommandHandler, filters)) + except Exception as e: + logger.error(f"[Telegram] event loop crashed: {e}", exc_info=True) + self.report_startup_error(str(e)) + finally: + try: + self._loop.close() + except Exception: + pass + logger.info("[Telegram] event loop exited") + + self._loop_thread = threading.Thread(target=_run_loop, daemon=True, name="telegram-loop") + self._loop_thread.start() + # Block startup() until the loop thread exits, matching other channels' + # behaviour (startup is a blocking call). + self._loop_thread.join() + + async def _async_main(self, Application, MessageHandler, CommandHandler, filters): + """Build Application, register handlers, and run polling.""" + builder = Application.builder().token(self.bot_token) + + # Proxy: prefer telegram_proxy config, fall back to HTTPS_PROXY env var + proxy_url = conf().get("telegram_proxy", "") or os.environ.get("HTTPS_PROXY", "") + if proxy_url: + try: + builder = builder.proxy(proxy_url).get_updates_proxy(proxy_url) + logger.info(f"[Telegram] using proxy: {proxy_url}") + except Exception as e: + logger.warning(f"[Telegram] proxy config failed, fallback to direct: {e}") + + # Media uploads (photo/voice/video/document) over a proxy can be slow, + # bump read/write/connect/pool timeouts. + builder = ( + builder + .read_timeout(60) + .write_timeout(120) + .connect_timeout(30) + .pool_timeout(30) + ) + + application = builder.build() + self._application = application + self._bot = application.bot + + # Fetch our own username (needed for @-mention matching in groups) + try: + me = await self._bot.get_me() + self.bot_username = me.username or "" + self.name = self.bot_username # ChatChannel uses self.name to strip @-mention + logger.info(f"[Telegram] Bot logged in as @{self.bot_username} (id={me.id})") + except Exception as e: + err = f"[Telegram] get_me failed: {e}" + logger.error(err) + self.report_startup_error(err) + return + + # Register the command menu (failure is non-fatal) + if conf().get("telegram_register_commands", True): + try: + from telegram import BotCommand + cmds = [BotCommand(name, desc) for name, desc in TELEGRAM_BOT_COMMANDS] + await self._bot.set_my_commands(cmds) + logger.info(f"[Telegram] Registered {len(cmds)} bot commands") + except Exception as e: + logger.warning(f"[Telegram] set_my_commands failed: {e}") + + # Handlers: + # 1) /cancel uses the fast-path + application.add_handler(CommandHandler("cancel", self._on_cancel)) + # 2) Normal messages (text + media) + application.add_handler(MessageHandler(filters.ALL & ~filters.COMMAND, self._on_message)) + # 3) Other slash commands are forwarded as plain text for the agent to handle + application.add_handler(MessageHandler(filters.COMMAND, self._on_command_passthrough)) + + # Start polling. drop_pending_updates avoids replaying backlog after restart. + # Transient "Server disconnected" / RemoteProtocolError during get_updates + # are common over proxies/flaky networks; PTB's network loop auto-retries, + # so we only need to keep the noise down (see _quiet_polling_network_errors). + self._quiet_polling_network_errors() + logger.info("[Telegram] Starting long polling...") + await application.initialize() + await application.start() + await application.updater.start_polling( + drop_pending_updates=True, + # Long-poll hold time on the server side; smaller value = reconnect more + # often but each hung connection fails faster. + timeout=30, + # Retry forever on transient get_updates network errors instead of giving up. + bootstrap_retries=-1, + ) + self.report_startup_success() + logger.info("[Telegram] ✅ Telegram bot ready, polling for updates") + + # Block until stop() + try: + while not self._stop_event.is_set(): + await asyncio.sleep(0.5) + finally: + try: + await application.updater.stop() + await application.stop() + await application.shutdown() + except Exception as e: + logger.warning(f"[Telegram] shutdown error: {e}") + + @staticmethod + def _quiet_polling_network_errors(): + """Downgrade PTB's noisy 'Exception happened while polling for updates' logs. + + These transient get_updates errors (RemoteProtocolError / NetworkError / + TimedOut, typically over a proxy) are auto-retried by PTB's network loop, + so logging the full traceback at ERROR is just noise. We attach a filter + that drops these specific records while leaving real errors untouched. + """ + import logging + + class _PollingNoiseFilter(logging.Filter): + _NEEDLES = ( + "Exception happened while polling for updates", + "Server disconnected without sending a response", + ) + + def filter(self, record: logging.LogRecord) -> bool: + try: + msg = record.getMessage() + except Exception: + return True + if any(n in msg for n in self._NEEDLES): + # Keep a single-line breadcrumb at DEBUG, drop the traceback. + logger.debug(f"[Telegram] transient polling network error (auto-retrying): {msg.splitlines()[0]}") + return False + return True + + noise_filter = _PollingNoiseFilter() + for name in ("telegram.ext.Updater", "telegram.ext._updater", "telegram.ext"): + logging.getLogger(name).addFilter(noise_filter) + + def stop(self): + logger.info("[Telegram] stop() called") + self._stop_event.set() + if self._loop_thread and self._loop_thread.is_alive(): + try: + self._loop_thread.join(timeout=10) + except Exception: + pass + logger.info("[Telegram] stop() completed") + + # ------------------------------------------------------------------ + # Inbound: telegram update -> ChatMessage -> ChatChannel.produce + # ------------------------------------------------------------------ + + async def _on_cancel(self, update, _context): + """Fast-path: /cancel calls cancel_session directly without going through agent.""" + try: + from agent.protocol import get_cancel_registry + session_id = self._compute_session_id(update) + cancelled = get_cancel_registry().cancel_session(session_id) + text = "Current task cancelled." if cancelled else "No running task to cancel." + await update.effective_message.reply_text(text) + logger.info(f"[Telegram] /cancel session={session_id}, cancelled={cancelled}") + except Exception as e: + logger.error(f"[Telegram] /cancel error: {e}", exc_info=True) + try: + await update.effective_message.reply_text(f"⚠️ /cancel failed: {e}") + except Exception: + pass + + async def _on_command_passthrough(self, update, _context): + """All non-/cancel commands fall through to plain message handling.""" + await self._on_message(update, _context) + + async def _on_message(self, update, _context): + """Telegram update entry: parse message -> build ChatMessage -> produce().""" + try: + message = update.effective_message + chat = update.effective_chat + if not message or not chat: + return + + # Idempotent dedup + msg_uid = f"{chat.id}:{message.message_id}" + if self._received_msgs.get(msg_uid): + return + self._received_msgs[msg_uid] = True + + is_group = chat.type in ("group", "supergroup") + + # Debug log: helpful when group messages are silently dropped + if is_group: + logger.debug( + f"[Telegram] group update received: chat_id={chat.id}, " + f"text={(message.text or message.caption or '')[:40]!r}, " + f"reply_to_bot={bool(message.reply_to_message and message.reply_to_message.from_user and message.reply_to_message.from_user.username == self.bot_username)}" + ) + + # Group trigger gate (silently drop if not triggered) + if is_group and not self._should_reply_in_group(update): + logger.debug(f"[Telegram] group message not triggered (need @{self.bot_username} or reply), skip") + return + + # Parse message type + download media if needed. + # Media messages with caption return both the local path and the caption text. + ctype, content, caption = await self._parse_message(message) + if ctype is None: + logger.debug(f"[Telegram] unsupported message type, skip. msg={message}") + return + + # Strip @bot mention for group text/caption + if is_group and self.bot_username: + if ctype == ContextType.TEXT and content: + content = self._strip_at_mention(content) + if caption: + caption = self._strip_at_mention(caption) + + tg_msg = TelegramMessage( + update, + is_group=is_group, + bot_username=self.bot_username, + ctype=ctype, + content=content, + ) + tg_msg.is_at = is_group # If we got here in a group, the bot is mentioned/replied + + # File cache: standalone media goes into cache, the next text query attaches them + from channel.file_cache import get_file_cache + file_cache = get_file_cache() + session_id = self._compute_session_id(update) + + # Media + caption together: treat as a complete query and bypass the cache + if ctype in (ContextType.IMAGE, ContextType.FILE) and caption: + tag = "image" if ctype == ContextType.IMAGE else "file" + merged_text = f"{caption}\n[{tag}: {content}]" + tg_msg.ctype = ContextType.TEXT + tg_msg.content = merged_text + ctype = ContextType.TEXT + logger.info(f"[Telegram] Media+caption merged for session {session_id}") + # fallthrough to the TEXT branch below + + elif ctype == ContextType.IMAGE: + file_cache.add(session_id, content, file_type="image") + logger.info(f"[Telegram] Image cached for session {session_id}, waiting for query...") + return + elif ctype == ContextType.FILE: + file_cache.add(session_id, content, file_type="file") + logger.info(f"[Telegram] File cached for session {session_id}: {content}") + return + + if ctype == ContextType.TEXT: + cached_files = file_cache.get(session_id) + if cached_files: + refs = [] + for fi in cached_files: + ftype = fi["type"] + tag = ftype if ftype in ("image", "video") else "file" + refs.append(f"[{tag}: {fi['path']}]") + tg_msg.content = (tg_msg.content or "") + "\n" + "\n".join(refs) + file_cache.clear(session_id) + logger.info(f"[Telegram] Attached {len(cached_files)} cached file(s) to query") + + # Dispatch to cow main pipeline (reuses ChatChannel._compose_context routing) + context = self._compose_context( + tg_msg.ctype, + tg_msg.content, + isgroup=is_group, + msg=tg_msg, + ) + if context: + context["session_id"] = session_id + context["receiver"] = str(chat.id) + context["telegram_chat_id"] = chat.id + context["telegram_reply_to_msg_id"] = message.message_id if is_group else None + self.produce(context) + logger.debug(f"[Telegram] received: type={ctype}, content={str(tg_msg.content)[:80]}") + + except Exception as e: + logger.error(f"[Telegram] _on_message error: {e}", exc_info=True) + + async def _parse_message(self, message): + """Parse a telegram message and return (ctype, content, caption). + + - content is text for ContextType.TEXT, otherwise the local file path + - caption is the optional text accompanying a media message; empty for plain text + """ + caption = (message.caption or "").strip() + + if message.photo: + largest = message.photo[-1] + path = await self._download_file(largest.file_id, suffix=".jpg") + return (ContextType.IMAGE, path, caption) if path else (None, None, "") + + if message.voice or message.audio: + audio_obj = message.voice or message.audio + suffix = ".ogg" if message.voice else ( + "." + (audio_obj.mime_type.split("/")[-1] if getattr(audio_obj, "mime_type", "") else "mp3") + ) + path = await self._download_file(audio_obj.file_id, suffix=suffix) + return (ContextType.VOICE, path, caption) if path else (None, None, "") + + if message.video or message.video_note: + video_obj = message.video or message.video_note + path = await self._download_file(video_obj.file_id, suffix=".mp4") + return (ContextType.FILE, path, caption) if path else (None, None, "") + + if message.document: + doc = message.document + ext = "" + if doc.file_name and "." in doc.file_name: + ext = "." + doc.file_name.rsplit(".", 1)[-1] + path = await self._download_file(doc.file_id, suffix=ext, original_name=doc.file_name) + if not path: + return (None, None, "") + # Image-typed documents (user picked "send as file") are treated as images + mime = (doc.mime_type or "").lower() + if mime.startswith("image/"): + return (ContextType.IMAGE, path, caption) + return (ContextType.FILE, path, caption) + + if message.text: + return (ContextType.TEXT, message.text.strip(), "") + + return (None, None, "") + + async def _download_file(self, file_id: str, suffix: str = "", original_name: str = ""): + """Download via bot.get_file into the local tmp dir; return path or None on failure.""" + try: + f = await self._bot.get_file(file_id) + tmp_dir = TelegramMessage.get_tmp_dir() + base = original_name or f"{file_id}{suffix or ''}" + # Prefix with file_id to avoid name collisions / weird chars + safe_name = f"{file_id}_{base}" if original_name else base + local_path = os.path.join(tmp_dir, safe_name) + await f.download_to_drive(custom_path=local_path) + logger.debug(f"[Telegram] downloaded file_id={file_id} -> {local_path}") + return local_path + except Exception as e: + logger.error(f"[Telegram] download_file failed (file_id={file_id}): {e}") + return None + + # ------------------------------------------------------------------ + # Group trigger logic + # ------------------------------------------------------------------ + + def _should_reply_in_group(self, update) -> bool: + """Decide whether to reply to a group message based on configuration.""" + mode = conf().get("telegram_group_trigger", "mention_or_reply") + if mode == "all": + return True + + message = update.effective_message + if not message: + return False + + # 1) Mentioned + if self.bot_username and self._is_mentioned(message, self.bot_username): + return True + + # 2) Reply to a bot message + if mode == "mention_or_reply": + reply = message.reply_to_message + if reply and reply.from_user and reply.from_user.username == self.bot_username: + return True + + return False + + @staticmethod + def _is_mentioned(message, bot_username: str) -> bool: + """Check whether entities/caption_entities contain a @mention of the bot.""" + bot_at = "@" + bot_username.lower() + text = (message.text or message.caption or "").lower() + if bot_at in text: + return True + # Also check entities strictly to support text_mention (no-username @) + for ent in (message.entities or []) + (message.caption_entities or []): + if ent.type == "mention": + src = message.text or message.caption or "" + if src[ent.offset: ent.offset + ent.length].lower() == bot_at: + return True + return False + + def _strip_at_mention(self, content: str) -> str: + """Strip @bot_username from group text (case-insensitive).""" + if not content or not self.bot_username: + return content + pattern = re.compile(r"@" + re.escape(self.bot_username), re.IGNORECASE) + return pattern.sub("", content).strip() + + @staticmethod + def _compute_session_id(update) -> str: + chat = update.effective_chat + user = update.effective_user + is_group = chat.type in ("group", "supergroup") + if is_group: + if conf().get("group_shared_session", True): + return f"tg_group_{chat.id}" + return f"tg_group_{chat.id}_{user.id}" + return f"tg_user_{user.id}" + + # ------------------------------------------------------------------ + # Override _compose_context: skip the parent's group whitelist/at checks + # (already handled in _on_message via _should_reply_in_group). Same idea + # as the feishu channel. + # ------------------------------------------------------------------ + + def _compose_context(self, ctype: ContextType, content, **kwargs): + context = Context(ctype, content) + context.kwargs = kwargs + if "channel_type" not in context: + context["channel_type"] = self.channel_type + if "origin_ctype" not in context: + context["origin_ctype"] = ctype + + cmsg = context["msg"] + if cmsg.is_group: + if conf().get("group_shared_session", True): + context["session_id"] = cmsg.other_user_id + else: + context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}" + else: + context["session_id"] = cmsg.from_user_id + context["receiver"] = cmsg.other_user_id + + if ctype == ContextType.TEXT: + img_match_prefix = check_prefix(content, conf().get("image_create_prefix")) + if img_match_prefix: + content = content.replace(img_match_prefix, "", 1) + context.type = ContextType.IMAGE_CREATE + else: + context.type = ContextType.TEXT + context.content = (content or "").strip() + if "desire_rtype" not in context and conf().get("always_reply_voice"): + context["desire_rtype"] = ReplyType.VOICE + elif ctype == ContextType.VOICE: + if "desire_rtype" not in context and ( + conf().get("voice_reply_voice") or conf().get("always_reply_voice") + ): + context["desire_rtype"] = ReplyType.VOICE + + return context + + # ------------------------------------------------------------------ + # Outbound: ChatChannel.send -> Telegram API + # ------------------------------------------------------------------ + + def send(self, reply: Reply, context: Context): + """Called from cow's sync main thread; we marshal the coroutine onto the loop thread.""" + if self._loop is None or self._bot is None: + logger.warning("[Telegram] bot not ready, drop reply") + return + + chat_id = context.get("telegram_chat_id") + reply_to = context.get("telegram_reply_to_msg_id") + if chat_id is None: + logger.warning("[Telegram] no telegram_chat_id in context, drop reply") + return + + coro = self._async_send(reply, chat_id, reply_to) + try: + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + # Media uploads through a proxy can be slow; let PTB's own timeouts win + future.result(timeout=180) + except Exception as e: + logger.error(f"[Telegram] send failed: {e}") + + # Number of retries for transient network errors (proxy hiccups etc.) + _SEND_RETRIES = 2 + _SEND_RETRY_BACKOFF = 2.0 # seconds + + async def _send_with_retry(self, send_fn, *, label: str): + """Run a single Telegram API call with retries for transient network errors.""" + from telegram.error import NetworkError, TimedOut + last_err = None + for attempt in range(self._SEND_RETRIES + 1): + try: + return await send_fn() + except (NetworkError, TimedOut) as e: + last_err = e + if attempt >= self._SEND_RETRIES: + break + wait = self._SEND_RETRY_BACKOFF * (attempt + 1) + logger.warning( + f"[Telegram] {label} transient error (attempt {attempt + 1}/" + f"{self._SEND_RETRIES + 1}): {e}; retry in {wait}s" + ) + await asyncio.sleep(wait) + raise last_err + + async def _async_send(self, reply: Reply, chat_id, reply_to_msg_id): + try: + rtype = reply.type + content = reply.content + + if rtype == ReplyType.TEXT or rtype == ReplyType.INFO or rtype == ReplyType.ERROR: + # Telegram caps a single text message at 4096 chars; auto-split + text = str(content) if content is not None else "" + if not text: + return + for chunk in _split_text(text, 4000): + await self._send_with_retry( + lambda c=chunk: self._bot.send_message( + chat_id=chat_id, + text=c, + reply_to_message_id=reply_to_msg_id, + # Avoid failing the whole send if reply_to was deleted + allow_sending_without_reply=True, + ), + label="send_message", + ) + + elif rtype == ReplyType.IMAGE: + # Already a local BytesIO; send it directly + content.seek(0) + await self._send_with_retry( + lambda: self._bot.send_photo( + chat_id=chat_id, + photo=content, + reply_to_message_id=reply_to_msg_id, + allow_sending_without_reply=True, + ), + label="send_photo", + ) + + elif rtype == ReplyType.IMAGE_URL: + url = str(content) + if url.startswith("file://"): + local = url[7:] + # Open inside the lambda so each retry gets a fresh stream + async def _send_local_photo(): + with open(local, "rb") as f: + return await self._bot.send_photo( + chat_id=chat_id, photo=f, + reply_to_message_id=reply_to_msg_id, + allow_sending_without_reply=True, + ) + await self._send_with_retry(_send_local_photo, label="send_photo(file)") + else: + await self._send_with_retry( + lambda: self._bot.send_photo( + chat_id=chat_id, photo=url, + reply_to_message_id=reply_to_msg_id, + allow_sending_without_reply=True, + ), + label="send_photo(url)", + ) + + elif rtype == ReplyType.VOICE: + local = content[7:] if isinstance(content, str) and content.startswith("file://") else content + async def _send_voice(): + with open(local, "rb") as f: + return await self._bot.send_voice( + chat_id=chat_id, voice=f, + reply_to_message_id=reply_to_msg_id, + allow_sending_without_reply=True, + ) + await self._send_with_retry(_send_voice, label="send_voice") + + elif rtype == ReplyType.FILE: + # Videos go through send_video, everything else through send_document + local = content[7:] if isinstance(content, str) and content.startswith("file://") else content + # File replies may carry an accompanying text caption + caption = getattr(reply, "text_content", None) or None + is_video = isinstance(local, str) and local.lower().endswith( + (".mp4", ".mov", ".avi", ".mkv", ".webm") + ) + + async def _send_file(): + with open(local, "rb") as f: + if is_video: + return await self._bot.send_video( + chat_id=chat_id, video=f, caption=caption, + reply_to_message_id=reply_to_msg_id, + allow_sending_without_reply=True, + ) + return await self._bot.send_document( + chat_id=chat_id, document=f, caption=caption, + reply_to_message_id=reply_to_msg_id, + allow_sending_without_reply=True, + ) + await self._send_with_retry(_send_file, label="send_video" if is_video else "send_document") + + else: + # Fallback: send as plain text + await self._send_with_retry( + lambda: self._bot.send_message( + chat_id=chat_id, text=str(content), + reply_to_message_id=reply_to_msg_id, + allow_sending_without_reply=True, + ), + label="send_message(fallback)", + ) + + logger.info(f"[Telegram] sent reply (type={rtype}, chat_id={chat_id})") + + except Exception as e: + logger.error(f"[Telegram] _async_send error: {e}", exc_info=True) + + +def _split_text(text: str, limit: int): + """Split long text preferring line breaks to keep markdown structure intact.""" + if len(text) <= limit: + yield text + return + buf = [] + size = 0 + for line in text.splitlines(keepends=True): + if size + len(line) > limit and buf: + yield "".join(buf) + buf, size = [], 0 + # Hard-split single lines that exceed the limit + while len(line) > limit: + yield line[:limit] + line = line[limit:] + buf.append(line) + size += len(line) + if buf: + yield "".join(buf) diff --git a/channel/telegram/telegram_message.py b/channel/telegram/telegram_message.py new file mode 100644 index 00000000..c97c6059 --- /dev/null +++ b/channel/telegram/telegram_message.py @@ -0,0 +1,62 @@ +""" +Telegram message adapter. + +Convert a python-telegram-bot Update into cow's unified ChatMessage. +File downloads are NOT performed here; the channel layer triggers +bot.get_file() on demand because it requires the async event loop. +""" +import os + +from bridge.context import ContextType +from channel.chat_message import ChatMessage +from common.utils import expand_path +from config import conf + + +class TelegramMessage(ChatMessage): + """Wrap a Telegram Update into the unified ChatMessage.""" + + def __init__(self, update, is_group: bool = False, bot_username: str = "", + ctype: ContextType = ContextType.TEXT, content: str = ""): + super().__init__(update) + message = update.effective_message + chat = update.effective_chat + user = update.effective_user + + # Basic fields + self.msg_id = str(message.message_id) if message else "" + self.create_time = int(message.date.timestamp()) if message and message.date else 0 + self.ctype = ctype + self.content = content + + # Sender / chat info + from_user_id = str(user.id) if user else "unknown" + from_user_nick = ( + user.full_name if user and user.full_name else (user.username if user else "unknown") + ) + self.from_user_id = from_user_id + self.from_user_nickname = from_user_nick or from_user_id + self.to_user_id = bot_username or "telegram_bot" + self.to_user_nickname = bot_username or "telegram_bot" + + self.is_group = is_group + if is_group: + # Group: other_user_id = group_id, actual_user_id = sender id + self.other_user_id = str(chat.id) + self.other_user_nickname = chat.title or str(chat.id) + self.actual_user_id = from_user_id + self.actual_user_nickname = self.from_user_nickname + else: + self.other_user_id = from_user_id + self.other_user_nickname = self.from_user_nickname + + # Whether the bot was triggered by @-mention or reply (set by channel layer) + self.is_at = False + + @staticmethod + def get_tmp_dir() -> str: + """Local download directory, aligned with other channels (agent_workspace/tmp).""" + workspace_root = expand_path(conf().get("agent_workspace", "~/cow")) + tmp_dir = os.path.join(workspace_root, "tmp") + os.makedirs(tmp_dir, exist_ok=True) + return tmp_dir diff --git a/channel/web/chat.html b/channel/web/chat.html index 56ce808f..d90adb15 100644 --- a/channel/web/chat.html +++ b/channel/web/chat.html @@ -137,6 +137,11 @@ 配置 + + + 模型 + @@ -417,21 +422,30 @@ - +
+ + +
@@ -460,6 +474,11 @@

模型配置

+
+ 高级配置 + +
@@ -850,6 +869,41 @@
+ + + +
+ + +
+
+
+
+

模型管理

+

统一管理对话、视觉、语音、向量、图像、搜索能力

+
+ +
+
+ Loading... +
+ +
+
+
+ @@ -959,7 +1013,7 @@ - `; @@ -1480,13 +1993,33 @@ function startSSE(requestId, loadingEl, timestamp, titleInfo) { stepsEl.appendChild(wrap); scrollChatToBottom(); - } else if (item.type === 'done') { - done = true; - es.close(); - delete activeStreams[requestId]; + } else if (item.type === 'cancelled') { + // Agent acknowledged the stop; mark the bubble. A trailing + // "done" still arrives with the partial answer. + ensureBotEl(); + if (currentReasoningEl) { + finalizeThinking(currentReasoningEl, reasoningStartTime, reasoningText); + currentReasoningEl = null; + reasoningText = ''; + } + if (!botEl.querySelector('.agent-cancelled-tag')) { + const tag = document.createElement('div'); + tag.className = 'agent-cancelled-tag text-xs text-amber-600 dark:text-amber-400 mt-1'; + tag.textContent = (currentLang === 'zh') ? '已中止' : 'Cancelled'; + stepsEl.appendChild(tag); + } + resetSendBtnSendMode(); - // item.content may be empty when "done" is only a stream-close signal after media. - const finalText = item.content || accumulatedText; + } else if (item.type === 'done') { + // Don't close the stream yet: the backend keeps it open + // for a short tail to deliver async attachments such as + // TTS audio (`voice_attach`). It will close the stream on + // its own via onerror once the tail expires. + done = true; + resetSendBtnSendMode(); + + const finalTextRaw = item.content || accumulatedText; + const finalText = localizeCancelMarker(finalTextRaw); if (!botEl && finalText) { if (loadingEl) { loadingEl.remove(); loadingEl = null; } @@ -1494,11 +2027,12 @@ function startSSE(requestId, loadingEl, timestamp, titleInfo) { } else if (botEl) { contentEl.classList.remove('sse-streaming'); if (finalText) contentEl.innerHTML = renderMarkdown(finalText); - contentEl.dataset.rawMd = finalText || ''; + contentEl.dataset.rawMd = finalTextRaw || ''; const copyBtn = botEl.querySelector('.copy-msg-btn'); if (copyBtn && finalText) copyBtn.style.display = ''; applyHighlighting(botEl); } + renderBotSpeakerButton(botEl, finalText); scrollChatToBottom(); if (titleInfo) { @@ -1508,12 +2042,22 @@ function startSSE(requestId, loadingEl, timestamp, titleInfo) { loadSessionList(); } + } else if (item.type === 'voice_attach') { + // TTS finished — attach a playable audio element to the + // current bot bubble. The stream closes right after. + if (botEl && item.url) { + attachAudioToBotBubble(botEl, item.url, { autoplay: true }); + } + es.close(); + delete activeStreams[requestId]; + } else if (item.type === 'error') { done = true; es.close(); delete activeStreams[requestId]; if (loadingEl) { loadingEl.remove(); loadingEl = null; } addBotMessage(t('error_send'), new Date()); + resetSendBtnSendMode(); } }; @@ -1521,7 +2065,10 @@ function startSSE(requestId, loadingEl, timestamp, titleInfo) { es.close(); delete activeStreams[requestId]; - if (done) return; + if (done) { + // Normal close after the post-done tail expired; nothing to do. + return; + } if (currentReasoningEl) { finalizeThinking(currentReasoningEl, reasoningStartTime, reasoningText); @@ -1547,6 +2094,7 @@ function startSSE(requestId, loadingEl, timestamp, titleInfo) { applyHighlighting(botEl); bindChatKnowledgeLinks(botEl); } + resetSendBtnSendMode(); }; } @@ -1785,13 +2333,23 @@ function _renderSentFileFromToolResult(step) { ` ${escapeHtml(fileName)}`; } +// Cosmetic translator for cancel markers persisted in history. +// History keeps the English canonical form for the LLM; only display is localized. +function localizeCancelMarker(text) { + if (!text) return text; + if (currentLang !== 'zh') return text; + return text + .replace(/_\(Cancelled by user\)_/g, '_(用户已中止)_') + .replace(/_\(Cancelled\)_/g, '_(已中止)_'); +} + function createBotMessageEl(content, timestamp, requestId, msg) { const el = document.createElement('div'); el.className = 'flex gap-3 px-4 sm:px-6 py-3'; if (requestId) el.dataset.requestId = requestId; let stepsHtml = ''; - let displayContent = content; + let displayContent = localizeCancelMarker(content); if (msg && msg.steps && msg.steps.length > 0) { // New format: ordered steps with interleaved content @@ -1812,21 +2370,174 @@ function createBotMessageEl(content, timestamp, requestId, msg) {
${stepsHtml ? `
${stepsHtml}
` : ''}
${renderMarkdown(displayContent)}
+
${formatTime(timestamp)} +
`; el.querySelector('.answer-content').dataset.rawMd = displayContent; + // Existing TTS attachment (history replay): mount the player up-front. + const existingAudio = msg && msg.extras && msg.extras.audio && msg.extras.audio.url; + if (existingAudio) { + attachAudioToBotBubble(el, existingAudio, { autoplay: false }); + } + renderBotSpeakerButton(el, displayContent); applyHighlighting(el); bindChatKnowledgeLinks(el); return el; } +// Append (or replace) a small audio player inside a bot bubble's +// dedicated `.bot-audio-slot`. Used by both live TTS pushes and history +// replay. Silent failures: never throws. +function attachAudioToBotBubble(botEl, audioUrl, opts) { + try { + if (!botEl || !audioUrl) return; + const slot = botEl.querySelector('.bot-audio-slot'); + if (!slot) return; + slot.innerHTML = ''; + slot.style.marginTop = '6px'; + const pill = renderVoicePill(audioUrl, { autoplay: !!(opts && opts.autoplay) }); + slot.appendChild(pill); + const speakBtn = botEl.querySelector('.speak-msg-btn'); + if (speakBtn) speakBtn.style.display = 'none'; + } catch (_) { /* silent */ } +} + +// Build a compact play/pause + progress + duration pill that wraps a +// hidden