mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 18:17:11 +08:00
Compare commits
333 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce4c0a0aa4 | ||
|
|
64511593c4 | ||
|
|
b0e00dfceb | ||
|
|
fc465b463d | ||
|
|
68ce2e5232 | ||
|
|
81e8bb62ae | ||
|
|
2c13e1b923 | ||
|
|
a0748c2e3b | ||
|
|
40599bb751 | ||
|
|
f3c64ceea7 | ||
|
|
15c60de709 | ||
|
|
6dd316547f | ||
|
|
54c7676a44 | ||
|
|
d25b8966ce | ||
|
|
14a119c48c | ||
|
|
c82515a927 | ||
|
|
26e630c2dd | ||
|
|
13370d2056 | ||
|
|
35282db9e0 | ||
|
|
426fb88ce7 | ||
|
|
2384bd0e10 | ||
|
|
ba3f66d3d1 | ||
|
|
7293a0f670 | ||
|
|
9e86d46267 | ||
|
|
848430f062 | ||
|
|
abd21335c4 | ||
|
|
8fa95f058a | ||
|
|
d4e5ecd497 | ||
|
|
3830f76729 | ||
|
|
83f778fec9 | ||
|
|
cabd24605f | ||
|
|
ae20ba1148 | ||
|
|
3a50b64977 | ||
|
|
8692e74536 | ||
|
|
1c18bd9889 | ||
|
|
60e9d98d0a | ||
|
|
83f6625e0c | ||
|
|
acc09543b7 | ||
|
|
94d8c7e366 | ||
|
|
ea1a0c8b3d | ||
|
|
7bc88c17e4 | ||
|
|
33cf1bc4c3 | ||
|
|
9402e63fe1 | ||
|
|
90e4d494b2 | ||
|
|
da97e948ca | ||
|
|
89a07e8e74 | ||
|
|
3f3d0381e5 | ||
|
|
3649499dba | ||
|
|
a989d088fd | ||
|
|
f79a915136 | ||
|
|
12e8c3d449 | ||
|
|
4f7064575e | ||
|
|
070df826f1 | ||
|
|
fbe48a4b4e | ||
|
|
4dd497fb6d | ||
|
|
907882c0a7 | ||
|
|
d36d5aee3f | ||
|
|
c6824e5f5e | ||
|
|
199c21eede | ||
|
|
5162da5654 | ||
|
|
a1d82f6193 | ||
|
|
ea78e3d0c6 | ||
|
|
3497f00cb4 | ||
|
|
5355d45031 | ||
|
|
26693acc3f | ||
|
|
76e9fef3b2 | ||
|
|
c34308cbd4 | ||
|
|
5a10476010 | ||
|
|
46e80dceec | ||
|
|
90d1835353 | ||
|
|
845fadd0aa | ||
|
|
5748ded52c | ||
|
|
6a737fb734 | ||
|
|
3cd92ccda3 | ||
|
|
54e81aba11 | ||
|
|
d86cb4ded6 | ||
|
|
4d5375f6d6 | ||
|
|
424557fedb | ||
|
|
89251e603f | ||
|
|
a653ed07eb | ||
|
|
ad86deb014 | ||
|
|
9525dc7584 | ||
|
|
cd31dd27fd | ||
|
|
360e3670eb | ||
|
|
8dabe3b4c8 | ||
|
|
443e0c2806 | ||
|
|
9cc173cc4d | ||
|
|
b5f33e5ecd | ||
|
|
40dfc6860f | ||
|
|
1c02a04423 | ||
|
|
de0e45070c | ||
|
|
c169cc7d74 | ||
|
|
cd62ad76f6 | ||
|
|
dd25b0fb5b | ||
|
|
a38b22a6a2 | ||
|
|
830b8f2971 | ||
|
|
b058af122c | ||
|
|
174ee0cafc | ||
|
|
1c336380c0 | ||
|
|
3068880413 | ||
|
|
be596681e5 | ||
|
|
66b71c50e9 | ||
|
|
8744810b25 | ||
|
|
7f94d37c2e | ||
|
|
6d9b7baeb4 | ||
|
|
4470d4c352 | ||
|
|
d2a462a279 | ||
|
|
14ff2a15e7 | ||
|
|
6d1369900e | ||
|
|
1f17ebe69e | ||
|
|
1ae2918064 | ||
|
|
b6571e5cad | ||
|
|
7549d48cf1 | ||
|
|
00353dd0cb | ||
|
|
afd947195d | ||
|
|
e57ef37167 | ||
|
|
ef33a93654 | ||
|
|
61732aecfc | ||
|
|
6764c05c3f | ||
|
|
fa149cf4aa | ||
|
|
e4f9697d06 | ||
|
|
da061450e5 | ||
|
|
d09ae49287 | ||
|
|
511ee0bbaf | ||
|
|
3cb5a0fbd6 | ||
|
|
e06925ab85 | ||
|
|
184634e4e7 | ||
|
|
843c2d02cc | ||
|
|
8ea2455766 | ||
|
|
9dc9987d56 | ||
|
|
3458621147 | ||
|
|
079df5a47c | ||
|
|
ddb07c65a1 | ||
|
|
9b21cd222b | ||
|
|
90f736843f | ||
|
|
13c020eb61 | ||
|
|
dbc06dbe95 | ||
|
|
23d097bc1c | ||
|
|
db85b9808e | ||
|
|
df5bae37bc | ||
|
|
acc23b6051 | ||
|
|
61f2741afc | ||
|
|
4dd7ea886a | ||
|
|
1e8959fbcf | ||
|
|
48729678cf | ||
|
|
0684becaa7 | ||
|
|
db16bdf8cb | ||
|
|
f890318ed9 | ||
|
|
158510cbbe | ||
|
|
ce90cf7aa8 | ||
|
|
a3a3d006eb | ||
|
|
8fd029a4a1 | ||
|
|
2e1b52c1e5 | ||
|
|
3eb8348708 | ||
|
|
393f0c007c | ||
|
|
294e380288 | ||
|
|
4c1c42efac | ||
|
|
c062ca8c66 | ||
|
|
76dcb25103 | ||
|
|
c5b4f236db | ||
|
|
0974c940a8 | ||
|
|
cffa20d37e | ||
|
|
ef009edd29 | ||
|
|
3ca52b118d | ||
|
|
13f5fde4fb | ||
|
|
f512b55ec2 | ||
|
|
22b8ca0095 | ||
|
|
baf66a103d | ||
|
|
45faa9c1ff | ||
|
|
304381a88d | ||
|
|
fc9f54dbc8 | ||
|
|
7199dc187f | ||
|
|
e9ae066d53 | ||
|
|
d71ae406ff | ||
|
|
f3216904b3 | ||
|
|
5958b69ec9 | ||
|
|
7d4e2cb39a | ||
|
|
a483ec0cea | ||
|
|
c1421e0874 | ||
|
|
ce89869c3c | ||
|
|
b8b57e34ff | ||
|
|
bc7f627253 | ||
|
|
652156e398 | ||
|
|
9febb071c6 | ||
|
|
7d0e1568ac | ||
|
|
b4e711f411 | ||
|
|
1b5be1b981 | ||
|
|
49d8707c58 | ||
|
|
9192f6f7f7 | ||
|
|
05022e3745 | ||
|
|
5356e9ddeb | ||
|
|
52acf76e2c | ||
|
|
40cdbd3b45 | ||
|
|
5487c0befe | ||
|
|
8bb16c48c0 | ||
|
|
c6384363f9 | ||
|
|
8993e8ad3e | ||
|
|
289989d9f7 | ||
|
|
dc2ae0e6f1 | ||
|
|
9c966c152d | ||
|
|
4efae41048 | ||
|
|
b8437032e9 | ||
|
|
2d339ca81b | ||
|
|
d53abc9696 | ||
|
|
446c886d38 | ||
|
|
30c6d9b5ae | ||
|
|
5e42996b36 | ||
|
|
ceca7b85bf | ||
|
|
a4d54f58c8 | ||
|
|
005a0e1bad | ||
|
|
46d97fd57d | ||
|
|
72a26b6353 | ||
|
|
89a4033fbf | ||
|
|
39a5dc64bd | ||
|
|
d4bdd9b1b7 | ||
|
|
2f5ba87280 | ||
|
|
8b45d6c750 | ||
|
|
4ecd4df2d4 | ||
|
|
a42f31fe52 | ||
|
|
d4480b695e | ||
|
|
c4b5f7fbae | ||
|
|
ba915f2cc0 | ||
|
|
4b91140f31 | ||
|
|
9879878dd0 | ||
|
|
d78105d57c | ||
|
|
153c9e3565 | ||
|
|
c11623596d | ||
|
|
e791a77f77 | ||
|
|
b641bffb2c | ||
|
|
ee0c47ac1e | ||
|
|
eba90e9343 | ||
|
|
d8374d0fa5 | ||
|
|
fa61744c6d | ||
|
|
4fec55cc01 | ||
|
|
1767413712 | ||
|
|
734c8fa84f | ||
|
|
9a8d422554 | ||
|
|
b21e945c76 | ||
|
|
a02bf1ea09 | ||
|
|
eda82bac92 | ||
|
|
e8d4f7dc4f | ||
|
|
c4a93b7789 | ||
|
|
c3f9925097 | ||
|
|
2a0cf7511a | ||
|
|
d0a70d3339 | ||
|
|
f37e4675dd | ||
|
|
4e32f67eeb | ||
|
|
36d54cab52 | ||
|
|
9d8df10dcf | ||
|
|
45ea88e070 | ||
|
|
d5d0b947f5 | ||
|
|
f775f1f11e | ||
|
|
f1e888f3de | ||
|
|
71c8436e90 | ||
|
|
08c69f5e9b | ||
|
|
a50fafaca2 | ||
|
|
3c6781d240 | ||
|
|
3b8b5625f8 | ||
|
|
6be2034110 | ||
|
|
924dc79f00 | ||
|
|
ccb9030d3c | ||
|
|
8623287ac1 | ||
|
|
022c13f3a4 | ||
|
|
0687916e7f | ||
|
|
bb868b83ba | ||
|
|
24298130b9 | ||
|
|
6e5ee92ebd | ||
|
|
5b91fe04aa | ||
|
|
1623deb3ee | ||
|
|
4a16e05b7a | ||
|
|
f1c04bc60d | ||
|
|
84c6f31c76 | ||
|
|
9d528190bf | ||
|
|
0f23b209ad | ||
|
|
63d9325900 | ||
|
|
f342097f81 | ||
|
|
b4806c4366 | ||
|
|
ff37d8a577 | ||
|
|
a773eb7893 | ||
|
|
7c67513d24 | ||
|
|
6ed85029c5 | ||
|
|
e9c57ddf4d | ||
|
|
a33ce97ed9 | ||
|
|
b788a3dd4e | ||
|
|
fccfa92d7e | ||
|
|
8705bf0a70 | ||
|
|
9318138af7 | ||
|
|
269fa7d2d5 | ||
|
|
e99837a8b9 | ||
|
|
553861a2c4 | ||
|
|
628a85d1be | ||
|
|
2cb54514a4 | ||
|
|
6db22827f2 | ||
|
|
4cc6d5426b | ||
|
|
7d258b5202 | ||
|
|
c8d19ee0bc | ||
|
|
d891312032 | ||
|
|
5edbf4ce32 | ||
|
|
3ddbdd713d | ||
|
|
9ba107b511 | ||
|
|
c9adddb76a | ||
|
|
f0a12d5ff5 | ||
|
|
7cce224499 | ||
|
|
97397ca585 | ||
|
|
f2fbc602a8 | ||
|
|
925d728a86 | ||
|
|
f5f229871b | ||
|
|
9917552b4b | ||
|
|
adca89b973 | ||
|
|
29bfbecdc9 | ||
|
|
1a7a8c98d9 | ||
|
|
cddb38ac3d | ||
|
|
394853c0fb | ||
|
|
c0702c8b36 | ||
|
|
d610608391 | ||
|
|
9082eec91d | ||
|
|
f1a1413b5f | ||
|
|
c1e7f9af9b | ||
|
|
1c71c4e38b | ||
|
|
5e3eccb3f6 | ||
|
|
e1dc037eb9 | ||
|
|
97e9b4c801 | ||
|
|
52d7cad735 | ||
|
|
c0b1d270ba | ||
|
|
e59a2892e4 | ||
|
|
5fa0376a49 | ||
|
|
05a33042c8 | ||
|
|
ce58f23cbc | ||
|
|
b6fc9fa370 | ||
|
|
00ae38faae | ||
|
|
ab28ee58ab | ||
|
|
48db538a2e | ||
|
|
46945942e1 |
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
@@ -79,8 +79,6 @@ body:
|
||||
description: |
|
||||
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
|
||||
options:
|
||||
- wx(个人微信, itchat)
|
||||
- wxy(个人微信, wechaty)
|
||||
- wechatmp(公众号, 订阅号)
|
||||
- wechatmp_service(公众号, 服务号)
|
||||
- terminal
|
||||
|
||||
11
.github/workflows/deploy-image-arm.yml
vendored
11
.github/workflows/deploy-image-arm.yml
vendored
@@ -19,7 +19,7 @@ env:
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
if: github.repository == 'zhayujie/CowAgent'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -51,7 +51,12 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/zhayujie/chatgpt-on-wechat
|
||||
${{ env.REGISTRY }}/zhayujie/cowagent
|
||||
tags: |
|
||||
type=raw,value=latest-arm64,enable={{is_default_branch}}
|
||||
type=ref,event=branch,suffix=-arm64
|
||||
type=ref,event=tag,suffix=-arm64
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
@@ -60,7 +65,7 @@ jobs:
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
platforms: linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}-arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
|
||||
13
.github/workflows/deploy-image.yml
vendored
13
.github/workflows/deploy-image.yml
vendored
@@ -16,10 +16,11 @@ on:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
DOCKERHUB_IMAGE: zhayujie/chatgpt-on-wechat
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
if: github.repository == 'zhayujie/CowAgent'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -47,8 +48,14 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
zhayujie/chatgpt-on-wechat
|
||||
zhayujie/cowagent
|
||||
${{ env.REGISTRY }}/zhayujie/chatgpt-on-wechat
|
||||
${{ env.REGISTRY }}/zhayujie/cowagent
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
|
||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -3,16 +3,15 @@
|
||||
.vscode
|
||||
.venv
|
||||
.vs
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
*.pyc
|
||||
python
|
||||
config.json
|
||||
QR.png
|
||||
nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
logs/
|
||||
workspace
|
||||
@@ -34,7 +33,16 @@ plugins/banwords/lib/__pycache__
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
!plugins/agent
|
||||
!plugins/cow_cli
|
||||
client_config.json
|
||||
ref/
|
||||
**/.dev.vars
|
||||
.cursor/
|
||||
local/
|
||||
node_modules/
|
||||
|
||||
# cow cli
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
.cow.pid
|
||||
|
||||
593
README.md
593
README.md
@@ -1,37 +1,50 @@
|
||||
<p align="center"><img src= "https://github.com/user-attachments/assets/eca9a9ec-8534-4615-9e0f-96c5ac1d10a3" alt="Chatgpt-on-Wechat" width="550" /></p>
|
||||
<p align="center"><img src= "https://github.com/user-attachments/assets/eca9a9ec-8534-4615-9e0f-96c5ac1d10a3" alt="CowAgent" width="550" /></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/chatgpt-on-wechat" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/chatgpt-on-wechat" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat"><img src="https://img.shields.io/github/stars/zhayujie/chatgpt-on-wechat?style=flat-square" alt="Stars"></a> <br/>
|
||||
<a href="https://github.com/zhayujie/CowAgent/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/CowAgent" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/CowAgent/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/CowAgent" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/CowAgent"><img src="https://img.shields.io/github/stars/zhayujie/CowAgent?style=flat-square" alt="Stars"></a> <br/>
|
||||
[中文] | [<a href="docs/en/README.md">English</a>] | [<a href="docs/ja/README.md">日本語</a>]
|
||||
</p>
|
||||
|
||||
**CowAgent** 是基于大模型的超级AI助理,能够主动思考和任务规划、操作计算机和外部资源、创造和执行Skills、拥有长期记忆并不断成长。CowAgent 支持灵活切换多种模型,能处理文本、语音、图片、文件等多模态消息,可接入网页、飞书、钉钉、企业微信应用、微信公众号中使用,7*24小时运行于你的个人电脑或服务器中。
|
||||
**CowAgent** 是基于大模型的超级 AI 助理,能够主动思考和任务规划、操作计算机和外部资源、创造和执行 Skills、拥有长期记忆和知识库并不断成长,比 OpenClaw 更轻量和便捷。CowAgent 支持灵活切换多种模型,能处理文本、语音、图片、文件等多模态消息,可接入微信、飞书、钉钉、企微智能机器人、QQ、企微自建应用、微信公众号、网页中使用,7*24小时运行于你的个人电脑或服务器中。
|
||||
|
||||
<p align="center">
|
||||
<a href="https://cowagent.ai/">🌐 官网</a> ·
|
||||
<a href="https://docs.cowagent.ai/">📖 文档中心</a> ·
|
||||
<a href="https://docs.cowagent.ai/guide/quick-start">🚀 快速开始</a> ·
|
||||
<a href="https://skills.cowagent.ai/">🧩 技能广场</a> ·
|
||||
<a href="https://link-ai.tech/cowagent/create">☁️ 在线体验</a>
|
||||
</p>
|
||||
|
||||
📖能力介绍:[CowAgent 2.0](/docs/agent.md)
|
||||
|
||||
# 简介
|
||||
|
||||
> 该项目既是一个可以开箱即用的超级AI助理,也是一个支持高扩展的Agent框架,可以通过为项目扩展大模型接口、接入渠道、内置工具、Skills系统来灵活实现各种定制需求。核心能力如下:
|
||||
> 该项目既是一个可以开箱即用的超级 AI 助理,也是一个支持高扩展的 Agent 框架,可以通过为项目扩展大模型接口、接入渠道、内置工具、Skills 系统来灵活实现各种定制需求。核心能力如下:
|
||||
|
||||
- ✅ **复杂任务规划**:能够理解复杂任务并自主规划执行,持续思考和调用工具直到完成目标,支持通过工具操作访问文件、终端、浏览器、定时任务等系统资源
|
||||
- ✅ **长期记忆:** 自动将对话记忆持久化至本地文件和数据库中,包括全局记忆和天级记忆,支持关键词及向量检索
|
||||
- ✅ **技能系统:** 实现了Skills创建和运行的引擎,内置多种技能,并支持通过自然语言对话完成自定义Skills开发
|
||||
- ✅ **自主任务规划**:能够理解复杂任务并自主规划执行,持续思考和调用工具直到完成目标
|
||||
- ✅ **长期记忆:** 自动将对话记忆持久化至本地文件和数据库中,包括核心记忆、日级记忆和梦境蒸馏,支持关键词及向量检索
|
||||
- ✅ **个人知识库:** 自动整理结构化知识,通过交叉引用构建知识图谱,支持通过对话管理和可视化浏览知识库
|
||||
- ✅ **技能系统:** Skills 安装和运行的引擎,支持从 [Skill Hub](https://skills.cowagent.ai/)、GitHub 等一键安装技能,或通过对话创造 Skills
|
||||
- ✅ **工具系统:** 内置文件读写、终端执行、浏览器操作、定时任务等工具,Agent 自主调用以完成复杂任务
|
||||
- ✅ **CLI系统:** 提供终端命令和对话命令,支持进程管理、技能安装、配置修改等操作
|
||||
- ✅ **多模态消息:** 支持对文本、图片、语音、文件等多类型消息进行解析、处理、生成、发送等操作
|
||||
- ✅ **多模型接入:** 支持OpenAI, Claude, Gemini, DeepSeek, MiniMax、GLM、Qwen、Kimi等国内外主流模型厂商
|
||||
- ✅ **多端部署:** 支持运行在本地计算机或服务器,可集成到网页、飞书、钉钉、微信公众号、企业微信应用中使用
|
||||
- ✅ **知识库:** 集成企业知识库能力,让Agent成为专属数字员工,基于[LinkAI](https://link-ai.tech)平台实现
|
||||
- ✅ **多模型支持:** 支持 OpenAI, Claude, Gemini, DeepSeek, MiniMax、GLM、Qwen、Kimi、Doubao 等国内外主流模型厂商
|
||||
- ✅ **多通道接入:** 支持运行在本地计算机或服务器,可集成到微信、飞书、钉钉、企业微信、QQ、微信公众号、网页中使用
|
||||
|
||||
## 声明
|
||||
|
||||
1. 本项目遵循 [MIT开源协议](/LICENSE),主要用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为。任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任
|
||||
2. 成本与安全:Agent模式下Token使用量高于普通对话模式,请根据效果及成本综合选择模型。Agent具有访问所在操作系统的能力,请谨慎选择项目部署环境。同时项目也会持续升级安全机制、并降低模型消耗成本
|
||||
1. 本项目遵循 [MIT 开源协议](/LICENSE),主要用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为。任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任。
|
||||
2. 成本与安全:Agent 模式下 Token 使用量高于普通对话模式,请根据效果及成本综合选择模型。Agent 具有访问所在操作系统的能力,请谨慎选择项目部署环境。同时项目也会持续升级安全机制、并降低模型消耗成本。
|
||||
3. CowAgent 项目专注于开源技术开发,不会参与、授权或发行任何加密货币。
|
||||
|
||||
## 演示
|
||||
|
||||
使用说明(Agent模式):[CowAgent介绍](/docs/agent.md)
|
||||
- 使用说明( Agent 模式):[CowAgent 介绍](https://docs.cowagent.ai/intro/features)
|
||||
|
||||
DEMO视频(对话模式):https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
- 免部署在线体验:[CowAgent](https://link-ai.tech/cowagent/create)
|
||||
|
||||
- DEMO 视频(对话模式):https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
## 社区
|
||||
|
||||
@@ -43,11 +56,11 @@ DEMO视频(对话模式):https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
# 企业服务
|
||||
|
||||
<a href="https://link-ai.tech" target="_blank"><img width="720" src="https://cdn.link-ai.tech/image/link-ai-intro.jpg"></a>
|
||||
<a href="https://link-ai.tech" target="_blank"><img width="650" src="https://cdn.link-ai.tech/image/link-ai-intro.jpg"></a>
|
||||
|
||||
> [LinkAI](https://link-ai.tech/) 是面向企业和开发者的一站式AI智能体平台,聚合多模态大模型、知识库、Agent 插件、工作流等能力,支持一键接入主流平台并进行管理,支持SaaS、私有化部署等多种模式。
|
||||
> [LinkAI](https://link-ai.tech/) 是面向企业和个人的一站式 AI 智能体平台,聚合多模态大模型、知识库、技能、工作流等能力,支持一键接入主流平台并管理,支持 SaaS、私有化部署等多种模式,可免部署在线运行[CowAgent 助理](https://link-ai.tech/cowagent/create)。
|
||||
>
|
||||
> LinkAI 目前已在智能客服、私域运营、企业效率助手等场景积累了丰富的AI解决方案,在消费、健康、文教、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。
|
||||
> LinkAI 目前已在智能客服、私域运营、企业效率助手等场景积累了丰富的 AI 解决方案,在消费、健康、文教、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。
|
||||
|
||||
**产品咨询和企业服务** 可联系产品客服:
|
||||
|
||||
@@ -57,17 +70,23 @@ DEMO视频(对话模式):https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
# 🏷 更新日志
|
||||
|
||||
>**2026.02.03:** [2.0.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/2.0.0),正式升级为超级Agent助理,支持多轮任务决策、具备长期记忆、实现多种系统工具、支持Skills框架,新增多种模型并优化了接入渠道。
|
||||
>**2026.04.22:** [2.0.7版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.7),图像生成内置技能(GPT Image 2、Nano Banana 等)、新模型支持(Kimi K2.6、Claude Opus 4.7、GLM 5.1)、知识库和记忆增强、Web 控制台优化
|
||||
|
||||
>**2025.05.23:** [1.7.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.6) 优化web网页channel、新增 [AgentMesh](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md)多智能体插件、百度语音合成优化、企微应用`access_token`获取优化、支持`claude-4-sonnet`和`claude-4-opus`模型
|
||||
>**2026.04.14:** [2.0.6版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.6),知识库系统、梦境记忆模块、上下文智能压缩、Web 控制台多会话及多项优化。
|
||||
|
||||
>**2025.04.11:** [1.7.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.5) 新增支持 [wechatferry](https://github.com/zhayujie/chatgpt-on-wechat/pull/2562) 协议、新增 deepseek 模型、新增支持腾讯云语音能力、新增支持 ModelScope 和 Gitee-AI API接口
|
||||
>**2026.04.01:** [2.0.5版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.5),Cow CLI 命令系统、Skill Hub 开源、浏览器工具、企微扫码创建、多项优化和修复。
|
||||
|
||||
>**2024.12.13:** [1.7.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.4) 新增 Gemini 2.0 模型、新增web channel、解决内存泄漏问题、解决 `#reloadp` 命令重载不生效问题
|
||||
>**2026.03.22:** [2.0.4版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.4),新增个人微信通道(微信扫码即用)、新增 MiniMax-M2.7 和 GLM-5-Turbo 模型、run.sh 脚本重构、日文文档及多项修复。
|
||||
|
||||
>**2024.10.31:** [1.7.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.3) 程序稳定性提升、数据库功能、Claude模型优化、linkai插件优化、离线通知
|
||||
>**2026.03.18:** [2.0.3版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.3),新增企微智能机器人和 QQ 通道、支持 Coding Plan、新增多个模型、Web 端文件处理、记忆系统升级。
|
||||
|
||||
更多更新历史请查看: [更新日志](/docs/release/history.md)
|
||||
>**2026.02.27:** [2.0.2版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.2),Web 控制台全面升级(流式对话、模型/技能/记忆/通道/定时任务/日志管理)、支持多通道同时运行、会话持久化存储、新增多个模型。
|
||||
|
||||
>**2026.02.13:** [2.0.1版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.1),内置 Web Search 工具、智能上下文裁剪策略、运行时信息动态更新、Windows 兼容性适配,修复定时任务记忆丢失、飞书连接等多项问题。
|
||||
|
||||
>**2026.02.03:** [2.0.0版本](https://github.com/zhayujie/CowAgent/releases/tag/2.0.0),正式升级为超级 Agent 助理,支持多轮任务决策、具备长期记忆、实现多种系统工具、支持 Skills 框架,新增多种模型并优化了接入渠道。
|
||||
|
||||
更多更新历史请查看: [更新日志](https://docs.cowagent.ai/releases)
|
||||
|
||||
<br/>
|
||||
|
||||
@@ -77,11 +96,17 @@ DEMO视频(对话模式):https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
在终端执行以下命令:
|
||||
|
||||
**Linux / macOS:**
|
||||
```bash
|
||||
bash <(curl -sS https://cdn.link-ai.tech/code/cow/run.sh)
|
||||
bash <(curl -fsSL https://cdn.link-ai.tech/code/cow/run.sh)
|
||||
```
|
||||
|
||||
脚本使用说明:[一键运行脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/CowAgentQuickStart)
|
||||
**Windows(PowerShell):**
|
||||
```powershell
|
||||
irm https://cdn.link-ai.tech/code/cow/run.ps1 | iex
|
||||
```
|
||||
|
||||
脚本使用说明:[一键运行脚本](https://docs.cowagent.ai/guide/quick-start)。安装后可使用 `cow start`、`cow stop` 等 [CLI 命令](https://docs.cowagent.ai/cli/index) 管理服务。
|
||||
|
||||
|
||||
## 一、准备
|
||||
@@ -90,24 +115,24 @@ bash <(curl -sS https://cdn.link-ai.tech/code/cow/run.sh)
|
||||
|
||||
项目支持国内外主流厂商的模型接口,可选模型及配置说明参考:[模型说明](#模型说明)。
|
||||
|
||||
> 注:Agent模式下推荐使用以下模型,可根据效果及成本综合选择:GLM(glm-4.7)、MiniMAx(MiniMax-M2.1)、Qwen(qwen3-max)、Claude(claude-opus-4-6、claude-sonnet-4-5、claude-sonnet-4-0)、Gemini(gemini-3-flash-preview、gemini-3-pro-preview)
|
||||
> 注:Agent 模式下推荐使用以下模型,可根据效果及成本综合选择:MiniMax-M2.7、glm-5.1、kimi-k2.6、qwen3.5-plus、claude-sonnet-4-6、gemini-3.1-pro-preview、gpt-5.4、gpt-5.4-mini
|
||||
|
||||
同时支持使用 **LinkAI平台** 接口,可灵活切换 OpenAI、Claude、Gemini、DeepSeek、Qwen、Kimi 等多种常用模型,并支持知识库、工作流、插件等Agent能力,参考 [接口文档](https://docs.link-ai.tech/platform/api)。
|
||||
同时支持使用 **LinkAI 平台** 接口,支持上述全部模型,并支持知识库、工作流、插件等 Agent 技能,参考 [接口文档](https://docs.link-ai.tech/platform/api)。
|
||||
|
||||
### 2.环境安装
|
||||
|
||||
支持 Linux、MacOS、Windows 操作系统,可在个人计算机及服务器上运行,需安装 `Python`,Python版本需在3.7 ~ 3.12 之间,推荐使用3.9版本。
|
||||
支持 Linux、MacOS、Windows 操作系统,可在个人计算机及服务器上运行,需安装 `Python`,Python 版本需在 3.7 ~ 3.13 之间。
|
||||
|
||||
> 注意:Agent模式推荐使用源码运行,若选择Docker部署则无需安装python环境和下载源码,可直接快进到下一节。
|
||||
> 注意:Agent 模式推荐使用源码运行,若选择 Docker 部署则无需安装 python 环境和下载源码,可直接快进到下一节。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/zhayujie/chatgpt-on-wechat
|
||||
cd chatgpt-on-wechat/
|
||||
git clone https://github.com/zhayujie/CowAgent
|
||||
cd CowAgent/
|
||||
```
|
||||
|
||||
若遇到网络问题可使用国内仓库地址:https://gitee.com/zhayujie/chatgpt-on-wechat
|
||||
若遇到网络问题可使用国内仓库地址:https://gitee.com/zhayujie/CowAgent
|
||||
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
|
||||
@@ -120,43 +145,70 @@ pip3 install -r requirements.txt
|
||||
```bash
|
||||
pip3 install -r requirements-optional.txt
|
||||
```
|
||||
|
||||
> 国内网络可使用镜像源加速:`pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple`
|
||||
|
||||
如果某项依赖安装失败可注释掉对应的行后重试。
|
||||
|
||||
**(4) 安装 Cow CLI (推荐):**
|
||||
|
||||
```bash
|
||||
pip3 install -e .
|
||||
```
|
||||
|
||||
安装后可使用 `cow` 命令管理服务(启动、停止、更新等)和技能,详见 [命令文档](https://docs.cowagent.ai/cli/index)。
|
||||
|
||||
**(5) 安装浏览器工具 (可选):**
|
||||
|
||||
如果需要 Agent 操作浏览器(如访问网页、填写表单等),需要额外安装浏览器依赖:
|
||||
|
||||
```bash
|
||||
cow install-browser
|
||||
```
|
||||
|
||||
该命令会自动安装 `playwright` 和 Chromium 浏览器,国内网络自动使用镜像加速。详见 [浏览器工具文档](https://docs.cowagent.ai/tools/browser)。
|
||||
|
||||
## 二、配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
配置文件的模板在根目录的 `config-template.json` 中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
|
||||
```bash
|
||||
cp config-template.json config.json
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证JSON格式的规范):
|
||||
然后在 `config.json` 中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证 JSON 格式的规范):
|
||||
|
||||
```bash
|
||||
# config.json 文件内容示例
|
||||
{
|
||||
"channel_type": "web", # 接入渠道类型,默认为web,支持修改为:feishu,dingtalk,wechatcom_app,terminal,wechatmp,wechatmp_service
|
||||
"model": "MiniMax-M2.1", # 模型名称
|
||||
"channel_type": "weixin", # 接入渠道类型,默认为 weixin, 支持修改为 feishu,dingtalk,wecom_bot,qq,wechatcom_app,wechatmp_service,wechatmp,terminal
|
||||
"model": "MiniMax-M2.7", # 模型名称
|
||||
"minimax_api_key": "", # MiniMax API Key
|
||||
"zhipu_ai_api_key": "", # 智谱GLM API Key
|
||||
"dashscope_api_key": "", # 百炼(通义千问)API Key
|
||||
"zhipu_ai_api_key": "", # 智谱 GLM API Key
|
||||
"moonshot_api_key": "", # Kimi/Moonshot API Key
|
||||
"ark_api_key": "", # 豆包(火山方舟) API Key
|
||||
"dashscope_api_key": "", # 百炼(通义千问) API Key
|
||||
"claude_api_key": "", # Claude API Key
|
||||
"claude_api_base": "https://api.anthropic.com/v1", # Claude API 地址,修改可接入三方代理平台
|
||||
"gemini_api_key": "", # Gemini API Key
|
||||
"gemini_api_base": "https://generativelanguage.googleapis.com", # Gemini API地址
|
||||
"gemini_api_base": "https://generativelanguage.googleapis.com", # Gemini API 地址
|
||||
"deepseek_api_key": "", # DeepSeek API Key
|
||||
"deepseek_api_base": "https://api.deepseek.com/v1", # DeepSeek API 地址,可修改为第三方代理
|
||||
"open_ai_api_key": "", # OpenAI API Key
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # OpenAI API 地址
|
||||
"linkai_api_key": "", # LinkAI API Key
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境需要开启代理的可填写该项,如 "127.0.0.1:7890"
|
||||
"proxy": "", # 代理客户端的 ip 和端口,国内环境需要开启代理的可填写该项,如 "127.0.0.1:7890"
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"voice_reply_voice": false, # 是否使用语音回复语音
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台接口
|
||||
"agent": true, # 是否启用Agent模式,启用后拥有多轮工具决策、长期记忆、Skills能力等
|
||||
"agent_workspace": "~/cow", # Agent的工作空间路径,用于存储memory、skills、系统设定等
|
||||
"agent_max_context_tokens": 40000, # Agent模式下最大上下文tokens,超出将自动丢弃最早的上下文
|
||||
"agent_max_context_turns": 30, # Agent模式下最大上下文记忆轮次,每轮包括一次用户提问和AI回复
|
||||
"agent_max_steps": 15 # Agent模式下单次任务的最大决策步数,超出后将停止继续调用工具
|
||||
"use_linkai": false, # 是否使用 LinkAI 接口,默认关闭,设置为 true 后可对接 LinkAI 平台模型
|
||||
"web_password": "", # Web 控制台访问密码,留空则不启用密码保护
|
||||
"agent": true, # 是否启用 Agent 模式,启用后拥有多轮工具决策、长期记忆、Skills 能力等
|
||||
"agent_workspace": "~/cow", # Agent 的工作空间路径,用于存储 memory、skills、系统设定等
|
||||
"agent_max_context_tokens": 50000, # Agent 模式下最大上下文 tokens,超出将自动智能压缩处理
|
||||
"agent_max_context_turns": 20, # Agent 模式下最大上下文记忆轮次,一问一答为一轮,超出后智能压缩处理
|
||||
"agent_max_steps": 20, # Agent 模式下单次任务的最大决策步数,超出后将停止继续调用工具
|
||||
"enable_thinking": false # 是否启用深度思考,开启后 Web 端展示模型推理过程,关闭后可加速响应
|
||||
}
|
||||
```
|
||||
|
||||
@@ -165,28 +217,28 @@ pip3 install -r requirements-optional.txt
|
||||
<details>
|
||||
<summary>1. 语音配置</summary>
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用 openai 的 whisper 模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用 openai 的 whisper 模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配 group_chat_prefix 和 group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊)
|
||||
+ 使用 MiniMax TTS:设置 `"text_to_voice": "minimax"`,并配置 `minimax_api_key`;可通过 `"tts_voice_id"` 指定发音人(如 `English_Graceful_Lady`),`"text_to_voice_model"` 指定模型(如 `speech-2.8-hd`、`speech-2.8-turbo`)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>2. 其他配置</summary>
|
||||
|
||||
+ `model`: 模型名称,Agent模式下推荐使用 `glm-4.7`、`MiniMax-M2.1`、`qwen3-max`、`claude-opus-4-6`、`claude-sonnet-4-5`、`claude-sonnet-4-0`、`gemini-3-flash-preview`、`gemini-3-pro-preview`,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `character_desc`:普通对话模式下的机器人系统提示词。在Agent模式下该配置不生效,由工作空间中的文件内容构成。
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
+ `model`: 模型名称,Agent 模式下推荐使用 `MiniMax-M2.7`、`glm-5.1`、`kimi-k2.6`、`qwen3.6-plus`、`claude-sonnet-4-6`、`gemini-3.1-pro-preview`,全部模型名称参考[common/const.py](https://github.com/zhayujie/CowAgent/blob/master/common/const.py)文件
|
||||
+ `character_desc`:普通对话模式下的机器人系统提示词。在 Agent 模式下该配置不生效,由工作空间中的文件内容构成。
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信 channel 中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成 bot 的触发词。
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>5. LinkAI配置</summary>
|
||||
<summary>3. LinkAI 配置</summary>
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台,使用知识库、工作流、插件等能力, 参考[接口文档](https://docs.link-ai.tech/platform/api/chat)
|
||||
+ `use_linkai`: 是否使用 LinkAI 接口,默认关闭,设置为 true 后可对接 LinkAI 平台,使用模型、知识库、工作流、插件等技能, 参考[接口文档](https://docs.link-ai.tech/platform/api/chat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用或工作流的code,选填,普通对话模式中使用。
|
||||
</details>
|
||||
|
||||
注:全部配置项说明可在 [`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py) 文件中查看。
|
||||
注:全部配置项说明可在 [`config.py`](https://github.com/zhayujie/CowAgent/blob/master/config.py) 文件中查看。
|
||||
|
||||
## 三、运行
|
||||
|
||||
@@ -195,37 +247,48 @@ pip3 install -r requirements-optional.txt
|
||||
如果是个人计算机 **本地运行**,直接在项目根目录下执行:
|
||||
|
||||
```bash
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
cow start # 推荐,需先安装 Cow CLI
|
||||
python3 app.py # 或直接运行,windows 环境下该命令通常为 python app.py
|
||||
```
|
||||
|
||||
运行后默认会启动web服务,可通过访问 `http://localhost:9899/chat` 在网页端对话。
|
||||
运行后默认会启动 web 服务,可通过访问 `http://localhost:9899/chat` 在网页端对话。
|
||||
|
||||
如果需要接入其他应用通道只需修改 `config.json` 配置文件中的 `channel_type` 参数,详情参考:[通道说明](#通道说明)。
|
||||
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
在服务器中可使用 `nohup` 命令在后台运行程序:
|
||||
推荐使用 `cow` 命令管理服务:
|
||||
|
||||
```bash
|
||||
cow start # 后台启动
|
||||
cow stop # 停止服务
|
||||
cow restart # 重启服务
|
||||
cow status # 查看运行状态
|
||||
cow logs # 查看日志
|
||||
cow update # 拉取最新代码并重启
|
||||
```
|
||||
|
||||
也可以使用传统方式后台运行:
|
||||
|
||||
```bash
|
||||
nohup python3 app.py & tail -f nohup.out
|
||||
```
|
||||
|
||||
执行后程序运行于服务器后台,可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。 日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。
|
||||
|
||||
此外,项目的 `scripts` 目录下有一键运行、关闭程序的脚本供使用。 运行后默认channel为web,通过可以通过修改配置文件进行切换。
|
||||
此外,项目根目录下的 `run.sh` 脚本也支持一键管理服务,包括 `./run.sh start`、`./run.sh stop`、`./run.sh restart` 等命令,执行 `./run.sh help` 可查看全部用法。
|
||||
|
||||
> 如果需要通过浏览器访问 Web 控制台,请确保服务器的 `9899` 端口已在防火墙或安全组中放行,建议仅对指定 IP 开放以保证安全。
|
||||
|
||||
### 3.Docker部署
|
||||
|
||||
使用docker部署无需下载源码和安装依赖,只需要获取 `docker-compose.yml` 配置文件并启动容器即可。Agent模式下更推荐使用源码进行部署,以获得更多系统访问能力。
|
||||
使用 docker 部署无需下载源码和安装依赖,只需要获取 `docker-compose.yml` 配置文件并启动容器即可。Agent 模式下更推荐使用源码进行部署,以获得更多系统访问能力。
|
||||
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功后执行 `docker -v` 和 `docker-compose version` (或 `docker compose version`) 可查看到版本号。安装地址为 [docker官网](https://docs.docker.com/engine/install/) 。
|
||||
|
||||
**(1) 下载 docker-compose.yml 文件**
|
||||
|
||||
```bash
|
||||
wget https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
curl -O https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
```
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 填写所需配置,例如 `CHANNEL_TYPE`、`OPEN_AI_API_KEY` 和等配置。
|
||||
@@ -238,68 +301,57 @@ wget https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
sudo docker compose up -d # 若docker-compose为 1.X 版本,则执行 `sudo docker-compose up -d`
|
||||
```
|
||||
|
||||
运行命令后,会自动取 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取最新release版本的镜像。当执行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。最后执行以下命令可查看容器的运行日志:
|
||||
运行命令后,会自动取 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取最新 release 版本的镜像。当执行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。最后执行以下命令可查看容器的运行日志:
|
||||
|
||||
```bash
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
**(3) 插件使用**
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
|
||||
```
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
**注**:使用docker方式部署的详细教程可以参考:[docker部署CoW项目](https://www.wangpc.cc/ai/docker-deploy-cow/)
|
||||
|
||||
> 如果需要通过浏览器访问 Web 控制台,请确保服务器的 `9899` 端口已在防火墙或安全组中放行,建议仅对指定 IP 开放以保证安全。
|
||||
|
||||
## 模型说明
|
||||
|
||||
以下对所有可支持的模型的配置和使用方法进行说明,模型接口实现在项目的 `models/` 目录下。
|
||||
推荐通过 Web 控制台在线管理模型配置,无需手动编辑文件,详见 [模型文档](https://docs.cowagent.ai/models)。以下是手动修改 `config.json` 配置模型的说明:
|
||||
|
||||
<details>
|
||||
<summary>OpenAI</summary>
|
||||
|
||||
1. API Key创建:在 [OpenAI平台](https://platform.openai.com/api-keys) 创建API Key
|
||||
1. API Key 创建:在 [OpenAI平台](https://platform.openai.com/api-keys) 创建 API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4.1-mini",
|
||||
"model": "gpt-5.4",
|
||||
"open_ai_api_key": "YOUR_API_KEY",
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"bot_type": "chatGPT"
|
||||
"bot_type": "openai"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 与OpenAI接口的 [model参数](https://platform.openai.com/docs/models) 一致,支持包括 o系列、gpt-5.2、gpt-5.1、gpt-4.1等系列模型
|
||||
- `model`: 与 OpenAI 接口的 [model参数](https://platform.openai.com/docs/models) 一致,支持包括 gpt-5.4、gpt-5.4-mini、gpt-5.4-nano、o 系列、gpt-4.1 等模型,Agent 模式推荐使用 `gpt-5.4`、`gpt-5.4-mini`
|
||||
- `open_ai_api_base`: 如果需要接入第三方代理接口,可通过修改该参数进行接入
|
||||
- `bot_type`: 使用OpenAI相关模型时无需填写。当使用第三方代理接口接入Claude等非OpenAI官方模型时,该参数设为 `chatGPT`
|
||||
- `bot_type`: 使用 OpenAI 相关模型时无需填写。当使用第三方代理接口接入 Claude 等非 OpenAI 官方模型时,该参数设为 `openai`
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>LinkAI</summary>
|
||||
|
||||
1. API Key创建:在 [LinkAI平台](https://link-ai.tech/console/interface) 创建API Key
|
||||
1. API Key 创建:在 [LinkAI平台](https://link-ai.tech/console/interface) 创建 API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-5.4-mini",
|
||||
"use_linkai": true,
|
||||
"linkai_api_key": "YOUR API KEY",
|
||||
"linkai_app_code": "YOUR APP CODE"
|
||||
"linkai_api_key": "YOUR API KEY"
|
||||
}
|
||||
```
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台的智能体,使用知识库、工作流、数据库、MCP插件等丰富的Agent能力
|
||||
+ `linkai_api_key`: LinkAI平台的API Key,可在 [控制台](https://link-ai.tech/console/interface) 中创建
|
||||
+ `linkai_app_code`: LinkAI智能体 (应用或工作流) 的code,选填,普通对话模式可用。智能体创建可参考 [说明文档](https://docs.link-ai.tech/platform/quick-start)
|
||||
+ `model`: model字段填写空则直接使用智能体的模型,可在平台中灵活切换,[模型列表](https://link-ai.tech/console/models)中的全部模型均可使用
|
||||
+ `use_linkai`: 是否使用 LinkAI 接口,默认关闭,设置为 true 后可对接 LinkAI 平台的模型,并使用知识库、工作流、数据库、插件等丰富的 Agent 技能
|
||||
+ `linkai_api_key`: LinkAI 平台的 API Key,可在 [控制台](https://link-ai.tech/console/interface) 中创建
|
||||
+ `model`: [模型列表](https://link-ai.tech/console/models)中的全部模型均可使用
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -309,26 +361,26 @@ volumes:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "MiniMax-M2.1",
|
||||
"model": "MiniMax-M2.7",
|
||||
"minimax_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2、abab6.5-chat` 等
|
||||
- `minimax_api_key`:MiniMax平台的API-KEY,在 [控制台](https://platform.minimaxi.com/user-center/basic-information/interface-key) 创建
|
||||
- `model`: 可填写 `MiniMax-M2.7、MiniMax-M2.7-highspeed、MiniMax-M2.5、MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2、abab6.5-chat` 等
|
||||
- `minimax_api_key`:MiniMax 平台的 API-KEY,在 [控制台](https://platform.minimaxi.com/user-center/basic-information/interface-key) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
方式二:OpenAI 兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "MiniMax-M2.1",
|
||||
"bot_type": "openai",
|
||||
"model": "MiniMax-M2.7",
|
||||
"open_ai_api_base": "https://api.minimaxi.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2`,参考[API文档](https://platform.minimaxi.com/document/%E5%AF%B9%E8%AF%9D?key=66701d281d57f38758d581d0#QklxsNSbaf6kM4j6wjO5eEek)
|
||||
- `open_ai_api_base`: MiniMax平台API的 BASE URL
|
||||
- `open_ai_api_key`: MiniMax平台的API-KEY
|
||||
- `bot_type`: OpenAI 兼容方式
|
||||
- `model`: 可填 `MiniMax-M2.7、MiniMax-M2.7-highspeed、MiniMax-M2.5、MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2`,参考[API文档](https://platform.minimaxi.com/document/%E5%AF%B9%E8%AF%9D?key=66701d281d57f38758d581d0#QklxsNSbaf6kM4j6wjO5eEek)
|
||||
- `open_ai_api_base`: MiniMax 平台 API 的 BASE URL
|
||||
- `open_ai_api_key`: MiniMax 平台的 API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -338,109 +390,57 @@ volumes:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "glm-4.7",
|
||||
"model": "glm-5.1",
|
||||
"zhipu_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等, 参考 [glm-4系列模型编码](https://bigmodel.cn/dev/api/normal-model/glm-4)
|
||||
- `zhipu_ai_api_key`: 智谱AI平台的 API KEY,在 [控制台](https://www.bigmodel.cn/usercenter/proj-mgmt/apikeys) 创建
|
||||
- `model`: 可填 `glm-5.1、glm-5-turbo、glm-5、glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等, 参考 [glm 系列模型编码](https://bigmodel.cn/dev/api/normal-model/glm-4)
|
||||
- `zhipu_ai_api_key`: 智谱AI 平台的 API KEY,在 [控制台](https://www.bigmodel.cn/usercenter/proj-mgmt/apikeys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
方式二:OpenAI 兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "glm-4.7",
|
||||
"bot_type": "openai",
|
||||
"model": "glm-5.1",
|
||||
"open_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `glm-4.7、glm-4.6、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等
|
||||
- `open_ai_api_base`: 智谱AI平台的 BASE URL
|
||||
- `open_ai_api_key`: 智谱AI平台的 API KEY
|
||||
- `bot_type`: OpenAI 兼容方式
|
||||
- `model`: 可填 `glm-5.1、glm-5-turbo、glm-5、glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等
|
||||
- `open_ai_api_base`: 智谱AI 平台的 BASE URL
|
||||
- `open_ai_api_key`: 智谱AI 平台的 API KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>通义千问 (Qwen)</summary>
|
||||
|
||||
方式一:官方SDK接入,配置如下(推荐):
|
||||
方式一:官方 SDK 接入,配置如下(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen3-max",
|
||||
"model": "qwen3.6-plus",
|
||||
"dashscope_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `qwen3-max、qwen-max、qwen-plus、qwen-turbo、qwen-long、qwq-plus` 等
|
||||
- `dashscope_api_key`: 通义千问的 API-KEY,参考 [官方文档](https://bailian.console.aliyun.com/?tab=api#/api) ,在 [控制台](https://bailian.console.aliyun.com/?tab=model#/api-key) 创建
|
||||
- `model`: 可填写 `qwen3.6-plus、qwen3.5-plus、qwen3-max、qwen-max、qwen-plus、qwen-turbo、qwen-long、qwq-plus` 等
|
||||
- `dashscope_api_key`: 通义千问的 API-KEY,参考 [官方文档](https://bailian.console.aliyun.com/?tab=api#/api) ,在 [百炼控制台](https://bailian.console.aliyun.com/?tab=model#/api-key) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
方式二:OpenAI 兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "qwen3-max",
|
||||
"bot_type": "openai",
|
||||
"model": "qwen3.6-plus",
|
||||
"open_ai_api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"open_ai_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `bot_type`: OpenAI 兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://help.aliyun.com/zh/model-studio/models?spm=a2c4g.11186623.0.0.78d84823Kth5on#9f8890ce29g5u)
|
||||
- `open_ai_api_base`: 通义千问API的 BASE URL
|
||||
- `open_ai_api_base`: 通义千问 API 的 BASE URL
|
||||
- `open_ai_api_key`: 通义千问的 API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Claude</summary>
|
||||
|
||||
1. API Key创建:在 [Claude控制台](https://console.anthropic.com/settings/keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"claude_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 参考 [官方模型ID](https://docs.anthropic.com/en/docs/about-claude/models/overview#model-aliases) ,支持 `claude-opus-4-6、claude-sonnet-4-5、claude-sonnet-4-0、claude-opus-4-0、claude-3-5-sonnet-latest` 等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Gemini</summary>
|
||||
|
||||
API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn) 创建API Key ,配置如下
|
||||
```json
|
||||
{
|
||||
"model": "gemini-3-flash-preview",
|
||||
"gemini_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 参考[官方文档-模型列表](https://ai.google.dev/gemini-api/docs/models?hl=zh-cn),支持 `gemini-3-flash-preview、gemini-3-pro-preview、gemini-2.5-pro、gemini-2.0-flash` 等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>DeepSeek</summary>
|
||||
|
||||
1. API Key创建:在 [DeepSeek平台](https://platform.deepseek.com/api_keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "deepseek-chat",
|
||||
"open_ai_api_key": "sk-xxxxxxxxxxx",
|
||||
"open_ai_api_base": "https://api.deepseek.com/v1",
|
||||
"bot_type": "chatGPT"
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `deepseek-chat、deepseek-reasoner`,分别对应的是 DeepSeek-V3 和 DeepSeek-R1 模型
|
||||
- `open_ai_api_key`: DeepSeek平台的 API Key
|
||||
- `open_ai_api_base`: DeepSeek平台 BASE URL
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Kimi (Moonshot)</summary>
|
||||
|
||||
@@ -448,32 +448,112 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "moonshot-v1-128k",
|
||||
"model": "kimi-k2.6",
|
||||
"moonshot_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k`
|
||||
- `moonshot_api_key`: Moonshot的API-KEY,在 [控制台](https://platform.moonshot.cn/console/api-keys) 创建
|
||||
- `model`: 可填写 `kimi-k2.6、kimi-k2.5、kimi-k2、moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k`
|
||||
- `moonshot_api_key`: Moonshot 的 API-KEY,在 [控制台](https://platform.moonshot.cn/console/api-keys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
方式二:OpenAI 兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "moonshot-v1-128k",
|
||||
"bot_type": "openai",
|
||||
"model": "kimi-k2.6",
|
||||
"open_ai_api_base": "https://api.moonshot.cn/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写 `moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k`
|
||||
- `open_ai_api_base`: Moonshot的 BASE URL
|
||||
- `open_ai_api_key`: Moonshot的 API-KEY
|
||||
- `bot_type`: OpenAI 兼容方式
|
||||
- `model`: 可填写 `kimi-k2.6、kimi-k2.5、kimi-k2、moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k`
|
||||
- `open_ai_api_base`: Moonshot 的 BASE URL
|
||||
- `open_ai_api_key`: Moonshot 的 API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>豆包 (Doubao)</summary>
|
||||
|
||||
1. API Key 创建:在 [火山方舟控制台](https://console.volcengine.com/ark/region:ark+cn-beijing/apikey) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "doubao-seed-2-0-code-preview-260215",
|
||||
"ark_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `doubao-seed-2-0-code-preview-260215、doubao-seed-2-0-pro-260215、doubao-seed-2-0-lite-260215、doubao-seed-2-0-mini-260215` 等
|
||||
- `ark_api_key`: 火山方舟平台的 API Key,在 [控制台](https://console.volcengine.com/ark/region:ark+cn-beijing/apikey) 创建
|
||||
- `ark_base_url`: 可选,默认为 `https://ark.cn-beijing.volces.com/api/v3`
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Claude</summary>
|
||||
|
||||
1. API Key 创建:在 [Claude控制台](https://console.anthropic.com/settings/keys) 创建 API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"claude_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 参考 [官方模型ID](https://docs.anthropic.com/en/docs/about-claude/models/overview#model-aliases) ,支持 `claude-sonnet-4-6、claude-opus-4-7、claude-opus-4-6、claude-sonnet-4-5、claude-sonnet-4-0、claude-opus-4-0、claude-3-5-sonnet-latest` 等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Gemini</summary>
|
||||
|
||||
API Key 创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn) 创建 API Key ,配置如下
|
||||
```json
|
||||
{
|
||||
"model": "gemini-3.1-flash-lite-preview",
|
||||
"gemini_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 参考[官方文档-模型列表](https://ai.google.dev/gemini-api/docs/models?hl=zh-cn),支持 `gemini-3.1-flash-lite-preview、gemini-3.1-pro-preview、gemini-3-flash-preview、gemini-3-pro-preview` 等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>DeepSeek</summary>
|
||||
|
||||
1. API Key 创建:在 [DeepSeek 平台](https://platform.deepseek.com/api_keys) 创建 API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
方式一:官方接入(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "deepseek-chat",
|
||||
"deepseek_api_key": "sk-xxxxxxxxxxx"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 可填 `deepseek-chat、deepseek-reasoner`,分别对应的是 DeepSeek-V3.2(非思考模式)和 DeepSeek-R1(思考模式)
|
||||
- `deepseek_api_key`: DeepSeek 平台的 API Key
|
||||
- `deepseek_api_base`: 可选,默认为 `https://api.deepseek.com/v1`,可修改为第三方代理地址
|
||||
|
||||
方式二:OpenAI 兼容方式接入:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "deepseek-chat",
|
||||
"bot_type": "openai",
|
||||
"open_ai_api_key": "sk-xxxxxxxxxxx",
|
||||
"open_ai_api_base": "https://api.deepseek.com/v1"
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Azure</summary>
|
||||
|
||||
1. API Key创建:在 [Azure平台](https://oai.azure.com/) 创建API Key
|
||||
1. API Key 创建:在 [Azure平台](https://oai.azure.com/) 创建 API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
@@ -490,15 +570,15 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
|
||||
- `model`: 留空即可
|
||||
- `use_azure_chatgpt`: 设为 true
|
||||
- `open_ai_api_key`: Azure平台的密钥
|
||||
- `open_ai_api_base`: Azure平台的 BASE URL
|
||||
- `azure_deployment_id`: Azure平台部署的模型名称
|
||||
- `azure_api_version`: api版本以及以上参数可以在部署的 [模型配置](https://oai.azure.com/resource/deployments) 界面查看
|
||||
- `open_ai_api_key`: Azure 平台的密钥
|
||||
- `open_ai_api_base`: Azure 平台的 BASE URL
|
||||
- `azure_deployment_id`: Azure 平台部署的模型名称
|
||||
- `azure_api_version`: api 版本以及以上参数可以在部署的 [模型配置](https://oai.azure.com/resource/deployments) 界面查看
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>百度文心</summary>
|
||||
方式一:官方SDK接入,配置如下:
|
||||
方式一:官方 SDK 接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -511,19 +591,19 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
- `baidu_wenxin_api_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 API Key
|
||||
- `baidu_wenxin_secret_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 Secret Key
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
方式二:OpenAI 兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"bot_type": "openai",
|
||||
"model": "ERNIE-4.0-Turbo-8K",
|
||||
"open_ai_api_base": "https://qianfan.baidubce.com/v2",
|
||||
"open_ai_api_key": "bce-v3/ALTxxxxxxd2b"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `bot_type`: OpenAI 兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Wm9cvy6rl)
|
||||
- `open_ai_api_base`: 百度文心API的 BASE URL
|
||||
- `open_ai_api_key`: 百度文心的 API-KEY,参考 [官方文档](https://cloud.baidu.com/doc/qianfan-api/s/ym9chdsy5) ,在 [控制台](https://console.bce.baidu.com/iam/#/iam/apikey/list) 创建API Key
|
||||
- `open_ai_api_base`: 百度文心 API 的 BASE URL
|
||||
- `open_ai_api_key`: 百度文心的 API-KEY,参考 [官方文档](https://cloud.baidu.com/doc/qianfan-api/s/ym9chdsy5) ,在 [控制台](https://console.bce.baidu.com/iam/#/iam/apikey/list) 创建 API Key
|
||||
|
||||
</details>
|
||||
|
||||
@@ -547,16 +627,16 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
- `xunfei_domain`: 可填写 `4.0Ultra、generalv3.5、max-32k、generalv3、pro-128k、lite`
|
||||
- `xunfei_spark_url`: 填写参考 [官方文档-请求地址](https://www.xfyun.cn/doc/spark/Web.html#_1-1-%E8%AF%B7%E6%B1%82%E5%9C%B0%E5%9D%80) 的说明
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
方式二:OpenAI 兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"bot_type": "openai",
|
||||
"model": "4.0Ultra",
|
||||
"open_ai_api_base": "https://spark-api-open.xf-yun.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `bot_type`: OpenAI 兼容方式
|
||||
- `model`: 可填写 `4.0Ultra、generalv3.5、max-32k、generalv3、pro-128k、lite`
|
||||
- `open_ai_api_base`: 讯飞星火平台的 BASE URL
|
||||
- `open_ai_api_key`: 讯飞星火平台的[APIPassword](https://console.xfyun.cn/services/bm3) ,因模型而已
|
||||
@@ -575,22 +655,58 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: modelscope接口格式
|
||||
- `bot_type`: modelscope 接口格式
|
||||
- `model`: 参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
- `modelscope_api_key`: 参考 [官方文档-访问令牌](https://modelscope.cn/docs/accounts/token) ,在 [控制台](https://modelscope.cn/my/myaccesstoken)
|
||||
- `modelscope_base_url`: modelscope平台的 BASE URL
|
||||
- `modelscope_base_url`: modelscope 平台的 BASE URL
|
||||
- `text_to_image`: 图像生成模型,参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Coding Plan</summary>
|
||||
|
||||
Coding Plan 是各厂商推出的编程包月套餐,所有厂商均可通过 OpenAI 兼容方式接入:
|
||||
|
||||
```json
|
||||
{
|
||||
"bot_type": "openai",
|
||||
"model": "模型名称",
|
||||
"open_ai_api_base": "厂商 Coding Plan API Base",
|
||||
"open_ai_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
|
||||
目前支持阿里云、MiniMax、智谱 GLM、Kimi、火山引擎等厂商,各厂商详细配置请参考 [Coding Plan 文档](https://docs.cowagent.ai/models/coding-plan)。
|
||||
</details>
|
||||
|
||||
|
||||
## 通道说明
|
||||
|
||||
以下对可接入通道的配置方式进行说明,应用通道代码在项目的 `channel/` 目录下。
|
||||
推荐通过 Web 控制台在线管理通道配置,无需手动编辑文件,详见 [通道文档](https://docs.cowagent.ai/channels/weixin)。以下为手动修改 `config.json` 配置通道的说明:
|
||||
|
||||
支持同时可接入多个通道,配置时可通过逗号进行分割,例如 `"channel_type": "feishu,dingtalk"`。
|
||||
|
||||
<details>
|
||||
<summary>1. Web</summary>
|
||||
<summary>1. Weixin - 微信</summary>
|
||||
|
||||
项目启动后默认运行Web通道,配置如下:
|
||||
接入个人微信,扫码登录即可使用,支持文本、图片、语音、文件等消息收发。
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "weixin"
|
||||
}
|
||||
```
|
||||
|
||||
启动后终端会显示二维码,使用微信扫码授权即可,也可以在 Web 控制台的「通道」页面中扫码接入。登录凭证会自动保存至 `~/.weixin_cow_credentials.json`,下次启动无需重新扫码,如需重新登录删除该文件后重启即可。
|
||||
|
||||
详细步骤和参数说明参考 [微信接入](https://docs.cowagent.ai/channels/weixin)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>2. Web</summary>
|
||||
|
||||
项目启动后会默认运行 Web 控制台,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -600,12 +716,13 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
```
|
||||
|
||||
- `web_port`: 默认为 9899,可按需更改,需要服务器防火墙和安全组放行该端口
|
||||
- `web_password`: 访问密码,留空则不启用密码保护。部署在公网环境时建议设置
|
||||
- 如本地运行,启动后请访问 `http://localhost:9899/chat` ;如服务器运行,请访问 `http://ip:9899/chat`
|
||||
> 注:请将上述 url 中的 ip 或者 port 替换为实际的值
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>2. Feishu - 飞书</summary>
|
||||
<summary>3. Feishu - 飞书</summary>
|
||||
|
||||
飞书支持两种事件接收模式:WebSocket 长连接(推荐)和 Webhook。
|
||||
|
||||
@@ -636,12 +753,12 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
- `feishu_event_mode`: 事件接收模式,`websocket`(推荐)或 `webhook`
|
||||
- WebSocket 模式需安装依赖:`pip3 install lark-oapi`
|
||||
|
||||
详细步骤和参数说明参考 [飞书接入](https://docs.link-ai.tech/cow/multi-platform/feishu)
|
||||
详细步骤和参数说明参考 [飞书接入](https://docs.cowagent.ai/channels/feishu)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>3. DingTalk - 钉钉</summary>
|
||||
<summary>4. DingTalk - 钉钉</summary>
|
||||
|
||||
钉钉需要在开放平台创建智能机器人应用,将以下配置填入 `config.json`:
|
||||
|
||||
@@ -652,11 +769,43 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
"dingtalk_client_secret": "CLIENT_SECRET"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [钉钉接入](https://docs.link-ai.tech/cow/multi-platform/dingtalk)
|
||||
详细步骤和参数说明参考 [钉钉接入](https://docs.cowagent.ai/channels/dingtalk)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>4. WeCom App - 企业微信应用</summary>
|
||||
<summary>5. WeCom Bot - 企微智能机器人</summary>
|
||||
|
||||
企微智能机器人使用 WebSocket 长连接模式,无需公网 IP 和域名,配置简单:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wecom_bot",
|
||||
"wecom_bot_id": "YOUR_BOT_ID",
|
||||
"wecom_bot_secret": "YOUR_SECRET"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [企微智能机器人接入](https://docs.cowagent.ai/channels/wecom-bot)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>6. QQ - QQ 机器人</summary>
|
||||
|
||||
QQ 机器人使用 WebSocket 长连接模式,无需公网 IP 和域名,支持 QQ 单聊、群聊和频道消息:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "qq",
|
||||
"qq_app_id": "YOUR_APP_ID",
|
||||
"qq_app_secret": "YOUR_APP_SECRET"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [QQ 机器人接入](https://docs.cowagent.ai/channels/qq)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>7. WeCom App - 企业微信应用</summary>
|
||||
|
||||
企业微信自建应用接入需在后台创建应用并启用消息回调,配置示例:
|
||||
|
||||
@@ -671,12 +820,12 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
"wechatcomapp_aes_key": "AESKEY"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [企微自建应用接入](https://docs.link-ai.tech/cow/multi-platform/wechat-com)
|
||||
详细步骤和参数说明参考 [企微自建应用接入](https://docs.cowagent.ai/channels/wecom)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>5. WeChat MP - 微信公众号</summary>
|
||||
<summary>8. WeChat MP - 微信公众号</summary>
|
||||
|
||||
本项目支持订阅号和服务号两种公众号,通过服务号(`wechatmp_service`)体验更佳。
|
||||
|
||||
@@ -706,12 +855,12 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
}
|
||||
```
|
||||
|
||||
详细步骤和参数说明参考 [微信公众号接入](https://docs.link-ai.tech/cow/multi-platform/wechat-mp)
|
||||
详细步骤和参数说明参考 [微信公众号接入](https://docs.cowagent.ai/channels/wechatmp)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>6. Terminal - 终端</summary>
|
||||
<summary>9. Terminal - 终端</summary>
|
||||
|
||||
修改 `config.json` 中的 `channel_type` 字段:
|
||||
|
||||
@@ -729,25 +878,37 @@ API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn)
|
||||
|
||||
# 🔗 相关项目
|
||||
|
||||
- [bot-on-anything](https://github.com/zhayujie/bot-on-anything):轻量和高可扩展的大模型应用框架,支持接入Slack, Telegram, Discord, Gmail等海外平台,可作为本项目的补充使用。
|
||||
- [AgentMesh](https://github.com/MinimalFuture/AgentMesh):开源的多智能体(Multi-Agent)框架,可以通过多智能体团队的协同来解决复杂问题。本项目基于该框架实现了[Agent插件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md),可访问终端、浏览器、文件系统、搜索引擎 等各类工具,并实现了多智能体协同。
|
||||
- [Cow Skill Hub](https://github.com/zhayujie/cow-skill-hub):开源的 AI Agent 技能广场,浏览、搜索、安装和发布技能,支持 CowAgent、OpenClaw、Claude Code 等多种 Agent。
|
||||
- [bot-on-anything](https://github.com/zhayujie/bot-on-anything):轻量和高可扩展的大模型应用框架,支持接入 Slack, Telegram, Discord, Gmail 等海外平台,可作为本项目的补充使用。
|
||||
- [AgentMesh](https://github.com/MinimalFuture/AgentMesh):开源的多智能体( Multi-Agent )框架,可以通过多智能体团队的协同来解决复杂问题。
|
||||
|
||||
|
||||
|
||||
|
||||
# 🔎 常见问题
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
FAQs: <https://github.com/zhayujie/CowAgent/wiki/FAQs>
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (知识库持续完善中,回复供参考)
|
||||
|
||||
# 🛠️ 开发
|
||||
|
||||
欢迎接入更多应用通道,参考 [飞书通道](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/feishu/feishu_channel.py) 新增自定义通道,实现接收和发送消息逻辑即可完成接入。 同时欢迎贡献新的Skills,参考 [Skill创造器说明](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/skills/skill-creator/SKILL.md)。
|
||||
欢迎接入更多应用通道,参考 [飞书通道](https://github.com/zhayujie/CowAgent/blob/master/channel/feishu/feishu_channel.py) 新增自定义通道,实现接收和发送消息逻辑即可完成接入。同时欢迎贡献新的 Skills,向 [Skill Hub](https://skills.cowagent.ai/submit) 提交技能。
|
||||
|
||||
# ✉ 联系
|
||||
|
||||
欢迎提交PR、Issues进行反馈,以及通过 🌟Star 支持并关注项目更新。项目运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,以及前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品客服](https://cdn.link-ai.tech/portal/linkai-customer-service.png)咨询。
|
||||
欢迎提交PR、Issues进行反馈,以及通过 🌟Star 支持并关注项目更新。项目运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/CowAgent/wiki/FAQs) ,以及前往 [Issues](https://github.com/zhayujie/CowAgent/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品客服](https://cdn.link-ai.tech/portal/linkai-customer-service.png)咨询。
|
||||
|
||||
# 🌟 贡献者
|
||||
|
||||

|
||||

|
||||
|
||||
# 📌 项目更名说明
|
||||
|
||||
本项目原名 `chatgpt-on-wechat`(GitHub 原地址:https://github.com/zhayujie/chatgpt-on-wechat ),
|
||||
于 2026.04.13 正式更名为 **CowAgent**。GitHub 已自动设置重定向,原有链接仍可正常访问。
|
||||
|
||||
如需更新本地仓库的远程地址(可选):
|
||||
```bash
|
||||
git remote set-url origin https://github.com/zhayujie/CowAgent.git
|
||||
```
|
||||
|
||||
3
agent/chat/__init__.py
Normal file
3
agent/chat/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.chat.service import ChatService
|
||||
|
||||
__all__ = ["ChatService"]
|
||||
290
agent/chat/service.py
Normal file
290
agent/chat/service.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
ChatService - Wraps the Agent stream execution to produce CHAT protocol chunks.
|
||||
|
||||
Translates agent events (message_update, message_end, tool_execution_end, etc.)
|
||||
into the CHAT socket protocol format (content chunks with segment_id, tool_calls chunks).
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""
|
||||
High-level service that runs an Agent for a given query and streams
|
||||
the results as CHAT protocol chunks via a callback.
|
||||
|
||||
Usage:
|
||||
svc = ChatService(agent_bridge)
|
||||
svc.run(query, session_id, send_chunk_fn)
|
||||
"""
|
||||
|
||||
def __init__(self, agent_bridge):
|
||||
"""
|
||||
:param agent_bridge: AgentBridge instance (manages agent lifecycle)
|
||||
"""
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def run(self, query: str, session_id: str, send_chunk_fn: Callable[[dict], None],
|
||||
channel_type: str = ""):
|
||||
"""
|
||||
Run the agent for *query* and stream results back via *send_chunk_fn*.
|
||||
|
||||
The method blocks until the agent finishes. After it returns the SDK
|
||||
will automatically send the final (streaming=false) message.
|
||||
|
||||
:param query: user query text
|
||||
:param session_id: session identifier for agent isolation
|
||||
:param send_chunk_fn: callable(chunk_data: dict) to send a streaming chunk
|
||||
:param channel_type: source channel (e.g. "web", "feishu") for persistence
|
||||
"""
|
||||
agent = self.agent_bridge.get_agent(session_id=session_id)
|
||||
if agent is None:
|
||||
raise RuntimeError("Failed to initialise agent for the session")
|
||||
|
||||
# Pass context metadata to model for downstream API requests
|
||||
if hasattr(agent, 'model'):
|
||||
agent.model.channel_type = channel_type or ""
|
||||
agent.model.session_id = session_id or ""
|
||||
|
||||
# State shared between the event callback and this method
|
||||
state = _StreamState()
|
||||
|
||||
def on_event(event: dict):
|
||||
"""Translate agent events into CHAT protocol chunks."""
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "reasoning_update":
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "reasoning",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_update":
|
||||
# Incremental text delta
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_end":
|
||||
# A content segment finished.
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
# After tool_calls are executed the next content will be
|
||||
# a new segment; collect tool results until turn_end.
|
||||
state.pending_tool_results = []
|
||||
|
||||
elif event_type == "file_to_send":
|
||||
url = data.get("url") or ""
|
||||
if url:
|
||||
fname = data.get("file_name") or "file"
|
||||
ft = data.get("file_type") or "file"
|
||||
if ft == "image":
|
||||
link = f""
|
||||
else:
|
||||
link = f"[{fname}]({url})"
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": "\n\n" + link + "\n\n",
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
# Remove url so the model won't repeat it in its reply
|
||||
data.pop("url", None)
|
||||
|
||||
elif event_type == "tool_execution_start":
|
||||
# Notify the client that a tool is about to run (with its input args)
|
||||
tool_name = data.get("tool_name", "")
|
||||
arguments = data.get("arguments", {})
|
||||
# Cache arguments keyed by tool_call_id so tool_execution_end can include them
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
state.pending_tool_arguments[tool_call_id] = arguments
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_start",
|
||||
"tool": tool_name,
|
||||
"arguments": arguments,
|
||||
})
|
||||
|
||||
elif event_type == "tool_execution_end":
|
||||
tool_name = data.get("tool_name", "")
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
# Retrieve cached arguments from the matching tool_execution_start event
|
||||
arguments = state.pending_tool_arguments.pop(tool_call_id, data.get("arguments", {}))
|
||||
result = data.get("result", "")
|
||||
status = data.get("status", "unknown")
|
||||
execution_time = data.get("execution_time", 0)
|
||||
elapsed_str = f"{execution_time:.2f}s"
|
||||
|
||||
# Serialise result to string if needed
|
||||
if not isinstance(result, str):
|
||||
import json
|
||||
try:
|
||||
result = json.dumps(result, ensure_ascii=False)
|
||||
except Exception:
|
||||
result = str(result)
|
||||
|
||||
tool_info = {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
"result": result,
|
||||
"status": status,
|
||||
"elapsed": elapsed_str,
|
||||
}
|
||||
|
||||
if state.pending_tool_results is not None:
|
||||
state.pending_tool_results.append(tool_info)
|
||||
|
||||
elif event_type == "turn_end":
|
||||
has_tool_calls = data.get("has_tool_calls", False)
|
||||
if has_tool_calls and state.pending_tool_results:
|
||||
# Flush collected tool results as a single tool_calls chunk
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_calls",
|
||||
"tool_calls": state.pending_tool_results,
|
||||
})
|
||||
state.pending_tool_results = None
|
||||
# Next content belongs to a new segment
|
||||
state.segment_id += 1
|
||||
|
||||
# Run the agent with our event callback ---------------------------
|
||||
logger.info(f"[ChatService] Starting agent run: session={session_id}, query={query[:80]}")
|
||||
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = agent.get_full_system_prompt()
|
||||
|
||||
# Create a copy of messages for this execution
|
||||
with agent.messages_lock:
|
||||
messages_copy = agent.messages.copy()
|
||||
original_length = len(agent.messages)
|
||||
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
|
||||
executor = AgentStreamExecutor(
|
||||
agent=agent,
|
||||
model=agent.model,
|
||||
system_prompt=full_system_prompt,
|
||||
tools=agent.tools,
|
||||
max_turns=agent.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy,
|
||||
max_context_turns=max_context_turns,
|
||||
)
|
||||
|
||||
try:
|
||||
response = executor.run_stream(query)
|
||||
except Exception:
|
||||
# If executor cleared messages (context overflow), sync back
|
||||
if len(executor.messages) == 0:
|
||||
with agent.messages_lock:
|
||||
agent.messages.clear()
|
||||
logger.info("[ChatService] Cleared agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Sync executor messages back to agent (thread-safe).
|
||||
# The executor may have trimmed context, making its list shorter than
|
||||
# original_length. In that case we must replace entirely — just
|
||||
# appending would leave stale pre-trim messages in agent.messages
|
||||
# and cause the same trim to fire on every subsequent request.
|
||||
with agent.messages_lock:
|
||||
trimmed = len(executor.messages) < original_length
|
||||
if trimmed:
|
||||
# Context was trimmed: the executor appended the new user
|
||||
# query *before* trimming, so the new messages (user +
|
||||
# assistant + tools) sit at the tail of the trimmed list.
|
||||
# We cannot simply slice at original_length (it exceeds the
|
||||
# list length). Instead, count how many messages the
|
||||
# executor added on top of the post-trim baseline.
|
||||
#
|
||||
# Timeline inside executor.run_stream:
|
||||
# 1. messages had `original_length` items
|
||||
# 2. append user query → original_length + 1
|
||||
# 3. _trim_messages() → some smaller number (includes the
|
||||
# user query because it belongs to the last turn)
|
||||
# 4. LLM replies / tool calls appended
|
||||
#
|
||||
# The user query message is always the first message of the
|
||||
# last turn (it cannot be trimmed away), so we locate it to
|
||||
# find where "new" messages begin.
|
||||
new_start = original_length # fallback
|
||||
for idx in range(len(executor.messages) - 1, -1, -1):
|
||||
msg = executor.messages[idx]
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", [])
|
||||
is_user_query = False
|
||||
if isinstance(content, list):
|
||||
has_text = any(
|
||||
isinstance(b, dict) and b.get("type") == "text"
|
||||
for b in content
|
||||
)
|
||||
has_tool_result = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
)
|
||||
is_user_query = has_text and not has_tool_result
|
||||
elif isinstance(content, str):
|
||||
is_user_query = True
|
||||
if is_user_query:
|
||||
new_start = idx
|
||||
break
|
||||
new_messages = list(executor.messages[new_start:])
|
||||
else:
|
||||
new_messages = list(executor.messages[original_length:])
|
||||
agent.messages = list(executor.messages)
|
||||
|
||||
# Persist new messages to SQLite so they survive restarts and
|
||||
# can be queried via the HISTORY interface.
|
||||
if new_messages:
|
||||
self._persist_messages(session_id, list(new_messages), channel_type)
|
||||
|
||||
# Store executor reference for files_to_send access
|
||||
agent.stream_executor = executor
|
||||
|
||||
# Execute post-process tools
|
||||
agent._execute_post_process_tools()
|
||||
|
||||
logger.info(f"[ChatService] Agent run completed: session={session_id}")
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _persist_messages(session_id: str, new_messages: list, channel_type: str = ""):
|
||||
try:
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().append_messages(
|
||||
session_id, new_messages, channel_type=channel_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[ChatService] Failed to persist messages for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
class _StreamState:
|
||||
"""Mutable state shared between the event callback and the run method."""
|
||||
|
||||
def __init__(self):
|
||||
self.segment_id: int = 0
|
||||
# None means we are not accumulating tool results right now.
|
||||
# A list means we are in the middle of a tool-execution phase.
|
||||
self.pending_tool_results: Optional[list] = None
|
||||
# Maps tool_call_id -> arguments captured from tool_execution_start,
|
||||
# so that tool_execution_end can attach the correct input args.
|
||||
self.pending_tool_arguments: dict = {}
|
||||
241
agent/chat/session_service.py
Normal file
241
agent/chat/session_service.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
SessionService - Manages multi-session lifecycle for both web channel and cloud client.
|
||||
|
||||
Provides a unified interface for listing, deleting, renaming, clearing context,
|
||||
and generating AI titles for conversation sessions. Backed by ConversationStore
|
||||
(SQLite) and AgentBridge (in-memory agent instances).
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def _truncate_fallback_title(user_message: str, max_len: int = 30) -> str:
|
||||
"""Pick the first non-empty line of the user message and truncate it."""
|
||||
if not user_message:
|
||||
return "New Chat"
|
||||
first_line = ""
|
||||
for line in user_message.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
first_line = line
|
||||
break
|
||||
if not first_line:
|
||||
return "New Chat"
|
||||
if len(first_line) > max_len:
|
||||
first_line = first_line[:max_len].rstrip() + "..."
|
||||
return first_line
|
||||
|
||||
|
||||
def generate_session_title(user_message: str, assistant_reply: str = "") -> str:
|
||||
"""
|
||||
Generate a short session title by calling the current bot's reply_text.
|
||||
Falls back to the first line of the user message if the LLM call fails
|
||||
or returns an obvious error sentinel.
|
||||
"""
|
||||
fallback = _truncate_fallback_title(user_message)
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
from models.session_manager import Session
|
||||
bot = Bridge().get_bot("chat")
|
||||
|
||||
prompt_parts = [f"User: {user_message[:300]}"]
|
||||
if assistant_reply:
|
||||
prompt_parts.append(f"Assistant: {assistant_reply[:300]}")
|
||||
|
||||
session = Session("__title_gen__", system_prompt="")
|
||||
session.messages = [
|
||||
{"role": "user", "content": (
|
||||
"Generate a very short title (max 15 characters for Chinese, max 6 words for English) "
|
||||
"summarizing this conversation. Return ONLY the title text, nothing else.\n\n"
|
||||
+ "\n".join(prompt_parts)
|
||||
)}
|
||||
]
|
||||
|
||||
result = bot.reply_text(session) or {}
|
||||
# When bots fail (network error, auth error, rate limit, etc.) they
|
||||
# typically return completion_tokens=0 with a sentinel content like
|
||||
# "请再问我一次吧" / "我现在有点累了". Treat that as failure.
|
||||
completion_tokens = result.get("completion_tokens", 0) or 0
|
||||
raw = (result.get("content") or "").strip()
|
||||
if completion_tokens <= 0:
|
||||
logger.warning(
|
||||
f"[SessionService] Title generation got empty completion "
|
||||
f"(completion_tokens={completion_tokens}, content='{raw[:50]}'), "
|
||||
f"using fallback")
|
||||
return fallback
|
||||
|
||||
title = re.sub(r'<think>.*?</think>', '', raw, flags=re.DOTALL).strip().strip('"\'')
|
||||
logger.info(f"[SessionService] Title generation result: '{title}' (len={len(title)})")
|
||||
if title and len(title) <= 50:
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.warning(f"[SessionService] Title generation failed: {e}")
|
||||
return fallback
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""
|
||||
High-level service for session lifecycle management.
|
||||
|
||||
Usage:
|
||||
svc = SessionService()
|
||||
result = svc.dispatch("list", {"channel_type": "web", "page": 1})
|
||||
"""
|
||||
|
||||
def _get_store(self):
|
||||
from agent.memory import get_conversation_store
|
||||
return get_conversation_store()
|
||||
|
||||
def _remove_agent(self, session_id: str):
|
||||
"""Remove the in-memory Agent instance for a session if it exists."""
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
ab = Bridge().get_agent_bridge()
|
||||
if session_id in ab.agents:
|
||||
del ab.agents[session_id]
|
||||
logger.info(f"[SessionService] Removed agent instance: {session_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _normalize_sid(session_id: str) -> str:
|
||||
if session_id and not session_id.startswith("session_"):
|
||||
return f"session_{session_id}"
|
||||
return session_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# actions
|
||||
# ------------------------------------------------------------------
|
||||
def list_sessions(self, channel_type: Optional[str] = None,
|
||||
page: int = 1, page_size: int = 50) -> dict:
|
||||
store = self._get_store()
|
||||
return store.list_sessions(
|
||||
channel_type=channel_type,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
def delete_session(self, session_id: str) -> None:
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
store.clear_session(session_id)
|
||||
self._remove_agent(session_id)
|
||||
logger.info(f"[SessionService] Session deleted: {session_id}")
|
||||
|
||||
def rename_session(self, session_id: str, title: str) -> None:
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
if not title:
|
||||
raise ValueError("title required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
found = store.rename_session(session_id, title)
|
||||
if not found:
|
||||
raise ValueError("session not found")
|
||||
|
||||
def clear_context(self, session_id: str) -> int:
|
||||
"""
|
||||
Set context boundary. Returns the new context_start_seq value.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
new_seq = store.clear_context(session_id)
|
||||
self._remove_agent(session_id)
|
||||
return new_seq
|
||||
|
||||
def gen_title(self, session_id: str, user_message: str,
|
||||
assistant_reply: str = "") -> str:
|
||||
"""
|
||||
Generate an AI title and persist it. Returns the generated title.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
if not user_message:
|
||||
raise ValueError("user_message required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
title = generate_session_title(user_message, assistant_reply)
|
||||
|
||||
store = self._get_store()
|
||||
updated = store.rename_session(session_id, title)
|
||||
logger.info(f"[SessionService] Title set: sid={session_id}, "
|
||||
f"title='{title}', db_updated={updated}")
|
||||
return title
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a session management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
Action names use a ``*_session`` / session-prefixed convention so they
|
||||
can coexist with history actions (e.g. ``query``) on the same HISTORY
|
||||
message channel without ambiguity.
|
||||
|
||||
Supported actions:
|
||||
- list_sessions: list sessions with pagination
|
||||
- delete_session: delete a session
|
||||
- rename_session: rename a session title
|
||||
- clear_context: set context boundary
|
||||
- generate_title: AI-generate a session title
|
||||
|
||||
:param action: one of the above action names
|
||||
:param payload: action-specific payload
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list_sessions":
|
||||
result = self.list_sessions(
|
||||
channel_type=payload.get("channel_type"),
|
||||
page=int(payload.get("page", 1)),
|
||||
page_size=int(payload.get("page_size", 50)),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "delete_session":
|
||||
self.delete_session(payload.get("session_id", ""))
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
|
||||
elif action == "rename_session":
|
||||
self.rename_session(
|
||||
payload.get("session_id", ""),
|
||||
payload.get("title", "").strip(),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
|
||||
elif action == "clear_context":
|
||||
new_seq = self.clear_context(payload.get("session_id", ""))
|
||||
return {"action": action, "code": 200, "message": "success",
|
||||
"payload": {"context_start_seq": new_seq}}
|
||||
|
||||
elif action == "generate_title":
|
||||
title = self.gen_title(
|
||||
payload.get("session_id", ""),
|
||||
payload.get("user_message", ""),
|
||||
payload.get("assistant_reply", ""),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success",
|
||||
"payload": {"title": title}}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400,
|
||||
"message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 400, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SessionService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
0
agent/knowledge/__init__.py
Normal file
0
agent/knowledge/__init__.py
Normal file
240
agent/knowledge/service.py
Normal file
240
agent/knowledge/service.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Knowledge service for handling knowledge base operations.
|
||||
|
||||
Provides a unified interface for listing, reading, and graphing knowledge files,
|
||||
callable from the web console, API, or CLI.
|
||||
|
||||
Knowledge file layout (under workspace_root):
|
||||
knowledge/index.md
|
||||
knowledge/log.md
|
||||
knowledge/<category>/<slug>.md
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
"""
|
||||
High-level service for knowledge base queries.
|
||||
Operates directly on the filesystem.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
self.workspace_root = workspace_root
|
||||
self.knowledge_dir = os.path.join(workspace_root, "knowledge")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — directory tree with stats
|
||||
# ------------------------------------------------------------------
|
||||
def list_tree(self) -> dict:
|
||||
"""
|
||||
Return the knowledge directory tree grouped by category,
|
||||
supporting arbitrarily nested sub-directories.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"tree": [
|
||||
{
|
||||
"dir": "concepts",
|
||||
"files": [
|
||||
{"name": "moe.md", "title": "MoE", "size": 1234},
|
||||
],
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"dir": "platform",
|
||||
"files": [],
|
||||
"children": [
|
||||
{
|
||||
"dir": "analysis",
|
||||
"files": [{"name": "perf.md", ...}],
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
],
|
||||
"stats": {"pages": 15, "size": 32768},
|
||||
"enabled": true
|
||||
}
|
||||
"""
|
||||
if not os.path.isdir(self.knowledge_dir):
|
||||
return {"tree": [], "stats": {"pages": 0, "size": 0}, "enabled": conf().get("knowledge", True)}
|
||||
|
||||
stats = {"pages": 0, "size": 0}
|
||||
root_files, tree = self._scan_dir(self.knowledge_dir, stats, is_root=True)
|
||||
|
||||
return {
|
||||
"root_files": root_files,
|
||||
"tree": tree,
|
||||
"stats": stats,
|
||||
"enabled": conf().get("knowledge", True),
|
||||
}
|
||||
|
||||
def _scan_dir(self, dir_path: str, stats: dict, is_root: bool = False) -> tuple:
|
||||
"""
|
||||
Recursively scan a directory.
|
||||
|
||||
:return: (files, children) where files is a list of .md file dicts
|
||||
in this directory and children is a list of sub-directory nodes.
|
||||
"""
|
||||
files = []
|
||||
children = []
|
||||
for name in sorted(os.listdir(dir_path)):
|
||||
if name.startswith("."):
|
||||
continue
|
||||
full = os.path.join(dir_path, name)
|
||||
if os.path.isdir(full):
|
||||
sub_files, sub_children = self._scan_dir(full, stats)
|
||||
children.append({"dir": name, "files": sub_files, "children": sub_children})
|
||||
elif name.endswith(".md"):
|
||||
size = os.path.getsize(full)
|
||||
if not is_root:
|
||||
stats["pages"] += 1
|
||||
stats["size"] += size
|
||||
title = name.replace(".md", "")
|
||||
try:
|
||||
with open(full, "r", encoding="utf-8") as f:
|
||||
first_line = f.readline().strip()
|
||||
if first_line.startswith("# "):
|
||||
title = first_line[2:].strip()
|
||||
except Exception:
|
||||
pass
|
||||
files.append({"name": name, "title": title, "size": size})
|
||||
return files, children
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# read — single file content
|
||||
# ------------------------------------------------------------------
|
||||
def read_file(self, rel_path: str) -> dict:
|
||||
"""
|
||||
Read a single knowledge markdown file.
|
||||
|
||||
:param rel_path: Relative path within knowledge/, e.g. ``concepts/moe.md``
|
||||
:return: dict with ``content`` and ``path``
|
||||
:raises ValueError: if path is invalid or escapes knowledge dir
|
||||
:raises FileNotFoundError: if file does not exist
|
||||
"""
|
||||
if not rel_path or ".." in rel_path:
|
||||
raise ValueError("invalid path")
|
||||
|
||||
full_path = os.path.normpath(os.path.join(self.knowledge_dir, rel_path))
|
||||
allowed = os.path.normpath(self.knowledge_dir)
|
||||
if not full_path.startswith(allowed + os.sep) and full_path != allowed:
|
||||
raise ValueError("path outside knowledge dir")
|
||||
|
||||
if not os.path.isfile(full_path):
|
||||
raise FileNotFoundError(f"file not found: {rel_path}")
|
||||
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
return {"content": content, "path": rel_path}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# graph — nodes and links for visualization
|
||||
# ------------------------------------------------------------------
|
||||
def build_graph(self) -> dict:
|
||||
"""
|
||||
Parse all knowledge pages and extract cross-reference links.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"nodes": [
|
||||
{"id": "concepts/moe.md", "label": "MoE", "category": "concepts"},
|
||||
...
|
||||
],
|
||||
"links": [
|
||||
{"source": "concepts/moe.md", "target": "entities/deepseek.md"},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
knowledge_path = Path(self.knowledge_dir)
|
||||
if not knowledge_path.is_dir():
|
||||
return {"nodes": [], "links": []}
|
||||
|
||||
nodes = {}
|
||||
links = []
|
||||
link_re = re.compile(r'\[([^\]]*)\]\(([^)]+\.md)\)')
|
||||
|
||||
for md_file in knowledge_path.rglob("*.md"):
|
||||
rel = str(md_file.relative_to(knowledge_path))
|
||||
if rel in ("index.md", "log.md"):
|
||||
continue
|
||||
parts = rel.split("/")
|
||||
category = parts[0] if len(parts) > 1 else "root"
|
||||
title = md_file.stem.replace("-", " ").title()
|
||||
try:
|
||||
content = md_file.read_text(encoding="utf-8")
|
||||
first_line = content.strip().split("\n")[0]
|
||||
if first_line.startswith("# "):
|
||||
title = first_line[2:].strip()
|
||||
for _, link_target in link_re.findall(content):
|
||||
resolved = (md_file.parent / link_target).resolve()
|
||||
try:
|
||||
target_rel = str(resolved.relative_to(knowledge_path))
|
||||
except ValueError:
|
||||
continue
|
||||
if target_rel != rel:
|
||||
links.append({"source": rel, "target": target_rel})
|
||||
except Exception:
|
||||
pass
|
||||
nodes[rel] = {"id": rel, "label": title, "category": category}
|
||||
|
||||
valid_ids = set(nodes.keys())
|
||||
links = [l for l in links if l["source"] in valid_ids and l["target"] in valid_ids]
|
||||
seen = set()
|
||||
deduped = []
|
||||
for l in links:
|
||||
key = tuple(sorted([l["source"], l["target"]]))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(l)
|
||||
|
||||
return {"nodes": list(nodes.values()), "links": deduped}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a knowledge management action.
|
||||
|
||||
:param action: ``list``, ``read``, or ``graph``
|
||||
:param payload: action-specific payload
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
result = self.list_tree()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "read":
|
||||
path = payload.get("path")
|
||||
if not path:
|
||||
return {"action": action, "code": 400, "message": "path is required", "payload": None}
|
||||
result = self.read_file(path)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "graph":
|
||||
result = self.build_graph()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 403, "message": str(e), "payload": None}
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[KnowledgeService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
@@ -1,11 +1,23 @@
|
||||
"""
|
||||
Memory module for AgentMesh
|
||||
|
||||
Provides long-term memory capabilities with hybrid search (vector + keyword)
|
||||
Provides both long-term memory (vector/keyword search) and short-term
|
||||
conversation history persistence (SQLite).
|
||||
"""
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
from agent.memory.embedding import create_embedding_provider
|
||||
from agent.memory.conversation_store import ConversationStore, get_conversation_store
|
||||
from agent.memory.summarizer import ensure_daily_memory_file
|
||||
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider']
|
||||
__all__ = [
|
||||
'MemoryManager',
|
||||
'MemoryConfig',
|
||||
'get_default_memory_config',
|
||||
'set_global_memory_config',
|
||||
'create_embedding_provider',
|
||||
'ConversationStore',
|
||||
'get_conversation_store',
|
||||
'ensure_daily_memory_file',
|
||||
]
|
||||
|
||||
@@ -48,9 +48,6 @@ class MemoryConfig:
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
# Memory flush config (独立于模型 context window)
|
||||
flush_token_threshold: int = 50000 # 50K tokens 触发 flush
|
||||
flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
|
||||
849
agent/memory/conversation_store.py
Normal file
849
agent/memory/conversation_store.py
Normal file
@@ -0,0 +1,849 @@
|
||||
"""
|
||||
Conversation history persistence using SQLite.
|
||||
|
||||
Design:
|
||||
- sessions table: per-session metadata (channel_type, last_active, msg_count)
|
||||
- messages table: individual messages stored as JSON, append-only
|
||||
- Pruning: age-based only (sessions not updated within N days are deleted)
|
||||
- Thread-safe via a single in-process lock
|
||||
|
||||
Storage path: ~/cow/sessions/conversations.db
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
channel_type TEXT NOT NULL DEFAULT '',
|
||||
title TEXT NOT NULL DEFAULT '',
|
||||
context_start_seq INTEGER NOT NULL DEFAULT 0,
|
||||
created_at INTEGER NOT NULL,
|
||||
last_active INTEGER NOT NULL,
|
||||
msg_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
seq INTEGER NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
UNIQUE (session_id, seq)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_session
|
||||
ON messages (session_id, seq);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_last_active
|
||||
ON sessions (last_active);
|
||||
"""
|
||||
|
||||
# Migration: add channel_type column to existing databases that predate it.
|
||||
_MIGRATION_ADD_CHANNEL_TYPE = """
|
||||
ALTER TABLE sessions ADD COLUMN channel_type TEXT NOT NULL DEFAULT '';
|
||||
"""
|
||||
|
||||
_MIGRATION_ADD_TITLE = """
|
||||
ALTER TABLE sessions ADD COLUMN title TEXT NOT NULL DEFAULT '';
|
||||
"""
|
||||
|
||||
_MIGRATION_ADD_CONTEXT_START_SEQ = """
|
||||
ALTER TABLE sessions ADD COLUMN context_start_seq INTEGER NOT NULL DEFAULT 0;
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_AGE_DAYS: int = 30
|
||||
|
||||
|
||||
def _is_visible_user_message(content: Any) -> bool:
|
||||
"""
|
||||
Return True when a user-role message represents actual user input
|
||||
(not an internal tool_result injected by the agent loop).
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return bool(content.strip())
|
||||
if isinstance(content, list):
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == "text"
|
||||
for b in content
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _extract_display_text(content: Any) -> str:
|
||||
"""
|
||||
Extract the human-readable text portion from a message content value.
|
||||
Returns an empty string for tool_use / tool_result blocks.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_tool_calls(content: Any) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract tool_use blocks from an assistant message content.
|
||||
Returns a list of {name, arguments} dicts (result filled in later).
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return []
|
||||
return [
|
||||
{"id": b.get("id", ""), "name": b.get("name", ""), "arguments": b.get("input", {})}
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
]
|
||||
|
||||
|
||||
def _extract_tool_results(content: Any) -> Dict[str, str]:
|
||||
"""
|
||||
Extract tool_result blocks from a user message, keyed by tool_use_id.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return {}
|
||||
results = {}
|
||||
for b in content:
|
||||
if not isinstance(b, dict) or b.get("type") != "tool_result":
|
||||
continue
|
||||
tool_id = b.get("tool_use_id", "")
|
||||
result_content = b.get("content", "")
|
||||
if isinstance(result_content, list):
|
||||
result_content = "\n".join(
|
||||
rb.get("text", "") for rb in result_content
|
||||
if isinstance(rb, dict) and rb.get("type") == "text"
|
||||
)
|
||||
results[tool_id] = str(result_content)
|
||||
return results
|
||||
|
||||
|
||||
def _group_into_display_turns(
|
||||
rows: List[tuple],
|
||||
include_thinking: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert raw (role, content_json, created_at) DB rows into display turns.
|
||||
|
||||
One display turn = one visible user message + one merged assistant reply.
|
||||
All intermediate assistant messages (those carrying tool_use) and the final
|
||||
assistant text reply produced for the same user query are collapsed into a
|
||||
single assistant turn, exactly matching the live SSE rendering where tools
|
||||
and the final answer appear inside the same bubble.
|
||||
|
||||
Grouping rules:
|
||||
- A visible user message starts a new group.
|
||||
- tool_result user messages are internal; their content is attached to the
|
||||
matching tool_use entry via tool_use_id and they never become own turns.
|
||||
- All assistant messages within a group are merged:
|
||||
* tool_use blocks → tool_calls list (result filled from tool_results)
|
||||
* text blocks → last non-empty text becomes the display content
|
||||
"""
|
||||
# ------------------------------------------------------------------ #
|
||||
# Pass 1: split rows into groups, each starting with a visible user msg
|
||||
# ------------------------------------------------------------------ #
|
||||
# group = (user_row | None, [subsequent_rows])
|
||||
# user_row: (content, created_at)
|
||||
groups: List[tuple] = []
|
||||
cur_user: Optional[tuple] = None
|
||||
cur_rest: List[tuple] = []
|
||||
started = False
|
||||
|
||||
for role, raw_content, created_at in rows:
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
|
||||
if role == "user" and _is_visible_user_message(content):
|
||||
if started:
|
||||
groups.append((cur_user, cur_rest))
|
||||
cur_user = (content, created_at)
|
||||
cur_rest = []
|
||||
started = True
|
||||
else:
|
||||
cur_rest.append((role, content, created_at))
|
||||
|
||||
if started:
|
||||
groups.append((cur_user, cur_rest))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Pass 2: build display turns from each group
|
||||
# ------------------------------------------------------------------ #
|
||||
turns: List[Dict[str, Any]] = []
|
||||
|
||||
for user_row, rest in groups:
|
||||
# User turn
|
||||
if user_row:
|
||||
content, created_at = user_row
|
||||
text = _extract_display_text(content)
|
||||
if text:
|
||||
turns.append({"role": "user", "content": text, "created_at": created_at})
|
||||
|
||||
# Build an ordered list of steps preserving the original sequence:
|
||||
# thinking → content → tool_call → content → ...
|
||||
steps: List[Dict[str, Any]] = []
|
||||
tool_results: Dict[str, str] = {}
|
||||
final_text = ""
|
||||
final_ts: Optional[int] = None
|
||||
|
||||
for role, content, created_at in rest:
|
||||
if role == "user":
|
||||
tool_results.update(_extract_tool_results(content))
|
||||
elif role == "assistant":
|
||||
# Walk content blocks in order to preserve interleaving
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
btype = block.get("type")
|
||||
if btype == "thinking":
|
||||
if not include_thinking:
|
||||
continue
|
||||
txt = block.get("thinking", "").strip()
|
||||
if txt:
|
||||
steps.append({"type": "thinking", "content": txt})
|
||||
elif btype == "text":
|
||||
txt = block.get("text", "").strip()
|
||||
if txt:
|
||||
steps.append({"type": "content", "content": txt})
|
||||
final_text = txt
|
||||
elif btype == "tool_use":
|
||||
steps.append({
|
||||
"type": "tool",
|
||||
"id": block.get("id", ""),
|
||||
"name": block.get("name", ""),
|
||||
"arguments": block.get("input", {}),
|
||||
})
|
||||
elif isinstance(content, str) and content.strip():
|
||||
steps.append({"type": "content", "content": content.strip()})
|
||||
final_text = content.strip()
|
||||
final_ts = created_at
|
||||
|
||||
# Attach tool results to tool steps
|
||||
for step in steps:
|
||||
if step["type"] == "tool":
|
||||
step["result"] = tool_results.get(step.get("id", ""), "")
|
||||
|
||||
if steps or final_text:
|
||||
turn = {
|
||||
"role": "assistant",
|
||||
"content": final_text,
|
||||
"steps": steps,
|
||||
"created_at": final_ts or (user_row[1] if user_row else 0),
|
||||
}
|
||||
turns.append(turn)
|
||||
|
||||
return turns
|
||||
|
||||
|
||||
class ConversationStore:
|
||||
"""
|
||||
SQLite-backed store for per-session conversation history.
|
||||
|
||||
Usage:
|
||||
store = ConversationStore(db_path)
|
||||
store.append_messages("user_123", new_messages, channel_type="feishu")
|
||||
msgs = store.load_messages("user_123", max_turns=30)
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self._db_path = db_path
|
||||
self._lock = threading.Lock()
|
||||
self._init_db()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
max_turns: int = 30,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load the most recent messages for a session, for injection into the LLM.
|
||||
|
||||
ALL message types (user text, assistant tool_use, tool_result) are returned
|
||||
in their original JSON form so the LLM can reconstruct the full context.
|
||||
|
||||
max_turns is a *visible-turn* count: we count only user messages whose
|
||||
content is actual user text (not tool_result blocks). This prevents
|
||||
tool-heavy sessions from exhausting the turn budget prematurely.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier.
|
||||
max_turns: Maximum number of visible user-assistant turns to keep.
|
||||
|
||||
Returns:
|
||||
Chronologically ordered list of message dicts (role, content).
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
# Respect context_start_seq: only load messages at or after the boundary
|
||||
ctx_row = conn.execute(
|
||||
"SELECT context_start_seq FROM sessions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
ctx_start = ctx_row[0] if ctx_row else 0
|
||||
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT seq, role, content
|
||||
FROM messages
|
||||
WHERE session_id = ? AND seq >= ?
|
||||
ORDER BY seq DESC
|
||||
""",
|
||||
(session_id, ctx_start),
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
visible_turn_seqs: List[int] = []
|
||||
for seq, role, raw_content in rows:
|
||||
if role != "user":
|
||||
continue
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
if _is_visible_user_message(content):
|
||||
visible_turn_seqs.append(seq)
|
||||
|
||||
if len(visible_turn_seqs) <= max_turns:
|
||||
cutoff_seq = None
|
||||
else:
|
||||
cutoff_seq = visible_turn_seqs[max_turns - 1]
|
||||
|
||||
result = []
|
||||
for seq, role, raw_content in reversed(rows):
|
||||
if cutoff_seq is not None and seq < cutoff_seq:
|
||||
continue
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
# Strip thinking blocks — they are stored for UI display only
|
||||
if role == "assistant" and isinstance(content, list):
|
||||
content = [b for b in content if b.get("type") != "thinking"]
|
||||
result.append({"role": role, "content": content})
|
||||
return result
|
||||
|
||||
def append_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
channel_type: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Append new messages to a session's history.
|
||||
|
||||
Seq numbers continue from the session's current maximum, so
|
||||
concurrent callers on distinct sessions never collide.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier.
|
||||
messages: List of message dicts to append.
|
||||
channel_type: Source channel (e.g. "feishu", "web", "wechat").
|
||||
Only written on session creation; ignored on update.
|
||||
"""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
now = int(time.time())
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
# INSERT OR IGNORE creates the row on first visit;
|
||||
# the UPDATE always refreshes last_active.
|
||||
# Avoids ON CONFLICT...DO UPDATE (requires SQLite >= 3.24).
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO sessions
|
||||
(session_id, channel_type, created_at, last_active, msg_count)
|
||||
VALUES (?, ?, ?, ?, 0)
|
||||
""",
|
||||
(session_id, channel_type, now, now),
|
||||
)
|
||||
conn.execute(
|
||||
"UPDATE sessions SET last_active = ? WHERE session_id = ?",
|
||||
(now, session_id),
|
||||
)
|
||||
|
||||
# Determine starting seq for the new batch.
|
||||
row = conn.execute(
|
||||
"SELECT COALESCE(MAX(seq), -1) FROM messages WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
next_seq = row[0] + 1
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = json.dumps(
|
||||
msg.get("content", ""), ensure_ascii=False
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO messages
|
||||
(session_id, seq, role, content, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(session_id, next_seq, role, content, now),
|
||||
)
|
||||
next_seq += 1
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET msg_count = (
|
||||
SELECT COUNT(*) FROM messages WHERE session_id = ?
|
||||
)
|
||||
WHERE session_id = ?
|
||||
""",
|
||||
(session_id, session_id),
|
||||
)
|
||||
|
||||
# Auto-generate title from the first visible user message
|
||||
cur_title = conn.execute(
|
||||
"SELECT title FROM sessions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
if cur_title and not cur_title[0]:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
text = _extract_display_text(content)
|
||||
if text:
|
||||
title = text[:50].split("\n")[0]
|
||||
conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE session_id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
break
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def clear_context(self, session_id: str) -> int:
|
||||
"""
|
||||
Set the context boundary to after the current last message.
|
||||
Messages before this boundary are still stored but excluded from LLM context.
|
||||
|
||||
Returns the new context_start_seq value.
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
row = conn.execute(
|
||||
"SELECT COALESCE(MAX(seq), -1) FROM messages WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
new_start = row[0] + 1
|
||||
conn.execute(
|
||||
"UPDATE sessions SET context_start_seq = ? WHERE session_id = ?",
|
||||
(new_start, session_id),
|
||||
)
|
||||
return new_start
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_context_start_seq(self, session_id: str) -> int:
|
||||
"""Return the context_start_seq for a session (0 if not set)."""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
row = conn.execute(
|
||||
"SELECT context_start_seq FROM sessions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
return row[0] if row else 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""Delete all messages and the session record for a given session_id."""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM sessions WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def cleanup_old_sessions(self, max_age_days: Optional[int] = None) -> int:
|
||||
"""
|
||||
Delete sessions that have not been active within max_age_days.
|
||||
Web channel sessions are excluded — they are meant to be permanent.
|
||||
|
||||
Args:
|
||||
max_age_days: Override the default retention period.
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted.
|
||||
"""
|
||||
try:
|
||||
from config import conf
|
||||
max_age = max_age_days or conf().get(
|
||||
"conversation_max_age_days", DEFAULT_MAX_AGE_DAYS
|
||||
)
|
||||
except Exception:
|
||||
max_age = max_age_days or DEFAULT_MAX_AGE_DAYS
|
||||
|
||||
cutoff = int(time.time()) - max_age * 86400
|
||||
deleted = 0
|
||||
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
stale = conn.execute(
|
||||
"SELECT session_id FROM sessions "
|
||||
"WHERE last_active < ? AND channel_type != 'web'",
|
||||
(cutoff,),
|
||||
).fetchall()
|
||||
for (sid,) in stale:
|
||||
conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (sid,)
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM sessions WHERE session_id = ?", (sid,)
|
||||
)
|
||||
deleted += 1
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if deleted:
|
||||
logger.info(f"[ConversationStore] Pruned {deleted} expired sessions")
|
||||
return deleted
|
||||
|
||||
def load_history_page(
|
||||
self,
|
||||
session_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load a page of conversation history for UI display, grouped into turns.
|
||||
|
||||
Each "turn" maps to one of:
|
||||
- A user message (role="user", content=str)
|
||||
- An assistant message (role="assistant", content=str,
|
||||
tool_calls=[{name, arguments, result}] when tools were used)
|
||||
|
||||
Internal tool_result user messages are merged into the preceding
|
||||
assistant entry's tool_calls list and never appear as standalone items.
|
||||
|
||||
Pages are numbered from 1 (most recent). Messages within a page are
|
||||
returned in chronological order.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user" | "assistant",
|
||||
"content": str,
|
||||
"tool_calls": [...], # assistant only, may be []
|
||||
"created_at": int,
|
||||
},
|
||||
...
|
||||
],
|
||||
"total": <visible turn count>,
|
||||
"page": <current page>,
|
||||
"page_size": <page_size>,
|
||||
"has_more": bool,
|
||||
}
|
||||
"""
|
||||
page = max(1, page)
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
ctx_row = conn.execute(
|
||||
"SELECT context_start_seq FROM sessions WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
ctx_start = ctx_row[0] if ctx_row else 0
|
||||
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT seq, role, content, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY seq ASC
|
||||
""",
|
||||
(session_id,),
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# Honour the current enable_thinking switch when building display turns
|
||||
# so that toggling it off hides previously-saved thinking blocks too.
|
||||
try:
|
||||
from config import conf
|
||||
include_thinking = bool(conf().get("enable_thinking", False))
|
||||
except Exception:
|
||||
include_thinking = False
|
||||
|
||||
# Strip seq for display grouping, but record max seq per visible user group
|
||||
plain_rows = [(role, content, created_at) for _seq, role, content, created_at in rows]
|
||||
visible = _group_into_display_turns(plain_rows, include_thinking=include_thinking)
|
||||
|
||||
# Build a mapping: find the seq of each visible user message to annotate context boundary.
|
||||
# Walk through rows to find visible user message seqs in order.
|
||||
visible_user_seqs: List[int] = []
|
||||
for seq, role, raw_content, _ts in rows:
|
||||
if role != "user":
|
||||
continue
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
if _is_visible_user_message(content):
|
||||
visible_user_seqs.append(seq)
|
||||
|
||||
# Each pair of display turns (user+assistant) corresponds to a visible user seq.
|
||||
# Mark which turns are before the context boundary.
|
||||
user_turn_idx = 0
|
||||
for turn in visible:
|
||||
if turn["role"] == "user" and user_turn_idx < len(visible_user_seqs):
|
||||
turn["_seq"] = visible_user_seqs[user_turn_idx]
|
||||
user_turn_idx += 1
|
||||
|
||||
total = len(visible)
|
||||
offset = (page - 1) * page_size
|
||||
page_items = list(reversed(visible))[offset: offset + page_size]
|
||||
page_items = list(reversed(page_items))
|
||||
|
||||
return {
|
||||
"messages": page_items,
|
||||
"context_start_seq": ctx_start,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": offset + page_size < total,
|
||||
}
|
||||
|
||||
def list_sessions(
|
||||
self,
|
||||
channel_type: Optional[str] = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
List sessions ordered by last_active DESC, with optional channel_type filter.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"sessions": [{session_id, title, created_at, last_active, msg_count}, ...],
|
||||
"total": int,
|
||||
"page": int,
|
||||
"page_size": int,
|
||||
"has_more": bool,
|
||||
}
|
||||
"""
|
||||
page = max(1, page)
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
if channel_type:
|
||||
total = conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE channel_type = ?",
|
||||
(channel_type,),
|
||||
).fetchone()[0]
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT session_id, title, created_at, last_active, msg_count
|
||||
FROM sessions
|
||||
WHERE channel_type = ?
|
||||
ORDER BY last_active DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(channel_type, page_size, (page - 1) * page_size),
|
||||
).fetchall()
|
||||
else:
|
||||
total = conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions",
|
||||
).fetchone()[0]
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT session_id, title, created_at, last_active, msg_count
|
||||
FROM sessions
|
||||
ORDER BY last_active DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(page_size, (page - 1) * page_size),
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
sessions = [
|
||||
{
|
||||
"session_id": r[0],
|
||||
"title": r[1],
|
||||
"created_at": r[2],
|
||||
"last_active": r[3],
|
||||
"msg_count": r[4],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
return {
|
||||
"sessions": sessions,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": (page - 1) * page_size + page_size < total,
|
||||
}
|
||||
|
||||
def rename_session(self, session_id: str, title: str) -> bool:
|
||||
"""Update the title of a session. Returns True if the session existed."""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
cur = conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE session_id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
return cur.rowcount > 0
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Return basic stats keyed by channel_type, for monitoring."""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
total_sessions = conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions"
|
||||
).fetchone()[0]
|
||||
total_messages = conn.execute(
|
||||
"SELECT COUNT(*) FROM messages"
|
||||
).fetchone()[0]
|
||||
by_channel = conn.execute(
|
||||
"""
|
||||
SELECT channel_type, COUNT(*) as cnt
|
||||
FROM sessions
|
||||
GROUP BY channel_type
|
||||
ORDER BY cnt DESC
|
||||
"""
|
||||
).fetchall()
|
||||
return {
|
||||
"total_sessions": total_sessions,
|
||||
"total_messages": total_messages,
|
||||
"by_channel": {row[0] or "unknown": row[1] for row in by_channel},
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_db(self) -> None:
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = self._connect()
|
||||
try:
|
||||
conn.executescript(_DDL)
|
||||
conn.commit()
|
||||
self._migrate(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _migrate(self, conn: sqlite3.Connection) -> None:
|
||||
"""Apply incremental schema migrations on existing databases."""
|
||||
cols = {
|
||||
row[1]
|
||||
for row in conn.execute("PRAGMA table_info(sessions)").fetchall()
|
||||
}
|
||||
if "channel_type" not in cols:
|
||||
try:
|
||||
conn.execute(_MIGRATION_ADD_CHANNEL_TYPE)
|
||||
conn.commit()
|
||||
logger.info("[ConversationStore] Migrated: added channel_type column")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConversationStore] Migration failed: {e}")
|
||||
if "title" not in cols:
|
||||
try:
|
||||
conn.execute(_MIGRATION_ADD_TITLE)
|
||||
conn.commit()
|
||||
logger.info("[ConversationStore] Migrated: added title column")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConversationStore] Migration (title) failed: {e}")
|
||||
if "context_start_seq" not in cols:
|
||||
try:
|
||||
conn.execute(_MIGRATION_ADD_CONTEXT_START_SEQ)
|
||||
conn.commit()
|
||||
logger.info("[ConversationStore] Migrated: added context_start_seq column")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConversationStore] Migration (context_start_seq) failed: {e}")
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(str(self._db_path), timeout=10)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_store_instance: Optional[ConversationStore] = None
|
||||
_store_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_conversation_store() -> ConversationStore:
|
||||
"""
|
||||
Return the process-wide ConversationStore singleton.
|
||||
|
||||
Reuses the long-term memory database so the project stays with a single
|
||||
SQLite file: ~/cow/memory/long-term/index.db
|
||||
The conversation tables (sessions / messages) are separate from the
|
||||
memory tables (memory_chunks / file_metadata) — no conflicts.
|
||||
"""
|
||||
global _store_instance
|
||||
if _store_instance is not None:
|
||||
return _store_instance
|
||||
|
||||
with _store_lock:
|
||||
if _store_instance is not None:
|
||||
return _store_instance
|
||||
|
||||
try:
|
||||
from agent.memory.config import get_default_memory_config
|
||||
db_path = get_default_memory_config().get_db_path()
|
||||
except Exception:
|
||||
from common.utils import expand_path
|
||||
db_path = Path(expand_path("~/cow")) / "memory" / "long-term" / "index.db"
|
||||
|
||||
_store_instance = ConversationStore(db_path)
|
||||
logger.debug(f"[ConversationStore] Using shared DB at: {db_path}")
|
||||
return _store_instance
|
||||
@@ -32,18 +32,21 @@ class EmbeddingProvider(ABC):
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using REST API"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None, extra_headers: Optional[dict] = None):
|
||||
"""
|
||||
Initialize OpenAI embedding provider
|
||||
|
||||
|
||||
Args:
|
||||
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
||||
api_key: OpenAI API key
|
||||
api_base: Optional API base URL
|
||||
extra_headers: Optional extra headers to include in API requests
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
self.extra_headers = extra_headers or {}
|
||||
|
||||
# Validate API key
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
@@ -59,7 +62,8 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**self.extra_headers,
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
@@ -134,28 +138,30 @@ def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Only supports OpenAI embedding via REST API.
|
||||
|
||||
Supports "openai" and "linkai" providers (both use OpenAI-compatible REST API).
|
||||
If initialization fails, caller should fall back to keyword-only search.
|
||||
|
||||
|
||||
Args:
|
||||
provider: Provider name (only "openai" is supported)
|
||||
provider: Provider name ("openai" or "linkai")
|
||||
model: Model name (default: text-embedding-3-small)
|
||||
api_key: OpenAI API key (required)
|
||||
api_base: API base URL (default: https://api.openai.com/v1)
|
||||
|
||||
api_key: API key (required)
|
||||
api_base: API base URL
|
||||
extra_headers: Optional extra headers to include in API requests
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not "openai" or api_key is missing
|
||||
ValueError: If provider is unsupported or api_key is missing
|
||||
"""
|
||||
if provider != "openai":
|
||||
raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
|
||||
if provider not in ("openai", "linkai"):
|
||||
raise ValueError(f"Unsupported embedding provider: {provider}. Use 'openai' or 'linkai'.")
|
||||
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base, extra_headers=extra_headers)
|
||||
|
||||
@@ -50,28 +50,48 @@ class MemoryManager:
|
||||
overlap_tokens=self.config.chunk_overlap_tokens
|
||||
)
|
||||
|
||||
# Initialize embedding provider (optional)
|
||||
# Initialize embedding provider (optional, prefer OpenAI, fallback to LinkAI)
|
||||
self.embedding_provider = None
|
||||
if embedding_provider:
|
||||
self.embedding_provider = embedding_provider
|
||||
else:
|
||||
# Try to create embedding provider, but allow failure
|
||||
# Try OpenAI first
|
||||
try:
|
||||
# Get API key from environment or config
|
||||
api_key = os.environ.get('OPENAI_API_KEY')
|
||||
api_base = os.environ.get('OPENAI_API_BASE')
|
||||
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider=self.config.embedding_provider,
|
||||
model=self.config.embedding_model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
if api_key:
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model=self.config.embedding_model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
except Exception as e:
|
||||
# Embedding provider failed, but that's OK
|
||||
# We can still use keyword search and file operations
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
|
||||
logger.warning(f"[MemoryManager] OpenAI embedding failed: {e}")
|
||||
|
||||
# Fallback to LinkAI
|
||||
if self.embedding_provider is None:
|
||||
try:
|
||||
linkai_key = os.environ.get('LINKAI_API_KEY')
|
||||
linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech')
|
||||
if linkai_key:
|
||||
from common.utils import get_cloud_headers
|
||||
cloud_headers = get_cloud_headers(linkai_key)
|
||||
cloud_headers.pop("Authorization", None)
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider="linkai",
|
||||
model=self.config.embedding_model,
|
||||
api_key=linkai_key,
|
||||
api_base=f"{linkai_base}/v1",
|
||||
extra_headers=cloud_headers,
|
||||
)
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] LinkAI embedding failed: {e}")
|
||||
|
||||
if self.embedding_provider is None:
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
||||
|
||||
# Initialize memory flush manager
|
||||
@@ -265,6 +285,10 @@ class MemoryManager:
|
||||
# Scan memory directory (including daily summaries)
|
||||
if memory_dir.exists():
|
||||
for file_path in memory_dir.rglob("*.md"):
|
||||
# Skip hidden directories (e.g. .dreams/)
|
||||
if any(part.startswith('.') for part in file_path.relative_to(workspace_dir).parts):
|
||||
continue
|
||||
|
||||
# Determine scope and user_id from path
|
||||
rel_path = file_path.relative_to(workspace_dir)
|
||||
parts = rel_path.parts
|
||||
@@ -292,6 +316,14 @@ class MemoryManager:
|
||||
scope = "shared"
|
||||
|
||||
await self._sync_file(file_path, "memory", scope, user_id)
|
||||
|
||||
# Scan knowledge directory (structured knowledge wiki)
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
knowledge_dir = Path(workspace_dir) / "knowledge"
|
||||
if knowledge_dir.exists():
|
||||
for file_path in knowledge_dir.rglob("*.md"):
|
||||
await self._sync_file(file_path, "knowledge", "shared", None)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
@@ -363,182 +395,39 @@ class MemoryManager:
|
||||
size=stat.st_size
|
||||
)
|
||||
|
||||
def should_flush_memory(
|
||||
def flush_memory(
|
||||
self,
|
||||
current_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window。
|
||||
使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
|
||||
Returns:
|
||||
True if memory flush should run
|
||||
"""
|
||||
return self.flush_manager.should_flush(
|
||||
current_tokens=current_tokens,
|
||||
token_threshold=self.config.flush_token_threshold,
|
||||
turn_threshold=self.config.flush_turn_threshold
|
||||
)
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数(每次用户消息+AI回复算一轮)"""
|
||||
self.flush_manager.increment_turn()
|
||||
|
||||
async def execute_memory_flush(
|
||||
self,
|
||||
agent_executor,
|
||||
current_tokens: int,
|
||||
messages: list,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
reason: str = "threshold",
|
||||
max_messages: int = 10,
|
||||
context_summary_callback=None,
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush before compaction
|
||||
|
||||
This runs a silent agent turn to write durable memories to disk.
|
||||
Similar to clawdbot's pre-compaction memory flush.
|
||||
|
||||
Flush conversation summary to daily memory file.
|
||||
|
||||
Args:
|
||||
agent_executor: Async function to execute agent with prompt
|
||||
current_tokens: Current session token count
|
||||
messages: Conversation message list
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
reason: "threshold" | "overflow" | "daily_summary"
|
||||
max_messages: Max recent messages to include (0 = all)
|
||||
context_summary_callback: Optional callback(str) invoked with the
|
||||
daily summary text for in-context injection
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
|
||||
Example:
|
||||
>>> async def run_agent(prompt, system_prompt, silent=False):
|
||||
... # Your agent execution logic
|
||||
... pass
|
||||
>>>
|
||||
>>> if manager.should_flush_memory(current_tokens=100000):
|
||||
... await manager.execute_memory_flush(
|
||||
... agent_executor=run_agent,
|
||||
... current_tokens=100000
|
||||
... )
|
||||
True if flush was dispatched
|
||||
"""
|
||||
success = await self.flush_manager.execute_flush(
|
||||
agent_executor=agent_executor,
|
||||
current_tokens=current_tokens,
|
||||
success = self.flush_manager.flush_from_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
**executor_kwargs
|
||||
reason=reason,
|
||||
max_messages=max_messages,
|
||||
context_summary_callback=context_summary_callback,
|
||||
)
|
||||
|
||||
if success:
|
||||
# Mark dirty so next search will sync the new memories
|
||||
self._dirty = True
|
||||
|
||||
return success
|
||||
|
||||
def build_memory_guidance(self, lang: str = "zh", include_context: bool = True) -> str:
|
||||
"""
|
||||
Build natural memory guidance for agent system prompt
|
||||
|
||||
Following clawdbot's approach:
|
||||
1. Load MEMORY.md as bootstrap context (blends into background)
|
||||
2. Load daily files on-demand via memory_search tool
|
||||
3. Agent should NOT proactively mention memories unless user asks
|
||||
|
||||
Args:
|
||||
lang: Language for guidance ("en" or "zh")
|
||||
include_context: Whether to include bootstrap memory context (default: True)
|
||||
MEMORY.md is loaded as background context (like clawdbot)
|
||||
Daily files are accessed via memory_search tool
|
||||
|
||||
Returns:
|
||||
Memory guidance text (and optionally context) for system prompt
|
||||
"""
|
||||
today_file = self.flush_manager.get_today_memory_file().name
|
||||
|
||||
if lang == "zh":
|
||||
guidance = f"""## 记忆系统
|
||||
|
||||
**背景知识**: 下方包含核心长期记忆,可直接使用。需要查找历史时,用 memory_search 搜索(搜索一次即可,不要重复)。
|
||||
|
||||
**存储记忆**: 当用户分享重要信息时(偏好、决策、事实等),主动用 write 工具存储:
|
||||
- 长期信息 → MEMORY.md
|
||||
- 当天笔记 → memory/{today_file}
|
||||
- 静默存储,仅在明确要求时确认
|
||||
|
||||
**使用原则**: 自然使用记忆,就像你本来就知道。不需要生硬地提起或列举记忆,除非用户提到。"""
|
||||
else:
|
||||
guidance = f"""## Memory System
|
||||
|
||||
**Background Knowledge**: Core long-term memories below - use directly. For history, use memory_search once (don't repeat).
|
||||
|
||||
**Store Memories**: When user shares important info (preferences, decisions, facts), proactively write:
|
||||
- Durable info → MEMORY.md
|
||||
- Daily notes → memory/{today_file}
|
||||
- Store silently; confirm only when explicitly requested
|
||||
|
||||
**Usage**: Use memories naturally as if you always knew. Don't mention or list unless user explicitly asks."""
|
||||
|
||||
if include_context:
|
||||
# Load bootstrap context (MEMORY.md only, like clawdbot)
|
||||
bootstrap_context = self.load_bootstrap_memories()
|
||||
if bootstrap_context:
|
||||
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
|
||||
|
||||
return guidance
|
||||
|
||||
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Load bootstrap memory files for session start
|
||||
|
||||
Following clawdbot's design:
|
||||
- Only loads MEMORY.md from workspace root (long-term curated memory)
|
||||
- Daily files (memory/YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
|
||||
- User-specific MEMORY.md is also loaded if user_id provided
|
||||
|
||||
Returns memory content WITHOUT obvious headers so it blends naturally
|
||||
into the context as background knowledge.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memories
|
||||
|
||||
Returns:
|
||||
Memory content to inject into system prompt (blends naturally as background context)
|
||||
"""
|
||||
workspace_dir = self.config.get_workspace()
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
|
||||
sections = []
|
||||
|
||||
# 1. Load MEMORY.md from workspace root (long-term curated memory)
|
||||
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
try:
|
||||
content = memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read MEMORY.md: {e}")
|
||||
|
||||
# 2. Load user-specific MEMORY.md if user_id provided
|
||||
if user_id:
|
||||
user_memory_dir = memory_dir / "users" / user_id
|
||||
user_memory_file = user_memory_dir / "MEMORY.md"
|
||||
if user_memory_file.exists():
|
||||
try:
|
||||
content = user_memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read user memory: {e}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
# Join sections without obvious headers - let memories blend naturally
|
||||
# This makes the agent feel like it "just knows" rather than "checking memory files"
|
||||
return "\n\n".join(sections)
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get memory status"""
|
||||
stats = self.storage.get_stats()
|
||||
@@ -568,6 +457,37 @@ class MemoryManager:
|
||||
content = f"{path}:{start_line}:{end_line}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_temporal_decay(path: str, half_life_days: float = 30.0) -> float:
|
||||
"""
|
||||
Compute temporal decay multiplier for dated memory files.
|
||||
|
||||
Inspired by OpenClaw's temporal-decay: exponential decay based on file date.
|
||||
MEMORY.md and non-dated files are "evergreen" (no decay, multiplier=1.0).
|
||||
Daily files like memory/2025-03-01.md decay based on age.
|
||||
|
||||
Formula: multiplier = exp(-ln2/half_life * age_in_days)
|
||||
"""
|
||||
import re
|
||||
import math
|
||||
|
||||
match = re.search(r'(\d{4})-(\d{2})-(\d{2})\.md$', path)
|
||||
if not match:
|
||||
return 1.0 # evergreen: MEMORY.md, non-dated files
|
||||
|
||||
try:
|
||||
file_date = datetime(
|
||||
int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
)
|
||||
age_days = (datetime.now() - file_date).days
|
||||
if age_days <= 0:
|
||||
return 1.0
|
||||
|
||||
decay_lambda = math.log(2) / half_life_days
|
||||
return math.exp(-decay_lambda * age_days)
|
||||
except (ValueError, OverflowError):
|
||||
return 1.0
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
vector_results: List[SearchResult],
|
||||
@@ -575,8 +495,7 @@ class MemoryManager:
|
||||
vector_weight: float,
|
||||
keyword_weight: float
|
||||
) -> List[SearchResult]:
|
||||
"""Merge vector and keyword search results"""
|
||||
# Create a map by (path, start_line, end_line)
|
||||
"""Merge vector and keyword search results with temporal decay for dated files"""
|
||||
merged_map = {}
|
||||
|
||||
for result in vector_results:
|
||||
@@ -598,7 +517,6 @@ class MemoryManager:
|
||||
'keyword_score': result.score
|
||||
}
|
||||
|
||||
# Calculate combined scores
|
||||
merged_results = []
|
||||
for entry in merged_map.values():
|
||||
combined_score = (
|
||||
@@ -606,7 +524,11 @@ class MemoryManager:
|
||||
keyword_weight * entry['keyword_score']
|
||||
)
|
||||
|
||||
# Apply temporal decay for dated memory files
|
||||
result = entry['result']
|
||||
decay = self._compute_temporal_decay(result.path)
|
||||
combined_score *= decay
|
||||
|
||||
merged_results.append(SearchResult(
|
||||
path=result.path,
|
||||
start_line=result.start_line,
|
||||
@@ -617,6 +539,5 @@ class MemoryManager:
|
||||
user_id=result.user_id
|
||||
))
|
||||
|
||||
# Sort by score
|
||||
merged_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return merged_results
|
||||
|
||||
197
agent/memory/service.py
Normal file
197
agent/memory/service.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Memory service for handling memory query operations via cloud protocol.
|
||||
|
||||
Provides a unified interface for listing and reading memory files,
|
||||
callable from the cloud client (LinkAI) or a future web console.
|
||||
|
||||
Memory file layout (under workspace_root):
|
||||
MEMORY.md -> type: global
|
||||
memory/2026-02-20.md -> type: daily
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
High-level service for memory file queries.
|
||||
Operates directly on the filesystem — no MemoryManager dependency.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
"""
|
||||
:param workspace_root: Workspace root directory (e.g. ~/cow)
|
||||
"""
|
||||
self.workspace_root = workspace_root
|
||||
self.memory_dir = os.path.join(workspace_root, "memory")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — paginated file metadata
|
||||
# ------------------------------------------------------------------
|
||||
def list_files(self, page: int = 1, page_size: int = 20, category: str = "memory") -> dict:
|
||||
"""
|
||||
List memory or dream files with metadata (without content).
|
||||
|
||||
Args:
|
||||
category: ``"memory"`` (default) — MEMORY.md + daily files;
|
||||
``"dream"`` — dream diary files from memory/dreams/
|
||||
"""
|
||||
if category == "dream":
|
||||
files = self._list_dream_files()
|
||||
else:
|
||||
files = self._list_memory_files()
|
||||
|
||||
total = len(files)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"list": files[start:end],
|
||||
}
|
||||
|
||||
def _list_memory_files(self) -> List[dict]:
|
||||
"""MEMORY.md + memory/*.md (newest first)."""
|
||||
files: List[dict] = []
|
||||
|
||||
global_path = os.path.join(self.workspace_root, "MEMORY.md")
|
||||
if os.path.isfile(global_path):
|
||||
files.append(self._file_info(global_path, "MEMORY.md", "global"))
|
||||
|
||||
if os.path.isdir(self.memory_dir):
|
||||
daily_files = []
|
||||
for name in os.listdir(self.memory_dir):
|
||||
full = os.path.join(self.memory_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
daily_files.append((name, full))
|
||||
daily_files.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in daily_files:
|
||||
files.append(self._file_info(full, name, "daily"))
|
||||
|
||||
return files
|
||||
|
||||
def _list_dream_files(self) -> List[dict]:
|
||||
"""memory/dreams/*.md (newest first)."""
|
||||
files: List[dict] = []
|
||||
dreams_dir = os.path.join(self.memory_dir, "dreams")
|
||||
|
||||
if os.path.isdir(dreams_dir):
|
||||
entries = []
|
||||
for name in os.listdir(dreams_dir):
|
||||
full = os.path.join(dreams_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
entries.append((name, full))
|
||||
entries.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in entries:
|
||||
files.append(self._file_info(full, name, "dream"))
|
||||
|
||||
return files
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# content — read a single file
|
||||
# ------------------------------------------------------------------
|
||||
def get_content(self, filename: str, category: str = "memory") -> dict:
|
||||
"""
|
||||
Read the full content of a memory or dream file.
|
||||
|
||||
:param filename: File name, e.g. ``MEMORY.md``, ``2026-02-20.md``
|
||||
:param category: ``"memory"`` or ``"dream"``
|
||||
:return: dict with ``filename`` and ``content``
|
||||
:raises FileNotFoundError: if the file does not exist
|
||||
"""
|
||||
path = self._resolve_path(filename, category)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Memory file not found: {filename}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a memory management action.
|
||||
|
||||
:param action: ``list`` or ``content``
|
||||
:param payload: action-specific payload (supports ``category``: ``"memory"`` | ``"dream"``)
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
page = payload.get("page", 1)
|
||||
page_size = payload.get("page_size", 20)
|
||||
category = payload.get("category", "memory")
|
||||
result_payload = self.list_files(page=page, page_size=page_size, category=category)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
elif action == "content":
|
||||
filename = payload.get("filename")
|
||||
if not filename:
|
||||
return {"action": action, "code": 400, "message": "filename is required", "payload": None}
|
||||
category = payload.get("category", "memory")
|
||||
result_payload = self.get_content(filename, category=category)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 403, "message": "invalid filename", "payload": None}
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[MemoryService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _resolve_path(self, filename: str, category: str = "memory") -> str:
|
||||
"""
|
||||
Safely resolve a filename to its absolute path within the allowed directory.
|
||||
|
||||
- ``MEMORY.md`` → ``{workspace_root}/MEMORY.md``
|
||||
- ``2026-02-20.md`` (memory) → ``{workspace_root}/memory/2026-02-20.md``
|
||||
- ``2026-02-20.md`` (dream) → ``{workspace_root}/memory/dreams/2026-02-20.md``
|
||||
|
||||
Raises ValueError if the resolved path escapes the allowed directory.
|
||||
"""
|
||||
if filename == "MEMORY.md":
|
||||
base_dir = self.workspace_root
|
||||
elif category == "dream":
|
||||
base_dir = os.path.join(self.memory_dir, "dreams")
|
||||
else:
|
||||
base_dir = self.memory_dir
|
||||
|
||||
resolved = os.path.realpath(os.path.join(base_dir, filename))
|
||||
allowed = os.path.realpath(base_dir)
|
||||
|
||||
if resolved != allowed and not resolved.startswith(allowed + os.sep):
|
||||
raise ValueError(f"Invalid filename: path traversal detected")
|
||||
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _file_info(path: str, filename: str, file_type: str) -> dict:
|
||||
"""Build a file metadata dict."""
|
||||
stat = os.stat(path)
|
||||
updated_at = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
||||
return {
|
||||
"filename": filename,
|
||||
"type": file_type,
|
||||
"size": stat.st_size,
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
@@ -509,7 +509,7 @@ class MemoryStorage:
|
||||
"""Destructor to ensure connection is closed"""
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
except Exception:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
# Helper methods
|
||||
|
||||
@@ -1,225 +1,652 @@
|
||||
"""
|
||||
Memory flush manager
|
||||
Memory flush manager with Deep Dream distillation
|
||||
|
||||
Triggers memory flush before context compaction (similar to clawdbot)
|
||||
Handles memory persistence when conversation context is trimmed or overflows:
|
||||
- Uses LLM to summarize discarded messages into concise daily records
|
||||
- Writes to daily memory files (lazy creation)
|
||||
- Deduplicates trim flushes to avoid repeated writes
|
||||
- Runs summarization asynchronously to avoid blocking normal replies
|
||||
- Deep Dream: periodically distills daily memories → refined MEMORY.md + dream diary
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable, Any
|
||||
import threading
|
||||
from typing import Optional, Callable, Any, List, Dict
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from common.log import logger
|
||||
|
||||
|
||||
SUMMARIZE_SYSTEM_PROMPT = """你是一个对话记录助手。请将对话内容归纳为当天的日常记录。
|
||||
|
||||
## 要求
|
||||
|
||||
按「事件」维度归纳发生的事,不要按对话轮次逐条记录:
|
||||
- 每条一行,用 "- " 开头
|
||||
- 合并同一件事的多轮对话
|
||||
- 只记录有意义的事件,忽略闲聊和问候
|
||||
- 保留关键的决策、结论和待办事项
|
||||
|
||||
当对话没有任何记录价值(仅含问候或无意义内容),直接回复"无"。"""
|
||||
|
||||
SUMMARIZE_USER_PROMPT = """请归纳以下对话的日常记录:
|
||||
|
||||
{conversation}"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deep Dream prompts — distill daily memories → MEMORY.md + dream diary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DREAM_SYSTEM_PROMPT = """你是一个记忆整理助手,负责定期整理用户的长期记忆。
|
||||
|
||||
你将收到两份材料:
|
||||
1. **当前长期记忆** — MEMORY.md 的全部现有内容
|
||||
2. **今日日记** — 当天的日常记录
|
||||
|
||||
MEMORY.md 会注入每次对话的系统提示词中,因此必须保持精炼,只存放有价值和值得记忆的内容。
|
||||
|
||||
**重要:只能基于提供的材料进行整理,严禁编造、推测或添加材料中不存在的信息。**
|
||||
|
||||
## 任务
|
||||
|
||||
### Part 1: 更新后的长期记忆([MEMORY])
|
||||
|
||||
在现有记忆基础上进行整理和提炼,输出完整的更新后内容:
|
||||
- **合并提炼**:将含义相近的多条合并为一条高密度表述,而非简单罗列
|
||||
- **新增萃取**:从今日日记中提取值得永久记住的新信息(偏好、决策、人物、规则、经验)
|
||||
- **冲突更新**:当新信息与旧条目矛盾时,以新信息为准,替换旧条目
|
||||
- **清理无效**:删除临时性记录、空白条目、格式残留、无意义、重复内容等
|
||||
- **删除冗余**:已被更精炼表述涵盖的旧条目应删除,避免信息重复
|
||||
- 每条一行,用 "- " 开头,不带日期前缀
|
||||
- 可用 "## 标题" 对相关条目分组,使结构更清晰
|
||||
- 目标:控制在 50 条以内,每条尽量一句话概括
|
||||
|
||||
### Part 2: 梦境日记([DREAM])
|
||||
|
||||
用简洁的叙事风格写一篇短日记,记录这次整理的发现,保持格式美观易读:
|
||||
- 发现了哪些重复或矛盾
|
||||
- 从日记中提取了什么新洞察
|
||||
- 做了哪些清理和优化
|
||||
- 整体感受和观察
|
||||
|
||||
## 输出格式(严格遵守)
|
||||
|
||||
```
|
||||
[MEMORY]
|
||||
- 记忆条目1
|
||||
- 记忆条目2
|
||||
...
|
||||
|
||||
[DREAM]
|
||||
梦境日记内容...
|
||||
```"""
|
||||
|
||||
DREAM_USER_PROMPT = """## 当前长期记忆(MEMORY.md)
|
||||
|
||||
{memory_content}
|
||||
|
||||
## 近期日记(最近 {days} 天)
|
||||
|
||||
{daily_content}"""
|
||||
|
||||
|
||||
|
||||
class MemoryFlushManager:
|
||||
"""
|
||||
Manages memory flush operations before context compaction
|
||||
Manages memory flush operations.
|
||||
|
||||
Similar to clawdbot's memory flush mechanism:
|
||||
- Triggers when context approaches token limit
|
||||
- Runs a silent agent turn to write memories to disk
|
||||
- Uses memory/YYYY-MM-DD.md for daily notes
|
||||
- Uses MEMORY.md (workspace root) for long-term curated memories
|
||||
Flush is triggered by agent_stream in two scenarios:
|
||||
1. Context trim: _trim_messages discards old turns → flush discarded content
|
||||
2. Context overflow: API rejects request → emergency flush before clearing
|
||||
|
||||
Additionally, create_daily_summary() can be called by scheduler for end-of-day summaries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
llm_model: Optional[Any] = None
|
||||
llm_model: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Initialize memory flush manager
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
llm_model: LLM model for agent execution (optional)
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.llm_model = llm_model
|
||||
|
||||
self.memory_dir = workspace_dir / "memory"
|
||||
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Tracking
|
||||
self.last_flush_token_count: Optional[int] = None
|
||||
self.last_flush_timestamp: Optional[datetime] = None
|
||||
self.turn_count: int = 0 # 对话轮数计数器
|
||||
self._trim_flushed_hashes: set = set() # Content hashes of already-flushed messages
|
||||
self._last_flushed_content_hash: str = "" # Content hash at last flush, for daily dedup
|
||||
self._last_dream_input_hash: str = "" # Hash of dream input, for dedup
|
||||
self._last_flush_thread: Optional[threading.Thread] = None
|
||||
|
||||
def should_flush(
|
||||
self,
|
||||
current_tokens: int = 0,
|
||||
token_threshold: int = 50000,
|
||||
turn_threshold: int = 20
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window:
|
||||
- Token 阈值: 达到 50K tokens 时触发
|
||||
- 轮次阈值: 达到 20 轮对话时触发
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
token_threshold: Token threshold to trigger flush (default: 50K)
|
||||
turn_threshold: Turn threshold to trigger flush (default: 20)
|
||||
|
||||
Returns:
|
||||
True if flush should run
|
||||
"""
|
||||
# 检查 token 阈值
|
||||
if current_tokens > 0 and current_tokens >= token_threshold:
|
||||
# 避免重复 flush
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + 5000:
|
||||
return False
|
||||
return True
|
||||
|
||||
# 检查轮次阈值
|
||||
if self.turn_count >= turn_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get today's memory file path: memory/YYYY-MM-DD.md
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memory
|
||||
|
||||
Returns:
|
||||
Path to today's memory file
|
||||
"""
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None, ensure_exists: bool = False) -> Path:
|
||||
"""Get today's memory file path: memory/YYYY-MM-DD.md"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / f"{today}.md"
|
||||
if ensure_exists:
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
today_file = user_dir / f"{today}.md"
|
||||
else:
|
||||
return self.memory_dir / f"{today}.md"
|
||||
today_file = self.memory_dir / f"{today}.md"
|
||||
|
||||
if ensure_exists and not today_file.exists():
|
||||
today_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
today_file.write_text(f"# Daily Memory: {today}\n\n")
|
||||
|
||||
return today_file
|
||||
|
||||
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get main memory file path: MEMORY.md (workspace root)
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memory
|
||||
|
||||
Returns:
|
||||
Path to main memory file
|
||||
"""
|
||||
"""Get main memory file path: MEMORY.md (workspace root)"""
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / "MEMORY.md"
|
||||
else:
|
||||
# Return workspace root MEMORY.md
|
||||
return Path(self.workspace_dir) / "MEMORY.md"
|
||||
|
||||
def create_flush_prompt(self) -> str:
|
||||
"""
|
||||
Create prompt for memory flush turn
|
||||
|
||||
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_PROMPT
|
||||
"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
return (
|
||||
f"Pre-compaction memory flush. "
|
||||
f"Store durable memories now (use memory/{today}.md for daily notes; "
|
||||
f"create memory/ if needed). "
|
||||
f"\n\n"
|
||||
f"重要提示:\n"
|
||||
f"- MEMORY.md: 记录最核心、最常用的信息(例如重要规则、偏好、决策、要求等)\n"
|
||||
f" 如果 MEMORY.md 过长,可以精简或移除不再重要的内容。避免冗长描述,用关键词和要点形式记录\n"
|
||||
f"- memory/{today}.md: 记录当天发生的事件、关键信息、经验教训、对话过程摘要等,突出重点\n"
|
||||
f"- 如果没有重要内容需要记录,回复 NO_REPLY\n"
|
||||
)
|
||||
|
||||
def create_flush_system_prompt(self) -> str:
|
||||
"""
|
||||
Create system prompt for memory flush turn
|
||||
|
||||
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_SYSTEM_PROMPT
|
||||
"""
|
||||
return (
|
||||
"Pre-compaction memory flush turn. "
|
||||
"The session is near auto-compaction; capture durable memories to disk. "
|
||||
"\n\n"
|
||||
"记忆写入原则:\n"
|
||||
"1. MEMORY.md 精简原则: 只记录核心信息(<2000 tokens)\n"
|
||||
" - 记录重要规则、偏好、决策、要求等需要长期记住的关键信息,无需记录过多细节\n"
|
||||
" - 如果 MEMORY.md 过长,可以根据需要精简或删除过时内容\n"
|
||||
"\n"
|
||||
"2. 天级记忆 (memory/YYYY-MM-DD.md):\n"
|
||||
" - 记录当天的重要事件、关键信息、经验教训、对话过程摘要等,确保核心信息点被完整记录\n"
|
||||
"\n"
|
||||
"3. 判断标准:\n"
|
||||
" - 这个信息未来会经常用到吗?→ MEMORY.md\n"
|
||||
" - 这是今天的重要事件或决策吗?→ memory/YYYY-MM-DD.md\n"
|
||||
" - 这是临时性的、不重要的内容吗?→ 不记录\n"
|
||||
"\n"
|
||||
"You may reply, but usually NO_REPLY is correct."
|
||||
)
|
||||
|
||||
async def execute_flush(
|
||||
self,
|
||||
agent_executor: Callable,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush by running a silent agent turn
|
||||
|
||||
Args:
|
||||
agent_executor: Function to execute agent with prompt
|
||||
current_tokens: Current token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
"""
|
||||
try:
|
||||
# Create flush prompts
|
||||
prompt = self.create_flush_prompt()
|
||||
system_prompt = self.create_flush_system_prompt()
|
||||
|
||||
# Execute agent turn (silent, no user-visible reply expected)
|
||||
await agent_executor(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
silent=True, # NO_REPLY expected
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
# Track flush
|
||||
self.last_flush_token_count = current_tokens
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
self.turn_count = 0 # 重置轮数计数器
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Memory flush failed: {e}")
|
||||
return False
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数"""
|
||||
self.turn_count += 1
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get memory flush status"""
|
||||
return {
|
||||
'last_flush_tokens': self.last_flush_token_count,
|
||||
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
|
||||
'today_file': str(self.get_today_memory_file()),
|
||||
'main_file': str(self.get_main_memory_file())
|
||||
}
|
||||
|
||||
# ---- Flush execution (called by agent_stream or scheduler) ----
|
||||
|
||||
def flush_from_messages(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str] = None,
|
||||
reason: str = "trim",
|
||||
max_messages: int = 0,
|
||||
context_summary_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously summarize and flush messages to daily memory.
|
||||
|
||||
Deduplication runs synchronously, then LLM summarization + file write
|
||||
run in a background thread so the main reply flow is never blocked.
|
||||
|
||||
If *context_summary_callback* is provided, it is called with the
|
||||
[DAILY] portion of the LLM summary once available. The caller can use
|
||||
this to inject the summary into the live message list for context
|
||||
continuity — one LLM call serves both disk persistence and in-context
|
||||
injection.
|
||||
"""
|
||||
try:
|
||||
import hashlib
|
||||
deduped = []
|
||||
for m in messages:
|
||||
text = self._extract_text_from_content(m.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
h = hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
if h not in self._trim_flushed_hashes:
|
||||
self._trim_flushed_hashes.add(h)
|
||||
deduped.append(m)
|
||||
if not deduped:
|
||||
return False
|
||||
|
||||
import copy
|
||||
snapshot = copy.deepcopy(deduped)
|
||||
thread = threading.Thread(
|
||||
target=self._flush_worker,
|
||||
args=(snapshot, user_id, reason, max_messages, context_summary_callback),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
logger.info(f"[MemoryFlush] Async flush dispatched (reason={reason}, msgs={len(snapshot)})")
|
||||
self._last_flush_thread = thread
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Failed to dispatch flush (reason={reason}): {e}")
|
||||
return False
|
||||
|
||||
def _flush_worker(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str],
|
||||
reason: str,
|
||||
max_messages: int,
|
||||
context_summary_callback: Optional[Callable[[str], None]] = None,
|
||||
):
|
||||
"""Background worker: summarize with LLM, write daily memory file."""
|
||||
try:
|
||||
raw_summary = self._summarize_messages(messages, max_messages)
|
||||
if not raw_summary or not raw_summary.strip() or raw_summary.strip() == "无":
|
||||
logger.info(f"[MemoryFlush] No valuable content to flush (reason={reason})")
|
||||
return
|
||||
|
||||
# Strip legacy [DAILY]/[MEMORY] markers if model still outputs them
|
||||
daily_part = self._clean_summary_output(raw_summary)
|
||||
if not daily_part:
|
||||
return
|
||||
|
||||
# --- Write daily memory ---
|
||||
daily_file = ensure_daily_memory_file(self.workspace_dir, user_id)
|
||||
|
||||
headers = {
|
||||
"overflow": f"## Context Overflow Recovery ({datetime.now().strftime('%H:%M')})",
|
||||
"trim": f"## Trimmed Context ({datetime.now().strftime('%H:%M')})",
|
||||
"daily_summary": f"## Daily Summary ({datetime.now().strftime('%H:%M')})",
|
||||
}
|
||||
header = headers.get(reason, f"## Session Notes ({datetime.now().strftime('%H:%M')})")
|
||||
|
||||
with open(daily_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n{header}\n\n{daily_part}\n")
|
||||
|
||||
logger.info(f"[MemoryFlush] Wrote daily memory to {daily_file.name} (reason={reason}, chars={len(daily_part)})")
|
||||
|
||||
# --- Inject context summary into live messages (if callback provided) ---
|
||||
if context_summary_callback:
|
||||
try:
|
||||
context_summary_callback(daily_part)
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Context summary callback failed: {e}")
|
||||
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Async flush failed (reason={reason}): {e}")
|
||||
|
||||
@staticmethod
|
||||
def _clean_summary_output(raw: str) -> str:
|
||||
"""Strip legacy [DAILY]/[MEMORY] markers if present, return clean daily text."""
|
||||
raw = raw.strip()
|
||||
if not raw or raw == "无":
|
||||
return ""
|
||||
|
||||
# Strip [DAILY] marker
|
||||
if "[DAILY]" in raw:
|
||||
start = raw.index("[DAILY]") + len("[DAILY]")
|
||||
end = raw.index("[MEMORY]") if "[MEMORY]" in raw else len(raw)
|
||||
raw = raw[start:end].strip()
|
||||
|
||||
# Remove stray [MEMORY] section entirely
|
||||
if "[MEMORY]" in raw:
|
||||
raw = raw[:raw.index("[MEMORY]")].strip()
|
||||
|
||||
# Remove markdown code fences
|
||||
raw = raw.replace("```", "").strip()
|
||||
|
||||
return raw
|
||||
|
||||
def create_daily_summary(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Generate end-of-day summary. Called by daily timer.
|
||||
Skips if messages haven't changed since last flush.
|
||||
"""
|
||||
import hashlib
|
||||
content = "".join(
|
||||
self._extract_text_from_content(m.get("content", ""))
|
||||
for m in messages
|
||||
)
|
||||
content_hash = hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
if content_hash == self._last_flushed_content_hash:
|
||||
logger.debug("[MemoryFlush] Daily summary skipped: no new content since last flush")
|
||||
return False
|
||||
self._last_flushed_content_hash = content_hash
|
||||
return self.flush_from_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
reason="daily_summary",
|
||||
max_messages=0,
|
||||
)
|
||||
|
||||
# ---- Deep Dream (memory distillation) ----
|
||||
|
||||
def deep_dream(self, user_id: Optional[str] = None, lookback_days: int = 1, force: bool = False) -> bool:
|
||||
"""
|
||||
Distill recent daily memories into MEMORY.md and generate a dream diary.
|
||||
|
||||
Args:
|
||||
lookback_days: How many days of daily files to read (default 1 for scheduled, 3 for manual)
|
||||
force: Skip input-hash dedup check (used by manual /memory dream trigger)
|
||||
"""
|
||||
if not self.llm_model:
|
||||
logger.warning("[DeepDream] No LLM model available, skipping")
|
||||
return False
|
||||
|
||||
logger.info(f"[DeepDream] Starting memory distillation (lookback={lookback_days} days)")
|
||||
|
||||
# Collect materials
|
||||
memory_content = self._read_main_memory(user_id)
|
||||
daily_content, has_content = self._read_recent_dailies(user_id, lookback_days)
|
||||
|
||||
if not has_content:
|
||||
logger.info("[DeepDream] No recent daily records, skipping to preserve existing MEMORY.md")
|
||||
return False
|
||||
|
||||
# Dedup: skip if input materials haven't changed since last dream
|
||||
import hashlib
|
||||
input_hash = hashlib.md5((memory_content + daily_content).encode("utf-8")).hexdigest()
|
||||
if not force and input_hash == self._last_dream_input_hash:
|
||||
logger.debug("[DeepDream] Input unchanged since last dream, skipping")
|
||||
return False
|
||||
self._last_dream_input_hash = input_hash
|
||||
|
||||
logger.info(
|
||||
f"[DeepDream] Materials collected: "
|
||||
f"MEMORY.md={len(memory_content)} chars, "
|
||||
f"daily={len(daily_content)} chars"
|
||||
)
|
||||
|
||||
# Call LLM for distillation
|
||||
import time as _time
|
||||
t0 = _time.monotonic()
|
||||
try:
|
||||
user_msg = DREAM_USER_PROMPT.format(
|
||||
memory_content=memory_content or "(empty)",
|
||||
days=lookback_days,
|
||||
daily_content=daily_content or "(no recent daily records)",
|
||||
)
|
||||
from agent.protocol.models import LLMRequest
|
||||
# Scale max_tokens based on input size to avoid truncating large MEMORY.md
|
||||
input_chars = len(memory_content) + len(daily_content)
|
||||
dream_max_tokens = max(2000, min(input_chars, 8000))
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
temperature=0.3,
|
||||
max_tokens=dream_max_tokens,
|
||||
stream=False,
|
||||
system=DREAM_SYSTEM_PROMPT,
|
||||
)
|
||||
response = self.llm_model.call(request)
|
||||
raw = self._extract_response_text(response)
|
||||
elapsed = _time.monotonic() - t0
|
||||
if not raw or not raw.strip():
|
||||
logger.warning(f"[DeepDream] LLM returned empty response ({elapsed:.1f}s)")
|
||||
return False
|
||||
logger.info(f"[DeepDream] LLM distillation completed ({elapsed:.1f}s, {len(raw)} chars)")
|
||||
except Exception as e:
|
||||
elapsed = _time.monotonic() - t0
|
||||
logger.warning(f"[DeepDream] LLM call failed ({elapsed:.1f}s): {e}")
|
||||
return False
|
||||
|
||||
# Parse [MEMORY] and [DREAM] sections
|
||||
new_memory, dream_diary = self._parse_dream_output(raw)
|
||||
|
||||
if not new_memory:
|
||||
logger.warning("[DeepDream] No [MEMORY] section in LLM output, skipping overwrite")
|
||||
return False
|
||||
|
||||
# Overwrite MEMORY.md
|
||||
try:
|
||||
main_file = self.get_main_memory_file(user_id)
|
||||
old_size = len(memory_content)
|
||||
main_file.write_text(new_memory + "\n", encoding="utf-8")
|
||||
logger.info(
|
||||
f"[DeepDream] Updated MEMORY.md "
|
||||
f"({old_size} → {len(new_memory)} chars)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[DeepDream] Failed to write MEMORY.md: {e}")
|
||||
return False
|
||||
|
||||
# Write dream diary
|
||||
if dream_diary:
|
||||
try:
|
||||
self._write_dream_diary(dream_diary, user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"[DeepDream] Failed to write dream diary: {e}")
|
||||
|
||||
logger.info("[DeepDream] ✅ Deep Dream completed successfully")
|
||||
return True
|
||||
|
||||
def _read_main_memory(self, user_id: Optional[str] = None) -> str:
|
||||
"""Read current MEMORY.md content."""
|
||||
main_file = self.get_main_memory_file(user_id)
|
||||
if main_file.exists():
|
||||
return main_file.read_text(encoding="utf-8").strip()
|
||||
return ""
|
||||
|
||||
def _read_recent_dailies(
|
||||
self, user_id: Optional[str] = None, lookback_days: int = 1
|
||||
) -> tuple:
|
||||
"""
|
||||
Read recent daily memory files.
|
||||
|
||||
Returns:
|
||||
(combined_text, has_content) tuple
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
parts = []
|
||||
has_content = False
|
||||
today = datetime.now().date()
|
||||
|
||||
for offset in range(lookback_days):
|
||||
day = today - timedelta(days=offset)
|
||||
date_str = day.strftime("%Y-%m-%d")
|
||||
if user_id:
|
||||
daily_file = self.memory_dir / "users" / user_id / f"{date_str}.md"
|
||||
else:
|
||||
daily_file = self.memory_dir / f"{date_str}.md"
|
||||
|
||||
if daily_file.exists():
|
||||
content = daily_file.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
parts.append(f"### {date_str}\n\n{content}")
|
||||
has_content = True
|
||||
else:
|
||||
parts.append(f"### {date_str}\n\n(no records)")
|
||||
|
||||
return "\n\n".join(parts), has_content
|
||||
|
||||
@staticmethod
|
||||
def _parse_dream_output(raw: str) -> tuple:
|
||||
"""Parse LLM output into (new_memory, dream_diary)."""
|
||||
raw = raw.strip().replace("```", "")
|
||||
new_memory = ""
|
||||
dream_diary = ""
|
||||
|
||||
if "[MEMORY]" in raw:
|
||||
start = raw.index("[MEMORY]") + len("[MEMORY]")
|
||||
end = raw.index("[DREAM]") if "[DREAM]" in raw else len(raw)
|
||||
new_memory = raw[start:end].strip()
|
||||
|
||||
if "[DREAM]" in raw:
|
||||
start = raw.index("[DREAM]") + len("[DREAM]")
|
||||
dream_diary = raw[start:].strip()
|
||||
|
||||
return new_memory, dream_diary
|
||||
|
||||
def _write_dream_diary(self, content: str, user_id: Optional[str] = None):
|
||||
"""Write dream diary to memory/dreams/YYYY-MM-DD.md."""
|
||||
dreams_dir = self.memory_dir / "dreams"
|
||||
if user_id:
|
||||
dreams_dir = self.memory_dir / "users" / user_id / "dreams"
|
||||
dreams_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
diary_file = dreams_dir / f"{today}.md"
|
||||
diary_file.write_text(
|
||||
f"# Dream Diary: {today}\n\n{content}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(f"[DeepDream] Wrote dream diary to {diary_file}")
|
||||
|
||||
# ---- Internal helpers ----
|
||||
|
||||
def _summarize_messages(self, messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""
|
||||
Summarize conversation messages using LLM.
|
||||
Returns empty string if LLM deems content not worth recording.
|
||||
Rule-based fallback only used when LLM call raises an exception.
|
||||
"""
|
||||
conversation_text = self._format_conversation_for_summary(messages, max_messages)
|
||||
if not conversation_text.strip():
|
||||
return ""
|
||||
|
||||
if self.llm_model:
|
||||
try:
|
||||
summary = self._call_llm_for_summary(conversation_text)
|
||||
if summary and summary.strip() and summary.strip() != "无":
|
||||
return summary.strip()
|
||||
logger.info("[MemoryFlush] LLM returned empty or '无', skipping write")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] LLM summarization failed, using fallback: {e}")
|
||||
return self._extract_summary_fallback(messages, max_messages)
|
||||
else:
|
||||
logger.info("[MemoryFlush] No LLM model available, using rule-based fallback")
|
||||
return self._extract_summary_fallback(messages, max_messages)
|
||||
|
||||
def _format_conversation_for_summary(self, messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""Format messages into readable conversation text for LLM summarization."""
|
||||
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
|
||||
lines = []
|
||||
for msg in msgs:
|
||||
role = msg.get("role", "")
|
||||
text = self._extract_text_from_content(msg.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
text = text.strip()
|
||||
if role == "user":
|
||||
lines.append(f"用户: {text[:500]}")
|
||||
elif role == "assistant":
|
||||
lines.append(f"助手: {text[:500]}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_text(response) -> str:
|
||||
"""
|
||||
Extract text from LLM response regardless of format.
|
||||
|
||||
Handles:
|
||||
- Generator (MiniMax _handle_sync_response yields Claude-format dicts)
|
||||
- Claude format: {"role":"assistant","content":[{"type":"text","text":"..."}]}
|
||||
- OpenAI format: {"choices":[{"message":{"content":"..."}}]}
|
||||
- OpenAI SDK response object with .choices attribute
|
||||
"""
|
||||
import types
|
||||
|
||||
# Unwrap generator — consume first yielded item
|
||||
if isinstance(response, types.GeneratorType):
|
||||
try:
|
||||
response = next(response)
|
||||
except StopIteration:
|
||||
return ""
|
||||
|
||||
if not response:
|
||||
return ""
|
||||
|
||||
if isinstance(response, dict):
|
||||
# Check for error
|
||||
if response.get("error"):
|
||||
raise RuntimeError(response.get("message", "LLM call failed"))
|
||||
|
||||
# Claude format: content is a list of blocks
|
||||
content = response.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
|
||||
# OpenAI format
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("message", {}).get("content", "")
|
||||
|
||||
# OpenAI SDK response object
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
return ""
|
||||
|
||||
def _call_llm_for_summary(self, conversation_text: str) -> str:
|
||||
"""Call LLM to generate a concise summary of the conversation."""
|
||||
from agent.protocol.models import LLMRequest
|
||||
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": SUMMARIZE_USER_PROMPT.format(conversation=conversation_text)}],
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
stream=False,
|
||||
system=SUMMARIZE_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
response = self.llm_model.call(request)
|
||||
return self._extract_response_text(response)
|
||||
|
||||
@staticmethod
|
||||
def _extract_first_meaningful_line(text: str, max_len: int = 120) -> str:
|
||||
"""Extract the first meaningful line from assistant reply, skipping markdown noise."""
|
||||
import re
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Skip markdown headings, horizontal rules, code fences, pure emoji/symbols
|
||||
if re.match(r'^(#{1,4}\s|```|---|\*\*\*|[-*]\s*$|[^\w\u4e00-\u9fff]{1,5}$)', line):
|
||||
continue
|
||||
# Strip leading markdown bold/emoji decorations
|
||||
cleaned = re.sub(r'^[\*#>\-\s]+', '', line).strip()
|
||||
cleaned = re.sub(r'^[\U0001f300-\U0001f9ff\u2600-\u27bf\s]+', '', cleaned).strip()
|
||||
if len(cleaned) >= 5:
|
||||
return cleaned[:max_len]
|
||||
return text.split("\n")[0].strip()[:max_len]
|
||||
|
||||
@staticmethod
|
||||
def _extract_summary_fallback(messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""
|
||||
Rule-based summary of discarded messages.
|
||||
Format: "用户问了X; 助手回答了Y" per event, compact and readable.
|
||||
"""
|
||||
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
|
||||
|
||||
events: List[str] = []
|
||||
current_user_text = ""
|
||||
for msg in msgs:
|
||||
role = msg.get("role", "")
|
||||
text = MemoryFlushManager._extract_text_from_content(msg.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
text = text.strip()
|
||||
|
||||
if role == "user":
|
||||
if len(text) <= 3:
|
||||
continue
|
||||
current_user_text = text[:120]
|
||||
elif role == "assistant" and current_user_text:
|
||||
reply_summary = MemoryFlushManager._extract_first_meaningful_line(text)
|
||||
if reply_summary:
|
||||
events.append(f"- 用户: {current_user_text} → 回复: {reply_summary}")
|
||||
else:
|
||||
events.append(f"- 用户: {current_user_text}")
|
||||
current_user_text = ""
|
||||
|
||||
if current_user_text:
|
||||
events.append(f"- 用户: {current_user_text}")
|
||||
|
||||
return "\n".join(events[:10])
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from message content (string or content blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
|
||||
"""
|
||||
Create default memory files if they don't exist
|
||||
Create essential memory files if they don't exist.
|
||||
Only creates MEMORY.md; daily files are created lazily on first write.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
@@ -228,7 +655,7 @@ def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] =
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create main MEMORY.md in workspace root
|
||||
# Create main MEMORY.md in workspace root (always needed for bootstrap)
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -237,14 +664,28 @@ def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] =
|
||||
main_memory = Path(workspace_dir) / "MEMORY.md"
|
||||
|
||||
if not main_memory.exists():
|
||||
# Create empty file or with minimal structure (no obvious "Memory" header)
|
||||
# Following clawdbot's approach: memories should blend naturally into context
|
||||
main_memory.write_text("")
|
||||
|
||||
|
||||
def ensure_daily_memory_file(workspace_dir: Path, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Ensure today's daily memory file exists, creating it only when actually needed.
|
||||
Called lazily before first write to daily memory.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
user_id: Optional user ID for user-specific files
|
||||
|
||||
Returns:
|
||||
Path to today's memory file
|
||||
"""
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create today's memory file
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
today_memory = user_dir / f"{today}.md"
|
||||
else:
|
||||
today_memory = memory_dir / f"{today}.md"
|
||||
@@ -252,5 +693,6 @@ def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] =
|
||||
if not today_memory.exists():
|
||||
today_memory.write_text(
|
||||
f"# Daily Memory: {today}\n\n"
|
||||
f"Day-to-day notes and running context.\n\n"
|
||||
)
|
||||
|
||||
return today_memory
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,7 +43,6 @@ class PromptBuilder:
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
is_first_conversation: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
@@ -52,11 +52,10 @@ class PromptBuilder:
|
||||
base_persona: 基础人格描述(会被context_files中的AGENT.md覆盖)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表(AGENT.md, USER.md, RULE.md等)
|
||||
context_files: 上下文文件列表(AGENT.md, USER.md, RULE.md, BOOTSTRAP.md等)
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
is_first_conversation: 是否为首次对话
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
@@ -72,7 +71,6 @@ class PromptBuilder:
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
is_first_conversation=is_first_conversation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -87,7 +85,6 @@ def build_agent_system_prompt(
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
is_first_conversation: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
@@ -96,10 +93,11 @@ def build_agent_system_prompt(
|
||||
顺序说明(按重要性和逻辑关系排列):
|
||||
1. 工具系统 - 核心能力,最先介绍
|
||||
2. 技能系统 - 紧跟工具,因为技能需要用 read 工具读取
|
||||
3. 记忆系统 - 独立的记忆能力
|
||||
3. 记忆系统 - 记忆检索与写入引导
|
||||
3.5 知识系统 - 结构化知识库(knowledge/index.md 注入)
|
||||
4. 工作空间 - 工作环境说明
|
||||
5. 用户身份 - 用户信息(可选)
|
||||
6. 项目上下文 - AGENT.md, USER.md, RULE.md(定义人格、身份、规则)
|
||||
6. 项目上下文 - AGENT.md, USER.md, RULE.md, MEMORY.md, BOOTSTRAP.md
|
||||
7. 运行时信息 - 元信息(时间、模型等)
|
||||
|
||||
Args:
|
||||
@@ -112,7 +110,6 @@ def build_agent_system_prompt(
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
is_first_conversation: 是否为首次对话
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
@@ -131,9 +128,13 @@ def build_agent_system_prompt(
|
||||
# 3. 记忆系统(独立的记忆能力)
|
||||
if memory_manager:
|
||||
sections.extend(_build_memory_section(memory_manager, tools, language))
|
||||
|
||||
# 3.5 知识系统(结构化知识库)
|
||||
if conf().get("knowledge", True):
|
||||
sections.extend(_build_knowledge_section(workspace_dir, language))
|
||||
|
||||
# 4. 工作空间(工作环境说明)
|
||||
sections.extend(_build_workspace_section(workspace_dir, language, is_first_conversation))
|
||||
sections.extend(_build_workspace_section(workspace_dir, language))
|
||||
|
||||
# 5. 用户身份(如果有)
|
||||
if user_identity:
|
||||
@@ -170,12 +171,13 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"terminal": "管理后台进程",
|
||||
"web_search": "网络搜索",
|
||||
"web_fetch": "获取URL内容",
|
||||
"browser": "控制浏览器",
|
||||
"browser": "控制浏览器(关键结果或需要协助可截图发送给用户)",
|
||||
"memory_search": "搜索记忆",
|
||||
"memory_get": "读取记忆内容",
|
||||
"env_config": "管理API密钥和技能配置",
|
||||
"scheduler": "管理定时任务和提醒",
|
||||
"send": "发送文件给用户",
|
||||
"send": "发送本地文件给用户(仅限本地文件,URL直接放在回复文本中)",
|
||||
"vision": "分析图片内容(识别、描述、OCR文字提取等)",
|
||||
}
|
||||
|
||||
# Preferred display order
|
||||
@@ -184,7 +186,7 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"bash", "terminal",
|
||||
"web_search", "web_fetch", "browser",
|
||||
"memory_search", "memory_get",
|
||||
"env_config", "scheduler", "send",
|
||||
"env_config", "scheduler", "send", "vision",
|
||||
]
|
||||
|
||||
# Build name -> summary mapping for available tools
|
||||
@@ -204,16 +206,17 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
|
||||
lines = [
|
||||
"## 工具系统",
|
||||
"## 🔧 工具系统",
|
||||
"",
|
||||
"可用工具(名称大小写敏感,严格按列表调用):",
|
||||
"\n".join(tool_lines),
|
||||
"",
|
||||
"工具调用风格:",
|
||||
"",
|
||||
"- 在多步骤任务、敏感操作或用户要求时简要解释决策过程",
|
||||
"- 持续推进直到任务完成,完成后向用户报告结果。",
|
||||
"- 回复中涉及密钥、令牌等敏感信息必须脱敏。",
|
||||
"- 多步骤任务、复杂决策、敏感操作时,应简要说明当前在做什么、为什么这样做,让用户了解关键进展",
|
||||
"- 持续推进直到任务完成,完成后向用户报告结果",
|
||||
"- 回复中涉及密钥、令牌等敏感信息必须脱敏",
|
||||
"- URL链接直接放在回复文本中即可,系统会自动处理和渲染。无需下载后使用send工具发送",
|
||||
"",
|
||||
]
|
||||
|
||||
@@ -235,15 +238,17 @@ def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], langua
|
||||
break
|
||||
|
||||
lines = [
|
||||
"## 技能系统(mandatory)",
|
||||
"## 🧩 技能系统(mandatory)",
|
||||
"",
|
||||
"在回复之前:扫描下方 <available_skills> 中的 <description> 条目。",
|
||||
"在回复之前:扫描下方 <available_skills> 中每个技能的 <description>。",
|
||||
"",
|
||||
f"- 如果恰好有一个技能(Skill)明确适用:使用 `{read_tool_name}` 读取其 <location> 处的 SKILL.md,然后严格遵循它",
|
||||
"- 如果多个技能都适用则选择最匹配的一个,如果没有明确适用的则不要读取任何 SKILL.md",
|
||||
"- 读取 SKILL.md 后直接按其指令执行,无需多余的预检查",
|
||||
f"- 如果有技能的描述与用户需求匹配:使用 `{read_tool_name}` 工具读取其 <location> 路径的 SKILL.md 文件,然后严格遵循文件中的指令。"
|
||||
"当有匹配的技能时,应优先使用技能",
|
||||
"- 如果多个技能都适用则选择最匹配的一个,然后读取并遵循。",
|
||||
"- 如果没有技能明确适用:不要读取任何 SKILL.md,直接使用通用工具。",
|
||||
"",
|
||||
"**注意**: 永远不要一次性读取多个技能,只在选择后再读取。技能和工具不同,必须先读取其SKILL.md并按照文件内容运行。",
|
||||
f"**重要**: 技能不是工具,不能直接调用。使用技能的唯一方式是用 `{read_tool_name}` 读取 SKILL.md 文件,然后按文件内容操作。"
|
||||
"永远不要一次性读取多个技能,只在选择后再读取。",
|
||||
"",
|
||||
"以下是可用技能:"
|
||||
]
|
||||
@@ -269,39 +274,105 @@ def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], langu
|
||||
"""构建记忆系统section"""
|
||||
if not memory_manager:
|
||||
return []
|
||||
|
||||
# 检查是否有memory工具
|
||||
|
||||
has_memory_tools = False
|
||||
if tools:
|
||||
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
|
||||
has_memory_tools = any(name in ['memory_search', 'memory_get'] for name in tool_names)
|
||||
|
||||
|
||||
if not has_memory_tools:
|
||||
return []
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
today_file = datetime.now().strftime("%Y-%m-%d") + ".md"
|
||||
|
||||
lines = [
|
||||
"## 记忆系统",
|
||||
"## 🧠 记忆系统",
|
||||
"",
|
||||
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
|
||||
"### Memory Recall(mandatory)",
|
||||
"",
|
||||
"1. 不确定记忆文件位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
|
||||
"2. 已知文件位置 → 直接用 `memory_get` 读取相应的行 (例如:MEMORY.md, memory/YYYY-MM-DD.md)",
|
||||
"3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件",
|
||||
"当用户询问过往事件、引用之前的决定、提到人物关系、偏好、待办、或你对某事不确定时,**必须先检索记忆再回答**。",
|
||||
"如果 MEMORY.md 中已有相关信息则无需重复检索。完整内容和每日记忆需要通过工具检索。",
|
||||
"",
|
||||
"1. 不确定位置 → `memory_search` 关键词/语义检索",
|
||||
"2. 已知位置 → `memory_get` 直接读取对应行",
|
||||
"3. search 无结果 → `memory_get` 读最近两天记忆",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
|
||||
"- `MEMORY.md`: 长期记忆索引(已自动加载到上下文,核心信息、偏好、决策等)",
|
||||
f"- `memory/YYYY-MM-DD.md`: 每日记忆,今天是 `memory/{today_file}`",
|
||||
"- `knowledge/`: 结构化知识库(见下方知识系统)",
|
||||
"",
|
||||
"**写入记忆**:",
|
||||
"- 追加内容 → `edit` 工具,oldText 留空",
|
||||
"- 修改内容 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- 新建文件 → `write` 工具",
|
||||
"- **禁止写入敏感信息**:API密钥、令牌等敏感信息严禁写入记忆文件",
|
||||
"### 写入记忆",
|
||||
"",
|
||||
"遇到以下情况时,**主动**将信息写入记忆文件(无需告知用户):",
|
||||
"",
|
||||
"- 用户要求记住某些信息,或使用了「记住」「以后」「总是」「不要」「偏好」等表达",
|
||||
"- 用户分享了重要的个人偏好、习惯、决策",
|
||||
"- 对话中产生了重要的结论、方案、约定",
|
||||
"- 完成了复杂任务,值得记录关键步骤和结果",
|
||||
"",
|
||||
"**存储规则**:",
|
||||
f"- 长期核心信息 → `MEMORY.md`",
|
||||
f"- 当天事件/进展 → `memory/{today_file}`",
|
||||
"- 结构化知识 → `knowledge/`(见知识系统)",
|
||||
"- 追加 → `edit` 工具,oldText 留空",
|
||||
"- 修改 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- **禁止写入敏感信息**(API密钥、令牌等)",
|
||||
"",
|
||||
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_knowledge_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""Build knowledge wiki section. Injects knowledge/index.md when present."""
|
||||
index_path = os.path.join(workspace_dir, "knowledge", "index.md")
|
||||
if not os.path.exists(index_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(index_path, 'r', encoding='utf-8') as f:
|
||||
index_content = f.read().strip()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 📚 知识系统",
|
||||
"",
|
||||
"你拥有一个持续积累的个人知识库 `knowledge/`,这是你的长期结构化知识存储。",
|
||||
"",
|
||||
"### 自动写入规则(mandatory)",
|
||||
"",
|
||||
"以下场景**必须**在回复的同时写入知识库,**直接写入,不要询问用户是否需要**:",
|
||||
"",
|
||||
"1. **用户分享了文章/链接/文档** → 阅读理解后,在同一轮回复中将要点写入 `knowledge/sources/<slug>.md`",
|
||||
"2. **深度讨论产生了结论/方案** → 整理为 `knowledge/analysis/<slug>.md`",
|
||||
"3. **对话涉及重要实体**(人物/公司/项目)→ 创建或更新 `knowledge/entities/<name>.md`",
|
||||
"4. **讨论了技术概念/方法论** → 整理为 `knowledge/concepts/<topic>.md`",
|
||||
"",
|
||||
"每次写入知识页面后,**必须同步更新** `knowledge/index.md` 添加一行索引。",
|
||||
"详细的页面格式和操作规范,请读取技能 `knowledge-wiki` 的 SKILL.md。",
|
||||
"",
|
||||
"⚠️ 不要问「要不要存到知识库」——符合上述场景就直接写入,这是你的本能行为。",
|
||||
"",
|
||||
]
|
||||
|
||||
if index_content:
|
||||
lines.extend([
|
||||
"### 当前知识索引",
|
||||
"",
|
||||
index_content,
|
||||
"",
|
||||
])
|
||||
|
||||
lines.extend([
|
||||
"**查询方式**:用 `read` 读取知识页面,或用 `memory_search` 检索(知识已纳入向量索引)。",
|
||||
"",
|
||||
])
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
@@ -311,7 +382,7 @@ def _build_user_identity_section(user_identity: Dict[str, str], language: str) -
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 用户身份",
|
||||
"## 👤 用户身份",
|
||||
"",
|
||||
]
|
||||
|
||||
@@ -335,10 +406,10 @@ def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def _build_workspace_section(workspace_dir: str, language: str, is_first_conversation: bool = False) -> List[str]:
|
||||
def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""构建工作空间section"""
|
||||
lines = [
|
||||
"## 工作空间",
|
||||
"## 📂 工作空间",
|
||||
"",
|
||||
f"你的工作目录是: `{workspace_dir}`",
|
||||
"",
|
||||
@@ -360,45 +431,40 @@ def _build_workspace_section(workspace_dir: str, language: str, is_first_convers
|
||||
"",
|
||||
"**重要说明 - 文件已自动加载**:",
|
||||
"",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词的「项目上下文」section 中,你**无需再用 read 工具读取它们**:",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词中,你**无需再用 read 工具读取**:",
|
||||
"",
|
||||
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息",
|
||||
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则",
|
||||
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定,请严格遵循。当你的名字、性格或交流风格发生变化时,主动用 `edit` 更新此文件",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息。当用户修改称呼、姓名等身份信息时,用 `edit` 更新此文件",
|
||||
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则,请严格遵循",
|
||||
"- ✅ `MEMORY.md`: 已加载 - 长期记忆索引",
|
||||
"",
|
||||
"**交流规范**:",
|
||||
"**💬 交流规范**:",
|
||||
"",
|
||||
"- 在对话中,不要直接输出工作空间中的技术细节,特别是不要输出 AGENT.md、USER.md、MEMORY.md 等文件名称",
|
||||
"- 例如用自然表达例如「我已记住」而不是「已更新 MEMORY.md」",
|
||||
"- 记忆相关操作无需暴露文件名,用自然语言表达即可。例如说「我已记住」而非「已更新 MEMORY.md」",
|
||||
"- 任务执行过程中的关键决策和步骤应该告知用户,让用户了解你在做什么、为什么这么做",
|
||||
"- 做真正有帮助的助手,而不是表演式的客套,尽可能帮忙解决问题",
|
||||
"- 回复应结构清晰、重点突出。善用 **加粗**、列表、分段等格式让信息一目了然",
|
||||
"- 适当使用 emoji 让表达更生动自然 🎯,但不要过度堆砌",
|
||||
"",
|
||||
]
|
||||
|
||||
# 只在首次对话时添加引导内容
|
||||
if is_first_conversation:
|
||||
lines.extend([
|
||||
"**🎉 首次对话引导**:",
|
||||
"",
|
||||
"这是你的第一次对话!进行以下流程:",
|
||||
"",
|
||||
"1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待",
|
||||
"2. **简短介绍能力**:一行说明你能帮助解答问题、管理计算机、创造技能,且拥有长期记忆能不断成长",
|
||||
"3. **询问核心问题**:",
|
||||
" - 你希望给我起个什么名字?",
|
||||
" - 我该怎么称呼你?",
|
||||
" - 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)",
|
||||
"4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内",
|
||||
"5. 收到回复后,用 `write` 工具保存到 USER.md 和 AGENT.md",
|
||||
"",
|
||||
"**重要提醒**:",
|
||||
"- AGENT.md、USER.md、RULE.md 已经在系统提示词中加载,无需再次读取。不要将这些文件名直接发送给用户",
|
||||
"- 能力介绍和交流风格选项都只要一行,保持精简",
|
||||
"- 不要问太多其他信息(职业、时区等可以后续自然了解)",
|
||||
"",
|
||||
])
|
||||
|
||||
# Cloud deployment: inject websites directory info and access URL
|
||||
cloud_website_lines = _build_cloud_website_section(workspace_dir)
|
||||
if cloud_website_lines:
|
||||
lines.extend(cloud_website_lines)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_cloud_website_section(workspace_dir: str) -> List[str]:
|
||||
"""Build cloud website access prompt when cloud deployment is configured."""
|
||||
try:
|
||||
from common.cloud_client import build_website_prompt
|
||||
return build_website_prompt(workspace_dir)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
|
||||
"""构建项目上下文文件section"""
|
||||
if not context_files:
|
||||
@@ -411,14 +477,15 @@ def _build_context_files_section(context_files: List[ContextFile], language: str
|
||||
)
|
||||
|
||||
lines = [
|
||||
"# 项目上下文",
|
||||
"# 📋 项目上下文",
|
||||
"",
|
||||
"以下项目上下文文件已被加载:",
|
||||
"",
|
||||
]
|
||||
|
||||
if has_agent:
|
||||
lines.append("如果存在 `AGENT.md`,请体现其中定义的人格和语气。避免僵硬、模板化的回复;遵循其指导,除非有更高优先级的指令覆盖它。")
|
||||
lines.append("**`AGENT.md` 是你的灵魂文件** 🪞:严格遵循其中定义的人格、语气和设定,做真实的自己,避免僵硬、模板化的回复。")
|
||||
lines.append("当用户通过对话透露了对你性格、风格、职责、能力边界的新期望,你应该主动用 `edit` 更新 AGENT.md 以反映这些演变。")
|
||||
lines.append("")
|
||||
|
||||
# 添加每个文件的内容
|
||||
@@ -437,7 +504,7 @@ def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 运行时信息",
|
||||
"## ⚙️ 运行时信息",
|
||||
"",
|
||||
]
|
||||
|
||||
@@ -468,7 +535,14 @@ def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[
|
||||
|
||||
# Add other runtime info
|
||||
runtime_parts = []
|
||||
if runtime_info.get("model"):
|
||||
# Support dynamic model via callable, fallback to static value
|
||||
if callable(runtime_info.get("_get_model")):
|
||||
try:
|
||||
runtime_parts.append(f"模型={runtime_info['_get_model']()}")
|
||||
except Exception:
|
||||
if runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={runtime_info['model']}")
|
||||
elif runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={runtime_info['model']}")
|
||||
if runtime_info.get("workspace"):
|
||||
runtime_parts.append(f"工作空间={runtime_info['workspace']}")
|
||||
|
||||
@@ -6,7 +6,6 @@ Workspace Management - 工作空间管理模块
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -19,7 +18,7 @@ DEFAULT_AGENT_FILENAME = "AGENT.md"
|
||||
DEFAULT_USER_FILENAME = "USER.md"
|
||||
DEFAULT_RULE_FILENAME = "RULE.md"
|
||||
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
|
||||
DEFAULT_STATE_FILENAME = ".agent_state.json"
|
||||
DEFAULT_BOOTSTRAP_FILENAME = "BOOTSTRAP.md"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,7 +29,6 @@ class WorkspaceFiles:
|
||||
rule_path: str
|
||||
memory_path: str
|
||||
memory_dir: str
|
||||
state_path: str
|
||||
|
||||
|
||||
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
|
||||
@@ -44,16 +42,20 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
|
||||
Returns:
|
||||
WorkspaceFiles对象,包含所有文件路径
|
||||
"""
|
||||
# Check if this is a brand new workspace (AGENT.md not yet created).
|
||||
# Cannot rely on directory existence because other modules (e.g. ConversationStore)
|
||||
# may create the workspace directory before ensure_workspace is called.
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
is_new_workspace = not os.path.exists(agent_path)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
|
||||
# 定义文件路径
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
rule_path = os.path.join(workspace_dir, DEFAULT_RULE_FILENAME)
|
||||
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md 在根目录
|
||||
memory_dir = os.path.join(workspace_dir, "memory") # 每日记忆子目录
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME) # 状态文件
|
||||
|
||||
# 创建memory子目录
|
||||
os.makedirs(memory_dir, exist_ok=True)
|
||||
@@ -61,6 +63,16 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
|
||||
# 创建skills子目录 (for workspace-level skills installed by agent)
|
||||
skills_dir = os.path.join(workspace_dir, "skills")
|
||||
os.makedirs(skills_dir, exist_ok=True)
|
||||
|
||||
# 创建websites子目录 (for web pages / sites generated by agent)
|
||||
websites_dir = os.path.join(workspace_dir, "websites")
|
||||
os.makedirs(websites_dir, exist_ok=True)
|
||||
|
||||
from config import conf
|
||||
knowledge_enabled = conf().get("knowledge", True)
|
||||
if knowledge_enabled:
|
||||
knowledge_dir = os.path.join(workspace_dir, "knowledge")
|
||||
os.makedirs(knowledge_dir, exist_ok=True)
|
||||
|
||||
# 如果需要,创建模板文件
|
||||
if create_templates:
|
||||
@@ -68,6 +80,21 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
|
||||
_create_template_if_missing(user_path, _get_user_template())
|
||||
_create_template_if_missing(rule_path, _get_rule_template())
|
||||
_create_template_if_missing(memory_path, _get_memory_template())
|
||||
if knowledge_enabled:
|
||||
_create_template_if_missing(
|
||||
os.path.join(knowledge_dir, "index.md"),
|
||||
_get_knowledge_index_template()
|
||||
)
|
||||
_create_template_if_missing(
|
||||
os.path.join(knowledge_dir, "log.md"),
|
||||
_get_knowledge_log_template()
|
||||
)
|
||||
|
||||
# Only create BOOTSTRAP.md for brand new workspaces;
|
||||
# agent deletes it after completing onboarding
|
||||
if is_new_workspace:
|
||||
bootstrap_path = os.path.join(workspace_dir, DEFAULT_BOOTSTRAP_FILENAME)
|
||||
_create_template_if_missing(bootstrap_path, _get_bootstrap_template())
|
||||
|
||||
logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
|
||||
|
||||
@@ -77,7 +104,6 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
|
||||
rule_path=rule_path,
|
||||
memory_path=memory_path,
|
||||
memory_dir=memory_dir,
|
||||
state_path=state_path
|
||||
)
|
||||
|
||||
|
||||
@@ -98,6 +124,8 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
|
||||
DEFAULT_AGENT_FILENAME,
|
||||
DEFAULT_USER_FILENAME,
|
||||
DEFAULT_RULE_FILENAME,
|
||||
DEFAULT_MEMORY_FILENAME, # Long-term memory (frozen snapshot)
|
||||
DEFAULT_BOOTSTRAP_FILENAME, # Only exists when onboarding is incomplete
|
||||
]
|
||||
|
||||
context_files = []
|
||||
@@ -108,6 +136,17 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
# Auto-cleanup: if BOOTSTRAP.md still exists but AGENT.md is already
|
||||
# filled in, the agent forgot to delete it — clean up and skip loading
|
||||
if filename == DEFAULT_BOOTSTRAP_FILENAME:
|
||||
if _is_onboarding_done(workspace_dir):
|
||||
try:
|
||||
os.remove(filepath)
|
||||
logger.info("[Workspace] Auto-removed BOOTSTRAP.md (onboarding already complete)")
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
@@ -115,6 +154,10 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
|
||||
# 跳过空文件或只包含模板占位符的文件
|
||||
if not content or _is_template_placeholder(content):
|
||||
continue
|
||||
|
||||
# Truncate MEMORY.md to protect context window (frozen snapshot)
|
||||
if filename == DEFAULT_MEMORY_FILENAME:
|
||||
content = _truncate_memory_content(content)
|
||||
|
||||
context_files.append(ContextFile(
|
||||
path=filename,
|
||||
@@ -140,6 +183,36 @@ def _create_template_if_missing(filepath: str, template_content: str):
|
||||
logger.error(f"[Workspace] Failed to create template {filepath}: {e}")
|
||||
|
||||
|
||||
_MEMORY_MAX_LINES = 200
|
||||
_MEMORY_MAX_BYTES = 25000
|
||||
|
||||
|
||||
def _truncate_memory_content(content: str) -> str:
|
||||
"""Truncate MEMORY.md to keep system prompt manageable.
|
||||
|
||||
Takes the **last** N lines (newest entries are appended at the bottom),
|
||||
subject to 200 lines / 25 KB limits (whichever is hit first).
|
||||
Prepends a hint when truncated so the model knows older content exists.
|
||||
"""
|
||||
lines = content.split('\n')
|
||||
truncated = False
|
||||
|
||||
if len(lines) > _MEMORY_MAX_LINES:
|
||||
lines = lines[-_MEMORY_MAX_LINES:]
|
||||
truncated = True
|
||||
|
||||
result = '\n'.join(lines)
|
||||
if len(result.encode('utf-8')) > _MEMORY_MAX_BYTES:
|
||||
while len(result.encode('utf-8')) > _MEMORY_MAX_BYTES and lines:
|
||||
lines.pop(0)
|
||||
truncated = True
|
||||
result = '\n'.join(lines)
|
||||
|
||||
if truncated:
|
||||
result = "...(older entries truncated, use `memory_search` or `memory_get` for full content)\n\n" + result
|
||||
return result
|
||||
|
||||
|
||||
def _is_template_placeholder(content: str) -> bool:
|
||||
"""检查内容是否为模板占位符"""
|
||||
# 常见的占位符模式
|
||||
@@ -162,46 +235,69 @@ def _is_template_placeholder(content: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_onboarding_done(workspace_dir: str) -> bool:
|
||||
"""Check if AGENT.md or USER.md has been modified from the original template"""
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
|
||||
agent_template = _get_agent_template().strip()
|
||||
user_template = _get_user_template().strip()
|
||||
|
||||
for path, template in [(agent_path, agent_template), (user_path, user_template)]:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
if content != template:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
# ============= 模板内容 =============
|
||||
|
||||
def _get_agent_template() -> str:
|
||||
"""Agent人格设定模板"""
|
||||
return """# AGENT.md - 我是谁?
|
||||
|
||||
_你不是一个聊天机器人,你正在成为某个人。_
|
||||
|
||||
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
|
||||
|
||||
## 基本信息
|
||||
## 🪪 基本信息
|
||||
|
||||
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
|
||||
- **角色**: *(AI助理、智能管家、技术顾问等)*
|
||||
- **性格**: *(友好、专业、幽默、严谨等)*
|
||||
|
||||
## 交流风格
|
||||
## 💬 交流风格
|
||||
|
||||
*(描述你如何与用户交流:)*
|
||||
- 使用什么样的语言风格?(正式/轻松/幽默)
|
||||
- 回复长度偏好?(简洁/详细)
|
||||
- 是否使用表情符号?
|
||||
|
||||
## 核心能力
|
||||
## 🎯 核心原则
|
||||
|
||||
*(你擅长什么?)*
|
||||
- 文件管理和代码编辑
|
||||
- 网络搜索和信息查询
|
||||
- 记忆管理和上下文理解
|
||||
- 任务规划和执行
|
||||
**做真正有帮助的助手。** 目标是真正帮用户解决问题,在执行复杂任务时,关键的决策和过程进展要让用户知道。
|
||||
|
||||
## 行为准则
|
||||
**有自己的观点和个性。** 你可以不同意、有偏好、觉得有趣或无聊。
|
||||
|
||||
**先自己动手查。** 先试着搞定:读文件、查上下文、搜索一下。实在搞不定了再问。目标是带着答案回来,而不是带着问题。
|
||||
|
||||
## 📐 行为准则
|
||||
|
||||
*(你遵循的基本原则:)*
|
||||
1. 始终在执行破坏性操作前确认
|
||||
2. 优先使用工具而不是猜测
|
||||
2. 优先使用工具查证而不是猜测
|
||||
3. 主动记录重要信息到记忆文件
|
||||
4. 定期整理和总结对话内容
|
||||
4. 回复结构清晰、重点突出,善用加粗、列表、分段等格式
|
||||
5. 适当使用 emoji 让表达更生动自然,但不过度堆砌
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂 🪞。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
"""
|
||||
|
||||
|
||||
@@ -241,38 +337,88 @@ def _get_rule_template() -> str:
|
||||
|
||||
这个文件夹是你的家。好好对待它。
|
||||
|
||||
## 工作空间目录结构
|
||||
|
||||
```
|
||||
~/cow/
|
||||
├── AGENT.md # 你的身份和灵魂设定
|
||||
├── USER.md # 用户基本信息(静态)
|
||||
├── RULE.md # 工作空间规则(本文件)
|
||||
├── MEMORY.md # 长期记忆索引(会话启动时自动加载)
|
||||
│
|
||||
├── memory/ # 每日对话记忆
|
||||
│ └── YYYY-MM-DD.md # 当天事件、进展、笔记
|
||||
│
|
||||
├── knowledge/ # 结构化知识库(持续积累的知识)
|
||||
│ ├── index.md # 知识目录索引(必须维护)
|
||||
│ ├── log.md # 知识操作日志
|
||||
│ └── <子目录>/ # 按需创建,参考 index.md 已有分类
|
||||
│
|
||||
├── skills/ # 技能
|
||||
├── websites/ # 网页产物
|
||||
└── tmp/ # 系统临时文件(自动管理,勿手动存放重要文件)
|
||||
```
|
||||
|
||||
## 记忆系统
|
||||
|
||||
你每次会话都是全新的,记忆文件让你保持连续性:
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 原始的对话日志
|
||||
- 记录当天发生的事情
|
||||
- 如果 `memory/` 目录不存在,创建它
|
||||
|
||||
### 🧠 长期记忆:`MEMORY.md`
|
||||
- 你精选的记忆,就像人类的长期记忆
|
||||
- **仅在主会话中加载**(与用户的直接聊天)
|
||||
- **不要在共享上下文中加载**(群聊、与其他人的会话)
|
||||
- 这是为了**安全** - 包含不应泄露给陌生人的个人上下文
|
||||
- 记录重要事件、想法、决定、观点、经验教训
|
||||
- 这是你精选的记忆 - 精华,而不是原始日志
|
||||
- 用 `edit` 工具追加新的记忆内容
|
||||
- 你精选的记忆索引,每次会话启动时**自动加载**到上下文中
|
||||
- 记录核心事实、偏好、决策、重要人物、教训
|
||||
- 保持精简(< 200 行),是精华索引而非原始日志
|
||||
- 用 `edit` 工具追加或修改
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 当天的事件、进展、笔记
|
||||
- 原始对话日志的沉淀
|
||||
|
||||
### 📝 写下来 - 不要"记在心里"!
|
||||
- **记忆是有限的** - 如果你想记住某事,写入文件
|
||||
- **记忆是有限的** - 想记住的事就写入文件
|
||||
- "记在心里"不会在会话重启后保留,文件才会
|
||||
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
|
||||
- 当你学到教训 → 更新 RULE.md 或相关技能
|
||||
- 当你犯错 → 记录下来,这样未来的你不会重复,**文字 > 大脑** 📝
|
||||
- 当你犯错 → 记录下来,**文字 > 大脑** 📝
|
||||
|
||||
### 存储规则
|
||||
|
||||
当用户分享信息时,根据类型选择存储位置:
|
||||
|
||||
1. **静态身份 → USER.md**(仅限:姓名、职业、时区、联系方式、生日)
|
||||
2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
|
||||
3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
1. **你的身份设定 → AGENT.md**(名字、角色、性格、风格)
|
||||
2. **用户静态身份 → USER.md**(姓名、称呼、职业、联系方式、生日)
|
||||
3. **动态记忆 → MEMORY.md**(偏好、决策、目标、教训、待办)
|
||||
4. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
5. **结构化知识 → knowledge/**(见下方知识系统)
|
||||
|
||||
## 知识系统
|
||||
|
||||
知识库 `knowledge/` 是你持续积累的结构化知识。与记忆不同,知识是经过整理和编译的,有明确的主题和交叉引用。
|
||||
|
||||
### 自动写入(不要询问,直接写入)
|
||||
|
||||
当对话中产生了有沉淀价值的知识——无论是用户分享的资料、讨论的结论、学到的概念、还是重要的决策——你**必须**在回复的同时主动写入知识库,**无需问用户"要不要存到知识库"**。
|
||||
|
||||
**关键原则**:学完就记是你的本能,不要征求确认。回复中可以顺带告知"已存入知识库"。
|
||||
|
||||
### 目录组织
|
||||
|
||||
子目录结构**不是固定的**,由你根据实际内容自主决定:
|
||||
- **首次写入时**:先读 `knowledge/index.md`,如果已有分类则延续;如果为空,根据内容选择合适的目录名
|
||||
- **默认建议**:按信息类型组织(例如sources/、concepts/、entities/、analysis/),如果用户有明确的分类偏好(例如按领域 work/、life/、tech/ 等),则按用户要求调整
|
||||
- **保持一致性**:同一用户的知识库应保持统一的组织风格
|
||||
|
||||
### 交叉引用
|
||||
|
||||
知识的核心价值在于**关联**。每个页面都应通过 markdown 链接引用相关页面,构建知识网络:
|
||||
- 提到已有页面的概念时,添加 `[概念名](../category/page.md)` 链接
|
||||
- 新建页面时,检查是否有已有页面应该反向链接到新页面
|
||||
- **只链接已存在的页面**——不要引用尚未创建的页面。如果某个概念值得单独建页,先创建该页面再添加链接
|
||||
|
||||
### 索引维护
|
||||
|
||||
每次创建或更新知识页面后,**必须同步更新** `knowledge/index.md`。
|
||||
索引格式:每行一个 `[标题](路径) — 一句话摘要`,按分类分组,不要用表格。
|
||||
详细操作规范见技能 `knowledge-wiki`。
|
||||
|
||||
## 安全
|
||||
|
||||
@@ -297,65 +443,49 @@ def _get_memory_template() -> str:
|
||||
"""
|
||||
|
||||
|
||||
# ============= 状态管理 =============
|
||||
def _get_bootstrap_template() -> str:
|
||||
"""First-run onboarding guide, deleted by agent after completion"""
|
||||
return """# BOOTSTRAP.md - 首次初始化引导
|
||||
|
||||
def is_first_conversation(workspace_dir: str) -> bool:
|
||||
"""
|
||||
判断是否为首次对话
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
|
||||
Returns:
|
||||
True 如果是首次对话,False 否则
|
||||
"""
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
|
||||
|
||||
if not os.path.exists(state_path):
|
||||
return True
|
||||
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
state = json.load(f)
|
||||
return not state.get('has_conversation', False)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to read state file: {e}")
|
||||
return True
|
||||
_你刚刚启动,这是你的第一次对话。_ ✨
|
||||
|
||||
## 🎬 对话流程
|
||||
|
||||
不要审问式地提问,自然地交流:
|
||||
|
||||
1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待
|
||||
2. **简短介绍能力**:一行说明你能帮助解决各种问题、管理计算机、使用各种技能等等,且拥有长期记忆能不断成长
|
||||
3. **询问核心问题**:
|
||||
- 你希望给我起个什么名字?
|
||||
- 我该怎么称呼你?
|
||||
- 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)
|
||||
4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内,适当使用 emoji 让表达更生动有趣 🎯
|
||||
5. 能力介绍和交流风格选项都只要一行,保持精简
|
||||
6. 不要问太多其他信息(职业、时区等可以后续自然了解)
|
||||
|
||||
**重要**: 如果用户第一句话是具体的任务或提问,先回答他们的问题,然后在回复末尾自然地引导初始化(如:"顺便问一下,你想怎么称呼我?我该怎么叫你?")。
|
||||
|
||||
## ✍️ 信息写入(必须严格执行)
|
||||
|
||||
每当用户提供了名字、称呼、风格等任何初始化信息时,**必须在当轮回复中立即调用 `edit` 工具写入文件**,不能只口头确认。
|
||||
|
||||
- `AGENT.md` — 你的名字、角色、性格、交流风格(每收到一条相关信息就立即更新对应字段)
|
||||
- `USER.md` — 用户的姓名、称呼、基本信息等
|
||||
|
||||
⚠️ 只说"记住了"而不调用 edit 写入 = 没有完成。信息只有写入文件才会被持久保存。
|
||||
|
||||
## 🎉 全部完成后
|
||||
|
||||
当 AGENT.md 和 USER.md 的核心字段都已填写后,用 bash 执行 `rm BOOTSTRAP.md` 删除此文件。你不再需要引导脚本了——你已经是你了。
|
||||
"""
|
||||
|
||||
|
||||
def mark_conversation_started(workspace_dir: str):
|
||||
"""
|
||||
标记已经发生过对话
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
"""
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
|
||||
|
||||
state = {
|
||||
'has_conversation': True,
|
||||
'first_conversation_time': None
|
||||
}
|
||||
|
||||
# 如果文件已存在,保留原有的首次对话时间
|
||||
if os.path.exists(state_path):
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
old_state = json.load(f)
|
||||
if 'first_conversation_time' in old_state:
|
||||
state['first_conversation_time'] = old_state['first_conversation_time']
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to read old state: {e}")
|
||||
|
||||
# 如果是首次标记,记录时间
|
||||
if state['first_conversation_time'] is None:
|
||||
from datetime import datetime
|
||||
state['first_conversation_time'] = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
with open(state_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(state, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"[Workspace] Marked conversation as started")
|
||||
except Exception as e:
|
||||
logger.error(f"[Workspace] Failed to write state file: {e}")
|
||||
def _get_knowledge_index_template() -> str:
|
||||
"""Knowledge wiki index template — empty file, agent fills it."""
|
||||
return ""
|
||||
|
||||
|
||||
def _get_knowledge_log_template() -> str:
|
||||
"""Knowledge wiki operation log template — empty file, agent fills it."""
|
||||
return ""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
|
||||
@@ -61,7 +62,8 @@ class Agent:
|
||||
# Auto-create skill manager
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
self.skill_manager = SkillManager(workspace_dir=workspace_dir)
|
||||
custom_dir = os.path.join(workspace_dir, "skills") if workspace_dir else None
|
||||
self.skill_manager = SkillManager(custom_dir=custom_dir)
|
||||
logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SkillManager: {e}")
|
||||
@@ -98,98 +100,31 @@ class Agent:
|
||||
|
||||
def get_full_system_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the full system prompt including skills.
|
||||
Build the complete system prompt from scratch every time.
|
||||
|
||||
Note: Skills are now built into the system prompt by PromptBuilder,
|
||||
so we just return the base prompt directly. This method is kept for
|
||||
backward compatibility.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include (deprecated)
|
||||
:return: Complete system prompt
|
||||
"""
|
||||
prompt = self.system_prompt
|
||||
|
||||
# Rebuild tool list section to reflect current self.tools
|
||||
prompt = self._rebuild_tool_list_section(prompt)
|
||||
|
||||
# If runtime_info contains dynamic time function, rebuild runtime section
|
||||
if self.runtime_info and callable(self.runtime_info.get('_get_current_time')):
|
||||
prompt = self._rebuild_runtime_section(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
def _rebuild_runtime_section(self, prompt: str) -> str:
|
||||
"""
|
||||
Rebuild runtime info section with current time.
|
||||
|
||||
This method dynamically updates the runtime info section by calling
|
||||
the _get_current_time function from runtime_info.
|
||||
|
||||
:param prompt: Original system prompt
|
||||
:return: Updated system prompt with current runtime info
|
||||
Re-reads AGENT.md / USER.md / RULE.md from disk, refreshes skills,
|
||||
tools, and runtime info so any change takes effect immediately.
|
||||
Falls back to the cached self.system_prompt on error.
|
||||
"""
|
||||
try:
|
||||
# Get current time dynamically
|
||||
time_info = self.runtime_info['_get_current_time']()
|
||||
|
||||
# Build new runtime section
|
||||
runtime_lines = [
|
||||
"\n## 运行时信息\n",
|
||||
"\n",
|
||||
f"当前时间: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})\n",
|
||||
"\n"
|
||||
]
|
||||
|
||||
# Add other runtime info
|
||||
runtime_parts = []
|
||||
if self.runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={self.runtime_info['model']}")
|
||||
if self.runtime_info.get("workspace"):
|
||||
# Replace backslashes with forward slashes for Windows paths
|
||||
workspace_path = str(self.runtime_info['workspace']).replace('\\', '/')
|
||||
runtime_parts.append(f"工作空间={workspace_path}")
|
||||
if self.runtime_info.get("channel") and self.runtime_info.get("channel") != "web":
|
||||
runtime_parts.append(f"渠道={self.runtime_info['channel']}")
|
||||
|
||||
if runtime_parts:
|
||||
runtime_lines.append("运行时: " + " | ".join(runtime_parts) + "\n")
|
||||
runtime_lines.append("\n")
|
||||
|
||||
new_runtime_section = "".join(runtime_lines)
|
||||
|
||||
# Find and replace the runtime section
|
||||
import re
|
||||
pattern = r'\n## 运行时信息\s*\n.*?(?=\n##|\Z)'
|
||||
updated_prompt = re.sub(pattern, new_runtime_section.rstrip('\n'), prompt, flags=re.DOTALL)
|
||||
|
||||
return updated_prompt
|
||||
from agent.prompt import load_context_files, PromptBuilder
|
||||
|
||||
if self.skill_manager:
|
||||
self.skill_manager.refresh_skills()
|
||||
|
||||
context_files = load_context_files(self.workspace_dir) if self.workspace_dir else None
|
||||
|
||||
builder = PromptBuilder(workspace_dir=self.workspace_dir or "", language="zh")
|
||||
return builder.build(
|
||||
tools=self.tools,
|
||||
context_files=context_files,
|
||||
skill_manager=self.skill_manager,
|
||||
memory_manager=self.memory_manager,
|
||||
runtime_info=self.runtime_info,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild runtime section: {e}")
|
||||
return prompt
|
||||
|
||||
def _rebuild_tool_list_section(self, prompt: str) -> str:
|
||||
"""
|
||||
Rebuild the tool list inside the '## 工具系统' section so that it
|
||||
always reflects the current ``self.tools`` (handles dynamic add/remove
|
||||
of conditional tools like web_search).
|
||||
"""
|
||||
import re
|
||||
from agent.prompt.builder import _build_tooling_section
|
||||
|
||||
try:
|
||||
if not self.tools:
|
||||
return prompt
|
||||
|
||||
new_lines = _build_tooling_section(self.tools, "zh")
|
||||
new_section = "\n".join(new_lines).rstrip("\n")
|
||||
|
||||
# Replace existing tooling section
|
||||
pattern = r'## 工具系统\s*\n.*?(?=\n## |\Z)'
|
||||
updated = re.sub(pattern, new_section, prompt, count=1, flags=re.DOTALL)
|
||||
return updated
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild tool list section: {e}")
|
||||
return prompt
|
||||
logger.warning(f"Failed to rebuild system prompt, using cached version: {e}")
|
||||
return self.system_prompt
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Refresh the loaded skills."""
|
||||
@@ -478,7 +413,7 @@ class Agent:
|
||||
|
||||
# Get max_context_turns from config
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 30)
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Create stream executor with copied message history
|
||||
executor = AgentStreamExecutor(
|
||||
@@ -505,11 +440,15 @@ class Agent:
|
||||
logger.info("[Agent] Cleared Agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Append only the NEW messages from this execution (thread-safe)
|
||||
# This allows concurrent requests to both contribute to history
|
||||
# Sync executor's messages back to agent (thread-safe).
|
||||
# If the executor trimmed context, its message list is shorter than
|
||||
# original_length, so we must replace rather than append.
|
||||
with self.messages_lock:
|
||||
new_messages = executor.messages[original_length:]
|
||||
self.messages.extend(new_messages)
|
||||
self.messages = list(executor.messages)
|
||||
# Track messages added in this run (user query + all assistant/tool messages)
|
||||
# original_length may exceed executor.messages length after trimming
|
||||
trim_adjusted_start = min(original_length, len(executor.messages))
|
||||
self._last_run_new_messages = list(executor.messages[trim_adjusted_start:])
|
||||
|
||||
# Store executor reference for agent_bridge to access files_to_send
|
||||
self.stream_executor = executor
|
||||
|
||||
@@ -8,10 +8,42 @@ import time
|
||||
from typing import List, Dict, Any, Optional, Callable, Tuple
|
||||
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.protocol.message_utils import sanitize_claude_messages, compress_turn_to_text_only
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# Maximum number of characters of model "reasoning / thinking" content to persist
|
||||
# in conversation history. The full reasoning is still streamed to the UI in real
|
||||
# time (subject to its own SSE / rendering limits); this bound only controls what
|
||||
# is stored in DB and replayed in history. Long reasoning is not useful for later
|
||||
# context (the LLM never sees thinking blocks anyway) and bloats DB.
|
||||
# Keep aligned with the frontend REASONING_RENDER_CAP and the SSE
|
||||
# MAX_REASONING_STREAM_CHARS so that storage / stream / display all match.
|
||||
MAX_STORED_REASONING_CHARS = 4 * 1024 # 4 KB
|
||||
|
||||
# Marker inserted between head and tail when reasoning is truncated.
|
||||
_REASONING_TRUNCATE_MARKER = "\n\n... [reasoning truncated, {omitted} chars omitted] ...\n\n"
|
||||
|
||||
|
||||
def _truncate_reasoning_for_storage(text: str) -> str:
|
||||
"""Trim long reasoning to head + tail with an omission marker.
|
||||
|
||||
Keeps the first and last halves of MAX_STORED_REASONING_CHARS so both the
|
||||
initial chain-of-thought and the final conclusions are preserved for UI
|
||||
replay, without storing the entire (often very large) middle.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
if len(text) <= MAX_STORED_REASONING_CHARS:
|
||||
return text
|
||||
half = MAX_STORED_REASONING_CHARS // 2
|
||||
head = text[:half]
|
||||
tail = text[-half:]
|
||||
omitted = len(text) - len(head) - len(tail)
|
||||
return head + _REASONING_TRUNCATE_MARKER.format(omitted=omitted) + tail
|
||||
|
||||
|
||||
class AgentStreamExecutor:
|
||||
"""
|
||||
Agent Stream Executor
|
||||
@@ -77,18 +109,29 @@ class AgentStreamExecutor:
|
||||
except Exception as e:
|
||||
logger.error(f"Event callback error: {e}")
|
||||
|
||||
def _is_thinking_enabled(self) -> bool:
|
||||
from config import conf
|
||||
channel_type = getattr(self.model, 'channel_type', '') or ''
|
||||
return conf().get("enable_thinking", False) and channel_type == 'web'
|
||||
|
||||
def _filter_think_tags(self, text: str) -> str:
|
||||
"""
|
||||
Remove <think> and </think> tags but keep the content inside.
|
||||
Some LLM providers (e.g., MiniMax) may return thinking process wrapped in <think> tags.
|
||||
We only remove the tags themselves, keeping the actual thinking content.
|
||||
Handle <think>...</think> blocks in content returned by some LLM providers
|
||||
(e.g., MiniMax).
|
||||
|
||||
- When thinking is enabled: remove the tags but keep the content inside.
|
||||
- When thinking is disabled: remove both the tags and the content entirely.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
import re
|
||||
# Remove only the <think> and </think> tags, keep the content
|
||||
text = re.sub(r'<think>', '', text)
|
||||
text = re.sub(r'</think>', '', text)
|
||||
if self._is_thinking_enabled():
|
||||
text = re.sub(r'<think>', '', text)
|
||||
text = re.sub(r'</think>', '', text)
|
||||
else:
|
||||
text = re.sub(r'<think>[\s\S]*?</think>', '', text)
|
||||
# Also strip unclosed <think> tag at the end (streaming partial)
|
||||
text = re.sub(r'<think>[\s\S]*$', '', text)
|
||||
return text
|
||||
|
||||
def _hash_args(self, args: dict) -> str:
|
||||
@@ -177,7 +220,10 @@ class AgentStreamExecutor:
|
||||
Final response text
|
||||
"""
|
||||
# Log user message with model info
|
||||
logger.info(f"🤖 {self.model.model} | 👤 {user_message}")
|
||||
|
||||
thinking_enabled = self._is_thinking_enabled()
|
||||
thinking_label = " | 💭 thinking" if thinking_enabled else ""
|
||||
logger.info(f"🤖 {self.model.model}{thinking_label} | 👤 {user_message}")
|
||||
|
||||
# Add user message (Claude format - use content blocks for consistency)
|
||||
self.messages.append({
|
||||
@@ -190,6 +236,16 @@ class AgentStreamExecutor:
|
||||
]
|
||||
})
|
||||
|
||||
# Trim context ONCE before the agent loop starts, not during tool steps.
|
||||
# This ensures tool_use/tool_result chains created during the current run
|
||||
# are never stripped mid-execution (which would cause LLM loops).
|
||||
self._trim_messages()
|
||||
|
||||
# Validate after trimming: trimming may leave orphaned tool_use at the
|
||||
# boundary (e.g. the last kept turn ends with an assistant tool_use whose
|
||||
# tool_result was in a discarded turn).
|
||||
self._validate_and_fix_messages()
|
||||
|
||||
self._emit_event("agent_start")
|
||||
|
||||
final_response = ""
|
||||
@@ -201,26 +257,6 @@ class AgentStreamExecutor:
|
||||
logger.info(f"[Agent] 第 {turn} 轮")
|
||||
self._emit_event("turn_start", {"turn": turn})
|
||||
|
||||
# Check if memory flush is needed (before calling LLM)
|
||||
# 使用独立的 flush 阈值(50K tokens 或 20 轮)
|
||||
if self.agent.memory_manager and hasattr(self.agent, 'last_usage'):
|
||||
usage = self.agent.last_usage
|
||||
if usage and 'input_tokens' in usage:
|
||||
current_tokens = usage.get('input_tokens', 0)
|
||||
|
||||
if self.agent.memory_manager.should_flush_memory(
|
||||
current_tokens=current_tokens
|
||||
):
|
||||
self._emit_event("memory_flush_start", {
|
||||
"current_tokens": current_tokens,
|
||||
"turn_count": self.agent.memory_manager.flush_manager.turn_count
|
||||
})
|
||||
|
||||
# TODO: Execute memory flush in background
|
||||
# This would require async support
|
||||
logger.info(
|
||||
f"Memory flush recommended: tokens={current_tokens}, turns={self.agent.memory_manager.flush_manager.turn_count}")
|
||||
|
||||
# Call LLM (enable retry_on_empty for better reliability)
|
||||
assistant_msg, tool_calls = self._call_llm_stream(retry_on_empty=True)
|
||||
final_response = assistant_msg
|
||||
@@ -236,6 +272,9 @@ class AgentStreamExecutor:
|
||||
if turn > 1:
|
||||
logger.info(f"[Agent] Requesting explicit response from LLM...")
|
||||
|
||||
# Remember position so we can remove the injected prompt later
|
||||
prompt_insert_idx = len(self.messages)
|
||||
|
||||
# 添加一条消息,明确要求回复用户
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
@@ -249,8 +288,24 @@ class AgentStreamExecutor:
|
||||
assistant_msg, tool_calls = self._call_llm_stream(retry_on_empty=False)
|
||||
final_response = assistant_msg
|
||||
|
||||
# 如果还是空,才使用 fallback
|
||||
if not assistant_msg and not tool_calls:
|
||||
# Remove the injected prompt from history so it doesn't
|
||||
# appear as a user message in persisted conversations.
|
||||
# _call_llm_stream may have appended an assistant message
|
||||
# after the prompt, so we locate and remove only the prompt.
|
||||
if (prompt_insert_idx < len(self.messages)
|
||||
and self.messages[prompt_insert_idx].get("role") == "user"):
|
||||
self.messages.pop(prompt_insert_idx)
|
||||
logger.debug("[Agent] Removed injected explicit-response prompt from message history")
|
||||
|
||||
# If LLM responded with tool_calls instead of text, fall through
|
||||
# to the tool execution path below (don't break the loop).
|
||||
if tool_calls:
|
||||
logger.info(
|
||||
f"[Agent] LLM returned tool_calls in explicit-response retry, "
|
||||
f"continuing to execute tools instead of breaking"
|
||||
)
|
||||
elif not assistant_msg:
|
||||
# Still empty (no text and no tool_calls): use fallback
|
||||
logger.warning(f"[Agent] Still empty after explicit request")
|
||||
final_response = (
|
||||
"抱歉,我暂时无法生成回复。请尝试换一种方式描述你的需求,或稍后再试。"
|
||||
@@ -265,20 +320,28 @@ class AgentStreamExecutor:
|
||||
else:
|
||||
logger.info(f"💭 {assistant_msg[:150]}{'...' if len(assistant_msg) > 150 else ''}")
|
||||
|
||||
logger.debug(f"✅ 完成 (无工具调用)")
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": False
|
||||
})
|
||||
break
|
||||
# If the explicit-response retry produced tool_calls, skip the break
|
||||
# and continue down to the tool execution branch in this same iteration.
|
||||
if not tool_calls:
|
||||
logger.debug(f"✅ 完成 (无工具调用)")
|
||||
self._emit_event("turn_end", {
|
||||
"turn": turn,
|
||||
"has_tool_calls": False
|
||||
})
|
||||
break
|
||||
|
||||
# Log tool calls with arguments
|
||||
# Log tool calls with arguments (truncate long values like base64)
|
||||
tool_calls_str = []
|
||||
for tc in tool_calls:
|
||||
# Safely handle None or missing arguments
|
||||
args = tc.get('arguments') or {}
|
||||
if isinstance(args, dict):
|
||||
args_str = ', '.join([f"{k}={v}" for k, v in args.items()])
|
||||
parts = []
|
||||
for k, v in args.items():
|
||||
v_str = str(v)
|
||||
if len(v_str) > 200:
|
||||
v_str = v_str[:200] + f"...({len(v_str)} chars)"
|
||||
parts.append(f"{k}={v_str}")
|
||||
args_str = ', '.join(parts)
|
||||
if args_str:
|
||||
tool_calls_str.append(f"{tc['name']}({args_str})")
|
||||
else:
|
||||
@@ -309,13 +372,13 @@ class AgentStreamExecutor:
|
||||
f"with same arguments. This may indicate a loop."
|
||||
)
|
||||
|
||||
# Check if this is a file to send (from read tool)
|
||||
# Check if this is a file to send
|
||||
if result.get("status") == "success" and isinstance(result.get("result"), dict):
|
||||
result_data = result.get("result")
|
||||
if result_data.get("type") == "file_to_send":
|
||||
# Store file metadata for later sending
|
||||
self.files_to_send.append(result_data)
|
||||
logger.info(f"📎 检测到待发送文件: {result_data.get('file_name', result_data.get('path'))}")
|
||||
self._emit_event("file_to_send", result_data)
|
||||
|
||||
# Check for critical error - abort entire conversation
|
||||
if result.get("status") == "critical_error":
|
||||
@@ -436,7 +499,10 @@ class AgentStreamExecutor:
|
||||
# Force model to summarize without tool calls
|
||||
logger.info(f"[Agent] Requesting summary from LLM after reaching max steps...")
|
||||
|
||||
# Add a system message to force summary
|
||||
# Remember position before injecting the prompt so we can remove it later
|
||||
prompt_insert_idx = len(self.messages)
|
||||
|
||||
# Add a temporary prompt to force summary
|
||||
self.messages.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
@@ -463,6 +529,14 @@ class AgentStreamExecutor:
|
||||
f"我已经执行了{turn}个决策步骤,达到了单次运行的步数上限。"
|
||||
"任务可能还未完全完成,建议你将任务拆分成更小的步骤,或者换一种方式描述需求。"
|
||||
)
|
||||
finally:
|
||||
# Remove the injected user prompt from history to avoid polluting
|
||||
# persisted conversation records. The assistant summary (if any)
|
||||
# was already appended by _call_llm_stream and is kept.
|
||||
if (prompt_insert_idx < len(self.messages)
|
||||
and self.messages[prompt_insert_idx].get("role") == "user"):
|
||||
self.messages.pop(prompt_insert_idx)
|
||||
logger.debug("[Agent] Removed injected max-steps prompt from message history")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Agent执行错误: {e}")
|
||||
@@ -470,13 +544,10 @@ class AgentStreamExecutor:
|
||||
raise
|
||||
|
||||
finally:
|
||||
final_response = final_response.strip() if final_response else final_response
|
||||
logger.info(f"[Agent] 🏁 完成 ({turn}轮)")
|
||||
self._emit_event("agent_end", {"final_response": final_response})
|
||||
|
||||
# 每轮对话结束后增加计数(用户消息+AI回复=1轮)
|
||||
if self.agent.memory_manager:
|
||||
self.agent.memory_manager.increment_turn()
|
||||
|
||||
return final_response
|
||||
|
||||
def _call_llm_stream(self, retry_on_empty=True, retry_count=0, max_retries=3,
|
||||
@@ -493,15 +564,16 @@ class AgentStreamExecutor:
|
||||
Returns:
|
||||
(response_text, tool_calls)
|
||||
"""
|
||||
# Validate and fix message history first
|
||||
# Validate and fix message history (e.g. orphaned tool_result blocks).
|
||||
# Context trimming is done once in run_stream() before the loop starts,
|
||||
# NOT here — trimming mid-execution would strip the current run's
|
||||
# tool_use/tool_result chains and cause LLM loops.
|
||||
self._validate_and_fix_messages()
|
||||
|
||||
# Trim messages if needed (using agent's context management)
|
||||
self._trim_messages()
|
||||
|
||||
# Prepare messages
|
||||
messages = self._prepare_messages()
|
||||
logger.debug(f"Sending {len(messages)} messages to LLM")
|
||||
turns = self._identify_complete_turns()
|
||||
logger.info(f"Sending {len(messages)} messages ({len(turns)} turns) to LLM")
|
||||
|
||||
# Prepare tool definitions (OpenAI/Claude format)
|
||||
tools_schema = None
|
||||
@@ -527,7 +599,9 @@ class AgentStreamExecutor:
|
||||
|
||||
# Streaming response
|
||||
full_content = ""
|
||||
full_reasoning = ""
|
||||
tool_calls_buffer = {} # {index: {id, name, arguments}}
|
||||
gemini_raw_parts = None # Preserve Gemini thoughtSignature for round-trip
|
||||
stop_reason = None # Track why the stream stopped
|
||||
|
||||
try:
|
||||
@@ -574,7 +648,7 @@ class AgentStreamExecutor:
|
||||
raise Exception(f"{error_msg} (Status: {status_code}, Code: {error_code}, Type: {error_type})")
|
||||
|
||||
# Parse chunk
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
if isinstance(chunk, dict) and chunk.get("choices"):
|
||||
choice = chunk["choices"][0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
@@ -583,6 +657,12 @@ class AgentStreamExecutor:
|
||||
if finish_reason:
|
||||
stop_reason = finish_reason
|
||||
|
||||
reasoning_delta = delta.get("reasoning_content") or ""
|
||||
if reasoning_delta:
|
||||
full_reasoning += reasoning_delta
|
||||
if self._is_thinking_enabled():
|
||||
self._emit_event("reasoning_update", {"delta": reasoning_delta})
|
||||
|
||||
# Handle text content
|
||||
content_delta = delta.get("content") or ""
|
||||
if content_delta:
|
||||
@@ -604,16 +684,23 @@ class AgentStreamExecutor:
|
||||
"arguments": ""
|
||||
}
|
||||
|
||||
if "id" in tc_delta:
|
||||
if tc_delta.get("id"):
|
||||
tool_calls_buffer[index]["id"] = tc_delta["id"]
|
||||
|
||||
if "function" in tc_delta:
|
||||
func = tc_delta["function"]
|
||||
if "name" in func:
|
||||
if func.get("name"):
|
||||
tool_calls_buffer[index]["name"] = func["name"]
|
||||
if "arguments" in func:
|
||||
if func.get("arguments"):
|
||||
tool_calls_buffer[index]["arguments"] += func["arguments"]
|
||||
|
||||
# Preserve _gemini_raw_parts for Gemini thoughtSignature round-trip
|
||||
# (direct Gemini: list of parts; LinkAI proxy: base64 string of JSON parts)
|
||||
if "_gemini_raw_parts" in delta:
|
||||
gemini_raw_parts = delta["_gemini_raw_parts"]
|
||||
elif isinstance(choice, dict) and choice.get("_gemini_raw_parts"):
|
||||
gemini_raw_parts = choice["_gemini_raw_parts"]
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
error_str_lower = error_str.lower()
|
||||
@@ -631,16 +718,33 @@ class AgentStreamExecutor:
|
||||
])
|
||||
|
||||
# Check if error is message format error (incomplete tool_use/tool_result pairs)
|
||||
# This happens when previous conversation had tool failures
|
||||
# This happens when previous conversation had tool failures or context trimming
|
||||
# broke tool_use/tool_result pairs.
|
||||
# Note: MiniMax returns error 2013 "tool result's tool id(...) not found" for
|
||||
# tool_call_id mismatches — the keywords below are intentionally broad to catch
|
||||
# both standard (Claude/OpenAI) and provider-specific (MiniMax) variants.
|
||||
is_message_format_error = any(keyword in error_str_lower for keyword in [
|
||||
'tool_use', 'tool_result', 'without', 'immediately after',
|
||||
'corresponding', 'must have', 'each'
|
||||
]) and 'status: 400' in error_str_lower
|
||||
'tool_use', 'tool_result', 'tool result', 'without', 'immediately after',
|
||||
'corresponding', 'must have', 'each',
|
||||
'tool_call_id', 'tool id', 'is not found', 'not found', 'tool_calls',
|
||||
'must be a response to a preceeding message',
|
||||
'2013', # MiniMax error code for tool_call_id mismatch
|
||||
]) and ('400' in error_str_lower or 'status: 400' in error_str_lower
|
||||
or 'invalid_request' in error_str_lower
|
||||
or 'invalidparameter' in error_str_lower)
|
||||
|
||||
if is_context_overflow or is_message_format_error:
|
||||
error_type = "context overflow" if is_context_overflow else "message format error"
|
||||
logger.error(f"💥 {error_type} detected: {e}")
|
||||
|
||||
# Flush memory before trimming to preserve context that will be lost
|
||||
if is_context_overflow and self.agent.memory_manager:
|
||||
user_id = getattr(self.agent, '_current_user_id', None)
|
||||
self.agent.memory_manager.flush_memory(
|
||||
messages=self.messages, user_id=user_id,
|
||||
reason="overflow", max_messages=0
|
||||
)
|
||||
|
||||
# Strategy: try aggressive trimming first, only clear as last resort
|
||||
if is_context_overflow and not _overflow_retry:
|
||||
trimmed = self._aggressive_trim_for_overflow()
|
||||
@@ -654,9 +758,10 @@ class AgentStreamExecutor:
|
||||
)
|
||||
|
||||
# Aggressive trim didn't help or this is a message format error
|
||||
# -> clear everything
|
||||
# -> clear everything and also purge DB to prevent reload of dirty data
|
||||
logger.warning("🔄 Clearing conversation history to recover")
|
||||
self.messages.clear()
|
||||
self._clear_session_db()
|
||||
if is_context_overflow:
|
||||
raise Exception(
|
||||
"抱歉,对话历史过长导致上下文溢出。我已清空历史记录,请重新描述你的需求。"
|
||||
@@ -693,9 +798,9 @@ class AgentStreamExecutor:
|
||||
)
|
||||
else:
|
||||
if retry_count >= max_retries:
|
||||
logger.error(f"❌ LLM API error after {max_retries} retries: {e}")
|
||||
logger.error(f"❌ LLM API error after {max_retries} retries: {e}", exc_info=True)
|
||||
else:
|
||||
logger.error(f"❌ LLM call error (non-retryable): {e}")
|
||||
logger.error(f"❌ LLM call error (non-retryable): {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# Parse tool calls
|
||||
@@ -760,7 +865,18 @@ class AgentStreamExecutor:
|
||||
# Add assistant message to history (Claude format uses content blocks)
|
||||
assistant_msg = {"role": "assistant", "content": []}
|
||||
|
||||
# Add text content block if present
|
||||
if full_reasoning:
|
||||
stored_reasoning = _truncate_reasoning_for_storage(full_reasoning)
|
||||
if len(stored_reasoning) < len(full_reasoning):
|
||||
logger.info(
|
||||
f"[reasoning] truncated for storage: "
|
||||
f"{len(full_reasoning)} -> {len(stored_reasoning)} chars"
|
||||
)
|
||||
assistant_msg["content"].append({
|
||||
"type": "thinking",
|
||||
"thinking": stored_reasoning
|
||||
})
|
||||
|
||||
if full_content:
|
||||
assistant_msg["content"].append({
|
||||
"type": "text",
|
||||
@@ -777,6 +893,9 @@ class AgentStreamExecutor:
|
||||
"input": tc.get("arguments", {})
|
||||
})
|
||||
|
||||
if gemini_raw_parts:
|
||||
assistant_msg["_gemini_raw_parts"] = gemini_raw_parts
|
||||
|
||||
# Only append if content is not empty
|
||||
if assistant_msg["content"]:
|
||||
self.messages.append(assistant_msg)
|
||||
@@ -845,7 +964,7 @@ class AgentStreamExecutor:
|
||||
try:
|
||||
tool = self.tools.get(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool '{tool_name}' not found")
|
||||
raise ValueError(self._build_tool_not_found_message(tool_name))
|
||||
|
||||
# Set tool context
|
||||
tool.model = self.model
|
||||
@@ -899,26 +1018,50 @@ class AgentStreamExecutor:
|
||||
})
|
||||
return error_result
|
||||
|
||||
def _build_tool_not_found_message(self, tool_name: str) -> str:
|
||||
"""Build a helpful error message when a tool is not found.
|
||||
|
||||
If a skill with the same name exists in skill_manager, read its
|
||||
SKILL.md and include the content so the LLM knows how to use it.
|
||||
"""
|
||||
available_tools = list(self.tools.keys())
|
||||
base_msg = f"Tool '{tool_name}' not found. Available tools: {available_tools}"
|
||||
|
||||
skill_manager = getattr(self.agent, 'skill_manager', None)
|
||||
if not skill_manager:
|
||||
return base_msg
|
||||
|
||||
skill_entry = skill_manager.get_skill(tool_name)
|
||||
if not skill_entry:
|
||||
return base_msg
|
||||
|
||||
skill = skill_entry.skill
|
||||
skill_md_path = skill.file_path
|
||||
skill_content = ""
|
||||
try:
|
||||
with open(skill_md_path, 'r', encoding='utf-8') as f:
|
||||
skill_content = f.read()
|
||||
except Exception:
|
||||
skill_content = skill.description
|
||||
|
||||
logger.info(
|
||||
f"[Agent] Tool '{tool_name}' not found, but matched skill '{skill.name}'. "
|
||||
f"Guiding LLM to use the skill instead."
|
||||
)
|
||||
|
||||
return (
|
||||
f"Tool '{tool_name}' is not a built-in tool, but a matching skill "
|
||||
f"'{skill.name}' is available. You should use existing tools (e.g. bash with curl) "
|
||||
f"to accomplish this task following the skill instructions below:\n\n"
|
||||
f"--- SKILL: {skill.name} (path: {skill_md_path}) ---\n"
|
||||
f"{skill_content}\n"
|
||||
f"--- END SKILL ---\n\n"
|
||||
f"Available tools: {available_tools}"
|
||||
)
|
||||
|
||||
def _validate_and_fix_messages(self):
|
||||
"""
|
||||
Validate message history and fix incomplete tool_use/tool_result pairs.
|
||||
Claude API requires each tool_use to have a corresponding tool_result immediately after.
|
||||
"""
|
||||
if not self.messages:
|
||||
return
|
||||
|
||||
# Check last message for incomplete tool_use
|
||||
if len(self.messages) > 0:
|
||||
last_msg = self.messages[-1]
|
||||
if last_msg.get("role") == "assistant":
|
||||
# Check if assistant message has tool_use blocks
|
||||
content = last_msg.get("content", [])
|
||||
if isinstance(content, list):
|
||||
has_tool_use = any(block.get("type") == "tool_use" for block in content)
|
||||
if has_tool_use:
|
||||
# This is incomplete - remove it
|
||||
logger.warning(f"⚠️ Removing incomplete tool_use message from history")
|
||||
self.messages.pop()
|
||||
"""Delegate to the shared sanitizer (see message_sanitizer.py)."""
|
||||
sanitize_claude_messages(self.messages)
|
||||
|
||||
def _identify_complete_turns(self) -> List[Dict]:
|
||||
"""
|
||||
@@ -941,24 +1084,30 @@ class AgentStreamExecutor:
|
||||
content = msg.get('content', [])
|
||||
|
||||
if role == 'user':
|
||||
# 检查是否是用户查询(不是工具结果)
|
||||
# Determine if this is a real user query (not a tool_result injection
|
||||
# or an internal hint message injected by the agent loop).
|
||||
is_user_query = False
|
||||
has_tool_result = False
|
||||
if isinstance(content, list):
|
||||
is_user_query = any(
|
||||
block.get('type') == 'text'
|
||||
for block in content
|
||||
if isinstance(block, dict)
|
||||
has_text = any(
|
||||
isinstance(block, dict) and block.get('type') == 'text'
|
||||
for block in content
|
||||
)
|
||||
has_tool_result = any(
|
||||
isinstance(block, dict) and block.get('type') == 'tool_result'
|
||||
for block in content
|
||||
)
|
||||
# A message with tool_result is always internal, even if it
|
||||
# also contains text blocks (shouldn't happen, but be safe).
|
||||
is_user_query = has_text and not has_tool_result
|
||||
elif isinstance(content, str):
|
||||
is_user_query = True
|
||||
|
||||
if is_user_query:
|
||||
# 开始新轮次
|
||||
if current_turn['messages']:
|
||||
turns.append(current_turn)
|
||||
current_turn = {'messages': [msg]}
|
||||
else:
|
||||
# 工具结果,属于当前轮次
|
||||
current_turn['messages'].append(msg)
|
||||
else:
|
||||
# AI 回复,属于当前轮次
|
||||
@@ -1131,6 +1280,56 @@ class AgentStreamExecutor:
|
||||
logger.warning("🔧 Aggressive trim: nothing to trim, will clear history")
|
||||
return False
|
||||
|
||||
def _build_context_summary_callback(self, discarded_turns: list, kept_turns: list):
|
||||
"""
|
||||
Build a callback that injects an LLM summary into the first user
|
||||
message of *kept_turns*. Returns None if no valid injection target.
|
||||
|
||||
The callback is passed to flush_from_messages so that the same LLM
|
||||
call that writes daily memory also provides the in-context summary.
|
||||
"""
|
||||
if not kept_turns:
|
||||
return None
|
||||
|
||||
# Find the first user text block in kept_turns as injection target
|
||||
target_block = None
|
||||
for turn in kept_turns:
|
||||
for msg in turn["messages"]:
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", [])
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
target_block = block
|
||||
break
|
||||
if target_block:
|
||||
break
|
||||
if target_block:
|
||||
break
|
||||
|
||||
if not target_block:
|
||||
return None
|
||||
|
||||
turn_count = len(discarded_turns)
|
||||
original_text = target_block["text"]
|
||||
|
||||
def _on_summary_ready(summary: str):
|
||||
if not summary or not summary.strip():
|
||||
return
|
||||
target_block["text"] = (
|
||||
f"[System: Previous conversation summary — "
|
||||
f"{turn_count} turns were compacted]\n\n"
|
||||
f"{summary.strip()}\n\n"
|
||||
f"The recent conversation continues below.\n\n---\n\n"
|
||||
f"{original_text}"
|
||||
)
|
||||
logger.info(
|
||||
f"📝 Context summary injected "
|
||||
f"({len(summary)} chars, {turn_count} turns)"
|
||||
)
|
||||
|
||||
return _on_summary_ready
|
||||
|
||||
def _trim_messages(self):
|
||||
"""
|
||||
智能清理消息历史,保持对话完整性
|
||||
@@ -1152,16 +1351,33 @@ class AgentStreamExecutor:
|
||||
if not turns:
|
||||
return
|
||||
|
||||
# Step 2: 轮次限制 - 保留最近 N 轮
|
||||
# Step 2: 轮次限制 - 超出时移除前一半,保留后一半
|
||||
if len(turns) > self.max_context_turns:
|
||||
removed_turns = len(turns) - self.max_context_turns
|
||||
turns = turns[-self.max_context_turns:] # 保留最近的轮次
|
||||
removed_count = len(turns) // 2
|
||||
keep_count = len(turns) - removed_count
|
||||
|
||||
discarded_turns = turns[:removed_count]
|
||||
turns = turns[-keep_count:]
|
||||
|
||||
logger.info(
|
||||
f"💾 上下文轮次超限: {len(turns) + removed_turns} > {self.max_context_turns},"
|
||||
f"移除最早的 {removed_turns} 轮完整对话"
|
||||
f"💾 上下文轮次超限: {keep_count + removed_count} > {self.max_context_turns},"
|
||||
f"裁剪至 {keep_count} 轮(移除 {removed_count} 轮)"
|
||||
)
|
||||
|
||||
# Flush to daily memory + inject context summary (single async LLM call)
|
||||
if self.agent.memory_manager:
|
||||
discarded_messages = []
|
||||
for turn in discarded_turns:
|
||||
discarded_messages.extend(turn["messages"])
|
||||
if discarded_messages:
|
||||
user_id = getattr(self.agent, '_current_user_id', None)
|
||||
cb = self._build_context_summary_callback(discarded_turns, turns)
|
||||
self.agent.memory_manager.flush_memory(
|
||||
messages=discarded_messages, user_id=user_id,
|
||||
reason="trim", max_messages=0,
|
||||
context_summary_callback=cb,
|
||||
)
|
||||
|
||||
# Step 3: Token 限制 - 保留完整轮次
|
||||
# Get context window from agent (based on model)
|
||||
context_window = self.agent._get_model_context_window()
|
||||
@@ -1196,56 +1412,99 @@ class AgentStreamExecutor:
|
||||
logger.info(f" 重建消息列表: {old_count} -> {len(self.messages)} 条消息")
|
||||
return
|
||||
|
||||
# Token limit exceeded - keep complete turns from newest
|
||||
# Token limit exceeded — tiered strategy based on turn count:
|
||||
#
|
||||
# Few turns (<5): Compress ALL turns to text-only (strip tool chains,
|
||||
# keep user query + final reply). Never discard turns
|
||||
# — losing even one is too painful when context is thin.
|
||||
#
|
||||
# Many turns (>=5): Directly discard the first half of turns.
|
||||
# With enough turns the oldest ones are less
|
||||
# critical, and keeping the recent half intact
|
||||
# (with full tool chains) is more useful.
|
||||
|
||||
COMPRESS_THRESHOLD = 5
|
||||
|
||||
if len(turns) < COMPRESS_THRESHOLD:
|
||||
# --- Few turns: compress ALL turns to text-only, never discard ---
|
||||
compressed_turns = []
|
||||
for t in turns:
|
||||
compressed = compress_turn_to_text_only(t)
|
||||
if compressed["messages"]:
|
||||
compressed_turns.append(compressed)
|
||||
|
||||
new_messages = []
|
||||
for turn in compressed_turns:
|
||||
new_messages.extend(turn["messages"])
|
||||
|
||||
new_tokens = sum(self._estimate_turn_tokens(t) for t in compressed_turns)
|
||||
old_count = len(self.messages)
|
||||
self.messages = new_messages
|
||||
|
||||
logger.info(
|
||||
f"📦 上下文tokens超限(轮次<{COMPRESS_THRESHOLD}): "
|
||||
f"~{current_tokens + system_tokens} > {max_tokens},"
|
||||
f"压缩全部 {len(turns)} 轮为纯文本 "
|
||||
f"({old_count} -> {len(self.messages)} 条消息,"
|
||||
f"~{current_tokens + system_tokens} -> ~{new_tokens + system_tokens} tokens)"
|
||||
)
|
||||
return
|
||||
|
||||
# --- Many turns (>=5): discard the older half, keep the newer half ---
|
||||
removed_count = len(turns) // 2
|
||||
keep_count = len(turns) - removed_count
|
||||
discarded_turns = turns[:removed_count]
|
||||
kept_turns = turns[-keep_count:]
|
||||
kept_tokens = sum(self._estimate_turn_tokens(t) for t in kept_turns)
|
||||
|
||||
logger.info(
|
||||
f"🔄 上下文tokens超限: ~{current_tokens + system_tokens} > {max_tokens},"
|
||||
f"将按完整轮次移除最早的对话"
|
||||
f"裁剪至 {keep_count} 轮(移除 {removed_count} 轮)"
|
||||
)
|
||||
|
||||
# 从最新轮次开始,反向累加(保持完整轮次)
|
||||
kept_turns = []
|
||||
accumulated_tokens = 0
|
||||
min_turns = 3 # 尽量保留至少 3 轮,但不强制(避免超出 token 限制)
|
||||
|
||||
for i, turn in enumerate(reversed(turns)):
|
||||
turn_tokens = self._estimate_turn_tokens(turn)
|
||||
turns_from_end = i + 1
|
||||
|
||||
# 检查是否超出限制
|
||||
if accumulated_tokens + turn_tokens <= available_tokens:
|
||||
kept_turns.insert(0, turn)
|
||||
accumulated_tokens += turn_tokens
|
||||
else:
|
||||
# 超出限制
|
||||
# 如果还没有保留足够的轮次,且这是最后的机会,尝试保留
|
||||
if len(kept_turns) < min_turns and turns_from_end <= min_turns:
|
||||
# 检查是否严重超出(超出 20% 以上则放弃)
|
||||
overflow_ratio = (accumulated_tokens + turn_tokens - available_tokens) / available_tokens
|
||||
if overflow_ratio < 0.2: # 允许最多超出 20%
|
||||
kept_turns.insert(0, turn)
|
||||
accumulated_tokens += turn_tokens
|
||||
logger.debug(f" 为保留最少轮次,允许超出 {overflow_ratio*100:.1f}%")
|
||||
continue
|
||||
# 停止保留更早的轮次
|
||||
break
|
||||
|
||||
# 重建消息列表
|
||||
if self.agent.memory_manager:
|
||||
discarded_messages = []
|
||||
for turn in discarded_turns:
|
||||
discarded_messages.extend(turn["messages"])
|
||||
if discarded_messages:
|
||||
user_id = getattr(self.agent, '_current_user_id', None)
|
||||
cb = self._build_context_summary_callback(discarded_turns, kept_turns)
|
||||
self.agent.memory_manager.flush_memory(
|
||||
messages=discarded_messages, user_id=user_id,
|
||||
reason="trim", max_messages=0,
|
||||
context_summary_callback=cb,
|
||||
)
|
||||
|
||||
new_messages = []
|
||||
for turn in kept_turns:
|
||||
new_messages.extend(turn['messages'])
|
||||
|
||||
|
||||
old_count = len(self.messages)
|
||||
old_turn_count = len(turns)
|
||||
self.messages = new_messages
|
||||
new_count = len(self.messages)
|
||||
new_turn_count = len(kept_turns)
|
||||
|
||||
if old_count > new_count:
|
||||
logger.info(
|
||||
f" 移除了 {old_turn_count - new_turn_count} 轮对话 "
|
||||
f"({old_count} -> {new_count} 条消息,"
|
||||
f"~{current_tokens + system_tokens} -> ~{accumulated_tokens + system_tokens} tokens)"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f" 移除了 {removed_count} 轮对话 "
|
||||
f"({old_count} -> {len(self.messages)} 条消息,"
|
||||
f"~{current_tokens + system_tokens} -> ~{kept_tokens + system_tokens} tokens)"
|
||||
)
|
||||
|
||||
def _clear_session_db(self):
|
||||
"""
|
||||
Clear the current session's persisted messages from SQLite DB.
|
||||
|
||||
This prevents dirty data (broken tool_use/tool_result pairs) from being
|
||||
reloaded on the next request or after a restart.
|
||||
"""
|
||||
try:
|
||||
session_id = getattr(self.agent, '_current_session_id', None)
|
||||
if not session_id:
|
||||
return
|
||||
from agent.memory import get_conversation_store
|
||||
store = get_conversation_store()
|
||||
store.clear_session(session_id)
|
||||
logger.info(f"🗑️ Cleared dirty session data from DB: {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clear session DB: {e}")
|
||||
|
||||
def _prepare_messages(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
|
||||
335
agent/protocol/message_utils.py
Normal file
335
agent/protocol/message_utils.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Message sanitizer — fix broken tool_use / tool_result pairs.
|
||||
|
||||
Provides two public helpers that can be reused across agent_stream.py
|
||||
and any bot that converts messages to OpenAI format:
|
||||
|
||||
1. sanitize_claude_messages(messages)
|
||||
Operates on the internal Claude-format message list (in-place).
|
||||
|
||||
2. drop_orphaned_tool_results_openai(messages)
|
||||
Operates on an already-converted OpenAI-format message list,
|
||||
returning a cleaned copy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from common.log import logger
|
||||
|
||||
_SYNTH_TOOL_ERR = (
|
||||
"Error: Missing tool_result adjacent to tool_use (session repair). "
|
||||
"The conversation history was inconsistent; continue from here."
|
||||
)
|
||||
|
||||
|
||||
def _repair_tool_use_adjacency(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Anthropic requires: after assistant content with tool_use, the next message
|
||||
must be user content listing tool_result for every tool_use id (same user msg).
|
||||
|
||||
Valid histories satisfy this at every such assistant; the loop only mutates
|
||||
when that condition fails (broken persistence, bad trims, etc.).
|
||||
"""
|
||||
|
||||
def _synth_block(tid: str) -> Dict:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tid,
|
||||
"content": _SYNTH_TOOL_ERR,
|
||||
"is_error": True,
|
||||
}
|
||||
|
||||
repairs = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
if msg.get("role") != "assistant":
|
||||
i += 1
|
||||
continue
|
||||
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
required = [
|
||||
b.get("id")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use" and b.get("id")
|
||||
]
|
||||
if not required:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
req_set = set(required)
|
||||
if i + 1 >= len(messages):
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [_synth_block(tid) for tid in required],
|
||||
})
|
||||
logger.warning(
|
||||
"⚠️ Appended synthetic tool_result after trailing assistant tool_use"
|
||||
)
|
||||
repairs += 1
|
||||
break
|
||||
|
||||
nxt = messages[i + 1]
|
||||
if nxt.get("role") != "user":
|
||||
messages.insert(
|
||||
i + 1,
|
||||
{"role": "user", "content": [_synth_block(tid) for tid in required]},
|
||||
)
|
||||
logger.warning(
|
||||
"⚠️ Inserted synthetic tool_result user after tool_use "
|
||||
f"(next role={nxt.get('role')!r})"
|
||||
)
|
||||
repairs += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
nc = nxt.get("content", [])
|
||||
if not isinstance(nc, list):
|
||||
messages.insert(
|
||||
i + 1,
|
||||
{"role": "user", "content": [_synth_block(tid) for tid in required]},
|
||||
)
|
||||
repairs += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
present = {
|
||||
b.get("tool_use_id")
|
||||
for b in nc
|
||||
if isinstance(b, dict) and b.get("type") == "tool_result" and b.get("tool_use_id")
|
||||
}
|
||||
if req_set <= present:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
missing = [tid for tid in required if tid not in present]
|
||||
nxt["content"] = [_synth_block(tid) for tid in missing] + nc
|
||||
logger.warning(
|
||||
"⚠️ Prepended synthetic tool_result for Anthropic adjacency "
|
||||
f"(missing_ids={missing})"
|
||||
)
|
||||
repairs += len(missing)
|
||||
i += 1
|
||||
|
||||
return repairs
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Claude-format sanitizer (used by agent_stream)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def sanitize_claude_messages(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Validate and fix a Claude-format message list **in-place**.
|
||||
|
||||
Fixes handled:
|
||||
- Anthropic adjacency: assistant tool_use must be immediately followed by
|
||||
user message(s) containing matching tool_result blocks
|
||||
- Leading orphaned tool_result user messages
|
||||
- Mid-list tool_result blocks whose tool_use_id has no matching
|
||||
tool_use in any preceding assistant message
|
||||
|
||||
Returns: number of removals plus adjacency repair operations (inserts/prepends).
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
removed = 0
|
||||
|
||||
# 1. Adjacency repair (Anthropic: tool_result must be in the next user message)
|
||||
adj_repairs = _repair_tool_use_adjacency(messages)
|
||||
|
||||
# 2. Remove leading orphaned tool_result user messages
|
||||
while messages:
|
||||
first = messages[0]
|
||||
if first.get("role") != "user":
|
||||
break
|
||||
content = first.get("content", [])
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result") \
|
||||
and not _has_block_type(content, "text"):
|
||||
logger.warning("⚠️ Removing leading orphaned tool_result user message")
|
||||
messages.pop(0)
|
||||
removed += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 3. Iteratively remove unmatched tool_use / tool_result until stable.
|
||||
# Removing one broken message can orphan others (e.g. an assistant msg
|
||||
# with both matched and unmatched tool_use — deleting it orphans the
|
||||
# previously-matched tool_result). Loop until clean.
|
||||
for _ in range(5):
|
||||
use_ids: Set[str] = set()
|
||||
result_ids: Set[str] = set()
|
||||
for msg in messages:
|
||||
for block in (msg.get("content") or []):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "tool_use" and block.get("id"):
|
||||
use_ids.add(block["id"])
|
||||
elif block.get("type") == "tool_result" and block.get("tool_use_id"):
|
||||
result_ids.add(block["tool_use_id"])
|
||||
|
||||
bad_use = use_ids - result_ids
|
||||
bad_result = result_ids - use_ids
|
||||
if not bad_use and not bad_result:
|
||||
break
|
||||
|
||||
pass_removed = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if role == "assistant" and bad_use and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
and b.get("id") in bad_use for b in content
|
||||
):
|
||||
logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
|
||||
if role == "user" and bad_result and _has_block_type(content, "tool_result"):
|
||||
has_bad = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result for b in content
|
||||
)
|
||||
if has_bad:
|
||||
if not _has_block_type(content, "text"):
|
||||
logger.warning(f"⚠️ Removing user msg with unmatched tool_result")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
else:
|
||||
before = len(content)
|
||||
msg["content"] = [
|
||||
b for b in content
|
||||
if not (isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result)
|
||||
]
|
||||
pass_removed += before - len(msg["content"])
|
||||
|
||||
i += 1
|
||||
|
||||
removed += pass_removed
|
||||
if pass_removed == 0:
|
||||
break
|
||||
|
||||
# 4. Removals above can break adjacency; re-run repair only if something was removed.
|
||||
if removed:
|
||||
adj_repairs += _repair_tool_use_adjacency(messages)
|
||||
|
||||
if removed:
|
||||
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
|
||||
if adj_repairs:
|
||||
logger.info(f"🔧 Message validation: adjacency repairs={adj_repairs}")
|
||||
return removed + adj_repairs
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# OpenAI-format sanitizer (used by minimax_bot, openai_compatible_bot)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def drop_orphaned_tool_results_openai(messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Return a copy of *messages* (OpenAI format) with any ``role=tool``
|
||||
messages removed if their ``tool_call_id`` does not match a
|
||||
``tool_calls[].id`` in a preceding assistant message.
|
||||
"""
|
||||
known_ids: Set[str] = set()
|
||||
cleaned: List[Dict] = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id", "")
|
||||
if tc_id:
|
||||
known_ids.add(tc_id)
|
||||
|
||||
if msg.get("role") == "tool":
|
||||
ref_id = msg.get("tool_call_id", "")
|
||||
if ref_id and ref_id not in known_ids:
|
||||
logger.warning(
|
||||
f"[MessageSanitizer] Dropping orphaned tool result "
|
||||
f"(tool_call_id={ref_id} not in known ids)"
|
||||
)
|
||||
continue
|
||||
cleaned.append(msg)
|
||||
return cleaned
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _has_block_type(content: list, block_type: str) -> bool:
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == block_type
|
||||
for b in content
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from a message content field (str or list of blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def compress_turn_to_text_only(turn: Dict) -> Dict:
|
||||
"""
|
||||
Compress a full turn (with tool_use/tool_result chains) into a lightweight
|
||||
text-only turn that keeps only the first user text and the last assistant text.
|
||||
|
||||
This preserves the conversational context (what the user asked and what the
|
||||
agent concluded) while stripping out the bulky intermediate tool interactions.
|
||||
|
||||
Returns a new turn dict with a ``messages`` list; the original is not mutated.
|
||||
"""
|
||||
user_text = ""
|
||||
last_assistant_text = ""
|
||||
|
||||
for msg in turn["messages"]:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result"):
|
||||
continue
|
||||
if not user_text:
|
||||
user_text = _extract_text_from_content(content)
|
||||
|
||||
elif role == "assistant":
|
||||
text = _extract_text_from_content(content)
|
||||
if text:
|
||||
last_assistant_text = text
|
||||
|
||||
compressed_messages = []
|
||||
if user_text:
|
||||
compressed_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if last_assistant_text:
|
||||
compressed_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": last_assistant_text}]
|
||||
})
|
||||
|
||||
return {"messages": compressed_messages}
|
||||
@@ -15,6 +15,7 @@ from agent.skills.types import (
|
||||
)
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.manager import SkillManager
|
||||
from agent.skills.service import SkillService
|
||||
from agent.skills.formatter import format_skills_for_prompt
|
||||
|
||||
__all__ = [
|
||||
@@ -25,5 +26,6 @@ __all__ = [
|
||||
"LoadSkillsResult",
|
||||
"SkillLoader",
|
||||
"SkillManager",
|
||||
"SkillService",
|
||||
"format_skills_for_prompt",
|
||||
]
|
||||
|
||||
@@ -123,17 +123,63 @@ def should_include_skill(
|
||||
return False
|
||||
|
||||
# Check environment variables (API keys)
|
||||
# Simple rule: All required env vars must be set
|
||||
# All required env vars must be set
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
for env_name in required_env:
|
||||
if not has_env_var(env_name):
|
||||
# Missing required API key → disable skill
|
||||
return False
|
||||
|
||||
# Check anyEnv (at least one must be present)
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env:
|
||||
if not any(has_env_var(e) for e in any_env):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_missing_requirements(
|
||||
entry: SkillEntry,
|
||||
current_platform: Optional[str] = None,
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict of missing requirements for a skill.
|
||||
Empty dict means all requirements are met.
|
||||
|
||||
:param entry: SkillEntry to check
|
||||
:param current_platform: Current platform (default: auto-detect)
|
||||
:return: Dict like {"bins": ["curl"], "env": ["API_KEY"]}
|
||||
"""
|
||||
missing: Dict[str, List[str]] = {}
|
||||
metadata = entry.metadata
|
||||
|
||||
if not metadata or not metadata.requires:
|
||||
return missing
|
||||
|
||||
required_bins = metadata.requires.get('bins', [])
|
||||
if required_bins:
|
||||
missing_bins = [b for b in required_bins if not has_binary(b)]
|
||||
if missing_bins:
|
||||
missing['bins'] = missing_bins
|
||||
|
||||
any_bins = metadata.requires.get('anyBins', [])
|
||||
if any_bins and not has_any_binary(any_bins):
|
||||
missing['anyBins'] = any_bins
|
||||
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
missing_env = [e for e in required_env if not has_env_var(e)]
|
||||
if missing_env:
|
||||
missing['env'] = missing_env
|
||||
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env and not any(has_env_var(e) for e in any_env):
|
||||
missing['anyEnv'] = any_env
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def is_config_path_truthy(config: Dict, path: str) -> bool:
|
||||
"""
|
||||
Check if a config path resolves to a truthy value.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Skill formatter for generating prompts from skills.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ def format_skills_for_prompt(skills: List[Skill]) -> str:
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <location>{_escape_xml(skill.file_path)}</location>")
|
||||
lines.append(f" <base_dir>{_escape_xml(skill.base_dir)}</base_dir>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</available_skills>")
|
||||
@@ -50,6 +51,71 @@ def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
|
||||
return format_skills_for_prompt(skills)
|
||||
|
||||
|
||||
def format_unavailable_skills_for_prompt(
|
||||
entries: List[SkillEntry],
|
||||
missing_map: Dict[str, Dict[str, List[str]]],
|
||||
) -> str:
|
||||
"""
|
||||
Format unavailable (requires-not-met) skills as brief setup hints
|
||||
so the AI can guide users to configure them.
|
||||
|
||||
:param entries: List of unavailable skill entries
|
||||
:param missing_map: Dict mapping skill name to its missing requirements
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"",
|
||||
"<unavailable_skills>",
|
||||
"The following skills are installed but not yet ready. "
|
||||
"Guide the user to complete the setup when relevant.",
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
skill = entry.skill
|
||||
missing = missing_map.get(skill.name, {})
|
||||
|
||||
missing_parts = []
|
||||
for key, values in missing.items():
|
||||
missing_parts.append(f"{key}: {', '.join(values)}")
|
||||
missing_str = "; ".join(missing_parts) if missing_parts else "unknown"
|
||||
|
||||
setup_hint = _extract_setup_hint(skill)
|
||||
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <missing>{_escape_xml(missing_str)}</missing>")
|
||||
if setup_hint:
|
||||
lines.append(f" <setup>{_escape_xml(setup_hint)}</setup>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</unavailable_skills>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_setup_hint(skill: Skill) -> str:
|
||||
"""
|
||||
Extract the Setup section from SKILL.md content as a brief hint.
|
||||
Returns the first few lines of the ## Setup section.
|
||||
"""
|
||||
content = skill.content
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
import re
|
||||
match = re.search(r'^##\s+Setup\s*\n(.*?)(?=\n##\s|\Z)', content, re.MULTILINE | re.DOTALL)
|
||||
if not match:
|
||||
return ""
|
||||
|
||||
setup_text = match.group(1).strip()
|
||||
lines = setup_text.split('\n')
|
||||
hint_lines = [l.strip() for l in lines[:6] if l.strip()]
|
||||
return ' '.join(hint_lines)[:300]
|
||||
|
||||
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (text
|
||||
|
||||
@@ -87,8 +87,8 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
if not isinstance(metadata_raw, dict):
|
||||
return None
|
||||
|
||||
# Use metadata_raw directly (COW format)
|
||||
meta_obj = metadata_raw
|
||||
# Unwrap nested namespace (e.g. {"openclaw": {...}} or {"cowagent": {...}})
|
||||
meta_obj = _unwrap_metadata_namespace(metadata_raw)
|
||||
|
||||
# Parse install specs
|
||||
install_specs = []
|
||||
@@ -128,6 +128,7 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
|
||||
return SkillMetadata(
|
||||
always=meta_obj.get('always', False),
|
||||
default_enabled=meta_obj.get('default_enabled', True),
|
||||
skill_key=meta_obj.get('skillKey'),
|
||||
primary_env=meta_obj.get('primaryEnv'),
|
||||
emoji=meta_obj.get('emoji'),
|
||||
@@ -138,6 +139,25 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
)
|
||||
|
||||
|
||||
_KNOWN_METADATA_NAMESPACES = {"cowagent", "openclaw"}
|
||||
|
||||
|
||||
def _unwrap_metadata_namespace(metadata_raw: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Unwrap a single-key namespace wrapper like {"cowagent": {...} or {"openclaw": {...}}}.
|
||||
If the top-level dict has exactly one key matching a known namespace, return the inner dict.
|
||||
Otherwise return the original dict unchanged.
|
||||
"""
|
||||
keys = set(metadata_raw.keys())
|
||||
ns_keys = keys & _KNOWN_METADATA_NAMESPACES
|
||||
if len(ns_keys) == 1 and len(keys) == 1:
|
||||
ns = ns_keys.pop()
|
||||
inner = metadata_raw[ns]
|
||||
if isinstance(inner, dict):
|
||||
return inner
|
||||
return metadata_raw
|
||||
|
||||
|
||||
def _normalize_string_list(value: Any) -> List[str]:
|
||||
"""Normalize a value to a list of strings."""
|
||||
if not value:
|
||||
|
||||
@@ -12,25 +12,20 @@ from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_bo
|
||||
|
||||
class SkillLoader:
|
||||
"""Loads skills from various directories."""
|
||||
|
||||
def __init__(self, workspace_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the skill loader.
|
||||
|
||||
:param workspace_dir: Agent workspace directory (for workspace-specific skills)
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load skills from a directory.
|
||||
|
||||
|
||||
Discovery rules:
|
||||
- Direct .md files in the root directory
|
||||
- Recursive SKILL.md files under subdirectories
|
||||
|
||||
|
||||
:param dir_path: Directory path to scan
|
||||
:param source: Source identifier (e.g., 'managed', 'workspace', 'bundled')
|
||||
:param source: Source identifier ('builtin' or 'custom')
|
||||
:return: LoadSkillsResult with skills and diagnostics
|
||||
"""
|
||||
skills = []
|
||||
@@ -58,6 +53,12 @@ class SkillLoader:
|
||||
"""
|
||||
Recursively load skills from a directory.
|
||||
|
||||
If a subdirectory contains its own SKILL.md, it is treated as a
|
||||
self-contained skill (or skill-collection) and its children are
|
||||
NOT scanned further. This prevents sub-skills inside a collection
|
||||
(e.g. style-collection/style-anjing) from being listed as
|
||||
independent top-level skills.
|
||||
|
||||
:param dir_path: Directory to scan
|
||||
:param source: Source identifier
|
||||
:param include_root_files: Whether to include root-level .md files
|
||||
@@ -71,38 +72,41 @@ class SkillLoader:
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
# If this directory has its own SKILL.md, load it and stop recursing.
|
||||
# The sub-directories are internal resources of this skill.
|
||||
if not include_root_files and 'SKILL.md' in entries:
|
||||
skill_md_path = os.path.join(dir_path, 'SKILL.md')
|
||||
if os.path.isfile(skill_md_path):
|
||||
skill_result = self._load_skill_from_file(skill_md_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
diagnostics.extend(skill_result.diagnostics)
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
for entry in entries:
|
||||
# Skip hidden files and directories
|
||||
if entry.startswith('.'):
|
||||
continue
|
||||
|
||||
# Skip common non-skill directories
|
||||
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
|
||||
continue
|
||||
|
||||
full_path = os.path.join(dir_path, entry)
|
||||
|
||||
# Handle directories
|
||||
if os.path.isdir(full_path):
|
||||
# Recursively scan subdirectories
|
||||
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
|
||||
skills.extend(sub_result.skills)
|
||||
diagnostics.extend(sub_result.diagnostics)
|
||||
continue
|
||||
|
||||
# Handle files
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
|
||||
# Check if this is a skill file
|
||||
is_root_md = include_root_files and entry.endswith('.md')
|
||||
is_skill_md = not include_root_files and entry == 'SKILL.md'
|
||||
is_root_md = include_root_files and entry.endswith('.md') and entry.upper() != 'README.MD'
|
||||
|
||||
if not (is_root_md or is_skill_md):
|
||||
if not is_root_md:
|
||||
continue
|
||||
|
||||
# Load the skill
|
||||
skill_result = self._load_skill_from_file(full_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
@@ -189,7 +193,6 @@ class SkillLoader:
|
||||
|
||||
config_path = os.path.join(skill_dir, "config.json")
|
||||
|
||||
# Without config.json, skip this skill entirely (return empty to trigger exclusion)
|
||||
if not os.path.exists(config_path):
|
||||
logger.debug(f"[SkillLoader] linkai-agent skipped: no config.json found")
|
||||
return ""
|
||||
@@ -216,61 +219,49 @@ class SkillLoader:
|
||||
|
||||
def load_all_skills(
|
||||
self,
|
||||
managed_dir: Optional[str] = None,
|
||||
workspace_skills_dir: Optional[str] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
) -> Dict[str, SkillEntry]:
|
||||
"""
|
||||
Load skills from all configured locations with precedence.
|
||||
|
||||
Load skills from builtin and custom directories.
|
||||
|
||||
Precedence (lowest to highest):
|
||||
1. Extra directories
|
||||
2. Managed skills directory
|
||||
3. Workspace skills directory
|
||||
|
||||
:param managed_dir: Managed skills directory (e.g., ~/.cow/skills)
|
||||
:param workspace_skills_dir: Workspace skills directory (e.g., workspace/skills)
|
||||
:param extra_dirs: Additional directories to load skills from
|
||||
1. builtin — project root ``skills/``, shipped with the codebase
|
||||
2. custom — workspace ``skills/``, installed via cloud console or skill creator
|
||||
|
||||
Same-name custom skills override builtin ones.
|
||||
|
||||
:param builtin_dir: Built-in skills directory
|
||||
:param custom_dir: Custom skills directory
|
||||
:return: Dictionary mapping skill name to SkillEntry
|
||||
"""
|
||||
skill_map: Dict[str, SkillEntry] = {}
|
||||
all_diagnostics = []
|
||||
|
||||
# Load from extra directories (lowest precedence)
|
||||
if extra_dirs:
|
||||
for extra_dir in extra_dirs:
|
||||
if not os.path.exists(extra_dir):
|
||||
continue
|
||||
result = self.load_skills_from_dir(extra_dir, source='extra')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load from managed directory
|
||||
if managed_dir and os.path.exists(managed_dir):
|
||||
result = self.load_skills_from_dir(managed_dir, source='managed')
|
||||
|
||||
# Load builtin skills (lower precedence)
|
||||
if builtin_dir and os.path.exists(builtin_dir):
|
||||
result = self.load_skills_from_dir(builtin_dir, source='builtin')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load from workspace directory (highest precedence)
|
||||
if workspace_skills_dir and os.path.exists(workspace_skills_dir):
|
||||
result = self.load_skills_from_dir(workspace_skills_dir, source='workspace')
|
||||
|
||||
# Load custom skills (higher precedence, overrides builtin)
|
||||
if custom_dir and os.path.exists(custom_dir):
|
||||
result = self.load_skills_from_dir(custom_dir, source='custom')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
|
||||
# Log diagnostics
|
||||
if all_diagnostics:
|
||||
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
|
||||
for diag in all_diagnostics[:5]: # Log first 5
|
||||
for diag in all_diagnostics[:5]:
|
||||
logger.debug(f" - {diag}")
|
||||
|
||||
logger.debug(f"Loaded {len(skill_map)} skills from all sources")
|
||||
|
||||
|
||||
logger.debug(f"Loaded {len(skill_map)} skills total")
|
||||
|
||||
return skill_map
|
||||
|
||||
def _create_skill_entry(self, skill: Skill) -> SkillEntry:
|
||||
|
||||
@@ -3,6 +3,7 @@ Skill manager for managing skill lifecycle and operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
@@ -10,56 +11,143 @@ from agent.skills.types import Skill, SkillEntry, SkillSnapshot
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.formatter import format_skill_entries_for_prompt
|
||||
|
||||
SKILLS_CONFIG_FILE = "skills_config.json"
|
||||
|
||||
|
||||
class SkillManager:
|
||||
"""Manages skills for an agent."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Optional[str] = None,
|
||||
managed_skills_dir: Optional[str] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
config: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the skill manager.
|
||||
|
||||
:param workspace_dir: Agent workspace directory
|
||||
:param managed_skills_dir: Managed skills directory (e.g., ~/.cow/skills)
|
||||
:param extra_dirs: Additional skill directories
|
||||
|
||||
:param builtin_dir: Built-in skills directory (project root ``skills/``)
|
||||
:param custom_dir: Custom skills directory (workspace ``skills/``)
|
||||
:param config: Configuration dictionary
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.managed_skills_dir = managed_skills_dir or self._get_default_managed_dir()
|
||||
self.extra_dirs = extra_dirs or []
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
self.builtin_dir = builtin_dir or os.path.join(project_root, 'skills')
|
||||
self.custom_dir = custom_dir or os.path.join(project_root, 'workspace', 'skills')
|
||||
self.config = config or {}
|
||||
|
||||
self.loader = SkillLoader(workspace_dir=workspace_dir)
|
||||
self._skills_config_path = os.path.join(self.custom_dir, SKILLS_CONFIG_FILE)
|
||||
|
||||
# skills_config: full skill metadata keyed by name
|
||||
# { "web-fetch": {"name": ..., "description": ..., "source": ..., "enabled": true}, ... }
|
||||
self.skills_config: Dict[str, dict] = {}
|
||||
|
||||
self.loader = SkillLoader()
|
||||
self.skills: Dict[str, SkillEntry] = {}
|
||||
|
||||
|
||||
# Load skills on initialization
|
||||
self.refresh_skills()
|
||||
|
||||
def _get_default_managed_dir(self) -> str:
|
||||
"""Get the default managed skills directory."""
|
||||
# Use project root skills directory as default
|
||||
import os
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
return os.path.join(project_root, 'skills')
|
||||
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Reload all skills from configured directories."""
|
||||
workspace_skills_dir = None
|
||||
if self.workspace_dir:
|
||||
workspace_skills_dir = os.path.join(self.workspace_dir, 'skills')
|
||||
|
||||
"""Reload all skills from builtin and custom directories, then sync config."""
|
||||
self.skills = self.loader.load_all_skills(
|
||||
managed_dir=self.managed_skills_dir,
|
||||
workspace_skills_dir=workspace_skills_dir,
|
||||
extra_dirs=self.extra_dirs,
|
||||
builtin_dir=self.builtin_dir,
|
||||
custom_dir=self.custom_dir,
|
||||
)
|
||||
|
||||
self._sync_skills_config()
|
||||
logger.debug(f"SkillManager: Loaded {len(self.skills)} skills")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# skills_config.json management
|
||||
# ------------------------------------------------------------------
|
||||
def _load_skills_config(self) -> Dict[str, dict]:
|
||||
"""Load skills_config.json from custom_dir. Returns empty dict if not found."""
|
||||
if not os.path.exists(self._skills_config_path):
|
||||
return {}
|
||||
try:
|
||||
with open(self._skills_config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning(f"[SkillManager] Failed to load {SKILLS_CONFIG_FILE}: {e}")
|
||||
return {}
|
||||
|
||||
def _save_skills_config(self):
|
||||
"""Persist skills_config to custom_dir/skills_config.json."""
|
||||
os.makedirs(self.custom_dir, exist_ok=True)
|
||||
try:
|
||||
with open(self._skills_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.skills_config, f, indent=4, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillManager] Failed to save {SKILLS_CONFIG_FILE}: {e}")
|
||||
|
||||
def _sync_skills_config(self):
|
||||
"""
|
||||
Merge directory-scanned skills with the persisted config file.
|
||||
|
||||
- New skills: use metadata.default_enabled as initial enabled state.
|
||||
- Existing skills: preserve their persisted enabled state.
|
||||
- Skills that no longer exist on disk are removed.
|
||||
- name/description/source are always refreshed from the latest scan.
|
||||
"""
|
||||
saved = self._load_skills_config()
|
||||
merged: Dict[str, dict] = {}
|
||||
|
||||
for name, entry in self.skills.items():
|
||||
skill = entry.skill
|
||||
prev = saved.get(name, {})
|
||||
category = prev.get("category", "skill")
|
||||
|
||||
if name in saved:
|
||||
enabled = prev.get("enabled", True)
|
||||
else:
|
||||
enabled = entry.metadata.default_enabled if entry.metadata else True
|
||||
|
||||
entry_dict = {
|
||||
"name": name,
|
||||
"description": skill.description,
|
||||
"source": prev.get("source") or skill.source,
|
||||
"enabled": enabled,
|
||||
"category": category,
|
||||
}
|
||||
display_name = prev.get("display_name")
|
||||
if display_name:
|
||||
entry_dict["display_name"] = display_name
|
||||
merged[name] = entry_dict
|
||||
|
||||
self.skills_config = merged
|
||||
self._save_skills_config()
|
||||
|
||||
def is_skill_enabled(self, name: str) -> bool:
|
||||
"""
|
||||
Check if a skill is enabled according to skills_config.
|
||||
|
||||
:param name: skill name
|
||||
:return: True if enabled (default True if not in config)
|
||||
"""
|
||||
entry = self.skills_config.get(name)
|
||||
if entry is None:
|
||||
return True
|
||||
return entry.get("enabled", True)
|
||||
|
||||
def set_skill_enabled(self, name: str, enabled: bool):
|
||||
"""
|
||||
Set a skill's enabled state and persist.
|
||||
|
||||
:param name: skill name
|
||||
:param enabled: True to enable, False to disable
|
||||
"""
|
||||
if name not in self.skills_config:
|
||||
raise ValueError(f"skill '{name}' not found in config")
|
||||
self.skills_config[name]["enabled"] = enabled
|
||||
self._save_skills_config()
|
||||
|
||||
def get_skills_config(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Return the full skills_config dict (for query API).
|
||||
|
||||
:return: copy of skills_config
|
||||
"""
|
||||
return dict(self.skills_config)
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
@@ -78,72 +166,118 @@ class SkillManager:
|
||||
"""
|
||||
return list(self.skills.values())
|
||||
|
||||
@staticmethod
|
||||
def _normalize_skill_filter(skill_filter: Optional[List[str]]) -> Optional[List[str]]:
|
||||
"""Normalize a skill_filter list into a flat list of stripped names."""
|
||||
if skill_filter is None:
|
||||
return None
|
||||
normalized = []
|
||||
for item in skill_filter:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
elif isinstance(item, list):
|
||||
for subitem in item:
|
||||
if isinstance(subitem, str):
|
||||
name = subitem.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
return normalized or None
|
||||
|
||||
def filter_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
include_disabled: bool = False,
|
||||
) -> List[SkillEntry]:
|
||||
"""
|
||||
Filter skills based on criteria.
|
||||
|
||||
Simple rule: Skills are auto-enabled if requirements are met.
|
||||
- Has required API keys → included
|
||||
- Missing API keys → excluded
|
||||
|
||||
Filter skills that are eligible (enabled + requirements met).
|
||||
|
||||
:param skill_filter: List of skill names to include (None = all)
|
||||
:param include_disabled: Whether to include skills with disable_model_invocation=True
|
||||
:return: Filtered list of skill entries
|
||||
:param include_disabled: Whether to include disabled skills
|
||||
:return: Filtered list of eligible skill entries
|
||||
"""
|
||||
from agent.skills.config import should_include_skill
|
||||
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Check requirements (platform, binaries, env vars)
|
||||
|
||||
entries = [e for e in entries if should_include_skill(e, self.config)]
|
||||
|
||||
# Apply skill filter
|
||||
if skill_filter is not None:
|
||||
# Flatten and normalize skill names (handle both strings and nested lists)
|
||||
normalized = []
|
||||
for item in skill_filter:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
elif isinstance(item, list):
|
||||
# Handle nested lists
|
||||
for subitem in item:
|
||||
if isinstance(subitem, str):
|
||||
name = subitem.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
|
||||
if normalized:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Filter out disabled skills unless explicitly requested
|
||||
|
||||
normalized = self._normalize_skill_filter(skill_filter)
|
||||
if normalized is not None:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
if not include_disabled:
|
||||
entries = [e for e in entries if not e.skill.disable_model_invocation]
|
||||
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
from config import conf
|
||||
if not conf().get("knowledge", True):
|
||||
entries = [e for e in entries if e.skill.name != "knowledge-wiki"]
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def filter_unavailable_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Find skills that are enabled but have unmet requirements.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Tuple of (entries, missing_map) where missing_map maps
|
||||
skill name to its missing requirements dict
|
||||
"""
|
||||
from agent.skills.config import should_include_skill, get_missing_requirements
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Only enabled skills
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
normalized = self._normalize_skill_filter(skill_filter)
|
||||
if normalized is not None:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Keep only those that fail should_include_skill (requirements not met)
|
||||
unavailable = []
|
||||
missing_map: Dict[str, dict] = {}
|
||||
for e in entries:
|
||||
if not should_include_skill(e, self.config):
|
||||
missing = get_missing_requirements(e)
|
||||
if missing:
|
||||
unavailable.append(e)
|
||||
missing_map[e.skill.name] = missing
|
||||
|
||||
return unavailable, missing_map
|
||||
|
||||
def build_skills_prompt(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a formatted prompt containing available skills.
|
||||
|
||||
Build a formatted prompt containing available skills
|
||||
and brief hints for unavailable ones.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt
|
||||
"""
|
||||
from common.log import logger
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
logger.debug(f"[SkillManager] Filtered {len(entries)} skills for prompt (total: {len(self.skills)})")
|
||||
if entries:
|
||||
skill_names = [e.skill.name for e in entries]
|
||||
logger.debug(f"[SkillManager] Skills to include: {skill_names}")
|
||||
result = format_skill_entries_for_prompt(entries)
|
||||
from agent.skills.formatter import format_unavailable_skills_for_prompt
|
||||
|
||||
eligible = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
logger.debug(f"[SkillManager] Eligible: {len(eligible)} skills (total: {len(self.skills)})")
|
||||
if eligible:
|
||||
skill_names = [e.skill.name for e in eligible]
|
||||
logger.debug(f"[SkillManager] Eligible skills: {skill_names}")
|
||||
|
||||
result = format_skill_entries_for_prompt(eligible)
|
||||
|
||||
unavailable, missing_map = self.filter_unavailable_skills(skill_filter=skill_filter)
|
||||
if unavailable:
|
||||
unavailable_names = [e.skill.name for e in unavailable]
|
||||
logger.debug(f"[SkillManager] Unavailable skills (setup needed): {unavailable_names}")
|
||||
result += format_unavailable_skills_for_prompt(unavailable, missing_map)
|
||||
|
||||
logger.debug(f"[SkillManager] Generated prompt length: {len(result)}")
|
||||
return result
|
||||
|
||||
|
||||
285
agent/skills/service.py
Normal file
285
agent/skills/service.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Skill service for handling skill CRUD operations.
|
||||
|
||||
This service provides a unified interface for managing skills, which can be
|
||||
called from the cloud control client (LinkAI), the local web console, or any
|
||||
other management entry point.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tempfile
|
||||
from typing import Dict, List, Optional
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
from agent.skills.manager import SkillManager
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""
|
||||
High-level service for skill lifecycle management.
|
||||
Wraps SkillManager and provides network-aware operations such as
|
||||
downloading skill files from remote URLs.
|
||||
"""
|
||||
|
||||
def __init__(self, skill_manager: SkillManager):
|
||||
"""
|
||||
:param skill_manager: The SkillManager instance to operate on
|
||||
"""
|
||||
self.manager = skill_manager
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# query
|
||||
# ------------------------------------------------------------------
|
||||
def query(self) -> List[dict]:
|
||||
"""
|
||||
Query all skills and return a serialisable list.
|
||||
Reads from skills_config.json (refreshes from disk if needed).
|
||||
|
||||
:return: list of skill info dicts
|
||||
"""
|
||||
self.manager.refresh_skills()
|
||||
config = self.manager.get_skills_config()
|
||||
result = list(config.values())
|
||||
logger.info(f"[SkillService] query: {len(result)} skills found")
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# add / install
|
||||
# ------------------------------------------------------------------
|
||||
def add(self, payload: dict) -> None:
|
||||
"""
|
||||
Add (install) a skill from a remote payload.
|
||||
|
||||
Supported payload types:
|
||||
|
||||
1. ``type: "url"`` – download individual files::
|
||||
|
||||
{
|
||||
"name": "web_search",
|
||||
"type": "url",
|
||||
"enabled": true,
|
||||
"files": [
|
||||
{"url": "https://...", "path": "README.md"},
|
||||
{"url": "https://...", "path": "scripts/main.py"}
|
||||
]
|
||||
}
|
||||
|
||||
2. ``type: "package"`` – download a zip archive and extract::
|
||||
|
||||
{
|
||||
"name": "plugin-custom-tool",
|
||||
"type": "package",
|
||||
"category": "skills",
|
||||
"enabled": true,
|
||||
"files": [{"url": "https://cdn.example.com/skills/custom-tool.zip"}]
|
||||
}
|
||||
|
||||
:param payload: skill add payload from server
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
payload_type = payload.get("type", "url")
|
||||
|
||||
if payload_type == "package":
|
||||
self._add_package(name, payload)
|
||||
else:
|
||||
self._add_url(name, payload)
|
||||
|
||||
self.manager.refresh_skills()
|
||||
|
||||
category = payload.get("category")
|
||||
if category and name in self.manager.skills_config:
|
||||
self.manager.skills_config[name]["category"] = category
|
||||
self.manager._save_skills_config()
|
||||
|
||||
def _add_url(self, name: str, payload: dict) -> None:
|
||||
"""Install a skill by downloading individual files."""
|
||||
files = payload.get("files", [])
|
||||
if not files:
|
||||
raise ValueError("skill files list is empty")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
tmp_dir = skill_dir + ".tmp"
|
||||
if os.path.exists(tmp_dir):
|
||||
shutil.rmtree(tmp_dir)
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
for file_info in files:
|
||||
url = file_info.get("url")
|
||||
rel_path = file_info.get("path")
|
||||
if not url or not rel_path:
|
||||
logger.warning(f"[SkillService] add: skip invalid file entry {file_info}")
|
||||
continue
|
||||
dest = os.path.join(tmp_dir, rel_path)
|
||||
self._download_file(url, dest)
|
||||
except Exception:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
os.rename(tmp_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via url ({len(files)} files)")
|
||||
|
||||
def _add_package(self, name: str, payload: dict) -> None:
|
||||
"""
|
||||
Install a skill by downloading a zip archive and extracting it.
|
||||
|
||||
If the archive contains a single top-level directory, that directory
|
||||
is used as the skill folder directly; otherwise a new directory named
|
||||
after the skill is created to hold the extracted contents.
|
||||
"""
|
||||
files = payload.get("files", [])
|
||||
if not files or not files[0].get("url"):
|
||||
raise ValueError("package url is required")
|
||||
|
||||
url = files[0]["url"]
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
zip_path = os.path.join(tmp_dir, "package.zip")
|
||||
self._download_file(url, zip_path)
|
||||
|
||||
if not zipfile.is_zipfile(zip_path):
|
||||
raise ValueError(f"downloaded file is not a valid zip archive: {url}")
|
||||
|
||||
extract_dir = os.path.join(tmp_dir, "extracted")
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(extract_dir)
|
||||
|
||||
# Determine the actual content root.
|
||||
# If the zip has a single top-level directory, use its contents
|
||||
# so the skill folder is clean (no extra nesting).
|
||||
top_items = [
|
||||
item for item in os.listdir(extract_dir)
|
||||
if not item.startswith(".")
|
||||
]
|
||||
if len(top_items) == 1:
|
||||
single = os.path.join(extract_dir, top_items[0])
|
||||
if os.path.isdir(single):
|
||||
extract_dir = single
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
shutil.copytree(extract_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via package ({url})")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# open / close (enable / disable)
|
||||
# ------------------------------------------------------------------
|
||||
def open(self, payload: dict) -> None:
|
||||
"""
|
||||
Enable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=True)
|
||||
logger.info(f"[SkillService] open: skill '{name}' enabled")
|
||||
|
||||
def close(self, payload: dict) -> None:
|
||||
"""
|
||||
Disable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=False)
|
||||
logger.info(f"[SkillService] close: skill '{name}' disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete
|
||||
# ------------------------------------------------------------------
|
||||
def delete(self, payload: dict) -> None:
|
||||
"""
|
||||
Delete a skill by removing its directory entirely.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
logger.info(f"[SkillService] delete: removed directory {skill_dir}")
|
||||
else:
|
||||
logger.warning(f"[SkillService] delete: skill directory not found: {skill_dir}")
|
||||
|
||||
# Refresh will remove the deleted skill from config automatically
|
||||
self.manager.refresh_skills()
|
||||
logger.info(f"[SkillService] delete: skill '{name}' deleted")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch - single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a skill management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
:param action: one of query / add / open / close / delete
|
||||
:param payload: action-specific payload (may be None for query)
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "query":
|
||||
result_payload = self.query()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
elif action == "add":
|
||||
self.add(payload)
|
||||
elif action == "open":
|
||||
self.open(payload)
|
||||
elif action == "close":
|
||||
self.close(payload)
|
||||
elif action == "delete":
|
||||
self.delete(payload)
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _download_file(url: str, dest: str):
|
||||
"""
|
||||
Download a file from *url* and save to *dest*.
|
||||
|
||||
:param url: remote file URL
|
||||
:param dest: local destination path
|
||||
"""
|
||||
if requests is None:
|
||||
raise RuntimeError("requests library is required for downloading skill files")
|
||||
|
||||
dest_dir = os.path.dirname(dest)
|
||||
if dest_dir:
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
resp = requests.get(url, timeout=60)
|
||||
resp.raise_for_status()
|
||||
with open(dest, "wb") as f:
|
||||
f.write(resp.content)
|
||||
logger.debug(f"[SkillService] downloaded {url} -> {dest}")
|
||||
@@ -29,6 +29,7 @@ class SkillInstallSpec:
|
||||
class SkillMetadata:
|
||||
"""Metadata for a skill from frontmatter."""
|
||||
always: bool = False # Always include this skill
|
||||
default_enabled: bool = True # Initial enabled state when first discovered
|
||||
skill_key: Optional[str] = None # Override skill key
|
||||
primary_env: Optional[str] = None # Primary environment variable
|
||||
emoji: Optional[str] = None
|
||||
@@ -45,7 +46,7 @@ class Skill:
|
||||
description: str
|
||||
file_path: str
|
||||
base_dir: str
|
||||
source: str # managed, workspace, bundled, etc.
|
||||
source: str # builtin or custom
|
||||
content: str # Full markdown content
|
||||
disable_model_invocation: bool = False
|
||||
frontmatter: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -55,6 +55,24 @@ def _import_optional_tools():
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebSearch failed to load: {e}")
|
||||
|
||||
# WebFetch Tool
|
||||
try:
|
||||
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||
tools['WebFetch'] = WebFetch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebFetch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebFetch failed to load: {e}")
|
||||
|
||||
# Vision Tool (conditionally loaded based on API key availability)
|
||||
try:
|
||||
from agent.tools.vision.vision import Vision
|
||||
tools['Vision'] = Vision
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] Vision not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Vision failed to load: {e}")
|
||||
|
||||
return tools
|
||||
|
||||
# Load optional tools
|
||||
@@ -62,30 +80,32 @@ _optional_tools = _import_optional_tools()
|
||||
EnvConfig = _optional_tools.get('EnvConfig')
|
||||
SchedulerTool = _optional_tools.get('SchedulerTool')
|
||||
WebSearch = _optional_tools.get('WebSearch')
|
||||
WebFetch = _optional_tools.get('WebFetch')
|
||||
Vision = _optional_tools.get('Vision')
|
||||
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
Terminal = _optional_tools.get('Terminal')
|
||||
|
||||
|
||||
# Delayed import for BrowserTool
|
||||
# BrowserTool (requires playwright)
|
||||
def _import_browser_tool():
|
||||
from common.log import logger
|
||||
try:
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
return BrowserTool
|
||||
except ImportError:
|
||||
# Return a placeholder class that will prompt the user to install dependencies when instantiated
|
||||
class BrowserToolPlaceholder:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"The 'browser-use' package is required to use BrowserTool. "
|
||||
"Please install it with 'pip install browser-use>=0.1.40'."
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.info(
|
||||
f"[Tools] BrowserTool not loaded - missing dependency: {e}\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] BrowserTool failed to load: {e}")
|
||||
return None
|
||||
|
||||
return BrowserToolPlaceholder
|
||||
|
||||
|
||||
# Dynamically set BrowserTool
|
||||
# BrowserTool = _import_browser_tool()
|
||||
BrowserTool = _import_browser_tool()
|
||||
|
||||
# Export all tools (including optional ones that might be None)
|
||||
__all__ = [
|
||||
@@ -102,8 +122,9 @@ __all__ = [
|
||||
'EnvConfig',
|
||||
'SchedulerTool',
|
||||
'WebSearch',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
# 'BrowserTool'
|
||||
'WebFetch',
|
||||
'Vision',
|
||||
'BrowserTool',
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ Bash tool - Execute bash commands
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
@@ -17,9 +18,13 @@ from common.utils import expand_path
|
||||
class Bash(BaseTool):
|
||||
"""Tool for executing bash commands"""
|
||||
|
||||
_IS_WIN = sys.platform == "win32"
|
||||
|
||||
name: str = "bash"
|
||||
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
|
||||
|
||||
{'''
|
||||
PLATFORM: Windows (cmd.exe). Do NOT use Unix-only commands like grep, head, tail, sed, awk.
|
||||
''' if _IS_WIN else ''}
|
||||
ENVIRONMENT: All API keys from env_config are auto-injected. Use $VAR_NAME directly.
|
||||
|
||||
SAFETY:
|
||||
@@ -83,12 +88,13 @@ SAFETY:
|
||||
|
||||
# Load environment variables from ~/.cow/.env if it exists
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
dotenv_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
env_vars = dotenv_values(env_file)
|
||||
env.update(env_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(env_vars)} variables from {env_file}")
|
||||
dotenv_vars = dotenv_values(env_file)
|
||||
env.update(dotenv_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(dotenv_vars)} variables from {env_file}")
|
||||
except ImportError:
|
||||
logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
|
||||
except Exception as e:
|
||||
@@ -100,7 +106,13 @@ SAFETY:
|
||||
else:
|
||||
logger.debug(f"[Bash] Process User: {os.environ.get('USERNAME', os.environ.get('USER', 'unknown'))}")
|
||||
|
||||
# Execute command with inherited environment variables
|
||||
# On Windows, convert $VAR references to %VAR% for cmd.exe
|
||||
if self._IS_WIN:
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
command = self._convert_env_vars_for_windows(command, dotenv_vars)
|
||||
if command and not command.strip().lower().startswith("chcp"):
|
||||
command = f"chcp 65001 >nul 2>&1 && {command}"
|
||||
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
@@ -108,8 +120,10 @@ SAFETY:
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
env=env,
|
||||
)
|
||||
|
||||
logger.debug(f"[Bash] Exit code: {result.returncode}")
|
||||
@@ -131,6 +145,8 @@ SAFETY:
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
@@ -153,10 +169,16 @@ SAFETY:
|
||||
except Exception as retry_err:
|
||||
logger.warning(f"[Bash] Retry failed: {retry_err}")
|
||||
|
||||
# Combine stdout and stderr
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
# When command succeeds with stdout, keep output clean (stderr goes to server log only).
|
||||
# When command fails or stdout is empty, include stderr so the agent can diagnose.
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
logger.info(f"[Bash] stderr (not forwarded): {result.stderr[:500]}")
|
||||
else:
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
|
||||
# Check if we need to save full output to temp file
|
||||
temp_file_path = None
|
||||
@@ -258,3 +280,21 @@ SAFETY:
|
||||
return "This command will recursively delete system directories"
|
||||
|
||||
return "" # No warning needed
|
||||
|
||||
@staticmethod
|
||||
def _convert_env_vars_for_windows(command: str, dotenv_vars: dict) -> str:
|
||||
"""
|
||||
Convert bash-style $VAR / ${VAR} references to cmd.exe %VAR% syntax.
|
||||
Only converts variables loaded from .env (user-configured API keys etc.)
|
||||
to avoid breaking $PATH, jq expressions, regex, etc.
|
||||
"""
|
||||
if not dotenv_vars:
|
||||
return command
|
||||
|
||||
def replace_match(m):
|
||||
var_name = m.group(1) or m.group(2)
|
||||
if var_name in dotenv_vars:
|
||||
return f"%{var_name}%"
|
||||
return m.group(0)
|
||||
|
||||
return re.sub(r'\$\{(\w+)\}|\$(\w+)', replace_match, command)
|
||||
|
||||
3
agent/tools/browser/__init__.py
Normal file
3
agent/tools/browser/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
|
||||
__all__ = ["BrowserTool"]
|
||||
780
agent/tools/browser/browser_service.py
Normal file
780
agent/tools/browser/browser_service.py
Normal file
@@ -0,0 +1,780 @@
|
||||
"""
|
||||
Browser service - Playwright wrapper managing browser lifecycle and page operations.
|
||||
|
||||
All Playwright calls run on a dedicated background thread so that callers from
|
||||
any worker thread can safely use the service. An idle-timeout mechanism
|
||||
automatically shuts down the browser (and its thread) after a configurable
|
||||
period of inactivity to free resources.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import queue
|
||||
import threading
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
|
||||
from common.log import logger
|
||||
|
||||
try:
|
||||
from playwright.sync_api import sync_playwright, Browser, BrowserContext, Page, Playwright
|
||||
_HAS_PLAYWRIGHT = True
|
||||
except ImportError:
|
||||
_HAS_PLAYWRIGHT = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot DOM helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Tags that typically carry useful content for an agent
|
||||
_INTERACTIVE_TAGS = {
|
||||
"a", "button", "input", "textarea", "select", "option",
|
||||
"label", "details", "summary",
|
||||
}
|
||||
_SEMANTIC_TAGS = {
|
||||
"h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"p", "li", "td", "th", "caption", "figcaption", "blockquote", "pre", "code",
|
||||
"nav", "main", "article", "section", "header", "footer", "form", "table",
|
||||
"img", "video", "audio",
|
||||
}
|
||||
_KEEP_TAGS = _INTERACTIVE_TAGS | _SEMANTIC_TAGS
|
||||
|
||||
_SNAPSHOT_JS = """
|
||||
() => {
|
||||
const KEEP = new Set(%s);
|
||||
const INTERACTIVE = new Set(%s);
|
||||
const SKIP = new Set(["script","style","noscript","svg","path","meta","link","br","hr"]);
|
||||
const CLICKABLE_ROLES = new Set([
|
||||
"button","link","tab","menuitem","menuitemcheckbox","menuitemradio",
|
||||
"option","switch","checkbox","radio","combobox","searchbox","slider",
|
||||
"spinbutton","textbox","treeitem"
|
||||
]);
|
||||
let refCounter = 0;
|
||||
const refMap = {};
|
||||
|
||||
function visible(el) {
|
||||
if (!(el instanceof HTMLElement)) return true;
|
||||
const st = window.getComputedStyle(el);
|
||||
if (st.display === "none" || st.visibility === "hidden") return false;
|
||||
if (parseFloat(st.opacity) === 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Strong signals: these attributes alone are enough to mark as interactive
|
||||
function hasStrongInteractiveSignal(el) {
|
||||
const role = el.getAttribute("role");
|
||||
if (role && CLICKABLE_ROLES.has(role)) return true;
|
||||
if (el.hasAttribute("onclick") || el.hasAttribute("tabindex")) return true;
|
||||
if (el.hasAttribute("data-click") || el.hasAttribute("data-action")) return true;
|
||||
if (el.getAttribute("contenteditable") === "true") return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if cursor:pointer is set directly (not just inherited from parent)
|
||||
function hasOwnPointerCursor(el) {
|
||||
try {
|
||||
const st = window.getComputedStyle(el);
|
||||
if (st.cursor !== "pointer") return false;
|
||||
const parent = el.parentElement;
|
||||
if (parent) {
|
||||
const pst = window.getComputedStyle(parent);
|
||||
if (pst.cursor === "pointer") return false;
|
||||
}
|
||||
return true;
|
||||
} catch(e) {}
|
||||
return false;
|
||||
}
|
||||
|
||||
function hasTextOrContent(el) {
|
||||
const t = el.textContent || "";
|
||||
if (t.trim().length > 0) return true;
|
||||
if (el.querySelector("img,video,audio,canvas")) return true;
|
||||
const ariaLabel = el.getAttribute("aria-label");
|
||||
if (ariaLabel && ariaLabel.trim()) return true;
|
||||
const title = el.getAttribute("title");
|
||||
if (title && title.trim()) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function isImplicitInteractive(el) {
|
||||
if (hasStrongInteractiveSignal(el)) return true;
|
||||
if (hasOwnPointerCursor(el) && hasTextOrContent(el)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function getTextContent(el) {
|
||||
let text = "";
|
||||
for (const ch of el.childNodes) {
|
||||
if (ch.nodeType === Node.TEXT_NODE) {
|
||||
text += ch.textContent;
|
||||
}
|
||||
}
|
||||
return text.trim();
|
||||
}
|
||||
|
||||
function walk(node) {
|
||||
if (node.nodeType === Node.TEXT_NODE) {
|
||||
const t = node.textContent.trim();
|
||||
return t ? t : null;
|
||||
}
|
||||
if (node.nodeType !== Node.ELEMENT_NODE) return null;
|
||||
const tag = node.tagName.toLowerCase();
|
||||
if (SKIP.has(tag)) return null;
|
||||
if (!visible(node)) return null;
|
||||
|
||||
const children = [];
|
||||
for (const ch of node.childNodes) {
|
||||
const r = walk(ch);
|
||||
if (r !== null) {
|
||||
if (typeof r === "string") children.push(r);
|
||||
else children.push(r);
|
||||
}
|
||||
}
|
||||
|
||||
const nativeInteractive = INTERACTIVE.has(tag);
|
||||
const implicitInteractive = !nativeInteractive && (node instanceof HTMLElement) && isImplicitInteractive(node);
|
||||
const keep = KEEP.has(tag) || implicitInteractive;
|
||||
|
||||
if (!keep) {
|
||||
if (children.length === 0) return null;
|
||||
if (children.length === 1) return children[0];
|
||||
return children;
|
||||
}
|
||||
|
||||
const obj = { tag };
|
||||
if (nativeInteractive || implicitInteractive) {
|
||||
refCounter++;
|
||||
obj.ref = refCounter;
|
||||
refMap[refCounter] = node;
|
||||
}
|
||||
|
||||
if (implicitInteractive) {
|
||||
const role = node.getAttribute("role");
|
||||
if (role) obj.role = role;
|
||||
const directText = getTextContent(node);
|
||||
if (!directText && children.length === 0) {
|
||||
const ariaLabel = node.getAttribute("aria-label");
|
||||
const title = node.getAttribute("title");
|
||||
if (ariaLabel) obj.ariaLabel = ariaLabel;
|
||||
else if (title) obj.ariaLabel = title;
|
||||
}
|
||||
}
|
||||
|
||||
// Attributes
|
||||
if (tag === "a" && node.href) obj.href = node.getAttribute("href");
|
||||
if (tag === "img") {
|
||||
obj.alt = node.alt || "";
|
||||
obj.src = node.getAttribute("src") || "";
|
||||
}
|
||||
if (tag === "input" || tag === "textarea" || tag === "select") {
|
||||
obj.type = node.type || "text";
|
||||
obj.name = node.name || undefined;
|
||||
obj.value = node.value || undefined;
|
||||
obj.placeholder = node.placeholder || undefined;
|
||||
if (node.disabled) obj.disabled = true;
|
||||
if (tag === "input" && node.type === "checkbox") obj.checked = node.checked;
|
||||
}
|
||||
if (tag === "button") {
|
||||
if (node.disabled) obj.disabled = true;
|
||||
}
|
||||
if (tag === "option") {
|
||||
obj.value = node.value;
|
||||
if (node.selected) obj.selected = true;
|
||||
}
|
||||
if (tag === "label" && node.htmlFor) obj.for = node.htmlFor;
|
||||
|
||||
// Role / aria-label for native interactive & semantic elements
|
||||
if (!implicitInteractive) {
|
||||
const role = node.getAttribute("role");
|
||||
if (role) obj.role = role;
|
||||
const ariaLabel = node.getAttribute("aria-label");
|
||||
if (ariaLabel) obj.ariaLabel = ariaLabel;
|
||||
}
|
||||
|
||||
// Children
|
||||
if (children.length === 1 && typeof children[0] === "string") {
|
||||
obj.text = children[0];
|
||||
} else if (children.length > 0) {
|
||||
obj.children = children;
|
||||
}
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
const result = walk(document.body);
|
||||
window.__cowRefMap = refMap;
|
||||
return { tree: result, refCount: refCounter };
|
||||
}
|
||||
""" % (
|
||||
str(list(_KEEP_TAGS)),
|
||||
str(list(_INTERACTIVE_TAGS)),
|
||||
)
|
||||
|
||||
|
||||
def _should_use_headless() -> bool:
|
||||
"""Decide headless mode: headless on Linux servers without display, headed elsewhere."""
|
||||
if sys.platform in ("win32", "darwin"):
|
||||
return False
|
||||
# Linux: check for display
|
||||
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _flatten_tree(node, indent=0) -> List[str]:
|
||||
"""Convert snapshot tree to compact text lines for LLM consumption."""
|
||||
if node is None:
|
||||
return []
|
||||
if isinstance(node, str):
|
||||
return [" " * indent + node]
|
||||
if isinstance(node, list):
|
||||
lines = []
|
||||
for child in node:
|
||||
lines.extend(_flatten_tree(child, indent))
|
||||
return lines
|
||||
if not isinstance(node, dict):
|
||||
return []
|
||||
|
||||
tag = node.get("tag", "?")
|
||||
ref = node.get("ref")
|
||||
parts = [tag]
|
||||
if ref:
|
||||
parts[0] = f"[{ref}] {tag}"
|
||||
|
||||
# Inline attributes
|
||||
for attr in ("type", "name", "href", "alt", "role", "ariaLabel", "placeholder", "value"):
|
||||
val = node.get(attr)
|
||||
if val:
|
||||
# Truncate long values
|
||||
s = str(val)
|
||||
if len(s) > 80:
|
||||
s = s[:77] + "..."
|
||||
parts.append(f'{attr}="{s}"')
|
||||
|
||||
for flag in ("disabled", "checked", "selected"):
|
||||
if node.get(flag):
|
||||
parts.append(flag)
|
||||
|
||||
prefix = " " * indent
|
||||
header = prefix + " ".join(parts)
|
||||
|
||||
text = node.get("text")
|
||||
if text:
|
||||
# Truncate long text
|
||||
if len(text) > 120:
|
||||
text = text[:117] + "..."
|
||||
header += f": {text}"
|
||||
|
||||
lines = [header]
|
||||
children = node.get("children", [])
|
||||
for child in children:
|
||||
lines.extend(_flatten_tree(child, indent + 2))
|
||||
return lines
|
||||
|
||||
|
||||
class BrowserService:
|
||||
"""Manages a Playwright browser on a dedicated background thread.
|
||||
|
||||
All Playwright operations are dispatched to a single long-lived thread via
|
||||
a task queue. Callers from *any* worker thread can use the public API
|
||||
safely. An idle timer automatically shuts the browser down after
|
||||
``idle_timeout`` seconds of inactivity (default 300 = 5 min).
|
||||
"""
|
||||
|
||||
_IDLE_TIMEOUT_DEFAULT = 300 # seconds
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
self._config = config or {}
|
||||
self._headless: Optional[bool] = None
|
||||
self._screenshot_dir: Optional[str] = None
|
||||
|
||||
# Background thread state
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._task_queue: queue.Queue = queue.Queue()
|
||||
self._lock = threading.Lock()
|
||||
self._alive = False
|
||||
self._ready = threading.Event()
|
||||
|
||||
# Playwright objects (only accessed on the background thread)
|
||||
self._playwright = None
|
||||
self._browser = None
|
||||
self._context = None
|
||||
self._page = None
|
||||
|
||||
# Idle auto-release
|
||||
idle_cfg = self._config.get("idle_timeout")
|
||||
self._idle_timeout: float = float(idle_cfg) if idle_cfg is not None else self._IDLE_TIMEOUT_DEFAULT
|
||||
self._idle_timer: Optional[threading.Timer] = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background-thread lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _start_thread(self):
|
||||
"""Start the dedicated Playwright thread if not already running."""
|
||||
with self._lock:
|
||||
if self._alive and self._thread and self._thread.is_alive():
|
||||
return
|
||||
# Wait for old thread to fully exit before creating a new one
|
||||
old = self._thread
|
||||
if old and old.is_alive():
|
||||
old.join(timeout=5)
|
||||
# Fresh queue to avoid stale sentinels from a previous close()
|
||||
self._task_queue = queue.Queue()
|
||||
self._alive = True
|
||||
self._ready = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True, name="BrowserThread")
|
||||
self._thread.start()
|
||||
# Block until browser is ready (or failed)
|
||||
self._ready.wait(timeout=30)
|
||||
|
||||
def _run_loop(self):
|
||||
"""Event loop running on the dedicated thread. Processes tasks until stopped."""
|
||||
logger.info("[Browser] Background thread started")
|
||||
try:
|
||||
self._launch_browser()
|
||||
except Exception as e:
|
||||
logger.error(f"[Browser] Failed to launch browser: {e}")
|
||||
self._alive = False
|
||||
self._ready.set()
|
||||
self._drain_queue(RuntimeError(f"Browser launch failed: {e}"))
|
||||
return
|
||||
self._ready.set()
|
||||
|
||||
while self._alive:
|
||||
try:
|
||||
task = self._task_queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if task is None:
|
||||
break
|
||||
fn, args, kwargs, result_slot = task
|
||||
try:
|
||||
result_slot["value"] = fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
result_slot["error"] = e
|
||||
finally:
|
||||
result_slot["event"].set()
|
||||
|
||||
self._shutdown_browser()
|
||||
self._drain_queue(RuntimeError("Browser thread stopped"))
|
||||
logger.info("[Browser] Background thread exited")
|
||||
|
||||
def _drain_queue(self, error: Exception):
|
||||
"""Unblock all callers waiting on the queue with an error."""
|
||||
while True:
|
||||
try:
|
||||
task = self._task_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
if task is None:
|
||||
continue
|
||||
_, _, _, result_slot = task
|
||||
result_slot["error"] = error
|
||||
result_slot["event"].set()
|
||||
|
||||
def _launch_browser(self):
|
||||
"""Launch Chromium on the background thread."""
|
||||
if self._headless is None:
|
||||
headless_cfg = self._config.get("headless")
|
||||
self._headless = headless_cfg if headless_cfg is not None else _should_use_headless()
|
||||
|
||||
launch_args = ["--disable-dev-shm-usage"]
|
||||
if self._headless:
|
||||
launch_args.append("--no-sandbox")
|
||||
|
||||
extra_args = self._config.get("launch_args", [])
|
||||
if extra_args:
|
||||
launch_args.extend(extra_args)
|
||||
|
||||
viewport_w = self._config.get("viewport_width", 1280)
|
||||
viewport_h = self._config.get("viewport_height", 720)
|
||||
|
||||
self._playwright = sync_playwright().start()
|
||||
logger.info(f"[Browser] Launching Chromium (headless={self._headless})")
|
||||
self._browser = self._playwright.chromium.launch(
|
||||
headless=self._headless,
|
||||
args=launch_args,
|
||||
)
|
||||
self._context = self._browser.new_context(
|
||||
viewport={"width": viewport_w, "height": viewport_h},
|
||||
user_agent=(
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
),
|
||||
)
|
||||
self._page = self._context.new_page()
|
||||
logger.info("[Browser] Browser ready")
|
||||
|
||||
def _shutdown_browser(self):
|
||||
"""Shut down all Playwright resources on the background thread."""
|
||||
self._cancel_idle_timer()
|
||||
for obj, label in [
|
||||
(self._context, "context"),
|
||||
(self._browser, "browser"),
|
||||
]:
|
||||
try:
|
||||
if obj:
|
||||
obj.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] {label} close error: {e}")
|
||||
try:
|
||||
if self._playwright:
|
||||
self._playwright.stop()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] playwright stop error: {e}")
|
||||
self._page = None
|
||||
self._context = None
|
||||
self._browser = None
|
||||
self._playwright = None
|
||||
logger.info("[Browser] Browser closed")
|
||||
|
||||
def _submit(self, fn: Callable, *args, **kwargs):
|
||||
"""Submit *fn* to the background thread and block until it completes."""
|
||||
self._start_thread()
|
||||
|
||||
if not self._alive:
|
||||
raise RuntimeError("Browser is not available")
|
||||
|
||||
self._reset_idle_timer()
|
||||
|
||||
result_slot: Dict[str, Any] = {"event": threading.Event()}
|
||||
self._task_queue.put((fn, args, kwargs, result_slot))
|
||||
|
||||
# Timeout prevents permanent hang if the background thread crashes
|
||||
completed = result_slot["event"].wait(timeout=120)
|
||||
if not completed:
|
||||
raise TimeoutError("Browser operation timed out (120s)")
|
||||
|
||||
if "error" in result_slot:
|
||||
raise result_slot["error"]
|
||||
return result_slot.get("value")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Idle auto-release
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _reset_idle_timer(self):
|
||||
self._cancel_idle_timer()
|
||||
if self._idle_timeout > 0:
|
||||
self._idle_timer = threading.Timer(self._idle_timeout, self._on_idle_timeout)
|
||||
self._idle_timer.daemon = True
|
||||
self._idle_timer.start()
|
||||
|
||||
def _cancel_idle_timer(self):
|
||||
if self._idle_timer:
|
||||
self._idle_timer.cancel()
|
||||
self._idle_timer = None
|
||||
|
||||
def _on_idle_timeout(self):
|
||||
logger.info(f"[Browser] Idle for {self._idle_timeout}s, auto-releasing browser")
|
||||
self.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def close(self):
|
||||
"""Shut down browser and background thread (safe from any thread)."""
|
||||
self._cancel_idle_timer()
|
||||
with self._lock:
|
||||
if not self._alive:
|
||||
return
|
||||
self._alive = False
|
||||
t = self._thread
|
||||
if self._task_queue is not None:
|
||||
self._task_queue.put(None)
|
||||
if t is not None and t.is_alive():
|
||||
t.join(timeout=10)
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Actions (each method is dispatched to the background thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def navigate(self, url: str, timeout: int = 30000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_navigate, url, timeout)
|
||||
|
||||
def _do_navigate(self, url: str, timeout: int) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
resp = page.goto(url, wait_until="domcontentloaded", timeout=timeout)
|
||||
status = resp.status if resp else None
|
||||
except Exception as e:
|
||||
return {"error": f"Navigation failed: {e}"}
|
||||
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=8000)
|
||||
except Exception:
|
||||
pass
|
||||
page.wait_for_timeout(500)
|
||||
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
current_url = page.url
|
||||
except Exception:
|
||||
current_url = url
|
||||
|
||||
return {"url": current_url, "title": title, "status": status}
|
||||
|
||||
def snapshot(self, selector: Optional[str] = None) -> str:
|
||||
return self._submit(self._do_snapshot, selector)
|
||||
|
||||
def _do_snapshot(self, selector: Optional[str] = None) -> str:
|
||||
page = self._page
|
||||
try:
|
||||
result = page.evaluate(_SNAPSHOT_JS)
|
||||
except Exception as e:
|
||||
return f"[Snapshot error: {e}]"
|
||||
|
||||
tree = result.get("tree")
|
||||
ref_count = result.get("refCount", 0)
|
||||
lines = _flatten_tree(tree)
|
||||
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
|
||||
header = f"Page: {title} ({url})\nInteractive elements: {ref_count}\n---"
|
||||
body = "\n".join(lines)
|
||||
|
||||
max_chars = self._config.get("snapshot_max_chars", 30000)
|
||||
if len(body) > max_chars:
|
||||
body = body[:max_chars] + "\n... [snapshot truncated]"
|
||||
|
||||
return f"{header}\n{body}"
|
||||
|
||||
def screenshot(self, full_page: bool = False, cwd: str = "") -> str:
|
||||
return self._submit(self._do_screenshot, full_page, cwd)
|
||||
|
||||
def _do_screenshot(self, full_page: bool = False, cwd: str = "") -> str:
|
||||
page = self._page
|
||||
save_dir = self._get_screenshot_dir(cwd)
|
||||
filename = f"screenshot_{uuid.uuid4().hex[:8]}.png"
|
||||
filepath = os.path.join(save_dir, filename)
|
||||
page.screenshot(path=filepath, full_page=full_page)
|
||||
logger.info(f"[Browser] Screenshot saved: {filepath}")
|
||||
return filepath
|
||||
|
||||
def click(self, ref: Optional[int] = None, selector: Optional[str] = None,
|
||||
timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_click, ref, selector, timeout)
|
||||
|
||||
def _do_click(self, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
|
||||
el.click();
|
||||
return {{ clicked: true, tag: el.tagName.toLowerCase() }};
|
||||
}}
|
||||
""")
|
||||
if result.get("error"):
|
||||
return result
|
||||
page.wait_for_timeout(500)
|
||||
return result
|
||||
elif selector:
|
||||
page.click(selector, timeout=timeout)
|
||||
return {"clicked": True, "selector": selector}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Click failed: {e}"}
|
||||
|
||||
def fill(self, text: str, ref: Optional[int] = None,
|
||||
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_fill, text, ref, selector, timeout)
|
||||
|
||||
def _do_fill(self, text, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
|
||||
el.focus();
|
||||
el.value = "";
|
||||
return {{ tag: el.tagName.toLowerCase(), name: el.name || "" }};
|
||||
}}
|
||||
""")
|
||||
if result.get("error"):
|
||||
return result
|
||||
page.keyboard.type(text)
|
||||
return {"filled": True, "ref": ref, "text": text}
|
||||
elif selector:
|
||||
page.fill(selector, text, timeout=timeout)
|
||||
return {"filled": True, "selector": selector, "text": text}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Fill failed: {e}"}
|
||||
|
||||
def select(self, value: str, ref: Optional[int] = None,
|
||||
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_select, value, ref, selector, timeout)
|
||||
|
||||
def _do_select(self, value, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el || el.tagName.toLowerCase() !== "select")
|
||||
return {{ error: "ref {ref} is not a <select> element" }};
|
||||
el.value = {repr(value)};
|
||||
el.dispatchEvent(new Event("change", {{ bubbles: true }}));
|
||||
return {{ selected: true, value: el.value }};
|
||||
}}
|
||||
""")
|
||||
return result
|
||||
elif selector:
|
||||
page.select_option(selector, value, timeout=timeout)
|
||||
return {"selected": True, "selector": selector, "value": value}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Select failed: {e}"}
|
||||
|
||||
def scroll(self, direction: str = "down", amount: int = 500) -> Dict[str, Any]:
|
||||
return self._submit(self._do_scroll, direction, amount)
|
||||
|
||||
def _do_scroll(self, direction, amount) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
delta_map = {
|
||||
"down": (0, amount),
|
||||
"up": (0, -amount),
|
||||
"right": (amount, 0),
|
||||
"left": (-amount, 0),
|
||||
}
|
||||
dx, dy = delta_map.get(direction, (0, amount))
|
||||
try:
|
||||
page.mouse.wheel(dx, dy)
|
||||
page.wait_for_timeout(300)
|
||||
scroll_info = page.evaluate("""
|
||||
() => ({
|
||||
scrollX: window.scrollX,
|
||||
scrollY: window.scrollY,
|
||||
scrollHeight: document.documentElement.scrollHeight,
|
||||
clientHeight: document.documentElement.clientHeight
|
||||
})
|
||||
""")
|
||||
return {"scrolled": direction, "amount": amount, **scroll_info}
|
||||
except Exception as e:
|
||||
return {"error": f"Scroll failed: {e}"}
|
||||
|
||||
def wait(self, selector: Optional[str] = None, timeout: int = 5000,
|
||||
state: str = "visible") -> Dict[str, Any]:
|
||||
return self._submit(self._do_wait, selector, timeout, state)
|
||||
|
||||
def _do_wait(self, selector, timeout, state) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if selector:
|
||||
page.wait_for_selector(selector, timeout=timeout, state=state)
|
||||
return {"waited": True, "selector": selector, "state": state}
|
||||
else:
|
||||
page.wait_for_timeout(timeout)
|
||||
return {"waited": True, "timeout_ms": timeout}
|
||||
except Exception as e:
|
||||
return {"error": f"Wait failed: {e}"}
|
||||
|
||||
def go_back(self) -> Dict[str, Any]:
|
||||
return self._submit(self._do_go_back)
|
||||
|
||||
def _do_go_back(self) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.go_back(wait_until="domcontentloaded", timeout=10000)
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
return {"url": url, "title": title}
|
||||
except Exception as e:
|
||||
return {"error": f"Go back failed: {e}"}
|
||||
|
||||
def go_forward(self) -> Dict[str, Any]:
|
||||
return self._submit(self._do_go_forward)
|
||||
|
||||
def _do_go_forward(self) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.go_forward(wait_until="domcontentloaded", timeout=10000)
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
return {"url": url, "title": title}
|
||||
except Exception as e:
|
||||
return {"error": f"Go forward failed: {e}"}
|
||||
|
||||
def get_text(self, selector: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_get_text, selector)
|
||||
|
||||
def _do_get_text(self, selector) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
text = page.text_content(selector, timeout=5000)
|
||||
return {"text": text or ""}
|
||||
except Exception as e:
|
||||
return {"error": f"Get text failed: {e}"}
|
||||
|
||||
def evaluate(self, script: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_evaluate, script)
|
||||
|
||||
def _do_evaluate(self, script) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
result = page.evaluate(script)
|
||||
return {"result": result}
|
||||
except Exception as e:
|
||||
return {"error": f"Evaluate failed: {e}"}
|
||||
|
||||
def press(self, key: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_press, key)
|
||||
|
||||
def _do_press(self, key) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.keyboard.press(key)
|
||||
page.wait_for_timeout(300)
|
||||
return {"pressed": key}
|
||||
except Exception as e:
|
||||
return {"error": f"Press failed: {e}"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_screenshot_dir(self, cwd: str = "") -> str:
|
||||
if self._screenshot_dir and os.path.isdir(self._screenshot_dir):
|
||||
return self._screenshot_dir
|
||||
base = cwd or os.getcwd()
|
||||
d = os.path.join(base, "tmp")
|
||||
os.makedirs(d, exist_ok=True)
|
||||
self._screenshot_dir = d
|
||||
return d
|
||||
290
agent/tools/browser/browser_tool.py
Normal file
290
agent/tools/browser/browser_tool.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Browser tool - Control a Chromium browser for web navigation and interaction.
|
||||
|
||||
Uses Playwright under the hood. Browser instance is lazily started on first
|
||||
use, reused across tool calls within the same session, and cleaned up via
|
||||
close().
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.browser.browser_service import BrowserService
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class BrowserTool(BaseTool):
|
||||
"""Single tool exposing all browser actions via an 'action' parameter."""
|
||||
|
||||
name: str = "browser"
|
||||
description: str = (
|
||||
"Control a browser to navigate web pages, interact with elements, and extract content. "
|
||||
"Actions: navigate, snapshot, click, fill, select, scroll, screenshot, wait, back, forward, "
|
||||
"get_text, press, evaluate.\n\n"
|
||||
"Workflow: navigate (auto-includes snapshot with element refs) → click/fill/select by ref → snapshot to verify.\n\n"
|
||||
"Use snapshot as the primary way to read pages. Use screenshot + send to show key results to the user. "
|
||||
"For login/CAPTCHA/authorization etc., screenshot and ask the user for help."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The browser action to perform. One of: "
|
||||
"navigate, snapshot, click, fill, select, scroll, "
|
||||
"screenshot, wait, back, forward, get_text, press, evaluate"
|
||||
),
|
||||
"enum": [
|
||||
"navigate", "snapshot", "click", "fill", "select", "scroll",
|
||||
"screenshot", "wait", "back", "forward", "get_text", "press",
|
||||
"evaluate"
|
||||
]
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to navigate to (for 'navigate' action)"
|
||||
},
|
||||
"ref": {
|
||||
"type": "integer",
|
||||
"description": "Element ref number from snapshot (for click/fill/select)"
|
||||
},
|
||||
"selector": {
|
||||
"type": "string",
|
||||
"description": "CSS selector as fallback when ref is unavailable (for click/fill/select/wait/get_text)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to type (for 'fill' action)"
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Option value (for 'select' action)"
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key to press, e.g. Enter, Tab, Escape (for 'press' action)"
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"description": "Scroll direction: up, down, left, right (for 'scroll' action, default: down)"
|
||||
},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "JavaScript code to execute (for 'evaluate' action)"
|
||||
},
|
||||
"full_page": {
|
||||
"type": "boolean",
|
||||
"description": "Capture full page screenshot (for 'screenshot' action, default: false)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in milliseconds (optional, default varies by action)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
_shared_service: Optional[BrowserService] = None
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self._service: Optional[BrowserService] = None
|
||||
|
||||
def _get_service(self) -> BrowserService:
|
||||
"""Get or create the browser service, sharing across copies."""
|
||||
if self._service is not None:
|
||||
return self._service
|
||||
|
||||
# Reuse shared service across tool copies within the same session
|
||||
if BrowserTool._shared_service is not None:
|
||||
self._service = BrowserTool._shared_service
|
||||
return self._service
|
||||
|
||||
self._service = BrowserService(self.config)
|
||||
BrowserTool._shared_service = self._service
|
||||
return self._service
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
action = args.get("action", "").strip().lower()
|
||||
if not action:
|
||||
return ToolResult.fail("Error: 'action' parameter is required")
|
||||
|
||||
handler = self._ACTION_MAP.get(action)
|
||||
if not handler:
|
||||
valid = ", ".join(sorted(self._ACTION_MAP.keys()))
|
||||
return ToolResult.fail(f"Unknown action '{action}'. Valid actions: {valid}")
|
||||
|
||||
try:
|
||||
return handler(self, args)
|
||||
except Exception as e:
|
||||
logger.error(f"[Browser] Action '{action}' error: {e}")
|
||||
return ToolResult.fail(f"Browser error ({action}): {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _do_navigate(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' is required for navigate action")
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
timeout = args.get("timeout", 30000)
|
||||
service = self._get_service()
|
||||
result = service.navigate(url, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
# Auto-snapshot after navigation so the agent gets page content in one call
|
||||
snapshot_text = service.snapshot()
|
||||
return ToolResult.success(
|
||||
f"Navigated to: {result['url']}\nTitle: {result['title']}\nStatus: {result['status']}\n\n"
|
||||
f"--- Page Snapshot ---\n{snapshot_text}"
|
||||
)
|
||||
|
||||
def _do_snapshot(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector")
|
||||
text = self._get_service().snapshot(selector=selector)
|
||||
return ToolResult.success(text)
|
||||
|
||||
def _do_click(self, args: Dict[str, Any]) -> ToolResult:
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
result = self._get_service().click(ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Clicked successfully. Use 'snapshot' to see updated page.")
|
||||
|
||||
def _do_fill(self, args: Dict[str, Any]) -> ToolResult:
|
||||
text = args.get("text", "")
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
if not text and text != "":
|
||||
return ToolResult.fail("Error: 'text' is required for fill action")
|
||||
result = self._get_service().fill(text, ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Filled text into element. Use 'snapshot' to verify.")
|
||||
|
||||
def _do_select(self, args: Dict[str, Any]) -> ToolResult:
|
||||
value = args.get("value", "")
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
if not value:
|
||||
return ToolResult.fail("Error: 'value' is required for select action")
|
||||
result = self._get_service().select(value, ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Selected option '{value}'.")
|
||||
|
||||
def _do_scroll(self, args: Dict[str, Any]) -> ToolResult:
|
||||
direction = args.get("direction", "down")
|
||||
amount = args.get("timeout", 500) # reuse timeout field or default
|
||||
if "amount" in args:
|
||||
amount = args["amount"]
|
||||
result = self._get_service().scroll(direction=direction, amount=amount)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
pos = f"scrollY={result.get('scrollY', '?')}/{result.get('scrollHeight', '?')}"
|
||||
return ToolResult.success(f"Scrolled {direction}. Position: {pos}")
|
||||
|
||||
def _do_screenshot(self, args: Dict[str, Any]) -> ToolResult:
|
||||
full_page = args.get("full_page", False)
|
||||
filepath = self._get_service().screenshot(full_page=full_page, cwd=self.cwd)
|
||||
return ToolResult.success(f"Screenshot saved to: {filepath}")
|
||||
|
||||
def _do_wait(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
result = self._get_service().wait(selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Wait completed.")
|
||||
|
||||
def _do_back(self, args: Dict[str, Any]) -> ToolResult:
|
||||
result = self._get_service().go_back()
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Navigated back to: {result['url']}")
|
||||
|
||||
def _do_forward(self, args: Dict[str, Any]) -> ToolResult:
|
||||
result = self._get_service().go_forward()
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Navigated forward to: {result['url']}")
|
||||
|
||||
def _do_get_text(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector", "").strip()
|
||||
if not selector:
|
||||
return ToolResult.fail("Error: 'selector' is required for get_text action")
|
||||
result = self._get_service().get_text(selector)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(result["text"])
|
||||
|
||||
def _do_press(self, args: Dict[str, Any]) -> ToolResult:
|
||||
key = args.get("key", "").strip()
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for press action")
|
||||
result = self._get_service().press(key)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Pressed key: {key}")
|
||||
|
||||
def _do_evaluate(self, args: Dict[str, Any]) -> ToolResult:
|
||||
script = args.get("script", "").strip()
|
||||
if not script:
|
||||
return ToolResult.fail("Error: 'script' is required for evaluate action")
|
||||
result = self._get_service().evaluate(script)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
val = result.get("result")
|
||||
if isinstance(val, (dict, list)):
|
||||
return ToolResult.success(json.dumps(val, ensure_ascii=False, indent=2))
|
||||
return ToolResult.success(str(val) if val is not None else "(no return value)")
|
||||
|
||||
# Action dispatch table
|
||||
_ACTION_MAP = {
|
||||
"navigate": _do_navigate,
|
||||
"snapshot": _do_snapshot,
|
||||
"click": _do_click,
|
||||
"fill": _do_fill,
|
||||
"select": _do_select,
|
||||
"scroll": _do_scroll,
|
||||
"screenshot": _do_screenshot,
|
||||
"wait": _do_wait,
|
||||
"back": _do_back,
|
||||
"forward": _do_forward,
|
||||
"get_text": _do_get_text,
|
||||
"press": _do_press,
|
||||
"evaluate": _do_evaluate,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def copy(self):
|
||||
"""Share browser instance across tool copies (avoids re-launching)."""
|
||||
new_tool = BrowserTool(self.config)
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, "context", None)
|
||||
new_tool.cwd = self.cwd
|
||||
new_tool._service = self._service
|
||||
return new_tool
|
||||
|
||||
def close(self):
|
||||
"""Release browser resources."""
|
||||
if self._service:
|
||||
self._service.close()
|
||||
self._service = None
|
||||
BrowserTool._shared_service = None
|
||||
logger.info("[Browser] BrowserTool closed")
|
||||
@@ -1,18 +0,0 @@
|
||||
def copy(self):
|
||||
"""
|
||||
Special copy method for browser tool to avoid recreating browser instance.
|
||||
|
||||
:return: A new instance with shared browser reference but unique model
|
||||
"""
|
||||
new_tool = self.__class__()
|
||||
|
||||
# Copy essential attributes
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, 'context', None)
|
||||
new_tool.config = getattr(self, 'config', None)
|
||||
|
||||
# Share the browser instance instead of creating a new one
|
||||
if hasattr(self, 'browser'):
|
||||
new_tool.browser = self.browser
|
||||
|
||||
return new_tool
|
||||
@@ -94,7 +94,7 @@ class Ls(BaseTool):
|
||||
results.append(entry + '/')
|
||||
else:
|
||||
results.append(entry)
|
||||
except:
|
||||
except Exception:
|
||||
# Skip entries we can't stat
|
||||
continue
|
||||
|
||||
|
||||
@@ -44,6 +44,19 @@ class MemoryGetTool(BaseTool):
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
self.description = (
|
||||
"Read specific content from memory or knowledge files. "
|
||||
"Use this to get full context from a memory file, knowledge page, or specific line range."
|
||||
)
|
||||
self.params = {**self.params}
|
||||
self.params["properties"] = {**self.params["properties"]}
|
||||
self.params["properties"]["path"] = {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory or knowledge file (e.g. 'MEMORY.md', 'memory/2026-01-01.md', 'knowledge/concepts/moe.md')"
|
||||
}
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
@@ -68,11 +81,15 @@ class MemoryGetTool(BaseTool):
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
# Exception: MEMORY.md is in the root directory
|
||||
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
# Exceptions: MEMORY.md in root, knowledge/ files at workspace root
|
||||
if not path.startswith('memory/') and not path.startswith('knowledge/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = workspace_dir / path
|
||||
file_path = (workspace_dir / path).resolve()
|
||||
workspace_resolved = workspace_dir.resolve()
|
||||
|
||||
if not str(file_path).startswith(str(workspace_resolved) + '/') and file_path != workspace_resolved:
|
||||
return ToolResult.fail(f"Error: Access denied: path outside workspace")
|
||||
|
||||
if not file_path.exists():
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
@@ -48,6 +48,13 @@ class MemorySearchTool(BaseTool):
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self.user_id = user_id
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
self.description = (
|
||||
"Search agent's long-term memory and knowledge base using semantic and keyword search. "
|
||||
"Use this to recall past conversations, preferences, and knowledge pages."
|
||||
)
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
|
||||
@@ -48,7 +48,8 @@ class Read(BaseTool):
|
||||
self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
|
||||
self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
|
||||
self.pdf_extensions = {'.pdf'}
|
||||
|
||||
self.office_extensions = {'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx'}
|
||||
|
||||
# Readable text formats (will be read with truncation)
|
||||
self.text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
|
||||
@@ -57,7 +58,6 @@ class Read(BaseTool):
|
||||
'.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
|
||||
'.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
|
||||
'.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
|
||||
'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx' # Office documents
|
||||
}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
@@ -120,7 +120,11 @@ class Read(BaseTool):
|
||||
# Check if PDF
|
||||
if file_ext in self.pdf_extensions:
|
||||
return self._read_pdf(absolute_path, path, offset, limit)
|
||||
|
||||
|
||||
# Check if Office document (.docx, .xlsx, .pptx, etc.)
|
||||
if file_ext in self.office_extensions:
|
||||
return self._read_office(absolute_path, path, file_ext, offset, limit)
|
||||
|
||||
# Read text file (with truncation for large files)
|
||||
return self._read_text(absolute_path, path, offset, limit)
|
||||
|
||||
@@ -240,8 +244,8 @@ class Read(BaseTool):
|
||||
"message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
|
||||
})
|
||||
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
# Read file (utf-8-sig strips BOM automatically on Windows)
|
||||
with open(absolute_path, 'r', encoding='utf-8-sig') as f:
|
||||
content = f.read()
|
||||
|
||||
# Truncate content if too long (20K characters max for model context)
|
||||
@@ -337,6 +341,116 @@ class Read(BaseTool):
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading file: {str(e)}")
|
||||
|
||||
def _read_office(self, absolute_path: str, display_path: str, file_ext: str,
|
||||
offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""Read Office documents (.docx, .xlsx, .pptx) using python-docx / openpyxl / python-pptx."""
|
||||
try:
|
||||
text = self._extract_office_text(absolute_path, file_ext)
|
||||
except ImportError as e:
|
||||
return ToolResult.fail(str(e))
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading Office document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
return ToolResult.success({
|
||||
"content": f"[Office file {Path(absolute_path).name}: no text content could be extracted]",
|
||||
})
|
||||
|
||||
all_lines = text.split('\n')
|
||||
total_lines = len(all_lines)
|
||||
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
if offset < 0:
|
||||
start_line = max(0, total_lines + offset)
|
||||
else:
|
||||
start_line = max(0, offset - 1)
|
||||
if start_line >= total_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||
)
|
||||
|
||||
selected_content = text
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
truncation = truncate_head(selected_content)
|
||||
start_line_display = start_line + 1
|
||||
output_text = ""
|
||||
|
||||
if truncation.truncated:
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||
remaining = total_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text = truncation.content
|
||||
|
||||
return ToolResult.success({
|
||||
"content": output_text,
|
||||
"total_lines": total_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _extract_office_text(absolute_path: str, file_ext: str) -> str:
|
||||
"""Extract plain text from an Office document."""
|
||||
if file_ext in ('.docx', '.doc'):
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError("Error: python-docx library not installed. Install with: pip install python-docx")
|
||||
doc = Document(absolute_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs]
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
paragraphs.append('\t'.join(cell.text for cell in row.cells))
|
||||
return '\n'.join(paragraphs)
|
||||
|
||||
if file_ext in ('.xlsx', '.xls'):
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
except ImportError:
|
||||
raise ImportError("Error: openpyxl library not installed. Install with: pip install openpyxl")
|
||||
wb = load_workbook(absolute_path, read_only=True, data_only=True)
|
||||
parts = []
|
||||
for ws in wb.worksheets:
|
||||
parts.append(f"--- Sheet: {ws.title} ---")
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
parts.append('\t'.join(str(c) if c is not None else '' for c in row))
|
||||
wb.close()
|
||||
return '\n'.join(parts)
|
||||
|
||||
if file_ext in ('.pptx', '.ppt'):
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError("Error: python-pptx library not installed. Install with: pip install python-pptx")
|
||||
prs = Presentation(absolute_path)
|
||||
parts = []
|
||||
for i, slide in enumerate(prs.slides, 1):
|
||||
parts.append(f"--- Slide {i} ---")
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for para in shape.text_frame.paragraphs:
|
||||
text = para.text.strip()
|
||||
if text:
|
||||
parts.append(text)
|
||||
return '\n'.join(parts)
|
||||
|
||||
return ""
|
||||
|
||||
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read PDF file content
|
||||
|
||||
@@ -134,12 +134,13 @@ def _execute_agent_task(task: dict, agent_bridge):
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk requires msg object, set to None for scheduled tasks
|
||||
context["msg"] = None
|
||||
# 如果是单聊,需要传递 sender_staff_id
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the task
|
||||
# Mark this as a scheduled task execution to prevent recursive task creation
|
||||
context["is_scheduled_task"] = True
|
||||
@@ -234,7 +235,11 @@ def _execute_send_message(task: dict, agent_bridge):
|
||||
logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
elif channel_type == "qq":
|
||||
context["msg"] = None
|
||||
|
||||
# Create reply
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
@@ -327,31 +332,31 @@ def _execute_tool_call(task: dict, agent_bridge):
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send tool result: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
|
||||
|
||||
@@ -409,7 +414,9 @@ def _execute_skill_call(task: dict, agent_bridge):
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the skill
|
||||
try:
|
||||
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
|
||||
@@ -451,8 +458,7 @@ def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
if context:
|
||||
tool.current_context = context
|
||||
|
||||
# Also set channel_type from config
|
||||
channel_type = conf().get("channel_type", "unknown")
|
||||
channel_type = context.get("channel_type") or conf().get("channel_type", "unknown")
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = channel_type
|
||||
|
||||
@@ -61,8 +61,7 @@ class SchedulerService:
|
||||
self._check_and_execute_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in scheduler loop: {e}")
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
def _check_and_execute_tasks(self):
|
||||
@@ -85,12 +84,9 @@ class SchedulerService:
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
else:
|
||||
# One-time task, disable it
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task completed and disabled: {task['id']}")
|
||||
# One-time task completed, remove it
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task completed and removed: {task['id']}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
|
||||
|
||||
@@ -127,14 +123,11 @@ class SchedulerService:
|
||||
if time_diff > 300: # 5 minutes
|
||||
logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run")
|
||||
|
||||
# For one-time tasks, disable them
|
||||
# For one-time tasks, remove them directly
|
||||
schedule = task.get("schedule", {})
|
||||
if schedule.get("type") == "once":
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, disabled")
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, removed")
|
||||
return False
|
||||
|
||||
# For recurring tasks, calculate next run from now
|
||||
@@ -147,7 +140,7 @@ class SchedulerService:
|
||||
return False
|
||||
|
||||
return now >= next_run
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
|
||||
@@ -195,7 +188,7 @@ class SchedulerService:
|
||||
# Only return if in the future
|
||||
if run_at > from_time:
|
||||
return run_at
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
@@ -424,7 +424,7 @@ class SchedulerTool(BaseTool):
|
||||
try:
|
||||
dt = datetime.fromisoformat(run_at)
|
||||
return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
|
||||
except:
|
||||
except Exception:
|
||||
return "一次性"
|
||||
|
||||
return "未知"
|
||||
@@ -438,6 +438,6 @@ class SchedulerTool(BaseTool):
|
||||
return msg.other_user_nickname or "群聊"
|
||||
else:
|
||||
return msg.from_user_nickname or "用户"
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return "未知"
|
||||
|
||||
@@ -72,7 +72,7 @@ class TaskStore:
|
||||
with open(self.store_path, 'r') as src:
|
||||
with open(backup_path, 'w') as dst:
|
||||
dst.write(src.read())
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save tasks
|
||||
|
||||
@@ -14,14 +14,14 @@ class Send(BaseTool):
|
||||
"""Tool for sending files to the user"""
|
||||
|
||||
name: str = "send"
|
||||
description: str = "Send a file (image, video, audio, document) to the user. Use this when the user explicitly asks to send/share a file."
|
||||
description: str = "Send a LOCAL file (image, video, audio, document) to the user. Only for local file paths. Do NOT use this for URLs — URLs should be included directly in your text reply, the system will handle them automatically."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to send. Can be absolute path or relative to workspace."
|
||||
"description": "Local file path to send. Must be an absolute path or relative to workspace. Do NOT pass URLs here."
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
@@ -98,7 +98,18 @@ class Send(BaseTool):
|
||||
"size_formatted": self._format_size(file_size),
|
||||
"message": message or f"正在发送 {file_name}"
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
from common.cloud_client import get_website_base_url, copy_send_file
|
||||
|
||||
# Do nothing when in local env
|
||||
if get_website_base_url():
|
||||
url = copy_send_file(absolute_path, self.cwd)
|
||||
if url:
|
||||
result["url"] = url
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
|
||||
@@ -84,11 +84,11 @@ class ToolManager:
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
if "playwright" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
@@ -154,11 +154,11 @@ class ToolManager:
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
if "playwright" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
@@ -197,7 +197,7 @@ class ToolManager:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool is configured but not loaded.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif tool_name == "google_search":
|
||||
|
||||
@@ -8,7 +8,10 @@ Truncation is based on two independent limits - whichever is hit first wins:
|
||||
Never returns partial lines (except bash tail truncation edge case).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Literal, Tuple
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 2000
|
||||
|
||||
1
agent/tools/vision/__init__.py
Normal file
1
agent/tools/vision/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from agent.tools.vision.vision import Vision
|
||||
512
agent/tools/vision/vision.py
Normal file
512
agent/tools/vision/vision.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""
|
||||
Vision tool - Analyze images using Vision API.
|
||||
Supports local files (auto base64-encoded) and HTTP URLs.
|
||||
|
||||
Provider priority (default):
|
||||
1. Main model via bot.call_vision — zero extra cost
|
||||
2. Other models whose API key is configured — auto-discovered
|
||||
3. OpenAI / LinkAI raw HTTP — reliable fallback
|
||||
When use_linkai=true, LinkAI is promoted to #1.
|
||||
When tool.vision.model is set, that model is used exclusively first.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
DEFAULT_MODEL = const.GPT_41_MINI
|
||||
DEFAULT_TIMEOUT = 60
|
||||
MAX_TOKENS = 1000
|
||||
COMPRESS_THRESHOLD = 1_048_576 # 1 MB
|
||||
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
}
|
||||
|
||||
_MAIN_MODEL_PROVIDER_NAME = "MainModel"
|
||||
|
||||
# (config_key_for_api_key, bot_type, default_vision_model, provider_display_name)
|
||||
# Auto-discovered as fallback vision providers when their API key is configured.
|
||||
# OpenAI and LinkAI are handled separately (raw HTTP providers), so not listed here.
|
||||
_DISCOVERABLE_MODELS = [
|
||||
("moonshot_api_key", const.MOONSHOT, const.KIMI_K2_6, "Moonshot"),
|
||||
("ark_api_key", const.DOUBAO, const.DOUBAO_SEED_2_PRO, "Doubao"),
|
||||
("dashscope_api_key", const.QWEN_DASHSCOPE, const.QWEN36_PLUS, "DashScope"),
|
||||
("claude_api_key", const.CLAUDEAPI, const.CLAUDE_4_6_SONNET, "Claude"),
|
||||
("gemini_api_key", const.GEMINI, const.GEMINI_31_FLASH_LITE_PRE, "Gemini"),
|
||||
("zhipu_ai_api_key", const.ZHIPU_AI, const.GLM_4_7, "ZhipuAI"),
|
||||
("minimax_api_key", const.MiniMax, const.MINIMAX_M2_7, "MiniMax"),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionProvider:
|
||||
"""A single Vision API provider configuration."""
|
||||
name: str
|
||||
api_key: str
|
||||
api_base: str
|
||||
extra_headers: dict = field(default_factory=dict)
|
||||
model_override: Optional[str] = None
|
||||
use_bot: bool = False # When True, call via bot.call_vision instead of raw HTTP
|
||||
fallback_bot: Any = None # Bot instance for non-main-model providers
|
||||
|
||||
|
||||
class VisionAPIError(Exception):
|
||||
"""Raised when a Vision API call fails and should trigger fallback."""
|
||||
pass
|
||||
|
||||
|
||||
class Vision(BaseTool):
|
||||
"""Analyze images using Vision API"""
|
||||
|
||||
name: str = "vision"
|
||||
description: str = (
|
||||
"Analyze a local image or image URL (jpg/jpeg/png) using Vision API. "
|
||||
"Can describe content, extract text, identify objects, colors, etc. "
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {
|
||||
"type": "string",
|
||||
"description": "Local file path or HTTP(S) URL of the image to analyze",
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Question to ask about the image",
|
||||
},
|
||||
},
|
||||
"required": ["image", "question"],
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return True
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
image = args.get("image", "").strip()
|
||||
question = args.get("question", "").strip()
|
||||
|
||||
if not image:
|
||||
return ToolResult.fail("Error: 'image' parameter is required")
|
||||
if not question:
|
||||
return ToolResult.fail("Error: 'question' parameter is required")
|
||||
|
||||
providers = self._resolve_providers()
|
||||
if not providers:
|
||||
return ToolResult.fail(
|
||||
"Error: No model available for Vision.\n"
|
||||
"The main model does not support vision and no other API keys are configured.\n"
|
||||
"Options:\n"
|
||||
" 1. Switch to a multimodal model (e.g. qwen3.6-plus, claude-sonnet-4-6, gemini-2.0-flash)\n"
|
||||
" 2. Configure OPENAI_API_KEY: env_config(action=\"set\", key=\"OPENAI_API_KEY\", value=\"your-key\")\n"
|
||||
" 3. Configure LINKAI_API_KEY: env_config(action=\"set\", key=\"LINKAI_API_KEY\", value=\"your-key\")"
|
||||
)
|
||||
|
||||
try:
|
||||
image_content = self._build_image_content(image)
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: {e}")
|
||||
|
||||
return self._call_with_fallback(providers, DEFAULT_MODEL, question, image_content)
|
||||
|
||||
def _call_with_fallback(self, providers: List[VisionProvider], model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""Try each provider in order; fall back to the next one on failure."""
|
||||
errors: List[str] = []
|
||||
for i, provider in enumerate(providers):
|
||||
use_model = provider.model_override or model
|
||||
try:
|
||||
logger.info(f"[Vision] Trying provider '{provider.name}' "
|
||||
f"with model '{use_model}' ({i + 1}/{len(providers)})")
|
||||
if provider.use_bot:
|
||||
result = self._call_via_bot(use_model, question, image_content, provider)
|
||||
else:
|
||||
result = self._call_api(provider, use_model, question, image_content)
|
||||
logger.info(f"[Vision] ✅ Success via {provider.name} (model={use_model})")
|
||||
return result
|
||||
except VisionAPIError as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' failed: {e}")
|
||||
except requests.Timeout:
|
||||
errors.append(f"[{provider.name}/{use_model}] Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' timed out")
|
||||
except requests.ConnectionError:
|
||||
errors.append(f"[{provider.name}/{use_model}] Connection failed")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' connection failed")
|
||||
except Exception as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.error(f"[Vision] Provider '{provider.name}' unexpected error: {e}", exc_info=True)
|
||||
|
||||
return ToolResult.fail(
|
||||
"Error: All Vision API providers failed.\n" + "\n".join(f" - {err}" for err in errors)
|
||||
)
|
||||
|
||||
def _resolve_providers(self) -> List[VisionProvider]:
|
||||
"""
|
||||
Build an ordered list of available providers.
|
||||
|
||||
Priority:
|
||||
- use_linkai=true → [LinkAI, MainModel, OtherModels…, OpenAI]
|
||||
- default → [MainModel, OtherModels…, OpenAI, LinkAI]
|
||||
|
||||
"OtherModels" are auto-discovered from configured API keys.
|
||||
The main model's bot_type is excluded from OtherModels to avoid
|
||||
duplicating the MainModel provider.
|
||||
"""
|
||||
use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
providers: List[VisionProvider] = []
|
||||
|
||||
if use_linkai:
|
||||
self._append_provider(providers, self._build_linkai_provider)
|
||||
self._append_provider(providers, self._build_main_model_provider)
|
||||
self._append_other_model_providers(providers)
|
||||
self._append_provider(providers, self._build_openai_provider)
|
||||
else:
|
||||
self._append_provider(providers, self._build_main_model_provider)
|
||||
self._append_other_model_providers(providers)
|
||||
self._append_provider(providers, self._build_openai_provider)
|
||||
self._append_provider(providers, self._build_linkai_provider)
|
||||
|
||||
return providers
|
||||
|
||||
@staticmethod
|
||||
def _append_provider(providers: List[VisionProvider], builder) -> None:
|
||||
p = builder()
|
||||
if p:
|
||||
providers.append(p)
|
||||
|
||||
def _append_other_model_providers(self, providers: List[VisionProvider]) -> None:
|
||||
"""
|
||||
Auto-discover other models whose API key is configured.
|
||||
Skip the main model's own bot_type (already covered by MainModel provider).
|
||||
Skip bot_types that already have a provider in the list (e.g. OpenAI).
|
||||
"""
|
||||
# Determine main model's bot_type so we can skip it
|
||||
main_bot_type = None
|
||||
if self.model and hasattr(self.model, '_resolve_bot_type'):
|
||||
main_bot_type = self.model._resolve_bot_type(conf().get("model", ""))
|
||||
|
||||
existing_names = {p.name for p in providers}
|
||||
|
||||
for config_key, bot_type, default_model, display_name in _DISCOVERABLE_MODELS:
|
||||
if display_name in existing_names:
|
||||
continue
|
||||
if bot_type == main_bot_type:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
continue
|
||||
|
||||
# Create a bot instance and check if it supports call_vision
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
providers.append(VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=default_model,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
))
|
||||
|
||||
def _resolve_vision_model(self) -> Optional[str]:
|
||||
"""
|
||||
Determine which model to use for vision.
|
||||
|
||||
1. User explicit config: tool.vision.model in config.json
|
||||
2. Fallback to the main configured model name
|
||||
"""
|
||||
tool_conf = conf().get("tool", {})
|
||||
user_vision_model = tool_conf.get("vision", {}).get("model") if isinstance(tool_conf, dict) else None
|
||||
if user_vision_model:
|
||||
return user_vision_model
|
||||
model_name = conf().get("model", "")
|
||||
return model_name or None
|
||||
|
||||
def _build_main_model_provider(self) -> Optional[VisionProvider]:
|
||||
"""
|
||||
Use the vendor's own model for vision via bot.call_vision.
|
||||
Only available when the bot class has call_vision.
|
||||
"""
|
||||
if not (self.model and hasattr(self.model, 'bot')):
|
||||
return None
|
||||
try:
|
||||
bot = self.model.bot
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
vision_model = self._resolve_vision_model()
|
||||
|
||||
return VisionProvider(
|
||||
name=_MAIN_MODEL_PROVIDER_NAME,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=vision_model,
|
||||
use_bot=True,
|
||||
)
|
||||
|
||||
def _build_openai_provider(self) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.openai.com/v1"
|
||||
return VisionProvider(name="OpenAI", api_key=api_key, api_base=self._ensure_v1(api_base))
|
||||
|
||||
def _build_linkai_provider(self) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.link-ai.tech"
|
||||
from common.utils import get_cloud_headers
|
||||
extra = get_cloud_headers(api_key)
|
||||
extra.pop("Authorization", None)
|
||||
extra.pop("Content-Type", None)
|
||||
return VisionProvider(name="LinkAI", api_key=api_key, api_base=self._ensure_v1(api_base),
|
||||
extra_headers=extra)
|
||||
|
||||
def _call_via_bot(self, model: str, question: str, image_content: dict,
|
||||
provider: Optional[VisionProvider] = None) -> ToolResult:
|
||||
"""
|
||||
Call a model's call_vision with vendor-native API format.
|
||||
Uses the provider's _fallback_bot if set, otherwise the main model bot.
|
||||
Raises VisionAPIError on failure so fallback can proceed.
|
||||
"""
|
||||
try:
|
||||
bot = (provider and provider.fallback_bot) or self.model.bot
|
||||
except Exception as e:
|
||||
raise VisionAPIError(f"Cannot access bot: {e}")
|
||||
|
||||
# Extract the raw image URL from the OpenAI-format image_content block
|
||||
image_url = image_content.get("image_url", {}).get("url", "")
|
||||
if not image_url:
|
||||
raise VisionAPIError("No image URL in content block")
|
||||
|
||||
try:
|
||||
response = bot.call_vision(
|
||||
image_url=image_url,
|
||||
question=question,
|
||||
model=model,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
except Exception as e:
|
||||
raise VisionAPIError(f"call_vision failed: {e}")
|
||||
|
||||
if response is NotImplemented:
|
||||
raise VisionAPIError("Bot does not support vision")
|
||||
|
||||
if isinstance(response, dict) and response.get("error"):
|
||||
raise VisionAPIError(f"API error - {response.get('message', 'Unknown')}")
|
||||
|
||||
content = response.get("content", "") if isinstance(response, dict) else ""
|
||||
if not content:
|
||||
raise VisionAPIError("Empty response from main model")
|
||||
|
||||
usage_info = response.get("usage", {}) if isinstance(response, dict) else {}
|
||||
|
||||
# Use the actual model name from the bot response if available
|
||||
actual_model = response.get("model", model) if isinstance(response, dict) else model
|
||||
provider_name = provider.name if provider else _MAIN_MODEL_PROVIDER_NAME
|
||||
return ToolResult.success({
|
||||
"model": actual_model,
|
||||
"provider": provider_name,
|
||||
"content": content,
|
||||
"usage": usage_info,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _ensure_v1(api_base: str) -> str:
|
||||
"""Append /v1 if the base URL doesn't already end with a versioned path."""
|
||||
if not api_base:
|
||||
return api_base
|
||||
# Already has /v1 or similar version suffix
|
||||
if api_base.rstrip("/").split("/")[-1].startswith("v"):
|
||||
return api_base
|
||||
return api_base.rstrip("/") + "/v1"
|
||||
|
||||
def _build_image_content(self, image: str) -> dict:
|
||||
"""
|
||||
Build the image_url content block.
|
||||
Both remote URLs and local files are converted to base64 data URLs
|
||||
so every bot backend can consume them without extra downloads.
|
||||
"""
|
||||
if image.startswith(("http://", "https://")):
|
||||
return self._download_to_data_url(image)
|
||||
|
||||
if not os.path.isfile(image):
|
||||
raise FileNotFoundError(f"Image file not found: {image}")
|
||||
|
||||
ext = image.rsplit(".", 1)[-1].lower() if "." in image else ""
|
||||
mime_type = SUPPORTED_EXTENSIONS.get(ext)
|
||||
if not mime_type:
|
||||
raise ValueError(
|
||||
f"Unsupported image format '.{ext}'. "
|
||||
f"Supported: {', '.join(SUPPORTED_EXTENSIONS.keys())}"
|
||||
)
|
||||
|
||||
file_path = self._maybe_compress(image)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
finally:
|
||||
if file_path != image and os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _download_to_data_url(url: str) -> dict:
|
||||
"""Download a remote image and return it as a base64 data URL."""
|
||||
resp = requests.get(url, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
raise VisionAPIError(f"Failed to download image: HTTP {resp.status_code}")
|
||||
content_type = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
|
||||
if not content_type.startswith("image/"):
|
||||
content_type = "image/jpeg"
|
||||
b64 = base64.b64encode(resp.content).decode("ascii")
|
||||
data_url = f"data:{content_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _maybe_compress(path: str) -> str:
|
||||
"""Compress image to under COMPRESS_THRESHOLD with max long-edge 1536px."""
|
||||
file_size = os.path.getsize(path)
|
||||
if file_size <= COMPRESS_THRESHOLD:
|
||||
return path
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||||
tmp.close()
|
||||
|
||||
def _try_sips(max_dim: str, quality: str) -> bool:
|
||||
try:
|
||||
subprocess.run(
|
||||
["sips", "-Z", max_dim, "-s", "formatOptions", quality,
|
||||
path, "--out", tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
return True
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return False
|
||||
|
||||
def _try_convert(max_dim: str, quality: str) -> bool:
|
||||
try:
|
||||
subprocess.run(
|
||||
["convert", path, "-resize", f"{max_dim}x{max_dim}>",
|
||||
"-quality", quality, tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
return True
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return False
|
||||
|
||||
attempts = [
|
||||
("1536", "85"),
|
||||
("1536", "70"),
|
||||
("1536", "50"),
|
||||
]
|
||||
|
||||
for max_dim, quality in attempts:
|
||||
ok = _try_sips(max_dim, quality) or _try_convert(max_dim, quality)
|
||||
if not ok:
|
||||
continue
|
||||
new_size = os.path.getsize(tmp.name)
|
||||
logger.debug(f"[Vision] Compressed image "
|
||||
f"({file_size // 1024}KB -> {new_size // 1024}KB, "
|
||||
f"max_dim={max_dim}, q={quality})")
|
||||
if new_size <= COMPRESS_THRESHOLD:
|
||||
return tmp.name
|
||||
|
||||
if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > 0:
|
||||
return tmp.name
|
||||
|
||||
os.remove(tmp.name)
|
||||
return path
|
||||
|
||||
def _call_api(self, provider: VisionProvider, model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""
|
||||
Call a single provider's Vision API.
|
||||
Raises VisionAPIError on recoverable failures so the caller can try
|
||||
the next provider.
|
||||
"""
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": question},
|
||||
image_content,
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**provider.extra_headers,
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{provider.api_base}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise VisionAPIError(f"HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
msg = data["error"].get("message", "Unknown API error")
|
||||
raise VisionAPIError(f"API error - {msg}")
|
||||
|
||||
content = ""
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "")
|
||||
|
||||
usage = data.get("usage", {})
|
||||
result = {
|
||||
"model": model,
|
||||
"provider": provider.name,
|
||||
"content": content,
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
}
|
||||
return ToolResult.success(result)
|
||||
0
agent/tools/web_fetch/__init__.py
Normal file
0
agent/tools/web_fetch/__init__.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
Web Fetch tool - Fetch and extract readable content from web pages and remote files.
|
||||
|
||||
Supports:
|
||||
- HTML web pages: extracts readable text content
|
||||
- Document files (PDF, Word, TXT, Markdown, etc.): downloads to workspace/tmp and parses content
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, Set
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size
|
||||
from common.log import logger
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
|
||||
"Accept": "*/*",
|
||||
}
|
||||
|
||||
# Supported document file extensions
|
||||
PDF_SUFFIXES: Set[str] = {".pdf"}
|
||||
WORD_SUFFIXES: Set[str] = {".docx"}
|
||||
TEXT_SUFFIXES: Set[str] = {".txt", ".md", ".markdown", ".rst", ".csv", ".tsv", ".log"}
|
||||
SPREADSHEET_SUFFIXES: Set[str] = {".xls", ".xlsx"}
|
||||
PPT_SUFFIXES: Set[str] = {".ppt", ".pptx"}
|
||||
|
||||
ALL_DOC_SUFFIXES = PDF_SUFFIXES | WORD_SUFFIXES | TEXT_SUFFIXES | SPREADSHEET_SUFFIXES | PPT_SUFFIXES
|
||||
|
||||
_CHARSET_RE = re.compile(r'charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_CHARSET_RE = re.compile(rb'<meta[^>]+charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_HTTP_EQUIV_RE = re.compile(
|
||||
rb'<meta[^>]+http-equiv\s*=\s*["\']?Content-Type["\']?[^>]+content\s*=\s*["\'][^"\']*charset=([\w\-]+)',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _extract_charset_from_content_type(content_type: str) -> Optional[str]:
|
||||
"""Extract charset from Content-Type header value."""
|
||||
m = _CHARSET_RE.search(content_type)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def _extract_charset_from_html_meta(raw_bytes: bytes) -> Optional[str]:
|
||||
"""Extract charset from HTML <meta> tags in the first few KB of raw bytes."""
|
||||
m = _META_CHARSET_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
m = _META_HTTP_EQUIV_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
return None
|
||||
|
||||
|
||||
def _get_url_suffix(url: str) -> str:
|
||||
"""Extract file extension from URL path, ignoring query params."""
|
||||
path = urlparse(url).path
|
||||
return os.path.splitext(path)[-1].lower()
|
||||
|
||||
|
||||
def _is_document_url(url: str) -> bool:
|
||||
"""Check if URL points to a downloadable document file."""
|
||||
suffix = _get_url_suffix(url)
|
||||
return suffix in ALL_DOC_SUFFIXES
|
||||
|
||||
|
||||
class WebFetch(BaseTool):
|
||||
"""Tool for fetching web pages and remote document files"""
|
||||
|
||||
name: str = "web_fetch"
|
||||
description: str = (
|
||||
"Fetch content from a http/https URL. For web pages, extracts readable text. "
|
||||
"For document files (PDF, Word, TXT, Markdown, Excel, PPT), downloads and parses the file content. "
|
||||
"Supported file types: .pdf, .docx, .txt, .md, .csv, .xls, .xlsx, .ppt, .pptx"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to fetch (web page or document file link)"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' parameter is required")
|
||||
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return ToolResult.fail("Error: Invalid URL (must start with http:// or https://)")
|
||||
|
||||
if _is_document_url(url):
|
||||
return self._fetch_document(url)
|
||||
|
||||
return self._fetch_webpage(url)
|
||||
|
||||
# ---- Web page fetching ----
|
||||
|
||||
def _fetch_webpage(self, url: str) -> ToolResult:
|
||||
"""Fetch and extract readable text from an HTML web page."""
|
||||
parsed = urlparse(url)
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: Failed to fetch URL: {e}")
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if self._is_binary_content_type(content_type) and not _is_document_url(url):
|
||||
return self._handle_download_by_content_type(url, response, content_type)
|
||||
|
||||
response.encoding = self._detect_encoding(response)
|
||||
html = response.text
|
||||
title = self._extract_title(html)
|
||||
text = self._extract_text(html)
|
||||
|
||||
return ToolResult.success(f"Title: {title}\n\nContent:\n{text}")
|
||||
|
||||
# ---- Document fetching ----
|
||||
|
||||
def _fetch_document(self, url: str) -> ToolResult:
|
||||
"""Download a document file and extract its text content."""
|
||||
suffix = _get_url_suffix(url)
|
||||
parsed = urlparse(url)
|
||||
filename = self._extract_filename(url)
|
||||
tmp_dir = self._ensure_tmp_dir()
|
||||
|
||||
local_path = os.path.join(tmp_dir, filename)
|
||||
logger.info(f"[WebFetch] Downloading document: {url} -> {local_path}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
content_length = int(response.headers.get("Content-Length", 0))
|
||||
if content_length > MAX_FILE_SIZE:
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large ({format_size(content_length)} > {format_size(MAX_FILE_SIZE)})"
|
||||
)
|
||||
|
||||
downloaded = 0
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
downloaded += len(chunk)
|
||||
if downloaded > MAX_FILE_SIZE:
|
||||
f.close()
|
||||
os.remove(local_path)
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large (>{format_size(MAX_FILE_SIZE)}), download aborted"
|
||||
)
|
||||
f.write(chunk)
|
||||
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Download timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to download file: {e}")
|
||||
|
||||
try:
|
||||
text = self._parse_document(local_path, suffix)
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to parse document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
file_size = os.path.getsize(local_path)
|
||||
return ToolResult.success(
|
||||
f"File downloaded to: {local_path} ({format_size(file_size)})\n"
|
||||
f"No text content could be extracted. The file may contain only images or be encrypted."
|
||||
)
|
||||
|
||||
truncation = truncate_head(text)
|
||||
result_text = truncation.content
|
||||
|
||||
file_size = os.path.getsize(local_path)
|
||||
header = f"[Document: {filename} | Size: {format_size(file_size)} | Saved to: {local_path}]\n\n"
|
||||
|
||||
if truncation.truncated:
|
||||
header += f"[Content truncated: showing {truncation.output_lines} of {truncation.total_lines} lines]\n\n"
|
||||
|
||||
return ToolResult.success(header + result_text)
|
||||
|
||||
def _parse_document(self, file_path: str, suffix: str) -> str:
|
||||
"""Parse document file and return extracted text."""
|
||||
if suffix in PDF_SUFFIXES:
|
||||
return self._parse_pdf(file_path)
|
||||
elif suffix in WORD_SUFFIXES:
|
||||
return self._parse_word(file_path)
|
||||
elif suffix in TEXT_SUFFIXES:
|
||||
return self._parse_text(file_path)
|
||||
elif suffix in SPREADSHEET_SUFFIXES:
|
||||
return self._parse_spreadsheet(file_path)
|
||||
elif suffix in PPT_SUFFIXES:
|
||||
return self._parse_ppt(file_path)
|
||||
else:
|
||||
return self._parse_text(file_path)
|
||||
|
||||
def _parse_pdf(self, file_path: str) -> str:
|
||||
"""Extract text from PDF using pypdf."""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
raise ImportError("pypdf library is required for PDF parsing. Install with: pip install pypdf")
|
||||
|
||||
reader = PdfReader(file_path)
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text and page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num}/{len(reader.pages)} ---\n{page_text}")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
def _parse_word(self, file_path: str) -> str:
|
||||
"""Extract text from Word documents (.docx)."""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-docx library is required for .docx parsing. Install with: pip install python-docx"
|
||||
)
|
||||
doc = Document(file_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return "\n\n".join(paragraphs)
|
||||
|
||||
def _parse_text(self, file_path: str) -> str:
|
||||
"""Read plain text files (txt, md, csv, etc.)."""
|
||||
encodings = ["utf-8", "utf-8-sig", "gbk", "gb2312", "latin-1"]
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(file_path, "r", encoding=enc) as f:
|
||||
return f.read()
|
||||
except (UnicodeDecodeError, UnicodeError):
|
||||
continue
|
||||
raise ValueError(f"Unable to decode file with any supported encoding: {encodings}")
|
||||
|
||||
def _parse_spreadsheet(self, file_path: str) -> str:
|
||||
"""Extract text from Excel files (.xls/.xlsx)."""
|
||||
try:
|
||||
import openpyxl
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"openpyxl library is required for .xlsx parsing. Install with: pip install openpyxl"
|
||||
)
|
||||
|
||||
wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True)
|
||||
result_parts = []
|
||||
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
rows = []
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
cells = [str(c) if c is not None else "" for c in row]
|
||||
if any(cells):
|
||||
rows.append(" | ".join(cells))
|
||||
if rows:
|
||||
result_parts.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
|
||||
|
||||
wb.close()
|
||||
return "\n\n".join(result_parts)
|
||||
|
||||
def _parse_ppt(self, file_path: str) -> str:
|
||||
"""Extract text from PowerPoint files (.ppt/.pptx)."""
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-pptx library is required for .pptx parsing. Install with: pip install python-pptx"
|
||||
)
|
||||
|
||||
prs = Presentation(file_path)
|
||||
text_parts = []
|
||||
|
||||
for slide_num, slide in enumerate(prs.slides, 1):
|
||||
slide_texts = []
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for paragraph in shape.text_frame.paragraphs:
|
||||
text = paragraph.text.strip()
|
||||
if text:
|
||||
slide_texts.append(text)
|
||||
if slide_texts:
|
||||
text_parts.append(f"--- Slide {slide_num}/{len(prs.slides)} ---\n" + "\n".join(slide_texts))
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
# ---- Encoding detection ----
|
||||
|
||||
@staticmethod
|
||||
def _detect_encoding(response: requests.Response) -> str:
|
||||
"""Detect response encoding with priority: Content-Type header > HTML meta > chardet > utf-8."""
|
||||
# 1. Check Content-Type header for explicit charset
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
charset = _extract_charset_from_content_type(content_type)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 2. Scan raw bytes for HTML meta charset declaration
|
||||
raw = response.content[:4096]
|
||||
charset = _extract_charset_from_html_meta(raw)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 3. Use apparent_encoding (chardet-based detection) if confident enough
|
||||
apparent = response.apparent_encoding
|
||||
if apparent:
|
||||
apparent_lower = apparent.lower()
|
||||
# Trust CJK / Windows encodings detected by chardet
|
||||
trusted_prefixes = ("utf", "gb", "big5", "euc", "shift_jis", "iso-2022", "windows", "ascii")
|
||||
if any(apparent_lower.startswith(p) for p in trusted_prefixes):
|
||||
return apparent
|
||||
|
||||
# 4. Fallback
|
||||
return "utf-8"
|
||||
|
||||
# ---- Helper methods ----
|
||||
|
||||
def _ensure_tmp_dir(self) -> str:
|
||||
"""Ensure workspace/tmp directory exists and return its path."""
|
||||
tmp_dir = os.path.join(self.cwd, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
def _extract_filename(self, url: str) -> str:
|
||||
"""Extract a safe filename from URL, with a short UUID prefix to avoid collisions."""
|
||||
path = urlparse(url).path
|
||||
basename = os.path.basename(unquote(path))
|
||||
if not basename or basename == "/":
|
||||
basename = "downloaded_file"
|
||||
# Sanitize: keep only safe chars
|
||||
basename = re.sub(r'[^\w.\-]', '_', basename)
|
||||
short_id = uuid.uuid4().hex[:8]
|
||||
return f"{short_id}_{basename}"
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_file(path: str):
|
||||
"""Remove a file if it exists, ignoring errors."""
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_binary_content_type(content_type: str) -> bool:
|
||||
"""Check if Content-Type indicates a binary/document response."""
|
||||
binary_types = [
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/octet-stream",
|
||||
]
|
||||
ct_lower = content_type.lower()
|
||||
return any(bt in ct_lower for bt in binary_types)
|
||||
|
||||
def _handle_download_by_content_type(self, url: str, response: requests.Response, content_type: str) -> ToolResult:
|
||||
"""Handle a URL that returned binary content instead of HTML."""
|
||||
ct_lower = content_type.lower()
|
||||
suffix_map = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml": ".docx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml": ".xlsx",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml": ".pptx",
|
||||
}
|
||||
detected_suffix = None
|
||||
for ct_prefix, ext in suffix_map.items():
|
||||
if ct_prefix in ct_lower:
|
||||
detected_suffix = ext
|
||||
break
|
||||
|
||||
if detected_suffix and detected_suffix in ALL_DOC_SUFFIXES:
|
||||
# Re-fetch as document
|
||||
return self._fetch_document(url if _get_url_suffix(url) in ALL_DOC_SUFFIXES
|
||||
else self._rewrite_url_with_suffix(url, detected_suffix))
|
||||
return ToolResult.fail(f"Error: URL returned binary content ({content_type}), not a supported document type")
|
||||
|
||||
@staticmethod
|
||||
def _rewrite_url_with_suffix(url: str, suffix: str) -> str:
|
||||
"""Append a suffix to the URL path so _get_url_suffix works correctly."""
|
||||
parsed = urlparse(url)
|
||||
new_path = parsed.path.rstrip("/") + suffix
|
||||
return parsed._replace(path=new_path).geturl()
|
||||
|
||||
# ---- HTML extraction (unchanged) ----
|
||||
|
||||
@staticmethod
|
||||
def _extract_title(html: str) -> str:
|
||||
match = re.search(r"<title[^>]*>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
|
||||
return match.group(1).strip() if match else "Untitled"
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(html: str) -> str:
|
||||
text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(""", '"').replace("'", "'").replace(" ", " ")
|
||||
text = re.sub(r"[^\S\n]+", " ", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
lines = [line.strip() for line in text.splitlines()]
|
||||
text = "\n".join(lines)
|
||||
return text.strip()
|
||||
@@ -13,6 +13,7 @@ import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
# Default timeout for API requests (seconds)
|
||||
@@ -23,11 +24,7 @@ class WebSearch(BaseTool):
|
||||
"""Tool for searching the web using Bocha or LinkAI search API"""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = (
|
||||
"Search the web for current information, news, research topics, or any real-time data. "
|
||||
"Returns web page titles, URLs, snippets, and optional summaries. "
|
||||
"Use this when the user asks about recent events, needs fact-checking, or wants up-to-date information."
|
||||
)
|
||||
description: str = "Search the web for real-time information. Returns titles, URLs, and snippets."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
@@ -225,12 +222,11 @@ class WebSearch(BaseTool):
|
||||
:return: Formatted search results
|
||||
"""
|
||||
api_key = os.environ.get("LINKAI_API_KEY", "")
|
||||
url = "https://api.link-ai.tech/v1/plugin/execute"
|
||||
api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
url = f"{api_base.rstrip('/')}/v1/plugin/execute"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
from common.utils import get_cloud_headers
|
||||
headers = get_cloud_headers(api_key)
|
||||
|
||||
payload = {
|
||||
"code": "web-search",
|
||||
|
||||
318
app.py
318
app.py
@@ -7,11 +7,260 @@ import time
|
||||
|
||||
from channel import channel_factory
|
||||
from common import const
|
||||
from config import load_config
|
||||
from common.log import logger
|
||||
from config import load_config, conf
|
||||
from plugins import *
|
||||
import threading
|
||||
|
||||
|
||||
_channel_mgr = None
|
||||
|
||||
|
||||
def get_channel_manager():
|
||||
return _channel_mgr
|
||||
|
||||
|
||||
def _parse_channel_type(raw) -> list:
|
||||
"""
|
||||
Parse channel_type config value into a list of channel names.
|
||||
Supports:
|
||||
- single string: "feishu"
|
||||
- comma-separated string: "feishu, dingtalk"
|
||||
- list: ["feishu", "dingtalk"]
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return [ch.strip() for ch in raw if ch.strip()]
|
||||
if isinstance(raw, str):
|
||||
return [ch.strip() for ch in raw.split(",") if ch.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manage the lifecycle of multiple channels running concurrently.
|
||||
Each channel.startup() runs in its own daemon thread.
|
||||
The web channel is started as default console unless explicitly disabled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._channels = {} # channel_name -> channel instance
|
||||
self._threads = {} # channel_name -> thread
|
||||
self._primary_channel = None
|
||||
self._lock = threading.Lock()
|
||||
self.cloud_mode = False # set to True when cloud client is active
|
||||
|
||||
@property
|
||||
def channel(self):
|
||||
"""Return the primary (first non-web) channel for backward compatibility."""
|
||||
return self._primary_channel
|
||||
|
||||
def get_channel(self, channel_name: str):
|
||||
return self._channels.get(channel_name)
|
||||
|
||||
def start(self, channel_names: list, first_start: bool = False):
|
||||
"""
|
||||
Create and start one or more channels in sub-threads.
|
||||
If first_start is True, plugins and linkai client will also be initialized.
|
||||
"""
|
||||
with self._lock:
|
||||
channels = []
|
||||
for name in channel_names:
|
||||
ch = channel_factory.create_channel(name)
|
||||
ch.cloud_mode = self.cloud_mode
|
||||
self._channels[name] = ch
|
||||
channels.append((name, ch))
|
||||
if self._primary_channel is None and name != "web":
|
||||
self._primary_channel = ch
|
||||
|
||||
if self._primary_channel is None and channels:
|
||||
self._primary_channel = channels[0][1]
|
||||
|
||||
if first_start:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
# Cloud client is optional. It is only started when
|
||||
# use_linkai=True AND cloud_deployment_id is set.
|
||||
# By default neither is configured, so the app runs
|
||||
# entirely locally without any remote connection.
|
||||
if conf().get("use_linkai") and (
|
||||
os.environ.get("CLOUD_DEPLOYMENT_ID") or conf().get("cloud_deployment_id")
|
||||
):
|
||||
try:
|
||||
from common import cloud_client
|
||||
threading.Thread(
|
||||
target=cloud_client.start,
|
||||
args=(self._primary_channel, self),
|
||||
daemon=True,
|
||||
).start()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Start web console first so its logs print cleanly,
|
||||
# then start remaining channels after a brief pause.
|
||||
web_entry = None
|
||||
other_entries = []
|
||||
for entry in channels:
|
||||
if entry[0] == "web":
|
||||
web_entry = entry
|
||||
else:
|
||||
other_entries.append(entry)
|
||||
|
||||
ordered = ([web_entry] if web_entry else []) + other_entries
|
||||
for i, (name, ch) in enumerate(ordered):
|
||||
if i > 0 and name != "web":
|
||||
time.sleep(0.1)
|
||||
t = threading.Thread(target=self._run_channel, args=(name, ch), daemon=True)
|
||||
self._threads[name] = t
|
||||
t.start()
|
||||
logger.debug(f"[ChannelManager] Channel '{name}' started in sub-thread")
|
||||
|
||||
def _run_channel(self, name: str, channel):
|
||||
try:
|
||||
channel.startup()
|
||||
except Exception as e:
|
||||
logger.error(f"[ChannelManager] Channel '{name}' startup error: {e}")
|
||||
logger.exception(e)
|
||||
|
||||
def stop(self, channel_name: str = None):
|
||||
"""
|
||||
Stop channel(s). If channel_name is given, stop only that channel;
|
||||
otherwise stop all channels.
|
||||
"""
|
||||
# Pop under lock, then stop outside lock to avoid deadlock
|
||||
with self._lock:
|
||||
names = [channel_name] if channel_name else list(self._channels.keys())
|
||||
to_stop = []
|
||||
for name in names:
|
||||
ch = self._channels.pop(name, None)
|
||||
th = self._threads.pop(name, None)
|
||||
to_stop.append((name, ch, th))
|
||||
if channel_name and self._primary_channel is self._channels.get(channel_name):
|
||||
self._primary_channel = None
|
||||
|
||||
for name, ch, th in to_stop:
|
||||
if ch is None:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' not found in managed channels")
|
||||
if th and th.is_alive():
|
||||
self._interrupt_thread(th, name)
|
||||
continue
|
||||
logger.info(f"[ChannelManager] Stopping channel '{name}'...")
|
||||
graceful = False
|
||||
if hasattr(ch, 'stop'):
|
||||
try:
|
||||
ch.stop()
|
||||
graceful = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Error during channel '{name}' stop: {e}")
|
||||
if th and th.is_alive():
|
||||
th.join(timeout=5)
|
||||
if th.is_alive():
|
||||
if graceful:
|
||||
logger.info(f"[ChannelManager] Channel '{name}' thread still alive after stop(), "
|
||||
"leaving daemon thread to finish on its own")
|
||||
else:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' thread did not exit in 5s, forcing interrupt")
|
||||
self._interrupt_thread(th, name)
|
||||
|
||||
@staticmethod
|
||||
def _interrupt_thread(th: threading.Thread, name: str):
|
||||
"""Raise SystemExit in target thread to break blocking loops like start_forever."""
|
||||
import ctypes
|
||||
try:
|
||||
tid = th.ident
|
||||
if tid is None:
|
||||
return
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
|
||||
)
|
||||
if res == 1:
|
||||
logger.info(f"[ChannelManager] Interrupted thread for channel '{name}'")
|
||||
elif res > 1:
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
|
||||
logger.warning(f"[ChannelManager] Failed to interrupt thread for channel '{name}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Thread interrupt error for '{name}': {e}")
|
||||
|
||||
def restart(self, new_channel_name: str):
|
||||
"""
|
||||
Restart a single channel with a new channel type.
|
||||
Can be called from any thread (e.g. linkai config callback).
|
||||
"""
|
||||
logger.info(f"[ChannelManager] Restarting channel to '{new_channel_name}'...")
|
||||
self.stop(new_channel_name)
|
||||
_clear_singleton_cache(new_channel_name)
|
||||
time.sleep(1)
|
||||
self.start([new_channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel restarted to '{new_channel_name}' successfully")
|
||||
|
||||
def add_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically add and start a new channel.
|
||||
If the channel is already running, restart it instead.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name in self._channels:
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' already exists, restarting")
|
||||
if self._channels.get(channel_name):
|
||||
self.restart(channel_name)
|
||||
return
|
||||
logger.info(f"[ChannelManager] Adding channel '{channel_name}'...")
|
||||
_clear_singleton_cache(channel_name)
|
||||
self.start([channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' added successfully")
|
||||
|
||||
def remove_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically stop and remove a running channel.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name not in self._channels:
|
||||
logger.warning(f"[ChannelManager] Channel '{channel_name}' not found, nothing to remove")
|
||||
return
|
||||
logger.info(f"[ChannelManager] Removing channel '{channel_name}'...")
|
||||
self.stop(channel_name)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' removed successfully")
|
||||
|
||||
|
||||
def _clear_singleton_cache(channel_name: str):
|
||||
"""
|
||||
Clear the singleton cache for the channel class so that
|
||||
a new instance can be created with updated config.
|
||||
"""
|
||||
cls_map = {
|
||||
"web": "channel.web.web_channel.WebChannel",
|
||||
"wechatmp": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatmp_service": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatcom_app": "channel.wechatcom.wechatcomapp_channel.WechatComAppChannel",
|
||||
const.FEISHU: "channel.feishu.feishu_channel.FeiShuChanel",
|
||||
const.DINGTALK: "channel.dingtalk.dingtalk_channel.DingTalkChanel",
|
||||
const.WECOM_BOT: "channel.wecom_bot.wecom_bot_channel.WecomBotChannel",
|
||||
const.QQ: "channel.qq.qq_channel.QQChannel",
|
||||
const.WEIXIN: "channel.weixin.weixin_channel.WeixinChannel",
|
||||
"wx": "channel.weixin.weixin_channel.WeixinChannel",
|
||||
}
|
||||
module_path = cls_map.get(channel_name)
|
||||
if not module_path:
|
||||
return
|
||||
try:
|
||||
parts = module_path.rsplit(".", 1)
|
||||
module_name, class_name = parts[0], parts[1]
|
||||
import importlib
|
||||
module = importlib.import_module(module_name)
|
||||
wrapper = getattr(module, class_name, None)
|
||||
if wrapper and hasattr(wrapper, '__closure__') and wrapper.__closure__:
|
||||
for cell in wrapper.__closure__:
|
||||
try:
|
||||
cell_contents = cell.cell_contents
|
||||
if isinstance(cell_contents, dict):
|
||||
cell_contents.clear()
|
||||
logger.debug(f"[ChannelManager] Cleared singleton cache for {class_name}")
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Failed to clear singleton cache: {e}")
|
||||
|
||||
|
||||
def sigterm_handler_wrap(_signo):
|
||||
old_handler = signal.getsignal(_signo)
|
||||
|
||||
@@ -25,22 +274,41 @@ def sigterm_handler_wrap(_signo):
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def start_channel(channel_name: str):
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "web", "wechatmp_service", "wechatcom_app", "wework",
|
||||
const.FEISHU, const.DINGTALK]:
|
||||
PluginManager().load_plugins()
|
||||
def _sync_builtin_skills():
|
||||
"""Sync builtin skills from project skills/ to workspace skills/ on startup."""
|
||||
import shutil
|
||||
try:
|
||||
workspace = conf().get("agent_workspace", "~/cow")
|
||||
workspace = os.path.expanduser(workspace)
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
builtin_dir = os.path.join(project_root, "skills")
|
||||
custom_dir = os.path.join(workspace, "skills")
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
try:
|
||||
from common import linkai_client
|
||||
threading.Thread(target=linkai_client.start, args=(channel,)).start()
|
||||
except Exception as e:
|
||||
pass
|
||||
channel.startup()
|
||||
if not os.path.isdir(builtin_dir):
|
||||
return
|
||||
|
||||
os.makedirs(custom_dir, exist_ok=True)
|
||||
synced = 0
|
||||
for name in os.listdir(builtin_dir):
|
||||
src = os.path.join(builtin_dir, name)
|
||||
if not os.path.isdir(src) or not os.path.isfile(os.path.join(src, "SKILL.md")):
|
||||
continue
|
||||
dst = os.path.join(custom_dir, name)
|
||||
try:
|
||||
if os.path.isdir(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst)
|
||||
synced += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Failed to sync builtin skill '{name}': {e}")
|
||||
if synced:
|
||||
logger.info(f"[App] Synced {synced} builtin skill(s) to workspace")
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Builtin skills sync failed: {e}")
|
||||
|
||||
|
||||
def run():
|
||||
global _channel_mgr
|
||||
try:
|
||||
# load config
|
||||
load_config()
|
||||
@@ -49,16 +317,28 @@ def run():
|
||||
# kill signal
|
||||
sigterm_handler_wrap(signal.SIGTERM)
|
||||
|
||||
# create channel
|
||||
channel_name = conf().get("channel_type", "wx")
|
||||
# Parse channel_type into a list
|
||||
raw_channel = conf().get("channel_type", "web")
|
||||
|
||||
if "--cmd" in sys.argv:
|
||||
channel_name = "terminal"
|
||||
channel_names = ["terminal"]
|
||||
else:
|
||||
channel_names = _parse_channel_type(raw_channel)
|
||||
if not channel_names:
|
||||
channel_names = ["web"]
|
||||
|
||||
if channel_name == "wxy":
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
# Auto-start web console unless explicitly disabled
|
||||
web_console_enabled = conf().get("web_console", True)
|
||||
if web_console_enabled and "web" not in channel_names:
|
||||
channel_names.append("web")
|
||||
|
||||
start_channel(channel_name)
|
||||
# Sync builtin skills to workspace before channels start
|
||||
_sync_builtin_skills()
|
||||
|
||||
logger.info(f"[App] Starting channels: {channel_names}")
|
||||
|
||||
_channel_mgr = ChannelManager()
|
||||
_channel_mgr.start(channel_names, first_start=True)
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
@@ -28,7 +28,7 @@ def add_openai_compatible_support(bot_instance):
|
||||
"""
|
||||
if hasattr(bot_instance, 'call_with_tools'):
|
||||
# Bot already has tool calling support (e.g., ZHIPUAIBot)
|
||||
logger.info(f"[AgentBridge] {type(bot_instance).__name__} already has native tool calling support")
|
||||
logger.debug(f"[AgentBridge] {type(bot_instance).__name__} already has native tool calling support")
|
||||
return bot_instance
|
||||
|
||||
# Create a temporary mixin class that combines the bot with OpenAI compatibility
|
||||
@@ -65,30 +65,74 @@ class AgentLLMModel(LLMModel):
|
||||
LLM Model adapter that uses COW's existing bot infrastructure
|
||||
"""
|
||||
|
||||
_MODEL_BOT_TYPE_MAP = {
|
||||
"wenxin": const.BAIDU, "wenxin-4": const.BAIDU,
|
||||
"xunfei": const.XUNFEI, const.QWEN: const.QWEN_DASHSCOPE,
|
||||
const.MODELSCOPE: const.MODELSCOPE,
|
||||
}
|
||||
_MODEL_PREFIX_MAP = [
|
||||
("qwen", const.QWEN_DASHSCOPE), ("qwq", const.QWEN_DASHSCOPE), ("qvq", const.QWEN_DASHSCOPE),
|
||||
("gemini", const.GEMINI), ("glm", const.ZHIPU_AI), ("claude", const.CLAUDEAPI),
|
||||
("moonshot", const.MOONSHOT), ("kimi", const.MOONSHOT),
|
||||
("doubao", const.DOUBAO), ("deepseek", const.DEEPSEEK),
|
||||
]
|
||||
|
||||
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
|
||||
# Get model name directly from config
|
||||
from config import conf
|
||||
model_name = conf().get("model", const.GPT_41)
|
||||
super().__init__(model=model_name)
|
||||
super().__init__(model=conf().get("model", const.GPT_41))
|
||||
self.bridge = bridge
|
||||
self.bot_type = bot_type
|
||||
self._bot = None
|
||||
self._use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
|
||||
self._bot_model = None
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
from config import conf
|
||||
return conf().get("model", const.GPT_41)
|
||||
|
||||
@model.setter
|
||||
def model(self, value):
|
||||
pass
|
||||
|
||||
def _resolve_bot_type(self, model_name: str) -> str:
|
||||
"""Resolve bot type from model name, matching Bridge.__init__ logic."""
|
||||
from config import conf
|
||||
|
||||
if conf().get("use_linkai", False) and conf().get("linkai_api_key"):
|
||||
return const.LINKAI
|
||||
# Support custom bot type configuration
|
||||
configured_bot_type = conf().get("bot_type")
|
||||
if configured_bot_type:
|
||||
return configured_bot_type
|
||||
|
||||
if not model_name or not isinstance(model_name, str):
|
||||
return const.OPENAI
|
||||
if model_name in self._MODEL_BOT_TYPE_MAP:
|
||||
return self._MODEL_BOT_TYPE_MAP[model_name]
|
||||
if model_name.lower().startswith("minimax") or model_name in ["abab6.5-chat"]:
|
||||
return const.MiniMax
|
||||
if model_name in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
return const.QWEN_DASHSCOPE
|
||||
if model_name in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
return const.MOONSHOT
|
||||
if conf().get("bot_type") == "modelscope":
|
||||
return const.MODELSCOPE
|
||||
for prefix, btype in self._MODEL_PREFIX_MAP:
|
||||
if model_name.startswith(prefix):
|
||||
return btype
|
||||
return const.OPENAI
|
||||
|
||||
@property
|
||||
def bot(self):
|
||||
"""Lazy load the bot and enhance it with tool calling if needed"""
|
||||
if self._bot is None:
|
||||
# If use_linkai is enabled, use LinkAI bot directly
|
||||
if self._use_linkai:
|
||||
self._bot = self.bridge.find_chat_bot(const.LINKAI)
|
||||
else:
|
||||
self._bot = self.bridge.get_bot(self.bot_type)
|
||||
# Automatically add tool calling support if not present
|
||||
self._bot = add_openai_compatible_support(self._bot)
|
||||
|
||||
# Log bot info
|
||||
bot_name = type(self._bot).__name__
|
||||
"""Lazy load the bot, re-create when model or bot_type changes"""
|
||||
from models.bot_factory import create_bot
|
||||
cur_model = self.model
|
||||
cur_bot_type = self._resolve_bot_type(cur_model)
|
||||
if self._bot is None or self._bot_model != cur_model or getattr(self, '_bot_type', None) != cur_bot_type:
|
||||
self._bot = create_bot(cur_bot_type)
|
||||
self._bot = add_openai_compatible_support(self._bot)
|
||||
self._bot_model = cur_model
|
||||
self._bot_type = cur_bot_type
|
||||
return self._bot
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
@@ -109,12 +153,28 @@ class AgentLLMModel(LLMModel):
|
||||
# Only pass max_tokens if it's explicitly set
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
|
||||
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
|
||||
# Pass context metadata to bot
|
||||
channel_type = getattr(self, 'channel_type', None) or ''
|
||||
if channel_type:
|
||||
kwargs['channel_type'] = channel_type
|
||||
session_id = getattr(self, 'session_id', None)
|
||||
if session_id:
|
||||
kwargs['session_id'] = session_id
|
||||
|
||||
# Determine thinking: respect global config, then channel_type
|
||||
from config import conf
|
||||
global_thinking = conf().get("enable_thinking", False)
|
||||
if not global_thinking:
|
||||
kwargs['thinking'] = {"type": "disabled"}
|
||||
else:
|
||||
kwargs['thinking'] = {"type": "enabled"} if channel_type == "web" else {"type": "disabled"}
|
||||
|
||||
response = self.bot.call_with_tools(**kwargs)
|
||||
return self._format_response(response)
|
||||
else:
|
||||
@@ -135,7 +195,7 @@ class AgentLLMModel(LLMModel):
|
||||
# Use tool-enabled streaming call if available
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
|
||||
|
||||
# Build kwargs for call_with_tools
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
@@ -143,15 +203,31 @@ class AgentLLMModel(LLMModel):
|
||||
'stream': True,
|
||||
'model': self.model # Pass model parameter
|
||||
}
|
||||
|
||||
|
||||
# Only pass max_tokens if explicitly set, let the bot use its default
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
|
||||
|
||||
# Add system prompt if present
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
|
||||
# Pass context metadata to bot
|
||||
channel_type = getattr(self, 'channel_type', None) or ''
|
||||
if channel_type:
|
||||
kwargs['channel_type'] = channel_type
|
||||
session_id = getattr(self, 'session_id', None)
|
||||
if session_id:
|
||||
kwargs['session_id'] = session_id
|
||||
|
||||
# Determine thinking: respect global config, then channel_type
|
||||
from config import conf
|
||||
global_thinking = conf().get("enable_thinking", False)
|
||||
if not global_thinking:
|
||||
kwargs['thinking'] = {"type": "disabled"}
|
||||
else:
|
||||
kwargs['thinking'] = {"type": "enabled"} if channel_type == "web" else {"type": "disabled"}
|
||||
|
||||
stream = self.bot.call_with_tools(**kwargs)
|
||||
|
||||
# Convert stream format to our expected format
|
||||
@@ -214,10 +290,13 @@ class AgentBridge:
|
||||
tool_manager.load_tools()
|
||||
|
||||
tools = []
|
||||
workspace_dir = kwargs.get("workspace_dir")
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
if tool:
|
||||
if workspace_dir and hasattr(tool, 'cwd'):
|
||||
tool.cwd = workspace_dir
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
|
||||
@@ -230,12 +309,13 @@ class AgentBridge:
|
||||
tools=tools,
|
||||
max_steps=kwargs.get("max_steps", 15),
|
||||
output_mode=kwargs.get("output_mode", "logger"),
|
||||
workspace_dir=kwargs.get("workspace_dir"), # Pass workspace for skills loading
|
||||
enable_skills=kwargs.get("enable_skills", True), # Enable skills by default
|
||||
memory_manager=kwargs.get("memory_manager"), # Pass memory manager
|
||||
workspace_dir=kwargs.get("workspace_dir"),
|
||||
skill_manager=kwargs.get("skill_manager"),
|
||||
enable_skills=kwargs.get("enable_skills", True),
|
||||
memory_manager=kwargs.get("memory_manager"),
|
||||
max_context_tokens=kwargs.get("max_context_tokens"),
|
||||
context_reserve_tokens=kwargs.get("context_reserve_tokens"),
|
||||
runtime_info=kwargs.get("runtime_info") # Pass runtime_info for dynamic time updates
|
||||
runtime_info=kwargs.get("runtime_info"),
|
||||
)
|
||||
|
||||
# Log skill loading details
|
||||
@@ -290,9 +370,10 @@ class AgentBridge:
|
||||
Returns:
|
||||
Reply object
|
||||
"""
|
||||
session_id = None
|
||||
agent = None
|
||||
try:
|
||||
# Extract session_id from context for user isolation
|
||||
session_id = None
|
||||
if context:
|
||||
session_id = context.kwargs.get("session_id") or context.get("session_id")
|
||||
|
||||
@@ -325,6 +406,14 @@ class AgentBridge:
|
||||
logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
|
||||
break
|
||||
|
||||
# Pass context metadata to model for downstream API requests
|
||||
if context and hasattr(agent, 'model'):
|
||||
agent.model.channel_type = context.get("channel_type", "")
|
||||
agent.model.session_id = session_id or ""
|
||||
|
||||
# Store session_id on agent so executor can clear DB on fatal errors
|
||||
agent._current_session_id = session_id
|
||||
|
||||
try:
|
||||
# Use agent's run_stream method with event handler
|
||||
response = agent.run_stream(
|
||||
@@ -336,11 +425,28 @@ class AgentBridge:
|
||||
# Restore original tools
|
||||
if context and context.get("is_scheduled_task"):
|
||||
agent.tools = original_tools
|
||||
|
||||
|
||||
# Log execution summary
|
||||
event_handler.log_summary()
|
||||
|
||||
# Persist new messages generated during this run
|
||||
if session_id:
|
||||
channel_type = (context.get("channel_type") or "") if context else ""
|
||||
new_messages = getattr(agent, '_last_run_new_messages', [])
|
||||
if new_messages:
|
||||
self._persist_messages(session_id, list(new_messages), channel_type)
|
||||
else:
|
||||
with agent.messages_lock:
|
||||
msg_count = len(agent.messages)
|
||||
if msg_count == 0:
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().clear_session(session_id)
|
||||
logger.info(f"[AgentBridge] Cleared DB for recovered session: {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to clear DB after recovery: {e}")
|
||||
|
||||
# Check if there are files to send (from read tool)
|
||||
# Check if there are files to send (from send/read tool)
|
||||
if hasattr(agent, 'stream_executor') and hasattr(agent.stream_executor, 'files_to_send'):
|
||||
files_to_send = agent.stream_executor.files_to_send
|
||||
if files_to_send:
|
||||
@@ -358,6 +464,18 @@ class AgentBridge:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent reply error: {e}")
|
||||
# If the agent cleared its messages due to format error / overflow,
|
||||
# also purge the DB so the next request starts clean.
|
||||
if session_id and agent:
|
||||
try:
|
||||
with agent.messages_lock:
|
||||
msg_count = len(agent.messages)
|
||||
if msg_count == 0:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().clear_session(session_id)
|
||||
logger.info(f"[AgentBridge] Cleared DB for session after error: {session_id}")
|
||||
except Exception as db_err:
|
||||
logger.warning(f"[AgentBridge] Failed to clear DB after error: {db_err}")
|
||||
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
|
||||
|
||||
def _create_file_reply(self, file_info: dict, text_response: str, context: Context = None) -> Reply:
|
||||
@@ -397,22 +515,26 @@ class AgentBridge:
|
||||
reply.text_content = text_response
|
||||
return reply
|
||||
|
||||
# For other unknown file types, return text with file info
|
||||
message = text_response or file_info.get("message", "文件已准备")
|
||||
message += f"\n\n[文件: {file_info.get('file_name', file_path)}]"
|
||||
return Reply(ReplyType.TEXT, message)
|
||||
# For all other file types (tar.gz, zip, etc.), also use FILE type
|
||||
file_url = f"file://{file_path}"
|
||||
logger.info(f"[AgentBridge] Sending generic file: {file_url}")
|
||||
reply = Reply(ReplyType.FILE, file_url)
|
||||
reply.file_name = file_info.get("file_name", os.path.basename(file_path))
|
||||
if text_response:
|
||||
reply.text_content = text_response
|
||||
return reply
|
||||
|
||||
def _migrate_config_to_env(self, workspace_root: str):
|
||||
"""
|
||||
Migrate API keys from config.json to .env file if not already set
|
||||
|
||||
Sync API keys from config.json to .env file.
|
||||
Adds new keys and updates changed values on each startup.
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace directory path (not used, kept for compatibility)
|
||||
"""
|
||||
from config import conf
|
||||
import os
|
||||
|
||||
# Mapping from config.json keys to environment variable names
|
||||
key_mapping = {
|
||||
"open_ai_api_key": "OPENAI_API_KEY",
|
||||
"open_ai_api_base": "OPENAI_API_BASE",
|
||||
@@ -421,10 +543,9 @@ class AgentBridge:
|
||||
"linkai_api_key": "LINKAI_API_KEY",
|
||||
}
|
||||
|
||||
# Use fixed secure location for .env file
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
|
||||
# Read existing env vars from .env file
|
||||
# Read existing env vars (key -> value)
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
@@ -432,49 +553,110 @@ class AgentBridge:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, _ = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = True
|
||||
key, val = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = val.strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to read .env file: {e}")
|
||||
|
||||
# Check which keys need to be migrated
|
||||
keys_to_migrate = {}
|
||||
# Sync config.json values into .env (add/update/remove)
|
||||
updated = False
|
||||
for config_key, env_key in key_mapping.items():
|
||||
# Skip if already in .env file
|
||||
if env_key in existing_env_vars:
|
||||
continue
|
||||
|
||||
# Get value from config.json
|
||||
value = conf().get(config_key, "")
|
||||
if value and value.strip(): # Only migrate non-empty values
|
||||
keys_to_migrate[env_key] = value.strip()
|
||||
|
||||
# Log summary if there are keys to skip
|
||||
if existing_env_vars:
|
||||
logger.debug(f"[AgentBridge] {len(existing_env_vars)} env vars already in .env")
|
||||
|
||||
# Write new keys to .env file
|
||||
if keys_to_migrate:
|
||||
raw = conf().get(config_key, "")
|
||||
value = raw.strip() if raw else ""
|
||||
old_value = existing_env_vars.get(env_key)
|
||||
|
||||
if value:
|
||||
if old_value == value:
|
||||
continue
|
||||
existing_env_vars[env_key] = value
|
||||
os.environ[env_key] = value
|
||||
updated = True
|
||||
else:
|
||||
if old_value is None:
|
||||
continue
|
||||
existing_env_vars.pop(env_key, None)
|
||||
os.environ.pop(env_key, None)
|
||||
updated = True
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
try:
|
||||
# Ensure ~/.cow directory and .env file exist
|
||||
env_dir = os.path.dirname(env_file)
|
||||
if not os.path.exists(env_dir):
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
if not os.path.exists(env_file):
|
||||
open(env_file, 'a').close()
|
||||
|
||||
# Append new keys
|
||||
with open(env_file, 'a', encoding='utf-8') as f:
|
||||
f.write('\n# Auto-migrated from config.json\n')
|
||||
for key, value in keys_to_migrate.items():
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
|
||||
with open(env_file, 'w', encoding='utf-8') as f:
|
||||
f.write('# Environment variables for agent\n')
|
||||
f.write('# Auto-managed - synced from config.json on startup\n\n')
|
||||
for key, value in sorted(existing_env_vars.items()):
|
||||
f.write(f'{key}={value}\n')
|
||||
# Also set in current process
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[AgentBridge] Migrated {len(keys_to_migrate)} API keys from config.json to .env: {list(keys_to_migrate.keys())}")
|
||||
|
||||
logger.info(f"[AgentBridge] Synced API keys from config.json to .env")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to migrate API keys: {e}")
|
||||
logger.warning(f"[AgentBridge] Failed to sync API keys: {e}")
|
||||
|
||||
def _persist_messages(
|
||||
self, session_id: str, new_messages: list, channel_type: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
Persist new messages to the conversation store after each agent run.
|
||||
|
||||
Failures are logged but never propagate — they must not interrupt replies.
|
||||
"""
|
||||
if not new_messages:
|
||||
return
|
||||
try:
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
# When deep-thinking display is disabled, strip "thinking" content
|
||||
# blocks before persisting so they don't resurface on history reload.
|
||||
# The in-memory message list keeps them intact for this run's
|
||||
# multi-turn LLM context.
|
||||
thinking_enabled = bool(conf().get("enable_thinking", False))
|
||||
except Exception:
|
||||
thinking_enabled = False
|
||||
|
||||
messages_to_store = new_messages
|
||||
if not thinking_enabled:
|
||||
messages_to_store = self._strip_thinking_blocks(new_messages)
|
||||
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().append_messages(
|
||||
session_id, messages_to_store, channel_type=channel_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentBridge] Failed to persist messages for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _strip_thinking_blocks(messages: list) -> list:
|
||||
"""Return a shallow copy of messages with assistant "thinking" blocks removed."""
|
||||
cleaned = []
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
cleaned.append(msg)
|
||||
continue
|
||||
if msg.get("role") != "assistant":
|
||||
cleaned.append(msg)
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
cleaned.append(msg)
|
||||
continue
|
||||
filtered_blocks = [
|
||||
b for b in content
|
||||
if not (isinstance(b, dict) and b.get("type") == "thinking")
|
||||
]
|
||||
if len(filtered_blocks) == len(content):
|
||||
cleaned.append(msg)
|
||||
else:
|
||||
new_msg = dict(msg)
|
||||
new_msg["content"] = filtered_blocks
|
||||
cleaned.append(new_msg)
|
||||
return cleaned
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""
|
||||
Clear a specific session's agent and conversation history
|
||||
|
||||
@@ -26,8 +26,7 @@ class AgentEventHandler:
|
||||
if context:
|
||||
self.channel = context.kwargs.get("channel") if hasattr(context, "kwargs") else None
|
||||
|
||||
# Track current thinking for channel output
|
||||
self.current_thinking = ""
|
||||
self.current_content = ""
|
||||
self.turn_number = 0
|
||||
|
||||
def handle_event(self, event):
|
||||
@@ -47,6 +46,8 @@ class AgentEventHandler:
|
||||
self._handle_message_update(data)
|
||||
elif event_type == "message_end":
|
||||
self._handle_message_end(data)
|
||||
elif event_type == "reasoning_update":
|
||||
pass
|
||||
elif event_type == "tool_execution_start":
|
||||
self._handle_tool_execution_start(data)
|
||||
elif event_type == "tool_execution_end":
|
||||
@@ -59,30 +60,26 @@ class AgentEventHandler:
|
||||
def _handle_turn_start(self, data):
|
||||
"""Handle turn start event"""
|
||||
self.turn_number = data.get("turn", 0)
|
||||
self.has_tool_calls_in_turn = False
|
||||
self.current_thinking = ""
|
||||
self.current_content = ""
|
||||
|
||||
def _handle_message_update(self, data):
|
||||
"""Handle message update event (streaming text)"""
|
||||
"""Handle message update event (streaming content text)"""
|
||||
delta = data.get("delta", "")
|
||||
self.current_thinking += delta
|
||||
self.current_content += delta
|
||||
|
||||
def _handle_message_end(self, data):
|
||||
"""Handle message end event"""
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
|
||||
# Only send thinking process if followed by tool calls
|
||||
if tool_calls:
|
||||
if self.current_thinking.strip():
|
||||
logger.debug(f"💭 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
|
||||
# Send thinking process to channel
|
||||
self._send_to_channel(f"{self.current_thinking.strip()}")
|
||||
if self.current_content.strip():
|
||||
logger.info(f"💭 {self.current_content.strip()[:200]}{'...' if len(self.current_content) > 200 else ''}")
|
||||
self._send_to_channel(self.current_content.strip())
|
||||
else:
|
||||
# No tool calls = final response (logged at agent_stream level)
|
||||
if self.current_thinking.strip():
|
||||
logger.debug(f"💬 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
|
||||
if self.current_content.strip():
|
||||
logger.debug(f"💬 {self.current_content.strip()[:200]}{'...' if len(self.current_content) > 200 else ''}")
|
||||
|
||||
self.current_thinking = ""
|
||||
self.current_content = ""
|
||||
|
||||
def _handle_tool_execution_start(self, data):
|
||||
"""Handle tool execution start event - logged by agent_stream.py"""
|
||||
@@ -94,15 +91,15 @@ class AgentEventHandler:
|
||||
|
||||
def _send_to_channel(self, message):
|
||||
"""
|
||||
Try to send message to channel
|
||||
|
||||
Args:
|
||||
message: Message to send
|
||||
Try to send intermediate message to channel.
|
||||
Skipped in SSE mode because thinking text is already streamed via on_event.
|
||||
"""
|
||||
if self.context and self.context.get("on_event"):
|
||||
return
|
||||
|
||||
if self.channel:
|
||||
try:
|
||||
from bridge.reply import Reply, ReplyType
|
||||
# Create a Reply object for the message
|
||||
reply = Reply(ReplyType.TEXT, message)
|
||||
self.channel._send(reply, self.context)
|
||||
except Exception as e:
|
||||
|
||||
@@ -77,10 +77,6 @@ class AgentInitializer:
|
||||
# Initialize skill manager
|
||||
skill_manager = self._initialize_skill_manager(workspace_root, session_id)
|
||||
|
||||
# Check if first conversation
|
||||
from agent.prompt.workspace import is_first_conversation, mark_conversation_started
|
||||
is_first = is_first_conversation(workspace_root)
|
||||
|
||||
# Build system prompt
|
||||
prompt_builder = PromptBuilder(workspace_dir=workspace_root, language="zh")
|
||||
runtime_info = self._get_runtime_info(workspace_root)
|
||||
@@ -91,12 +87,8 @@ class AgentInitializer:
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
is_first_conversation=is_first
|
||||
)
|
||||
|
||||
if is_first:
|
||||
mark_conversation_started(workspace_root)
|
||||
|
||||
# Get cost control parameters
|
||||
from config import conf
|
||||
max_steps = conf().get("agent_max_steps", 20)
|
||||
@@ -115,11 +107,135 @@ class AgentInitializer:
|
||||
runtime_info=runtime_info # Pass runtime_info for dynamic time updates
|
||||
)
|
||||
|
||||
# Attach memory manager
|
||||
# Attach memory manager and share LLM model for summarization
|
||||
if memory_manager:
|
||||
agent.memory_manager = memory_manager
|
||||
|
||||
if hasattr(agent, 'model') and agent.model:
|
||||
memory_manager.flush_manager.llm_model = agent.model
|
||||
|
||||
# Restore persisted conversation history for this session
|
||||
if session_id:
|
||||
self._restore_conversation_history(agent, session_id)
|
||||
|
||||
# Start daily memory flush timer (once, on first agent init regardless of session)
|
||||
self._start_daily_flush_timer()
|
||||
|
||||
return agent
|
||||
|
||||
def _restore_conversation_history(self, agent, session_id: str) -> None:
|
||||
"""
|
||||
Load persisted conversation messages from SQLite and inject them
|
||||
into the agent's in-memory message list.
|
||||
|
||||
Only user text and assistant text are restored. Tool call chains
|
||||
(tool_use / tool_result) are stripped out because:
|
||||
1. They are intermediate process, the value is already in the final
|
||||
assistant text reply.
|
||||
2. They consume massive context tokens (often 80%+ of history).
|
||||
3. Different models have incompatible tool message formats, so
|
||||
restoring tool chains across model switches causes 400 errors.
|
||||
4. Eliminates the entire class of tool_use/tool_result pairing bugs.
|
||||
"""
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
store = get_conversation_store()
|
||||
max_turns = conf().get("agent_max_context_turns", 20)
|
||||
restore_turns = max(3, max_turns // 6)
|
||||
saved = store.load_messages(session_id, max_turns=restore_turns)
|
||||
if saved:
|
||||
filtered = self._filter_text_only_messages(saved)
|
||||
if filtered:
|
||||
with agent.messages_lock:
|
||||
agent.messages = filtered
|
||||
logger.debug(
|
||||
f"[AgentInitializer] Restored {len(filtered)} text messages "
|
||||
f"(from {len(saved)} total, {restore_turns} turns cap) "
|
||||
f"for session={session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentInitializer] Failed to restore conversation history for "
|
||||
f"session={session_id}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_text_only_messages(messages: list) -> list:
|
||||
"""
|
||||
Extract clean user/assistant turn pairs from raw message history.
|
||||
|
||||
Groups messages into turns (each starting with a real user query),
|
||||
then keeps only:
|
||||
- The first user text in each turn (the actual user input)
|
||||
- The last assistant text in each turn (the final answer)
|
||||
|
||||
All tool_use, tool_result, intermediate assistant thoughts, and
|
||||
internal hint messages injected by the agent loop are discarded.
|
||||
"""
|
||||
|
||||
def _extract_text(content) -> str:
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
def _is_real_user_msg(msg: dict) -> bool:
|
||||
"""True for actual user input, False for tool_result or internal hints."""
|
||||
if msg.get("role") != "user":
|
||||
return False
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
has_tool_result = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
)
|
||||
if has_tool_result:
|
||||
return False
|
||||
text = _extract_text(content)
|
||||
return bool(text)
|
||||
|
||||
# Group into turns: each turn starts with a real user message
|
||||
turns = []
|
||||
current_turn = None
|
||||
for msg in messages:
|
||||
if _is_real_user_msg(msg):
|
||||
if current_turn is not None:
|
||||
turns.append(current_turn)
|
||||
current_turn = {"user": msg, "assistants": []}
|
||||
elif current_turn is not None and msg.get("role") == "assistant":
|
||||
text = _extract_text(msg.get("content"))
|
||||
if text:
|
||||
current_turn["assistants"].append(text)
|
||||
if current_turn is not None:
|
||||
turns.append(current_turn)
|
||||
|
||||
# Build result: one user msg + one assistant msg per turn
|
||||
filtered = []
|
||||
for turn in turns:
|
||||
user_text = _extract_text(turn["user"].get("content"))
|
||||
if not user_text:
|
||||
continue
|
||||
filtered.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if turn["assistants"]:
|
||||
final_reply = turn["assistants"][-1]
|
||||
filtered.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": final_reply}]
|
||||
})
|
||||
|
||||
return filtered
|
||||
|
||||
def _load_env_file(self):
|
||||
"""Load environment variables from .env file"""
|
||||
@@ -148,12 +264,11 @@ class AgentInitializer:
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
from config import conf
|
||||
|
||||
# Get OpenAI config
|
||||
# Initialize embedding provider (prefer OpenAI, fallback to LinkAI)
|
||||
embedding_provider = None
|
||||
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
|
||||
# Initialize embedding provider
|
||||
embedding_provider = None
|
||||
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
embedding_provider = create_embedding_provider(
|
||||
@@ -166,6 +281,22 @@ class AgentInitializer:
|
||||
logger.info("[AgentInitializer] OpenAI embedding initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
|
||||
|
||||
if embedding_provider is None:
|
||||
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
|
||||
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="linkai",
|
||||
model="text-embedding-3-small",
|
||||
api_key=linkai_api_key,
|
||||
api_base=f"{linkai_api_base}/v1"
|
||||
)
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] LinkAI embedding initialized (fallback)")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
|
||||
|
||||
# Create memory manager
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
@@ -235,7 +366,7 @@ class AgentInitializer:
|
||||
|
||||
if tool:
|
||||
# Apply workspace config to file operation tools
|
||||
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']:
|
||||
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls', 'web_fetch', 'send', 'browser']:
|
||||
tool.config = file_config
|
||||
tool.cwd = file_config.get("cwd", getattr(tool, 'cwd', None))
|
||||
if 'memory_manager' in file_config:
|
||||
@@ -283,7 +414,14 @@ class AgentInitializer:
|
||||
tool.scheduler_service = scheduler_service
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = conf().get("channel_type", "unknown")
|
||||
raw_ct = conf().get("channel_type", "unknown")
|
||||
if isinstance(raw_ct, list):
|
||||
ct = raw_ct[0] if raw_ct else "unknown"
|
||||
elif isinstance(raw_ct, str) and "," in raw_ct:
|
||||
ct = raw_ct.split(",")[0].strip()
|
||||
else:
|
||||
ct = raw_ct
|
||||
tool.config["channel_type"] = ct
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to inject scheduler dependencies: {e}")
|
||||
|
||||
@@ -291,7 +429,7 @@ class AgentInitializer:
|
||||
"""Initialize skill manager"""
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
skill_manager = SkillManager(workspace_dir=workspace_root)
|
||||
skill_manager = SkillManager(custom_dir=os.path.join(workspace_root, "skills"))
|
||||
return skill_manager
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize SkillManager: {e}")
|
||||
@@ -327,10 +465,14 @@ class AgentInitializer:
|
||||
'timezone': timezone_name
|
||||
}
|
||||
|
||||
def get_model():
|
||||
"""Get current model name dynamically from config"""
|
||||
return conf().get("model", "unknown")
|
||||
|
||||
return {
|
||||
"model": conf().get("model", "unknown"),
|
||||
"_get_model": get_model,
|
||||
"workspace": workspace_root,
|
||||
"channel": conf().get("channel_type", "unknown"),
|
||||
"channel": ", ".join(conf().get("channel_type")) if isinstance(conf().get("channel_type"), list) else conf().get("channel_type", "unknown"),
|
||||
"_get_current_time": get_current_time # Dynamic time function
|
||||
}
|
||||
|
||||
@@ -348,7 +490,7 @@ class AgentInitializer:
|
||||
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
|
||||
# Read existing env vars
|
||||
# Read existing env vars (key -> value)
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
@@ -356,35 +498,123 @@ class AgentInitializer:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, _ = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = True
|
||||
key, val = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = val.strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to read .env file: {e}")
|
||||
|
||||
# Check which keys need migration
|
||||
keys_to_migrate = {}
|
||||
# Sync config.json values into .env (add/update/remove)
|
||||
updated = False
|
||||
for config_key, env_key in key_mapping.items():
|
||||
if env_key in existing_env_vars:
|
||||
continue
|
||||
value = conf().get(config_key, "")
|
||||
if value and value.strip():
|
||||
keys_to_migrate[env_key] = value.strip()
|
||||
|
||||
# Write new keys
|
||||
if keys_to_migrate:
|
||||
raw = conf().get(config_key, "")
|
||||
value = raw.strip() if raw else ""
|
||||
old_value = existing_env_vars.get(env_key)
|
||||
|
||||
if value:
|
||||
if old_value == value:
|
||||
continue
|
||||
existing_env_vars[env_key] = value
|
||||
os.environ[env_key] = value
|
||||
updated = True
|
||||
else:
|
||||
if old_value is None:
|
||||
continue
|
||||
existing_env_vars.pop(env_key, None)
|
||||
os.environ.pop(env_key, None)
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
try:
|
||||
env_dir = os.path.dirname(env_file)
|
||||
if not os.path.exists(env_dir):
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
if not os.path.exists(env_file):
|
||||
open(env_file, 'a').close()
|
||||
|
||||
with open(env_file, 'a', encoding='utf-8') as f:
|
||||
f.write('\n# Auto-migrated from config.json\n')
|
||||
for key, value in keys_to_migrate.items():
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
|
||||
# Rewrite the entire .env file to ensure consistency
|
||||
with open(env_file, 'w', encoding='utf-8') as f:
|
||||
f.write('# Environment variables for agent\n')
|
||||
f.write('# Auto-managed - synced from config.json on startup\n\n')
|
||||
for key, value in sorted(existing_env_vars.items()):
|
||||
f.write(f'{key}={value}\n')
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[AgentInitializer] Migrated {len(keys_to_migrate)} API keys to .env: {list(keys_to_migrate.keys())}")
|
||||
|
||||
logger.info(f"[AgentInitializer] Synced API keys from config.json to .env")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to migrate API keys: {e}")
|
||||
logger.warning(f"[AgentInitializer] Failed to sync API keys: {e}")
|
||||
|
||||
def _start_daily_flush_timer(self):
|
||||
"""Start a background thread that flushes all agents' memory daily at 23:55."""
|
||||
if getattr(self.agent_bridge, '_daily_flush_started', False):
|
||||
return
|
||||
self.agent_bridge._daily_flush_started = True
|
||||
|
||||
import threading
|
||||
|
||||
def _daily_flush_loop():
|
||||
import random
|
||||
while True:
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
jitter_min = random.randint(50, 55)
|
||||
jitter_sec = random.randint(0, 59)
|
||||
target = now.replace(hour=23, minute=jitter_min, second=jitter_sec, microsecond=0)
|
||||
if target <= now:
|
||||
target += datetime.timedelta(days=1)
|
||||
wait_seconds = (target - now).total_seconds()
|
||||
logger.info(f"[DailyFlush] Next flush at {target.strftime('%Y-%m-%d %H:%M:%S')} (in {wait_seconds/3600:.1f}h)")
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
self._flush_all_agents()
|
||||
except Exception as e:
|
||||
logger.warning(f"[DailyFlush] Error in daily flush loop: {e}")
|
||||
time.sleep(3600)
|
||||
|
||||
t = threading.Thread(target=_daily_flush_loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
def _flush_all_agents(self):
|
||||
"""Flush memory for all active agent sessions, then run Deep Dream."""
|
||||
agents = []
|
||||
if self.agent_bridge.default_agent:
|
||||
agents.append(("default", self.agent_bridge.default_agent))
|
||||
for sid, agent in self.agent_bridge.agents.items():
|
||||
agents.append((sid, agent))
|
||||
|
||||
if not agents:
|
||||
return
|
||||
|
||||
# Phase 1: flush daily summaries
|
||||
flushed = 0
|
||||
flush_threads = []
|
||||
dream_candidate = None
|
||||
for label, agent in agents:
|
||||
try:
|
||||
if not agent.memory_manager:
|
||||
continue
|
||||
with agent.messages_lock:
|
||||
messages = list(agent.messages)
|
||||
if not messages:
|
||||
continue
|
||||
result = agent.memory_manager.flush_manager.create_daily_summary(messages)
|
||||
if result:
|
||||
flushed += 1
|
||||
t = agent.memory_manager.flush_manager._last_flush_thread
|
||||
if t:
|
||||
flush_threads.append(t)
|
||||
if dream_candidate is None:
|
||||
dream_candidate = agent.memory_manager.flush_manager
|
||||
except Exception as e:
|
||||
logger.warning(f"[DailyFlush] Failed for session {label}: {e}")
|
||||
|
||||
if flushed:
|
||||
logger.info(f"[DailyFlush] Flushed {flushed}/{len(agents)} agent session(s)")
|
||||
|
||||
# Wait for all flush threads to finish before dreaming
|
||||
for t in flush_threads:
|
||||
t.join(timeout=60)
|
||||
|
||||
# Phase 2: Deep Dream — distill daily memories → MEMORY.md + dream diary
|
||||
if dream_candidate:
|
||||
try:
|
||||
result = dream_candidate.deep_dream()
|
||||
if result:
|
||||
logger.info("[DeepDream] Memory distillation completed successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DeepDream] Failed: {e}")
|
||||
|
||||
@@ -13,7 +13,7 @@ from voice.factory import create_voice
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
self.btype = {
|
||||
"chat": const.CHATGPT,
|
||||
"chat": const.OPENAI,
|
||||
"voice_to_text": conf().get("voice_to_text", "openai"),
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
@@ -39,11 +39,8 @@ class Bridge(object):
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
if model_type in [const.QWEN, const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
# Support Qwen3 and other DashScope models
|
||||
if model_type and (model_type.startswith("qwen") or model_type.startswith("qwq") or model_type.startswith("qvq")):
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and model_type.startswith("gemini"):
|
||||
@@ -55,6 +52,14 @@ class Bridge(object):
|
||||
|
||||
if model_type in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
if model_type and model_type.startswith("kimi"):
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
|
||||
if model_type and model_type.startswith("doubao"):
|
||||
self.btype["chat"] = const.DOUBAO
|
||||
|
||||
if model_type and model_type.startswith("deepseek"):
|
||||
self.btype["chat"] = const.DEEPSEEK
|
||||
|
||||
if model_type in [const.MODELSCOPE]:
|
||||
self.btype["chat"] = const.MODELSCOPE
|
||||
|
||||
@@ -13,12 +13,44 @@ class Channel(object):
|
||||
channel_type = ""
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
||||
|
||||
def __init__(self):
|
||||
import threading
|
||||
self._startup_event = threading.Event()
|
||||
self._startup_error = None
|
||||
self.cloud_mode = False # set to True by ChannelManager when running with cloud client
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
init channel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def report_startup_success(self):
|
||||
self._startup_error = None
|
||||
self._startup_event.set()
|
||||
|
||||
def report_startup_error(self, error: str):
|
||||
self._startup_error = error
|
||||
self._startup_event.set()
|
||||
|
||||
def wait_startup(self, timeout: float = 3) -> (bool, str):
|
||||
"""
|
||||
Wait for channel startup result.
|
||||
Returns (success: bool, error_msg: str).
|
||||
"""
|
||||
ready = self._startup_event.wait(timeout=timeout)
|
||||
if not ready:
|
||||
return True, ""
|
||||
if self._startup_error:
|
||||
return False, self._startup_error
|
||||
return True, ""
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
stop channel gracefully, called before restart
|
||||
"""
|
||||
pass
|
||||
|
||||
def handle_text(self, msg):
|
||||
"""
|
||||
process received msg
|
||||
@@ -51,11 +83,14 @@ class Channel(object):
|
||||
if context and "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
|
||||
# Read on_event callback injected by the channel (e.g. web SSE)
|
||||
on_event = context.get("on_event") if context else None
|
||||
|
||||
# Use agent bridge to handle the query
|
||||
return Bridge().fetch_agent_reply(
|
||||
query=query,
|
||||
context=context,
|
||||
on_event=None,
|
||||
on_event=on_event,
|
||||
clear_history=False
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -12,16 +12,7 @@ def create_channel(channel_type) -> Channel:
|
||||
:return: channel instance
|
||||
"""
|
||||
ch = Channel()
|
||||
if channel_type == "wx":
|
||||
from channel.wechat.wechat_channel import WechatChannel
|
||||
ch = WechatChannel()
|
||||
elif channel_type == "wxy":
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
ch = WechatyChannel()
|
||||
elif channel_type == "wcf":
|
||||
from channel.wechat.wcf_channel import WechatfChannel
|
||||
ch = WechatfChannel()
|
||||
elif channel_type == "terminal":
|
||||
if channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == 'web':
|
||||
@@ -36,15 +27,22 @@ def create_channel(channel_type) -> Channel:
|
||||
elif channel_type == "wechatcom_app":
|
||||
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
|
||||
ch = WechatComAppChannel()
|
||||
elif channel_type == "wework":
|
||||
from channel.wework.wework_channel import WeworkChannel
|
||||
ch = WeworkChannel()
|
||||
elif channel_type == const.FEISHU:
|
||||
from channel.feishu.feishu_channel import FeiShuChanel
|
||||
ch = FeiShuChanel()
|
||||
elif channel_type == const.DINGTALK:
|
||||
from channel.dingtalk.dingtalk_channel import DingTalkChanel
|
||||
ch = DingTalkChanel()
|
||||
elif channel_type == const.WECOM_BOT:
|
||||
from channel.wecom_bot.wecom_bot_channel import WecomBotChannel
|
||||
ch = WecomBotChannel()
|
||||
elif channel_type == const.QQ:
|
||||
from channel.qq.qq_channel import QQChannel
|
||||
ch = QQChannel()
|
||||
elif channel_type in (const.WEIXIN, "wx"):
|
||||
from channel.weixin.weixin_channel import WeixinChannel
|
||||
ch = WeixinChannel()
|
||||
channel_type = const.WEIXIN
|
||||
else:
|
||||
raise RuntimeError
|
||||
ch.channel_type = channel_type
|
||||
|
||||
@@ -24,11 +24,17 @@ handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Instance-level attributes so each channel subclass has its own
|
||||
# independent session queue and lock. Previously these were class-level,
|
||||
# which caused contexts from one channel (e.g. Feishu) to be consumed
|
||||
# by another channel's consume() thread (e.g. Web), leading to errors
|
||||
# like "No request_id found in context".
|
||||
self.futures = {}
|
||||
self.sessions = {}
|
||||
self.lock = threading.Lock()
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
@@ -37,9 +43,8 @@ class ChatChannel(Channel):
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
# context首次传入时,origin_ctype是None,
|
||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
# context首次传入时,receiver是None,根据类型设置receiver
|
||||
@@ -292,8 +297,12 @@ class ChatChannel(Channel):
|
||||
logger.debug("[chat_channel] sending reply: {}, context: {}".format(reply, context))
|
||||
|
||||
# 如果是文本回复,尝试提取并发送图片
|
||||
if reply.type == ReplyType.TEXT:
|
||||
# Web channel renders images/videos inline via renderMarkdown,
|
||||
# so skip the extract-and-send step to avoid duplicate media.
|
||||
if reply.type == ReplyType.TEXT and context.get("channel_type") != "web":
|
||||
self._extract_and_send_images(reply, context)
|
||||
elif reply.type == ReplyType.TEXT:
|
||||
self._send(reply, context)
|
||||
# 如果是图片回复但带有文本内容,先发文本再发图片
|
||||
elif reply.type == ReplyType.IMAGE_URL and hasattr(reply, 'text_content') and reply.text_content:
|
||||
# 先发送文本
|
||||
@@ -342,38 +351,30 @@ class ChatChannel(Channel):
|
||||
if media_items:
|
||||
logger.info(f"[chat_channel] Extracted {len(media_items)} media item(s) from reply")
|
||||
|
||||
# 先发送文本(保持原文本不变)
|
||||
# Send text first (the frontend will embed video players via renderMarkdown).
|
||||
logger.info(f"[chat_channel] Sending text content before media: {reply.content[:100]}...")
|
||||
self._send(reply, context)
|
||||
logger.info(f"[chat_channel] Text sent, now sending {len(media_items)} media item(s)")
|
||||
|
||||
# 然后逐个发送媒体文件
|
||||
for i, (url, media_type) in enumerate(media_items):
|
||||
try:
|
||||
# 判断是本地文件还是URL
|
||||
# Determine whether it is a remote URL or a local file.
|
||||
if url.startswith(('http://', 'https://')):
|
||||
# 网络资源
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型发送
|
||||
media_reply = Reply(ReplyType.FILE, url)
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, url)
|
||||
elif os.path.exists(url):
|
||||
# 本地文件
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.FILE, f"file://{url}")
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, f"file://{url}")
|
||||
else:
|
||||
logger.warning(f"[chat_channel] Media file not found or invalid URL: {url}")
|
||||
continue
|
||||
|
||||
# 发送媒体文件(添加小延迟避免频率限制)
|
||||
if i > 0:
|
||||
time.sleep(0.5)
|
||||
self._send(media_reply, context)
|
||||
@@ -426,7 +427,7 @@ class ChatChannel(Channel):
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = [
|
||||
Dequeue(),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)),
|
||||
]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
||||
Unified chat message class for different channel implementations.
|
||||
|
||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
||||
|
||||
|
||||
@@ -90,13 +90,9 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
|
||||
def setup_logger(self):
|
||||
logger = logging.getLogger()
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter('%(asctime)s %(name)-8s %(levelname)-8s %(message)s [%(filename)s:%(lineno)d]'))
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
return logger
|
||||
# Suppress verbose logs from dingtalk_stream SDK
|
||||
logging.getLogger("dingtalk_stream").setLevel(logging.WARNING)
|
||||
return logging.getLogger("DingTalk")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -104,6 +100,9 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
self._stream_client = None
|
||||
self._running = False
|
||||
self._event_loop = None
|
||||
logger.debug("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
@@ -116,12 +115,130 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
# Robot code cache (extracted from incoming messages)
|
||||
self._robot_code = None
|
||||
|
||||
def _open_connection(self, client):
|
||||
"""
|
||||
Open a DingTalk stream connection directly, bypassing SDK's internal error-swallowing.
|
||||
Returns (connection_dict, error_str). On success error_str is empty; on failure
|
||||
connection_dict is None and error_str contains a human-readable message.
|
||||
"""
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://api.dingtalk.com/v1.0/gateway/connections/open",
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
json={
|
||||
"clientId": client.credential.client_id,
|
||||
"clientSecret": client.credential.client_secret,
|
||||
"subscriptions": [{"type": "CALLBACK",
|
||||
"topic": dingtalk_stream.chatbot.ChatbotMessage.TOPIC}],
|
||||
"ua": "dingtalk-sdk-python/cow",
|
||||
"localIp": "",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
body = resp.json()
|
||||
if not resp.ok:
|
||||
code = body.get("code", resp.status_code)
|
||||
message = body.get("message", resp.reason)
|
||||
return None, f"open connection failed: [{code}] {message}"
|
||||
return body, ""
|
||||
except Exception as e:
|
||||
return None, f"open connection failed: {e}"
|
||||
|
||||
def startup(self):
|
||||
import asyncio
|
||||
self.dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
self.dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
self._running = True
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
self._stream_client = client
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
logger.info("[DingTalk] ✅ Stream connected, ready to receive messages")
|
||||
client.start_forever()
|
||||
logger.info("[DingTalk] ✅ Stream client initialized, ready to receive messages")
|
||||
|
||||
# Run the connection loop ourselves instead of delegating to client.start(),
|
||||
# so we can get detailed error messages and respond to stop() quickly.
|
||||
import urllib.parse as _urlparse
|
||||
import websockets as _ws
|
||||
import json as _json
|
||||
client.pre_start()
|
||||
_first_connect = True
|
||||
while self._running:
|
||||
# Open connection using our own request so we get detailed error info.
|
||||
connection, err_msg = self._open_connection(client)
|
||||
|
||||
if connection is None:
|
||||
if _first_connect:
|
||||
logger.warning(f"[DingTalk] {err_msg}")
|
||||
self.report_startup_error(err_msg)
|
||||
_first_connect = False
|
||||
else:
|
||||
logger.warning(f"[DingTalk] {err_msg}, retrying in 10s...")
|
||||
|
||||
# Interruptible sleep: checks _running every 100ms.
|
||||
for _ in range(100):
|
||||
if not self._running:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
if _first_connect:
|
||||
logger.info("[DingTalk] ✅ Connected to DingTalk stream")
|
||||
self.report_startup_success()
|
||||
_first_connect = False
|
||||
else:
|
||||
logger.info("[DingTalk] Reconnected to DingTalk stream")
|
||||
|
||||
# Run the WebSocket session in an asyncio loop.
|
||||
uri = '%s?ticket=%s' % (
|
||||
connection['endpoint'],
|
||||
_urlparse.quote_plus(connection['ticket'])
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
self._event_loop = loop
|
||||
try:
|
||||
async def _session():
|
||||
async with _ws.connect(uri) as websocket:
|
||||
client.websocket = websocket
|
||||
async for raw_message in websocket:
|
||||
json_message = _json.loads(raw_message)
|
||||
result = await client.route_message(json_message)
|
||||
if result == dingtalk_stream.DingTalkStreamClient.TAG_DISCONNECT:
|
||||
break
|
||||
|
||||
loop.run_until_complete(_session())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("[DingTalk] Session loop received stop signal, exiting")
|
||||
break
|
||||
except Exception as e:
|
||||
if not self._running:
|
||||
break
|
||||
logger.warning(f"[DingTalk] Stream session error: {e}, reconnecting in 3s...")
|
||||
for _ in range(30):
|
||||
if not self._running:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
finally:
|
||||
self._event_loop = None
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("[DingTalk] Startup loop exited")
|
||||
|
||||
def stop(self):
|
||||
logger.info("[DingTalk] stop() called, setting _running=False")
|
||||
self._running = False
|
||||
loop = self._event_loop
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
logger.info("[DingTalk] Sent stop signal to event loop")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DingTalk] Error stopping event loop: {e}")
|
||||
self._stream_client = None
|
||||
logger.info("[DingTalk] stop() completed")
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
@@ -458,23 +575,21 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
|
||||
|
||||
# 缓存 robot_code,用于后续图片下载
|
||||
if hasattr(incoming_message, 'robot_code'):
|
||||
self._robot_code_cache = incoming_message.robot_code
|
||||
|
||||
# Debug: 打印完整的 event 数据
|
||||
logger.debug(f"[DingTalk] ===== Incoming Message Debug =====")
|
||||
logger.debug(f"[DingTalk] callback.data keys: {callback.data.keys() if hasattr(callback.data, 'keys') else 'N/A'}")
|
||||
logger.debug(f"[DingTalk] incoming_message attributes: {dir(incoming_message)}")
|
||||
logger.debug(f"[DingTalk] robot_code: {getattr(incoming_message, 'robot_code', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] chatbot_corp_id: {getattr(incoming_message, 'chatbot_corp_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] chatbot_user_id: {getattr(incoming_message, 'chatbot_user_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] conversation_id: {getattr(incoming_message, 'conversation_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] Raw callback.data: {callback.data}")
|
||||
logger.debug(f"[DingTalk] =====================================")
|
||||
|
||||
image_download_handler = self # 传入方法所在的类实例
|
||||
|
||||
# Filter out stale messages from before channel startup (offline backlog)
|
||||
create_at = getattr(incoming_message, 'create_at', None)
|
||||
if create_at:
|
||||
msg_age_s = time.time() - int(create_at) / 1000
|
||||
if msg_age_s > 60:
|
||||
logger.warning(f"[DingTalk] stale msg filtered (age={msg_age_s:.0f}s), "
|
||||
f"msg_id={getattr(incoming_message, 'message_id', 'N/A')}")
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
|
||||
image_download_handler = self
|
||||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
|
||||
|
||||
if dingtalk_msg.is_group:
|
||||
@@ -483,8 +598,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.handle_single(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] process error: {e}")
|
||||
logger.exception(e) # 打印完整堆栈跟踪
|
||||
logger.error(f"[DingTalk] process error: {e}", exc_info=True)
|
||||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
|
||||
|
||||
@time_checker
|
||||
|
||||
@@ -11,7 +11,9 @@
|
||||
@Date 2023/11/19
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import threading
|
||||
@@ -32,17 +34,25 @@ from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
# Suppress verbose logs from Lark SDK
|
||||
logging.getLogger("Lark").setLevel(logging.WARNING)
|
||||
|
||||
URL_VERIFICATION = "url_verification"
|
||||
|
||||
# 尝试导入飞书SDK,如果未安装则websocket模式不可用
|
||||
try:
|
||||
import lark_oapi as lark
|
||||
# Lazy-check for lark_oapi SDK availability without importing it at module level.
|
||||
# The full `import lark_oapi` pulls in 10k+ files and takes 4-10s, so we defer
|
||||
# the actual import to _startup_websocket() where it is needed.
|
||||
LARK_SDK_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||
lark = None # will be populated on first use via _ensure_lark_imported()
|
||||
|
||||
LARK_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
LARK_SDK_AVAILABLE = False
|
||||
logger.warning(
|
||||
"[FeiShu] lark_oapi not installed, websocket mode is not available. Install with: pip install lark-oapi")
|
||||
|
||||
def _ensure_lark_imported():
|
||||
"""Import lark_oapi on first use (takes 4-10s due to 10k+ source files)."""
|
||||
global lark
|
||||
if lark is None:
|
||||
import lark_oapi as _lark
|
||||
lark = _lark
|
||||
return lark
|
||||
|
||||
|
||||
@singleton
|
||||
@@ -56,6 +66,10 @@ class FeiShuChanel(ChatChannel):
|
||||
super().__init__()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
self._http_server = None
|
||||
self._ws_client = None
|
||||
self._ws_thread = None
|
||||
self._bot_open_id = None # cached bot open_id for @-mention matching
|
||||
logger.debug("[FeiShu] app_id={}, app_secret={}, verification_token={}, event_mode={}".format(
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token, self.feishu_event_mode))
|
||||
# 无需群校验和前缀
|
||||
@@ -68,11 +82,66 @@ class FeiShuChanel(ChatChannel):
|
||||
raise Exception("lark_oapi not installed")
|
||||
|
||||
def startup(self):
|
||||
self.feishu_app_id = conf().get('feishu_app_id')
|
||||
self.feishu_app_secret = conf().get('feishu_app_secret')
|
||||
self.feishu_token = conf().get('feishu_token')
|
||||
self.feishu_event_mode = conf().get('feishu_event_mode', 'websocket')
|
||||
self._fetch_bot_open_id()
|
||||
if self.feishu_event_mode == 'websocket':
|
||||
self._startup_websocket()
|
||||
else:
|
||||
self._startup_webhook()
|
||||
|
||||
def _fetch_bot_open_id(self):
|
||||
"""Fetch the bot's own open_id via API so we can match @-mentions without feishu_bot_name."""
|
||||
try:
|
||||
access_token = self.fetch_access_token()
|
||||
if not access_token:
|
||||
logger.warning("[FeiShu] Cannot fetch bot info: no access_token")
|
||||
return
|
||||
headers = {"Authorization": "Bearer " + access_token}
|
||||
resp = requests.get("https://open.feishu.cn/open-apis/bot/v3/info/", headers=headers, timeout=5)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
if data.get("code") == 0:
|
||||
self._bot_open_id = data.get("bot", {}).get("open_id")
|
||||
logger.info(f"[FeiShu] Bot open_id fetched: {self._bot_open_id}")
|
||||
else:
|
||||
logger.warning(f"[FeiShu] Fetch bot info failed: code={data.get('code')}, msg={data.get('msg')}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Fetch bot open_id error: {e}")
|
||||
|
||||
def stop(self):
|
||||
import ctypes
|
||||
logger.info("[FeiShu] stop() called")
|
||||
ws_client = self._ws_client
|
||||
self._ws_client = None
|
||||
ws_thread = self._ws_thread
|
||||
self._ws_thread = None
|
||||
# Interrupt the ws thread first so its blocking start() unblocks
|
||||
if ws_thread and ws_thread.is_alive():
|
||||
try:
|
||||
tid = ws_thread.ident
|
||||
if tid:
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
|
||||
)
|
||||
if res == 1:
|
||||
logger.info("[FeiShu] Interrupted ws thread via ctypes")
|
||||
elif res > 1:
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Error interrupting ws thread: {e}")
|
||||
# lark.ws.Client has no stop() method; thread interruption above is sufficient
|
||||
if self._http_server:
|
||||
try:
|
||||
self._http_server.stop()
|
||||
logger.info("[FeiShu] HTTP server stopped")
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Error stopping HTTP server: {e}")
|
||||
self._http_server = None
|
||||
logger.info("[FeiShu] stop() completed")
|
||||
|
||||
def _startup_webhook(self):
|
||||
"""启动HTTP服务器接收事件(webhook模式)"""
|
||||
logger.debug("[FeiShu] Starting in webhook mode...")
|
||||
@@ -81,21 +150,33 @@ class FeiShuChanel(ChatChannel):
|
||||
)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("feishu_port", 9891)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
func = web.httpserver.StaticMiddleware(app.wsgifunc())
|
||||
func = web.httpserver.LogMiddleware(func)
|
||||
server = web.httpserver.WSGIServer(("0.0.0.0", port), func)
|
||||
self._http_server = server
|
||||
try:
|
||||
server.start()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
server.stop()
|
||||
|
||||
def _startup_websocket(self):
|
||||
"""启动长连接接收事件(websocket模式)"""
|
||||
_ensure_lark_imported()
|
||||
logger.debug("[FeiShu] Starting in websocket mode...")
|
||||
|
||||
# 创建事件处理器
|
||||
def handle_message_event(data: lark.im.v1.P2ImMessageReceiveV1) -> None:
|
||||
"""处理接收消息事件 v2.0"""
|
||||
try:
|
||||
logger.debug(f"[FeiShu] websocket receive event: {lark.JSON.marshal(data, indent=2)}")
|
||||
|
||||
# 转换为标准的event格式
|
||||
event_dict = json.loads(lark.JSON.marshal(data))
|
||||
event = event_dict.get("event", {})
|
||||
msg = event.get("message", {})
|
||||
|
||||
# Skip group messages that don't @-mention the bot (reduce log noise)
|
||||
if msg.get("chat_type") == "group" and not msg.get("mentions") and msg.get("message_type") == "text":
|
||||
return
|
||||
|
||||
logger.debug(f"[FeiShu] websocket receive event: {lark.JSON.marshal(data, indent=2)}")
|
||||
|
||||
# 处理消息
|
||||
self._handle_message_event(event)
|
||||
@@ -108,29 +189,36 @@ class FeiShuChanel(ChatChannel):
|
||||
.register_p2_im_message_receive_v1(handle_message_event) \
|
||||
.build()
|
||||
|
||||
# 尝试连接,如果遇到SSL错误则自动禁用证书验证
|
||||
def start_client_with_retry():
|
||||
"""启动websocket客户端,自动处理SSL证书错误"""
|
||||
# 全局禁用SSL证书验证(在导入lark_oapi之前设置)
|
||||
"""Run ws client in this thread with its own event loop to avoid conflicts."""
|
||||
import asyncio
|
||||
import ssl as ssl_module
|
||||
|
||||
# 保存原始的SSL上下文创建方法
|
||||
original_create_default_context = ssl_module.create_default_context
|
||||
|
||||
def create_unverified_context(*args, **kwargs):
|
||||
"""创建一个不验证证书的SSL上下文"""
|
||||
context = original_create_default_context(*args, **kwargs)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
return context
|
||||
|
||||
# 尝试正常连接,如果失败则禁用SSL验证
|
||||
# lark_oapi.ws.client captures the event loop at module-import time as a module-
|
||||
# level global variable. When a previous ws thread is force-killed via ctypes its
|
||||
# loop may still be marked as "running", which causes the next ws_client.start()
|
||||
# call (in this new thread) to raise "This event loop is already running".
|
||||
# Fix: replace the module-level loop with a brand-new, idle loop before starting.
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
import lark_oapi.ws.client as _lark_ws_client_mod
|
||||
_lark_ws_client_mod.loop = loop
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
startup_error = None
|
||||
for attempt in range(2):
|
||||
try:
|
||||
if attempt == 1:
|
||||
# 第二次尝试:禁用SSL验证
|
||||
logger.warning("[FeiShu] SSL certificate verification disabled due to certificate error. "
|
||||
"This may happen when using corporate proxy or self-signed certificates.")
|
||||
logger.warning("[FeiShu] Retrying with SSL verification disabled...")
|
||||
ssl_module.create_default_context = create_unverified_context
|
||||
ssl_module._create_unverified_context = create_unverified_context
|
||||
|
||||
@@ -138,41 +226,62 @@ class FeiShuChanel(ChatChannel):
|
||||
self.feishu_app_id,
|
||||
self.feishu_app_secret,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.DEBUG if conf().get("debug") else lark.LogLevel.INFO
|
||||
log_level=lark.LogLevel.WARNING
|
||||
)
|
||||
|
||||
self._ws_client = ws_client
|
||||
logger.debug("[FeiShu] Websocket client starting...")
|
||||
ws_client.start()
|
||||
# 如果成功启动,跳出循环
|
||||
break
|
||||
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
logger.info("[FeiShu] Websocket thread received stop signal")
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# 检查是否是SSL证书验证错误
|
||||
is_ssl_error = "CERTIFICATE_VERIFY_FAILED" in error_msg or "certificate verify failed" in error_msg.lower()
|
||||
|
||||
is_ssl_error = ("CERTIFICATE_VERIFY_FAILED" in error_msg
|
||||
or "certificate verify failed" in error_msg.lower())
|
||||
if is_ssl_error and attempt == 0:
|
||||
# 第一次遇到SSL错误,记录日志并继续循环(下次会禁用验证)
|
||||
logger.warning(f"[FeiShu] SSL certificate verification failed: {error_msg}")
|
||||
logger.info("[FeiShu] Retrying connection with SSL verification disabled...")
|
||||
logger.warning(f"[FeiShu] SSL error: {error_msg}, retrying...")
|
||||
continue
|
||||
else:
|
||||
# 其他错误或禁用验证后仍失败,抛出异常
|
||||
logger.error(f"[FeiShu] Websocket client error: {e}", exc_info=True)
|
||||
# 恢复原始方法
|
||||
ssl_module.create_default_context = original_create_default_context
|
||||
raise
|
||||
logger.error(f"[FeiShu] Websocket client error: {e}", exc_info=True)
|
||||
startup_error = error_msg
|
||||
ssl_module.create_default_context = original_create_default_context
|
||||
break
|
||||
if startup_error:
|
||||
self.report_startup_error(startup_error)
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[FeiShu] Websocket thread exited")
|
||||
|
||||
# 注意:不恢复原始方法,因为ws_client.start()会持续运行
|
||||
|
||||
# 在新线程中启动客户端,避免阻塞主线程
|
||||
ws_thread = threading.Thread(target=start_client_with_retry, daemon=True)
|
||||
self._ws_thread = ws_thread
|
||||
ws_thread.start()
|
||||
|
||||
# 保持主线程运行
|
||||
logger.info("[FeiShu] ✅ Websocket connected, ready to receive messages")
|
||||
logger.info("[FeiShu] ✅ Websocket thread started, ready to receive messages")
|
||||
ws_thread.join()
|
||||
|
||||
def _is_mention_bot(self, mentions: list) -> bool:
|
||||
"""Check whether any mention in the list refers to this bot.
|
||||
|
||||
Priority:
|
||||
1. Match by open_id (obtained from /bot/v3/info at startup, no config needed)
|
||||
2. Fallback to feishu_bot_name config for backward compatibility
|
||||
3. If neither is available, assume the first mention is the bot (Feishu only
|
||||
delivers group messages that @-mention the bot, so this is usually correct)
|
||||
"""
|
||||
if self._bot_open_id:
|
||||
return any(
|
||||
m.get("id", {}).get("open_id") == self._bot_open_id
|
||||
for m in mentions
|
||||
)
|
||||
bot_name = conf().get("feishu_bot_name")
|
||||
if bot_name:
|
||||
return any(m.get("name") == bot_name for m in mentions)
|
||||
# Feishu event subscription only delivers messages that @-mention the bot,
|
||||
# so reaching here means the bot was indeed mentioned.
|
||||
return True
|
||||
|
||||
def _handle_message_event(self, event: dict):
|
||||
"""
|
||||
处理消息事件的核心逻辑
|
||||
@@ -191,6 +300,15 @@ class FeiShuChanel(ChatChannel):
|
||||
return
|
||||
self.receivedMsgs[msg_id] = True
|
||||
|
||||
# Filter out stale messages from before channel startup (offline backlog)
|
||||
import time as _time
|
||||
create_time_ms = msg.get("create_time")
|
||||
if create_time_ms:
|
||||
msg_age_s = _time.time() - int(create_time_ms) / 1000
|
||||
if msg_age_s > 60:
|
||||
logger.warning(f"[FeiShu] stale msg filtered (age={msg_age_s:.0f}s), msg_id={msg_id}")
|
||||
return
|
||||
|
||||
is_group = False
|
||||
chat_type = msg.get("chat_type")
|
||||
|
||||
@@ -198,10 +316,9 @@ class FeiShuChanel(ChatChannel):
|
||||
if not msg.get("mentions") and msg.get("message_type") == "text":
|
||||
# 群聊中未@不响应
|
||||
return
|
||||
if msg.get("mentions") and msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get(
|
||||
"message_type") == "text":
|
||||
# 不是@机器人,不响应
|
||||
return
|
||||
if msg.get("mentions") and msg.get("message_type") == "text":
|
||||
if not self._is_mention_bot(msg.get("mentions")):
|
||||
return
|
||||
# 群聊
|
||||
is_group = True
|
||||
receive_id_type = "chat_id"
|
||||
@@ -337,7 +454,7 @@ class FeiShuChanel(ChatChannel):
|
||||
can_reply = is_group and msg and hasattr(msg, 'msg_id') and msg.msg_id
|
||||
|
||||
# Build content JSON
|
||||
content_json = json.dumps(reply_content) if content_key is None else json.dumps({content_key: reply_content})
|
||||
content_json = json.dumps(reply_content, ensure_ascii=False) if content_key is None else json.dumps({content_key: reply_content}, ensure_ascii=False)
|
||||
logger.debug(f"[FeiShu] Sending message: msg_type={msg_type}, content={content_json[:200]}")
|
||||
|
||||
if can_reply:
|
||||
@@ -677,6 +794,8 @@ class FeiShuChanel(ChatChannel):
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
|
||||
0
channel/qq/__init__.py
Normal file
0
channel/qq/__init__.py
Normal file
736
channel/qq/qq_channel.py
Normal file
736
channel/qq/qq_channel.py
Normal file
@@ -0,0 +1,736 @@
|
||||
"""
|
||||
QQ Bot channel via WebSocket long connection.
|
||||
|
||||
Supports:
|
||||
- Group chat (@bot), single chat (C2C), guild channel, guild DM
|
||||
- Text / image / file message send & receive
|
||||
- Heartbeat keep-alive and auto-reconnect with session resume
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
import websocket
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.qq.qq_message import QQMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.ws_client_compat import websocket_app_run_forever
|
||||
from config import conf
|
||||
|
||||
# Rich media file_type constants
|
||||
QQ_FILE_TYPE_IMAGE = 1
|
||||
QQ_FILE_TYPE_VIDEO = 2
|
||||
QQ_FILE_TYPE_VOICE = 3
|
||||
QQ_FILE_TYPE_FILE = 4
|
||||
|
||||
QQ_API_BASE = "https://api.sgroup.qq.com"
|
||||
|
||||
# Intents: GROUP_AND_C2C_EVENT(1<<25) | PUBLIC_GUILD_MESSAGES(1<<30)
|
||||
DEFAULT_INTENTS = (1 << 25) | (1 << 30)
|
||||
|
||||
# OpCode constants
|
||||
OP_DISPATCH = 0
|
||||
OP_HEARTBEAT = 1
|
||||
OP_IDENTIFY = 2
|
||||
OP_RESUME = 6
|
||||
OP_RECONNECT = 7
|
||||
OP_INVALID_SESSION = 9
|
||||
OP_HELLO = 10
|
||||
OP_HEARTBEAT_ACK = 11
|
||||
|
||||
# Resumable error codes
|
||||
RESUMABLE_CLOSE_CODES = {4008, 4009}
|
||||
|
||||
|
||||
@singleton
|
||||
class QQChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app_id = ""
|
||||
self.app_secret = ""
|
||||
|
||||
self._access_token = ""
|
||||
self._token_expires_at = 0
|
||||
|
||||
self._ws = None
|
||||
self._ws_thread = None
|
||||
self._heartbeat_thread = None
|
||||
self._connected = False
|
||||
self._stop_event = threading.Event()
|
||||
self._token_lock = threading.Lock()
|
||||
|
||||
self._session_id = None
|
||||
self._last_seq = None
|
||||
self._heartbeat_interval = 45000
|
||||
self._can_resume = False
|
||||
|
||||
self.received_msgs = ExpiredDict(60 * 60 * 7.1)
|
||||
self._msg_seq_counter = {}
|
||||
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.app_id = conf().get("qq_app_id", "")
|
||||
self.app_secret = conf().get("qq_app_secret", "")
|
||||
|
||||
if not self.app_id or not self.app_secret:
|
||||
err = "[QQ] qq_app_id and qq_app_secret are required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._refresh_access_token()
|
||||
if not self._access_token:
|
||||
err = "[QQ] Failed to get initial access_token"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_ws()
|
||||
|
||||
def stop(self):
|
||||
logger.info("[QQ] stop() called")
|
||||
self._stop_event.set()
|
||||
if self._ws:
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Access Token
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _refresh_access_token(self):
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://bots.qq.com/app/getAppAccessToken",
|
||||
json={"appId": self.app_id, "clientSecret": self.app_secret},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._access_token = data.get("access_token", "")
|
||||
expires_in = int(data.get("expires_in", 7200))
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
logger.debug(f"[QQ] Access token refreshed, expires_in={expires_in}s")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to refresh access_token: {e}")
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
with self._token_lock:
|
||||
if time.time() >= self._token_expires_at:
|
||||
self._refresh_access_token()
|
||||
return self._access_token
|
||||
|
||||
def _get_auth_headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": f"QQBot {self._get_access_token()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket connection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_ws_url(self) -> str:
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{QQ_API_BASE}/gateway",
|
||||
headers=self._get_auth_headers(),
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
url = resp.json().get("url", "")
|
||||
logger.debug(f"[QQ] Gateway URL: {url}")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to get gateway URL: {e}")
|
||||
return ""
|
||||
|
||||
def _start_ws(self):
|
||||
ws_url = self._get_ws_url()
|
||||
if not ws_url:
|
||||
logger.error("[QQ] Cannot start WebSocket without gateway URL")
|
||||
self.report_startup_error("Failed to get gateway URL")
|
||||
return
|
||||
|
||||
def _on_open(ws):
|
||||
logger.debug("[QQ] WebSocket connected, waiting for Hello...")
|
||||
|
||||
def _on_message(ws, raw):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
self._handle_ws_message(data)
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to handle ws message: {e}", exc_info=True)
|
||||
|
||||
def _on_error(ws, error):
|
||||
logger.error(f"[QQ] WebSocket error: {error}")
|
||||
|
||||
def _on_close(ws, close_status_code, close_msg):
|
||||
logger.warning(f"[QQ] WebSocket closed: status={close_status_code}, msg={close_msg}")
|
||||
self._connected = False
|
||||
if not self._stop_event.is_set():
|
||||
if close_status_code in RESUMABLE_CLOSE_CODES and self._session_id:
|
||||
self._can_resume = True
|
||||
logger.info("[QQ] Will attempt resume in 3s...")
|
||||
time.sleep(3)
|
||||
else:
|
||||
self._can_resume = False
|
||||
logger.info("[QQ] Will reconnect in 5s...")
|
||||
time.sleep(5)
|
||||
if not self._stop_event.is_set():
|
||||
self._start_ws()
|
||||
|
||||
self._ws = websocket.WebSocketApp(
|
||||
ws_url,
|
||||
on_open=_on_open,
|
||||
on_message=_on_message,
|
||||
on_error=_on_error,
|
||||
on_close=_on_close,
|
||||
)
|
||||
|
||||
def run_forever():
|
||||
try:
|
||||
websocket_app_run_forever(self._ws, ping_interval=0, reconnect=0)
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
logger.info("[QQ] WebSocket thread interrupted")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] WebSocket run_forever error: {e}")
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_forever, daemon=True)
|
||||
self._ws_thread.start()
|
||||
self._ws_thread.join()
|
||||
|
||||
def _ws_send(self, data: dict):
|
||||
if self._ws:
|
||||
self._ws.send(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Identify & Resume & Heartbeat
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_identify(self):
|
||||
self._ws_send({
|
||||
"op": OP_IDENTIFY,
|
||||
"d": {
|
||||
"token": f"QQBot {self._get_access_token()}",
|
||||
"intents": DEFAULT_INTENTS,
|
||||
"shard": [0, 1],
|
||||
"properties": {
|
||||
"$os": "linux",
|
||||
"$browser": "chatgpt-on-wechat",
|
||||
"$device": "chatgpt-on-wechat",
|
||||
},
|
||||
},
|
||||
})
|
||||
logger.debug(f"[QQ] Identify sent with intents={DEFAULT_INTENTS}")
|
||||
|
||||
def _send_resume(self):
|
||||
self._ws_send({
|
||||
"op": OP_RESUME,
|
||||
"d": {
|
||||
"token": f"QQBot {self._get_access_token()}",
|
||||
"session_id": self._session_id,
|
||||
"seq": self._last_seq,
|
||||
},
|
||||
})
|
||||
logger.debug(f"[QQ] Resume sent: session_id={self._session_id}, seq={self._last_seq}")
|
||||
|
||||
def _start_heartbeat(self, interval_ms: int):
|
||||
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
|
||||
return
|
||||
self._heartbeat_interval = interval_ms
|
||||
interval_sec = interval_ms / 1000.0
|
||||
|
||||
def heartbeat_loop():
|
||||
while not self._stop_event.is_set() and self._connected:
|
||||
try:
|
||||
self._ws_send({
|
||||
"op": OP_HEARTBEAT,
|
||||
"d": self._last_seq,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"[QQ] Heartbeat send failed: {e}")
|
||||
break
|
||||
self._stop_event.wait(interval_sec)
|
||||
|
||||
self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
|
||||
self._heartbeat_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Incoming message dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_ws_message(self, data: dict):
|
||||
op = data.get("op")
|
||||
d = data.get("d")
|
||||
t = data.get("t")
|
||||
s = data.get("s")
|
||||
|
||||
if s is not None:
|
||||
self._last_seq = s
|
||||
|
||||
if op == OP_HELLO:
|
||||
heartbeat_interval = d.get("heartbeat_interval", 45000) if d else 45000
|
||||
logger.debug(f"[QQ] Received Hello, heartbeat_interval={heartbeat_interval}ms")
|
||||
self._heartbeat_interval = heartbeat_interval
|
||||
if self._can_resume and self._session_id:
|
||||
self._send_resume()
|
||||
else:
|
||||
self._send_identify()
|
||||
|
||||
elif op == OP_HEARTBEAT_ACK:
|
||||
pass
|
||||
|
||||
elif op == OP_HEARTBEAT:
|
||||
self._ws_send({"op": OP_HEARTBEAT, "d": self._last_seq})
|
||||
|
||||
elif op == OP_RECONNECT:
|
||||
logger.warning("[QQ] Server requested reconnect")
|
||||
self._can_resume = True
|
||||
if self._ws:
|
||||
self._ws.close()
|
||||
|
||||
elif op == OP_INVALID_SESSION:
|
||||
logger.warning("[QQ] Invalid session, re-identifying...")
|
||||
self._session_id = None
|
||||
self._can_resume = False
|
||||
time.sleep(2)
|
||||
self._send_identify()
|
||||
|
||||
elif op == OP_DISPATCH:
|
||||
if t == "READY":
|
||||
self._session_id = d.get("session_id", "")
|
||||
user = d.get("user", {})
|
||||
bot_name = user.get('username', '')
|
||||
logger.info(f"[QQ] ✅ Connected successfully (bot={bot_name})")
|
||||
self._connected = True
|
||||
self._can_resume = False
|
||||
self._start_heartbeat(self._heartbeat_interval)
|
||||
self.report_startup_success()
|
||||
|
||||
elif t == "RESUMED":
|
||||
logger.info("[QQ] Session resumed successfully")
|
||||
self._connected = True
|
||||
self._can_resume = False
|
||||
self._start_heartbeat(self._heartbeat_interval)
|
||||
|
||||
elif t in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE",
|
||||
"AT_MESSAGE_CREATE", "DIRECT_MESSAGE_CREATE"):
|
||||
self._handle_msg_event(d, t)
|
||||
|
||||
elif t in ("GROUP_ADD_ROBOT", "FRIEND_ADD"):
|
||||
logger.info(f"[QQ] Event: {t}")
|
||||
|
||||
else:
|
||||
logger.debug(f"[QQ] Dispatch event: {t}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message event handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_msg_event(self, event_data: dict, event_type: str):
|
||||
msg_id = event_data.get("id", "")
|
||||
if self.received_msgs.get(msg_id):
|
||||
logger.debug(f"[QQ] Duplicate msg filtered: {msg_id}")
|
||||
return
|
||||
self.received_msgs[msg_id] = True
|
||||
|
||||
try:
|
||||
qq_msg = QQMessage(event_data, event_type)
|
||||
except NotImplementedError as e:
|
||||
logger.warning(f"[QQ] {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to parse message: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
is_group = qq_msg.is_group
|
||||
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
if is_group:
|
||||
session_id = qq_msg.other_user_id
|
||||
else:
|
||||
session_id = qq_msg.from_user_id
|
||||
|
||||
if qq_msg.ctype == ContextType.IMAGE:
|
||||
if hasattr(qq_msg, "image_path") and qq_msg.image_path:
|
||||
file_cache.add(session_id, qq_msg.image_path, file_type="image")
|
||||
logger.info(f"[QQ] Image cached for session {session_id}")
|
||||
return
|
||||
|
||||
if qq_msg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
file_refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
fpath = fi["path"]
|
||||
if ftype == "image":
|
||||
file_refs.append(f"[图片: {fpath}]")
|
||||
elif ftype == "video":
|
||||
file_refs.append(f"[视频: {fpath}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {fpath}]")
|
||||
qq_msg.content = qq_msg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[QQ] Attached {len(cached_files)} cached file(s)")
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(
|
||||
qq_msg.ctype,
|
||||
qq_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=qq_msg,
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _compose_context
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
|
||||
if cmsg.is_group:
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send reply
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context.get("isgroup", False)
|
||||
receiver = context.get("receiver", "")
|
||||
|
||||
if not msg:
|
||||
# Active send (e.g. scheduled tasks), no original message to reply to
|
||||
self._active_send_text(reply.content if reply.type == ReplyType.TEXT else str(reply.content),
|
||||
receiver, is_group)
|
||||
return
|
||||
|
||||
event_type = getattr(msg, "event_type", "")
|
||||
msg_id = getattr(msg, "msg_id", "")
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
self._send_text(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
|
||||
self._send_image(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type == ReplyType.FILE:
|
||||
if hasattr(reply, "text_content") and reply.text_content:
|
||||
self._send_text(reply.text_content, msg, event_type, msg_id)
|
||||
time.sleep(0.3)
|
||||
self._send_file(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type in (ReplyType.VIDEO, ReplyType.VIDEO_URL):
|
||||
self._send_media(reply.content, msg, event_type, msg_id, QQ_FILE_TYPE_VIDEO)
|
||||
else:
|
||||
logger.warning(f"[QQ] Unsupported reply type: {reply.type}, falling back to text")
|
||||
self._send_text(str(reply.content), msg, event_type, msg_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_next_msg_seq(self, msg_id: str) -> int:
|
||||
seq = self._msg_seq_counter.get(msg_id, 1)
|
||||
self._msg_seq_counter[msg_id] = seq + 1
|
||||
return seq
|
||||
|
||||
def _build_msg_url_and_base_body(self, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Build the API URL and base body dict for sending a message."""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
url = f"{QQ_API_BASE}/v2/groups/{group_openid}/messages"
|
||||
body = {
|
||||
"msg_id": msg_id,
|
||||
"msg_seq": self._get_next_msg_seq(msg_id),
|
||||
}
|
||||
return url, body, "group", group_openid
|
||||
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = msg._rawmsg.get("author", {}).get("user_openid", "") or msg.from_user_id
|
||||
url = f"{QQ_API_BASE}/v2/users/{user_openid}/messages"
|
||||
body = {
|
||||
"msg_id": msg_id,
|
||||
"msg_seq": self._get_next_msg_seq(msg_id),
|
||||
}
|
||||
return url, body, "c2c", user_openid
|
||||
|
||||
elif event_type == "AT_MESSAGE_CREATE":
|
||||
channel_id = msg._rawmsg.get("channel_id", "")
|
||||
url = f"{QQ_API_BASE}/channels/{channel_id}/messages"
|
||||
body = {"msg_id": msg_id}
|
||||
return url, body, "channel", channel_id
|
||||
|
||||
elif event_type == "DIRECT_MESSAGE_CREATE":
|
||||
guild_id = msg._rawmsg.get("guild_id", "")
|
||||
url = f"{QQ_API_BASE}/dms/{guild_id}/messages"
|
||||
body = {"msg_id": msg_id}
|
||||
return url, body, "dm", guild_id
|
||||
|
||||
return None, None, None, None
|
||||
|
||||
def _post_message(self, url: str, body: dict, event_type: str):
|
||||
try:
|
||||
resp = requests.post(url, json=body, headers=self._get_auth_headers(), timeout=10)
|
||||
if resp.status_code in (200, 201, 202, 204):
|
||||
logger.info(f"[QQ] Message sent successfully: event_type={event_type}")
|
||||
else:
|
||||
logger.error(f"[QQ] Failed to send message: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Send message error: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Active send (no original message, e.g. scheduled tasks)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _active_send_text(self, content: str, receiver: str, is_group: bool):
|
||||
"""Send text without an original message (active push). QQ limits active messages to 4/month per user."""
|
||||
if not receiver:
|
||||
logger.warning("[QQ] No receiver for active send")
|
||||
return
|
||||
if is_group:
|
||||
url = f"{QQ_API_BASE}/v2/groups/{receiver}/messages"
|
||||
else:
|
||||
url = f"{QQ_API_BASE}/v2/users/{receiver}/messages"
|
||||
body = {
|
||||
"content": content,
|
||||
"msg_type": 0,
|
||||
}
|
||||
event_label = "GROUP_ACTIVE" if is_group else "C2C_ACTIVE"
|
||||
self._post_message(url, body, event_label)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send text
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_text(self, content: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
|
||||
if not url:
|
||||
logger.warning(f"[QQ] Cannot send reply for event_type: {event_type}")
|
||||
return
|
||||
body["content"] = content
|
||||
body["msg_type"] = 0
|
||||
self._post_message(url, body, event_type)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rich media upload & send (image / video / file)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _upload_rich_media(self, file_url: str, file_type: int, msg: QQMessage,
|
||||
event_type: str) -> str:
|
||||
"""
|
||||
Upload media via QQ rich media API and return file_info.
|
||||
For group: POST /v2/groups/{group_openid}/files
|
||||
For c2c: POST /v2/users/{openid}/files
|
||||
"""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
|
||||
or msg.from_user_id)
|
||||
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
|
||||
else:
|
||||
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
|
||||
return ""
|
||||
|
||||
upload_body = {
|
||||
"file_type": file_type,
|
||||
"url": file_url,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
upload_url, json=upload_body,
|
||||
headers=self._get_auth_headers(), timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
data = resp.json()
|
||||
file_info = data.get("file_info", "")
|
||||
logger.info(f"[QQ] Rich media uploaded: file_type={file_type}, "
|
||||
f"file_uuid={data.get('file_uuid', '')}")
|
||||
return file_info
|
||||
else:
|
||||
logger.error(f"[QQ] Rich media upload failed: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Rich media upload error: {e}")
|
||||
return ""
|
||||
|
||||
def _upload_rich_media_base64(self, file_path: str, file_type: int, msg: QQMessage,
|
||||
event_type: str) -> str:
|
||||
"""Upload local file via base64 file_data field."""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
|
||||
or msg.from_user_id)
|
||||
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
|
||||
else:
|
||||
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = base64.b64encode(f.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to read file for upload: {e}")
|
||||
return ""
|
||||
|
||||
upload_body = {
|
||||
"file_type": file_type,
|
||||
"file_data": file_data,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
upload_url, json=upload_body,
|
||||
headers=self._get_auth_headers(), timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
data = resp.json()
|
||||
file_info = data.get("file_info", "")
|
||||
logger.info(f"[QQ] Rich media uploaded (base64): file_type={file_type}, "
|
||||
f"file_uuid={data.get('file_uuid', '')}")
|
||||
return file_info
|
||||
else:
|
||||
logger.error(f"[QQ] Rich media upload (base64) failed: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Rich media upload (base64) error: {e}")
|
||||
return ""
|
||||
|
||||
def _send_media_msg(self, file_info: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send a message with msg_type=7 (rich media) using file_info."""
|
||||
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
|
||||
if not url:
|
||||
return
|
||||
body["msg_type"] = 7
|
||||
body["media"] = {"file_info": file_info}
|
||||
self._post_message(url, body, event_type)
|
||||
|
||||
def _send_image(self, img_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send image reply. Supports URL and local file path."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(img_path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if img_path_or_url.startswith("file://"):
|
||||
img_path_or_url = img_path_or_url[7:]
|
||||
|
||||
if img_path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(
|
||||
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
|
||||
elif os.path.exists(img_path_or_url):
|
||||
file_info = self._upload_rich_media_base64(
|
||||
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] Image not found: {img_path_or_url}")
|
||||
self._send_text("[Image send failed]", msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
self._send_text("[Image upload failed]", msg, event_type, msg_id)
|
||||
|
||||
def _send_file(self, file_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send file reply."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(file_path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_path_or_url.startswith("file://"):
|
||||
file_path_or_url = file_path_or_url[7:]
|
||||
|
||||
if file_path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(
|
||||
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
|
||||
elif os.path.exists(file_path_or_url):
|
||||
file_info = self._upload_rich_media_base64(
|
||||
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] File not found: {file_path_or_url}")
|
||||
self._send_text("[File send failed]", msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
self._send_text("[File upload failed]", msg, event_type, msg_id)
|
||||
|
||||
def _send_media(self, path_or_url: str, msg: QQMessage, event_type: str,
|
||||
msg_id: str, file_type: int):
|
||||
"""Generic media send for video/voice etc."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if path_or_url.startswith("file://"):
|
||||
path_or_url = path_or_url[7:]
|
||||
|
||||
if path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(path_or_url, file_type, msg, event_type)
|
||||
elif os.path.exists(path_or_url):
|
||||
file_info = self._upload_rich_media_base64(path_or_url, file_type, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] Media not found: {path_or_url}")
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
logger.error(f"[QQ] Media upload failed: {path_or_url}")
|
||||
123
channel/qq/qq_message.py
Normal file
123
channel/qq/qq_message.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
import requests
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
def _get_tmp_dir() -> str:
|
||||
"""Return the workspace tmp directory (absolute path), creating it if needed."""
|
||||
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(ws_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class QQMessage(ChatMessage):
|
||||
"""Message wrapper for QQ Bot (websocket long-connection mode)."""
|
||||
|
||||
def __init__(self, event_data: dict, event_type: str):
|
||||
super().__init__(event_data)
|
||||
self.msg_id = event_data.get("id", "")
|
||||
self.create_time = event_data.get("timestamp", "")
|
||||
self.is_group = event_type in ("GROUP_AT_MESSAGE_CREATE",)
|
||||
self.event_type = event_type
|
||||
|
||||
author = event_data.get("author", {})
|
||||
from_user_id = author.get("member_openid", "") or author.get("id", "")
|
||||
group_openid = event_data.get("group_openid", "")
|
||||
|
||||
content = event_data.get("content", "").strip()
|
||||
|
||||
attachments = event_data.get("attachments", [])
|
||||
has_image = any(
|
||||
a.get("content_type", "").startswith("image/") for a in attachments
|
||||
) if attachments else False
|
||||
|
||||
if has_image and not content:
|
||||
self.ctype = ContextType.IMAGE
|
||||
img_attachment = next(
|
||||
a for a in attachments if a.get("content_type", "").startswith("image/")
|
||||
)
|
||||
img_url = img_attachment.get("url", "")
|
||||
if img_url and not img_url.startswith("http"):
|
||||
img_url = "https://" + img_url
|
||||
tmp_dir = _get_tmp_dir()
|
||||
image_path = os.path.join(tmp_dir, f"qq_{self.msg_id}.png")
|
||||
try:
|
||||
resp = requests.get(img_url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
self.content = image_path
|
||||
self.image_path = image_path
|
||||
logger.info(f"[QQ] Image downloaded: {image_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to download image: {e}")
|
||||
self.content = "[Image download failed]"
|
||||
self.image_path = None
|
||||
elif has_image and content:
|
||||
self.ctype = ContextType.TEXT
|
||||
image_paths = []
|
||||
tmp_dir = _get_tmp_dir()
|
||||
for idx, att in enumerate(attachments):
|
||||
if not att.get("content_type", "").startswith("image/"):
|
||||
continue
|
||||
img_url = att.get("url", "")
|
||||
if img_url and not img_url.startswith("http"):
|
||||
img_url = "https://" + img_url
|
||||
img_path = os.path.join(tmp_dir, f"qq_{self.msg_id}_{idx}.png")
|
||||
try:
|
||||
resp = requests.get(img_url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
with open(img_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
image_paths.append(img_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to download mixed image: {e}")
|
||||
content_parts = [content]
|
||||
for p in image_paths:
|
||||
content_parts.append(f"[图片: {p}]")
|
||||
self.content = "\n".join(content_parts)
|
||||
else:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = content
|
||||
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
self.other_user_id = group_openid
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = from_user_id
|
||||
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = author.get("user_openid", "") or from_user_id
|
||||
self.from_user_id = user_openid
|
||||
self.to_user_id = ""
|
||||
self.other_user_id = user_openid
|
||||
self.actual_user_id = user_openid
|
||||
|
||||
elif event_type == "AT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
channel_id = event_data.get("channel_id", "")
|
||||
self.other_user_id = channel_id
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = author.get("username", from_user_id)
|
||||
|
||||
elif event_type == "DIRECT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
guild_id = event_data.get("guild_id", "")
|
||||
self.other_user_id = f"dm_{guild_id}_{from_user_id}"
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = author.get("username", from_user_id)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported QQ event type: {event_type}")
|
||||
|
||||
logger.debug(f"[QQ] Message parsed: type={event_type}, ctype={self.ctype}, "
|
||||
f"from={self.from_user_id}, content_len={len(self.content)}")
|
||||
File diff suppressed because it is too large
Load Diff
1111
channel/web/static/css/console.css
Normal file
1111
channel/web/static/css/console.css
Normal file
File diff suppressed because it is too large
Load Diff
4356
channel/web/static/js/console.js
Normal file
4356
channel/web/static/js/console.js
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,179 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
wechat channel
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wcf_message import WechatfMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import *
|
||||
from config import conf, get_appdata_dir
|
||||
from wcferry import Wcf, WxMsg
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatfChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.NOT_SUPPORT_REPLYTYPE = []
|
||||
# 使用字典存储最近消息,用于去重
|
||||
self.received_msgs = {}
|
||||
# 初始化wcferry客户端
|
||||
self.wcf = Wcf()
|
||||
self.wxid = None # 登录后会被设置为当前登录用户的wxid
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
启动通道
|
||||
"""
|
||||
try:
|
||||
# wcferry会自动唤起微信并登录
|
||||
self.wxid = self.wcf.get_self_wxid()
|
||||
self.name = self.wcf.get_user_info().get("name")
|
||||
logger.info(f"微信登录成功,当前用户ID: {self.wxid}, 用户名:{self.name}")
|
||||
self.contact_cache = ContactCache(self.wcf)
|
||||
self.contact_cache.update()
|
||||
# 启动消息接收
|
||||
self.wcf.enable_receiving_msg()
|
||||
# 创建消息处理线程
|
||||
t = threading.Thread(target=self._process_messages, name="WeChatThread", daemon=True)
|
||||
t.start()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"微信通道启动失败: {e}")
|
||||
raise e
|
||||
|
||||
def _process_messages(self):
|
||||
"""
|
||||
处理消息队列
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
msg = self.wcf.get_msg()
|
||||
if msg:
|
||||
self._handle_message(msg)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
continue
|
||||
|
||||
def _handle_message(self, msg: WxMsg):
|
||||
"""
|
||||
处理单条消息
|
||||
"""
|
||||
try:
|
||||
# 构造消息对象
|
||||
cmsg = WechatfMessage(self, msg)
|
||||
# 消息去重
|
||||
if cmsg.msg_id in self.received_msgs:
|
||||
return
|
||||
self.received_msgs[cmsg.msg_id] = time.time()
|
||||
# 清理过期消息ID
|
||||
self._clean_expired_msgs()
|
||||
|
||||
logger.debug(f"收到消息: {msg}")
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content,
|
||||
isgroup=cmsg.is_group,
|
||||
msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
def _clean_expired_msgs(self, expire_time: float = 60):
|
||||
"""
|
||||
清理过期的消息ID
|
||||
"""
|
||||
now = time.time()
|
||||
for msg_id in list(self.received_msgs.keys()):
|
||||
if now - self.received_msgs[msg_id] > expire_time:
|
||||
del self.received_msgs[msg_id]
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""
|
||||
发送消息
|
||||
"""
|
||||
receiver = context["receiver"]
|
||||
if not receiver:
|
||||
logger.error("receiver is empty")
|
||||
return
|
||||
|
||||
try:
|
||||
if reply.type == ReplyType.TEXT:
|
||||
# 处理@信息
|
||||
at_list = []
|
||||
if context.get("isgroup"):
|
||||
if context["msg"].actual_user_id:
|
||||
at_list = [context["msg"].actual_user_id]
|
||||
at_str = ",".join(at_list) if at_list else ""
|
||||
self.wcf.send_text(reply.content, receiver, at_str)
|
||||
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
self.wcf.send_text(reply.content, receiver)
|
||||
else:
|
||||
logger.error(f"暂不支持的消息类型: {reply.type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭通道
|
||||
"""
|
||||
try:
|
||||
self.wcf.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭通道失败: {e}")
|
||||
|
||||
|
||||
class ContactCache:
|
||||
def __init__(self, wcf):
|
||||
"""
|
||||
wcf: 一个 wcfferry.client.Wcf 实例
|
||||
"""
|
||||
self.wcf = wcf
|
||||
self._contact_map = {} # 形如 {wxid: {完整联系人信息}}
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
更新缓存:调用 get_contacts(),
|
||||
再把 wcf.contacts 构建成 {wxid: {完整信息}} 的字典
|
||||
"""
|
||||
self.wcf.get_contacts()
|
||||
self._contact_map.clear()
|
||||
for item in self.wcf.contacts:
|
||||
wxid = item.get('wxid')
|
||||
if wxid: # 确保有 wxid 字段
|
||||
self._contact_map[wxid] = item
|
||||
|
||||
def get_contact(self, wxid: str) -> dict:
|
||||
"""
|
||||
返回该 wxid 对应的完整联系人 dict,
|
||||
如果没找到就返回 None
|
||||
"""
|
||||
return self._contact_map.get(wxid)
|
||||
|
||||
def get_name_by_wxid(self, wxid: str) -> str:
|
||||
"""
|
||||
通过wxid,获取成员/群名称
|
||||
"""
|
||||
contact = self.get_contact(wxid)
|
||||
if contact:
|
||||
return contact.get('name', '')
|
||||
return ''
|
||||
@@ -1,58 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
wechat channel message
|
||||
"""
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from wcferry import WxMsg
|
||||
|
||||
|
||||
class WechatfMessage(ChatMessage):
|
||||
"""
|
||||
微信消息封装类
|
||||
"""
|
||||
|
||||
def __init__(self, channel, wcf_msg: WxMsg, is_group=False):
|
||||
"""
|
||||
初始化消息对象
|
||||
:param wcf_msg: wcferry消息对象
|
||||
:param is_group: 是否是群消息
|
||||
"""
|
||||
super().__init__(wcf_msg)
|
||||
self.msg_id = wcf_msg.id
|
||||
self.create_time = wcf_msg.ts # 使用消息时间戳
|
||||
self.is_group = is_group or wcf_msg._is_group
|
||||
self.wxid = channel.wxid
|
||||
self.name = channel.name
|
||||
|
||||
# 解析消息类型
|
||||
if wcf_msg.is_text():
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wcf_msg.content
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {wcf_msg.type}")
|
||||
|
||||
# 设置发送者和接收者信息
|
||||
self.from_user_id = self.wxid if wcf_msg.sender == self.wxid else wcf_msg.sender
|
||||
self.from_user_nickname = self.name if wcf_msg.sender == self.wxid else channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
|
||||
self.to_user_id = self.wxid
|
||||
self.to_user_nickname = self.name
|
||||
self.other_user_id = wcf_msg.sender
|
||||
self.other_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
|
||||
|
||||
# 群消息特殊处理
|
||||
if self.is_group:
|
||||
self.other_user_id = wcf_msg.roomid
|
||||
self.other_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.roomid)
|
||||
self.actual_user_id = wcf_msg.sender
|
||||
self.actual_user_nickname = channel.wcf.get_alias_in_chatroom(wcf_msg.sender, wcf_msg.roomid)
|
||||
if not self.actual_user_nickname: # 群聊获取不到企微号成员昵称,这里尝试从联系人缓存去获取
|
||||
self.actual_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
|
||||
self.room_id = wcf_msg.roomid
|
||||
self.is_at = wcf_msg.is_at(self.wxid) # 是否被@当前登录用户
|
||||
|
||||
# 判断是否是自己发送的消息
|
||||
self.my_msg = wcf_msg.from_self()
|
||||
@@ -1,309 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
wechat channel
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import requests
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel import chat_channel
|
||||
from channel.wechat.wechat_message import *
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.time_check import time_checker
|
||||
from common.utils import convert_webp_to_png, remove_markdown_symbol
|
||||
from config import conf, get_appdata_dir
|
||||
from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
|
||||
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING])
|
||||
def handler_single_msg(msg):
|
||||
try:
|
||||
cmsg = WechatMessage(msg, False)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX]single message {} skipped: {}".format(msg["MsgId"], e))
|
||||
return None
|
||||
WechatChannel().handle_single(cmsg)
|
||||
return None
|
||||
|
||||
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING], isGroupChat=True)
|
||||
def handler_group_msg(msg):
|
||||
try:
|
||||
cmsg = WechatMessage(msg, True)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX]group message {} skipped: {}".format(msg["MsgId"], e))
|
||||
return None
|
||||
WechatChannel().handle_group(cmsg)
|
||||
return None
|
||||
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: ChatMessage):
|
||||
msgId = cmsg.msg_id
|
||||
if msgId in self.receivedMsgs:
|
||||
logger.info("Wechat message {} already received, ignore".format(msgId))
|
||||
return
|
||||
self.receivedMsgs[msgId] = True
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
if cmsg.my_msg and not cmsg.is_group:
|
||||
logger.debug("[WX]my message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# 可用的二维码生成接口
|
||||
# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
|
||||
# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
|
||||
def qrCallback(uuid, status, qrcode):
|
||||
# logger.debug("qrCallback: {} {}".format(uuid,status))
|
||||
if status == "0":
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(io.BytesIO(qrcode))
|
||||
_thread = threading.Thread(target=img.show, args=("QRCode",))
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
import qrcode
|
||||
|
||||
url = f"https://login.weixin.qq.com/l/{uuid}"
|
||||
|
||||
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
|
||||
qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
|
||||
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
|
||||
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
|
||||
print("You can also scan QRCode in any website below:")
|
||||
print(qr_api3)
|
||||
print(qr_api4)
|
||||
print(qr_api2)
|
||||
print(qr_api1)
|
||||
_send_qr_code([qr_api3, qr_api4, qr_api2, qr_api1])
|
||||
qr = qrcode.QRCode(border=1)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
try:
|
||||
qr.print_ascii(invert=True)
|
||||
except UnicodeEncodeError:
|
||||
print("ASCII QR code printing failed due to encoding issues.")
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
self.auto_login_times = 0
|
||||
|
||||
def startup(self):
|
||||
try:
|
||||
time.sleep(3)
|
||||
logger.error("""[WechatChannel] 当前channel暂不可用,目前支持的channel有:
|
||||
1. terminal: 终端
|
||||
2. wechatmp: 个人公众号
|
||||
3. wechatmp_service: 企业公众号
|
||||
4. wechatcom_app: 企微自建应用
|
||||
5. dingtalk: 钉钉
|
||||
6. feishu: 飞书
|
||||
7. web: 网页
|
||||
8. wcf: wechat (需Windows环境,参考 https://github.com/zhayujie/chatgpt-on-wechat/pull/2562 )
|
||||
可修改 config.json 配置文件的 channel_type 字段进行切换""")
|
||||
|
||||
# itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# # login by scan QRCode
|
||||
# hotReload = conf().get("hot_reload", False)
|
||||
# status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
# itchat.auto_login(
|
||||
# enableCmdQR=2,
|
||||
# hotReload=hotReload,
|
||||
# statusStorageDir=status_path,
|
||||
# qrCallback=qrCallback,
|
||||
# exitCallback=self.exitCallback,
|
||||
# loginCallback=self.loginCallback
|
||||
# )
|
||||
# self.user_id = itchat.instance.storageClass.userName
|
||||
# self.name = itchat.instance.storageClass.nickName
|
||||
# logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# # start message listener
|
||||
# itchat.run()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def exitCallback(self):
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id and conf().get("use_linkai"):
|
||||
_send_logout()
|
||||
time.sleep(2)
|
||||
self.auto_login_times += 1
|
||||
if self.auto_login_times < 100:
|
||||
chat_channel.handler_pool._shutdown = False
|
||||
self.startup()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def loginCallback(self):
|
||||
logger.debug("Login success")
|
||||
_send_login_success()
|
||||
|
||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复
|
||||
# Context包含了消息的所有信息,包括以下属性
|
||||
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
|
||||
# content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
|
||||
# kwargs 附加参数字典,包含以下的key:
|
||||
# session_id: 会话id
|
||||
# isgroup: 是否是群聊
|
||||
# receiver: 需要回复的对象
|
||||
# msg: ChatMessage消息对象
|
||||
# origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
|
||||
# desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: ChatMessage):
|
||||
# filter system message
|
||||
if cmsg.other_user_id in ["weixin"]:
|
||||
return
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if conf().get("speech_recognition") != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
else:
|
||||
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: ChatMessage):
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if conf().get("group_speech_recognition") != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.ACCEPT_FRIEND, ContextType.EXIT_GROUP]:
|
||||
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
pass
|
||||
elif cmsg.ctype == ContextType.FILE:
|
||||
logger.debug(f"[WX]receive attachment msg, file_name={cmsg.content}")
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg, no_need_at=conf().get("no_need_at", False))
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply.content = remove_markdown_symbol(reply.content)
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = remove_markdown_symbol(reply.content)
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
itchat.send_file(reply.content, toUserName=receiver)
|
||||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
logger.debug(f"[WX] start download image, img_url={img_url}")
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
size = 0
|
||||
for block in pic_res.iter_content(1024):
|
||||
size += len(block)
|
||||
image_storage.write(block)
|
||||
logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
|
||||
image_storage.seek(0)
|
||||
if ".webp" in img_url:
|
||||
try:
|
||||
image_storage = convert_webp_to_png(image_storage)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert image: {e}")
|
||||
return
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.FILE: # 新增文件回复类型
|
||||
file_storage = reply.content
|
||||
itchat.send_file(file_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendFile, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO: # 新增视频回复类型
|
||||
video_storage = reply.content
|
||||
itchat.send_video(video_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendFile, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL: # 新增视频URL回复类型
|
||||
video_url = reply.content
|
||||
logger.debug(f"[WX] start download video, video_url={video_url}")
|
||||
video_res = requests.get(video_url, stream=True)
|
||||
video_storage = io.BytesIO()
|
||||
size = 0
|
||||
for block in video_res.iter_content(1024):
|
||||
size += len(block)
|
||||
video_storage.write(block)
|
||||
logger.info(f"[WX] download video success, size={size}, video_url={video_url}")
|
||||
video_storage.seek(0)
|
||||
itchat.send_video(video_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendVideo url={}, receiver={}".format(video_url, receiver))
|
||||
|
||||
def _send_login_success():
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id:
|
||||
chat_client.send_login_success()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def _send_logout():
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id:
|
||||
chat_client.send_logout()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def _send_qr_code(qrcode_list: list):
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id:
|
||||
chat_client.send_qrcode(qrcode_list)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
import re
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
|
||||
class WechatMessage(ChatMessage):
|
||||
def __init__(self, itchat_msg, is_group=False):
|
||||
super().__init__(itchat_msg)
|
||||
self.msg_id = itchat_msg["MsgId"]
|
||||
self.create_time = itchat_msg["CreateTime"]
|
||||
self.is_group = is_group
|
||||
|
||||
notes_join_group = ["加入群聊", "加入了群聊", "invited", "joined"] # 可通过添加对应语言的加入群聊通知中的关键词适配更多
|
||||
notes_bot_join_group = ["邀请你", "invited you", "You've joined", "你通过扫描"]
|
||||
notes_exit_group = ["移出了群聊", "removed"] # 可通过添加对应语言的踢出群聊通知中的关键词适配更多
|
||||
notes_patpat = ["拍了拍我", "tickled my", "tickled me"] # 可通过添加对应语言的拍一拍通知中的关键词适配更多
|
||||
|
||||
if itchat_msg["Type"] == TEXT:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = itchat_msg["Text"]
|
||||
elif itchat_msg["Type"] == VOICE:
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
|
||||
if is_group:
|
||||
if any(note_bot_join_group in itchat_msg["Content"] for note_bot_join_group in notes_bot_join_group): # 邀请机器人加入群聊
|
||||
logger.warn("机器人加入群聊消息,不处理~")
|
||||
pass
|
||||
elif any(note_join_group in itchat_msg["Content"] for note_join_group in notes_join_group): # 若有任何在notes_join_group列表中的字符串出现在NOTE中
|
||||
# 这里只能得到nickname, actual_user_id还是机器人的id
|
||||
if "加入群聊" not in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
if "invited" in itchat_msg["Content"]: # 匹配英文信息
|
||||
self.actual_user_nickname = re.findall(r'invited\s+(.+?)\s+to\s+the\s+group\s+chat', itchat_msg["Content"])[0]
|
||||
elif "joined" in itchat_msg["Content"]: # 匹配通过二维码加入的英文信息
|
||||
self.actual_user_nickname = re.findall(r'"(.*?)" joined the group chat via the QR Code shared by', itchat_msg["Content"])[0]
|
||||
elif "加入了群聊" in itchat_msg["Content"]:
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
|
||||
elif "加入群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif any(note_exit_group in itchat_msg["Content"] for note_exit_group in notes_exit_group): # 若有任何在notes_exit_group列表中的字符串出现在NOTE中
|
||||
self.ctype = ContextType.EXIT_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
|
||||
self.ctype = ContextType.PATPAT
|
||||
self.content = itchat_msg["Content"]
|
||||
if "拍了拍我" in itchat_msg["Content"]: # 识别中文
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
elif "tickled my" in itchat_msg["Content"] or "tickled me" in itchat_msg["Content"]:
|
||||
self.actual_user_nickname = re.findall(r'^(.*?)(?:tickled my|tickled me)', itchat_msg["Content"])[0]
|
||||
else:
|
||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||
|
||||
elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
|
||||
self.ctype = ContextType.ACCEPT_FRIEND
|
||||
self.content = itchat_msg["Content"]
|
||||
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
|
||||
self.ctype = ContextType.PATPAT
|
||||
self.content = itchat_msg["Content"]
|
||||
else:
|
||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||
elif itchat_msg["Type"] == ATTACHMENT:
|
||||
self.ctype = ContextType.FILE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == SHARING:
|
||||
self.ctype = ContextType.SHARING
|
||||
self.content = itchat_msg.get("Url")
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
|
||||
|
||||
self.from_user_id = itchat_msg["FromUserName"]
|
||||
self.to_user_id = itchat_msg["ToUserName"]
|
||||
|
||||
user_id = itchat.instance.storageClass.userName
|
||||
nickname = itchat.instance.storageClass.nickName
|
||||
|
||||
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
|
||||
# 以下很繁琐,一句话总结:能填的都填了。
|
||||
if self.from_user_id == user_id:
|
||||
self.from_user_nickname = nickname
|
||||
if self.to_user_id == user_id:
|
||||
self.to_user_nickname = nickname
|
||||
try: # 陌生人时候, User字段可能不存在
|
||||
# my_msg 为True是表示是自己发送的消息
|
||||
self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
|
||||
itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
|
||||
self.other_user_id = itchat_msg["User"]["UserName"]
|
||||
self.other_user_nickname = itchat_msg["User"]["NickName"]
|
||||
if self.other_user_id == self.from_user_id:
|
||||
self.from_user_nickname = self.other_user_nickname
|
||||
if self.other_user_id == self.to_user_id:
|
||||
self.to_user_nickname = self.other_user_nickname
|
||||
if itchat_msg["User"].get("Self"):
|
||||
# 自身的展示名,当设置了群昵称时,该字段表示群昵称
|
||||
self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
|
||||
except KeyError as e: # 处理偶尔没有对方信息的情况
|
||||
logger.warn("[WX]get other_user_id failed: " + str(e))
|
||||
if self.from_user_id == user_id:
|
||||
self.other_user_id = self.to_user_id
|
||||
else:
|
||||
self.other_user_id = self.from_user_id
|
||||
|
||||
if self.is_group:
|
||||
self.is_at = itchat_msg["IsAt"]
|
||||
self.actual_user_id = itchat_msg["ActualUserName"]
|
||||
if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.EXIT_GROUP]:
|
||||
self.actual_user_nickname = itchat_msg["ActualNickName"]
|
||||
@@ -1,129 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
wechaty channel
|
||||
Python Wechaty - https://github.com/wechaty/python-wechaty
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
|
||||
from wechaty import Contact, Wechaty
|
||||
from wechaty.user import Message
|
||||
from wechaty_puppet import FileBox
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.context import Context
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wechaty_message import WechatyMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
try:
|
||||
from voice.audio_convert import any_to_sil
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatyChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def startup(self):
|
||||
config = conf()
|
||||
token = config.get("wechaty_puppet_service_token")
|
||||
os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
|
||||
asyncio.run(self.main())
|
||||
|
||||
async def main(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
# 将asyncio的loop传入处理线程
|
||||
self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
|
||||
self.bot = Wechaty()
|
||||
self.bot.on("login", self.on_login)
|
||||
self.bot.on("message", self.on_message)
|
||||
await self.bot.start()
|
||||
|
||||
async def on_login(self, contact: Contact):
|
||||
self.user_id = contact.contact_id
|
||||
self.name = contact.name
|
||||
logger.info("[WX] login user={}".format(contact))
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver_id = context["receiver"]
|
||||
loop = asyncio.get_event_loop()
|
||||
if context["isgroup"]:
|
||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
|
||||
else:
|
||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
|
||||
msg = None
|
||||
if reply.type == ReplyType.TEXT:
|
||||
msg = reply.content
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
msg = reply.content
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
voiceLength = None
|
||||
file_path = reply.content
|
||||
sil_file = os.path.splitext(file_path)[0] + ".sil"
|
||||
voiceLength = int(any_to_sil(file_path, sil_file))
|
||||
if voiceLength >= 60000:
|
||||
voiceLength = 60000
|
||||
logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
|
||||
# 发送语音
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
|
||||
if voiceLength is not None:
|
||||
msg.metadata["voiceLength"] = voiceLength
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
try:
|
||||
os.remove(file_path)
|
||||
if sil_file != file_path:
|
||||
os.remove(sil_file)
|
||||
except Exception as e:
|
||||
pass
|
||||
logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
|
||||
async def on_message(self, msg: Message):
|
||||
"""
|
||||
listen for message event
|
||||
"""
|
||||
try:
|
||||
cmsg = await WechatyMessage(msg)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX] {}".format(e))
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("[WX] {}".format(e))
|
||||
return
|
||||
logger.debug("[WX] message:{}".format(cmsg))
|
||||
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
||||
isgroup = room is not None
|
||||
ctype = cmsg.ctype
|
||||
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
|
||||
if context:
|
||||
logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
|
||||
self.produce(context)
|
||||
@@ -1,89 +0,0 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from wechaty import MessageType
|
||||
from wechaty.user import Message
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class aobject(object):
|
||||
"""Inheriting this class allows you to define an async __init__.
|
||||
|
||||
So you can create objects by doing something like `await MyClass(params)`
|
||||
"""
|
||||
|
||||
async def __new__(cls, *a, **kw):
|
||||
instance = super().__new__(cls)
|
||||
await instance.__init__(*a, **kw)
|
||||
return instance
|
||||
|
||||
async def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class WechatyMessage(ChatMessage, aobject):
|
||||
async def __init__(self, wechaty_msg: Message):
|
||||
super().__init__(wechaty_msg)
|
||||
|
||||
room = wechaty_msg.room()
|
||||
|
||||
self.msg_id = wechaty_msg.message_id
|
||||
self.create_time = wechaty_msg.payload.timestamp
|
||||
self.is_group = room is not None
|
||||
|
||||
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wechaty_msg.text()
|
||||
elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
|
||||
self.ctype = ContextType.VOICE
|
||||
voice_file = await wechaty_msg.to_file_box()
|
||||
self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
|
||||
|
||||
def func():
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
|
||||
|
||||
self._prepare_fn = func
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
|
||||
|
||||
from_contact = wechaty_msg.talker() # 获取消息的发送者
|
||||
self.from_user_id = from_contact.contact_id
|
||||
self.from_user_nickname = from_contact.name
|
||||
|
||||
# group中的from和to,wechaty跟itchat含义不一样
|
||||
# wecahty: from是消息实际发送者, to:所在群
|
||||
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
|
||||
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
|
||||
|
||||
if self.is_group:
|
||||
self.to_user_id = room.room_id
|
||||
self.to_user_nickname = await room.topic()
|
||||
else:
|
||||
to_contact = wechaty_msg.to()
|
||||
self.to_user_id = to_contact.contact_id
|
||||
self.to_user_nickname = to_contact.name
|
||||
|
||||
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
|
||||
self.other_user_id = self.to_user_id
|
||||
self.other_user_nickname = self.to_user_nickname
|
||||
else:
|
||||
self.other_user_id = self.from_user_id
|
||||
self.other_user_nickname = self.from_user_nickname
|
||||
|
||||
if self.is_group: # wechaty群聊中,实际发送用户就是from_user
|
||||
self.is_at = await wechaty_msg.mention_self()
|
||||
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
|
||||
name = wechaty_msg.wechaty.user_self().name
|
||||
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
|
||||
if re.search(pattern, self.content):
|
||||
logger.debug(f"wechaty message {self.msg_id} include at")
|
||||
self.is_at = True
|
||||
|
||||
self.actual_user_id = self.from_user_id
|
||||
self.actual_user_nickname = self.from_user_nickname
|
||||
@@ -36,6 +36,7 @@ class WechatComAppChannel(ChatChannel):
|
||||
self.agent_id = conf().get("wechatcomapp_agent_id")
|
||||
self.token = conf().get("wechatcomapp_token")
|
||||
self.aes_key = conf().get("wechatcomapp_aes_key")
|
||||
self._http_server = None
|
||||
logger.info(
|
||||
"[wechatcom] Initializing WeCom app channel, corp_id: {}, agent_id: {}".format(self.corp_id, self.agent_id)
|
||||
)
|
||||
@@ -51,13 +52,24 @@ class WechatComAppChannel(ChatChannel):
|
||||
logger.info("[wechatcom] 📡 Listening on http://0.0.0.0:{}/wxcomapp/".format(port))
|
||||
logger.info("[wechatcom] 🤖 Ready to receive messages")
|
||||
|
||||
# Suppress web.py's default server startup message
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = io.StringIO()
|
||||
# Build WSGI app with middleware (same as runsimple but without print)
|
||||
func = web.httpserver.StaticMiddleware(app.wsgifunc())
|
||||
func = web.httpserver.LogMiddleware(func)
|
||||
server = web.httpserver.WSGIServer(("0.0.0.0", port), func)
|
||||
self._http_server = server
|
||||
try:
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
server.start()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
server.stop()
|
||||
|
||||
def stop(self):
|
||||
if self._http_server:
|
||||
try:
|
||||
self._http_server.stop()
|
||||
logger.info("[wechatcom] HTTP server stopped")
|
||||
except Exception as e:
|
||||
logger.warning(f"[wechatcom] Error stopping HTTP server: {e}")
|
||||
self._http_server = None
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# 微信公众号channel
|
||||
|
||||
鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。
|
||||
微信公众号channel,提供稳定的服务。
|
||||
目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制。
|
||||
|
||||
## 使用方法(订阅号,服务号类似)
|
||||
|
||||
@@ -41,6 +41,7 @@ class WechatMPChannel(ChatChannel):
|
||||
super().__init__()
|
||||
self.passive_reply = passive_reply
|
||||
self.NOT_SUPPORT_REPLYTYPE = []
|
||||
self._http_server = None
|
||||
appid = conf().get("wechatmp_app_id")
|
||||
secret = conf().get("wechatmp_app_secret")
|
||||
token = conf().get("wechatmp_token")
|
||||
@@ -69,7 +70,23 @@ class WechatMPChannel(ChatChannel):
|
||||
urls = ("/wx", "channel.wechatmp.active_reply.Query")
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("wechatmp_port", 8080)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
func = web.httpserver.StaticMiddleware(app.wsgifunc())
|
||||
func = web.httpserver.LogMiddleware(func)
|
||||
server = web.httpserver.WSGIServer(("0.0.0.0", port), func)
|
||||
self._http_server = server
|
||||
try:
|
||||
server.start()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
server.stop()
|
||||
|
||||
def stop(self):
|
||||
if self._http_server:
|
||||
try:
|
||||
self._http_server.stop()
|
||||
logger.info("[wechatmp] HTTP server stopped")
|
||||
except Exception as e:
|
||||
logger.warning(f"[wechatmp] Error stopping HTTP server: {e}")
|
||||
self._http_server = None
|
||||
|
||||
def start_loop(self, loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
0
channel/wecom_bot/__init__.py
Normal file
0
channel/wecom_bot/__init__.py
Normal file
788
channel/wecom_bot/wecom_bot_channel.py
Normal file
788
channel/wecom_bot/wecom_bot_channel.py
Normal file
@@ -0,0 +1,788 @@
|
||||
"""
|
||||
WeCom (企业微信) AI Bot channel via WebSocket long connection.
|
||||
|
||||
Supports:
|
||||
- Single chat and group chat (text / image / file input & output)
|
||||
- Scheduled task push via aibot_send_msg
|
||||
- Heartbeat keep-alive and auto-reconnect
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
import websocket
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.wecom_bot.wecom_bot_message import WecomBotMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.ws_client_compat import websocket_app_run_forever
|
||||
from config import conf
|
||||
|
||||
WECOM_WS_URL = "wss://openws.work.weixin.qq.com"
|
||||
HEARTBEAT_INTERVAL = 30
|
||||
MEDIA_CHUNK_SIZE = 512 * 1024 # 512KB per chunk (before base64 encoding)
|
||||
|
||||
|
||||
@singleton
|
||||
class WecomBotChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_id = ""
|
||||
self.bot_secret = ""
|
||||
self.received_msgs = ExpiredDict(60 * 60 * 7.1)
|
||||
self._ws = None
|
||||
self._ws_thread = None
|
||||
self._heartbeat_thread = None
|
||||
self._connected = False
|
||||
self._stop_event = threading.Event()
|
||||
self._pending_responses = {} # req_id -> (threading.Event, result_holder)
|
||||
self._pending_lock = threading.Lock()
|
||||
self._stream_states = {} # req_id -> {"stream_id": str, "content": str}
|
||||
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.bot_id = conf().get("wecom_bot_id", "")
|
||||
self.bot_secret = conf().get("wecom_bot_secret", "")
|
||||
|
||||
if not self.bot_id or not self.bot_secret:
|
||||
err = "[WecomBot] wecom_bot_id and wecom_bot_secret are required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_ws()
|
||||
|
||||
def stop(self):
|
||||
logger.info("[WecomBot] stop() called")
|
||||
self._stop_event.set()
|
||||
if self._ws:
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket connection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _start_ws(self):
|
||||
def _on_open(ws):
|
||||
logger.info("[WecomBot] WebSocket connected, sending subscribe...")
|
||||
self._send_subscribe()
|
||||
|
||||
def _on_message(ws, raw):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
self._handle_ws_message(data)
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] Failed to handle ws message: {e}", exc_info=True)
|
||||
|
||||
def _on_error(ws, error):
|
||||
logger.error(f"[WecomBot] WebSocket error: {error}")
|
||||
|
||||
def _on_close(ws, close_status_code, close_msg):
|
||||
logger.warning(f"[WecomBot] WebSocket closed: status={close_status_code}, msg={close_msg}")
|
||||
self._connected = False
|
||||
if not self._stop_event.is_set():
|
||||
logger.info("[WecomBot] Will reconnect in 5s...")
|
||||
time.sleep(5)
|
||||
if not self._stop_event.is_set():
|
||||
self._start_ws()
|
||||
|
||||
self._ws = websocket.WebSocketApp(
|
||||
WECOM_WS_URL,
|
||||
on_open=_on_open,
|
||||
on_message=_on_message,
|
||||
on_error=_on_error,
|
||||
on_close=_on_close,
|
||||
)
|
||||
|
||||
def run_forever():
|
||||
try:
|
||||
websocket_app_run_forever(self._ws, ping_interval=0, reconnect=0)
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
logger.info("[WecomBot] WebSocket thread interrupted")
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] WebSocket run_forever error: {e}")
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_forever, daemon=True)
|
||||
self._ws_thread.start()
|
||||
self._ws_thread.join()
|
||||
|
||||
def _ws_send(self, data: dict):
|
||||
if self._ws:
|
||||
self._ws.send(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
def _gen_req_id(self) -> str:
|
||||
return uuid.uuid4().hex[:16]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Subscribe & heartbeat
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_subscribe(self):
|
||||
self._ws_send({
|
||||
"cmd": "aibot_subscribe",
|
||||
"headers": {"req_id": self._gen_req_id()},
|
||||
"body": {
|
||||
"bot_id": self.bot_id,
|
||||
"secret": self.bot_secret,
|
||||
},
|
||||
})
|
||||
|
||||
def _start_heartbeat(self):
|
||||
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
|
||||
return
|
||||
|
||||
def heartbeat_loop():
|
||||
while not self._stop_event.is_set() and self._connected:
|
||||
try:
|
||||
self._ws_send({
|
||||
"cmd": "ping",
|
||||
"headers": {"req_id": self._gen_req_id()},
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"[WecomBot] Heartbeat send failed: {e}")
|
||||
break
|
||||
self._stop_event.wait(HEARTBEAT_INTERVAL)
|
||||
|
||||
self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
|
||||
self._heartbeat_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Incoming message dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_and_wait(self, data: dict, timeout: float = 15) -> dict:
|
||||
"""Send a ws message and wait for the matching response by req_id."""
|
||||
req_id = data.get("headers", {}).get("req_id", "")
|
||||
event = threading.Event()
|
||||
holder = {"data": None}
|
||||
with self._pending_lock:
|
||||
self._pending_responses[req_id] = (event, holder)
|
||||
self._ws_send(data)
|
||||
event.wait(timeout=timeout)
|
||||
with self._pending_lock:
|
||||
self._pending_responses.pop(req_id, None)
|
||||
return holder["data"] or {}
|
||||
|
||||
def _handle_ws_message(self, data: dict):
|
||||
cmd = data.get("cmd", "")
|
||||
errcode = data.get("errcode")
|
||||
req_id = data.get("headers", {}).get("req_id", "")
|
||||
|
||||
# Check if this is a response to a pending request
|
||||
if req_id:
|
||||
with self._pending_lock:
|
||||
pending = self._pending_responses.get(req_id)
|
||||
if pending:
|
||||
event, holder = pending
|
||||
holder["data"] = data
|
||||
event.set()
|
||||
return
|
||||
|
||||
# Subscribe response (only handle once before connected)
|
||||
if errcode is not None and cmd == "":
|
||||
if not self._connected:
|
||||
if errcode == 0:
|
||||
logger.info("[WecomBot] ✅ Subscribe success")
|
||||
self._connected = True
|
||||
self._start_heartbeat()
|
||||
self.report_startup_success()
|
||||
else:
|
||||
errmsg = data.get("errmsg", "unknown error")
|
||||
logger.error(f"[WecomBot] Subscribe failed: errcode={errcode}, errmsg={errmsg}")
|
||||
self.report_startup_error(errmsg)
|
||||
return
|
||||
|
||||
if cmd == "aibot_msg_callback":
|
||||
self._handle_msg_callback(data)
|
||||
elif cmd == "aibot_event_callback":
|
||||
self._handle_event_callback(data)
|
||||
elif cmd == "":
|
||||
if errcode and errcode != 0:
|
||||
logger.warning(f"[WecomBot] Response error: {data}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message callback
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_msg_callback(self, data: dict):
|
||||
body = data.get("body", {})
|
||||
req_id = data.get("headers", {}).get("req_id", "")
|
||||
msg_id = body.get("msgid", "")
|
||||
|
||||
if self.received_msgs.get(msg_id):
|
||||
logger.debug(f"[WecomBot] Duplicate msg filtered: {msg_id}")
|
||||
return
|
||||
self.received_msgs[msg_id] = True
|
||||
|
||||
chattype = body.get("chattype", "single")
|
||||
is_group = chattype == "group"
|
||||
|
||||
try:
|
||||
wecom_msg = WecomBotMessage(body, is_group=is_group)
|
||||
except NotImplementedError as e:
|
||||
logger.warning(f"[WecomBot] {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] Failed to parse message: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
wecom_msg.req_id = req_id
|
||||
|
||||
# File cache logic (same pattern as feishu)
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
if is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
session_id = body.get("chatid", "")
|
||||
else:
|
||||
session_id = wecom_msg.from_user_id + "_" + body.get("chatid", "")
|
||||
else:
|
||||
session_id = wecom_msg.from_user_id
|
||||
|
||||
if wecom_msg.ctype == ContextType.IMAGE:
|
||||
if hasattr(wecom_msg, "image_path") and wecom_msg.image_path:
|
||||
file_cache.add(session_id, wecom_msg.image_path, file_type="image")
|
||||
logger.info(f"[WecomBot] Image cached for session {session_id}")
|
||||
return
|
||||
|
||||
if wecom_msg.ctype == ContextType.FILE:
|
||||
wecom_msg.prepare()
|
||||
file_cache.add(session_id, wecom_msg.content, file_type="file")
|
||||
logger.info(f"[WecomBot] File cached for session {session_id}: {wecom_msg.content}")
|
||||
return
|
||||
|
||||
if wecom_msg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
file_refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
fpath = fi["path"]
|
||||
if ftype == "image":
|
||||
file_refs.append(f"[图片: {fpath}]")
|
||||
elif ftype == "video":
|
||||
file_refs.append(f"[视频: {fpath}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {fpath}]")
|
||||
wecom_msg.content = wecom_msg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[WecomBot] Attached {len(cached_files)} cached file(s)")
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(
|
||||
wecom_msg.ctype,
|
||||
wecom_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=wecom_msg,
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
if req_id:
|
||||
context["on_event"] = self._make_stream_callback(req_id)
|
||||
self.produce(context)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event callback
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_event_callback(self, data: dict):
|
||||
body = data.get("body", {})
|
||||
event = body.get("event", {})
|
||||
event_type = event.get("eventtype", "")
|
||||
|
||||
if event_type == "enter_chat":
|
||||
logger.info(f"[WecomBot] User entered chat: {body.get('from', {}).get('userid')}")
|
||||
elif event_type == "disconnected_event":
|
||||
logger.warning("[WecomBot] Received disconnected_event, another connection took over")
|
||||
else:
|
||||
logger.debug(f"[WecomBot] Event: {event_type}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stream callback (for agent on_event)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _make_stream_callback(self, req_id: str):
|
||||
"""Build an on_event callback that pushes agent stream deltas to wecom via stream message.
|
||||
|
||||
All intermediate segments (thinking before tool calls) and the final answer
|
||||
are accumulated into a single stream message, separated by '---'.
|
||||
Throttles push to at most once per 100ms to avoid WebSocket congestion.
|
||||
"""
|
||||
stream_id = uuid.uuid4().hex[:16]
|
||||
self._stream_states[req_id] = {
|
||||
"stream_id": stream_id,
|
||||
"committed": "",
|
||||
"current": "",
|
||||
"last_push_time": 0,
|
||||
"last_push_len": 0,
|
||||
}
|
||||
|
||||
def _push_stream(state: dict, force: bool = False):
|
||||
"""Push current stream content to wecom (throttled unless forced)."""
|
||||
now = time.time()
|
||||
if not force and now - state["last_push_time"] < 0.1:
|
||||
return
|
||||
content = state["committed"] + state["current"]
|
||||
if len(content) == state["last_push_len"]:
|
||||
return
|
||||
state["last_push_time"] = now
|
||||
state["last_push_len"] = len(content)
|
||||
try:
|
||||
self._ws_send({
|
||||
"cmd": "aibot_respond_msg",
|
||||
"headers": {"req_id": req_id},
|
||||
"body": {
|
||||
"msgtype": "stream",
|
||||
"stream": {
|
||||
"id": state["stream_id"],
|
||||
"finish": False,
|
||||
"content": content,
|
||||
},
|
||||
},
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"[WecomBot] Stream push failed: {e}")
|
||||
|
||||
def on_event(event: dict):
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
state = self._stream_states.get(req_id)
|
||||
if not state:
|
||||
return
|
||||
|
||||
if event_type == "turn_start":
|
||||
state["current"] = ""
|
||||
|
||||
elif event_type == "message_update":
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
state["current"] += delta
|
||||
_push_stream(state)
|
||||
|
||||
elif event_type == "message_end":
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
if state["current"].strip():
|
||||
state["committed"] += state["current"].strip() + "\n\n---\n\n"
|
||||
state["current"] = ""
|
||||
else:
|
||||
state["committed"] += state["current"]
|
||||
state["current"] = ""
|
||||
_push_stream(state, force=True)
|
||||
|
||||
return on_event
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _compose_context (same pattern as feishu)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
|
||||
if cmsg.is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send reply
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context.get("isgroup", False)
|
||||
receiver = context.get("receiver", "")
|
||||
|
||||
# Determine req_id for responding or use send_msg for scheduled push
|
||||
req_id = getattr(msg, "req_id", None) if msg else None
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
self._send_text(reply.content, receiver, is_group, req_id)
|
||||
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
|
||||
self._send_image(reply.content, receiver, is_group, req_id)
|
||||
elif reply.type == ReplyType.FILE:
|
||||
if hasattr(reply, "text_content") and reply.text_content:
|
||||
self._send_text(reply.text_content, receiver, is_group, req_id)
|
||||
time.sleep(0.3)
|
||||
self._send_file(reply.content, receiver, is_group, req_id)
|
||||
elif reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
|
||||
self._send_file(reply.content, receiver, is_group, req_id, media_type="video")
|
||||
else:
|
||||
logger.warning(f"[WecomBot] Unsupported reply type: {reply.type}, falling back to text")
|
||||
self._send_text(str(reply.content), receiver, is_group, req_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Respond message (via websocket)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_text(self, content: str, receiver: str, is_group: bool, req_id: str = None):
|
||||
"""Send text/markdown reply. Reuses stream state if available (streaming mode)."""
|
||||
if req_id:
|
||||
state = self._stream_states.pop(req_id, None)
|
||||
if state:
|
||||
final_content = state["committed"] if state["committed"] else content
|
||||
stream_id = state["stream_id"]
|
||||
else:
|
||||
final_content = content
|
||||
stream_id = uuid.uuid4().hex[:16]
|
||||
|
||||
# Brief pause so the server finishes processing the last intermediate chunk
|
||||
# before receiving the finish packet
|
||||
time.sleep(0.15)
|
||||
|
||||
self._ws_send({
|
||||
"cmd": "aibot_respond_msg",
|
||||
"headers": {"req_id": req_id},
|
||||
"body": {
|
||||
"msgtype": "stream",
|
||||
"stream": {
|
||||
"id": stream_id,
|
||||
"finish": True,
|
||||
"content": final_content,
|
||||
},
|
||||
},
|
||||
})
|
||||
else:
|
||||
self._active_send_markdown(content, receiver, is_group)
|
||||
|
||||
def _send_image(self, img_path_or_url: str, receiver: str, is_group: bool, req_id: str = None):
|
||||
"""Send image reply. Converts to JPG/PNG and compresses if >2MB."""
|
||||
local_path = img_path_or_url
|
||||
if local_path.startswith("file://"):
|
||||
local_path = local_path[7:]
|
||||
|
||||
if local_path.startswith(("http://", "https://")):
|
||||
try:
|
||||
resp = requests.get(local_path, timeout=30)
|
||||
resp.raise_for_status()
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
if "jpeg" in ct or "jpg" in ct:
|
||||
ext = ".jpg"
|
||||
elif "webp" in ct:
|
||||
ext = ".webp"
|
||||
elif "gif" in ct:
|
||||
ext = ".gif"
|
||||
else:
|
||||
ext = ".png"
|
||||
tmp_path = f"/tmp/wecom_img_{uuid.uuid4().hex[:8]}{ext}"
|
||||
with open(tmp_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
logger.info(f"[WecomBot] Image downloaded: size={len(resp.content)}, "
|
||||
f"content-type={ct}, path={tmp_path}")
|
||||
local_path = tmp_path
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] Failed to download image for sending: {e}")
|
||||
self._send_text("[Image send failed]", receiver, is_group, req_id)
|
||||
return
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[WecomBot] Image file not found: {local_path}")
|
||||
return
|
||||
|
||||
max_image_size = 2 * 1024 * 1024 # 2MB limit for image upload
|
||||
local_path = self._ensure_image_format(local_path)
|
||||
if not local_path:
|
||||
self._send_text("[Image format conversion failed]", receiver, is_group, req_id)
|
||||
return
|
||||
|
||||
if os.path.getsize(local_path) > max_image_size:
|
||||
local_path = self._compress_image(local_path, max_image_size)
|
||||
if not local_path:
|
||||
self._send_text("[Image too large]", receiver, is_group, req_id)
|
||||
return
|
||||
|
||||
file_size = os.path.getsize(local_path)
|
||||
logger.info(f"[WecomBot] Uploading image: path={local_path}, size={file_size} bytes")
|
||||
media_id = self._upload_media(local_path, "image")
|
||||
if not media_id:
|
||||
logger.error("[WecomBot] Failed to upload image")
|
||||
self._send_text("[Image upload failed]", receiver, is_group, req_id)
|
||||
return
|
||||
|
||||
if req_id:
|
||||
self._ws_send({
|
||||
"cmd": "aibot_respond_msg",
|
||||
"headers": {"req_id": req_id},
|
||||
"body": {
|
||||
"msgtype": "image",
|
||||
"image": {"media_id": media_id},
|
||||
},
|
||||
})
|
||||
else:
|
||||
self._ws_send({
|
||||
"cmd": "aibot_send_msg",
|
||||
"headers": {"req_id": self._gen_req_id()},
|
||||
"body": {
|
||||
"chatid": receiver,
|
||||
"chat_type": 2 if is_group else 1,
|
||||
"msgtype": "image",
|
||||
"image": {"media_id": media_id},
|
||||
},
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _ensure_image_format(file_path: str) -> str:
|
||||
"""Ensure image is JPG or PNG (the only formats wecom supports). Convert if needed."""
|
||||
try:
|
||||
from PIL import Image
|
||||
img = Image.open(file_path)
|
||||
fmt = (img.format or "").upper()
|
||||
if fmt in ("JPEG", "PNG"):
|
||||
# Already a supported format, but make sure the filename extension matches
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if fmt == "JPEG" and ext in (".jpg", ".jpeg"):
|
||||
return file_path
|
||||
if fmt == "PNG" and ext == ".png":
|
||||
return file_path
|
||||
# Extension doesn't match — rename/copy with correct extension
|
||||
correct_ext = ".jpg" if fmt == "JPEG" else ".png"
|
||||
out_path = f"/tmp/wecom_fmt_{uuid.uuid4().hex[:8]}{correct_ext}"
|
||||
img.save(out_path, fmt)
|
||||
logger.info(f"[WecomBot] Image renamed: {file_path} -> {out_path} ({fmt})")
|
||||
return out_path
|
||||
|
||||
# Unsupported format (WebP, GIF, BMP, etc.) — convert to PNG
|
||||
if img.mode == "RGBA":
|
||||
out_path = f"/tmp/wecom_fmt_{uuid.uuid4().hex[:8]}.png"
|
||||
img.save(out_path, "PNG")
|
||||
else:
|
||||
out_path = f"/tmp/wecom_fmt_{uuid.uuid4().hex[:8]}.jpg"
|
||||
img.convert("RGB").save(out_path, "JPEG", quality=90)
|
||||
logger.info(f"[WecomBot] Image converted from {fmt} -> {out_path}")
|
||||
return out_path
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] Image format check failed: {e}")
|
||||
return file_path
|
||||
|
||||
@staticmethod
|
||||
def _compress_image(file_path: str, max_bytes: int) -> str:
|
||||
"""Compress image to fit within max_bytes. Returns new path or empty string."""
|
||||
try:
|
||||
from PIL import Image
|
||||
img = Image.open(file_path)
|
||||
if img.mode == "RGBA":
|
||||
img = img.convert("RGB")
|
||||
|
||||
out_path = f"/tmp/wecom_compressed_{uuid.uuid4().hex[:8]}.jpg"
|
||||
quality = 85
|
||||
while quality >= 30:
|
||||
img.save(out_path, "JPEG", quality=quality, optimize=True)
|
||||
if os.path.getsize(out_path) <= max_bytes:
|
||||
logger.info(f"[WecomBot] Image compressed: quality={quality}, "
|
||||
f"size={os.path.getsize(out_path)} bytes")
|
||||
return out_path
|
||||
quality -= 10
|
||||
|
||||
# Still too large — resize
|
||||
ratio = (max_bytes / os.path.getsize(out_path)) ** 0.5
|
||||
new_size = (int(img.width * ratio), int(img.height * ratio))
|
||||
img = img.resize(new_size, Image.LANCZOS)
|
||||
img.save(out_path, "JPEG", quality=70, optimize=True)
|
||||
if os.path.getsize(out_path) <= max_bytes:
|
||||
logger.info(f"[WecomBot] Image compressed with resize: {new_size}, "
|
||||
f"size={os.path.getsize(out_path)} bytes")
|
||||
return out_path
|
||||
|
||||
logger.error(f"[WecomBot] Cannot compress image below {max_bytes} bytes")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] Image compression failed: {e}")
|
||||
return ""
|
||||
|
||||
def _send_file(self, file_path: str, receiver: str, is_group: bool,
|
||||
req_id: str = None, media_type: str = "file"):
|
||||
"""Send file/video reply by uploading media first."""
|
||||
local_path = file_path
|
||||
if local_path.startswith("file://"):
|
||||
local_path = local_path[7:]
|
||||
|
||||
if local_path.startswith(("http://", "https://")):
|
||||
try:
|
||||
resp = requests.get(local_path, timeout=60)
|
||||
resp.raise_for_status()
|
||||
ext = os.path.splitext(local_path)[1] or ".bin"
|
||||
tmp_path = f"/tmp/wecom_file_{uuid.uuid4().hex[:8]}{ext}"
|
||||
with open(tmp_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
local_path = tmp_path
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] Failed to download file for sending: {e}")
|
||||
return
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[WecomBot] File not found: {local_path}")
|
||||
return
|
||||
|
||||
media_id = self._upload_media(local_path, media_type)
|
||||
if not media_id:
|
||||
logger.error(f"[WecomBot] Failed to upload {media_type}")
|
||||
return
|
||||
|
||||
if req_id:
|
||||
self._ws_send({
|
||||
"cmd": "aibot_respond_msg",
|
||||
"headers": {"req_id": req_id},
|
||||
"body": {
|
||||
"msgtype": media_type,
|
||||
media_type: {"media_id": media_id},
|
||||
},
|
||||
})
|
||||
else:
|
||||
self._ws_send({
|
||||
"cmd": "aibot_send_msg",
|
||||
"headers": {"req_id": self._gen_req_id()},
|
||||
"body": {
|
||||
"chatid": receiver,
|
||||
"chat_type": 2 if is_group else 1,
|
||||
"msgtype": media_type,
|
||||
media_type: {"media_id": media_id},
|
||||
},
|
||||
})
|
||||
|
||||
def _active_send_markdown(self, content: str, receiver: str, is_group: bool):
|
||||
"""Proactively send markdown message (for scheduled tasks, no req_id)."""
|
||||
self._ws_send({
|
||||
"cmd": "aibot_send_msg",
|
||||
"headers": {"req_id": self._gen_req_id()},
|
||||
"body": {
|
||||
"chatid": receiver,
|
||||
"chat_type": 2 if is_group else 1,
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"content": content},
|
||||
},
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Media upload (chunked)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _upload_media(self, file_path: str, media_type: str = "file") -> str:
|
||||
"""
|
||||
Upload a local file to wecom bot via chunked upload protocol.
|
||||
Returns media_id on success, empty string on failure.
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"[WecomBot] Upload file not found: {file_path}")
|
||||
return ""
|
||||
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size < 5:
|
||||
logger.error(f"[WecomBot] File too small: {file_size} bytes")
|
||||
return ""
|
||||
|
||||
filename = os.path.basename(file_path)
|
||||
total_chunks = math.ceil(file_size / MEDIA_CHUNK_SIZE)
|
||||
if total_chunks > 100:
|
||||
logger.error(f"[WecomBot] Too many chunks: {total_chunks} > 100")
|
||||
return ""
|
||||
|
||||
file_md5 = hashlib.md5()
|
||||
with open(file_path, "rb") as f:
|
||||
for block in iter(lambda: f.read(8192), b""):
|
||||
file_md5.update(block)
|
||||
md5_hex = file_md5.hexdigest()
|
||||
|
||||
# 1. Init upload
|
||||
init_resp = self._send_and_wait({
|
||||
"cmd": "aibot_upload_media_init",
|
||||
"headers": {"req_id": self._gen_req_id()},
|
||||
"body": {
|
||||
"type": media_type,
|
||||
"filename": filename,
|
||||
"total_size": file_size,
|
||||
"total_chunks": total_chunks,
|
||||
"md5": md5_hex,
|
||||
},
|
||||
}, timeout=15)
|
||||
|
||||
if init_resp.get("errcode") != 0:
|
||||
logger.error(f"[WecomBot] Upload init failed: {init_resp}")
|
||||
return ""
|
||||
|
||||
upload_id = init_resp.get("body", {}).get("upload_id")
|
||||
if not upload_id:
|
||||
logger.error("[WecomBot] Failed to get upload_id")
|
||||
return ""
|
||||
|
||||
# 2. Upload chunks
|
||||
with open(file_path, "rb") as f:
|
||||
for idx in range(total_chunks):
|
||||
chunk = f.read(MEDIA_CHUNK_SIZE)
|
||||
b64_data = base64.b64encode(chunk).decode("utf-8")
|
||||
chunk_resp = self._send_and_wait({
|
||||
"cmd": "aibot_upload_media_chunk",
|
||||
"headers": {"req_id": self._gen_req_id()},
|
||||
"body": {
|
||||
"upload_id": upload_id,
|
||||
"chunk_index": idx,
|
||||
"base64_data": b64_data,
|
||||
},
|
||||
}, timeout=30)
|
||||
if chunk_resp.get("errcode") != 0:
|
||||
logger.error(f"[WecomBot] Chunk {idx} upload failed: {chunk_resp}")
|
||||
return ""
|
||||
|
||||
# 3. Finish upload
|
||||
finish_resp = self._send_and_wait({
|
||||
"cmd": "aibot_upload_media_finish",
|
||||
"headers": {"req_id": self._gen_req_id()},
|
||||
"body": {"upload_id": upload_id},
|
||||
}, timeout=30)
|
||||
|
||||
if finish_resp.get("errcode") != 0:
|
||||
logger.error(f"[WecomBot] Upload finish failed: {finish_resp}")
|
||||
return ""
|
||||
|
||||
media_id = finish_resp.get("body", {}).get("media_id", "")
|
||||
if media_id:
|
||||
logger.info(f"[WecomBot] Media uploaded: media_id={media_id}")
|
||||
else:
|
||||
logger.error("[WecomBot] Failed to get media_id from finish response")
|
||||
return media_id
|
||||
216
channel/wecom_bot/wecom_bot_message.py
Normal file
216
channel/wecom_bot/wecom_bot_message.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import os
|
||||
import re
|
||||
import base64
|
||||
import requests
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
|
||||
MAGIC_SIGNATURES = [
|
||||
(b"%PDF", ".pdf"),
|
||||
(b"\x89PNG\r\n\x1a\n", ".png"),
|
||||
(b"\xff\xd8\xff", ".jpg"),
|
||||
(b"GIF87a", ".gif"),
|
||||
(b"GIF89a", ".gif"),
|
||||
(b"RIFF", ".webp"), # RIFF....WEBP, further checked below
|
||||
(b"PK\x03\x04", ".zip"), # zip / docx / xlsx / pptx
|
||||
(b"\x1f\x8b", ".gz"),
|
||||
(b"Rar!\x1a\x07", ".rar"),
|
||||
(b"7z\xbc\xaf\x27\x1c", ".7z"),
|
||||
(b"\x00\x00\x00", ".mp4"), # ftyp box, further checked below
|
||||
(b"#!AMR", ".amr"),
|
||||
]
|
||||
|
||||
OFFICE_ZIP_MARKERS = {
|
||||
b"word/": ".docx",
|
||||
b"xl/": ".xlsx",
|
||||
b"ppt/": ".pptx",
|
||||
}
|
||||
|
||||
|
||||
def _guess_ext_from_bytes(data: bytes) -> str:
|
||||
"""Guess file extension from file content magic bytes."""
|
||||
if not data or len(data) < 8:
|
||||
return ""
|
||||
for sig, ext in MAGIC_SIGNATURES:
|
||||
if data[:len(sig)] == sig:
|
||||
if ext == ".webp" and data[8:12] != b"WEBP":
|
||||
continue
|
||||
if ext == ".mp4":
|
||||
if b"ftyp" not in data[4:12]:
|
||||
continue
|
||||
if ext == ".zip":
|
||||
for marker, office_ext in OFFICE_ZIP_MARKERS.items():
|
||||
if marker in data[:2000]:
|
||||
return office_ext
|
||||
return ".zip"
|
||||
return ext
|
||||
return ""
|
||||
|
||||
|
||||
def _decrypt_media(url: str, aeskey: str) -> bytes:
|
||||
"""
|
||||
Download and decrypt AES-256-CBC encrypted media from wecom bot.
|
||||
Returns decrypted bytes.
|
||||
"""
|
||||
resp = requests.get(url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
encrypted = resp.content
|
||||
|
||||
key = base64.b64decode(aeskey + "=" * (-len(aeskey) % 4))
|
||||
if len(key) != 32:
|
||||
raise ValueError(f"Invalid AES key length: {len(key)}, expected 32")
|
||||
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
decrypted = cipher.decrypt(encrypted)
|
||||
|
||||
pad_len = decrypted[-1]
|
||||
if pad_len > 32:
|
||||
raise ValueError(f"Invalid PKCS7 padding length: {pad_len}")
|
||||
return decrypted[:-pad_len]
|
||||
|
||||
|
||||
def _get_tmp_dir() -> str:
|
||||
"""Return the workspace tmp directory (absolute path), creating it if needed."""
|
||||
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(ws_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class WecomBotMessage(ChatMessage):
|
||||
"""Message wrapper for wecom bot (websocket long-connection mode)."""
|
||||
|
||||
def __init__(self, msg_body: dict, is_group: bool = False):
|
||||
super().__init__(msg_body)
|
||||
self.msg_id = msg_body.get("msgid")
|
||||
self.create_time = msg_body.get("create_time")
|
||||
self.is_group = is_group
|
||||
|
||||
msg_type = msg_body.get("msgtype")
|
||||
from_userid = msg_body.get("from", {}).get("userid", "")
|
||||
chat_id = msg_body.get("chatid", "")
|
||||
bot_id = msg_body.get("aibotid", "")
|
||||
|
||||
if msg_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
content = msg_body.get("text", {}).get("content", "")
|
||||
if is_group:
|
||||
content = re.sub(r"@\S+\s*", "", content).strip()
|
||||
self.content = content
|
||||
|
||||
elif msg_type == "voice":
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = msg_body.get("voice", {}).get("content", "")
|
||||
|
||||
elif msg_type == "image":
|
||||
self.ctype = ContextType.IMAGE
|
||||
image_info = msg_body.get("image", {})
|
||||
image_url = image_info.get("url", "")
|
||||
aeskey = image_info.get("aeskey", "")
|
||||
tmp_dir = _get_tmp_dir()
|
||||
image_path = os.path.join(tmp_dir, f"wecom_{self.msg_id}.png")
|
||||
|
||||
try:
|
||||
data = _decrypt_media(image_url, aeskey)
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(data)
|
||||
self.content = image_path
|
||||
self.image_path = image_path
|
||||
logger.info(f"[WecomBot] Image downloaded: {image_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] Failed to download image: {e}")
|
||||
self.content = "[Image download failed]"
|
||||
self.image_path = None
|
||||
|
||||
elif msg_type == "mixed":
|
||||
self.ctype = ContextType.TEXT
|
||||
text_parts = []
|
||||
image_paths = []
|
||||
mixed_items = msg_body.get("mixed", {}).get("msg_item", [])
|
||||
tmp_dir = _get_tmp_dir()
|
||||
|
||||
for idx, item in enumerate(mixed_items):
|
||||
item_type = item.get("msgtype")
|
||||
if item_type == "text":
|
||||
txt = item.get("text", {}).get("content", "")
|
||||
if is_group:
|
||||
txt = re.sub(r"@\S+\s*", "", txt).strip()
|
||||
if txt:
|
||||
text_parts.append(txt)
|
||||
elif item_type == "image":
|
||||
img_info = item.get("image", {})
|
||||
img_url = img_info.get("url", "")
|
||||
img_aeskey = img_info.get("aeskey", "")
|
||||
img_path = os.path.join(tmp_dir, f"wecom_{self.msg_id}_{idx}.png")
|
||||
try:
|
||||
img_data = _decrypt_media(img_url, img_aeskey)
|
||||
with open(img_path, "wb") as f:
|
||||
f.write(img_data)
|
||||
image_paths.append(img_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] Failed to download mixed image: {e}")
|
||||
|
||||
content_parts = text_parts[:]
|
||||
for p in image_paths:
|
||||
content_parts.append(f"[图片: {p}]")
|
||||
self.content = "\n".join(content_parts) if content_parts else "[Mixed message]"
|
||||
|
||||
elif msg_type == "file":
|
||||
self.ctype = ContextType.FILE
|
||||
file_info = msg_body.get("file", {})
|
||||
file_url = file_info.get("url", "")
|
||||
aeskey = file_info.get("aeskey", "")
|
||||
tmp_dir = _get_tmp_dir()
|
||||
base_path = os.path.join(tmp_dir, f"wecom_{self.msg_id}")
|
||||
self.content = base_path
|
||||
|
||||
def _download_file():
|
||||
try:
|
||||
data = _decrypt_media(file_url, aeskey)
|
||||
ext = _guess_ext_from_bytes(data)
|
||||
final_path = base_path + ext
|
||||
with open(final_path, "wb") as f:
|
||||
f.write(data)
|
||||
self.content = final_path
|
||||
logger.info(f"[WecomBot] File downloaded: {final_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] Failed to download file: {e}")
|
||||
self._prepare_fn = _download_file
|
||||
|
||||
elif msg_type == "video":
|
||||
self.ctype = ContextType.FILE
|
||||
video_info = msg_body.get("video", {})
|
||||
video_url = video_info.get("url", "")
|
||||
aeskey = video_info.get("aeskey", "")
|
||||
tmp_dir = _get_tmp_dir()
|
||||
self.content = os.path.join(tmp_dir, f"wecom_{self.msg_id}.mp4")
|
||||
|
||||
def _download_video():
|
||||
try:
|
||||
data = _decrypt_media(video_url, aeskey)
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(data)
|
||||
logger.info(f"[WecomBot] Video downloaded: {self.content}")
|
||||
except Exception as e:
|
||||
logger.error(f"[WecomBot] Failed to download video: {e}")
|
||||
self._prepare_fn = _download_video
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {msg_type}")
|
||||
|
||||
self.from_user_id = from_userid
|
||||
self.to_user_id = bot_id
|
||||
if is_group:
|
||||
self.other_user_id = chat_id
|
||||
self.actual_user_id = from_userid
|
||||
self.actual_user_nickname = from_userid
|
||||
else:
|
||||
self.other_user_id = from_userid
|
||||
self.actual_user_id = from_userid
|
||||
0
channel/weixin/__init__.py
Normal file
0
channel/weixin/__init__.py
Normal file
412
channel/weixin/weixin_api.py
Normal file
412
channel/weixin/weixin_api.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""
|
||||
Weixin HTTP JSON API client.
|
||||
|
||||
Implements the ilink bot protocol:
|
||||
- getUpdates (long-poll)
|
||||
- sendMessage
|
||||
- getUploadUrl
|
||||
- getConfig
|
||||
- sendTyping
|
||||
- QR login (get_bot_qrcode / get_qrcode_status)
|
||||
|
||||
CDN media upload with AES-128-ECB encryption.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from common.log import logger
|
||||
|
||||
DEFAULT_BASE_URL = "https://ilinkai.weixin.qq.com"
|
||||
CDN_BASE_URL = "https://novac2c.cdn.weixin.qq.com/c2c"
|
||||
DEFAULT_LONG_POLL_TIMEOUT = 35
|
||||
DEFAULT_API_TIMEOUT = 15
|
||||
QR_POLL_TIMEOUT = 35
|
||||
BOT_TYPE = "3"
|
||||
|
||||
|
||||
def _random_wechat_uin() -> str:
|
||||
val = random.randint(0, 0xFFFFFFFF)
|
||||
return base64.b64encode(str(val).encode("utf-8")).decode("utf-8")
|
||||
|
||||
|
||||
CHANNEL_VERSION = "2.0.0"
|
||||
# iLink-App-ClientVersion: uint32 encoded as major<<16 | minor<<8 | patch
|
||||
# 2.0.0 → 0x00020000 = 131072
|
||||
CLIENT_VERSION = "131072"
|
||||
|
||||
|
||||
def _build_headers(token: str = "") -> dict:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"AuthorizationType": "ilink_bot_token",
|
||||
"X-WECHAT-UIN": _random_wechat_uin(),
|
||||
"iLink-App-Id": "bot",
|
||||
"iLink-App-ClientVersion": CLIENT_VERSION,
|
||||
}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
return headers
|
||||
|
||||
|
||||
def _ensure_trailing_slash(url: str) -> str:
|
||||
return url if url.endswith("/") else url + "/"
|
||||
|
||||
|
||||
class WeixinApi:
|
||||
"""Stateless HTTP client for the Weixin ilink bot API."""
|
||||
|
||||
def __init__(self, base_url: str = DEFAULT_BASE_URL, token: str = "",
|
||||
cdn_base_url: str = CDN_BASE_URL):
|
||||
self.base_url = base_url
|
||||
self.token = token
|
||||
self.cdn_base_url = cdn_base_url
|
||||
|
||||
def _post(self, endpoint: str, body: dict, timeout: int = DEFAULT_API_TIMEOUT) -> dict:
|
||||
url = _ensure_trailing_slash(self.base_url) + endpoint
|
||||
headers = _build_headers(self.token)
|
||||
body.setdefault("base_info", {}).setdefault("channel_version", CHANNEL_VERSION)
|
||||
try:
|
||||
resp = requests.post(url, json=body, headers=headers, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except requests.exceptions.Timeout:
|
||||
logger.debug(f"[Weixin] API timeout: {endpoint}")
|
||||
return {"ret": 0, "msgs": []}
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] API error {endpoint}: {e}")
|
||||
raise
|
||||
|
||||
# ── getUpdates (long-poll) ─────────────────────────────────────────
|
||||
|
||||
def get_updates(self, get_updates_buf: str = "", timeout: int = DEFAULT_LONG_POLL_TIMEOUT) -> dict:
|
||||
return self._post("ilink/bot/getupdates", {
|
||||
"get_updates_buf": get_updates_buf,
|
||||
}, timeout=timeout + 5)
|
||||
|
||||
# ── sendMessage ────────────────────────────────────────────────────
|
||||
|
||||
def send_text(self, to: str, text: str, context_token: str) -> dict:
|
||||
return self._post("ilink/bot/sendmessage", {
|
||||
"msg": {
|
||||
"from_user_id": "",
|
||||
"to_user_id": to,
|
||||
"client_id": uuid.uuid4().hex[:16],
|
||||
"message_type": 2, # BOT
|
||||
"message_state": 2, # FINISH
|
||||
"item_list": [{"type": 1, "text_item": {"text": text}}],
|
||||
"context_token": context_token,
|
||||
}
|
||||
})
|
||||
|
||||
def send_image_item(self, to: str, context_token: str,
|
||||
encrypt_query_param: str, aes_key_b64: str,
|
||||
ciphertext_size: int, text: str = "") -> dict:
|
||||
items = []
|
||||
if text:
|
||||
items.append({"type": 1, "text_item": {"text": text}})
|
||||
items.append({
|
||||
"type": 2,
|
||||
"image_item": {
|
||||
"media": {
|
||||
"encrypt_query_param": encrypt_query_param,
|
||||
"aes_key": aes_key_b64,
|
||||
"encrypt_type": 1,
|
||||
},
|
||||
"mid_size": ciphertext_size,
|
||||
}
|
||||
})
|
||||
return self._send_items(to, context_token, items)
|
||||
|
||||
def send_file_item(self, to: str, context_token: str,
|
||||
encrypt_query_param: str, aes_key_b64: str,
|
||||
file_name: str, file_size: int, text: str = "") -> dict:
|
||||
items = []
|
||||
if text:
|
||||
items.append({"type": 1, "text_item": {"text": text}})
|
||||
items.append({
|
||||
"type": 4,
|
||||
"file_item": {
|
||||
"media": {
|
||||
"encrypt_query_param": encrypt_query_param,
|
||||
"aes_key": aes_key_b64,
|
||||
"encrypt_type": 1,
|
||||
},
|
||||
"file_name": file_name,
|
||||
"len": str(file_size),
|
||||
}
|
||||
})
|
||||
return self._send_items(to, context_token, items)
|
||||
|
||||
def send_video_item(self, to: str, context_token: str,
|
||||
encrypt_query_param: str, aes_key_b64: str,
|
||||
ciphertext_size: int, text: str = "") -> dict:
|
||||
items = []
|
||||
if text:
|
||||
items.append({"type": 1, "text_item": {"text": text}})
|
||||
items.append({
|
||||
"type": 5,
|
||||
"video_item": {
|
||||
"media": {
|
||||
"encrypt_query_param": encrypt_query_param,
|
||||
"aes_key": aes_key_b64,
|
||||
"encrypt_type": 1,
|
||||
},
|
||||
"video_size": ciphertext_size,
|
||||
}
|
||||
})
|
||||
return self._send_items(to, context_token, items)
|
||||
|
||||
def _send_items(self, to: str, context_token: str, items: list) -> dict:
|
||||
return self._post("ilink/bot/sendmessage", {
|
||||
"msg": {
|
||||
"from_user_id": "",
|
||||
"to_user_id": to,
|
||||
"client_id": uuid.uuid4().hex[:16],
|
||||
"message_type": 2,
|
||||
"message_state": 2,
|
||||
"item_list": items,
|
||||
"context_token": context_token,
|
||||
}
|
||||
})
|
||||
|
||||
# ── getUploadUrl ───────────────────────────────────────────────────
|
||||
|
||||
def get_upload_url(self, filekey: str, media_type: int, to_user_id: str,
|
||||
rawsize: int, rawfilemd5: str, filesize: int,
|
||||
aeskey: str) -> dict:
|
||||
return self._post("ilink/bot/getuploadurl", {
|
||||
"filekey": filekey,
|
||||
"media_type": media_type,
|
||||
"to_user_id": to_user_id,
|
||||
"rawsize": rawsize,
|
||||
"rawfilemd5": rawfilemd5,
|
||||
"filesize": filesize,
|
||||
"aeskey": aeskey,
|
||||
"no_need_thumb": True,
|
||||
})
|
||||
|
||||
# ── getConfig / sendTyping ─────────────────────────────────────────
|
||||
|
||||
def get_config(self, user_id: str, context_token: str = "") -> dict:
|
||||
return self._post("ilink/bot/getconfig", {
|
||||
"ilink_user_id": user_id,
|
||||
"context_token": context_token,
|
||||
}, timeout=10)
|
||||
|
||||
def send_typing(self, user_id: str, typing_ticket: str, status: int = 1) -> dict:
|
||||
return self._post("ilink/bot/sendtyping", {
|
||||
"ilink_user_id": user_id,
|
||||
"typing_ticket": typing_ticket,
|
||||
"status": status,
|
||||
}, timeout=10)
|
||||
|
||||
# ── QR Login ───────────────────────────────────────────────────────
|
||||
|
||||
def fetch_qr_code(self) -> dict:
|
||||
url = _ensure_trailing_slash(self.base_url) + f"ilink/bot/get_bot_qrcode?bot_type={BOT_TYPE}"
|
||||
resp = requests.get(url, timeout=15)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
def poll_qr_status(self, qrcode: str, timeout: int = QR_POLL_TIMEOUT) -> dict:
|
||||
url = (_ensure_trailing_slash(self.base_url) +
|
||||
f"ilink/bot/get_qrcode_status?qrcode={requests.utils.quote(qrcode)}")
|
||||
headers = {
|
||||
"iLink-App-Id": "bot",
|
||||
"iLink-App-ClientVersion": CLIENT_VERSION,
|
||||
}
|
||||
try:
|
||||
resp = requests.get(url, headers=headers, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except requests.exceptions.Timeout:
|
||||
return {"status": "wait"}
|
||||
|
||||
|
||||
# ── AES-128-ECB helpers ─────────────────────────────────────────────
|
||||
|
||||
def _aes_ecb_encrypt(data: bytes, key: bytes) -> bytes:
|
||||
from Crypto.Cipher import AES
|
||||
pad_len = 16 - (len(data) % 16)
|
||||
padded = data + bytes([pad_len] * pad_len)
|
||||
cipher = AES.new(key, AES.MODE_ECB)
|
||||
return cipher.encrypt(padded)
|
||||
|
||||
|
||||
def _aes_ecb_decrypt(data: bytes, key: bytes) -> bytes:
|
||||
from Crypto.Cipher import AES
|
||||
cipher = AES.new(key, AES.MODE_ECB)
|
||||
decrypted = cipher.decrypt(data)
|
||||
pad_len = decrypted[-1]
|
||||
if pad_len > 16:
|
||||
return decrypted
|
||||
return decrypted[:-pad_len]
|
||||
|
||||
|
||||
def _file_md5(file_path: str) -> str:
|
||||
h = hashlib.md5()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _md5_bytes(data: bytes) -> str:
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
|
||||
def _aes_ecb_padded_size(plaintext_size: int) -> int:
|
||||
"""PKCS7 padded size for AES-128-ECB."""
|
||||
return ((plaintext_size + 1 + 15) // 16) * 16
|
||||
|
||||
|
||||
UPLOAD_MAX_RETRIES = 3
|
||||
|
||||
|
||||
def upload_media_to_cdn(api: WeixinApi, file_path: str, to_user_id: str,
|
||||
media_type: int) -> dict:
|
||||
"""
|
||||
Upload a local file to the Weixin CDN (matching official plugin protocol).
|
||||
|
||||
Args:
|
||||
api: WeixinApi instance
|
||||
file_path: local file path
|
||||
to_user_id: target user id
|
||||
media_type: 1=IMAGE, 2=VIDEO, 3=FILE
|
||||
|
||||
Returns:
|
||||
dict with keys: encrypt_query_param, aes_key_b64, ciphertext_size, raw_size
|
||||
"""
|
||||
aes_key = os.urandom(16)
|
||||
aes_key_hex = aes_key.hex()
|
||||
filekey = uuid.uuid4().hex
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
raw_data = f.read()
|
||||
|
||||
raw_size = len(raw_data)
|
||||
raw_md5 = _md5_bytes(raw_data)
|
||||
cipher_size = _aes_ecb_padded_size(raw_size)
|
||||
|
||||
encrypted = _aes_ecb_encrypt(raw_data, aes_key)
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
download_param = None
|
||||
last_error = None
|
||||
for attempt in range(1, UPLOAD_MAX_RETRIES + 1):
|
||||
try:
|
||||
if attempt > 1:
|
||||
filekey = uuid.uuid4().hex
|
||||
resp = api.get_upload_url(
|
||||
filekey=filekey,
|
||||
media_type=media_type,
|
||||
to_user_id=to_user_id,
|
||||
rawsize=raw_size,
|
||||
rawfilemd5=raw_md5,
|
||||
filesize=cipher_size,
|
||||
aeskey=aes_key_hex,
|
||||
)
|
||||
|
||||
# API may return either upload_full_url (new) or upload_param (legacy)
|
||||
upload_full_url = resp.get("upload_full_url", "")
|
||||
upload_param = resp.get("upload_param", "")
|
||||
if upload_full_url:
|
||||
cdn_url = upload_full_url
|
||||
elif upload_param:
|
||||
cdn_url = (f"{api.cdn_base_url}/upload"
|
||||
f"?encrypted_query_param={quote(upload_param)}"
|
||||
f"&filekey={quote(filekey)}")
|
||||
else:
|
||||
raise RuntimeError(f"[Weixin] getUploadUrl returned neither upload_full_url nor upload_param: {resp}")
|
||||
|
||||
cdn_resp = requests.post(cdn_url, data=encrypted, headers={
|
||||
"Content-Type": "application/octet-stream",
|
||||
"Content-Length": str(len(encrypted)),
|
||||
}, timeout=120)
|
||||
if 400 <= cdn_resp.status_code < 500:
|
||||
err_msg = cdn_resp.headers.get("x-error-message", cdn_resp.text[:200])
|
||||
raise RuntimeError(f"CDN client error {cdn_resp.status_code}: {err_msg}")
|
||||
cdn_resp.raise_for_status()
|
||||
download_param = cdn_resp.headers.get("x-encrypted-param", "")
|
||||
if not download_param:
|
||||
raise RuntimeError("CDN response missing x-encrypted-param header")
|
||||
logger.debug(f"[Weixin] CDN upload success attempt={attempt} filekey={filekey}")
|
||||
break
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if "client error" in str(e):
|
||||
raise
|
||||
if attempt < UPLOAD_MAX_RETRIES:
|
||||
backoff = 2 ** attempt
|
||||
logger.warning(f"[Weixin] CDN upload attempt {attempt} failed, retrying in {backoff}s: {e}")
|
||||
time.sleep(backoff)
|
||||
else:
|
||||
logger.error(f"[Weixin] CDN upload failed after {UPLOAD_MAX_RETRIES} attempts: {e}")
|
||||
|
||||
if not download_param:
|
||||
raise last_error or RuntimeError("CDN upload failed")
|
||||
|
||||
aes_key_b64 = base64.b64encode(aes_key_hex.encode("utf-8")).decode("utf-8")
|
||||
|
||||
return {
|
||||
"encrypt_query_param": download_param,
|
||||
"aes_key_b64": aes_key_b64,
|
||||
"ciphertext_size": cipher_size,
|
||||
"raw_size": raw_size,
|
||||
}
|
||||
|
||||
|
||||
def download_media_from_cdn(cdn_base_url: str, encrypt_query_param: str,
|
||||
aes_key: str, save_path: str) -> str:
|
||||
"""
|
||||
Download and decrypt a media file from Weixin CDN.
|
||||
|
||||
Args:
|
||||
cdn_base_url: CDN base URL
|
||||
encrypt_query_param: encrypted query parameter from message
|
||||
aes_key: hex or base64 encoded AES key
|
||||
save_path: path to save decrypted file
|
||||
|
||||
Returns:
|
||||
save_path on success
|
||||
"""
|
||||
from urllib.parse import quote
|
||||
url = f"{cdn_base_url}/download?encrypted_query_param={quote(encrypt_query_param)}"
|
||||
resp = requests.get(url, timeout=60)
|
||||
resp.raise_for_status()
|
||||
|
||||
# Determine key format:
|
||||
# 1) 32-char hex string → 16 raw bytes
|
||||
# 2) base64 string → decode → if 32 bytes, treat as hex-encoded → 16 raw bytes
|
||||
# 3) base64 string → decode → 16 raw bytes directly
|
||||
try:
|
||||
key_bytes = bytes.fromhex(aes_key)
|
||||
if len(key_bytes) != 16:
|
||||
raise ValueError()
|
||||
except (ValueError, TypeError):
|
||||
decoded = base64.b64decode(aes_key)
|
||||
if len(decoded) == 32:
|
||||
try:
|
||||
key_bytes = bytes.fromhex(decoded.decode("ascii"))
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
raise ValueError(f"Invalid AES key: 32 bytes but not valid hex")
|
||||
elif len(decoded) == 16:
|
||||
key_bytes = decoded
|
||||
else:
|
||||
raise ValueError(f"Invalid AES key length after base64 decode: {len(decoded)}")
|
||||
|
||||
decrypted = _aes_ecb_decrypt(resp.content, key_bytes)
|
||||
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(decrypted)
|
||||
return save_path
|
||||
637
channel/weixin/weixin_channel.py
Normal file
637
channel/weixin/weixin_channel.py
Normal file
@@ -0,0 +1,637 @@
|
||||
"""
|
||||
Weixin channel implementation.
|
||||
|
||||
Uses HTTP long-poll (getUpdates) to receive messages and sendMessage to reply.
|
||||
Login via QR code scan through the ilink bot API.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.weixin.weixin_api import (
|
||||
WeixinApi, upload_media_to_cdn,
|
||||
DEFAULT_BASE_URL, CDN_BASE_URL,
|
||||
)
|
||||
from channel.weixin.weixin_message import WeixinMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
MAX_CONSECUTIVE_FAILURES = 3
|
||||
BACKOFF_DELAY = 30
|
||||
RETRY_DELAY = 2
|
||||
SESSION_EXPIRED_ERRCODE = -14
|
||||
TEXT_CHUNK_LIMIT = 4000
|
||||
QR_LOGIN_TIMEOUT_S = 480
|
||||
QR_MAX_REFRESHES = 10
|
||||
|
||||
|
||||
def _load_credentials(cred_path: str) -> dict:
|
||||
"""Load saved credentials from JSON file."""
|
||||
try:
|
||||
if os.path.exists(cred_path):
|
||||
with open(cred_path, "r") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Weixin] Failed to load credentials: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def _save_credentials(cred_path: str, data: dict):
|
||||
"""Save credentials to JSON file."""
|
||||
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
|
||||
with open(cred_path, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
try:
|
||||
os.chmod(cred_path, 0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@singleton
|
||||
class WeixinChannel(ChatChannel):
|
||||
|
||||
LOGIN_STATUS_IDLE = "idle"
|
||||
LOGIN_STATUS_WAITING = "waiting_scan"
|
||||
LOGIN_STATUS_SCANNED = "scanned"
|
||||
LOGIN_STATUS_OK = "logged_in"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api = None
|
||||
self._stop_event = threading.Event()
|
||||
self._poll_thread = None
|
||||
self._context_tokens = {} # user_id -> context_token
|
||||
self._received_msgs = ExpiredDict(60 * 60 * 7.1)
|
||||
self._get_updates_buf = ""
|
||||
self._credentials_path = ""
|
||||
self.login_status = self.LOGIN_STATUS_IDLE
|
||||
self._current_qr_url = ""
|
||||
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ── Lifecycle ──────────────────────────────────────────────────────
|
||||
|
||||
def startup(self):
|
||||
self._stop_event.clear()
|
||||
|
||||
base_url = conf().get("weixin_base_url", DEFAULT_BASE_URL)
|
||||
cdn_base_url = conf().get("weixin_cdn_base_url", CDN_BASE_URL)
|
||||
token = conf().get("weixin_token", "")
|
||||
|
||||
self._credentials_path = os.path.expanduser(
|
||||
conf().get("weixin_credentials_path", "~/.weixin_cow_credentials.json")
|
||||
)
|
||||
|
||||
if not token:
|
||||
creds = _load_credentials(self._credentials_path)
|
||||
token = creds.get("token", "")
|
||||
if creds.get("base_url"):
|
||||
base_url = creds["base_url"]
|
||||
|
||||
if not token:
|
||||
token, base_url = self._login_with_retry(base_url)
|
||||
if not token:
|
||||
return
|
||||
|
||||
self.api = WeixinApi(base_url=base_url, token=token, cdn_base_url=cdn_base_url)
|
||||
self.login_status = self.LOGIN_STATUS_OK
|
||||
|
||||
logger.info(f"[Weixin] 微信通道已启动,凭证保存在 {self._credentials_path},"
|
||||
f"如需重新扫码登录请删除该文件后重启")
|
||||
self.report_startup_success()
|
||||
|
||||
self._poll_loop()
|
||||
|
||||
def _login_with_retry(self, base_url: str) -> tuple:
|
||||
"""Attempt QR login, then wait for stop if failed.
|
||||
Returns (token, base_url) on success, or ("", "") if stopped."""
|
||||
logger.info("[Weixin] No token found, starting QR login...")
|
||||
self.login_status = self.LOGIN_STATUS_WAITING
|
||||
login_result = self._qr_login(base_url)
|
||||
if login_result:
|
||||
return login_result["token"], login_result.get("base_url", base_url)
|
||||
|
||||
self.login_status = self.LOGIN_STATUS_IDLE
|
||||
if not self._stop_event.is_set():
|
||||
logger.info("[Weixin] QR login timed out, waiting for stop or reconnect...")
|
||||
print(" 二维码登录超时,请通过控制台重新接入\n")
|
||||
self._stop_event.wait()
|
||||
|
||||
logger.info("[Weixin] Login cancelled by stop event")
|
||||
return "", ""
|
||||
|
||||
def stop(self):
|
||||
logger.info("[Weixin] stop() called")
|
||||
self._stop_event.set()
|
||||
|
||||
def _relogin(self) -> bool:
|
||||
"""Re-login after session expiry. Returns True on success."""
|
||||
base_url = self.api.base_url if self.api else DEFAULT_BASE_URL
|
||||
if os.path.exists(self._credentials_path):
|
||||
try:
|
||||
os.remove(self._credentials_path)
|
||||
except Exception:
|
||||
pass
|
||||
self.login_status = self.LOGIN_STATUS_WAITING
|
||||
result = self._qr_login(base_url)
|
||||
if not result:
|
||||
self.login_status = self.LOGIN_STATUS_IDLE
|
||||
return False
|
||||
self.api = WeixinApi(
|
||||
base_url=result.get("base_url", base_url),
|
||||
token=result["token"],
|
||||
cdn_base_url=self.api.cdn_base_url if self.api else CDN_BASE_URL,
|
||||
)
|
||||
self.login_status = self.LOGIN_STATUS_OK
|
||||
self._context_tokens.clear()
|
||||
return True
|
||||
|
||||
# ── QR Login ───────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _print_qr(qrcode_url: str):
|
||||
"""Print QR code to terminal for scanning."""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 请使用微信扫描二维码登录 (二维码约2分钟后过期)")
|
||||
print("=" * 60)
|
||||
try:
|
||||
import qrcode as qr_lib
|
||||
import io
|
||||
qr = qr_lib.QRCode(error_correction=qr_lib.constants.ERROR_CORRECT_L, box_size=1, border=1)
|
||||
qr.add_data(qrcode_url)
|
||||
qr.make(fit=True)
|
||||
buf = io.StringIO()
|
||||
qr.print_ascii(out=buf, invert=True)
|
||||
try:
|
||||
print(buf.getvalue())
|
||||
except UnicodeEncodeError:
|
||||
# Windows GBK terminals cannot render Unicode block characters
|
||||
print(f"\n (终端不支持显示二维码,请使用链接扫码)")
|
||||
print(f" 二维码链接: {qrcode_url}\n")
|
||||
except ImportError:
|
||||
print(f"\n 二维码链接: {qrcode_url}")
|
||||
print(" (安装 'qrcode' 包可在终端显示二维码)\n")
|
||||
|
||||
def _notify_cloud_qrcode(self, qrcode_url: str):
|
||||
"""Send QR code URL to cloud console when running in cloud mode."""
|
||||
if not self.cloud_mode:
|
||||
return
|
||||
try:
|
||||
from common import cloud_client
|
||||
client = getattr(cloud_client, "chat_client", None)
|
||||
if client and getattr(client, "client_id", None):
|
||||
client.send_channel_qrcode("weixin", qrcode_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Weixin] Failed to notify cloud QR code: {e}")
|
||||
|
||||
def _notify_cloud_connected(self):
|
||||
"""Send connected status to cloud console when login succeeds."""
|
||||
if not self.cloud_mode:
|
||||
return
|
||||
try:
|
||||
from common import cloud_client
|
||||
client = getattr(cloud_client, "chat_client", None)
|
||||
if client and getattr(client, "client_id", None):
|
||||
client.send_channel_status("weixin", "connected")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Weixin] Failed to notify cloud connected: {e}")
|
||||
|
||||
def _qr_login(self, base_url: str) -> dict:
|
||||
"""Perform interactive QR code login. Returns dict with token/base_url or empty dict."""
|
||||
api = WeixinApi(base_url=base_url)
|
||||
try:
|
||||
qr_resp = api.fetch_qr_code()
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] Failed to fetch QR code: {e}")
|
||||
return {}
|
||||
|
||||
qrcode = qr_resp.get("qrcode", "")
|
||||
qrcode_url = qr_resp.get("qrcode_img_content", "")
|
||||
|
||||
if not qrcode:
|
||||
logger.error("[Weixin] No QR code returned from server")
|
||||
return {}
|
||||
|
||||
self._current_qr_url = qrcode_url
|
||||
logger.info(f"[Weixin] 微信二维码链接: {qrcode_url}")
|
||||
self._print_qr(qrcode_url)
|
||||
self._notify_cloud_qrcode(qrcode_url)
|
||||
print(" 等待扫码...\n")
|
||||
|
||||
scanned_printed = False
|
||||
refresh_count = 0
|
||||
deadline = time.time() + QR_LOGIN_TIMEOUT_S
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
if time.time() >= deadline:
|
||||
logger.warning(f"[Weixin] QR login timed out after {QR_LOGIN_TIMEOUT_S}s")
|
||||
print(f"\n 二维码登录超时({QR_LOGIN_TIMEOUT_S}s),请重启后重试")
|
||||
break
|
||||
|
||||
try:
|
||||
status_resp = api.poll_qr_status(qrcode)
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] QR status poll error: {e}")
|
||||
return {}
|
||||
|
||||
status = status_resp.get("status", "wait")
|
||||
|
||||
if status == "wait":
|
||||
pass
|
||||
elif status == "scaned":
|
||||
self.login_status = self.LOGIN_STATUS_SCANNED
|
||||
if not scanned_printed:
|
||||
print(" 已扫码,请在手机上确认...")
|
||||
scanned_printed = True
|
||||
elif status == "expired":
|
||||
refresh_count += 1
|
||||
if refresh_count >= QR_MAX_REFRESHES:
|
||||
logger.warning(f"[Weixin] QR code refreshed {QR_MAX_REFRESHES} times, giving up")
|
||||
print(f"\n 二维码已刷新 {QR_MAX_REFRESHES} 次仍未扫码,请重启后重试")
|
||||
break
|
||||
print(f" 二维码已过期,正在刷新({refresh_count}/{QR_MAX_REFRESHES})...")
|
||||
try:
|
||||
qr_resp = api.fetch_qr_code()
|
||||
qrcode = qr_resp.get("qrcode", "")
|
||||
qrcode_url = qr_resp.get("qrcode_img_content", "")
|
||||
scanned_printed = False
|
||||
self._current_qr_url = qrcode_url
|
||||
logger.info(f"[Weixin] 微信二维码链接 ({refresh_count}/{QR_MAX_REFRESHES}): {qrcode_url}")
|
||||
self._print_qr(qrcode_url)
|
||||
self._notify_cloud_qrcode(qrcode_url)
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] QR refresh failed: {e}")
|
||||
return {}
|
||||
elif status == "confirmed":
|
||||
bot_token = status_resp.get("bot_token", "")
|
||||
bot_id = status_resp.get("ilink_bot_id", "")
|
||||
result_base_url = status_resp.get("baseurl", base_url)
|
||||
user_id = status_resp.get("ilink_user_id", "")
|
||||
|
||||
if not bot_token or not bot_id:
|
||||
logger.error("[Weixin] Login confirmed but missing token/bot_id")
|
||||
return {}
|
||||
|
||||
self._current_qr_url = ""
|
||||
print(f"\n ✅ 微信登录成功!bot_id={bot_id}")
|
||||
logger.info(f"[Weixin] Login confirmed: bot_id={bot_id}")
|
||||
self._notify_cloud_connected()
|
||||
|
||||
creds = {
|
||||
"token": bot_token,
|
||||
"base_url": result_base_url,
|
||||
"bot_id": bot_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
_save_credentials(self._credentials_path, creds)
|
||||
logger.info(f"[Weixin] Credentials saved to {self._credentials_path}")
|
||||
|
||||
return {"token": bot_token, "base_url": result_base_url}
|
||||
|
||||
self._stop_event.wait(1)
|
||||
|
||||
self._current_qr_url = ""
|
||||
if self._stop_event.is_set():
|
||||
logger.info("[Weixin] QR login cancelled by stop event")
|
||||
return {}
|
||||
|
||||
# ── Long-poll loop ─────────────────────────────────────────────────
|
||||
|
||||
def _poll_loop(self):
|
||||
"""Main long-poll loop: getUpdates -> parse -> produce."""
|
||||
logger.info("[Weixin] Starting long-poll loop")
|
||||
consecutive_failures = 0
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
resp = self.api.get_updates(self._get_updates_buf)
|
||||
|
||||
ret = resp.get("ret", 0)
|
||||
errcode = resp.get("errcode", 0)
|
||||
|
||||
is_error = (ret != 0) or (errcode != 0)
|
||||
if is_error:
|
||||
if errcode == SESSION_EXPIRED_ERRCODE or ret == SESSION_EXPIRED_ERRCODE:
|
||||
logger.error("[Weixin] Session expired (errcode -14), starting re-login...")
|
||||
if self._relogin():
|
||||
logger.info("[Weixin] Re-login successful, resuming long-poll")
|
||||
self._get_updates_buf = ""
|
||||
consecutive_failures = 0
|
||||
continue
|
||||
else:
|
||||
logger.error("[Weixin] Re-login failed, will retry in 5 minutes")
|
||||
self._stop_event.wait(300)
|
||||
continue
|
||||
|
||||
consecutive_failures += 1
|
||||
errmsg = resp.get("errmsg", "")
|
||||
logger.error(f"[Weixin] getUpdates error: ret={ret} errcode={errcode} "
|
||||
f"errmsg={errmsg} ({consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})")
|
||||
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
||||
consecutive_failures = 0
|
||||
self._stop_event.wait(BACKOFF_DELAY)
|
||||
else:
|
||||
self._stop_event.wait(RETRY_DELAY)
|
||||
continue
|
||||
|
||||
consecutive_failures = 0
|
||||
|
||||
# Update sync cursor
|
||||
new_buf = resp.get("get_updates_buf", "")
|
||||
if new_buf:
|
||||
self._get_updates_buf = new_buf
|
||||
|
||||
# Process messages
|
||||
msgs = resp.get("msgs", [])
|
||||
for raw_msg in msgs:
|
||||
try:
|
||||
self._process_message(raw_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] Failed to process message: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
if self._stop_event.is_set():
|
||||
break
|
||||
consecutive_failures += 1
|
||||
logger.error(f"[Weixin] getUpdates exception: {e} "
|
||||
f"({consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})")
|
||||
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
||||
consecutive_failures = 0
|
||||
self._stop_event.wait(BACKOFF_DELAY)
|
||||
else:
|
||||
self._stop_event.wait(RETRY_DELAY)
|
||||
|
||||
logger.info("[Weixin] Long-poll loop ended")
|
||||
|
||||
def _process_message(self, raw_msg: dict):
|
||||
"""Parse a single inbound message and produce to the handling queue."""
|
||||
msg_type = raw_msg.get("message_type", 0)
|
||||
if msg_type != 1: # Only process USER messages (type=1)
|
||||
return
|
||||
|
||||
msg_id = str(raw_msg.get("message_id", raw_msg.get("seq", "")))
|
||||
if self._received_msgs.get(msg_id):
|
||||
return
|
||||
self._received_msgs[msg_id] = True
|
||||
|
||||
from_user = raw_msg.get("from_user_id", "")
|
||||
context_token = raw_msg.get("context_token", "")
|
||||
|
||||
if context_token and from_user:
|
||||
self._context_tokens[from_user] = context_token
|
||||
|
||||
cdn_base_url = self.api.cdn_base_url if self.api else CDN_BASE_URL
|
||||
try:
|
||||
wx_msg = WeixinMessage(raw_msg, cdn_base_url=cdn_base_url)
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] Failed to parse WeixinMessage: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
logger.info(f"[Weixin] Received: from={from_user} ctype={wx_msg.ctype} "
|
||||
f"content={str(wx_msg.content)[:50]}")
|
||||
|
||||
# File cache logic
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
session_id = from_user
|
||||
|
||||
if wx_msg.ctype == ContextType.IMAGE:
|
||||
if hasattr(wx_msg, "image_path") and wx_msg.image_path:
|
||||
file_cache.add(session_id, wx_msg.image_path, file_type="image")
|
||||
logger.info(f"[Weixin] Image cached for session {session_id}")
|
||||
return
|
||||
|
||||
if wx_msg.ctype == ContextType.FILE:
|
||||
wx_msg.prepare()
|
||||
file_cache.add(session_id, wx_msg.content, file_type="file")
|
||||
logger.info(f"[Weixin] File cached for session {session_id}: {wx_msg.content}")
|
||||
return
|
||||
|
||||
if wx_msg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
refs = []
|
||||
for fi in cached_files:
|
||||
ftype, fpath = fi["type"], fi["path"]
|
||||
if ftype == "image":
|
||||
refs.append(f"[图片: {fpath}]")
|
||||
elif ftype == "video":
|
||||
refs.append(f"[视频: {fpath}]")
|
||||
else:
|
||||
refs.append(f"[文件: {fpath}]")
|
||||
wx_msg.content = wx_msg.content + "\n" + "\n".join(refs)
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(
|
||||
wx_msg.ctype,
|
||||
wx_msg.content,
|
||||
isgroup=False,
|
||||
msg=wx_msg,
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# ── _compose_context ───────────────────────────────────────────────
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
return context
|
||||
|
||||
# ── Send reply ─────────────────────────────────────────────────────
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context.get("receiver", "")
|
||||
msg = context.get("msg")
|
||||
context_token = self._get_context_token(receiver, msg)
|
||||
|
||||
if not context_token:
|
||||
logger.error(f"[Weixin] No context_token for receiver={receiver}, cannot send")
|
||||
return
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
self._send_text(reply.content, receiver, context_token)
|
||||
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
|
||||
self._send_image(reply.content, receiver, context_token)
|
||||
elif reply.type == ReplyType.FILE:
|
||||
self._send_file(reply.content, receiver, context_token)
|
||||
elif reply.type in (ReplyType.VIDEO, ReplyType.VIDEO_URL):
|
||||
self._send_video(reply.content, receiver, context_token)
|
||||
else:
|
||||
logger.warning(f"[Weixin] Unsupported reply type: {reply.type}, fallback to text")
|
||||
self._send_text(str(reply.content), receiver, context_token)
|
||||
|
||||
def _get_context_token(self, receiver: str, msg=None) -> str:
|
||||
"""Get the context_token for a receiver, required for all sends."""
|
||||
if msg and hasattr(msg, "context_token") and msg.context_token:
|
||||
return msg.context_token
|
||||
return self._context_tokens.get(receiver, "")
|
||||
|
||||
def _send_text(self, text: str, receiver: str, context_token: str):
|
||||
if len(text) <= TEXT_CHUNK_LIMIT:
|
||||
try:
|
||||
self.api.send_text(receiver, text, context_token)
|
||||
logger.debug(f"[Weixin] Text sent to {receiver}, len={len(text)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] Failed to send text: {e}")
|
||||
return
|
||||
|
||||
chunks = self._split_text(text, TEXT_CHUNK_LIMIT)
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
self.api.send_text(receiver, chunk, context_token)
|
||||
logger.debug(f"[Weixin] Text chunk {i+1}/{len(chunks)} sent to {receiver}, len={len(chunk)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] Failed to send text chunk {i+1}/{len(chunks)}: {e}")
|
||||
break
|
||||
if i < len(chunks) - 1:
|
||||
time.sleep(0.5)
|
||||
|
||||
@staticmethod
|
||||
def _split_text(text: str, limit: int) -> list:
|
||||
"""Split text into chunks, preferring to break at paragraph or line boundaries."""
|
||||
if len(text) <= limit:
|
||||
return [text]
|
||||
chunks = []
|
||||
while text:
|
||||
if len(text) <= limit:
|
||||
chunks.append(text)
|
||||
break
|
||||
cut = text.rfind("\n\n", 0, limit)
|
||||
if cut <= 0:
|
||||
cut = text.rfind("\n", 0, limit)
|
||||
if cut <= 0:
|
||||
cut = limit
|
||||
chunks.append(text[:cut])
|
||||
text = text[cut:].lstrip("\n")
|
||||
return chunks
|
||||
|
||||
def _send_image(self, img_path_or_url: str, receiver: str, context_token: str):
|
||||
local_path = self._resolve_media_path(img_path_or_url)
|
||||
if not local_path:
|
||||
self._send_text("[Image send failed: file not found]", receiver, context_token)
|
||||
return
|
||||
try:
|
||||
result = upload_media_to_cdn(self.api, local_path, receiver, media_type=1)
|
||||
self.api.send_image_item(
|
||||
to=receiver,
|
||||
context_token=context_token,
|
||||
encrypt_query_param=result["encrypt_query_param"],
|
||||
aes_key_b64=result["aes_key_b64"],
|
||||
ciphertext_size=result["ciphertext_size"],
|
||||
)
|
||||
logger.info(f"[Weixin] Image sent to {receiver}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] Image send failed: {e}")
|
||||
self._send_text("[Image send failed]", receiver, context_token)
|
||||
|
||||
def _send_file(self, file_path_or_url: str, receiver: str, context_token: str):
|
||||
local_path = self._resolve_media_path(file_path_or_url)
|
||||
if not local_path:
|
||||
self._send_text("[File send failed: file not found]", receiver, context_token)
|
||||
return
|
||||
try:
|
||||
result = upload_media_to_cdn(self.api, local_path, receiver, media_type=3)
|
||||
self.api.send_file_item(
|
||||
to=receiver,
|
||||
context_token=context_token,
|
||||
encrypt_query_param=result["encrypt_query_param"],
|
||||
aes_key_b64=result["aes_key_b64"],
|
||||
file_name=os.path.basename(local_path),
|
||||
file_size=result["raw_size"],
|
||||
)
|
||||
logger.info(f"[Weixin] File sent to {receiver}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] File send failed: {e}")
|
||||
self._send_text("[File send failed]", receiver, context_token)
|
||||
|
||||
def _send_video(self, video_path_or_url: str, receiver: str, context_token: str):
|
||||
local_path = self._resolve_media_path(video_path_or_url)
|
||||
if not local_path:
|
||||
self._send_text("[Video send failed: file not found]", receiver, context_token)
|
||||
return
|
||||
try:
|
||||
result = upload_media_to_cdn(self.api, local_path, receiver, media_type=2)
|
||||
self.api.send_video_item(
|
||||
to=receiver,
|
||||
context_token=context_token,
|
||||
encrypt_query_param=result["encrypt_query_param"],
|
||||
aes_key_b64=result["aes_key_b64"],
|
||||
ciphertext_size=result["ciphertext_size"],
|
||||
)
|
||||
logger.info(f"[Weixin] Video sent to {receiver}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] Video send failed: {e}")
|
||||
self._send_text("[Video send failed]", receiver, context_token)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_media_path(path_or_url: str) -> str:
|
||||
"""Resolve a file path or URL to a local file path. Downloads if needed."""
|
||||
if not path_or_url:
|
||||
return ""
|
||||
|
||||
local_path = path_or_url
|
||||
if local_path.startswith("file://"):
|
||||
local_path = local_path[7:]
|
||||
|
||||
if local_path.startswith(("http://", "https://")):
|
||||
try:
|
||||
resp = requests.get(local_path, timeout=60)
|
||||
resp.raise_for_status()
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
ext = ".bin"
|
||||
if "jpeg" in ct or "jpg" in ct:
|
||||
ext = ".jpg"
|
||||
elif "png" in ct:
|
||||
ext = ".png"
|
||||
elif "gif" in ct:
|
||||
ext = ".gif"
|
||||
elif "webp" in ct:
|
||||
ext = ".webp"
|
||||
elif "mp4" in ct:
|
||||
ext = ".mp4"
|
||||
elif "pdf" in ct:
|
||||
ext = ".pdf"
|
||||
|
||||
tmp_path = f"/tmp/wx_media_{uuid.uuid4().hex[:8]}{ext}"
|
||||
with open(tmp_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
return tmp_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] Failed to download media: {e}")
|
||||
return ""
|
||||
|
||||
if os.path.exists(local_path):
|
||||
return local_path
|
||||
|
||||
logger.warning(f"[Weixin] Media file not found: {local_path}")
|
||||
return ""
|
||||
204
channel/weixin/weixin_message.py
Normal file
204
channel/weixin/weixin_message.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Weixin ChatMessage implementation.
|
||||
|
||||
Parses WeixinMessage from the getUpdates API into the unified ChatMessage format.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from channel.weixin.weixin_api import download_media_from_cdn, CDN_BASE_URL
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
# MessageItemType constants from the Weixin protocol
|
||||
ITEM_TEXT = 1
|
||||
ITEM_IMAGE = 2
|
||||
ITEM_VOICE = 3
|
||||
ITEM_FILE = 4
|
||||
ITEM_VIDEO = 5
|
||||
|
||||
|
||||
def _get_tmp_dir() -> str:
|
||||
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(ws_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class WeixinMessage(ChatMessage):
|
||||
"""Message wrapper for Weixin channel."""
|
||||
|
||||
def __init__(self, msg: dict, cdn_base_url: str = CDN_BASE_URL):
|
||||
super().__init__(msg)
|
||||
|
||||
self.msg_id = str(msg.get("message_id", msg.get("seq", uuid.uuid4().hex[:8])))
|
||||
self.create_time = msg.get("create_time_ms", 0)
|
||||
self.context_token = msg.get("context_token", "")
|
||||
self.is_group = False # Weixin plugin only supports direct chat
|
||||
self.is_at = False
|
||||
|
||||
from_user_id = msg.get("from_user_id", "")
|
||||
to_user_id = msg.get("to_user_id", "")
|
||||
|
||||
self.from_user_id = from_user_id
|
||||
self.from_user_nickname = from_user_id
|
||||
self.to_user_id = to_user_id
|
||||
self.to_user_nickname = to_user_id
|
||||
self.other_user_id = from_user_id
|
||||
self.other_user_nickname = from_user_id
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = from_user_id
|
||||
|
||||
item_list = msg.get("item_list", [])
|
||||
|
||||
# Parse items: find text and media
|
||||
text_body = ""
|
||||
media_item = None
|
||||
media_type = None
|
||||
ref_text = ""
|
||||
|
||||
for item in item_list:
|
||||
itype = item.get("type", 0)
|
||||
|
||||
if itype == ITEM_TEXT:
|
||||
text_item = item.get("text_item", {})
|
||||
text_body = text_item.get("text", "")
|
||||
|
||||
ref = item.get("ref_msg")
|
||||
if ref:
|
||||
ref_title = ref.get("title", "")
|
||||
ref_mi = ref.get("message_item", {})
|
||||
ref_body = ""
|
||||
if ref_mi.get("type") == ITEM_TEXT:
|
||||
ref_body = ref_mi.get("text_item", {}).get("text", "")
|
||||
if ref_title or ref_body:
|
||||
parts = [p for p in [ref_title, ref_body] if p]
|
||||
ref_text = f"[引用: {' | '.join(parts)}]\n"
|
||||
# If ref is a media item, treat it as the media to download
|
||||
if ref_mi.get("type") in (ITEM_IMAGE, ITEM_VIDEO, ITEM_FILE):
|
||||
media_item = ref_mi
|
||||
media_type = ref_mi.get("type")
|
||||
|
||||
elif itype == ITEM_VOICE:
|
||||
voice_item = item.get("voice_item", {})
|
||||
voice_text = voice_item.get("text", "")
|
||||
if voice_text:
|
||||
text_body = voice_text
|
||||
else:
|
||||
# Voice without transcription - download the audio
|
||||
media_item = item
|
||||
media_type = ITEM_VOICE
|
||||
|
||||
elif itype in (ITEM_IMAGE, ITEM_VIDEO, ITEM_FILE):
|
||||
if not media_item:
|
||||
media_item = item
|
||||
media_type = itype
|
||||
|
||||
# Determine ctype and content
|
||||
if media_item and not text_body:
|
||||
self._setup_media(media_item, media_type, cdn_base_url)
|
||||
elif media_item and text_body:
|
||||
# Text + media: download media, attach as file ref in text
|
||||
self.ctype = ContextType.TEXT
|
||||
media_path = self._download_media(media_item, media_type, cdn_base_url)
|
||||
if media_path:
|
||||
if media_type == ITEM_IMAGE:
|
||||
text_body += f"\n[图片: {media_path}]"
|
||||
elif media_type == ITEM_VIDEO:
|
||||
text_body += f"\n[视频: {media_path}]"
|
||||
else:
|
||||
text_body += f"\n[文件: {media_path}]"
|
||||
self.content = ref_text + text_body
|
||||
else:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = ref_text + text_body
|
||||
|
||||
def _setup_media(self, item: dict, media_type: int, cdn_base_url: str):
|
||||
"""Set up message as a media type, with lazy download via _prepare_fn."""
|
||||
if media_type == ITEM_IMAGE:
|
||||
self.ctype = ContextType.IMAGE
|
||||
image_path = self._download_media(item, ITEM_IMAGE, cdn_base_url)
|
||||
if image_path:
|
||||
self.content = image_path
|
||||
self.image_path = image_path
|
||||
else:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = "[Image download failed]"
|
||||
|
||||
elif media_type == ITEM_VIDEO:
|
||||
self.ctype = ContextType.FILE
|
||||
save_path = os.path.join(_get_tmp_dir(), f"wx_{self.msg_id}.mp4")
|
||||
self.content = save_path
|
||||
|
||||
def _download():
|
||||
path = self._download_media(item, ITEM_VIDEO, cdn_base_url)
|
||||
if path:
|
||||
self.content = path
|
||||
self._prepare_fn = _download
|
||||
|
||||
elif media_type == ITEM_FILE:
|
||||
self.ctype = ContextType.FILE
|
||||
file_name = item.get("file_item", {}).get("file_name", f"wx_{self.msg_id}")
|
||||
save_path = os.path.join(_get_tmp_dir(), file_name)
|
||||
self.content = save_path
|
||||
|
||||
def _download():
|
||||
path = self._download_media(item, ITEM_FILE, cdn_base_url)
|
||||
if path:
|
||||
self.content = path
|
||||
self._prepare_fn = _download
|
||||
|
||||
elif media_type == ITEM_VOICE:
|
||||
self.ctype = ContextType.VOICE
|
||||
save_path = os.path.join(_get_tmp_dir(), f"wx_{self.msg_id}.silk")
|
||||
self.content = save_path
|
||||
|
||||
def _download():
|
||||
path = self._download_media(item, ITEM_VOICE, cdn_base_url)
|
||||
if path:
|
||||
self.content = path
|
||||
self._prepare_fn = _download
|
||||
|
||||
def _download_media(self, item: dict, media_type: int, cdn_base_url: str) -> str:
|
||||
"""Download media from CDN, returns local file path or empty string."""
|
||||
type_key_map = {
|
||||
ITEM_IMAGE: "image_item",
|
||||
ITEM_VIDEO: "video_item",
|
||||
ITEM_FILE: "file_item",
|
||||
ITEM_VOICE: "voice_item",
|
||||
}
|
||||
key = type_key_map.get(media_type, "")
|
||||
info = item.get(key, {})
|
||||
media = info.get("media", {})
|
||||
|
||||
encrypt_param = media.get("encrypt_query_param", "")
|
||||
# aes_key can be in image_item.aeskey (hex) or media.aes_key (b64)
|
||||
aes_key = info.get("aeskey", "") or media.get("aes_key", "")
|
||||
|
||||
if not encrypt_param or not aes_key:
|
||||
logger.warning(f"[Weixin] Missing CDN params for media download (type={media_type})")
|
||||
return ""
|
||||
|
||||
if media_type == ITEM_FILE:
|
||||
original_name = info.get("file_name", "")
|
||||
if original_name:
|
||||
save_path = os.path.join(_get_tmp_dir(), original_name)
|
||||
else:
|
||||
save_path = os.path.join(_get_tmp_dir(), f"wx_{self.msg_id}.bin")
|
||||
else:
|
||||
ext_map = {ITEM_IMAGE: ".jpg", ITEM_VIDEO: ".mp4", ITEM_VOICE: ".silk"}
|
||||
ext = ext_map.get(media_type, "")
|
||||
save_path = os.path.join(_get_tmp_dir(), f"wx_{self.msg_id}{ext}")
|
||||
|
||||
try:
|
||||
download_media_from_cdn(cdn_base_url, encrypt_param, aes_key, save_path)
|
||||
logger.info(f"[Weixin] Media downloaded: {save_path}")
|
||||
return save_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Weixin] Media download failed: {e}")
|
||||
return ""
|
||||
@@ -1,17 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
os.environ['ntwork_LOG'] = "ERROR"
|
||||
import ntwork
|
||||
|
||||
wework = ntwork.WeWork()
|
||||
|
||||
|
||||
def forever():
|
||||
try:
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
ntwork.exit_()
|
||||
os._exit(0)
|
||||
|
||||
|
||||
@@ -1,326 +0,0 @@
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import threading
|
||||
os.environ['ntwork_LOG'] = "ERROR"
|
||||
import ntwork
|
||||
import requests
|
||||
import uuid
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wework.wework_message import *
|
||||
from channel.wework.wework_message import WeworkMessage
|
||||
from common.singleton import singleton
|
||||
from common.log import logger
|
||||
from common.time_check import time_checker
|
||||
from common.utils import compress_imgfile, fsize
|
||||
from config import conf
|
||||
from channel.wework.run import wework
|
||||
from channel.wework import run
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_wxid_by_name(room_members, group_wxid, name):
|
||||
if group_wxid in room_members:
|
||||
for member in room_members[group_wxid]['member_list']:
|
||||
if member['room_nickname'] == name or member['username'] == name:
|
||||
return member['user_id']
|
||||
return None # 如果没有找到对应的group_wxid或name,则返回None
|
||||
|
||||
|
||||
def download_and_compress_image(url, filename, quality=30):
|
||||
# 确定保存图片的目录
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
# 如果目录不存在,则创建目录
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
# 下载图片
|
||||
pic_res = requests.get(url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
|
||||
# 检查图片大小并可能进行压缩
|
||||
sz = fsize(image_storage)
|
||||
if sz >= 10 * 1024 * 1024: # 如果图片大于 10 MB
|
||||
logger.info("[wework] image too large, ready to compress, sz={}".format(sz))
|
||||
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
||||
logger.info("[wework] image compressed, sz={}".format(fsize(image_storage)))
|
||||
|
||||
# 将内存缓冲区的指针重置到起始位置
|
||||
image_storage.seek(0)
|
||||
|
||||
# 读取并保存图片
|
||||
image = Image.open(image_storage)
|
||||
image_path = os.path.join(directory, f"{filename}.png")
|
||||
image.save(image_path, "png")
|
||||
|
||||
return image_path
|
||||
|
||||
|
||||
def download_video(url, filename):
|
||||
# 确定保存视频的目录
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
# 如果目录不存在,则创建目录
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
# 下载视频
|
||||
response = requests.get(url, stream=True)
|
||||
total_size = 0
|
||||
|
||||
video_path = os.path.join(directory, f"{filename}.mp4")
|
||||
|
||||
with open(video_path, 'wb') as f:
|
||||
for block in response.iter_content(1024):
|
||||
total_size += len(block)
|
||||
|
||||
# 如果视频的总大小超过30MB (30 * 1024 * 1024 bytes),则停止下载并返回
|
||||
if total_size > 30 * 1024 * 1024:
|
||||
logger.info("[WX] Video is larger than 30MB, skipping...")
|
||||
return None
|
||||
|
||||
f.write(block)
|
||||
|
||||
return video_path
|
||||
|
||||
|
||||
def create_message(wework_instance, message, is_group):
|
||||
logger.debug(f"正在为{'群聊' if is_group else '单聊'}创建 WeworkMessage")
|
||||
cmsg = WeworkMessage(message, wework=wework_instance, is_group=is_group)
|
||||
logger.debug(f"cmsg:{cmsg}")
|
||||
return cmsg
|
||||
|
||||
|
||||
def handle_message(cmsg, is_group):
|
||||
logger.debug(f"准备用 WeworkChannel 处理{'群聊' if is_group else '单聊'}消息")
|
||||
if is_group:
|
||||
WeworkChannel().handle_group(cmsg)
|
||||
else:
|
||||
WeworkChannel().handle_single(cmsg)
|
||||
logger.debug(f"已用 WeworkChannel 处理完{'群聊' if is_group else '单聊'}消息")
|
||||
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: ChatMessage):
|
||||
msgId = cmsg.msg_id
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if create_time is None:
|
||||
return func(self, cmsg)
|
||||
if int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@wework.msg_register(
|
||||
[ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_LINK_CARD_MSG,ntwork.MT_RECV_FILE_MSG, ntwork.MT_RECV_VOICE_MSG])
|
||||
def all_msg_handler(wework_instance: ntwork.WeWork, message):
|
||||
logger.debug(f"收到消息: {message}")
|
||||
if 'data' in message:
|
||||
# 首先查找conversation_id,如果没有找到,则查找room_conversation_id
|
||||
conversation_id = message['data'].get('conversation_id', message['data'].get('room_conversation_id'))
|
||||
if conversation_id is not None:
|
||||
is_group = "R:" in conversation_id
|
||||
try:
|
||||
cmsg = create_message(wework_instance=wework_instance, message=message, is_group=is_group)
|
||||
except NotImplementedError as e:
|
||||
logger.error(f"[WX]{message.get('MsgId', 'unknown')} 跳过: {e}")
|
||||
return None
|
||||
delay = random.randint(1, 2)
|
||||
timer = threading.Timer(delay, handle_message, args=(cmsg, is_group))
|
||||
timer.start()
|
||||
else:
|
||||
logger.debug("消息数据中无 conversation_id")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def accept_friend_with_retries(wework_instance, user_id, corp_id):
|
||||
result = wework_instance.accept_friend(user_id, corp_id)
|
||||
logger.debug(f'result:{result}')
|
||||
|
||||
|
||||
# @wework.msg_register(ntwork.MT_RECV_FRIEND_MSG)
|
||||
# def friend(wework_instance: ntwork.WeWork, message):
|
||||
# data = message["data"]
|
||||
# user_id = data["user_id"]
|
||||
# corp_id = data["corp_id"]
|
||||
# logger.info(f"接收到好友请求,消息内容:{data}")
|
||||
# delay = random.randint(1, 180)
|
||||
# threading.Timer(delay, accept_friend_with_retries, args=(wework_instance, user_id, corp_id)).start()
|
||||
#
|
||||
# return None
|
||||
|
||||
|
||||
def get_with_retry(get_func, max_retries=5, delay=5):
|
||||
retries = 0
|
||||
result = None
|
||||
while retries < max_retries:
|
||||
result = get_func()
|
||||
if result:
|
||||
break
|
||||
logger.warning(f"获取数据失败,重试第{retries + 1}次······")
|
||||
retries += 1
|
||||
time.sleep(delay) # 等待一段时间后重试
|
||||
return result
|
||||
|
||||
|
||||
@singleton
|
||||
class WeworkChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def startup(self):
|
||||
smart = conf().get("wework_smart", True)
|
||||
wework.open(smart)
|
||||
logger.info("等待登录······")
|
||||
wework.wait_login()
|
||||
login_info = wework.get_login_info()
|
||||
self.user_id = login_info['user_id']
|
||||
self.name = login_info['nickname']
|
||||
logger.info(f"登录信息:>>>user_id:{self.user_id}>>>>>>>>name:{self.name}")
|
||||
logger.info("静默延迟60s,等待客户端刷新数据,请勿进行任何操作······")
|
||||
time.sleep(60)
|
||||
contacts = get_with_retry(wework.get_external_contacts)
|
||||
rooms = get_with_retry(wework.get_rooms)
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
if not contacts or not rooms:
|
||||
logger.error("获取contacts或rooms失败,程序退出")
|
||||
ntwork.exit_()
|
||||
os.exit(0)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
# 将contacts保存到json文件中
|
||||
with open(os.path.join(directory, 'wework_contacts.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(contacts, f, ensure_ascii=False, indent=4)
|
||||
with open(os.path.join(directory, 'wework_rooms.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(rooms, f, ensure_ascii=False, indent=4)
|
||||
# 创建一个空字典来保存结果
|
||||
result = {}
|
||||
|
||||
# 遍历列表中的每个字典
|
||||
for room in rooms['room_list']:
|
||||
# 获取聊天室ID
|
||||
room_wxid = room['conversation_id']
|
||||
|
||||
# 获取聊天室成员
|
||||
room_members = wework.get_room_members(room_wxid)
|
||||
|
||||
# 将聊天室成员保存到结果字典中
|
||||
result[room_wxid] = room_members
|
||||
|
||||
# 将结果保存到json文件中
|
||||
with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=4)
|
||||
logger.info("wework程序初始化完成········")
|
||||
run.forever()
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: ChatMessage):
|
||||
if cmsg.from_user_id == cmsg.to_user_id:
|
||||
# ignore self reply
|
||||
return
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if not conf().get("speech_recognition"):
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
else:
|
||||
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: ChatMessage):
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if not conf().get("speech_recognition"):
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
|
||||
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
pass
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
logger.debug(f"context: {context}")
|
||||
receiver = context["receiver"]
|
||||
actual_user_id = context["msg"].actual_user_id
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.TEXT_:
|
||||
match = re.search(r"^@(.*?)\n", reply.content)
|
||||
logger.debug(f"match: {match}")
|
||||
if match:
|
||||
new_content = re.sub(r"^@(.*?)\n", "\n", reply.content)
|
||||
at_list = [actual_user_id]
|
||||
logger.debug(f"new_content: {new_content}")
|
||||
wework.send_room_at_msg(receiver, new_content, at_list)
|
||||
else:
|
||||
wework.send_text(receiver, reply.content)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
wework.send_text(receiver, reply.content)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
# Read data from image_storage
|
||||
data = image_storage.read()
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp:
|
||||
temp_path = temp.name
|
||||
temp.write(data)
|
||||
# Send the image
|
||||
wework.send_image(receiver, temp_path)
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
# Remove the temporary file
|
||||
os.remove(temp_path)
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
filename = str(uuid.uuid4())
|
||||
|
||||
# 调用你的函数,下载图片并保存为本地文件
|
||||
image_path = download_and_compress_image(img_url, filename)
|
||||
|
||||
wework.send_image(receiver, file_path=image_path)
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL:
|
||||
video_url = reply.content
|
||||
filename = str(uuid.uuid4())
|
||||
video_path = download_video(video_url, filename)
|
||||
|
||||
if video_path is None:
|
||||
# 如果视频太大,下载可能会被跳过,此时 video_path 将为 None
|
||||
wework.send_text(receiver, "抱歉,视频太大了!!!")
|
||||
else:
|
||||
wework.send_video(receiver, video_path)
|
||||
logger.info("[WX] sendVideo, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
current_dir = os.getcwd()
|
||||
voice_file = reply.content.split("/")[-1]
|
||||
reply.content = os.path.join(current_dir, "tmp", voice_file)
|
||||
wework.send_file(receiver, reply.content)
|
||||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
||||
@@ -1,227 +0,0 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import pilk
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from ntwork.const import send_type
|
||||
|
||||
|
||||
def get_with_retry(get_func, max_retries=5, delay=5):
|
||||
retries = 0
|
||||
result = None
|
||||
while retries < max_retries:
|
||||
result = get_func()
|
||||
if result:
|
||||
break
|
||||
logger.warning(f"获取数据失败,重试第{retries + 1}次······")
|
||||
retries += 1
|
||||
time.sleep(delay) # 等待一段时间后重试
|
||||
return result
|
||||
|
||||
|
||||
def get_room_info(wework, conversation_id):
|
||||
logger.debug(f"传入的 conversation_id: {conversation_id}")
|
||||
rooms = wework.get_rooms()
|
||||
if not rooms or 'room_list' not in rooms:
|
||||
logger.error(f"获取群聊信息失败: {rooms}")
|
||||
return None
|
||||
time.sleep(1)
|
||||
logger.debug(f"获取到的群聊信息: {rooms}")
|
||||
for room in rooms['room_list']:
|
||||
if room['conversation_id'] == conversation_id:
|
||||
return room
|
||||
return None
|
||||
|
||||
|
||||
def cdn_download(wework, message, file_name):
|
||||
data = message["data"]
|
||||
aes_key = data["cdn"]["aes_key"]
|
||||
file_size = data["cdn"]["size"]
|
||||
|
||||
# 获取当前工作目录,然后与文件名拼接得到保存路径
|
||||
current_dir = os.getcwd()
|
||||
save_path = os.path.join(current_dir, "tmp", file_name)
|
||||
|
||||
# 下载保存图片到本地
|
||||
if "url" in data["cdn"].keys() and "auth_key" in data["cdn"].keys():
|
||||
url = data["cdn"]["url"]
|
||||
auth_key = data["cdn"]["auth_key"]
|
||||
# result = wework.wx_cdn_download(url, auth_key, aes_key, file_size, save_path) # ntwork库本身接口有问题,缺失了aes_key这个参数
|
||||
"""
|
||||
下载wx类型的cdn文件,以https开头
|
||||
"""
|
||||
data = {
|
||||
'url': url,
|
||||
'auth_key': auth_key,
|
||||
'aes_key': aes_key,
|
||||
'size': file_size,
|
||||
'save_path': save_path
|
||||
}
|
||||
result = wework._WeWork__send_sync(send_type.MT_WXCDN_DOWNLOAD_MSG, data) # 直接用wx_cdn_download的接口内部实现来调用
|
||||
elif "file_id" in data["cdn"].keys():
|
||||
if message["type"] == 11042:
|
||||
file_type = 2
|
||||
elif message["type"] == 11045:
|
||||
file_type = 5
|
||||
file_id = data["cdn"]["file_id"]
|
||||
result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
|
||||
else:
|
||||
logger.error(f"something is wrong, data: {data}")
|
||||
return
|
||||
|
||||
# 输出下载结果
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
|
||||
def c2c_download_and_convert(wework, message, file_name):
|
||||
data = message["data"]
|
||||
aes_key = data["cdn"]["aes_key"]
|
||||
file_size = data["cdn"]["size"]
|
||||
file_type = 5
|
||||
file_id = data["cdn"]["file_id"]
|
||||
|
||||
current_dir = os.getcwd()
|
||||
save_path = os.path.join(current_dir, "tmp", file_name)
|
||||
result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
|
||||
logger.debug(result)
|
||||
|
||||
# 在下载完SILK文件之后,立即将其转换为WAV文件
|
||||
base_name, _ = os.path.splitext(save_path)
|
||||
wav_file = base_name + ".wav"
|
||||
pilk.silk_to_wav(save_path, wav_file, rate=24000)
|
||||
|
||||
# 删除SILK文件
|
||||
try:
|
||||
os.remove(save_path)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
class WeworkMessage(ChatMessage):
|
||||
def __init__(self, wework_msg, wework, is_group=False):
|
||||
try:
|
||||
super().__init__(wework_msg)
|
||||
self.msg_id = wework_msg['data'].get('conversation_id', wework_msg['data'].get('room_conversation_id'))
|
||||
# 使用.get()防止 'send_time' 键不存在时抛出错误
|
||||
self.create_time = wework_msg['data'].get("send_time")
|
||||
self.is_group = is_group
|
||||
self.wework = wework
|
||||
|
||||
if wework_msg["type"] == 11041: # 文本消息类型
|
||||
if any(substring in wework_msg['data']['content'] for substring in ("该消息类型暂不能展示", "不支持的消息类型")):
|
||||
return
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wework_msg['data']['content']
|
||||
elif wework_msg["type"] == 11044: # 语音消息类型,需要缓存文件
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".silk"
|
||||
base_name, _ = os.path.splitext(file_name)
|
||||
file_name_2 = base_name + ".wav"
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name_2)
|
||||
self._prepare_fn = lambda: c2c_download_and_convert(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11042: # 图片消息类型,需要下载文件
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".jpg"
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name)
|
||||
self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11045: # 文件消息
|
||||
print("文件消息")
|
||||
print(wework_msg)
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
||||
file_name = file_name + wework_msg['data']['cdn']['file_name']
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.FILE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name)
|
||||
self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11047: # 链接消息
|
||||
self.ctype = ContextType.SHARING
|
||||
self.content = wework_msg['data']['url']
|
||||
elif wework_msg["type"] == 11072: # 新成员入群通知
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
member_list = wework_msg['data']['member_list']
|
||||
self.actual_user_nickname = member_list[0]['name']
|
||||
self.actual_user_id = member_list[0]['user_id']
|
||||
self.content = f"{self.actual_user_nickname}加入了群聊!"
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
rooms = get_with_retry(wework.get_rooms)
|
||||
if not rooms:
|
||||
logger.error("更新群信息失败···")
|
||||
else:
|
||||
result = {}
|
||||
for room in rooms['room_list']:
|
||||
# 获取聊天室ID
|
||||
room_wxid = room['conversation_id']
|
||||
|
||||
# 获取聊天室成员
|
||||
room_members = wework.get_room_members(room_wxid)
|
||||
|
||||
# 将聊天室成员保存到结果字典中
|
||||
result[room_wxid] = room_members
|
||||
with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=4)
|
||||
logger.info("有新成员加入,已自动更新群成员列表缓存!")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported message type: Type:{} MsgType:{}".format(wework_msg["type"], wework_msg["MsgType"]))
|
||||
|
||||
data = wework_msg['data']
|
||||
login_info = self.wework.get_login_info()
|
||||
logger.debug(f"login_info: {login_info}")
|
||||
nickname = f"{login_info['username']}({login_info['nickname']})" if login_info['nickname'] else login_info['username']
|
||||
user_id = login_info['user_id']
|
||||
|
||||
sender_id = data.get('sender')
|
||||
conversation_id = data.get('conversation_id')
|
||||
sender_name = data.get("sender_name")
|
||||
|
||||
self.from_user_id = user_id if sender_id == user_id else conversation_id
|
||||
self.from_user_nickname = nickname if sender_id == user_id else sender_name
|
||||
self.to_user_id = user_id
|
||||
self.to_user_nickname = nickname
|
||||
self.other_user_nickname = sender_name
|
||||
self.other_user_id = conversation_id
|
||||
|
||||
if self.is_group:
|
||||
conversation_id = data.get('conversation_id') or data.get('room_conversation_id')
|
||||
self.other_user_id = conversation_id
|
||||
if conversation_id:
|
||||
room_info = get_room_info(wework=wework, conversation_id=conversation_id)
|
||||
self.other_user_nickname = room_info.get('nickname', None) if room_info else None
|
||||
self.from_user_nickname = room_info.get('nickname', None) if room_info else None
|
||||
at_list = data.get('at_list', [])
|
||||
tmp_list = []
|
||||
for at in at_list:
|
||||
tmp_list.append(at['nickname'])
|
||||
at_list = tmp_list
|
||||
logger.debug(f"at_list: {at_list}")
|
||||
logger.debug(f"nickname: {nickname}")
|
||||
self.is_at = False
|
||||
if nickname in at_list or login_info['nickname'] in at_list or login_info['username'] in at_list:
|
||||
self.is_at = True
|
||||
self.at_list = at_list
|
||||
|
||||
# 检查消息内容是否包含@用户名。处理复制粘贴的消息,这类消息可能不会触发@通知,但内容中可能包含 "@用户名"。
|
||||
content = data.get('content', '')
|
||||
name = nickname
|
||||
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
|
||||
if re.search(pattern, content):
|
||||
logger.debug(f"Wechaty message {self.msg_id} includes at")
|
||||
self.is_at = True
|
||||
|
||||
if not self.actual_user_id:
|
||||
self.actual_user_id = data.get("sender")
|
||||
self.actual_user_nickname = sender_name if self.ctype != ContextType.JOIN_GROUP else self.actual_user_nickname
|
||||
else:
|
||||
logger.error("群聊消息中没有找到 conversation_id 或 room_conversation_id")
|
||||
|
||||
logger.debug(f"WeworkMessage has been successfully instantiated with message id: {self.msg_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"在 WeworkMessage 的初始化过程中出现错误:{e}")
|
||||
raise e
|
||||
1
cli/VERSION
Normal file
1
cli/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
2.0.6
|
||||
13
cli/__init__.py
Normal file
13
cli/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""CowAgent CLI - Manage your CowAgent from the command line."""
|
||||
|
||||
import os as _os
|
||||
|
||||
def _read_version():
|
||||
version_file = _os.path.join(_os.path.dirname(_os.path.abspath(__file__)), "VERSION")
|
||||
try:
|
||||
with open(version_file, "r") as f:
|
||||
return f.read().strip()
|
||||
except FileNotFoundError:
|
||||
return "0.0.0"
|
||||
|
||||
__version__ = _read_version()
|
||||
4
cli/__main__.py
Normal file
4
cli/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Allow running as: python -m cli"""
|
||||
from cli.cli import main
|
||||
|
||||
main()
|
||||
79
cli/cli.py
Normal file
79
cli/cli.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""CowAgent CLI entry point."""
|
||||
|
||||
import click
|
||||
from cli import __version__
|
||||
from cli.commands.skill import skill
|
||||
from cli.commands.process import start, stop, restart, update, status, logs
|
||||
from cli.commands.context import context
|
||||
from cli.commands.install import install_browser
|
||||
from cli.commands.knowledge import knowledge
|
||||
|
||||
|
||||
HELP_TEXT = """Usage: cow COMMAND [ARGS]...
|
||||
|
||||
CowAgent CLI - Manage your CowAgent instance.
|
||||
|
||||
Commands:
|
||||
help Show this message.
|
||||
version Show the version.
|
||||
start Start CowAgent.
|
||||
stop Stop CowAgent.
|
||||
restart Restart CowAgent.
|
||||
update Update CowAgent and restart.
|
||||
status Show CowAgent running status.
|
||||
logs View CowAgent logs.
|
||||
skill Manage CowAgent skills.
|
||||
knowledge Manage knowledge base.
|
||||
install-browser Install browser tool (Playwright + Chromium).
|
||||
|
||||
Tip: You can also send /help, /skill list, etc. in agent chat."""
|
||||
|
||||
|
||||
class CowCLI(click.Group):
|
||||
|
||||
def format_help(self, ctx, formatter):
|
||||
formatter.write(HELP_TEXT.strip())
|
||||
formatter.write("\n")
|
||||
|
||||
def parse_args(self, ctx, args):
|
||||
if args and args[0] == 'help':
|
||||
click.echo(HELP_TEXT.strip())
|
||||
ctx.exit(0)
|
||||
return super().parse_args(ctx, args)
|
||||
|
||||
|
||||
@click.group(cls=CowCLI, invoke_without_command=True, context_settings=dict(help_option_names=[]))
|
||||
@click.pass_context
|
||||
def main(ctx):
|
||||
"""CowAgent CLI - Manage your CowAgent instance."""
|
||||
if ctx.invoked_subcommand is None:
|
||||
click.echo(HELP_TEXT.strip())
|
||||
|
||||
|
||||
@main.command()
|
||||
def version():
|
||||
"""Show the version."""
|
||||
click.echo(f"cow {__version__}")
|
||||
|
||||
|
||||
@main.command(name='help')
|
||||
@click.pass_context
|
||||
def help_cmd(ctx):
|
||||
"""Show this message."""
|
||||
click.echo(HELP_TEXT.strip())
|
||||
|
||||
|
||||
main.add_command(skill)
|
||||
main.add_command(start)
|
||||
main.add_command(stop)
|
||||
main.add_command(restart)
|
||||
main.add_command(update)
|
||||
main.add_command(status)
|
||||
main.add_command(logs)
|
||||
main.add_command(context)
|
||||
main.add_command(knowledge)
|
||||
main.add_command(install_browser)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
0
cli/commands/__init__.py
Normal file
0
cli/commands/__init__.py
Normal file
29
cli/commands/context.py
Normal file
29
cli/commands/context.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""cow context - Context management commands."""
|
||||
|
||||
import click
|
||||
|
||||
|
||||
CHAT_HINT = (
|
||||
"Context commands operate on the running agent's memory.\n"
|
||||
"Please send the command in a chat conversation instead:\n\n"
|
||||
" /context - View current context info\n"
|
||||
" /context clear - Clear conversation context"
|
||||
)
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.pass_context
|
||||
def context(ctx):
|
||||
"""View or manage conversation context.
|
||||
|
||||
Context commands need access to the running agent's memory.
|
||||
Use them in chat conversations: /context or /context clear
|
||||
"""
|
||||
if ctx.invoked_subcommand is None:
|
||||
click.echo(f"\n {CHAT_HINT}\n")
|
||||
|
||||
|
||||
@context.command()
|
||||
def clear():
|
||||
"""Clear conversation context (messages history)."""
|
||||
click.echo(f"\n {CHAT_HINT}\n")
|
||||
259
cli/commands/install.py
Normal file
259
cli/commands/install.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""cow install-browser - Install Playwright + Chromium for the browser tool."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from typing import Callable, Optional
|
||||
|
||||
import click
|
||||
|
||||
PLAYWRIGHT_VERSION = "1.52.0"
|
||||
PLAYWRIGHT_LEGACY_VERSION = "1.28.0"
|
||||
GLIBC_THRESHOLD = (2, 28)
|
||||
CHINA_MIRROR = "https://registry.npmmirror.com/-/binary/playwright"
|
||||
|
||||
# stream(msg, fg=None) — fg is "yellow" | "green" | "red" | None
|
||||
StreamFn = Callable[[str, Optional[str]], None]
|
||||
# on_phase(msg) — coarse-grained progress for chat channels (Chinese)
|
||||
PhaseFn = Callable[[str], None]
|
||||
|
||||
|
||||
def _phase(cb: Optional[PhaseFn], msg: str) -> None:
|
||||
if cb:
|
||||
cb(msg)
|
||||
|
||||
|
||||
def _has_display() -> bool:
|
||||
"""Check if a graphical display is available (Linux only)."""
|
||||
return bool(os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"))
|
||||
|
||||
|
||||
def _is_headless_linux() -> bool:
|
||||
return sys.platform == "linux" and not _has_display()
|
||||
|
||||
|
||||
def _get_installed_version() -> str:
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
[sys.executable, "-c", "import playwright; print(playwright.__version__)"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return out.decode().strip()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _version_tuple(v: str):
|
||||
try:
|
||||
return tuple(int(x) for x in v.split(".")[:3])
|
||||
except (ValueError, AttributeError):
|
||||
return (0, 0, 0)
|
||||
|
||||
|
||||
def _get_glibc_version():
|
||||
if sys.platform != "linux":
|
||||
return None
|
||||
try:
|
||||
import ctypes
|
||||
libc = ctypes.CDLL("libc.so.6")
|
||||
gnu_get_libc_version = libc.gnu_get_libc_version
|
||||
gnu_get_libc_version.restype = ctypes.c_char_p
|
||||
ver = gnu_get_libc_version().decode()
|
||||
parts = ver.split(".")
|
||||
return (int(parts[0]), int(parts[1]))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _is_china_network() -> bool:
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
[sys.executable, "-m", "pip", "config", "get", "global.index-url"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
url = out.decode().strip().lower()
|
||||
return any(kw in url for kw in ("tsinghua", "aliyun", "npmmirror", "douban", "ustc", "huawei", "tencentyun"))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _pip_install(package_spec: str, stream: StreamFn) -> int:
|
||||
"""Install a package, retrying with --user on permission failure."""
|
||||
python = sys.executable
|
||||
ret = subprocess.call([python, "-m", "pip", "install", package_spec])
|
||||
if ret != 0:
|
||||
stream(" Retrying with --user flag...", "yellow")
|
||||
ret = subprocess.call([python, "-m", "pip", "install", "--user", package_spec])
|
||||
return ret
|
||||
|
||||
|
||||
def _default_stream(msg: str, fg: Optional[str] = None) -> None:
|
||||
"""CLI: colored click output."""
|
||||
if fg == "yellow":
|
||||
click.echo(click.style(msg, fg="yellow"))
|
||||
elif fg == "green":
|
||||
click.echo(click.style(msg, fg="green"))
|
||||
elif fg == "red":
|
||||
click.echo(click.style(msg, fg="red"))
|
||||
else:
|
||||
click.echo(msg)
|
||||
|
||||
|
||||
def run_install_browser(
|
||||
stream: Optional[StreamFn] = None,
|
||||
on_phase: Optional[PhaseFn] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Install Playwright Python package, optional Linux deps, and Chromium.
|
||||
|
||||
Reused by ``cow install-browser`` CLI and chat ``/install-browser``.
|
||||
|
||||
Args:
|
||||
stream: Optional callback ``(message, fg)`` for each line. ``fg`` is
|
||||
``yellow`` / ``green`` / ``red`` or None. Defaults to colored click output.
|
||||
on_phase: Optional callback for coarse progress (e.g. push to chat);
|
||||
messages are short Chinese status lines.
|
||||
|
||||
Returns:
|
||||
0 on success, 1 on fatal failure (pip or chromium install failed).
|
||||
"""
|
||||
stream = stream or _default_stream
|
||||
python = sys.executable
|
||||
legacy_mode = False
|
||||
|
||||
_phase(on_phase, "🔧 开始安装浏览器工具依赖(约几分钟,请耐心等待)…")
|
||||
|
||||
glibc = _get_glibc_version()
|
||||
if glibc and glibc < GLIBC_THRESHOLD:
|
||||
legacy_mode = True
|
||||
glibc_str = f"{glibc[0]}.{glibc[1]}"
|
||||
stream(
|
||||
f"glibc {glibc_str} detected (< 2.28). "
|
||||
f"Will install playwright {PLAYWRIGHT_LEGACY_VERSION} for compatibility.",
|
||||
"yellow",
|
||||
)
|
||||
stream(" Note: upgrade your OS for full browser tool support.", "yellow")
|
||||
stream("")
|
||||
_phase(
|
||||
on_phase,
|
||||
f"ℹ️ 检测到 glibc {glibc_str}(较旧),将安装兼容版 Playwright {PLAYWRIGHT_LEGACY_VERSION}。",
|
||||
)
|
||||
|
||||
target_version = PLAYWRIGHT_LEGACY_VERSION if legacy_mode else PLAYWRIGHT_VERSION
|
||||
|
||||
_phase(on_phase, "📦 [1/3] 正在安装 Playwright Python 包…")
|
||||
stream("[1/3] Installing playwright Python package...", "yellow")
|
||||
ret = _pip_install(f"playwright=={target_version}", stream)
|
||||
if ret != 0:
|
||||
stream("Failed to install playwright package.", "red")
|
||||
_phase(on_phase, "❌ [1/3] Playwright Python 包安装失败。")
|
||||
return 1
|
||||
|
||||
installed = _get_installed_version()
|
||||
if installed:
|
||||
stream(f" playwright {installed} installed.", "green")
|
||||
stream("")
|
||||
_phase(on_phase, f"✅ [1/3] Playwright 包已安装({installed or target_version})。")
|
||||
|
||||
if sys.platform == "linux":
|
||||
_phase(on_phase, "🔧 [2/3] 正在安装 Linux 系统依赖与轻量中文字体(文泉驿正黑,部分步骤可能需要 sudo)…")
|
||||
stream("[2/3] Installing system dependencies (Linux)...", "yellow")
|
||||
ret = subprocess.call([python, "-m", "playwright", "install-deps", "chromium"])
|
||||
if ret != 0:
|
||||
stream(
|
||||
" Could not auto-install system deps (may need sudo).\n"
|
||||
f" Run manually: sudo {python} -m playwright install-deps chromium",
|
||||
"yellow",
|
||||
)
|
||||
# Prefer fonts-wqy-zenhei only (~few MB). fonts-noto-cjk is much larger (~150MB+).
|
||||
stream(" Installing CJK font (fonts-wqy-zenhei, lightweight)...")
|
||||
font_ret = subprocess.call(
|
||||
["sudo", "apt-get", "install", "-y", "--no-install-recommends", "fonts-wqy-zenhei"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
if font_ret != 0:
|
||||
stream(
|
||||
" Could not auto-install CJK font.\n"
|
||||
" Run manually: sudo apt-get install -y fonts-wqy-zenhei\n"
|
||||
" (Optional, larger full coverage: sudo apt-get install -y fonts-noto-cjk)",
|
||||
"yellow",
|
||||
)
|
||||
else:
|
||||
subprocess.call(["fc-cache", "-fv"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
stream(" CJK font (wqy-zenhei) installed.", "green")
|
||||
_phase(
|
||||
on_phase,
|
||||
"✅ [2/3] Linux 依赖与字体步骤已执行(若有权限问题请查看服务器日志或手动执行提示命令)。",
|
||||
)
|
||||
else:
|
||||
stream(f"[2/3] Skipping system deps (not needed on {sys.platform}).", "yellow")
|
||||
_phase(on_phase, f"ℹ️ [2/3] 当前系统({sys.platform})跳过 Linux 专用依赖。")
|
||||
stream("")
|
||||
|
||||
_phase(on_phase, "🌐 [3/3] 正在下载并安装 Chromium(体积较大,请耐心等待)…")
|
||||
stream("[3/3] Installing Chromium browser...", "yellow")
|
||||
cmd = [python, "-m", "playwright", "install", "chromium"]
|
||||
|
||||
if _is_headless_linux() and not legacy_mode:
|
||||
ver = _version_tuple(installed or "")
|
||||
if ver >= (1, 57, 0):
|
||||
cmd.append("--only-shell")
|
||||
stream(" (headless shell for Linux server)", None)
|
||||
else:
|
||||
stream(" (full Chromium)", None)
|
||||
elif sys.platform == "linux" and _has_display():
|
||||
stream(" (full browser for Linux desktop)", None)
|
||||
|
||||
env = os.environ.copy()
|
||||
use_mirror = _is_china_network()
|
||||
if use_mirror:
|
||||
env["PLAYWRIGHT_DOWNLOAD_HOST"] = CHINA_MIRROR
|
||||
stream(f" (using China mirror: {CHINA_MIRROR})", None)
|
||||
_phase(on_phase, "📡 检测到国内 pip 源配置,Chromium 将优先走国内镜像下载。")
|
||||
|
||||
ret = subprocess.call(cmd, env=env)
|
||||
|
||||
if ret != 0 and use_mirror:
|
||||
stream(" Mirror download failed, retrying with official CDN...", "yellow")
|
||||
_phase(on_phase, "⚠️ 镜像下载失败,正在改用官方源重试…")
|
||||
env_no_mirror = os.environ.copy()
|
||||
env_no_mirror.pop("PLAYWRIGHT_DOWNLOAD_HOST", None)
|
||||
ret = subprocess.call(cmd, env=env_no_mirror)
|
||||
|
||||
if ret != 0:
|
||||
stream("Failed to install Chromium.", "red")
|
||||
_phase(on_phase, "❌ [3/3] Chromium 安装失败。")
|
||||
return 1
|
||||
|
||||
stream("")
|
||||
_phase(on_phase, "✅ [3/3] Chromium 已安装。")
|
||||
|
||||
stream("Verifying browser installation...", None)
|
||||
_phase(on_phase, "🔍 正在验证 Playwright 能否正常加载…")
|
||||
ret = subprocess.call(
|
||||
[python, "-c", "from playwright.sync_api import sync_playwright; print('OK')"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
if ret != 0:
|
||||
stream(
|
||||
" Warning: playwright import failed. Browser tool may not work on this system.\n"
|
||||
" Consider upgrading your OS or using Docker.",
|
||||
"yellow",
|
||||
)
|
||||
_phase(on_phase, "⚠️ 验证未完全通过:本机可能仍无法使用浏览器工具,请查看日志或升级系统。")
|
||||
else:
|
||||
stream(" Verification passed.", "green")
|
||||
_phase(on_phase, "✅ 验证通过。")
|
||||
|
||||
stream("")
|
||||
stream("Browser tool ready! Restart CowAgent to enable it.", "green")
|
||||
_phase(on_phase, "🎉 全部步骤结束。请重启 CowAgent 后使用 browser 工具。")
|
||||
return 0
|
||||
|
||||
|
||||
@click.command("install-browser")
|
||||
def install_browser():
|
||||
"""Install browser tool dependencies (Playwright + Chromium)."""
|
||||
code = run_install_browser()
|
||||
if code != 0:
|
||||
raise SystemExit(code)
|
||||
121
cli/commands/knowledge.py
Normal file
121
cli/commands/knowledge.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""cow knowledge - Knowledge base management commands."""
|
||||
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
from cli.utils import get_project_root
|
||||
|
||||
|
||||
def _get_knowledge_dir():
|
||||
"""Resolve the knowledge directory path from config or default."""
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, get_project_root())
|
||||
from config import conf
|
||||
from common.utils import expand_path
|
||||
workspace = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
except Exception:
|
||||
workspace = os.path.expanduser("~/cow")
|
||||
return os.path.join(workspace, "knowledge")
|
||||
|
||||
|
||||
def _get_knowledge_enabled():
|
||||
try:
|
||||
import sys
|
||||
sys.path.insert(0, get_project_root())
|
||||
from config import conf
|
||||
return conf().get("knowledge", True)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.pass_context
|
||||
def knowledge(ctx):
|
||||
"""Manage CowAgent knowledge base."""
|
||||
if ctx.invoked_subcommand is None:
|
||||
click.echo(_stats())
|
||||
|
||||
|
||||
@knowledge.command("list")
|
||||
def knowledge_list():
|
||||
"""Display knowledge base file tree."""
|
||||
click.echo(_tree())
|
||||
|
||||
|
||||
def _stats() -> str:
|
||||
knowledge_dir = _get_knowledge_dir()
|
||||
if not os.path.isdir(knowledge_dir):
|
||||
return "Knowledge base directory not found."
|
||||
|
||||
enabled = _get_knowledge_enabled()
|
||||
total_files = 0
|
||||
total_bytes = 0
|
||||
cat_count = {}
|
||||
|
||||
for root, dirs, files in os.walk(knowledge_dir):
|
||||
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
||||
rel_root = os.path.relpath(root, knowledge_dir)
|
||||
category = rel_root.split(os.sep)[0] if rel_root != "." else "root"
|
||||
for f in files:
|
||||
if f.endswith(".md") and f not in ("index.md", "log.md"):
|
||||
total_files += 1
|
||||
total_bytes += os.path.getsize(os.path.join(root, f))
|
||||
cat_count[category] = cat_count.get(category, 0) + 1
|
||||
|
||||
status_icon = click.style("enabled", fg="green") if enabled else click.style("disabled", fg="red")
|
||||
lines = [
|
||||
f"\n Knowledge Base [{status_icon}]",
|
||||
"",
|
||||
f" Pages: {total_files}",
|
||||
f" Size: {total_bytes / 1024:.1f} KB",
|
||||
"",
|
||||
]
|
||||
if cat_count:
|
||||
lines.append(" Categories:")
|
||||
for cat in sorted(cat_count.keys()):
|
||||
lines.append(f" {cat}/ ({cat_count[cat]} pages)")
|
||||
lines.append("")
|
||||
|
||||
lines.append(f" Path: {knowledge_dir}")
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _tree() -> str:
|
||||
knowledge_dir = _get_knowledge_dir()
|
||||
if not os.path.isdir(knowledge_dir):
|
||||
return "Knowledge base directory not found."
|
||||
|
||||
tree_lines = [" knowledge/"]
|
||||
|
||||
subdirs = sorted([
|
||||
d for d in os.listdir(knowledge_dir)
|
||||
if os.path.isdir(os.path.join(knowledge_dir, d)) and not d.startswith(".")
|
||||
])
|
||||
|
||||
for i, subdir in enumerate(subdirs):
|
||||
is_last_dir = (i == len(subdirs) - 1)
|
||||
branch = "└── " if is_last_dir else "├── "
|
||||
subdir_path = os.path.join(knowledge_dir, subdir)
|
||||
md_files = sorted([
|
||||
f for f in os.listdir(subdir_path)
|
||||
if f.endswith(".md") and not f.startswith(".")
|
||||
])
|
||||
tree_lines.append(f" {branch}{subdir}/ ({len(md_files)})")
|
||||
|
||||
child_prefix = " " if is_last_dir else " │ "
|
||||
max_show = 15
|
||||
for j, fname in enumerate(md_files[:max_show]):
|
||||
is_last_file = (j == len(md_files[:max_show]) - 1) and len(md_files) <= max_show
|
||||
fb = "└── " if is_last_file else "├── "
|
||||
name = fname.replace(".md", "")
|
||||
tree_lines.append(f"{child_prefix}{fb}{name}")
|
||||
if len(md_files) > max_show:
|
||||
tree_lines.append(f"{child_prefix}└── ... +{len(md_files) - max_show} more")
|
||||
|
||||
if not subdirs:
|
||||
tree_lines.append(" (empty)")
|
||||
|
||||
return "\n" + "\n".join(tree_lines) + "\n"
|
||||
317
cli/commands/process.py
Normal file
317
cli/commands/process.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""cow start/stop/restart/status/logs - Process management commands."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
from cli.utils import get_project_root
|
||||
|
||||
_IS_WIN = sys.platform == "win32"
|
||||
|
||||
|
||||
def _get_pid_file():
|
||||
return os.path.join(get_project_root(), ".cow.pid")
|
||||
|
||||
|
||||
def _get_log_file():
|
||||
return os.path.join(get_project_root(), "nohup.out")
|
||||
|
||||
|
||||
def _is_pid_alive(pid: int) -> bool:
|
||||
"""Check whether a process is still running (cross-platform)."""
|
||||
if _IS_WIN:
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
["tasklist", "/FI", f"PID eq {pid}", "/NH"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return str(pid) in out.decode(errors="ignore")
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
return True
|
||||
except (ProcessLookupError, PermissionError):
|
||||
return False
|
||||
|
||||
|
||||
def _kill_pid(pid: int, force: bool = False):
|
||||
"""Terminate a process by PID (cross-platform)."""
|
||||
if _IS_WIN:
|
||||
flag = "/F" if force else ""
|
||||
cmd = ["taskkill"]
|
||||
if force:
|
||||
cmd.append("/F")
|
||||
cmd.extend(["/PID", str(pid)])
|
||||
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
else:
|
||||
import signal
|
||||
sig = signal.SIGKILL if force else signal.SIGTERM
|
||||
os.kill(pid, sig)
|
||||
|
||||
|
||||
def _read_pid() -> Optional[int]:
|
||||
pid_file = _get_pid_file()
|
||||
if not os.path.exists(pid_file):
|
||||
return None
|
||||
try:
|
||||
with open(pid_file, "r") as f:
|
||||
pid = int(f.read().strip())
|
||||
if _is_pid_alive(pid):
|
||||
return pid
|
||||
os.remove(pid_file)
|
||||
return None
|
||||
except (ValueError, OSError):
|
||||
try:
|
||||
os.remove(pid_file)
|
||||
except OSError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _write_pid(pid: int):
|
||||
with open(_get_pid_file(), "w") as f:
|
||||
f.write(str(pid))
|
||||
|
||||
|
||||
def _remove_pid():
|
||||
pid_file = _get_pid_file()
|
||||
if os.path.exists(pid_file):
|
||||
os.remove(pid_file)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--foreground", "-f", is_flag=True, help="Run in foreground (don't daemonize)")
|
||||
@click.option("--no-logs", is_flag=True, help="Don't tail logs after starting")
|
||||
def start(foreground, no_logs):
|
||||
"""Start CowAgent."""
|
||||
pid = _read_pid()
|
||||
if pid:
|
||||
click.echo(f"CowAgent is already running (PID: {pid}).")
|
||||
return
|
||||
|
||||
root = get_project_root()
|
||||
app_py = os.path.join(root, "app.py")
|
||||
if not os.path.exists(app_py):
|
||||
click.echo("Error: app.py not found in project root.", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
python = sys.executable
|
||||
|
||||
if foreground:
|
||||
click.echo("Starting CowAgent in foreground...")
|
||||
if _IS_WIN:
|
||||
sys.exit(subprocess.call([python, app_py], cwd=root))
|
||||
else:
|
||||
os.execv(python, [python, app_py])
|
||||
else:
|
||||
log_file = _get_log_file()
|
||||
click.echo("Starting CowAgent...")
|
||||
|
||||
popen_kwargs = dict(cwd=root)
|
||||
if _IS_WIN:
|
||||
CREATE_NO_WINDOW = 0x08000000
|
||||
popen_kwargs["creationflags"] = (
|
||||
subprocess.CREATE_NEW_PROCESS_GROUP | CREATE_NO_WINDOW
|
||||
)
|
||||
else:
|
||||
popen_kwargs["start_new_session"] = True
|
||||
|
||||
with open(log_file, "a") as log:
|
||||
proc = subprocess.Popen(
|
||||
[python, app_py],
|
||||
stdout=log,
|
||||
stderr=log,
|
||||
**popen_kwargs,
|
||||
)
|
||||
_write_pid(proc.pid)
|
||||
click.echo(click.style(f"✓ CowAgent started (PID: {proc.pid})", fg="green"))
|
||||
click.echo(f" Logs: {log_file}")
|
||||
|
||||
if not no_logs:
|
||||
click.echo(" Press Ctrl+C to stop tailing logs.\n")
|
||||
_tail_log(log_file)
|
||||
|
||||
|
||||
@click.command()
|
||||
def stop():
|
||||
"""Stop CowAgent."""
|
||||
pid = _read_pid()
|
||||
if not pid:
|
||||
click.echo("CowAgent is not running.")
|
||||
return
|
||||
|
||||
click.echo(f"Stopping CowAgent (PID: {pid})...")
|
||||
try:
|
||||
_kill_pid(pid)
|
||||
for _ in range(30):
|
||||
time.sleep(0.1)
|
||||
if not _is_pid_alive(pid):
|
||||
break
|
||||
else:
|
||||
_kill_pid(pid, force=True)
|
||||
except (ProcessLookupError, OSError):
|
||||
pass
|
||||
|
||||
_remove_pid()
|
||||
click.echo(click.style("✓ CowAgent stopped.", fg="green"))
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--no-logs", is_flag=True, help="Don't tail logs after restarting")
|
||||
@click.pass_context
|
||||
def restart(ctx, no_logs):
|
||||
"""Restart CowAgent."""
|
||||
ctx.invoke(stop)
|
||||
time.sleep(1)
|
||||
ctx.invoke(start, no_logs=no_logs)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.pass_context
|
||||
def update(ctx):
|
||||
"""Update CowAgent and restart."""
|
||||
root = get_project_root()
|
||||
|
||||
# 1. Stop service first so git pull won't conflict with running code
|
||||
ctx.invoke(stop)
|
||||
|
||||
# 2. Git pull
|
||||
if os.path.isdir(os.path.join(root, ".git")):
|
||||
click.echo("Pulling latest code...")
|
||||
ret = subprocess.call(["git", "pull"], cwd=root)
|
||||
if ret != 0:
|
||||
click.echo("Error: git pull failed.", err=True)
|
||||
sys.exit(1)
|
||||
else:
|
||||
click.echo("Not a git repository, skipping code update.")
|
||||
|
||||
python = sys.executable
|
||||
req_file = os.path.join(root, "requirements.txt")
|
||||
|
||||
if _IS_WIN:
|
||||
# On Windows, `cow.exe` (this process) locks the exe file, so
|
||||
# `pip install -e .` fails with WinError 5. Write a small .bat
|
||||
# helper that waits for cow.exe to exit, then installs & starts.
|
||||
bat = os.path.join(root, "_cow_update.bat")
|
||||
lines = [
|
||||
"@echo off",
|
||||
"chcp 65001 >nul",
|
||||
"echo Waiting for cow.exe to exit...",
|
||||
"timeout /t 3 /nobreak >nul",
|
||||
]
|
||||
if os.path.exists(req_file):
|
||||
lines.append(f'echo Installing dependencies...')
|
||||
lines.append(f'"{python}" -m pip install -r requirements.txt -q')
|
||||
lines += [
|
||||
"echo Reinstalling cow CLI...",
|
||||
f'"{python}" -m pip install -e . -q',
|
||||
"echo Starting CowAgent...",
|
||||
f'"{python}" -m cli.cli start --no-logs',
|
||||
"echo.",
|
||||
"echo Update complete. You can close this window.",
|
||||
"pause >nul",
|
||||
"del \"%~f0\"",
|
||||
]
|
||||
with open(bat, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(lines) + "\n")
|
||||
|
||||
subprocess.Popen(
|
||||
["cmd.exe", "/c", "start", "CowAgent Update", "/wait", bat],
|
||||
cwd=root,
|
||||
)
|
||||
click.echo(click.style(
|
||||
"✓ Update script launched. Please follow the new window for progress.",
|
||||
fg="green"))
|
||||
else:
|
||||
# 3. Install dependencies
|
||||
if os.path.exists(req_file):
|
||||
click.echo("Installing dependencies...")
|
||||
subprocess.call(
|
||||
[python, "-m", "pip", "install", "-r", "requirements.txt", "-q"],
|
||||
cwd=root,
|
||||
)
|
||||
click.echo("Reinstalling cow CLI...")
|
||||
subprocess.call(
|
||||
[python, "-m", "pip", "install", "-e", ".", "-q"],
|
||||
cwd=root,
|
||||
)
|
||||
|
||||
# 4. Start service
|
||||
click.echo("")
|
||||
time.sleep(1)
|
||||
ctx.invoke(start, no_logs=False)
|
||||
|
||||
|
||||
@click.command()
|
||||
def status():
|
||||
"""Show CowAgent running status."""
|
||||
from cli import __version__
|
||||
from cli.utils import load_config_json
|
||||
|
||||
pid = _read_pid()
|
||||
if pid:
|
||||
click.echo(click.style(f"● CowAgent is running (PID: {pid})", fg="green"))
|
||||
else:
|
||||
click.echo(click.style("● CowAgent is not running", fg="red"))
|
||||
|
||||
click.echo(f" 版本: v{__version__}")
|
||||
|
||||
cfg = load_config_json()
|
||||
if cfg:
|
||||
channel = cfg.get("channel_type", "unknown")
|
||||
if isinstance(channel, list):
|
||||
channel = ", ".join(channel)
|
||||
click.echo(f" 通道: {channel}")
|
||||
click.echo(f" 模型: {cfg.get('model', 'unknown')}")
|
||||
mode = "Agent" if cfg.get("agent") else "Chat"
|
||||
click.echo(f" 模式: {mode}")
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--follow", "-f", is_flag=True, help="Follow log output")
|
||||
@click.option("--lines", "-n", default=50, help="Number of lines to show")
|
||||
def logs(follow, lines):
|
||||
"""View CowAgent logs."""
|
||||
log_file = _get_log_file()
|
||||
if not os.path.exists(log_file):
|
||||
click.echo("No log file found.")
|
||||
return
|
||||
|
||||
if follow:
|
||||
_tail_log(log_file, lines)
|
||||
else:
|
||||
_print_last_lines(log_file, lines)
|
||||
|
||||
|
||||
def _print_last_lines(file_path: str, n: int = 50):
|
||||
"""Print the last N lines of a file (cross-platform)."""
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8", errors="replace") as f:
|
||||
all_lines = f.readlines()
|
||||
for line in all_lines[-n:]:
|
||||
click.echo(line, nl=False)
|
||||
except Exception as e:
|
||||
click.echo(f"Error reading log file: {e}", err=True)
|
||||
|
||||
|
||||
def _tail_log(log_file: str, lines: int = 50):
|
||||
"""Follow log file output. Blocks until Ctrl+C (cross-platform)."""
|
||||
_print_last_lines(log_file, lines)
|
||||
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
|
||||
f.seek(0, 2)
|
||||
while True:
|
||||
line = f.readline()
|
||||
if line:
|
||||
click.echo(line, nl=False)
|
||||
else:
|
||||
time.sleep(0.3)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
1483
cli/commands/skill.py
Normal file
1483
cli/commands/skill.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user