Compare commits
448 Commits
feat-confi
...
feat-add-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bccce2d7cb | ||
|
|
83cd6ad158 | ||
|
|
116fb27257 | ||
|
|
8d67177a1b | ||
|
|
ad2db1a776 | ||
|
|
2e6d9e0f27 | ||
|
|
e05f85f3ce | ||
|
|
40c48a9a61 | ||
|
|
c9a7525d0b | ||
|
|
fd571ac539 | ||
|
|
c5a3f991c5 | ||
|
|
eb74b73351 | ||
|
|
9b31f45481 | ||
|
|
bc9c1691f5 | ||
|
|
73bf83d2ff | ||
|
|
36e1988fee | ||
|
|
aad6ef635e | ||
|
|
96659cd616 | ||
|
|
c8787b7de4 | ||
|
|
91d427c8f9 | ||
|
|
c8c0573dbd | ||
|
|
29af855ecd | ||
|
|
0a146a245d | ||
|
|
bd85fee7d7 | ||
|
|
571897e2fd | ||
|
|
840dabeccd | ||
|
|
069bffa3e8 | ||
|
|
cc10d230b0 | ||
|
|
2517f2add8 | ||
|
|
a534266025 | ||
|
|
8c25395805 | ||
|
|
36b913124b | ||
|
|
90773ab69f | ||
|
|
b7734c3926 | ||
|
|
d3faf9c8dc | ||
|
|
bca97a1d14 | ||
|
|
ac9d0f18c5 | ||
|
|
09fa624797 | ||
|
|
b8333e351c | ||
|
|
a01423a196 | ||
|
|
7c35df7a82 | ||
|
|
2b90f377e6 | ||
|
|
fff7326209 | ||
|
|
c181e500bc | ||
|
|
16b7271826 | ||
|
|
4a1f62b185 | ||
|
|
d23a0754c1 | ||
|
|
3ffb563a44 | ||
|
|
4e42f2a017 | ||
|
|
a0dfdb79df | ||
|
|
a85c5f9d4e | ||
|
|
2720bba5b7 | ||
|
|
4634a7bc2f | ||
|
|
16d9b449c9 | ||
|
|
8761997757 | ||
|
|
19bba4abbc | ||
|
|
7839f0aac5 | ||
|
|
83def1db30 | ||
|
|
a0b29d1ffe | ||
|
|
f5479c56af | ||
|
|
246f0a45c8 | ||
|
|
fe871aad77 | ||
|
|
6f860e1bc4 | ||
|
|
249ea40ae3 | ||
|
|
20d8ae19a7 | ||
|
|
ad51aabfd7 | ||
|
|
1cf395c041 | ||
|
|
745179a5bf | ||
|
|
ff5d477fa5 | ||
|
|
907825601d | ||
|
|
c2ec26910a | ||
|
|
83f2aea123 | ||
|
|
a5c5439315 | ||
|
|
eca9b60235 | ||
|
|
d2d5d98d78 | ||
|
|
fb341b869b | ||
|
|
29e66cb186 | ||
|
|
307769b949 | ||
|
|
9a09e057d6 | ||
|
|
3e28659528 | ||
|
|
b861eef26f | ||
|
|
caaf006a49 | ||
|
|
b2429ec30c | ||
|
|
55aaf60a57 | ||
|
|
a5790d82f6 | ||
|
|
63f99af1e6 | ||
|
|
4eed2568aa | ||
|
|
fb7962c7f2 | ||
|
|
76e6b7b471 | ||
|
|
fccb7ff9ed | ||
|
|
3b12ef2e66 | ||
|
|
f9d099be1b | ||
|
|
c322c0e3a5 | ||
|
|
530fc20596 | ||
|
|
a23b4ed754 | ||
|
|
fc4f5077b0 | ||
|
|
6a553886da | ||
|
|
1065c7e722 | ||
|
|
a9c8a59f58 | ||
|
|
8730f7fd27 | ||
|
|
8f608223d7 | ||
|
|
a7cbd47a2f | ||
|
|
b80c3fe5a8 | ||
|
|
5080051e39 | ||
|
|
23bfc8d0ba | ||
|
|
80e9062041 | ||
|
|
67bd3420ed | ||
|
|
aea081703f | ||
|
|
f300d2a2d5 | ||
|
|
f150d7d83a | ||
|
|
4d1f059c0d | ||
|
|
bc7f953fcc | ||
|
|
f653483eea | ||
|
|
6b200fd36b | ||
|
|
161fc6cdf0 | ||
|
|
6f68ed6bce | ||
|
|
a4592ffdfe | ||
|
|
7cd7bd1a48 | ||
|
|
9eeca70292 | ||
|
|
02bfe30848 | ||
|
|
c9c99de3d9 | ||
|
|
8752f0cc60 | ||
|
|
5c65196e44 | ||
|
|
f5798bfe90 | ||
|
|
0e556b3468 | ||
|
|
31820f56e7 | ||
|
|
fd88828abd | ||
|
|
ae11159918 | ||
|
|
472a8605c0 | ||
|
|
e1760ba211 | ||
|
|
ce4c0a0aa4 | ||
|
|
64511593c4 | ||
|
|
b0e00dfceb | ||
|
|
fc465b463d | ||
|
|
68ce2e5232 | ||
|
|
81e8bb62ae | ||
|
|
2c13e1b923 | ||
|
|
a0748c2e3b | ||
|
|
40599bb751 | ||
|
|
f3c64ceea7 | ||
|
|
15c60de709 | ||
|
|
6dd316547f | ||
|
|
54c7676a44 | ||
|
|
d25b8966ce | ||
|
|
14a119c48c | ||
|
|
c82515a927 | ||
|
|
26e630c2dd | ||
|
|
13370d2056 | ||
|
|
35282db9e0 | ||
|
|
426fb88ce7 | ||
|
|
2384bd0e10 | ||
|
|
ba3f66d3d1 | ||
|
|
7293a0f670 | ||
|
|
9e86d46267 | ||
|
|
848430f062 | ||
|
|
abd21335c4 | ||
|
|
8fa95f058a | ||
|
|
d4e5ecd497 | ||
|
|
3830f76729 | ||
|
|
83f778fec9 | ||
|
|
cabd24605f | ||
|
|
ae20ba1148 | ||
|
|
3a50b64977 | ||
|
|
8692e74536 | ||
|
|
1c18bd9889 | ||
|
|
60e9d98d0a | ||
|
|
83f6625e0c | ||
|
|
acc09543b7 | ||
|
|
94d8c7e366 | ||
|
|
ea1a0c8b3d | ||
|
|
7bc88c17e4 | ||
|
|
33cf1bc4c3 | ||
|
|
9402e63fe1 | ||
|
|
90e4d494b2 | ||
|
|
da97e948ca | ||
|
|
89a07e8e74 | ||
|
|
3f3d0381e5 | ||
|
|
3649499dba | ||
|
|
a989d088fd | ||
|
|
f79a915136 | ||
|
|
12e8c3d449 | ||
|
|
4f7064575e | ||
|
|
070df826f1 | ||
|
|
fbe48a4b4e | ||
|
|
4dd497fb6d | ||
|
|
907882c0a7 | ||
|
|
d36d5aee3f | ||
|
|
c6824e5f5e | ||
|
|
199c21eede | ||
|
|
5162da5654 | ||
|
|
a1d82f6193 | ||
|
|
ea78e3d0c6 | ||
|
|
3497f00cb4 | ||
|
|
5355d45031 | ||
|
|
26693acc3f | ||
|
|
76e9fef3b2 | ||
|
|
c34308cbd4 | ||
|
|
5a10476010 | ||
|
|
46e80dceec | ||
|
|
90d1835353 | ||
|
|
845fadd0aa | ||
|
|
5748ded52c | ||
|
|
6a737fb734 | ||
|
|
3cd92ccda3 | ||
|
|
54e81aba11 | ||
|
|
d86cb4ded6 | ||
|
|
4d5375f6d6 | ||
|
|
424557fedb | ||
|
|
89251e603f | ||
|
|
a653ed07eb | ||
|
|
ad86deb014 | ||
|
|
9525dc7584 | ||
|
|
cd31dd27fd | ||
|
|
360e3670eb | ||
|
|
8dabe3b4c8 | ||
|
|
443e0c2806 | ||
|
|
9cc173cc4d | ||
|
|
b5f33e5ecd | ||
|
|
40dfc6860f | ||
|
|
1c02a04423 | ||
|
|
de0e45070c | ||
|
|
c169cc7d74 | ||
|
|
cd62ad76f6 | ||
|
|
dd25b0fb5b | ||
|
|
a38b22a6a2 | ||
|
|
830b8f2971 | ||
|
|
b058af122c | ||
|
|
174ee0cafc | ||
|
|
1c336380c0 | ||
|
|
3068880413 | ||
|
|
be596681e5 | ||
|
|
66b71c50e9 | ||
|
|
8744810b25 | ||
|
|
7f94d37c2e | ||
|
|
6d9b7baeb4 | ||
|
|
4470d4c352 | ||
|
|
d2a462a279 | ||
|
|
14ff2a15e7 | ||
|
|
6d1369900e | ||
|
|
1f17ebe69e | ||
|
|
1ae2918064 | ||
|
|
b6571e5cad | ||
|
|
7549d48cf1 | ||
|
|
00353dd0cb | ||
|
|
afd947195d | ||
|
|
e57ef37167 | ||
|
|
ef33a93654 | ||
|
|
61732aecfc | ||
|
|
6764c05c3f | ||
|
|
fa149cf4aa | ||
|
|
e4f9697d06 | ||
|
|
da061450e5 | ||
|
|
d09ae49287 | ||
|
|
511ee0bbaf | ||
|
|
3cb5a0fbd6 | ||
|
|
e06925ab85 | ||
|
|
184634e4e7 | ||
|
|
843c2d02cc | ||
|
|
8ea2455766 | ||
|
|
9dc9987d56 | ||
|
|
3458621147 | ||
|
|
079df5a47c | ||
|
|
ddb07c65a1 | ||
|
|
9b21cd222b | ||
|
|
90f736843f | ||
|
|
13c020eb61 | ||
|
|
dbc06dbe95 | ||
|
|
23d097bc1c | ||
|
|
db85b9808e | ||
|
|
df5bae37bc | ||
|
|
acc23b6051 | ||
|
|
61f2741afc | ||
|
|
4dd7ea886a | ||
|
|
1e8959fbcf | ||
|
|
48729678cf | ||
|
|
0684becaa7 | ||
|
|
db16bdf8cb | ||
|
|
f890318ed9 | ||
|
|
158510cbbe | ||
|
|
ce90cf7aa8 | ||
|
|
a3a3d006eb | ||
|
|
8fd029a4a1 | ||
|
|
2e1b52c1e5 | ||
|
|
3eb8348708 | ||
|
|
393f0c007c | ||
|
|
294e380288 | ||
|
|
4c1c42efac | ||
|
|
c062ca8c66 | ||
|
|
76dcb25103 | ||
|
|
c5b4f236db | ||
|
|
0974c940a8 | ||
|
|
cffa20d37e | ||
|
|
ef009edd29 | ||
|
|
3ca52b118d | ||
|
|
13f5fde4fb | ||
|
|
f512b55ec2 | ||
|
|
22b8ca0095 | ||
|
|
baf66a103d | ||
|
|
45faa9c1ff | ||
|
|
304381a88d | ||
|
|
fc9f54dbc8 | ||
|
|
7199dc187f | ||
|
|
e9ae066d53 | ||
|
|
d71ae406ff | ||
|
|
f3216904b3 | ||
|
|
5958b69ec9 | ||
|
|
7d4e2cb39a | ||
|
|
a483ec0cea | ||
|
|
c1421e0874 | ||
|
|
ce89869c3c | ||
|
|
b8b57e34ff | ||
|
|
bc7f627253 | ||
|
|
652156e398 | ||
|
|
9febb071c6 | ||
|
|
7d0e1568ac | ||
|
|
b4e711f411 | ||
|
|
1b5be1b981 | ||
|
|
49d8707c58 | ||
|
|
9192f6f7f7 | ||
|
|
05022e3745 | ||
|
|
5356e9ddeb | ||
|
|
52acf76e2c | ||
|
|
40cdbd3b45 | ||
|
|
5487c0befe | ||
|
|
8bb16c48c0 | ||
|
|
c6384363f9 | ||
|
|
8993e8ad3e | ||
|
|
289989d9f7 | ||
|
|
dc2ae0e6f1 | ||
|
|
9c966c152d | ||
|
|
4efae41048 | ||
|
|
b8437032e9 | ||
|
|
2d339ca81b | ||
|
|
d53abc9696 | ||
|
|
446c886d38 | ||
|
|
30c6d9b5ae | ||
|
|
5e42996b36 | ||
|
|
ceca7b85bf | ||
|
|
a4d54f58c8 | ||
|
|
005a0e1bad | ||
|
|
46d97fd57d | ||
|
|
72a26b6353 | ||
|
|
89a4033fbf | ||
|
|
39a5dc64bd | ||
|
|
d4bdd9b1b7 | ||
|
|
2f5ba87280 | ||
|
|
8b45d6c750 | ||
|
|
4ecd4df2d4 | ||
|
|
a42f31fe52 | ||
|
|
d4480b695e | ||
|
|
c4b5f7fbae | ||
|
|
ba915f2cc0 | ||
|
|
4b91140f31 | ||
|
|
9879878dd0 | ||
|
|
d78105d57c | ||
|
|
153c9e3565 | ||
|
|
c11623596d | ||
|
|
e791a77f77 | ||
|
|
b641bffb2c | ||
|
|
ee0c47ac1e | ||
|
|
eba90e9343 | ||
|
|
d8374d0fa5 | ||
|
|
fa61744c6d | ||
|
|
4fec55cc01 | ||
|
|
1767413712 | ||
|
|
734c8fa84f | ||
|
|
9a8d422554 | ||
|
|
b21e945c76 | ||
|
|
a02bf1ea09 | ||
|
|
eda82bac92 | ||
|
|
e8d4f7dc4f | ||
|
|
c4a93b7789 | ||
|
|
c3f9925097 | ||
|
|
2a0cf7511a | ||
|
|
d0a70d3339 | ||
|
|
f37e4675dd | ||
|
|
4e32f67eeb | ||
|
|
36d54cab52 | ||
|
|
9d8df10dcf | ||
|
|
45ea88e070 | ||
|
|
d5d0b947f5 | ||
|
|
f775f1f11e | ||
|
|
f1e888f3de | ||
|
|
71c8436e90 | ||
|
|
08c69f5e9b | ||
|
|
a50fafaca2 | ||
|
|
3c6781d240 | ||
|
|
3b8b5625f8 | ||
|
|
6be2034110 | ||
|
|
924dc79f00 | ||
|
|
ccb9030d3c | ||
|
|
8623287ac1 | ||
|
|
022c13f3a4 | ||
|
|
0687916e7f | ||
|
|
bb868b83ba | ||
|
|
24298130b9 | ||
|
|
6e5ee92ebd | ||
|
|
5b91fe04aa | ||
|
|
1623deb3ee | ||
|
|
4a16e05b7a | ||
|
|
f1c04bc60d | ||
|
|
84c6f31c76 | ||
|
|
9d528190bf | ||
|
|
0f23b209ad | ||
|
|
63d9325900 | ||
|
|
f342097f81 | ||
|
|
b4806c4366 | ||
|
|
ff37d8a577 | ||
|
|
a773eb7893 | ||
|
|
7c67513d24 | ||
|
|
6ed85029c5 | ||
|
|
e9c57ddf4d | ||
|
|
a33ce97ed9 | ||
|
|
b788a3dd4e | ||
|
|
fccfa92d7e | ||
|
|
8705bf0a70 | ||
|
|
9318138af7 | ||
|
|
269fa7d2d5 | ||
|
|
e99837a8b9 | ||
|
|
553861a2c4 | ||
|
|
628a85d1be | ||
|
|
2cb54514a4 | ||
|
|
6db22827f2 | ||
|
|
4cc6d5426b | ||
|
|
7d258b5202 | ||
|
|
c8d19ee0bc | ||
|
|
d891312032 | ||
|
|
5edbf4ce32 | ||
|
|
3ddbdd713d | ||
|
|
9ba107b511 | ||
|
|
c9adddb76a | ||
|
|
f0a12d5ff5 | ||
|
|
7cce224499 | ||
|
|
97397ca585 | ||
|
|
f2fbc602a8 | ||
|
|
925d728a86 | ||
|
|
f5f229871b | ||
|
|
9917552b4b | ||
|
|
adca89b973 | ||
|
|
29bfbecdc9 | ||
|
|
1a7a8c98d9 | ||
|
|
cddb38ac3d | ||
|
|
394853c0fb | ||
|
|
c0702c8b36 | ||
|
|
d610608391 | ||
|
|
9082eec91d | ||
|
|
f1a1413b5f | ||
|
|
c1e7f9af9b |
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
@@ -79,8 +79,6 @@ body:
|
||||
description: |
|
||||
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
|
||||
options:
|
||||
- wx(个人微信, itchat)
|
||||
- wxy(个人微信, wechaty)
|
||||
- wechatmp(公众号, 订阅号)
|
||||
- wechatmp_service(公众号, 服务号)
|
||||
- terminal
|
||||
|
||||
11
.github/workflows/deploy-image-arm.yml
vendored
@@ -19,7 +19,7 @@ env:
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
if: github.repository == 'zhayujie/CowAgent'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -51,7 +51,12 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/zhayujie/chatgpt-on-wechat
|
||||
${{ env.REGISTRY }}/zhayujie/cowagent
|
||||
tags: |
|
||||
type=raw,value=latest-arm64,enable={{is_default_branch}}
|
||||
type=ref,event=branch,suffix=-arm64
|
||||
type=ref,event=tag,suffix=-arm64
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
@@ -60,7 +65,7 @@ jobs:
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
platforms: linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}-arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
|
||||
13
.github/workflows/deploy-image.yml
vendored
@@ -16,10 +16,11 @@ on:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
DOCKERHUB_IMAGE: zhayujie/chatgpt-on-wechat
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
if: github.repository == 'zhayujie/CowAgent'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -47,8 +48,14 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
zhayujie/chatgpt-on-wechat
|
||||
zhayujie/cowagent
|
||||
${{ env.REGISTRY }}/zhayujie/chatgpt-on-wechat
|
||||
${{ env.REGISTRY }}/zhayujie/cowagent
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
|
||||
13
.gitignore
vendored
@@ -3,16 +3,15 @@
|
||||
.vscode
|
||||
.venv
|
||||
.vs
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
*.pyc
|
||||
python
|
||||
config.json
|
||||
QR.png
|
||||
nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
logs/
|
||||
workspace
|
||||
@@ -33,8 +32,16 @@ plugins/banwords/lib/__pycache__
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
!plugins/agent
|
||||
!plugins/cow_cli
|
||||
client_config.json
|
||||
ref/
|
||||
**/.dev.vars
|
||||
.cursor/
|
||||
local/
|
||||
node_modules/
|
||||
|
||||
# cow cli
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
.cow.pid
|
||||
|
||||
928
README.md
@@ -1,773 +1,255 @@
|
||||
<p align="center"><img src= "https://github.com/user-attachments/assets/eca9a9ec-8534-4615-9e0f-96c5ac1d10a3" alt="Chatgpt-on-Wechat" width="550" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/eca9a9ec-8534-4615-9e0f-96c5ac1d10a3" alt="CowAgent" width="420" /></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/chatgpt-on-wechat" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/chatgpt-on-wechat" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat"><img src="https://img.shields.io/github/stars/zhayujie/chatgpt-on-wechat?style=flat-square" alt="Stars"></a> <br/>
|
||||
<a href="https://github.com/zhayujie/CowAgent/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/CowAgent" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/CowAgent/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/CowAgent" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/CowAgent"><img src="https://img.shields.io/github/stars/zhayujie/CowAgent?style=flat-square" alt="Stars"></a> <br/>
|
||||
[English] | [<a href="docs/zh/README.md">中文</a>] | [<a href="docs/ja/README.md">日本語</a>]
|
||||
</p>
|
||||
|
||||
**CowAgent** 是基于大模型的超级AI助理,能够主动思考和任务规划、操作计算机和外部资源、创造和执行Skills、拥有长期记忆并不断成长。CowAgent 支持灵活切换多种模型,能处理文本、语音、图片、文件等多模态消息,可接入网页、飞书、钉钉、企业微信应用、微信公众号中使用,7*24小时运行于你的个人电脑或服务器中。
|
||||
**CowAgent** is an open-source super AI assistant that proactively plans tasks, controls your computer and external services, creates and runs Skills, and grows alongside you through a personal knowledge base and long-term memory — a reference implementation of Agent Harness engineering.
|
||||
|
||||
📖能力介绍:[CowAgent 2.0](/docs/agent.md)
|
||||
CowAgent is lightweight, easy to deploy, and built to extend. Plug in any major LLM provider and run it 24/7 on a personal computer or server, across the web and all major IM platforms.
|
||||
|
||||
# 简介
|
||||
|
||||
> 该项目既是一个可以开箱即用的超级AI助理,也是一个支持高扩展的Agent框架,可以通过为项目扩展大模型接口、接入渠道、内置工具、Skills系统来灵活实现各种定制需求。核心能力如下:
|
||||
|
||||
- ✅ **复杂任务规划**:能够理解复杂任务并自主规划执行,持续思考和调用工具直到完成目标,支持通过工具操作访问文件、终端、浏览器、定时任务等系统资源
|
||||
- ✅ **长期记忆:** 自动将对话记忆持久化至本地文件和数据库中,包括全局记忆和天级记忆,支持关键词及向量检索
|
||||
- ✅ **技能系统:** 实现了Skills创建和运行的引擎,内置多种技能,并支持通过自然语言对话完成自定义Skills开发
|
||||
- ✅ **多模态消息:** 支持对文本、图片、语音、文件等多类型消息进行解析、处理、生成、发送等操作
|
||||
- ✅ **多模型接入:** 支持OpenAI, Claude, Gemini, DeepSeek, MiniMax、GLM、Qwen、Kimi、Doubao等国内外主流模型厂商
|
||||
- ✅ **多端部署:** 支持运行在本地计算机或服务器,可集成到网页、飞书、钉钉、微信公众号、企业微信应用中使用
|
||||
- ✅ **知识库:** 集成企业知识库能力,让Agent成为专属数字员工,基于[LinkAI](https://link-ai.tech)平台实现
|
||||
|
||||
## 声明
|
||||
|
||||
1. 本项目遵循 [MIT开源协议](/LICENSE),主要用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为。任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任
|
||||
2. 成本与安全:Agent模式下Token使用量高于普通对话模式,请根据效果及成本综合选择模型。Agent具有访问所在操作系统的能力,请谨慎选择项目部署环境。同时项目也会持续升级安全机制、并降低模型消耗成本
|
||||
|
||||
## 演示
|
||||
|
||||
使用说明(Agent模式):[CowAgent介绍](/docs/agent.md)
|
||||
|
||||
DEMO视频(对话模式):https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
## 社区
|
||||
|
||||
添加小助手微信加入开源项目交流群:
|
||||
|
||||
<img width="140" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
<p align="center">
|
||||
<a href="https://cowagent.ai/">🌐 Website</a> ·
|
||||
<a href="https://docs.cowagent.ai/en/intro/index">📖 Docs</a> ·
|
||||
<a href="https://docs.cowagent.ai/en/guide/quick-start">🚀 Quick Start</a> ·
|
||||
<a href="https://skills.cowagent.ai/">🧩 Skill Hub</a> ·
|
||||
<a href="https://link-ai.tech/cowagent/create">☁️ Try Online</a>
|
||||
</p>
|
||||
|
||||
<br/>
|
||||
|
||||
# 企业服务
|
||||
## 🌟 Highlights
|
||||
|
||||
<a href="https://link-ai.tech" target="_blank"><img width="720" src="https://cdn.link-ai.tech/image/link-ai-intro.jpg"></a>
|
||||
|
||||
> [LinkAI](https://link-ai.tech/) 是面向企业和开发者的一站式AI智能体平台,聚合多模态大模型、知识库、Agent 插件、工作流等能力,支持一键接入主流平台并进行管理,支持SaaS、私有化部署等多种模式。
|
||||
>
|
||||
> LinkAI 目前已在智能客服、私域运营、企业效率助手等场景积累了丰富的AI解决方案,在消费、健康、文教、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。
|
||||
|
||||
**产品咨询和企业服务** 可联系产品客服:
|
||||
|
||||
<img width="150" src="https://cdn.link-ai.tech/portal/linkai-customer-service.png">
|
||||
| Capability | Description |
|
||||
| :--- | :--- |
|
||||
| [Planning](https://docs.cowagent.ai/en/intro/architecture) | Decomposes complex tasks and executes them step by step, looping over tools until the goal is reached |
|
||||
| [Memory](https://docs.cowagent.ai/en/memory/index) | Three-tier architecture (context → daily → core), automatic Deep Dream distillation, hybrid keyword + vector retrieval |
|
||||
| [Knowledge](https://docs.cowagent.ai/en/knowledge/index) | Auto-curates structured knowledge into a Markdown wiki, builds an evolving knowledge graph with visual browsing |
|
||||
| [Skills](https://docs.cowagent.ai/en/skills/index) | One-click install from [Skill Hub](https://skills.cowagent.ai/), GitHub, ClawHub; or create custom skills via natural-language conversation |
|
||||
| [Tools](https://docs.cowagent.ai/en/tools/index) | Built-in file I/O, terminal, browser, scheduler, memory retrieval, web search, and 10+ more tools — with native MCP integration |
|
||||
| [Channels](https://docs.cowagent.ai/en/channels/index) | Integrates with Web, WeChat, Feishu, DingTalk, WeCom, QQ, and Official Accounts |
|
||||
| Multimodal | First-class support for text, images, voice, and files — recognition, generation, and delivery |
|
||||
| [Models](https://docs.cowagent.ai/en/models/index) | Claude, GPT, Gemini, DeepSeek, Qwen, GLM, Kimi, MiniMax, Doubao, and more — swap providers from the Web console with one click |
|
||||
| [Deploy](https://docs.cowagent.ai/en/guide/quick-start) | One-line installer, unified Web console, multiple deployment modes (local, Docker, server) |
|
||||
|
||||
<br/>
|
||||
|
||||
# 🏷 更新日志
|
||||
## 🏗️ Architecture
|
||||
|
||||
>**2026.02.03:** [2.0.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/2.0.0),正式升级为超级Agent助理,支持多轮任务决策、具备长期记忆、实现多种系统工具、支持Skills框架,新增多种模型并优化了接入渠道。
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/architecture/en/architecture.jpg" alt="CowAgent Architecture" width="750"/>
|
||||
|
||||
>**2025.05.23:** [1.7.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.6) 优化web网页channel、新增 [AgentMesh](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md)多智能体插件、百度语音合成优化、企微应用`access_token`获取优化、支持`claude-4-sonnet`和`claude-4-opus`模型
|
||||
CowAgent is a complete **Agent Harness**: messages flow in through **Channels**; the **Agent Core** plans and reasons over memory, knowledge, and the available tools and skills; **Models** generate the response, which is sent back through the originating channel. Every layer is decoupled and independently extensible.
|
||||
|
||||
>**2025.04.11:** [1.7.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.5) 新增支持 [wechatferry](https://github.com/zhayujie/chatgpt-on-wechat/pull/2562) 协议、新增 deepseek 模型、新增支持腾讯云语音能力、新增支持 ModelScope 和 Gitee-AI API接口
|
||||
|
||||
>**2024.12.13:** [1.7.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.4) 新增 Gemini 2.0 模型、新增web channel、解决内存泄漏问题、解决 `#reloadp` 命令重载不生效问题
|
||||
|
||||
>**2024.10.31:** [1.7.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.3) 程序稳定性提升、数据库功能、Claude模型优化、linkai插件优化、离线通知
|
||||
|
||||
更多更新历史请查看: [更新日志](/docs/release/history.md)
|
||||
Read more in [Architecture](https://docs.cowagent.ai/en/intro/architecture).
|
||||
|
||||
<br/>
|
||||
|
||||
# 🚀 快速开始
|
||||
## 🚀 Quick Start
|
||||
|
||||
项目提供了一键安装、配置、启动、管理程序的脚本,推荐使用脚本快速运行,也可以根据下文中的详细指引一步步安装运行。
|
||||
A one-line installer takes care of dependencies, configuration, and startup:
|
||||
|
||||
在终端执行以下命令:
|
||||
**Linux / macOS:**
|
||||
|
||||
```bash
|
||||
bash <(curl -sS https://cdn.link-ai.tech/code/cow/run.sh)
|
||||
bash <(curl -fsSL https://cdn.link-ai.tech/code/cow/run.sh)
|
||||
```
|
||||
|
||||
脚本使用说明:[一键运行脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/CowAgentQuickStart)
|
||||
**Windows (PowerShell):**
|
||||
|
||||
```powershell
|
||||
irm https://cdn.link-ai.tech/code/cow/run.ps1 | iex
|
||||
```
|
||||
|
||||
## 一、准备
|
||||
|
||||
### 1. 模型API
|
||||
|
||||
项目支持国内外主流厂商的模型接口,可选模型及配置说明参考:[模型说明](#模型说明)。
|
||||
|
||||
> 注:Agent模式下推荐使用以下模型,可根据效果及成本综合选择:MiniMax-M2.5、glm-5、kimi-k2.5、qwen3.5-plus、claude-sonnet-4-6、gemini-3.1-pro-preview
|
||||
|
||||
同时支持使用 **LinkAI平台** 接口,可灵活切换 OpenAI、Claude、Gemini、DeepSeek、Qwen、Kimi 等多种常用模型,并支持知识库、工作流、插件等Agent能力,参考 [接口文档](https://docs.link-ai.tech/platform/api)。
|
||||
|
||||
### 2.环境安装
|
||||
|
||||
支持 Linux、MacOS、Windows 操作系统,可在个人计算机及服务器上运行,需安装 `Python`,Python版本需在3.7 ~ 3.12 之间,推荐使用3.9版本。
|
||||
|
||||
> 注意:Agent模式推荐使用源码运行,若选择Docker部署则无需安装python环境和下载源码,可直接快进到下一节。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
**Docker:**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/zhayujie/chatgpt-on-wechat
|
||||
cd chatgpt-on-wechat/
|
||||
curl -O https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
若遇到网络问题可使用国内仓库地址:https://gitee.com/zhayujie/chatgpt-on-wechat
|
||||
Once started, open `http://localhost:9899` to access the **Web console** — your one-stop hub to chat with the Agent, configure models, connect channels, and install skills.
|
||||
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
> Deploying on a server? Set `web_host` to `0.0.0.0` in `config.json` to make the console reachable from outside, and set `web_password` to protect it. Don't forget to open port `9899` in your firewall or security group.
|
||||
|
||||
> 📖 Detailed guides: [Quick Start](https://docs.cowagent.ai/en/guide/quick-start) · [Install from Source](https://docs.cowagent.ai/en/guide/manual-install) · [Upgrade](https://docs.cowagent.ai/en/guide/upgrade)
|
||||
|
||||
After installation, manage the service with the [cow CLI](https://docs.cowagent.ai/en/cli/index):
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
cow start | stop | restart # service control
|
||||
cow status | logs # status and logs
|
||||
cow update # pull latest code and restart
|
||||
cow skill install <name> # install a skill
|
||||
cow install-browser # install browser automation
|
||||
```
|
||||
|
||||
**(3) 拓展依赖 (可选,建议安装):**
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements-optional.txt
|
||||
```
|
||||
如果某项依赖安装失败可注释掉对应的行后重试。
|
||||
|
||||
## 二、配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
|
||||
```bash
|
||||
cp config-template.json config.json
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证JSON格式的规范):
|
||||
|
||||
```bash
|
||||
# config.json 文件内容示例
|
||||
{
|
||||
"channel_type": "web", # 接入渠道类型,默认为web,支持修改为:feishu,dingtalk,wechatcom_app,terminal,wechatmp,wechatmp_service
|
||||
"model": "MiniMax-M2.5", # 模型名称
|
||||
"minimax_api_key": "", # MiniMax API Key
|
||||
"zhipu_ai_api_key": "", # 智谱GLM API Key
|
||||
"moonshot_api_key": "", # Kimi/Moonshot API Key
|
||||
"ark_api_key": "", # 豆包(火山方舟) API Key
|
||||
"dashscope_api_key": "", # 百炼(通义千问)API Key
|
||||
"claude_api_key": "", # Claude API Key
|
||||
"claude_api_base": "https://api.anthropic.com/v1", # Claude API 地址,修改可接入三方代理平台
|
||||
"gemini_api_key": "", # Gemini API Key
|
||||
"gemini_api_base": "https://generativelanguage.googleapis.com", # Gemini API地址
|
||||
"open_ai_api_key": "", # OpenAI API Key
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # OpenAI API 地址
|
||||
"linkai_api_key": "", # LinkAI API Key
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境需要开启代理的可填写该项,如 "127.0.0.1:7890"
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"voice_reply_voice": false, # 是否使用语音回复语音
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台接口
|
||||
"agent": true, # 是否启用Agent模式,启用后拥有多轮工具决策、长期记忆、Skills能力等
|
||||
"agent_workspace": "~/cow", # Agent的工作空间路径,用于存储memory、skills、系统设定等
|
||||
"agent_max_context_tokens": 40000, # Agent模式下最大上下文tokens,超出将自动丢弃最早的上下文
|
||||
"agent_max_context_turns": 30, # Agent模式下最大上下文记忆轮次,每轮包括一次用户提问和AI回复
|
||||
"agent_max_steps": 15 # Agent模式下单次任务的最大决策步数,超出后将停止继续调用工具
|
||||
}
|
||||
```
|
||||
|
||||
**配置补充说明:**
|
||||
|
||||
<details>
|
||||
<summary>1. 语音配置</summary>
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>2. 其他配置</summary>
|
||||
|
||||
+ `model`: 模型名称,Agent模式下推荐使用 `MiniMax-M2.5`、`glm-5`、`kimi-k2.5`、`qwen3.5-plus`、`claude-sonnet-4-6`、`gemini-3.1-pro-preview`,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `character_desc`:普通对话模式下的机器人系统提示词。在Agent模式下该配置不生效,由工作空间中的文件内容构成。
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>3. LinkAI配置</summary>
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台,使用知识库、工作流、插件等能力, 参考[接口文档](https://docs.link-ai.tech/platform/api/chat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用或工作流的code,选填,普通对话模式中使用。
|
||||
</details>
|
||||
|
||||
注:全部配置项说明可在 [`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py) 文件中查看。
|
||||
|
||||
## 三、运行
|
||||
|
||||
### 1.本地运行
|
||||
|
||||
如果是个人计算机 **本地运行**,直接在项目根目录下执行:
|
||||
|
||||
```bash
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
|
||||
运行后默认会启动web服务,可通过访问 `http://localhost:9899/chat` 在网页端对话。
|
||||
|
||||
如果需要接入其他应用通道只需修改 `config.json` 配置文件中的 `channel_type` 参数,详情参考:[通道说明](#通道说明)。
|
||||
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
在服务器中可使用 `nohup` 命令在后台运行程序:
|
||||
|
||||
```bash
|
||||
nohup python3 app.py & tail -f nohup.out
|
||||
```
|
||||
|
||||
执行后程序运行于服务器后台,可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。 日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。
|
||||
|
||||
此外,项目的 `scripts` 目录下有一键运行、关闭程序的脚本供使用。 运行后默认channel为web,通过可以通过修改配置文件进行切换。
|
||||
|
||||
|
||||
### 3.Docker部署
|
||||
|
||||
使用docker部署无需下载源码和安装依赖,只需要获取 `docker-compose.yml` 配置文件并启动容器即可。Agent模式下更推荐使用源码进行部署,以获得更多系统访问能力。
|
||||
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功后执行 `docker -v` 和 `docker-compose version` (或 `docker compose version`) 可查看到版本号。安装地址为 [docker官网](https://docs.docker.com/engine/install/) 。
|
||||
|
||||
**(1) 下载 docker-compose.yml 文件**
|
||||
|
||||
```bash
|
||||
wget https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
```
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 填写所需配置,例如 `CHANNEL_TYPE`、`OPEN_AI_API_KEY` 和等配置。
|
||||
|
||||
**(2) 启动容器**
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
|
||||
```bash
|
||||
sudo docker compose up -d # 若docker-compose为 1.X 版本,则执行 `sudo docker-compose up -d`
|
||||
```
|
||||
|
||||
运行命令后,会自动取 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取最新release版本的镜像。当执行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。最后执行以下命令可查看容器的运行日志:
|
||||
|
||||
```bash
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
**(3) 插件使用**
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
|
||||
```
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
**注**:使用docker方式部署的详细教程可以参考:[docker部署CoW项目](https://www.wangpc.cc/ai/docker-deploy-cow/)
|
||||
|
||||
|
||||
## 模型说明
|
||||
|
||||
以下对所有可支持的模型的配置和使用方法进行说明,模型接口实现在项目的 `models/` 目录下。
|
||||
|
||||
<details>
|
||||
<summary>OpenAI</summary>
|
||||
|
||||
1. API Key创建:在 [OpenAI平台](https://platform.openai.com/api-keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4.1-mini",
|
||||
"open_ai_api_key": "YOUR_API_KEY",
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"bot_type": "chatGPT"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 与OpenAI接口的 [model参数](https://platform.openai.com/docs/models) 一致,支持包括 o系列、gpt-5.2、gpt-5.1、gpt-4.1等系列模型
|
||||
- `open_ai_api_base`: 如果需要接入第三方代理接口,可通过修改该参数进行接入
|
||||
- `bot_type`: 使用OpenAI相关模型时无需填写。当使用第三方代理接口接入Claude等非OpenAI官方模型时,该参数设为 `chatGPT`
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>LinkAI</summary>
|
||||
|
||||
1. API Key创建:在 [LinkAI平台](https://link-ai.tech/console/interface) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"use_linkai": true,
|
||||
"linkai_api_key": "YOUR API KEY",
|
||||
"linkai_app_code": "YOUR APP CODE"
|
||||
}
|
||||
```
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台的智能体,使用知识库、工作流、数据库、MCP插件等丰富的Agent能力
|
||||
+ `linkai_api_key`: LinkAI平台的API Key,可在 [控制台](https://link-ai.tech/console/interface) 中创建
|
||||
+ `linkai_app_code`: LinkAI智能体 (应用或工作流) 的code,选填,普通对话模式可用。智能体创建可参考 [说明文档](https://docs.link-ai.tech/platform/quick-start)
|
||||
+ `model`: model字段填写空则直接使用智能体的模型,可在平台中灵活切换,[模型列表](https://link-ai.tech/console/models)中的全部模型均可使用
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>MiniMax</summary>
|
||||
|
||||
方式一:官方接入,配置如下(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "MiniMax-M2.5",
|
||||
"minimax_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `MiniMax-M2.5、MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2、abab6.5-chat` 等
|
||||
- `minimax_api_key`:MiniMax平台的API-KEY,在 [控制台](https://platform.minimaxi.com/user-center/basic-information/interface-key) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "MiniMax-M2.5",
|
||||
"open_ai_api_base": "https://api.minimaxi.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `MiniMax-M2.5、MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2`,参考[API文档](https://platform.minimaxi.com/document/%E5%AF%B9%E8%AF%9D?key=66701d281d57f38758d581d0#QklxsNSbaf6kM4j6wjO5eEek)
|
||||
- `open_ai_api_base`: MiniMax平台API的 BASE URL
|
||||
- `open_ai_api_key`: MiniMax平台的API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>智谱AI (GLM)</summary>
|
||||
|
||||
方式一:官方接入,配置如下(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "glm-5",
|
||||
"zhipu_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `glm-5、glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等, 参考 [glm系列模型编码](https://bigmodel.cn/dev/api/normal-model/glm-4)
|
||||
- `zhipu_ai_api_key`: 智谱AI平台的 API KEY,在 [控制台](https://www.bigmodel.cn/usercenter/proj-mgmt/apikeys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "glm-5",
|
||||
"open_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `glm-5、glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等
|
||||
- `open_ai_api_base`: 智谱AI平台的 BASE URL
|
||||
- `open_ai_api_key`: 智谱AI平台的 API KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>通义千问 (Qwen)</summary>
|
||||
|
||||
方式一:官方SDK接入,配置如下(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen3.5-plus",
|
||||
"dashscope_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `qwen3.5-plus、qwen3-max、qwen-max、qwen-plus、qwen-turbo、qwen-long、qwq-plus` 等
|
||||
- `dashscope_api_key`: 通义千问的 API-KEY,参考 [官方文档](https://bailian.console.aliyun.com/?tab=api#/api) ,在 [控制台](https://bailian.console.aliyun.com/?tab=model#/api-key) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "qwen3.5-plus",
|
||||
"open_ai_api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"open_ai_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://help.aliyun.com/zh/model-studio/models?spm=a2c4g.11186623.0.0.78d84823Kth5on#9f8890ce29g5u)
|
||||
- `open_ai_api_base`: 通义千问API的 BASE URL
|
||||
- `open_ai_api_key`: 通义千问的 API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Kimi (Moonshot)</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "kimi-k2.5",
|
||||
"moonshot_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `kimi-k2.5、kimi-k2、moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k`
|
||||
- `moonshot_api_key`: Moonshot的API-KEY,在 [控制台](https://platform.moonshot.cn/console/api-keys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "kimi-k2.5",
|
||||
"open_ai_api_base": "https://api.moonshot.cn/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写 `kimi-k2.5、kimi-k2、moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k`
|
||||
- `open_ai_api_base`: Moonshot的 BASE URL
|
||||
- `open_ai_api_key`: Moonshot的 API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>豆包 (Doubao)</summary>
|
||||
|
||||
1. API Key创建:在 [火山方舟控制台](https://console.volcengine.com/ark/region:ark+cn-beijing/apikey) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "doubao-seed-2-0-code-preview-260215",
|
||||
"ark_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `doubao-seed-2-0-code-preview-260215、doubao-seed-2-0-pro-260215、doubao-seed-2-0-lite-260215、doubao-seed-2-0-mini-260215` 等
|
||||
- `ark_api_key`: 火山方舟平台的 API Key,在 [控制台](https://console.volcengine.com/ark/region:ark+cn-beijing/apikey) 创建
|
||||
- `ark_base_url`: 可选,默认为 `https://ark.cn-beijing.volces.com/api/v3`
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Claude</summary>
|
||||
|
||||
1. API Key创建:在 [Claude控制台](https://console.anthropic.com/settings/keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"claude_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 参考 [官方模型ID](https://docs.anthropic.com/en/docs/about-claude/models/overview#model-aliases) ,支持 `claude-sonnet-4-6、claude-opus-4-6、claude-sonnet-4-5、claude-sonnet-4-0、claude-opus-4-0、claude-3-5-sonnet-latest` 等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Gemini</summary>
|
||||
|
||||
API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn) 创建API Key ,配置如下
|
||||
```json
|
||||
{
|
||||
"model": "gemini-3.1-pro-preview",
|
||||
"gemini_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 参考[官方文档-模型列表](https://ai.google.dev/gemini-api/docs/models?hl=zh-cn),支持 `gemini-3.1-pro-preview、gemini-3-flash-preview、gemini-3-pro-preview、gemini-2.5-pro、gemini-2.0-flash` 等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>DeepSeek</summary>
|
||||
|
||||
1. API Key创建:在 [DeepSeek平台](https://platform.deepseek.com/api_keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "deepseek-chat",
|
||||
"open_ai_api_key": "sk-xxxxxxxxxxx",
|
||||
"open_ai_api_base": "https://api.deepseek.com/v1",
|
||||
"bot_type": "chatGPT"
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `deepseek-chat、deepseek-reasoner`,分别对应的是 DeepSeek-V3 和 DeepSeek-R1 模型
|
||||
- `open_ai_api_key`: DeepSeek平台的 API Key
|
||||
- `open_ai_api_base`: DeepSeek平台 BASE URL
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Azure</summary>
|
||||
|
||||
1. API Key创建:在 [Azure平台](https://oai.azure.com/) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "",
|
||||
"use_azure_chatgpt": true,
|
||||
"open_ai_api_key": "",
|
||||
"open_ai_api_base": "",
|
||||
"azure_deployment_id": "",
|
||||
"azure_api_version": "2025-01-01-preview"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 留空即可
|
||||
- `use_azure_chatgpt`: 设为 true
|
||||
- `open_ai_api_key`: Azure平台的密钥
|
||||
- `open_ai_api_base`: Azure平台的 BASE URL
|
||||
- `azure_deployment_id`: Azure平台部署的模型名称
|
||||
- `azure_api_version`: api版本以及以上参数可以在部署的 [模型配置](https://oai.azure.com/resource/deployments) 界面查看
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>百度文心</summary>
|
||||
方式一:官方SDK接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "wenxin-4",
|
||||
"baidu_wenxin_api_key": "IajztZ0bDxgnP9bEykU7lBer",
|
||||
"baidu_wenxin_secret_key": "EDPZn6L24uAS9d8RWFfotK47dPvkjD6G"
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `wenxin`和`wenxin-4`,对应模型为 文心-3.5 和 文心-4.0
|
||||
- `baidu_wenxin_api_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 API Key
|
||||
- `baidu_wenxin_secret_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 Secret Key
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "ERNIE-4.0-Turbo-8K",
|
||||
"open_ai_api_base": "https://qianfan.baidubce.com/v2",
|
||||
"open_ai_api_key": "bce-v3/ALTxxxxxxd2b"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Wm9cvy6rl)
|
||||
- `open_ai_api_base`: 百度文心API的 BASE URL
|
||||
- `open_ai_api_key`: 百度文心的 API-KEY,参考 [官方文档](https://cloud.baidu.com/doc/qianfan-api/s/ym9chdsy5) ,在 [控制台](https://console.bce.baidu.com/iam/#/iam/apikey/list) 创建API Key
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>讯飞星火</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
参考 [官方文档-快速指引](https://www.xfyun.cn/doc/platform/quickguide.html#%E7%AC%AC%E4%BA%8C%E6%AD%A5-%E5%88%9B%E5%BB%BA%E6%82%A8%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8-%E5%BC%80%E5%A7%8B%E4%BD%BF%E7%94%A8%E6%9C%8D%E5%8A%A1) 获取 `APPID、 APISecret、 APIKey` 三个参数
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "xunfei",
|
||||
"xunfei_app_id": "",
|
||||
"xunfei_api_key": "",
|
||||
"xunfei_api_secret": "",
|
||||
"xunfei_domain": "4.0Ultra",
|
||||
"xunfei_spark_url": "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
}
|
||||
```
|
||||
- `model`: 填 `xunfei`
|
||||
- `xunfei_domain`: 可填写 `4.0Ultra、generalv3.5、max-32k、generalv3、pro-128k、lite`
|
||||
- `xunfei_spark_url`: 填写参考 [官方文档-请求地址](https://www.xfyun.cn/doc/spark/Web.html#_1-1-%E8%AF%B7%E6%B1%82%E5%9C%B0%E5%9D%80) 的说明
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "4.0Ultra",
|
||||
"open_ai_api_base": "https://spark-api-open.xf-yun.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写 `4.0Ultra、generalv3.5、max-32k、generalv3、pro-128k、lite`
|
||||
- `open_ai_api_base`: 讯飞星火平台的 BASE URL
|
||||
- `open_ai_api_key`: 讯飞星火平台的[APIPassword](https://console.xfyun.cn/services/bm3) ,因模型而已
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ModelScope</summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"bot_type": "modelscope",
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"modelscope_api_key": "your_api_key",
|
||||
"modelscope_base_url": "https://api-inference.modelscope.cn/v1/chat/completions",
|
||||
"text_to_image": "MusePublic/489_ckpt_FLUX_1"
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: modelscope接口格式
|
||||
- `model`: 参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
- `modelscope_api_key`: 参考 [官方文档-访问令牌](https://modelscope.cn/docs/accounts/token) ,在 [控制台](https://modelscope.cn/my/myaccesstoken)
|
||||
- `modelscope_base_url`: modelscope平台的 BASE URL
|
||||
- `text_to_image`: 图像生成模型,参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
</details>
|
||||
|
||||
|
||||
## 通道说明
|
||||
|
||||
以下对可接入通道的配置方式进行说明,应用通道代码在项目的 `channel/` 目录下。
|
||||
|
||||
<details>
|
||||
<summary>1. Web</summary>
|
||||
|
||||
项目启动后默认运行Web通道,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "web",
|
||||
"web_port": 9899
|
||||
}
|
||||
```
|
||||
|
||||
- `web_port`: 默认为 9899,可按需更改,需要服务器防火墙和安全组放行该端口
|
||||
- 如本地运行,启动后请访问 `http://localhost:9899/chat` ;如服务器运行,请访问 `http://ip:9899/chat`
|
||||
> 注:请将上述 url 中的 ip 或者 port 替换为实际的值
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>2. Feishu - 飞书</summary>
|
||||
|
||||
飞书支持两种事件接收模式:WebSocket 长连接(推荐)和 Webhook。
|
||||
|
||||
**方式一:WebSocket 模式(推荐,无需公网 IP)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "APP_ID",
|
||||
"feishu_app_secret": "APP_SECRET",
|
||||
"feishu_event_mode": "websocket"
|
||||
}
|
||||
```
|
||||
|
||||
**方式二:Webhook 模式(需要公网 IP)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "APP_ID",
|
||||
"feishu_app_secret": "APP_SECRET",
|
||||
"feishu_token": "VERIFICATION_TOKEN",
|
||||
"feishu_event_mode": "webhook",
|
||||
"feishu_port": 9891
|
||||
}
|
||||
```
|
||||
|
||||
- `feishu_event_mode`: 事件接收模式,`websocket`(推荐)或 `webhook`
|
||||
- WebSocket 模式需安装依赖:`pip3 install lark-oapi`
|
||||
|
||||
详细步骤和参数说明参考 [飞书接入](https://docs.link-ai.tech/cow/multi-platform/feishu)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>3. DingTalk - 钉钉</summary>
|
||||
|
||||
钉钉需要在开放平台创建智能机器人应用,将以下配置填入 `config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "dingtalk",
|
||||
"dingtalk_client_id": "CLIENT_ID",
|
||||
"dingtalk_client_secret": "CLIENT_SECRET"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [钉钉接入](https://docs.link-ai.tech/cow/multi-platform/dingtalk)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>4. WeCom App - 企业微信应用</summary>
|
||||
|
||||
企业微信自建应用接入需在后台创建应用并启用消息回调,配置示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatcom_app",
|
||||
"wechatcom_corp_id": "CORPID",
|
||||
"wechatcomapp_token": "TOKEN",
|
||||
"wechatcomapp_port": 9898,
|
||||
"wechatcomapp_secret": "SECRET",
|
||||
"wechatcomapp_agent_id": "AGENTID",
|
||||
"wechatcomapp_aes_key": "AESKEY"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [企微自建应用接入](https://docs.link-ai.tech/cow/multi-platform/wechat-com)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>5. WeChat MP - 微信公众号</summary>
|
||||
|
||||
本项目支持订阅号和服务号两种公众号,通过服务号(`wechatmp_service`)体验更佳。
|
||||
|
||||
**个人订阅号(wechatmp)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatmp",
|
||||
"wechatmp_token": "TOKEN",
|
||||
"wechatmp_port": 80,
|
||||
"wechatmp_app_id": "APPID",
|
||||
"wechatmp_app_secret": "APPSECRET",
|
||||
"wechatmp_aes_key": ""
|
||||
}
|
||||
```
|
||||
|
||||
**企业服务号(wechatmp_service)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatmp_service",
|
||||
"wechatmp_token": "TOKEN",
|
||||
"wechatmp_port": 80,
|
||||
"wechatmp_app_id": "APPID",
|
||||
"wechatmp_app_secret": "APPSECRET",
|
||||
"wechatmp_aes_key": ""
|
||||
}
|
||||
```
|
||||
|
||||
详细步骤和参数说明参考 [微信公众号接入](https://docs.link-ai.tech/cow/multi-platform/wechat-mp)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>6. Terminal - 终端</summary>
|
||||
|
||||
修改 `config.json` 中的 `channel_type` 字段:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "terminal"
|
||||
}
|
||||
```
|
||||
|
||||
运行后可在终端与机器人进行对话。
|
||||
|
||||
</details>
|
||||
|
||||
<br/>
|
||||
|
||||
# 🔗 相关项目
|
||||
## 🤖 Models
|
||||
|
||||
- [bot-on-anything](https://github.com/zhayujie/bot-on-anything):轻量和高可扩展的大模型应用框架,支持接入Slack, Telegram, Discord, Gmail等海外平台,可作为本项目的补充使用。
|
||||
- [AgentMesh](https://github.com/MinimalFuture/AgentMesh):开源的多智能体(Multi-Agent)框架,可以通过多智能体团队的协同来解决复杂问题。本项目基于该框架实现了[Agent插件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md),可访问终端、浏览器、文件系统、搜索引擎 等各类工具,并实现了多智能体协同。
|
||||
CowAgent supports all mainstream LLM providers. **Chat, vision, image generation, ASR/TTS, and embeddings** can each be routed to a different vendor. Providers are configured directly in the Web console — no manual file editing required.
|
||||
|
||||
| Provider | Featured Models | Chat | Vision | Image Gen | ASR | TTS | Embedding |
|
||||
| --- | --- | :-: | :-: | :-: | :-: | :-: | :-: |
|
||||
| [Claude](https://docs.cowagent.ai/en/models/claude) | claude-opus-4-7 | ✅ | ✅ | | | | |
|
||||
| [OpenAI](https://docs.cowagent.ai/en/models/openai) | gpt-5.5, o-series | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Gemini](https://docs.cowagent.ai/en/models/gemini) | gemini-3.5-flash | ✅ | ✅ | ✅ | | | |
|
||||
| [DeepSeek](https://docs.cowagent.ai/en/models/deepseek) | deepseek-v4-flash / pro | ✅ | | | | | |
|
||||
| [Qwen](https://docs.cowagent.ai/en/models/qwen) | qwen3.7-max | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [GLM](https://docs.cowagent.ai/en/models/glm) | glm-5.1, glm-5v-turbo | ✅ | ✅ | | ✅ | | ✅ |
|
||||
| [Doubao](https://docs.cowagent.ai/en/models/doubao) | doubao-seed-2.0 series | ✅ | ✅ | ✅ | | | ✅ |
|
||||
| [Kimi](https://docs.cowagent.ai/en/models/kimi) | kimi-k2.6 | ✅ | ✅ | | | | |
|
||||
| [MiniMax](https://docs.cowagent.ai/en/models/minimax) | MiniMax-M2.7 | ✅ | ✅ | ✅ | | ✅ | |
|
||||
| [ERNIE](https://docs.cowagent.ai/en/models/qianfan) | ernie-5.1 | ✅ | ✅ | | | | |
|
||||
| [MiMo](https://docs.cowagent.ai/en/models/mimo) | mimo-v2.5 / pro | ✅ | ✅ | | | ✅ | |
|
||||
| [LinkAI](https://docs.cowagent.ai/en/models/linkai) | One key for 100+ models | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Custom](https://docs.cowagent.ai/en/models/custom) | Local models / third-party proxy | ✅ | | | | | |
|
||||
|
||||
> For details on each provider, see the [Models overview](https://docs.cowagent.ai/en/models/index).
|
||||
|
||||
# 🔎 常见问题
|
||||
<br/>
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
## 💬 Channels
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (知识库持续完善中,回复供参考)
|
||||
A single Agent instance can serve multiple channels in parallel. Most channels can be onboarded right from the Web console.
|
||||
|
||||
# 🛠️ 开发
|
||||
| Channel | Text | Image | File | Voice | Group |
|
||||
| --- | :-: | :-: | :-: | :-: | :-: |
|
||||
| [Web Console](https://docs.cowagent.ai/en/channels/web) (default) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [WeChat](https://docs.cowagent.ai/en/channels/weixin) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Feishu / Lark](https://docs.cowagent.ai/en/channels/feishu) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [DingTalk](https://docs.cowagent.ai/en/channels/dingtalk) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [WeCom Bot](https://docs.cowagent.ai/en/channels/wecom-bot) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [QQ](https://docs.cowagent.ai/en/channels/qq) | ✅ | ✅ | ✅ | | ✅ |
|
||||
| [WeCom App](https://docs.cowagent.ai/en/channels/wecom) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [WeChat Official Account](https://docs.cowagent.ai/en/channels/wechatmp) | ✅ | ✅ | | ✅ | |
|
||||
|
||||
欢迎接入更多应用通道,参考 [飞书通道](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/feishu/feishu_channel.py) 新增自定义通道,实现接收和发送消息逻辑即可完成接入。 同时欢迎贡献新的Skills,参考 [Skill创造器说明](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/skills/skill-creator/SKILL.md)。
|
||||
> See the [Channels overview](https://docs.cowagent.ai/en/channels/index) for setup details.
|
||||
|
||||
# ✉ 联系
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-chat.png" alt="CowAgent Web Console" width="800"/>
|
||||
|
||||
欢迎提交PR、Issues进行反馈,以及通过 🌟Star 支持并关注项目更新。项目运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,以及前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品客服](https://cdn.link-ai.tech/portal/linkai-customer-service.png)咨询。
|
||||
*The Web console is the default channel and the unified entry point to configure models, channels, skills, memory, and more.*
|
||||
|
||||
# 🌟 贡献者
|
||||
<br/>
|
||||
|
||||

|
||||
## 🧠 Memory & Knowledge Base
|
||||
|
||||
**Long-term memory** uses a three-tier architecture: conversation context (short-term) → daily memory (mid-term) → MEMORY.md (long-term). A nightly **Deep Dream** pass distills scattered memories into refined long-term entries and a narrative journal. See [Long-term Memory](https://docs.cowagent.ai/en/memory/index) · [Deep Dream](https://docs.cowagent.ai/en/memory/deep-dream).
|
||||
|
||||
**Personal knowledge base** complements the time-ordered memory by organizing structured knowledge **by topic**. The Agent automatically curates valuable information from conversations, maintains cross-references and indexes, and the Web console offers an interactive knowledge-graph view. See [Personal Knowledge Base](https://docs.cowagent.ai/en/knowledge/index).
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="50%">
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-memory.png" alt="Long-term Memory" />
|
||||
<p align="center"><em>Long-term Memory · Three-tier architecture + Deep Dream</em></p>
|
||||
</td>
|
||||
<td width="50%">
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-knowledge.png" alt="Personal Knowledge Base" />
|
||||
<p align="center"><em>Knowledge Base · Auto-curated Markdown wiki</em></p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<br/>
|
||||
|
||||
## 🔧 Tools & Skills
|
||||
|
||||
**Tools** are atomic capabilities the Agent uses to interact with system resources. **Skills** are higher-level workflows defined by a manifest file that compose multiple tools to accomplish complex tasks.
|
||||
|
||||
### Tool System
|
||||
|
||||
**Built-in tools** cover file I/O (`read` / `write` / `edit` / `ls`), terminal (`bash`), file sending (`send`), memory retrieval (`memory`), environment variables (`env_config`), web fetching (`web_fetch`), scheduling (`scheduler`), web search (`web_search`), vision (`vision`), and browser automation (`browser`).
|
||||
|
||||
**MCP protocol** integrates the open ecosystem of [Model Context Protocol](https://modelcontextprotocol.io) servers. A single `mcp.json` is enough — supports stdio / SSE transports, hot reload, and zero-code integration.
|
||||
|
||||
Learn more: [Tools overview](https://docs.cowagent.ai/en/tools/index) · [MCP integration](https://docs.cowagent.ai/en/tools/mcp).
|
||||
|
||||
### Skills System
|
||||
|
||||
- **[Skill Hub](https://skills.cowagent.ai/)** — open skill marketplace: browse, search, install in one click
|
||||
- **GitHub / ClawHub / URL and more** — install skills from any source
|
||||
- **Conversational authoring** — generate custom skills through dialogue with `skill-creator`; turn any workflow or third-party API into a reusable skill
|
||||
|
||||
```bash
|
||||
/skill list # list installed skills
|
||||
/skill search <keyword> # search the marketplace
|
||||
/skill install <name> # one-click install
|
||||
```
|
||||
|
||||
Learn more: [Skills overview](https://docs.cowagent.ai/en/skills/index) · [Creating Skills](https://docs.cowagent.ai/en/skills/create).
|
||||
|
||||
<br/>
|
||||
|
||||
## 🏷 Changelog
|
||||
|
||||
> **2026.05.22:** [v2.0.9](https://github.com/zhayujie/CowAgent/releases/tag/2.0.9) — Model management, MCP protocol support, persistent browser sessions, new models (gpt-5.5, gemini-3.5-flash, qwen3.7-max), deployment hardening.
|
||||
|
||||
> **2026.05.06:** [v2.0.8](https://github.com/zhayujie/CowAgent/releases/tag/2.0.8) — Feishu channel overhaul (voice, streaming, QR onboarding), DeepSeek V4 and Baidu Qianfan support, scheduler tool upgrades.
|
||||
|
||||
> **2026.04.22:** [v2.0.7](https://github.com/zhayujie/CowAgent/releases/tag/2.0.7) — Built-in image generation (GPT Image 2, Nano Banana), new models (Kimi K2.6, Claude Opus 4.7, GLM 5.1), memory and knowledge enhancements.
|
||||
|
||||
> **2026.04.14:** [v2.0.6](https://github.com/zhayujie/CowAgent/releases/tag/2.0.6) — Knowledge base, Deep Dream memory distillation, smart context compression, multi-session Web console.
|
||||
|
||||
> **2026.04.01:** [v2.0.5](https://github.com/zhayujie/CowAgent/releases/tag/2.0.5) — Cow CLI, Skill Hub open source, browser tool, WeCom Bot QR onboarding.
|
||||
|
||||
> **2026.02.03:** [v2.0.0](https://github.com/zhayujie/CowAgent/releases/tag/2.0.0) — Major upgrade to a super Agent assistant with multi-step task planning, long-term memory, and the Skills framework.
|
||||
|
||||
Full history: [Release Notes](https://docs.cowagent.ai/en/releases/overview)
|
||||
|
||||
<br/>
|
||||
|
||||
## 🤝 Community & Support
|
||||
|
||||
[File an issue](https://github.com/zhayujie/CowAgent/issues) on GitHub, or scan the QR code below to join our WeChat community:
|
||||
|
||||
<img width="130" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
|
||||
<br/>
|
||||
|
||||
## 🔗 Related Projects
|
||||
|
||||
- **[Cow Skill Hub](https://github.com/zhayujie/cow-skill-hub)** — open skill marketplace for AI Agents; works with CowAgent, OpenClaw, Claude Code, and more
|
||||
- **[bot-on-anything](https://github.com/zhayujie/bot-on-anything)** — lightweight LLM application framework with integrations for Slack, Telegram, Discord, Gmail, and more
|
||||
- **[AgentMesh](https://github.com/MinimalFuture/AgentMesh)** — open-source multi-agent framework for solving complex problems through team collaboration
|
||||
|
||||
<br/>
|
||||
|
||||
## 🏢 Enterprise Services
|
||||
|
||||
[**LinkAI**](https://link-ai.tech/) is an all-in-one AI Agent platform for enterprises and developers, offering managed hosting and enterprise-grade support for CowAgent:
|
||||
|
||||
- **🚀 Zero-deployment hosted runtime** — spin up a [CowAgent online assistant](https://link-ai.tech/cowagent/create) in under a minute, no server required
|
||||
- **🧠 Agent infrastructure** — unified access to LLMs, knowledge bases, databases, skills, and workflows; plug-and-play building blocks that extend what CowAgent can do
|
||||
- **🏢 Team & enterprise features** — workspaces, role-based access, audit logs, and private deployment for production use cases
|
||||
|
||||
For enterprise inquiries: sales@simple-future.tech or [scan the QR code](https://cdn.link-ai.tech/consultant.jpg) to reach our team on WeChat.
|
||||
|
||||
<br/>
|
||||
|
||||
## 🛠️ Development & Contributing
|
||||
|
||||
Contributions are welcome — add a new channel by following the [Feishu channel reference](https://github.com/zhayujie/CowAgent/blob/master/channel/feishu/feishu_channel.py), or contribute new skills to [Skill Hub](https://skills.cowagent.ai/submit).
|
||||
|
||||
⭐ Star the project to follow updates, and feel free to open PRs and Issues.
|
||||
|
||||
## 🌟 Contributors
|
||||
|
||||

|
||||
|
||||
<br/>
|
||||
|
||||
## ⚠️ Disclaimer
|
||||
|
||||
1. This project is licensed under the [MIT License](/LICENSE) and is intended for technical research and learning. You are responsible for complying with applicable laws and regulations in your jurisdiction; the maintainers assume no liability for any consequences arising from use of this project.
|
||||
2. **Cost & safety:** Agent mode consumes substantially more tokens than regular chat — pick models that balance quality and cost. The Agent has access to your local operating system, so only deploy it in trusted environments.
|
||||
3. CowAgent is a pure open-source project and does not participate in, authorize, or issue any cryptocurrency.
|
||||
|
||||
<br/>
|
||||
|
||||
## 📌 Project Renaming Notice
|
||||
|
||||
This project was previously named `chatgpt-on-wechat` and is now officially **CowAgent**. The old GitHub URL redirects automatically; existing users may optionally run `git remote set-url origin https://github.com/zhayujie/CowAgent.git` to update the local remote.
|
||||
|
||||
@@ -27,7 +27,8 @@ class ChatService:
|
||||
"""
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def run(self, query: str, session_id: str, send_chunk_fn: Callable[[dict], None]):
|
||||
def run(self, query: str, session_id: str, send_chunk_fn: Callable[[dict], None],
|
||||
channel_type: str = ""):
|
||||
"""
|
||||
Run the agent for *query* and stream results back via *send_chunk_fn*.
|
||||
|
||||
@@ -37,11 +38,17 @@ class ChatService:
|
||||
:param query: user query text
|
||||
:param session_id: session identifier for agent isolation
|
||||
:param send_chunk_fn: callable(chunk_data: dict) to send a streaming chunk
|
||||
:param channel_type: source channel (e.g. "web", "feishu") for persistence
|
||||
"""
|
||||
agent = self.agent_bridge.get_agent(session_id=session_id)
|
||||
if agent is None:
|
||||
raise RuntimeError("Failed to initialise agent for the session")
|
||||
|
||||
# Pass context metadata to model for downstream API requests
|
||||
if hasattr(agent, 'model'):
|
||||
agent.model.channel_type = channel_type or ""
|
||||
agent.model.session_id = session_id or ""
|
||||
|
||||
# State shared between the event callback and this method
|
||||
state = _StreamState()
|
||||
|
||||
@@ -50,7 +57,16 @@ class ChatService:
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "message_update":
|
||||
if event_type == "reasoning_update":
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "reasoning",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_update":
|
||||
# Incremental text delta
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
@@ -68,9 +84,41 @@ class ChatService:
|
||||
# a new segment; collect tool results until turn_end.
|
||||
state.pending_tool_results = []
|
||||
|
||||
elif event_type == "tool_execution_end":
|
||||
elif event_type == "file_to_send":
|
||||
url = data.get("url") or ""
|
||||
if url:
|
||||
fname = data.get("file_name") or "file"
|
||||
ft = data.get("file_type") or "file"
|
||||
if ft == "image":
|
||||
link = f""
|
||||
else:
|
||||
link = f"[{fname}]({url})"
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": "\n\n" + link + "\n\n",
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
# Remove url so the model won't repeat it in its reply
|
||||
data.pop("url", None)
|
||||
|
||||
elif event_type == "tool_execution_start":
|
||||
# Notify the client that a tool is about to run (with its input args)
|
||||
tool_name = data.get("tool_name", "")
|
||||
arguments = data.get("arguments", {})
|
||||
# Cache arguments keyed by tool_call_id so tool_execution_end can include them
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
state.pending_tool_arguments[tool_call_id] = arguments
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_start",
|
||||
"tool": tool_name,
|
||||
"arguments": arguments,
|
||||
})
|
||||
|
||||
elif event_type == "tool_execution_end":
|
||||
tool_name = data.get("tool_name", "")
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
# Retrieve cached arguments from the matching tool_execution_start event
|
||||
arguments = state.pending_tool_arguments.pop(tool_call_id, data.get("arguments", {}))
|
||||
result = data.get("result", "")
|
||||
status = data.get("status", "unknown")
|
||||
execution_time = data.get("execution_time", 0)
|
||||
@@ -111,7 +159,7 @@ class ChatService:
|
||||
logger.info(f"[ChatService] Starting agent run: session={session_id}, query={query[:80]}")
|
||||
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 30)
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = agent.get_full_system_prompt()
|
||||
@@ -144,10 +192,61 @@ class ChatService:
|
||||
logger.info("[ChatService] Cleared agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Append only the NEW messages from this execution (thread-safe)
|
||||
# Sync executor messages back to agent (thread-safe).
|
||||
# The executor may have trimmed context, making its list shorter than
|
||||
# original_length. In that case we must replace entirely — just
|
||||
# appending would leave stale pre-trim messages in agent.messages
|
||||
# and cause the same trim to fire on every subsequent request.
|
||||
with agent.messages_lock:
|
||||
new_messages = executor.messages[original_length:]
|
||||
agent.messages.extend(new_messages)
|
||||
trimmed = len(executor.messages) < original_length
|
||||
if trimmed:
|
||||
# Context was trimmed: the executor appended the new user
|
||||
# query *before* trimming, so the new messages (user +
|
||||
# assistant + tools) sit at the tail of the trimmed list.
|
||||
# We cannot simply slice at original_length (it exceeds the
|
||||
# list length). Instead, count how many messages the
|
||||
# executor added on top of the post-trim baseline.
|
||||
#
|
||||
# Timeline inside executor.run_stream:
|
||||
# 1. messages had `original_length` items
|
||||
# 2. append user query → original_length + 1
|
||||
# 3. _trim_messages() → some smaller number (includes the
|
||||
# user query because it belongs to the last turn)
|
||||
# 4. LLM replies / tool calls appended
|
||||
#
|
||||
# The user query message is always the first message of the
|
||||
# last turn (it cannot be trimmed away), so we locate it to
|
||||
# find where "new" messages begin.
|
||||
new_start = original_length # fallback
|
||||
for idx in range(len(executor.messages) - 1, -1, -1):
|
||||
msg = executor.messages[idx]
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", [])
|
||||
is_user_query = False
|
||||
if isinstance(content, list):
|
||||
has_text = any(
|
||||
isinstance(b, dict) and b.get("type") == "text"
|
||||
for b in content
|
||||
)
|
||||
has_tool_result = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
)
|
||||
is_user_query = has_text and not has_tool_result
|
||||
elif isinstance(content, str):
|
||||
is_user_query = True
|
||||
if is_user_query:
|
||||
new_start = idx
|
||||
break
|
||||
new_messages = list(executor.messages[new_start:])
|
||||
else:
|
||||
new_messages = list(executor.messages[original_length:])
|
||||
agent.messages = list(executor.messages)
|
||||
|
||||
# Persist new messages to SQLite so they survive restarts and
|
||||
# can be queried via the HISTORY interface.
|
||||
if new_messages:
|
||||
self._persist_messages(session_id, list(new_messages), channel_type)
|
||||
|
||||
# Store executor reference for files_to_send access
|
||||
agent.stream_executor = executor
|
||||
@@ -158,6 +257,26 @@ class ChatService:
|
||||
logger.info(f"[ChatService] Agent run completed: session={session_id}")
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _persist_messages(session_id: str, new_messages: list, channel_type: str = ""):
|
||||
try:
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().append_messages(
|
||||
session_id, new_messages, channel_type=channel_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[ChatService] Failed to persist messages for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
class _StreamState:
|
||||
"""Mutable state shared between the event callback and the run method."""
|
||||
|
||||
@@ -166,3 +285,6 @@ class _StreamState:
|
||||
# None means we are not accumulating tool results right now.
|
||||
# A list means we are in the middle of a tool-execution phase.
|
||||
self.pending_tool_results: Optional[list] = None
|
||||
# Maps tool_call_id -> arguments captured from tool_execution_start,
|
||||
# so that tool_execution_end can attach the correct input args.
|
||||
self.pending_tool_arguments: dict = {}
|
||||
|
||||
241
agent/chat/session_service.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
SessionService - Manages multi-session lifecycle for both web channel and cloud client.
|
||||
|
||||
Provides a unified interface for listing, deleting, renaming, clearing context,
|
||||
and generating AI titles for conversation sessions. Backed by ConversationStore
|
||||
(SQLite) and AgentBridge (in-memory agent instances).
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def _truncate_fallback_title(user_message: str, max_len: int = 30) -> str:
|
||||
"""Pick the first non-empty line of the user message and truncate it."""
|
||||
if not user_message:
|
||||
return "New Chat"
|
||||
first_line = ""
|
||||
for line in user_message.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
first_line = line
|
||||
break
|
||||
if not first_line:
|
||||
return "New Chat"
|
||||
if len(first_line) > max_len:
|
||||
first_line = first_line[:max_len].rstrip() + "..."
|
||||
return first_line
|
||||
|
||||
|
||||
def generate_session_title(user_message: str, assistant_reply: str = "") -> str:
|
||||
"""
|
||||
Generate a short session title by calling the current bot's reply_text.
|
||||
Falls back to the first line of the user message if the LLM call fails
|
||||
or returns an obvious error sentinel.
|
||||
"""
|
||||
fallback = _truncate_fallback_title(user_message)
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
from models.session_manager import Session
|
||||
bot = Bridge().get_bot("chat")
|
||||
|
||||
prompt_parts = [f"User: {user_message[:300]}"]
|
||||
if assistant_reply:
|
||||
prompt_parts.append(f"Assistant: {assistant_reply[:300]}")
|
||||
|
||||
session = Session("__title_gen__", system_prompt="")
|
||||
session.messages = [
|
||||
{"role": "user", "content": (
|
||||
"Generate a very short title (max 15 characters for Chinese, max 6 words for English) "
|
||||
"summarizing this conversation. Return ONLY the title text, nothing else.\n\n"
|
||||
+ "\n".join(prompt_parts)
|
||||
)}
|
||||
]
|
||||
|
||||
result = bot.reply_text(session) or {}
|
||||
# When bots fail (network error, auth error, rate limit, etc.) they
|
||||
# typically return completion_tokens=0 with a sentinel content like
|
||||
# "请再问我一次吧" / "我现在有点累了". Treat that as failure.
|
||||
completion_tokens = result.get("completion_tokens", 0) or 0
|
||||
raw = (result.get("content") or "").strip()
|
||||
if completion_tokens <= 0:
|
||||
logger.warning(
|
||||
f"[SessionService] Title generation got empty completion "
|
||||
f"(completion_tokens={completion_tokens}, content='{raw[:50]}'), "
|
||||
f"using fallback")
|
||||
return fallback
|
||||
|
||||
title = re.sub(r'<think>.*?</think>', '', raw, flags=re.DOTALL).strip().strip('"\'')
|
||||
logger.info(f"[SessionService] Title generation result: '{title}' (len={len(title)})")
|
||||
if title and len(title) <= 50:
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.warning(f"[SessionService] Title generation failed: {e}")
|
||||
return fallback
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""
|
||||
High-level service for session lifecycle management.
|
||||
|
||||
Usage:
|
||||
svc = SessionService()
|
||||
result = svc.dispatch("list", {"channel_type": "web", "page": 1})
|
||||
"""
|
||||
|
||||
def _get_store(self):
|
||||
from agent.memory import get_conversation_store
|
||||
return get_conversation_store()
|
||||
|
||||
def _remove_agent(self, session_id: str):
|
||||
"""Remove the in-memory Agent instance for a session if it exists."""
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
ab = Bridge().get_agent_bridge()
|
||||
if session_id in ab.agents:
|
||||
del ab.agents[session_id]
|
||||
logger.info(f"[SessionService] Removed agent instance: {session_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _normalize_sid(session_id: str) -> str:
|
||||
if session_id and not session_id.startswith("session_"):
|
||||
return f"session_{session_id}"
|
||||
return session_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# actions
|
||||
# ------------------------------------------------------------------
|
||||
def list_sessions(self, channel_type: Optional[str] = None,
|
||||
page: int = 1, page_size: int = 50) -> dict:
|
||||
store = self._get_store()
|
||||
return store.list_sessions(
|
||||
channel_type=channel_type,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
def delete_session(self, session_id: str) -> None:
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
store.clear_session(session_id)
|
||||
self._remove_agent(session_id)
|
||||
logger.info(f"[SessionService] Session deleted: {session_id}")
|
||||
|
||||
def rename_session(self, session_id: str, title: str) -> None:
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
if not title:
|
||||
raise ValueError("title required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
found = store.rename_session(session_id, title)
|
||||
if not found:
|
||||
raise ValueError("session not found")
|
||||
|
||||
def clear_context(self, session_id: str) -> int:
|
||||
"""
|
||||
Set context boundary. Returns the new context_start_seq value.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
new_seq = store.clear_context(session_id)
|
||||
self._remove_agent(session_id)
|
||||
return new_seq
|
||||
|
||||
def gen_title(self, session_id: str, user_message: str,
|
||||
assistant_reply: str = "") -> str:
|
||||
"""
|
||||
Generate an AI title and persist it. Returns the generated title.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
if not user_message:
|
||||
raise ValueError("user_message required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
title = generate_session_title(user_message, assistant_reply)
|
||||
|
||||
store = self._get_store()
|
||||
updated = store.rename_session(session_id, title)
|
||||
logger.info(f"[SessionService] Title set: sid={session_id}, "
|
||||
f"title='{title}', db_updated={updated}")
|
||||
return title
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a session management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
Action names use a ``*_session`` / session-prefixed convention so they
|
||||
can coexist with history actions (e.g. ``query``) on the same HISTORY
|
||||
message channel without ambiguity.
|
||||
|
||||
Supported actions:
|
||||
- list_sessions: list sessions with pagination
|
||||
- delete_session: delete a session
|
||||
- rename_session: rename a session title
|
||||
- clear_context: set context boundary
|
||||
- generate_title: AI-generate a session title
|
||||
|
||||
:param action: one of the above action names
|
||||
:param payload: action-specific payload
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list_sessions":
|
||||
result = self.list_sessions(
|
||||
channel_type=payload.get("channel_type"),
|
||||
page=int(payload.get("page", 1)),
|
||||
page_size=int(payload.get("page_size", 50)),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "delete_session":
|
||||
self.delete_session(payload.get("session_id", ""))
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
|
||||
elif action == "rename_session":
|
||||
self.rename_session(
|
||||
payload.get("session_id", ""),
|
||||
payload.get("title", "").strip(),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
|
||||
elif action == "clear_context":
|
||||
new_seq = self.clear_context(payload.get("session_id", ""))
|
||||
return {"action": action, "code": 200, "message": "success",
|
||||
"payload": {"context_start_seq": new_seq}}
|
||||
|
||||
elif action == "generate_title":
|
||||
title = self.gen_title(
|
||||
payload.get("session_id", ""),
|
||||
payload.get("user_message", ""),
|
||||
payload.get("assistant_reply", ""),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success",
|
||||
"payload": {"title": title}}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400,
|
||||
"message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 400, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SessionService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
0
agent/knowledge/__init__.py
Normal file
240
agent/knowledge/service.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Knowledge service for handling knowledge base operations.
|
||||
|
||||
Provides a unified interface for listing, reading, and graphing knowledge files,
|
||||
callable from the web console, API, or CLI.
|
||||
|
||||
Knowledge file layout (under workspace_root):
|
||||
knowledge/index.md
|
||||
knowledge/log.md
|
||||
knowledge/<category>/<slug>.md
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
"""
|
||||
High-level service for knowledge base queries.
|
||||
Operates directly on the filesystem.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
self.workspace_root = workspace_root
|
||||
self.knowledge_dir = os.path.join(workspace_root, "knowledge")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — directory tree with stats
|
||||
# ------------------------------------------------------------------
|
||||
def list_tree(self) -> dict:
|
||||
"""
|
||||
Return the knowledge directory tree grouped by category,
|
||||
supporting arbitrarily nested sub-directories.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"tree": [
|
||||
{
|
||||
"dir": "concepts",
|
||||
"files": [
|
||||
{"name": "moe.md", "title": "MoE", "size": 1234},
|
||||
],
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"dir": "platform",
|
||||
"files": [],
|
||||
"children": [
|
||||
{
|
||||
"dir": "analysis",
|
||||
"files": [{"name": "perf.md", ...}],
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
],
|
||||
"stats": {"pages": 15, "size": 32768},
|
||||
"enabled": true
|
||||
}
|
||||
"""
|
||||
if not os.path.isdir(self.knowledge_dir):
|
||||
return {"tree": [], "stats": {"pages": 0, "size": 0}, "enabled": conf().get("knowledge", True)}
|
||||
|
||||
stats = {"pages": 0, "size": 0}
|
||||
root_files, tree = self._scan_dir(self.knowledge_dir, stats, is_root=True)
|
||||
|
||||
return {
|
||||
"root_files": root_files,
|
||||
"tree": tree,
|
||||
"stats": stats,
|
||||
"enabled": conf().get("knowledge", True),
|
||||
}
|
||||
|
||||
def _scan_dir(self, dir_path: str, stats: dict, is_root: bool = False) -> tuple:
|
||||
"""
|
||||
Recursively scan a directory.
|
||||
|
||||
:return: (files, children) where files is a list of .md file dicts
|
||||
in this directory and children is a list of sub-directory nodes.
|
||||
"""
|
||||
files = []
|
||||
children = []
|
||||
for name in sorted(os.listdir(dir_path)):
|
||||
if name.startswith("."):
|
||||
continue
|
||||
full = os.path.join(dir_path, name)
|
||||
if os.path.isdir(full):
|
||||
sub_files, sub_children = self._scan_dir(full, stats)
|
||||
children.append({"dir": name, "files": sub_files, "children": sub_children})
|
||||
elif name.endswith(".md"):
|
||||
size = os.path.getsize(full)
|
||||
if not is_root:
|
||||
stats["pages"] += 1
|
||||
stats["size"] += size
|
||||
title = name.replace(".md", "")
|
||||
try:
|
||||
with open(full, "r", encoding="utf-8") as f:
|
||||
first_line = f.readline().strip()
|
||||
if first_line.startswith("# "):
|
||||
title = first_line[2:].strip()
|
||||
except Exception:
|
||||
pass
|
||||
files.append({"name": name, "title": title, "size": size})
|
||||
return files, children
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# read — single file content
|
||||
# ------------------------------------------------------------------
|
||||
def read_file(self, rel_path: str) -> dict:
|
||||
"""
|
||||
Read a single knowledge markdown file.
|
||||
|
||||
:param rel_path: Relative path within knowledge/, e.g. ``concepts/moe.md``
|
||||
:return: dict with ``content`` and ``path``
|
||||
:raises ValueError: if path is invalid or escapes knowledge dir
|
||||
:raises FileNotFoundError: if file does not exist
|
||||
"""
|
||||
if not rel_path or ".." in rel_path:
|
||||
raise ValueError("invalid path")
|
||||
|
||||
full_path = os.path.normpath(os.path.join(self.knowledge_dir, rel_path))
|
||||
allowed = os.path.normpath(self.knowledge_dir)
|
||||
if not full_path.startswith(allowed + os.sep) and full_path != allowed:
|
||||
raise ValueError("path outside knowledge dir")
|
||||
|
||||
if not os.path.isfile(full_path):
|
||||
raise FileNotFoundError(f"file not found: {rel_path}")
|
||||
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
return {"content": content, "path": rel_path}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# graph — nodes and links for visualization
|
||||
# ------------------------------------------------------------------
|
||||
def build_graph(self) -> dict:
|
||||
"""
|
||||
Parse all knowledge pages and extract cross-reference links.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"nodes": [
|
||||
{"id": "concepts/moe.md", "label": "MoE", "category": "concepts"},
|
||||
...
|
||||
],
|
||||
"links": [
|
||||
{"source": "concepts/moe.md", "target": "entities/deepseek.md"},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
knowledge_path = Path(self.knowledge_dir)
|
||||
if not knowledge_path.is_dir():
|
||||
return {"nodes": [], "links": []}
|
||||
|
||||
nodes = {}
|
||||
links = []
|
||||
link_re = re.compile(r'\[([^\]]*)\]\(([^)]+\.md)\)')
|
||||
|
||||
for md_file in knowledge_path.rglob("*.md"):
|
||||
rel = str(md_file.relative_to(knowledge_path))
|
||||
if rel in ("index.md", "log.md"):
|
||||
continue
|
||||
parts = rel.split("/")
|
||||
category = parts[0] if len(parts) > 1 else "root"
|
||||
title = md_file.stem.replace("-", " ").title()
|
||||
try:
|
||||
content = md_file.read_text(encoding="utf-8")
|
||||
first_line = content.strip().split("\n")[0]
|
||||
if first_line.startswith("# "):
|
||||
title = first_line[2:].strip()
|
||||
for _, link_target in link_re.findall(content):
|
||||
resolved = (md_file.parent / link_target).resolve()
|
||||
try:
|
||||
target_rel = str(resolved.relative_to(knowledge_path))
|
||||
except ValueError:
|
||||
continue
|
||||
if target_rel != rel:
|
||||
links.append({"source": rel, "target": target_rel})
|
||||
except Exception:
|
||||
pass
|
||||
nodes[rel] = {"id": rel, "label": title, "category": category}
|
||||
|
||||
valid_ids = set(nodes.keys())
|
||||
links = [l for l in links if l["source"] in valid_ids and l["target"] in valid_ids]
|
||||
seen = set()
|
||||
deduped = []
|
||||
for l in links:
|
||||
key = tuple(sorted([l["source"], l["target"]]))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(l)
|
||||
|
||||
return {"nodes": list(nodes.values()), "links": deduped}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a knowledge management action.
|
||||
|
||||
:param action: ``list``, ``read``, or ``graph``
|
||||
:param payload: action-specific payload
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
result = self.list_tree()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "read":
|
||||
path = payload.get("path")
|
||||
if not path:
|
||||
return {"action": action, "code": 400, "message": "path is required", "payload": None}
|
||||
result = self.read_file(path)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "graph":
|
||||
result = self.build_graph()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 403, "message": str(e), "payload": None}
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[KnowledgeService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
@@ -1,11 +1,23 @@
|
||||
"""
|
||||
Memory module for AgentMesh
|
||||
|
||||
Provides long-term memory capabilities with hybrid search (vector + keyword)
|
||||
Provides both long-term memory (vector/keyword search) and short-term
|
||||
conversation history persistence (SQLite).
|
||||
"""
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
from agent.memory.embedding import create_embedding_provider
|
||||
from agent.memory.conversation_store import ConversationStore, get_conversation_store
|
||||
from agent.memory.summarizer import ensure_daily_memory_file
|
||||
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider']
|
||||
__all__ = [
|
||||
'MemoryManager',
|
||||
'MemoryConfig',
|
||||
'get_default_memory_config',
|
||||
'set_global_memory_config',
|
||||
'create_embedding_provider',
|
||||
'ConversationStore',
|
||||
'get_conversation_store',
|
||||
'ensure_daily_memory_file',
|
||||
]
|
||||
|
||||
@@ -48,9 +48,6 @@ class MemoryConfig:
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
# Memory flush config (独立于模型 context window)
|
||||
flush_token_threshold: int = 50000 # 50K tokens 触发 flush
|
||||
flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
|
||||
1055
agent/memory/conversation_store.py
Normal file
@@ -1,161 +0,0 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports OpenAI and local embedding models
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Get embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using REST API"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
Initialize OpenAI embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
||||
api_key: OpenAI API key
|
||||
api_base: Optional API base URL
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
|
||||
# Validate API key
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("OpenAI API key is not configured. Please set 'open_ai_api_key' in config.json")
|
||||
|
||||
# Set dimensions based on model
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI embedding API using requests"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=5)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to OpenAI API at {url}. Please check your network connection and api_base configuration. Error: {str(e)}")
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"OpenAI API request timed out after 10s. Please check your network connection. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError(f"Invalid OpenAI API key. Please check your 'open_ai_api_key' in config.json")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError(f"OpenAI API rate limit exceeded. Please try again later.")
|
||||
else:
|
||||
raise ValueError(f"OpenAI API request failed: {e.response.status_code} - {e.response.text}")
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
result = self._call_api(text)
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
result = self._call_api(texts)
|
||||
return [item["embedding"] for item in result["data"]]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
"""Get cached embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
"""Cache embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
"""Compute cache key"""
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
"""Clear cache"""
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Only supports OpenAI embedding via REST API.
|
||||
If initialization fails, caller should fall back to keyword-only search.
|
||||
|
||||
Args:
|
||||
provider: Provider name (only "openai" is supported)
|
||||
model: Model name (default: text-embedding-3-small)
|
||||
api_key: OpenAI API key (required)
|
||||
api_base: API base URL (default: https://api.openai.com/v1)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not "openai" or api_key is missing
|
||||
"""
|
||||
if provider != "openai":
|
||||
raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
|
||||
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
41
agent/memory/embedding/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Embedding subsystem for memory.
|
||||
|
||||
Public API:
|
||||
create_embedding_provider, EmbeddingProvider, OpenAIEmbeddingProvider,
|
||||
EMBEDDING_VENDORS, EmbeddingCache
|
||||
RebuildResult, clear_index, rebuild_in_process
|
||||
detect_index_dim, cleanup_legacy_state_file
|
||||
"""
|
||||
|
||||
from agent.memory.embedding.provider import (
|
||||
EMBEDDING_VENDORS,
|
||||
DoubaoEmbeddingProvider,
|
||||
EmbeddingCache,
|
||||
EmbeddingProvider,
|
||||
OpenAIEmbeddingProvider,
|
||||
create_embedding_provider,
|
||||
)
|
||||
from agent.memory.embedding.rebuild import (
|
||||
RebuildResult,
|
||||
clear_index,
|
||||
rebuild_in_process,
|
||||
)
|
||||
from agent.memory.embedding.state import (
|
||||
cleanup_legacy_state_file,
|
||||
detect_index_dim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EMBEDDING_VENDORS",
|
||||
"DoubaoEmbeddingProvider",
|
||||
"EmbeddingCache",
|
||||
"EmbeddingProvider",
|
||||
"OpenAIEmbeddingProvider",
|
||||
"create_embedding_provider",
|
||||
"RebuildResult",
|
||||
"clear_index",
|
||||
"rebuild_in_process",
|
||||
"cleanup_legacy_state_file",
|
||||
"detect_index_dim",
|
||||
]
|
||||
486
agent/memory/embedding/provider.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports multiple OpenAI-compatible embedding vendors:
|
||||
- openai (text-embedding-3-small / large)
|
||||
- linkai (OpenAI-compatible passthrough)
|
||||
- dashscope (Aliyun Tongyi text-embedding-v4)
|
||||
- doubao (ByteDance Doubao Seed1.5 / large-text on Volcengine Ark)
|
||||
- zhipu (ZhipuAI embedding-3)
|
||||
|
||||
Vendor keys here intentionally match the project's bot_type constants in
|
||||
common.const (OPENAI, LINKAI, QWEN_DASHSCOPE, DOUBAO, ZHIPU_AI).
|
||||
|
||||
All providers share a single OpenAI-compatible REST client. Vendor-specific
|
||||
behaviors (truncation, query instruction prefix) are configured via metadata.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
# HTTP read timeout for a single embeddings request (seconds). A batch of
|
||||
# 64+ chunks can take 30-50s end-to-end from China-side networks, so 30s is
|
||||
# routinely too tight; 90s gives meaningful headroom without letting bad
|
||||
# endpoints hang forever.
|
||||
EMBEDDING_HTTP_TIMEOUT = 90
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a single text (treated as a query by default)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts (treated as documents)"""
|
||||
pass
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a query string (may apply vendor instruction prefix)"""
|
||||
return self.embed(text)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Effective embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vendor metadata table
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Each entry describes how to reach a vendor's embedding endpoint. Most
|
||||
# vendors expose an OpenAI-compatible /embeddings API; the few that don't
|
||||
# (currently: doubao) set `provider_class` to pick a dedicated adapter.
|
||||
# Fields:
|
||||
# provider_class : optional adapter key ("doubao"); defaults to OpenAI-compat
|
||||
# default_base_url : default API base when not overridden by user
|
||||
# default_model : default embedding model name
|
||||
# default_dimensions : recommended unified dim when explicit path is enabled
|
||||
# supports_dim_param : whether the API accepts a `dimensions` request param
|
||||
# needs_client_truncate : whether to slice + L2-normalize on the client side
|
||||
# needs_client_normalize : whether to L2-normalize on the client (always safe)
|
||||
# query_instruction : optional prefix for asymmetric retrieval (Doubao Seed)
|
||||
# max_batch_size : max texts per /embeddings request; embed_batch
|
||||
# auto-paginates above this. Conservative defaults.
|
||||
#
|
||||
EMBEDDING_VENDORS = {
|
||||
"openai": {
|
||||
"default_base_url": "https://api.openai.com/v1",
|
||||
"default_model": "text-embedding-3-small",
|
||||
# Match the legacy default so users adding `embedding_provider: openai`
|
||||
# to an existing index don't need to rebuild. Override via
|
||||
# embedding_dimensions if you want 1024 / 1536 / 3072.
|
||||
"default_dimensions": 1536,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
# OpenAI permits up to 2048 items per request, but a single call
|
||||
# carrying hundreds of long chunks routinely exceeds the 30s read
|
||||
# timeout from China-side networks. 64 keeps each call well under
|
||||
# both the token-per-request budget and a reasonable wall clock.
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
"linkai": {
|
||||
"default_base_url": "https://api.link-ai.tech/v1",
|
||||
"default_model": "text-embedding-3-small",
|
||||
"default_dimensions": 1536,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
"dashscope": {
|
||||
"default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"default_model": "text-embedding-v4",
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 10, # DashScope hard cap (text-embedding-v4)
|
||||
},
|
||||
"doubao": {
|
||||
# Doubao no longer offers an OpenAI-compatible /v1/embeddings endpoint.
|
||||
# Current models are unified under /api/v3/embeddings/multimodal
|
||||
# which uses a structured `input` payload — see DoubaoEmbeddingProvider.
|
||||
"provider_class": "doubao",
|
||||
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||||
"default_model": "doubao-embedding-vision-251215",
|
||||
# Native options: 1024 or 2048. We default to 1024 to align with the
|
||||
# other Chinese vendors (dashscope/zhipu) and keep storage footprint
|
||||
# consistent across providers; users can still override via
|
||||
# `embedding_dimensions: 2048` in config.
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
# Multimodal endpoint produces ONE embedding per call (input list is
|
||||
# a single document's parts, not a batch). embed_batch loops.
|
||||
"max_batch_size": 1,
|
||||
},
|
||||
"zhipu": {
|
||||
"default_base_url": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"default_model": "embedding-3",
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _l2_normalize(vec: List[float]) -> List[float]:
|
||||
"""Normalize a vector to unit length (L2 norm). Returns input on zero vector."""
|
||||
norm = math.sqrt(sum(v * v for v in vec))
|
||||
if norm == 0:
|
||||
return vec
|
||||
return [v / norm for v in vec]
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
OpenAI-compatible embedding provider.
|
||||
|
||||
Used for openai/linkai/dashscope/ark/zhipu by configuring the metadata
|
||||
fields. The legacy two-arg constructor (model, api_key, api_base) keeps
|
||||
working, so the original OpenAI/LinkAI fallback code path is unchanged.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "text-embedding-3-small",
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
supports_dim_param: bool = True,
|
||||
needs_client_truncate: bool = False,
|
||||
needs_client_normalize: bool = False,
|
||||
query_instruction: str = "",
|
||||
max_batch_size: int = 256,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Model name (e.g. text-embedding-3-small, text-embedding-v4, embedding-3)
|
||||
api_key: API key (required)
|
||||
api_base: API base URL (defaults to OpenAI)
|
||||
extra_headers: Optional extra HTTP headers
|
||||
dimensions: Target output dimension. Required when supports_dim_param
|
||||
is False and needs_client_truncate is True (used to slice).
|
||||
supports_dim_param: Whether the vendor accepts a `dimensions` body param
|
||||
needs_client_truncate: Slice the returned vector to `dimensions`
|
||||
needs_client_normalize: L2-normalize on the client after slicing
|
||||
query_instruction: Optional prefix prepended to query texts only
|
||||
max_batch_size: Max items per /embeddings request; embed_batch
|
||||
auto-paginates above this.
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.supports_dim_param = supports_dim_param
|
||||
self.needs_client_truncate = needs_client_truncate
|
||||
self.needs_client_normalize = needs_client_normalize
|
||||
self.query_instruction = query_instruction or ""
|
||||
self.max_batch_size = max(1, int(max_batch_size or 1))
|
||||
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("Embedding API key is not configured")
|
||||
|
||||
if dimensions is not None and dimensions > 0:
|
||||
self._dimensions = dimensions
|
||||
else:
|
||||
# Legacy heuristic for OpenAI text-embedding-3-* family
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI-compatible /embeddings endpoint"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**self.extra_headers,
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model,
|
||||
}
|
||||
if self.supports_dim_param and self._dimensions:
|
||||
data["dimensions"] = self._dimensions
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=EMBEDDING_HTTP_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to embedding API at {url}. "
|
||||
f"Please check network and api_base. Error: {str(e)}"
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"Embedding API request timed out. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError("Invalid embedding API key")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError("Embedding API rate limit exceeded")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Embedding API request failed: "
|
||||
f"{e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
|
||||
def _post_process(self, raw: List[float]) -> List[float]:
|
||||
"""Apply optional client-side truncation + normalization"""
|
||||
vec = raw
|
||||
if self.needs_client_truncate and self._dimensions and len(vec) > self._dimensions:
|
||||
vec = vec[: self._dimensions]
|
||||
if self.needs_client_normalize:
|
||||
vec = _l2_normalize(vec)
|
||||
return vec
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding (treated as document by default)"""
|
||||
result = self._call_api(text)
|
||||
return self._post_process(result["data"][0]["embedding"])
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a query (applies vendor instruction prefix if any)"""
|
||||
if self.query_instruction:
|
||||
text = f"{self.query_instruction}{text}"
|
||||
return self.embed(text)
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple documents.
|
||||
|
||||
Automatically paginates by self.max_batch_size so callers can pass any
|
||||
number of texts. Order of returned vectors matches the input order.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
out: List[List[float]] = []
|
||||
step = self.max_batch_size
|
||||
for i in range(0, len(texts), step):
|
||||
chunk = texts[i:i + step]
|
||||
result = self._call_api(chunk)
|
||||
out.extend(self._post_process(item["embedding"]) for item in result["data"])
|
||||
return out
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class DoubaoEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
Doubao (Volcengine Ark) multimodal embedding provider.
|
||||
|
||||
Doubao deprecated their OpenAI-compatible /v1/embeddings endpoint and
|
||||
unified everything under /api/v3/embeddings/multimodal, which uses a
|
||||
structured `input: [{type, text|image_url|video_url}, ...]` payload.
|
||||
|
||||
Notes:
|
||||
* The endpoint produces ONE embedding per call (input list is multiple
|
||||
modality parts of a single document, not a batch). embed_batch
|
||||
therefore loops per-text — no native batch support.
|
||||
* Native dimensions: 1024 or 2048 (default 1024 to align with other
|
||||
Chinese vendors). No client-side truncation needed.
|
||||
* Auth: Bearer ARK API key.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
|
||||
self.extra_headers = extra_headers or {}
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("Doubao embedding API key (ark_api_key) is not configured")
|
||||
|
||||
if dimensions in (1024, 2048):
|
||||
self._dimensions = dimensions
|
||||
elif dimensions is None:
|
||||
self._dimensions = 1024
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Doubao embedding dimensions must be 1024 or 2048, got {dimensions}"
|
||||
)
|
||||
|
||||
def _call_api(self, text: str) -> List[float]:
|
||||
"""One call → one embedding. multimodal endpoint takes a single
|
||||
document represented as a list of typed parts; we send a single
|
||||
text part."""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings/multimodal"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**self.extra_headers,
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": [{"type": "text", "text": text}],
|
||||
"dimensions": self._dimensions,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=EMBEDDING_HTTP_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to Doubao embedding API at {url}. "
|
||||
f"Please check network and api_base. Error: {str(e)}"
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"Doubao embedding API request timed out. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError("Invalid Doubao (ark) embedding API key")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError("Doubao embedding API rate limit exceeded")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Doubao embedding API request failed: "
|
||||
f"{e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
|
||||
# Response shape per docs: {"data": {"embedding": [...]}}
|
||||
data = body.get("data")
|
||||
if isinstance(data, dict) and "embedding" in data:
|
||||
return data["embedding"]
|
||||
# Some providers wrap as a list of one — be defensive
|
||||
if isinstance(data, list) and data and "embedding" in data[0]:
|
||||
return data[0]["embedding"]
|
||||
raise ValueError(f"Unexpected Doubao embedding response shape: {body}")
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
return self._call_api(text)
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
# Endpoint produces one embedding per call; loop. Order preserved.
|
||||
return [self._call_api(t) for t in texts]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""In-memory cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create an embedding provider.
|
||||
|
||||
Backward compatible: when called with provider in {"openai", "linkai"}
|
||||
and no `dimensions` arg, behaves exactly as before (1536-dim OpenAI).
|
||||
|
||||
New providers ("dashscope", "doubao", "zhipu") require explicit configuration
|
||||
and use the unified 1024-dim defaults from EMBEDDING_VENDORS.
|
||||
|
||||
Args:
|
||||
provider: Vendor key (one of EMBEDDING_VENDORS)
|
||||
model: Model name (uses vendor default if None)
|
||||
api_key: API key (required)
|
||||
api_base: API base URL (uses vendor default if None)
|
||||
extra_headers: Optional extra HTTP headers
|
||||
dimensions: Target output dimension (uses vendor default if None)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
meta = EMBEDDING_VENDORS.get(provider)
|
||||
if meta is None:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding provider: {provider}. "
|
||||
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}"
|
||||
)
|
||||
|
||||
# Doubao uses a non-OpenAI-compatible multimodal endpoint.
|
||||
if meta.get("provider_class") == "doubao":
|
||||
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
||||
return DoubaoEmbeddingProvider(
|
||||
model=model or meta["default_model"],
|
||||
api_key=api_key,
|
||||
api_base=api_base or meta["default_base_url"],
|
||||
extra_headers=extra_headers,
|
||||
dimensions=final_dim,
|
||||
)
|
||||
|
||||
# Legacy two-arg call for openai/linkai keeps 1536-dim default behavior
|
||||
# so existing data isn't invalidated.
|
||||
is_legacy_call = (
|
||||
provider in ("openai", "linkai")
|
||||
and dimensions is None
|
||||
)
|
||||
if is_legacy_call:
|
||||
return OpenAIEmbeddingProvider(
|
||||
model=model or "text-embedding-3-small",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
||||
return OpenAIEmbeddingProvider(
|
||||
model=model or meta["default_model"],
|
||||
api_key=api_key,
|
||||
api_base=api_base or meta["default_base_url"],
|
||||
extra_headers=extra_headers,
|
||||
dimensions=final_dim,
|
||||
supports_dim_param=meta["supports_dim_param"],
|
||||
needs_client_truncate=meta["needs_client_truncate"],
|
||||
needs_client_normalize=meta["needs_client_normalize"],
|
||||
query_instruction=meta["query_instruction"],
|
||||
max_batch_size=meta.get("max_batch_size", 256),
|
||||
)
|
||||
191
agent/memory/embedding/rebuild.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Rebuild memory vector index.
|
||||
|
||||
Recommended entry point (in-chat, while agent is running):
|
||||
/memory rebuild-index
|
||||
|
||||
Backward-compatible CLI entry (must run from project root):
|
||||
python -m agent.memory.rebuild_index
|
||||
|
||||
What it does:
|
||||
1. Probes the embedding endpoint with a tiny call to fail fast on
|
||||
bad provider/model/key — before touching the index.
|
||||
2. Clears the SQLite chunks/files tables (workspace markdown stays intact).
|
||||
3. Runs a fresh sync, regenerating embeddings with the currently configured
|
||||
provider/model/dimensions.
|
||||
|
||||
This is the only safe way to switch embedding_provider after the existing
|
||||
index has been populated by a different-dim model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class RebuildResult:
|
||||
"""Outcome of a rebuild_in_process() call"""
|
||||
ok: bool
|
||||
removed: int = 0
|
||||
chunks: int = 0
|
||||
files: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def clear_index(db_path, storage=None) -> int:
|
||||
"""Wipe chunks/files, reset FTS5, and clean up any legacy state file.
|
||||
|
||||
Args:
|
||||
db_path: Path of the index DB (also used to locate the legacy state
|
||||
file for migration cleanup, and — when *storage* is None — to
|
||||
open a fresh connection).
|
||||
storage: Optional pre-opened MemoryStorage. When provided we reuse it
|
||||
so the live connection's triggers stay in sync — opening a second
|
||||
connection would leave the original one's triggers pointing at a
|
||||
DROP'd chunks_fts table.
|
||||
|
||||
We reset (DROP+recreate) chunks_fts because its shadow tables can become
|
||||
inconsistent across rebuild cycles, causing bm25() / ORDER BY rank to
|
||||
raise "database disk image is malformed" even when raw MATCH still works.
|
||||
|
||||
Returns number of chunks removed.
|
||||
"""
|
||||
from agent.memory.embedding.state import cleanup_legacy_state_file
|
||||
from agent.memory.storage import MemoryStorage
|
||||
|
||||
owns_storage = storage is None
|
||||
if owns_storage:
|
||||
storage = MemoryStorage(db_path)
|
||||
try:
|
||||
before = storage.conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
|
||||
storage.conn.execute("DELETE FROM chunks")
|
||||
storage.conn.execute("DELETE FROM files")
|
||||
storage.conn.commit()
|
||||
storage.reset_fts5()
|
||||
finally:
|
||||
if owns_storage:
|
||||
storage.close()
|
||||
|
||||
cleanup_legacy_state_file(db_path)
|
||||
return int(before)
|
||||
|
||||
|
||||
def rebuild_in_process(memory_manager) -> RebuildResult:
|
||||
"""
|
||||
Rebuild the index using an existing, fully-initialized MemoryManager.
|
||||
|
||||
Used by the in-chat /memory rebuild-index command. The caller already has
|
||||
config loaded, embedding_provider built, and (optionally) the agent
|
||||
running, so we only need to:
|
||||
1. Clear chunks/files + state on the manager's storage.
|
||||
2. Re-sync (force=True).
|
||||
|
||||
NOTE: caller must ensure memory_manager.embedding_provider is set, otherwise
|
||||
sync() will silently skip embedding generation.
|
||||
"""
|
||||
if memory_manager is None:
|
||||
return RebuildResult(ok=False, error="memory_manager is None")
|
||||
if memory_manager.embedding_provider is None:
|
||||
return RebuildResult(ok=False, error="embedding_provider is not initialized")
|
||||
|
||||
# Probe the embedding endpoint BEFORE clearing the index. A bad
|
||||
# provider/model/key would otherwise leave the user with an empty index
|
||||
# that not even keyword search can serve.
|
||||
try:
|
||||
memory_manager.embedding_provider.embed_query("ping")
|
||||
except Exception as e:
|
||||
logger.error(f"[RebuildIndex] embedding probe failed, aborting rebuild: {e}")
|
||||
return RebuildResult(ok=False, error=f"embedding endpoint not reachable: {e}")
|
||||
|
||||
db_path = memory_manager.config.get_db_path()
|
||||
try:
|
||||
removed = clear_index(db_path, storage=memory_manager.storage)
|
||||
except Exception as e:
|
||||
logger.exception("[RebuildIndex] clear_index failed")
|
||||
return RebuildResult(ok=False, error=f"clear failed: {e}")
|
||||
|
||||
try:
|
||||
asyncio.run(memory_manager.sync(force=True))
|
||||
except RuntimeError:
|
||||
# Already inside a running event loop (rare in chat handler thread).
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(memory_manager.sync(force=True))
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.exception("[RebuildIndex] sync failed")
|
||||
return RebuildResult(ok=False, removed=removed, error=f"re-embed failed: {e}")
|
||||
|
||||
stats = memory_manager.storage.get_stats()
|
||||
chunks = int(stats.get("chunks", 0))
|
||||
embedded = int(stats.get("embedded", 0))
|
||||
|
||||
# sync() degrades to "no embeddings" on batch failure so keyword search
|
||||
# still works at startup — but in a /rebuild-index request the user
|
||||
# explicitly asked for vectors. Surface that as a failure.
|
||||
if chunks > 0 and embedded == 0:
|
||||
return RebuildResult(
|
||||
ok=False,
|
||||
removed=removed,
|
||||
chunks=chunks,
|
||||
files=int(stats.get("files", 0)),
|
||||
error=(
|
||||
"embedding API failed during sync; index now has chunks but no "
|
||||
"vectors. Check embedding provider/model/key and retry."
|
||||
),
|
||||
)
|
||||
|
||||
return RebuildResult(
|
||||
ok=True,
|
||||
removed=removed,
|
||||
chunks=chunks,
|
||||
files=int(stats.get("files", 0)),
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Standalone CLI entry. Must be run from project root (relative config path)."""
|
||||
from config import conf, load_config
|
||||
from agent.memory import MemoryConfig, MemoryManager
|
||||
|
||||
load_config()
|
||||
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
logger.info(f"[RebuildIndex] Workspace: {workspace_root}")
|
||||
logger.info(f"[RebuildIndex] Index db: {memory_config.get_db_path()}")
|
||||
|
||||
from bridge.agent_initializer import AgentInitializer
|
||||
|
||||
initializer = AgentInitializer(bridge=None, agent_bridge=None)
|
||||
embedding_provider = initializer._init_embedding_provider(memory_config, session_id=None)
|
||||
if embedding_provider is None:
|
||||
logger.error(
|
||||
"[RebuildIndex] No embedding provider could be initialized. "
|
||||
"Check your config.json. Aborting rebuild."
|
||||
)
|
||||
return 1
|
||||
|
||||
manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
result = rebuild_in_process(manager)
|
||||
if not result.ok:
|
||||
logger.error(f"[RebuildIndex] {result.error}")
|
||||
return 1
|
||||
|
||||
logger.info(
|
||||
f"[RebuildIndex] Done. removed={result.removed}, "
|
||||
f"chunks={result.chunks}, files={result.files}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
51
agent/memory/embedding/state.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Embedding-related index utilities.
|
||||
|
||||
We don't keep a sidecar state file — the SQLite index is the source of truth
|
||||
and config.json is the source of intent. The two functions below are the
|
||||
only things needing on-disk awareness:
|
||||
|
||||
detect_index_dim : read the dim of stored vectors (display-only)
|
||||
cleanup_legacy_state_file: remove old embedding_state.json from earlier
|
||||
versions; safe no-op when absent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
PathLike = Union[str, os.PathLike]
|
||||
|
||||
|
||||
def detect_index_dim(storage) -> Optional[int]:
|
||||
"""Return the dim of the first stored embedding, or None if the index
|
||||
has no embeddings. Used by /memory status."""
|
||||
try:
|
||||
row = storage.conn.execute(
|
||||
"SELECT embedding FROM chunks WHERE embedding IS NOT NULL LIMIT 1"
|
||||
).fetchone()
|
||||
except Exception:
|
||||
return None
|
||||
if not row or not row["embedding"]:
|
||||
return None
|
||||
try:
|
||||
raw = row["embedding"]
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
# New BLOB format: 4 bytes per float32
|
||||
return len(raw) // 4
|
||||
emb = json.loads(raw)
|
||||
return len(emb) if isinstance(emb, list) else None
|
||||
except (json.JSONDecodeError, TypeError, Exception):
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_legacy_state_file(db_path: PathLike) -> None:
|
||||
"""Remove old embedding_state.json files from earlier versions.
|
||||
Safe to call repeatedly; no-op if the file is absent."""
|
||||
legacy = Path(db_path).parent / "embedding_state.json"
|
||||
try:
|
||||
legacy.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -13,7 +13,7 @@ from datetime import datetime, timedelta
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||
from agent.memory.chunker import TextChunker
|
||||
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
|
||||
from agent.memory.embedding import EmbeddingProvider, EmbeddingCache
|
||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||
|
||||
|
||||
@@ -50,30 +50,22 @@ class MemoryManager:
|
||||
overlap_tokens=self.config.chunk_overlap_tokens
|
||||
)
|
||||
|
||||
# Initialize embedding provider (optional)
|
||||
self.embedding_provider = None
|
||||
if embedding_provider:
|
||||
self.embedding_provider = embedding_provider
|
||||
else:
|
||||
# Try to create embedding provider, but allow failure
|
||||
try:
|
||||
# Get API key from environment or config
|
||||
api_key = os.environ.get('OPENAI_API_KEY')
|
||||
api_base = os.environ.get('OPENAI_API_BASE')
|
||||
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider=self.config.embedding_provider,
|
||||
model=self.config.embedding_model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
except Exception as e:
|
||||
# Embedding provider failed, but that's OK
|
||||
# We can still use keyword search and file operations
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
|
||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
||||
|
||||
# Embedding provider is owned by the caller (agent_initializer is the
|
||||
# canonical entry point and handles legacy/explicit + state validation).
|
||||
# When None is passed, memory degrades to keyword-only search instead
|
||||
# of silently re-initializing a vendor here, which would bypass the
|
||||
# caller's state checks and risk corrupting the index.
|
||||
self.embedding_provider = embedding_provider
|
||||
if self.embedding_provider is None:
|
||||
from common.log import logger
|
||||
logger.info(
|
||||
"[MemoryManager] No embedding provider; memory will use keyword search only"
|
||||
)
|
||||
|
||||
# Cache for query embeddings (avoids redundant API calls within a session)
|
||||
self._embedding_cache = EmbeddingCache()
|
||||
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
self.flush_manager = MemoryFlushManager(
|
||||
@@ -133,12 +125,21 @@ class MemoryManager:
|
||||
if self.config.sync_on_search and self._dirty:
|
||||
await self.sync()
|
||||
|
||||
# Perform vector search (if embedding provider available)
|
||||
from common.log import logger
|
||||
|
||||
# Perform vector search (if embedding provider available).
|
||||
# Failures degrade silently to keyword-only — no exception is raised.
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
try:
|
||||
from common.log import logger
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
provider_name = type(self.embedding_provider).__name__
|
||||
model_name = getattr(self.embedding_provider, 'model', '')
|
||||
cached = self._embedding_cache.get(query, provider_name, model_name)
|
||||
if cached is not None:
|
||||
query_embedding = cached
|
||||
else:
|
||||
query_embedding = self.embedding_provider.embed_query(query)
|
||||
self._embedding_cache.put(query, provider_name, model_name, query_embedding)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
@@ -147,19 +148,19 @@ class MemoryManager:
|
||||
)
|
||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Vector search failed: {e}")
|
||||
|
||||
# Perform keyword search
|
||||
logger.error(
|
||||
f"[MemoryManager] Vector search failed, falling back to keyword-only: {e}"
|
||||
)
|
||||
|
||||
# Perform keyword search (also runs as fallback when vector failed)
|
||||
keyword_results = self.storage.search_keyword(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
vector_results,
|
||||
@@ -167,7 +168,7 @@ class MemoryManager:
|
||||
self.config.vector_weight,
|
||||
self.config.keyword_weight
|
||||
)
|
||||
|
||||
|
||||
# Filter by min score and limit
|
||||
filtered = [r for r in merged if r.score >= min_score]
|
||||
return filtered[:max_results]
|
||||
@@ -249,295 +250,195 @@ class MemoryManager:
|
||||
|
||||
async def sync(self, force: bool = False):
|
||||
"""
|
||||
Synchronize memory from files
|
||||
|
||||
Synchronize memory from files.
|
||||
|
||||
Two-pass design to amortize embedding HTTP cost:
|
||||
1. Walk all files, chunk those whose hash changed, collect pending
|
||||
chunks across files. No embedding calls yet.
|
||||
2. Run a single embed_batch over the union of pending chunks (the
|
||||
provider auto-paginates by vendor cap), then persist per-file.
|
||||
|
||||
For workspaces with many small files (101 files / ~1 chunk each), this
|
||||
cuts ~100 HTTP calls down to ~ceil(total_chunks / vendor_cap).
|
||||
|
||||
Args:
|
||||
force: Force full reindex
|
||||
"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
workspace_dir = self.config.get_workspace()
|
||||
|
||||
# Scan MEMORY.md (workspace root)
|
||||
|
||||
files_to_scan: List[tuple] = [] # (file_path, source, scope, user_id)
|
||||
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
await self._sync_file(memory_file, "memory", "shared", None)
|
||||
|
||||
# Scan memory directory (including daily summaries)
|
||||
files_to_scan.append((memory_file, "memory", "shared", None))
|
||||
|
||||
if memory_dir.exists():
|
||||
for file_path in memory_dir.rglob("*.md"):
|
||||
# Determine scope and user_id from path
|
||||
rel_path = file_path.relative_to(workspace_dir)
|
||||
parts = rel_path.parts
|
||||
|
||||
# Check if it's in daily summary directory
|
||||
if "daily" in parts:
|
||||
# Daily summary files
|
||||
if "users" in parts or len(parts) > 3:
|
||||
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
|
||||
user_idx = parts.index("daily") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
rel_parts = file_path.relative_to(workspace_dir).parts
|
||||
if any(part.startswith('.') for part in rel_parts):
|
||||
continue
|
||||
# Dream diaries are narrative reflections produced by Deep
|
||||
# Dream; their factual content has already been distilled
|
||||
# into MEMORY.md. Indexing them adds noisy near-duplicates
|
||||
# that crowd out the authoritative entry in retrieval.
|
||||
if "dreams" in rel_parts:
|
||||
continue
|
||||
if "daily" in rel_parts:
|
||||
if "users" in rel_parts or len(rel_parts) > 3:
|
||||
user_idx = rel_parts.index("daily") + 1
|
||||
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared daily summary: memory/daily/2024-01-29.md
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
elif "users" in parts:
|
||||
# User-scoped memory
|
||||
user_idx = parts.index("users") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
elif "users" in rel_parts:
|
||||
user_idx = rel_parts.index("users") + 1
|
||||
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared memory
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
|
||||
await self._sync_file(file_path, "memory", scope, user_id)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
async def _sync_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
source: str,
|
||||
scope: str,
|
||||
user_id: Optional[str]
|
||||
):
|
||||
"""Sync a single file"""
|
||||
# Compute file hash
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
|
||||
# Get relative path
|
||||
workspace_dir = self.config.get_workspace()
|
||||
rel_path = str(file_path.relative_to(workspace_dir))
|
||||
|
||||
# Check if file changed
|
||||
stored_hash = self.storage.get_file_hash(rel_path)
|
||||
if stored_hash == file_hash:
|
||||
return # No changes
|
||||
|
||||
# Delete old chunks
|
||||
self.storage.delete_by_path(rel_path)
|
||||
|
||||
# Chunk and embed
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None
|
||||
))
|
||||
|
||||
# Save
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
stat = file_path.stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size
|
||||
)
|
||||
|
||||
def should_flush_memory(
|
||||
self,
|
||||
current_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window。
|
||||
使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
|
||||
Returns:
|
||||
True if memory flush should run
|
||||
"""
|
||||
return self.flush_manager.should_flush(
|
||||
current_tokens=current_tokens,
|
||||
token_threshold=self.config.flush_token_threshold,
|
||||
turn_threshold=self.config.flush_turn_threshold
|
||||
)
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数(每次用户消息+AI回复算一轮)"""
|
||||
self.flush_manager.increment_turn()
|
||||
|
||||
async def execute_memory_flush(
|
||||
self,
|
||||
agent_executor,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush before compaction
|
||||
|
||||
This runs a silent agent turn to write durable memories to disk.
|
||||
Similar to clawdbot's pre-compaction memory flush.
|
||||
|
||||
Args:
|
||||
agent_executor: Async function to execute agent with prompt
|
||||
current_tokens: Current session token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
|
||||
Example:
|
||||
>>> async def run_agent(prompt, system_prompt, silent=False):
|
||||
... # Your agent execution logic
|
||||
... pass
|
||||
>>>
|
||||
>>> if manager.should_flush_memory(current_tokens=100000):
|
||||
... await manager.execute_memory_flush(
|
||||
... agent_executor=run_agent,
|
||||
... current_tokens=100000
|
||||
... )
|
||||
"""
|
||||
success = await self.flush_manager.execute_flush(
|
||||
agent_executor=agent_executor,
|
||||
current_tokens=current_tokens,
|
||||
user_id=user_id,
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
if success:
|
||||
# Mark dirty so next search will sync the new memories
|
||||
self._dirty = True
|
||||
|
||||
return success
|
||||
|
||||
def build_memory_guidance(self, lang: str = "zh", include_context: bool = True) -> str:
|
||||
"""
|
||||
Build natural memory guidance for agent system prompt
|
||||
|
||||
Following clawdbot's approach:
|
||||
1. Load MEMORY.md as bootstrap context (blends into background)
|
||||
2. Load daily files on-demand via memory_search tool
|
||||
3. Agent should NOT proactively mention memories unless user asks
|
||||
|
||||
Args:
|
||||
lang: Language for guidance ("en" or "zh")
|
||||
include_context: Whether to include bootstrap memory context (default: True)
|
||||
MEMORY.md is loaded as background context (like clawdbot)
|
||||
Daily files are accessed via memory_search tool
|
||||
|
||||
Returns:
|
||||
Memory guidance text (and optionally context) for system prompt
|
||||
"""
|
||||
today_file = self.flush_manager.get_today_memory_file().name
|
||||
|
||||
if lang == "zh":
|
||||
guidance = f"""## 记忆系统
|
||||
files_to_scan.append((file_path, "memory", scope, user_id))
|
||||
|
||||
**背景知识**: 下方包含核心长期记忆,可直接使用。需要查找历史时,用 memory_search 搜索(搜索一次即可,不要重复)。
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
knowledge_dir = Path(workspace_dir) / "knowledge"
|
||||
if knowledge_dir.exists():
|
||||
for file_path in knowledge_dir.rglob("*.md"):
|
||||
files_to_scan.append((file_path, "knowledge", "shared", None))
|
||||
|
||||
**存储记忆**: 当用户分享重要信息时(偏好、决策、事实等),主动用 write 工具存储:
|
||||
- 长期信息 → MEMORY.md
|
||||
- 当天笔记 → memory/{today_file}
|
||||
- 静默存储,仅在明确要求时确认
|
||||
|
||||
**使用原则**: 自然使用记忆,就像你本来就知道。不需要生硬地提起或列举记忆,除非用户提到。"""
|
||||
else:
|
||||
guidance = f"""## Memory System
|
||||
|
||||
**Background Knowledge**: Core long-term memories below - use directly. For history, use memory_search once (don't repeat).
|
||||
|
||||
**Store Memories**: When user shares important info (preferences, decisions, facts), proactively write:
|
||||
- Durable info → MEMORY.md
|
||||
- Daily notes → memory/{today_file}
|
||||
- Store silently; confirm only when explicitly requested
|
||||
|
||||
**Usage**: Use memories naturally as if you always knew. Don't mention or list unless user explicitly asks."""
|
||||
|
||||
if include_context:
|
||||
# Load bootstrap context (MEMORY.md only, like clawdbot)
|
||||
bootstrap_context = self.load_bootstrap_memories()
|
||||
if bootstrap_context:
|
||||
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
|
||||
|
||||
return guidance
|
||||
|
||||
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Load bootstrap memory files for session start
|
||||
|
||||
Following clawdbot's design:
|
||||
- Only loads MEMORY.md from workspace root (long-term curated memory)
|
||||
- Daily files (memory/YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
|
||||
- User-specific MEMORY.md is also loaded if user_id provided
|
||||
|
||||
Returns memory content WITHOUT obvious headers so it blends naturally
|
||||
into the context as background knowledge.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memories
|
||||
|
||||
Returns:
|
||||
Memory content to inject into system prompt (blends naturally as background context)
|
||||
"""
|
||||
workspace_dir = self.config.get_workspace()
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
|
||||
sections = []
|
||||
|
||||
# 1. Load MEMORY.md from workspace root (long-term curated memory)
|
||||
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
# Pass 1: inline chunking + change detection. Inlined (instead of
|
||||
# calling self._prepare_file_for_sync) so this method does not depend
|
||||
# on any sibling helpers — keeps it robust against partial reloads
|
||||
# where the class object is older than the method's source.
|
||||
pending: List[Dict[str, Any]] = []
|
||||
workspace_dir_path = self.config.get_workspace()
|
||||
for file_path, source, scope, user_id in files_to_scan:
|
||||
try:
|
||||
content = memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except Exception:
|
||||
continue
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
rel_path = str(file_path.relative_to(workspace_dir_path))
|
||||
if self.storage.get_file_hash(rel_path) == file_hash:
|
||||
continue
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
continue
|
||||
pending.append({
|
||||
"file_path": file_path,
|
||||
"rel_path": rel_path,
|
||||
"source": source,
|
||||
"scope": scope,
|
||||
"user_id": user_id,
|
||||
"file_hash": file_hash,
|
||||
"chunks": chunks,
|
||||
"texts": [c.text for c in chunks],
|
||||
})
|
||||
|
||||
if not pending:
|
||||
self._dirty = False
|
||||
return
|
||||
|
||||
# Pass 2: single batched embed across all pending chunks.
|
||||
# CRITICAL: never touch the index until we hold valid embeddings.
|
||||
# If embed_batch fails, leave the existing index intact (chunks +
|
||||
# file_hash) so the next sync will retry the same files. Writing
|
||||
# NULL embeddings + updating file_hash here would mark the file as
|
||||
# "successfully synced" and silently strand it without vectors.
|
||||
all_texts: List[str] = []
|
||||
for entry in pending:
|
||||
all_texts.extend(entry["texts"])
|
||||
|
||||
if not self.embedding_provider:
|
||||
# No provider configured at all (legacy keyword-only). Persist
|
||||
# chunks without embeddings — this is the user's intent.
|
||||
all_embeddings: List[Optional[List[float]]] = [None] * len(all_texts)
|
||||
else:
|
||||
try:
|
||||
all_embeddings = self.embedding_provider.embed_batch(all_texts)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read MEMORY.md: {e}")
|
||||
|
||||
# 2. Load user-specific MEMORY.md if user_id provided
|
||||
if user_id:
|
||||
user_memory_dir = memory_dir / "users" / user_id
|
||||
user_memory_file = user_memory_dir / "MEMORY.md"
|
||||
if user_memory_file.exists():
|
||||
try:
|
||||
content = user_memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read user memory: {e}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
# Join sections without obvious headers - let memories blend naturally
|
||||
# This makes the agent feel like it "just knows" rather than "checking memory files"
|
||||
return "\n\n".join(sections)
|
||||
from common.log import logger
|
||||
logger.error(
|
||||
f"[MemoryManager] Batch embedding failed for {len(all_texts)} "
|
||||
f"chunks across {len(pending)} files: {e}. "
|
||||
f"Index left untouched; will retry on next sync."
|
||||
)
|
||||
# Bail before touching storage. self._dirty stays True so
|
||||
# callers know there is pending work.
|
||||
return
|
||||
|
||||
# Pass 3: inline persist — same self-contained reasoning as Pass 1.
|
||||
cursor = 0
|
||||
for entry in pending:
|
||||
n = len(entry["texts"])
|
||||
entry_embeddings = all_embeddings[cursor:cursor + n]
|
||||
cursor += n
|
||||
|
||||
rel_path = entry["rel_path"]
|
||||
self.storage.delete_by_path(rel_path)
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(entry["chunks"], entry_embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=entry["user_id"],
|
||||
scope=entry["scope"],
|
||||
source=entry["source"],
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None,
|
||||
))
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
stat = entry["file_path"].stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=entry["source"],
|
||||
file_hash=entry["file_hash"],
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size,
|
||||
)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
def flush_memory(
|
||||
self,
|
||||
messages: list,
|
||||
user_id: Optional[str] = None,
|
||||
reason: str = "threshold",
|
||||
max_messages: int = 10,
|
||||
context_summary_callback=None,
|
||||
) -> bool:
|
||||
"""
|
||||
Flush conversation summary to daily memory file.
|
||||
|
||||
Args:
|
||||
messages: Conversation message list
|
||||
user_id: Optional user ID
|
||||
reason: "threshold" | "overflow" | "daily_summary"
|
||||
max_messages: Max recent messages to include (0 = all)
|
||||
context_summary_callback: Optional callback(str) invoked with the
|
||||
daily summary text for in-context injection
|
||||
|
||||
Returns:
|
||||
True if flush was dispatched
|
||||
"""
|
||||
success = self.flush_manager.flush_from_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
reason=reason,
|
||||
max_messages=max_messages,
|
||||
context_summary_callback=context_summary_callback,
|
||||
)
|
||||
if success:
|
||||
self._dirty = True
|
||||
return success
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get memory status"""
|
||||
@@ -568,6 +469,37 @@ class MemoryManager:
|
||||
content = f"{path}:{start_line}:{end_line}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_temporal_decay(path: str, half_life_days: float = 30.0) -> float:
|
||||
"""
|
||||
Compute temporal decay multiplier for dated memory files.
|
||||
|
||||
Inspired by OpenClaw's temporal-decay: exponential decay based on file date.
|
||||
MEMORY.md and non-dated files are "evergreen" (no decay, multiplier=1.0).
|
||||
Daily files like memory/2025-03-01.md decay based on age.
|
||||
|
||||
Formula: multiplier = exp(-ln2/half_life * age_in_days)
|
||||
"""
|
||||
import re
|
||||
import math
|
||||
|
||||
match = re.search(r'(\d{4})-(\d{2})-(\d{2})\.md$', path)
|
||||
if not match:
|
||||
return 1.0 # evergreen: MEMORY.md, non-dated files
|
||||
|
||||
try:
|
||||
file_date = datetime(
|
||||
int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
)
|
||||
age_days = (datetime.now() - file_date).days
|
||||
if age_days <= 0:
|
||||
return 1.0
|
||||
|
||||
decay_lambda = math.log(2) / half_life_days
|
||||
return math.exp(-decay_lambda * age_days)
|
||||
except (ValueError, OverflowError):
|
||||
return 1.0
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
vector_results: List[SearchResult],
|
||||
@@ -575,8 +507,7 @@ class MemoryManager:
|
||||
vector_weight: float,
|
||||
keyword_weight: float
|
||||
) -> List[SearchResult]:
|
||||
"""Merge vector and keyword search results"""
|
||||
# Create a map by (path, start_line, end_line)
|
||||
"""Merge vector and keyword search results with temporal decay for dated files"""
|
||||
merged_map = {}
|
||||
|
||||
for result in vector_results:
|
||||
@@ -598,7 +529,6 @@ class MemoryManager:
|
||||
'keyword_score': result.score
|
||||
}
|
||||
|
||||
# Calculate combined scores
|
||||
merged_results = []
|
||||
for entry in merged_map.values():
|
||||
combined_score = (
|
||||
@@ -606,7 +536,11 @@ class MemoryManager:
|
||||
keyword_weight * entry['keyword_score']
|
||||
)
|
||||
|
||||
# Apply temporal decay for dated memory files
|
||||
result = entry['result']
|
||||
decay = self._compute_temporal_decay(result.path)
|
||||
combined_score *= decay
|
||||
|
||||
merged_results.append(SearchResult(
|
||||
path=result.path,
|
||||
start_line=result.start_line,
|
||||
@@ -617,6 +551,5 @@ class MemoryManager:
|
||||
user_id=result.user_id
|
||||
))
|
||||
|
||||
# Sort by score
|
||||
merged_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return merged_results
|
||||
|
||||
14
agent/memory/rebuild_index.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Backward-compatible shim for the legacy entry point:
|
||||
python -m agent.memory.rebuild_index
|
||||
|
||||
The implementation now lives in agent.memory.embedding.rebuild.
|
||||
Prefer using `/memory rebuild-index` in chat going forward.
|
||||
"""
|
||||
|
||||
from agent.memory.embedding.rebuild import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(main())
|
||||
@@ -32,68 +32,80 @@ class MemoryService:
|
||||
# ------------------------------------------------------------------
|
||||
# list — paginated file metadata
|
||||
# ------------------------------------------------------------------
|
||||
def list_files(self, page: int = 1, page_size: int = 20) -> dict:
|
||||
def list_files(self, page: int = 1, page_size: int = 20, category: str = "memory") -> dict:
|
||||
"""
|
||||
List all memory files with metadata (without content).
|
||||
List memory or dream files with metadata (without content).
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total": 15,
|
||||
"list": [
|
||||
{"filename": "MEMORY.md", "type": "global", "size": 2048, "updated_at": "2026-02-20 10:00:00"},
|
||||
{"filename": "2026-02-20.md", "type": "daily", "size": 512, "updated_at": "2026-02-20 09:30:00"},
|
||||
...
|
||||
]
|
||||
}
|
||||
Args:
|
||||
category: ``"memory"`` (default) — MEMORY.md + daily files;
|
||||
``"dream"`` — dream diary files from memory/dreams/
|
||||
"""
|
||||
if category == "dream":
|
||||
files = self._list_dream_files()
|
||||
else:
|
||||
files = self._list_memory_files()
|
||||
|
||||
total = len(files)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"list": files[start:end],
|
||||
}
|
||||
|
||||
def _list_memory_files(self) -> List[dict]:
|
||||
"""MEMORY.md + memory/*.md (newest first)."""
|
||||
files: List[dict] = []
|
||||
|
||||
# 1. Global memory — MEMORY.md in workspace root
|
||||
global_path = os.path.join(self.workspace_root, "MEMORY.md")
|
||||
if os.path.isfile(global_path):
|
||||
files.append(self._file_info(global_path, "MEMORY.md", "global"))
|
||||
|
||||
# 2. Daily memory files — memory/*.md (sorted newest first)
|
||||
if os.path.isdir(self.memory_dir):
|
||||
daily_files = []
|
||||
for name in os.listdir(self.memory_dir):
|
||||
full = os.path.join(self.memory_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
daily_files.append((name, full))
|
||||
# Sort by filename descending (newest date first)
|
||||
daily_files.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in daily_files:
|
||||
files.append(self._file_info(full, name, "daily"))
|
||||
|
||||
total = len(files)
|
||||
return files
|
||||
|
||||
# Paginate
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
page_items = files[start:end]
|
||||
def _list_dream_files(self) -> List[dict]:
|
||||
"""memory/dreams/*.md (newest first)."""
|
||||
files: List[dict] = []
|
||||
dreams_dir = os.path.join(self.memory_dir, "dreams")
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"list": page_items,
|
||||
}
|
||||
if os.path.isdir(dreams_dir):
|
||||
entries = []
|
||||
for name in os.listdir(dreams_dir):
|
||||
full = os.path.join(dreams_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
entries.append((name, full))
|
||||
entries.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in entries:
|
||||
files.append(self._file_info(full, name, "dream"))
|
||||
|
||||
return files
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# content — read a single file
|
||||
# ------------------------------------------------------------------
|
||||
def get_content(self, filename: str) -> dict:
|
||||
def get_content(self, filename: str, category: str = "memory") -> dict:
|
||||
"""
|
||||
Read the full content of a memory file.
|
||||
Read the full content of a memory or dream file.
|
||||
|
||||
:param filename: File name, e.g. ``MEMORY.md`` or ``2026-02-20.md``
|
||||
:param filename: File name, e.g. ``MEMORY.md``, ``2026-02-20.md``
|
||||
:param category: ``"memory"`` or ``"dream"``
|
||||
:return: dict with ``filename`` and ``content``
|
||||
:raises FileNotFoundError: if the file does not exist
|
||||
"""
|
||||
path = self._resolve_path(filename)
|
||||
path = self._resolve_path(filename, category)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Memory file not found: {filename}")
|
||||
|
||||
@@ -113,7 +125,7 @@ class MemoryService:
|
||||
Dispatch a memory management action.
|
||||
|
||||
:param action: ``list`` or ``content``
|
||||
:param payload: action-specific payload
|
||||
:param payload: action-specific payload (supports ``category``: ``"memory"`` | ``"dream"``)
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
@@ -121,19 +133,23 @@ class MemoryService:
|
||||
if action == "list":
|
||||
page = payload.get("page", 1)
|
||||
page_size = payload.get("page_size", 20)
|
||||
result_payload = self.list_files(page=page, page_size=page_size)
|
||||
category = payload.get("category", "memory")
|
||||
result_payload = self.list_files(page=page, page_size=page_size, category=category)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
elif action == "content":
|
||||
filename = payload.get("filename")
|
||||
if not filename:
|
||||
return {"action": action, "code": 400, "message": "filename is required", "payload": None}
|
||||
result_payload = self.get_content(filename)
|
||||
category = payload.get("category", "memory")
|
||||
result_payload = self.get_content(filename, category=category)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 403, "message": "invalid filename", "payload": None}
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
@@ -143,16 +159,30 @@ class MemoryService:
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _resolve_path(self, filename: str) -> str:
|
||||
def _resolve_path(self, filename: str, category: str = "memory") -> str:
|
||||
"""
|
||||
Resolve a filename to its absolute path.
|
||||
Safely resolve a filename to its absolute path within the allowed directory.
|
||||
|
||||
- ``MEMORY.md`` → ``{workspace_root}/MEMORY.md``
|
||||
- ``2026-02-20.md`` → ``{workspace_root}/memory/2026-02-20.md``
|
||||
- ``2026-02-20.md`` (memory) → ``{workspace_root}/memory/2026-02-20.md``
|
||||
- ``2026-02-20.md`` (dream) → ``{workspace_root}/memory/dreams/2026-02-20.md``
|
||||
|
||||
Raises ValueError if the resolved path escapes the allowed directory.
|
||||
"""
|
||||
if filename == "MEMORY.md":
|
||||
return os.path.join(self.workspace_root, filename)
|
||||
return os.path.join(self.memory_dir, filename)
|
||||
base_dir = self.workspace_root
|
||||
elif category == "dream":
|
||||
base_dir = os.path.join(self.memory_dir, "dreams")
|
||||
else:
|
||||
base_dir = self.memory_dir
|
||||
|
||||
resolved = os.path.realpath(os.path.join(base_dir, filename))
|
||||
allowed = os.path.realpath(base_dir)
|
||||
|
||||
if resolved != allowed and not resolved.startswith(allowed + os.sep):
|
||||
raise ValueError(f"Invalid filename: path traversal detected")
|
||||
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _file_info(path: str, filename: str, file_type: str) -> dict:
|
||||
|
||||
@@ -1,225 +1,700 @@
|
||||
"""
|
||||
Memory flush manager
|
||||
Memory flush manager with Deep Dream distillation
|
||||
|
||||
Triggers memory flush before context compaction (similar to clawdbot)
|
||||
Handles memory persistence when conversation context is trimmed or overflows:
|
||||
- Uses LLM to summarize discarded messages into concise daily records
|
||||
- Writes to daily memory files (lazy creation)
|
||||
- Deduplicates trim flushes to avoid repeated writes
|
||||
- Runs summarization asynchronously to avoid blocking normal replies
|
||||
- Deep Dream: periodically distills daily memories → refined MEMORY.md + dream diary
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable, Any
|
||||
import threading
|
||||
from typing import Optional, Callable, Any, List, Dict
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from common.log import logger
|
||||
|
||||
|
||||
SUMMARIZE_SYSTEM_PROMPT = """你是一个对话记录助手。请将对话内容归纳为当天的日常记录。
|
||||
|
||||
## 要求
|
||||
|
||||
按「事件」维度归纳发生的事,不要按对话轮次逐条记录:
|
||||
- 每条一行,用 "- " 开头
|
||||
- 合并同一件事的多轮对话
|
||||
- 只记录有意义的事件,忽略闲聊和问候
|
||||
- 保留关键的决策、结论和待办事项
|
||||
|
||||
当对话没有任何记录价值(仅含问候或无意义内容),直接回复"无"。"""
|
||||
|
||||
SUMMARIZE_USER_PROMPT = """请归纳以下对话的日常记录:
|
||||
|
||||
{conversation}"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deep Dream prompts — distill daily memories → MEMORY.md + dream diary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DREAM_SYSTEM_PROMPT = """你是一个记忆整理助手,负责定期整理用户的长期记忆。
|
||||
|
||||
你将收到两份材料:
|
||||
1. **当前长期记忆** — MEMORY.md 的全部现有内容
|
||||
2. **今日日记** — 当天的日常记录
|
||||
|
||||
MEMORY.md 会注入每次对话的系统提示词中,因此必须保持精炼,只存放有价值和值得记忆的内容。
|
||||
|
||||
**重要:只能基于提供的材料进行整理,严禁编造、推测或添加材料中不存在的信息。**
|
||||
|
||||
## 任务
|
||||
|
||||
### Part 1: 更新后的长期记忆([MEMORY])
|
||||
|
||||
在现有记忆基础上进行整理和提炼,输出完整的更新后内容:
|
||||
- **合并提炼**:将含义相近的多条合并为一条高密度表述,而非简单罗列
|
||||
- **新增萃取**:从今日日记中提取值得永久记住的新信息(偏好、决策、人物、规则、经验)
|
||||
- **冲突更新**:当新信息与旧条目矛盾时,以新信息为准,替换旧条目
|
||||
- **清理无效**:删除临时性记录、空白条目、格式残留、无意义、重复内容等
|
||||
- **删除冗余**:已被更精炼表述涵盖的旧条目应删除,避免信息重复
|
||||
- 每条一行,用 "- " 开头,不带日期前缀
|
||||
- 可用 "## 标题" 对相关条目分组,使结构更清晰
|
||||
- 目标:控制在 50 条以内,每条尽量一句话概括
|
||||
|
||||
### Part 2: 梦境日记([DREAM])
|
||||
|
||||
用简洁的叙事风格写一篇短日记,记录这次整理的发现,保持格式美观易读:
|
||||
- 发现了哪些重复或矛盾
|
||||
- 从日记中提取了什么新洞察
|
||||
- 做了哪些清理和优化
|
||||
- 整体感受和观察
|
||||
|
||||
## 输出格式(严格遵守)
|
||||
|
||||
```
|
||||
[MEMORY]
|
||||
- 记忆条目1
|
||||
- 记忆条目2
|
||||
...
|
||||
|
||||
[DREAM]
|
||||
梦境日记内容...
|
||||
```"""
|
||||
|
||||
DREAM_USER_PROMPT = """## 当前长期记忆(MEMORY.md)
|
||||
|
||||
{memory_content}
|
||||
|
||||
## 近期日记(最近 {days} 天)
|
||||
|
||||
{daily_content}"""
|
||||
|
||||
|
||||
|
||||
class MemoryFlushManager:
|
||||
"""
|
||||
Manages memory flush operations before context compaction
|
||||
Manages memory flush operations.
|
||||
|
||||
Similar to clawdbot's memory flush mechanism:
|
||||
- Triggers when context approaches token limit
|
||||
- Runs a silent agent turn to write memories to disk
|
||||
- Uses memory/YYYY-MM-DD.md for daily notes
|
||||
- Uses MEMORY.md (workspace root) for long-term curated memories
|
||||
Flush is triggered by agent_stream in two scenarios:
|
||||
1. Context trim: _trim_messages discards old turns → flush discarded content
|
||||
2. Context overflow: API rejects request → emergency flush before clearing
|
||||
|
||||
Additionally, create_daily_summary() can be called by scheduler for end-of-day summaries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
llm_model: Optional[Any] = None
|
||||
llm_model: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Initialize memory flush manager
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
llm_model: LLM model for agent execution (optional)
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.llm_model = llm_model
|
||||
|
||||
self.memory_dir = workspace_dir / "memory"
|
||||
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Tracking
|
||||
self.last_flush_token_count: Optional[int] = None
|
||||
self.last_flush_timestamp: Optional[datetime] = None
|
||||
self.turn_count: int = 0 # 对话轮数计数器
|
||||
self._trim_flushed_hashes: set = set() # Content hashes of already-flushed messages
|
||||
self._last_flushed_content_hash: str = "" # Content hash at last flush, for daily dedup
|
||||
self._last_dream_input_hash: str = "" # "{date}:{daily_hash}" of last dream, for dedup
|
||||
self._last_flush_thread: Optional[threading.Thread] = None
|
||||
|
||||
def should_flush(
|
||||
self,
|
||||
current_tokens: int = 0,
|
||||
token_threshold: int = 50000,
|
||||
turn_threshold: int = 20
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window:
|
||||
- Token 阈值: 达到 50K tokens 时触发
|
||||
- 轮次阈值: 达到 20 轮对话时触发
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
token_threshold: Token threshold to trigger flush (default: 50K)
|
||||
turn_threshold: Turn threshold to trigger flush (default: 20)
|
||||
|
||||
Returns:
|
||||
True if flush should run
|
||||
"""
|
||||
# 检查 token 阈值
|
||||
if current_tokens > 0 and current_tokens >= token_threshold:
|
||||
# 避免重复 flush
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + 5000:
|
||||
return False
|
||||
return True
|
||||
|
||||
# 检查轮次阈值
|
||||
if self.turn_count >= turn_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get today's memory file path: memory/YYYY-MM-DD.md
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memory
|
||||
|
||||
Returns:
|
||||
Path to today's memory file
|
||||
"""
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None, ensure_exists: bool = False) -> Path:
|
||||
"""Get today's memory file path: memory/YYYY-MM-DD.md"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / f"{today}.md"
|
||||
if ensure_exists:
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
today_file = user_dir / f"{today}.md"
|
||||
else:
|
||||
return self.memory_dir / f"{today}.md"
|
||||
today_file = self.memory_dir / f"{today}.md"
|
||||
|
||||
if ensure_exists and not today_file.exists():
|
||||
today_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
today_file.write_text(f"# Daily Memory: {today}\n\n")
|
||||
|
||||
return today_file
|
||||
|
||||
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get main memory file path: MEMORY.md (workspace root)
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memory
|
||||
|
||||
Returns:
|
||||
Path to main memory file
|
||||
"""
|
||||
"""Get main memory file path: MEMORY.md (workspace root)"""
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / "MEMORY.md"
|
||||
else:
|
||||
# Return workspace root MEMORY.md
|
||||
return Path(self.workspace_dir) / "MEMORY.md"
|
||||
|
||||
def create_flush_prompt(self) -> str:
|
||||
"""
|
||||
Create prompt for memory flush turn
|
||||
|
||||
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_PROMPT
|
||||
"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
return (
|
||||
f"Pre-compaction memory flush. "
|
||||
f"Store durable memories now (use memory/{today}.md for daily notes; "
|
||||
f"create memory/ if needed). "
|
||||
f"\n\n"
|
||||
f"重要提示:\n"
|
||||
f"- MEMORY.md: 记录最核心、最常用的信息(例如重要规则、偏好、决策、要求等)\n"
|
||||
f" 如果 MEMORY.md 过长,可以精简或移除不再重要的内容。避免冗长描述,用关键词和要点形式记录\n"
|
||||
f"- memory/{today}.md: 记录当天发生的事件、关键信息、经验教训、对话过程摘要等,突出重点\n"
|
||||
f"- 如果没有重要内容需要记录,回复 NO_REPLY\n"
|
||||
)
|
||||
|
||||
def create_flush_system_prompt(self) -> str:
|
||||
"""
|
||||
Create system prompt for memory flush turn
|
||||
|
||||
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_SYSTEM_PROMPT
|
||||
"""
|
||||
return (
|
||||
"Pre-compaction memory flush turn. "
|
||||
"The session is near auto-compaction; capture durable memories to disk. "
|
||||
"\n\n"
|
||||
"记忆写入原则:\n"
|
||||
"1. MEMORY.md 精简原则: 只记录核心信息(<2000 tokens)\n"
|
||||
" - 记录重要规则、偏好、决策、要求等需要长期记住的关键信息,无需记录过多细节\n"
|
||||
" - 如果 MEMORY.md 过长,可以根据需要精简或删除过时内容\n"
|
||||
"\n"
|
||||
"2. 天级记忆 (memory/YYYY-MM-DD.md):\n"
|
||||
" - 记录当天的重要事件、关键信息、经验教训、对话过程摘要等,确保核心信息点被完整记录\n"
|
||||
"\n"
|
||||
"3. 判断标准:\n"
|
||||
" - 这个信息未来会经常用到吗?→ MEMORY.md\n"
|
||||
" - 这是今天的重要事件或决策吗?→ memory/YYYY-MM-DD.md\n"
|
||||
" - 这是临时性的、不重要的内容吗?→ 不记录\n"
|
||||
"\n"
|
||||
"You may reply, but usually NO_REPLY is correct."
|
||||
)
|
||||
|
||||
async def execute_flush(
|
||||
self,
|
||||
agent_executor: Callable,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush by running a silent agent turn
|
||||
|
||||
Args:
|
||||
agent_executor: Function to execute agent with prompt
|
||||
current_tokens: Current token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
"""
|
||||
try:
|
||||
# Create flush prompts
|
||||
prompt = self.create_flush_prompt()
|
||||
system_prompt = self.create_flush_system_prompt()
|
||||
|
||||
# Execute agent turn (silent, no user-visible reply expected)
|
||||
await agent_executor(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
silent=True, # NO_REPLY expected
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
# Track flush
|
||||
self.last_flush_token_count = current_tokens
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
self.turn_count = 0 # 重置轮数计数器
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Memory flush failed: {e}")
|
||||
return False
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数"""
|
||||
self.turn_count += 1
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get memory flush status"""
|
||||
return {
|
||||
'last_flush_tokens': self.last_flush_token_count,
|
||||
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
|
||||
'today_file': str(self.get_today_memory_file()),
|
||||
'main_file': str(self.get_main_memory_file())
|
||||
}
|
||||
|
||||
# ---- Flush execution (called by agent_stream or scheduler) ----
|
||||
|
||||
def flush_from_messages(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str] = None,
|
||||
reason: str = "trim",
|
||||
max_messages: int = 0,
|
||||
context_summary_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously summarize and flush messages to daily memory.
|
||||
|
||||
Deduplication runs synchronously, then LLM summarization + file write
|
||||
run in a background thread so the main reply flow is never blocked.
|
||||
|
||||
If *context_summary_callback* is provided, it is called with the
|
||||
[DAILY] portion of the LLM summary once available. The caller can use
|
||||
this to inject the summary into the live message list for context
|
||||
continuity — one LLM call serves both disk persistence and in-context
|
||||
injection.
|
||||
"""
|
||||
try:
|
||||
# Strip scheduler-injected pairs before any further processing.
|
||||
# These messages already serve as short-term context inside the
|
||||
# receiver session; promoting them into long-term daily memory
|
||||
# produces low-value flat logs (e.g. "11:28 price=1013, normal /
|
||||
# 11:58 price=1013, normal / ...") and wastes summarisation tokens.
|
||||
messages = self._strip_scheduler_pairs(messages)
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
import hashlib
|
||||
deduped = []
|
||||
for m in messages:
|
||||
text = self._extract_text_from_content(m.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
h = hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
if h not in self._trim_flushed_hashes:
|
||||
self._trim_flushed_hashes.add(h)
|
||||
deduped.append(m)
|
||||
if not deduped:
|
||||
return False
|
||||
|
||||
import copy
|
||||
snapshot = copy.deepcopy(deduped)
|
||||
thread = threading.Thread(
|
||||
target=self._flush_worker,
|
||||
args=(snapshot, user_id, reason, max_messages, context_summary_callback),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
logger.info(f"[MemoryFlush] Async flush dispatched (reason={reason}, msgs={len(snapshot)})")
|
||||
self._last_flush_thread = thread
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Failed to dispatch flush (reason={reason}): {e}")
|
||||
return False
|
||||
|
||||
def _flush_worker(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str],
|
||||
reason: str,
|
||||
max_messages: int,
|
||||
context_summary_callback: Optional[Callable[[str], None]] = None,
|
||||
):
|
||||
"""Background worker: summarize with LLM, write daily memory file."""
|
||||
try:
|
||||
raw_summary = self._summarize_messages(messages, max_messages)
|
||||
if not raw_summary or not raw_summary.strip() or raw_summary.strip() == "无":
|
||||
logger.info(f"[MemoryFlush] No valuable content to flush (reason={reason})")
|
||||
return
|
||||
|
||||
# Strip legacy [DAILY]/[MEMORY] markers if model still outputs them
|
||||
daily_part = self._clean_summary_output(raw_summary)
|
||||
if not daily_part:
|
||||
return
|
||||
|
||||
# --- Write daily memory ---
|
||||
daily_file = ensure_daily_memory_file(self.workspace_dir, user_id)
|
||||
|
||||
headers = {
|
||||
"overflow": f"## Context Overflow Recovery ({datetime.now().strftime('%H:%M')})",
|
||||
"trim": f"## Trimmed Context ({datetime.now().strftime('%H:%M')})",
|
||||
"daily_summary": f"## Daily Summary ({datetime.now().strftime('%H:%M')})",
|
||||
}
|
||||
header = headers.get(reason, f"## Session Notes ({datetime.now().strftime('%H:%M')})")
|
||||
|
||||
with open(daily_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n{header}\n\n{daily_part}\n")
|
||||
|
||||
logger.info(f"[MemoryFlush] Wrote daily memory to {daily_file.name} (reason={reason}, chars={len(daily_part)})")
|
||||
|
||||
# --- Inject context summary into live messages (if callback provided) ---
|
||||
if context_summary_callback:
|
||||
try:
|
||||
context_summary_callback(daily_part)
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Context summary callback failed: {e}")
|
||||
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Async flush failed (reason={reason}): {e}")
|
||||
|
||||
@staticmethod
|
||||
def _clean_summary_output(raw: str) -> str:
|
||||
"""Strip legacy [DAILY]/[MEMORY] markers if present, return clean daily text."""
|
||||
raw = raw.strip()
|
||||
if not raw or raw == "无":
|
||||
return ""
|
||||
|
||||
# Strip [DAILY] marker
|
||||
if "[DAILY]" in raw:
|
||||
start = raw.index("[DAILY]") + len("[DAILY]")
|
||||
end = raw.index("[MEMORY]") if "[MEMORY]" in raw else len(raw)
|
||||
raw = raw[start:end].strip()
|
||||
|
||||
# Remove stray [MEMORY] section entirely
|
||||
if "[MEMORY]" in raw:
|
||||
raw = raw[:raw.index("[MEMORY]")].strip()
|
||||
|
||||
# Remove markdown code fences
|
||||
raw = raw.replace("```", "").strip()
|
||||
|
||||
return raw
|
||||
|
||||
def create_daily_summary(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Generate end-of-day summary. Called by daily timer.
|
||||
Skips if messages haven't changed since last flush.
|
||||
"""
|
||||
import hashlib
|
||||
content = "".join(
|
||||
self._extract_text_from_content(m.get("content", ""))
|
||||
for m in messages
|
||||
)
|
||||
content_hash = hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
if content_hash == self._last_flushed_content_hash:
|
||||
logger.debug("[MemoryFlush] Daily summary skipped: no new content since last flush")
|
||||
return False
|
||||
self._last_flushed_content_hash = content_hash
|
||||
return self.flush_from_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
reason="daily_summary",
|
||||
max_messages=0,
|
||||
)
|
||||
|
||||
# ---- Deep Dream (memory distillation) ----
|
||||
|
||||
def deep_dream(self, user_id: Optional[str] = None, lookback_days: int = 1, force: bool = False) -> bool:
|
||||
"""
|
||||
Distill recent daily memories into MEMORY.md and generate a dream diary.
|
||||
|
||||
Args:
|
||||
lookback_days: How many days of daily files to read (default 1 for scheduled, 3 for manual)
|
||||
force: Skip input-hash dedup check (used by manual /memory dream trigger)
|
||||
"""
|
||||
if not self.llm_model:
|
||||
logger.warning("[DeepDream] No LLM model available, skipping")
|
||||
return False
|
||||
|
||||
logger.info(f"[DeepDream] Starting memory distillation (lookback={lookback_days} days)")
|
||||
|
||||
# Collect materials
|
||||
memory_content = self._read_main_memory(user_id)
|
||||
daily_content, has_content = self._read_recent_dailies(user_id, lookback_days)
|
||||
|
||||
if not has_content:
|
||||
logger.info("[DeepDream] No recent daily records, skipping to preserve existing MEMORY.md")
|
||||
return False
|
||||
|
||||
# Dedup: skip if same daily content already dreamed today.
|
||||
# Note: only hash daily_content (not memory_content), because deep_dream
|
||||
# itself rewrites MEMORY.md as a side effect, which would otherwise
|
||||
# invalidate the hash on every subsequent call within the same window.
|
||||
import hashlib
|
||||
daily_hash = hashlib.md5(daily_content.encode("utf-8")).hexdigest()
|
||||
today_str = datetime.now().strftime("%Y-%m-%d")
|
||||
dedup_key = f"{today_str}:{daily_hash}"
|
||||
if not force and dedup_key == self._last_dream_input_hash:
|
||||
logger.info("[DeepDream] Already dreamed today with same daily content, skipping")
|
||||
return False
|
||||
self._last_dream_input_hash = dedup_key
|
||||
|
||||
logger.info(
|
||||
f"[DeepDream] Materials collected: "
|
||||
f"MEMORY.md={len(memory_content)} chars, "
|
||||
f"daily={len(daily_content)} chars"
|
||||
)
|
||||
|
||||
# Call LLM for distillation
|
||||
import time as _time
|
||||
t0 = _time.monotonic()
|
||||
try:
|
||||
user_msg = DREAM_USER_PROMPT.format(
|
||||
memory_content=memory_content or "(empty)",
|
||||
days=lookback_days,
|
||||
daily_content=daily_content or "(no recent daily records)",
|
||||
)
|
||||
from agent.protocol.models import LLMRequest
|
||||
# Scale max_tokens based on input size to avoid truncating large MEMORY.md
|
||||
input_chars = len(memory_content) + len(daily_content)
|
||||
dream_max_tokens = max(2000, min(input_chars, 8000))
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
temperature=0.3,
|
||||
max_tokens=dream_max_tokens,
|
||||
stream=False,
|
||||
system=DREAM_SYSTEM_PROMPT,
|
||||
)
|
||||
response = self.llm_model.call(request)
|
||||
raw = self._extract_response_text(response)
|
||||
elapsed = _time.monotonic() - t0
|
||||
if not raw or not raw.strip():
|
||||
logger.warning(f"[DeepDream] LLM returned empty response ({elapsed:.1f}s)")
|
||||
return False
|
||||
logger.info(f"[DeepDream] LLM distillation completed ({elapsed:.1f}s, {len(raw)} chars)")
|
||||
except Exception as e:
|
||||
elapsed = _time.monotonic() - t0
|
||||
logger.warning(f"[DeepDream] LLM call failed ({elapsed:.1f}s): {e}")
|
||||
return False
|
||||
|
||||
# Parse [MEMORY] and [DREAM] sections
|
||||
new_memory, dream_diary = self._parse_dream_output(raw)
|
||||
|
||||
if not new_memory:
|
||||
logger.warning("[DeepDream] No [MEMORY] section in LLM output, skipping overwrite")
|
||||
return False
|
||||
|
||||
# Overwrite MEMORY.md
|
||||
try:
|
||||
main_file = self.get_main_memory_file(user_id)
|
||||
old_size = len(memory_content)
|
||||
main_file.write_text(new_memory + "\n", encoding="utf-8")
|
||||
logger.info(
|
||||
f"[DeepDream] Updated MEMORY.md "
|
||||
f"({old_size} → {len(new_memory)} chars)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[DeepDream] Failed to write MEMORY.md: {e}")
|
||||
return False
|
||||
|
||||
# Write dream diary
|
||||
if dream_diary:
|
||||
try:
|
||||
self._write_dream_diary(dream_diary, user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"[DeepDream] Failed to write dream diary: {e}")
|
||||
|
||||
logger.info("[DeepDream] ✅ Deep Dream completed successfully")
|
||||
return True
|
||||
|
||||
def _read_main_memory(self, user_id: Optional[str] = None) -> str:
|
||||
"""Read current MEMORY.md content."""
|
||||
main_file = self.get_main_memory_file(user_id)
|
||||
if main_file.exists():
|
||||
return main_file.read_text(encoding="utf-8").strip()
|
||||
return ""
|
||||
|
||||
def _read_recent_dailies(
|
||||
self, user_id: Optional[str] = None, lookback_days: int = 1
|
||||
) -> tuple:
|
||||
"""
|
||||
Read recent daily memory files.
|
||||
|
||||
Returns:
|
||||
(combined_text, has_content) tuple
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
parts = []
|
||||
has_content = False
|
||||
today = datetime.now().date()
|
||||
|
||||
for offset in range(lookback_days):
|
||||
day = today - timedelta(days=offset)
|
||||
date_str = day.strftime("%Y-%m-%d")
|
||||
if user_id:
|
||||
daily_file = self.memory_dir / "users" / user_id / f"{date_str}.md"
|
||||
else:
|
||||
daily_file = self.memory_dir / f"{date_str}.md"
|
||||
|
||||
if daily_file.exists():
|
||||
content = daily_file.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
parts.append(f"### {date_str}\n\n{content}")
|
||||
has_content = True
|
||||
else:
|
||||
parts.append(f"### {date_str}\n\n(no records)")
|
||||
|
||||
return "\n\n".join(parts), has_content
|
||||
|
||||
@staticmethod
|
||||
def _parse_dream_output(raw: str) -> tuple:
|
||||
"""Parse LLM output into (new_memory, dream_diary)."""
|
||||
raw = raw.strip().replace("```", "")
|
||||
new_memory = ""
|
||||
dream_diary = ""
|
||||
|
||||
if "[MEMORY]" in raw:
|
||||
start = raw.index("[MEMORY]") + len("[MEMORY]")
|
||||
end = raw.index("[DREAM]") if "[DREAM]" in raw else len(raw)
|
||||
new_memory = raw[start:end].strip()
|
||||
|
||||
if "[DREAM]" in raw:
|
||||
start = raw.index("[DREAM]") + len("[DREAM]")
|
||||
dream_diary = raw[start:].strip()
|
||||
|
||||
return new_memory, dream_diary
|
||||
|
||||
def _write_dream_diary(self, content: str, user_id: Optional[str] = None):
|
||||
"""Write dream diary to memory/dreams/YYYY-MM-DD.md."""
|
||||
dreams_dir = self.memory_dir / "dreams"
|
||||
if user_id:
|
||||
dreams_dir = self.memory_dir / "users" / user_id / "dreams"
|
||||
dreams_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
diary_file = dreams_dir / f"{today}.md"
|
||||
diary_file.write_text(
|
||||
f"# Dream Diary: {today}\n\n{content}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(f"[DeepDream] Wrote dream diary to {diary_file}")
|
||||
|
||||
# ---- Internal helpers ----
|
||||
|
||||
def _summarize_messages(self, messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""
|
||||
Summarize conversation messages using LLM.
|
||||
Returns empty string if LLM deems content not worth recording.
|
||||
Rule-based fallback only used when LLM call raises an exception.
|
||||
"""
|
||||
conversation_text = self._format_conversation_for_summary(messages, max_messages)
|
||||
if not conversation_text.strip():
|
||||
return ""
|
||||
|
||||
if self.llm_model:
|
||||
try:
|
||||
summary = self._call_llm_for_summary(conversation_text)
|
||||
if summary and summary.strip() and summary.strip() != "无":
|
||||
return summary.strip()
|
||||
logger.info("[MemoryFlush] LLM returned empty or '无', skipping write")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] LLM summarization failed, using fallback: {e}")
|
||||
return self._extract_summary_fallback(messages, max_messages)
|
||||
else:
|
||||
logger.info("[MemoryFlush] No LLM model available, using rule-based fallback")
|
||||
return self._extract_summary_fallback(messages, max_messages)
|
||||
|
||||
def _format_conversation_for_summary(self, messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""Format messages into readable conversation text for LLM summarization."""
|
||||
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
|
||||
lines = []
|
||||
for msg in msgs:
|
||||
role = msg.get("role", "")
|
||||
text = self._extract_text_from_content(msg.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
text = text.strip()
|
||||
if role == "user":
|
||||
lines.append(f"用户: {text[:500]}")
|
||||
elif role == "assistant":
|
||||
lines.append(f"助手: {text[:500]}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_text(response) -> str:
|
||||
"""
|
||||
Extract text from LLM response regardless of format.
|
||||
|
||||
Handles:
|
||||
- Generator (MiniMax _handle_sync_response yields Claude-format dicts)
|
||||
- Claude format: {"role":"assistant","content":[{"type":"text","text":"..."}]}
|
||||
- OpenAI format: {"choices":[{"message":{"content":"..."}}]}
|
||||
- OpenAI SDK response object with .choices attribute
|
||||
"""
|
||||
import types
|
||||
|
||||
# Unwrap generator — consume first yielded item
|
||||
if isinstance(response, types.GeneratorType):
|
||||
try:
|
||||
response = next(response)
|
||||
except StopIteration:
|
||||
return ""
|
||||
|
||||
if not response:
|
||||
return ""
|
||||
|
||||
if isinstance(response, dict):
|
||||
# Check for error
|
||||
if response.get("error"):
|
||||
raise RuntimeError(response.get("message", "LLM call failed"))
|
||||
|
||||
# Claude format: content is a list of blocks
|
||||
content = response.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
|
||||
# OpenAI format
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("message", {}).get("content", "")
|
||||
|
||||
# OpenAI SDK response object
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
return ""
|
||||
|
||||
def _call_llm_for_summary(self, conversation_text: str) -> str:
|
||||
"""Call LLM to generate a concise summary of the conversation."""
|
||||
from agent.protocol.models import LLMRequest
|
||||
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": SUMMARIZE_USER_PROMPT.format(conversation=conversation_text)}],
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
stream=False,
|
||||
system=SUMMARIZE_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
response = self.llm_model.call(request)
|
||||
return self._extract_response_text(response)
|
||||
|
||||
@staticmethod
|
||||
def _extract_first_meaningful_line(text: str, max_len: int = 120) -> str:
|
||||
"""Extract the first meaningful line from assistant reply, skipping markdown noise."""
|
||||
import re
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Skip markdown headings, horizontal rules, code fences, pure emoji/symbols
|
||||
if re.match(r'^(#{1,4}\s|```|---|\*\*\*|[-*]\s*$|[^\w\u4e00-\u9fff]{1,5}$)', line):
|
||||
continue
|
||||
# Strip leading markdown bold/emoji decorations
|
||||
cleaned = re.sub(r'^[\*#>\-\s]+', '', line).strip()
|
||||
cleaned = re.sub(r'^[\U0001f300-\U0001f9ff\u2600-\u27bf\s]+', '', cleaned).strip()
|
||||
if len(cleaned) >= 5:
|
||||
return cleaned[:max_len]
|
||||
return text.split("\n")[0].strip()[:max_len]
|
||||
|
||||
@staticmethod
|
||||
def _extract_summary_fallback(messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""
|
||||
Rule-based summary of discarded messages.
|
||||
Format: "用户问了X; 助手回答了Y" per event, compact and readable.
|
||||
"""
|
||||
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
|
||||
|
||||
events: List[str] = []
|
||||
current_user_text = ""
|
||||
for msg in msgs:
|
||||
role = msg.get("role", "")
|
||||
text = MemoryFlushManager._extract_text_from_content(msg.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
text = text.strip()
|
||||
|
||||
if role == "user":
|
||||
if len(text) <= 3:
|
||||
continue
|
||||
current_user_text = text[:120]
|
||||
elif role == "assistant" and current_user_text:
|
||||
reply_summary = MemoryFlushManager._extract_first_meaningful_line(text)
|
||||
if reply_summary:
|
||||
events.append(f"- 用户: {current_user_text} → 回复: {reply_summary}")
|
||||
else:
|
||||
events.append(f"- 用户: {current_user_text}")
|
||||
current_user_text = ""
|
||||
|
||||
if current_user_text:
|
||||
events.append(f"- 用户: {current_user_text}")
|
||||
|
||||
return "\n".join(events[:10])
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from message content (string or content blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _strip_scheduler_pairs(cls, messages: List[Dict]) -> List[Dict]:
|
||||
"""Drop scheduler-injected user/assistant pairs from a flush batch.
|
||||
|
||||
A scheduler user message starts with the ``[SCHEDULED]`` marker
|
||||
(written by ``AgentBridge.remember_scheduled_output``); the message
|
||||
immediately following it (if it is an assistant turn) is its paired
|
||||
output and is dropped together. Regular user/assistant turns and
|
||||
any tool_use / tool_result blocks are preserved as-is.
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
SCHEDULED_PREFIX = "[SCHEDULED]"
|
||||
result = []
|
||||
skip_next_assistant = False
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
result.append(msg)
|
||||
skip_next_assistant = False
|
||||
continue
|
||||
role = msg.get("role")
|
||||
if skip_next_assistant and role == "assistant":
|
||||
skip_next_assistant = False
|
||||
continue
|
||||
skip_next_assistant = False
|
||||
if role == "user":
|
||||
text = cls._extract_text_from_content(msg.get("content", ""))
|
||||
if text.lstrip().startswith(SCHEDULED_PREFIX):
|
||||
skip_next_assistant = True
|
||||
continue
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
|
||||
"""
|
||||
Create default memory files if they don't exist
|
||||
Create essential memory files if they don't exist.
|
||||
Only creates MEMORY.md; daily files are created lazily on first write.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
@@ -228,7 +703,7 @@ def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] =
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create main MEMORY.md in workspace root
|
||||
# Create main MEMORY.md in workspace root (always needed for bootstrap)
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -237,14 +712,28 @@ def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] =
|
||||
main_memory = Path(workspace_dir) / "MEMORY.md"
|
||||
|
||||
if not main_memory.exists():
|
||||
# Create empty file or with minimal structure (no obvious "Memory" header)
|
||||
# Following clawdbot's approach: memories should blend naturally into context
|
||||
main_memory.write_text("")
|
||||
|
||||
|
||||
def ensure_daily_memory_file(workspace_dir: Path, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Ensure today's daily memory file exists, creating it only when actually needed.
|
||||
Called lazily before first write to daily memory.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
user_id: Optional user ID for user-specific files
|
||||
|
||||
Returns:
|
||||
Path to today's memory file
|
||||
"""
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create today's memory file
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
today_memory = user_dir / f"{today}.md"
|
||||
else:
|
||||
today_memory = memory_dir / f"{today}.md"
|
||||
@@ -252,5 +741,6 @@ def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] =
|
||||
if not today_memory.exists():
|
||||
today_memory.write_text(
|
||||
f"# Daily Memory: {today}\n\n"
|
||||
f"Day-to-day notes and running context.\n\n"
|
||||
)
|
||||
|
||||
return today_memory
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,7 +43,6 @@ class PromptBuilder:
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
is_first_conversation: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
@@ -52,11 +52,10 @@ class PromptBuilder:
|
||||
base_persona: 基础人格描述(会被context_files中的AGENT.md覆盖)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表(AGENT.md, USER.md, RULE.md等)
|
||||
context_files: 上下文文件列表(AGENT.md, USER.md, RULE.md, BOOTSTRAP.md等)
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
is_first_conversation: 是否为首次对话
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
@@ -72,7 +71,6 @@ class PromptBuilder:
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
is_first_conversation=is_first_conversation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -87,7 +85,6 @@ def build_agent_system_prompt(
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
is_first_conversation: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
@@ -96,10 +93,11 @@ def build_agent_system_prompt(
|
||||
顺序说明(按重要性和逻辑关系排列):
|
||||
1. 工具系统 - 核心能力,最先介绍
|
||||
2. 技能系统 - 紧跟工具,因为技能需要用 read 工具读取
|
||||
3. 记忆系统 - 独立的记忆能力
|
||||
3. 记忆系统 - 记忆检索与写入引导
|
||||
3.5 知识系统 - 结构化知识库(knowledge/index.md 注入)
|
||||
4. 工作空间 - 工作环境说明
|
||||
5. 用户身份 - 用户信息(可选)
|
||||
6. 项目上下文 - AGENT.md, USER.md, RULE.md(定义人格、身份、规则)
|
||||
6. 项目上下文 - AGENT.md, USER.md, RULE.md, MEMORY.md, BOOTSTRAP.md
|
||||
7. 运行时信息 - 元信息(时间、模型等)
|
||||
|
||||
Args:
|
||||
@@ -112,7 +110,6 @@ def build_agent_system_prompt(
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
is_first_conversation: 是否为首次对话
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
@@ -131,9 +128,13 @@ def build_agent_system_prompt(
|
||||
# 3. 记忆系统(独立的记忆能力)
|
||||
if memory_manager:
|
||||
sections.extend(_build_memory_section(memory_manager, tools, language))
|
||||
|
||||
# 3.5 知识系统(结构化知识库)
|
||||
if conf().get("knowledge", True):
|
||||
sections.extend(_build_knowledge_section(workspace_dir, language))
|
||||
|
||||
# 4. 工作空间(工作环境说明)
|
||||
sections.extend(_build_workspace_section(workspace_dir, language, is_first_conversation))
|
||||
sections.extend(_build_workspace_section(workspace_dir, language))
|
||||
|
||||
# 5. 用户身份(如果有)
|
||||
if user_identity:
|
||||
@@ -170,12 +171,13 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"terminal": "管理后台进程",
|
||||
"web_search": "网络搜索",
|
||||
"web_fetch": "获取URL内容",
|
||||
"browser": "控制浏览器",
|
||||
"browser": "控制浏览器(关键结果或需要协助可截图发送给用户)",
|
||||
"memory_search": "搜索记忆",
|
||||
"memory_get": "读取记忆内容",
|
||||
"env_config": "管理API密钥和技能配置",
|
||||
"scheduler": "管理定时任务和提醒",
|
||||
"send": "发送文件给用户",
|
||||
"send": "发送本地文件给用户(仅限本地文件,URL直接放在回复文本中)",
|
||||
"vision": "分析图片内容(识别、描述、OCR文字提取等)",
|
||||
}
|
||||
|
||||
# Preferred display order
|
||||
@@ -184,7 +186,7 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"bash", "terminal",
|
||||
"web_search", "web_fetch", "browser",
|
||||
"memory_search", "memory_get",
|
||||
"env_config", "scheduler", "send",
|
||||
"env_config", "scheduler", "send", "vision",
|
||||
]
|
||||
|
||||
# Build name -> summary mapping for available tools
|
||||
@@ -204,16 +206,17 @@ def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
|
||||
lines = [
|
||||
"## 工具系统",
|
||||
"## 🔧 工具系统",
|
||||
"",
|
||||
"可用工具(名称大小写敏感,严格按列表调用):",
|
||||
"\n".join(tool_lines),
|
||||
"",
|
||||
"工具调用风格:",
|
||||
"",
|
||||
"- 在多步骤任务、敏感操作或用户要求时简要解释决策过程",
|
||||
"- 持续推进直到任务完成,完成后向用户报告结果。",
|
||||
"- 回复中涉及密钥、令牌等敏感信息必须脱敏。",
|
||||
"- 多步骤任务、复杂决策、敏感操作时,应简要说明当前在做什么、为什么这样做,让用户了解关键进展",
|
||||
"- 持续推进直到任务完成,完成后向用户报告结果",
|
||||
"- 回复中涉及密钥、令牌等敏感信息必须脱敏",
|
||||
"- URL链接直接放在回复文本中即可,系统会自动处理和渲染。无需下载后使用send工具发送",
|
||||
"",
|
||||
]
|
||||
|
||||
@@ -235,15 +238,17 @@ def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], langua
|
||||
break
|
||||
|
||||
lines = [
|
||||
"## 技能系统(mandatory)",
|
||||
"## 🧩 技能系统(mandatory)",
|
||||
"",
|
||||
"在回复之前:扫描下方 <available_skills> 中的 <description> 条目。",
|
||||
"在回复之前:扫描下方 <available_skills> 中每个技能的 <description>。",
|
||||
"",
|
||||
f"- 如果恰好有一个技能(Skill)明确适用:使用 `{read_tool_name}` 读取其 <location> 处的 SKILL.md,然后严格遵循它",
|
||||
"- 如果多个技能都适用则选择最匹配的一个,如果没有明确适用的则不要读取任何 SKILL.md",
|
||||
"- 读取 SKILL.md 后直接按其指令执行,无需多余的预检查",
|
||||
f"- 如果有技能的描述与用户需求匹配:使用 `{read_tool_name}` 工具读取其 <location> 路径的 SKILL.md 文件,然后严格遵循文件中的指令。"
|
||||
"当有匹配的技能时,应优先使用技能",
|
||||
"- 如果多个技能都适用则选择最匹配的一个,然后读取并遵循。",
|
||||
"- 如果没有技能明确适用:不要读取任何 SKILL.md,直接使用通用工具。",
|
||||
"",
|
||||
"**注意**: 永远不要一次性读取多个技能,只在选择后再读取。技能和工具不同,必须先读取其SKILL.md并按照文件内容运行。",
|
||||
f"**重要**: 技能不是工具,不能直接调用。使用技能的唯一方式是用 `{read_tool_name}` 读取 SKILL.md 文件,然后按文件内容操作。"
|
||||
"永远不要一次性读取多个技能,只在选择后再读取。",
|
||||
"",
|
||||
"以下是可用技能:"
|
||||
]
|
||||
@@ -269,39 +274,105 @@ def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], langu
|
||||
"""构建记忆系统section"""
|
||||
if not memory_manager:
|
||||
return []
|
||||
|
||||
# 检查是否有memory工具
|
||||
|
||||
has_memory_tools = False
|
||||
if tools:
|
||||
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
|
||||
has_memory_tools = any(name in ['memory_search', 'memory_get'] for name in tool_names)
|
||||
|
||||
|
||||
if not has_memory_tools:
|
||||
return []
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
today_file = datetime.now().strftime("%Y-%m-%d") + ".md"
|
||||
|
||||
lines = [
|
||||
"## 记忆系统",
|
||||
"## 🧠 记忆系统",
|
||||
"",
|
||||
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
|
||||
"### Memory Recall(mandatory)",
|
||||
"",
|
||||
"1. 不确定记忆文件位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
|
||||
"2. 已知文件位置 → 直接用 `memory_get` 读取相应的行 (例如:MEMORY.md, memory/YYYY-MM-DD.md)",
|
||||
"3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件",
|
||||
"当用户询问过往事件、引用之前的决定、提到人物关系、偏好、待办、或你对某事不确定时,**必须先检索记忆再回答**。",
|
||||
"如果 MEMORY.md 中已有相关信息则无需重复检索。完整内容和每日记忆需要通过工具检索。",
|
||||
"",
|
||||
"1. 不确定位置 → `memory_search` 关键词/语义检索",
|
||||
"2. 已知位置 → `memory_get` 直接读取对应行",
|
||||
"3. search 无结果 → `memory_get` 读最近两天记忆",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
|
||||
"- `MEMORY.md`: 长期记忆索引(已自动加载到上下文,核心信息、偏好、决策等)",
|
||||
f"- `memory/YYYY-MM-DD.md`: 每日记忆,今天是 `memory/{today_file}`",
|
||||
"- `knowledge/`: 结构化知识库(见下方知识系统)",
|
||||
"",
|
||||
"**写入记忆**:",
|
||||
"- 追加内容 → `edit` 工具,oldText 留空",
|
||||
"- 修改内容 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- 新建文件 → `write` 工具",
|
||||
"- **禁止写入敏感信息**:API密钥、令牌等敏感信息严禁写入记忆文件",
|
||||
"### 写入记忆",
|
||||
"",
|
||||
"遇到以下情况时,**主动**将信息写入记忆文件(无需告知用户):",
|
||||
"",
|
||||
"- 用户要求记住某些信息,或使用了「记住」「以后」「总是」「不要」「偏好」等表达",
|
||||
"- 用户分享了重要的个人偏好、习惯、决策",
|
||||
"- 对话中产生了重要的结论、方案、约定",
|
||||
"- 完成了复杂任务,值得记录关键步骤和结果",
|
||||
"",
|
||||
"**存储规则**:",
|
||||
f"- 长期核心信息 → `MEMORY.md`",
|
||||
f"- 当天事件/进展 → `memory/{today_file}`",
|
||||
"- 结构化知识 → `knowledge/`(见知识系统)",
|
||||
"- 追加 → `edit` 工具,oldText 留空",
|
||||
"- 修改 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- **禁止写入敏感信息**(API密钥、令牌等)",
|
||||
"",
|
||||
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_knowledge_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""Build knowledge wiki section. Injects knowledge/index.md when present."""
|
||||
index_path = os.path.join(workspace_dir, "knowledge", "index.md")
|
||||
if not os.path.exists(index_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(index_path, 'r', encoding='utf-8') as f:
|
||||
index_content = f.read().strip()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 📚 知识系统",
|
||||
"",
|
||||
"你拥有一个持续积累的个人知识库 `knowledge/`,这是你的长期结构化知识存储。",
|
||||
"",
|
||||
"### 自动写入规则(mandatory)",
|
||||
"",
|
||||
"以下场景**必须**在回复的同时写入知识库,**直接写入,不要询问用户是否需要**:",
|
||||
"",
|
||||
"1. **用户分享了文章/链接/文档** → 阅读理解后,在同一轮回复中将要点写入 `knowledge/sources/<slug>.md`",
|
||||
"2. **深度讨论产生了结论/方案** → 整理为 `knowledge/analysis/<slug>.md`",
|
||||
"3. **对话涉及重要实体**(人物/公司/项目)→ 创建或更新 `knowledge/entities/<name>.md`",
|
||||
"4. **讨论了技术概念/方法论** → 整理为 `knowledge/concepts/<topic>.md`",
|
||||
"",
|
||||
"每次写入知识页面后,**必须同步更新** `knowledge/index.md` 添加一行索引。",
|
||||
"详细的页面格式和操作规范,请读取技能 `knowledge-wiki` 的 SKILL.md。",
|
||||
"",
|
||||
"⚠️ 不要问「要不要存到知识库」——符合上述场景就直接写入,这是你的本能行为。",
|
||||
"",
|
||||
]
|
||||
|
||||
if index_content:
|
||||
lines.extend([
|
||||
"### 当前知识索引",
|
||||
"",
|
||||
index_content,
|
||||
"",
|
||||
])
|
||||
|
||||
lines.extend([
|
||||
"**查询方式**:用 `read` 读取知识页面,或用 `memory_search` 检索(知识已纳入向量索引)。",
|
||||
"",
|
||||
])
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
@@ -311,7 +382,7 @@ def _build_user_identity_section(user_identity: Dict[str, str], language: str) -
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 用户身份",
|
||||
"## 👤 用户身份",
|
||||
"",
|
||||
]
|
||||
|
||||
@@ -335,10 +406,10 @@ def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def _build_workspace_section(workspace_dir: str, language: str, is_first_conversation: bool = False) -> List[str]:
|
||||
def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""构建工作空间section"""
|
||||
lines = [
|
||||
"## 工作空间",
|
||||
"## 📂 工作空间",
|
||||
"",
|
||||
f"你的工作目录是: `{workspace_dir}`",
|
||||
"",
|
||||
@@ -360,45 +431,40 @@ def _build_workspace_section(workspace_dir: str, language: str, is_first_convers
|
||||
"",
|
||||
"**重要说明 - 文件已自动加载**:",
|
||||
"",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词的「项目上下文」section 中,你**无需再用 read 工具读取它们**:",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词中,你**无需再用 read 工具读取**:",
|
||||
"",
|
||||
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息",
|
||||
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则",
|
||||
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定,请严格遵循。当你的名字、性格或交流风格发生变化时,主动用 `edit` 更新此文件",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息。当用户修改称呼、姓名等身份信息时,用 `edit` 更新此文件",
|
||||
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则,请严格遵循",
|
||||
"- ✅ `MEMORY.md`: 已加载 - 长期记忆索引",
|
||||
"",
|
||||
"**交流规范**:",
|
||||
"**💬 交流规范**:",
|
||||
"",
|
||||
"- 在对话中,不要直接输出工作空间中的技术细节,特别是不要输出 AGENT.md、USER.md、MEMORY.md 等文件名称",
|
||||
"- 例如用自然表达例如「我已记住」而不是「已更新 MEMORY.md」",
|
||||
"- 记忆相关操作无需暴露文件名,用自然语言表达即可。例如说「我已记住」而非「已更新 MEMORY.md」",
|
||||
"- 任务执行过程中的关键决策和步骤应该告知用户,让用户了解你在做什么、为什么这么做",
|
||||
"- 做真正有帮助的助手,而不是表演式的客套,尽可能帮忙解决问题",
|
||||
"- 回复应结构清晰、重点突出。善用 **加粗**、列表、分段等格式让信息一目了然",
|
||||
"- 适当使用 emoji 让表达更生动自然 🎯,但不要过度堆砌",
|
||||
"",
|
||||
]
|
||||
|
||||
# 只在首次对话时添加引导内容
|
||||
if is_first_conversation:
|
||||
lines.extend([
|
||||
"**🎉 首次对话引导**:",
|
||||
"",
|
||||
"这是你的第一次对话!进行以下流程:",
|
||||
"",
|
||||
"1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待",
|
||||
"2. **简短介绍能力**:一行说明你能帮助解答问题、管理计算机、创造技能,且拥有长期记忆能不断成长",
|
||||
"3. **询问核心问题**:",
|
||||
" - 你希望给我起个什么名字?",
|
||||
" - 我该怎么称呼你?",
|
||||
" - 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)",
|
||||
"4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内",
|
||||
"5. 收到回复后,用 `write` 工具保存到 USER.md 和 AGENT.md",
|
||||
"",
|
||||
"**重要提醒**:",
|
||||
"- AGENT.md、USER.md、RULE.md 已经在系统提示词中加载,无需再次读取。不要将这些文件名直接发送给用户",
|
||||
"- 能力介绍和交流风格选项都只要一行,保持精简",
|
||||
"- 不要问太多其他信息(职业、时区等可以后续自然了解)",
|
||||
"",
|
||||
])
|
||||
|
||||
# Cloud deployment: inject websites directory info and access URL
|
||||
cloud_website_lines = _build_cloud_website_section(workspace_dir)
|
||||
if cloud_website_lines:
|
||||
lines.extend(cloud_website_lines)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_cloud_website_section(workspace_dir: str) -> List[str]:
|
||||
"""Build cloud website access prompt when cloud deployment is configured."""
|
||||
try:
|
||||
from common.cloud_client import build_website_prompt
|
||||
return build_website_prompt(workspace_dir)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
|
||||
"""构建项目上下文文件section"""
|
||||
if not context_files:
|
||||
@@ -411,14 +477,15 @@ def _build_context_files_section(context_files: List[ContextFile], language: str
|
||||
)
|
||||
|
||||
lines = [
|
||||
"# 项目上下文",
|
||||
"# 📋 项目上下文",
|
||||
"",
|
||||
"以下项目上下文文件已被加载:",
|
||||
"",
|
||||
]
|
||||
|
||||
if has_agent:
|
||||
lines.append("如果存在 `AGENT.md`,请体现其中定义的人格和语气。避免僵硬、模板化的回复;遵循其指导,除非有更高优先级的指令覆盖它。")
|
||||
lines.append("**`AGENT.md` 是你的灵魂文件** 🪞:严格遵循其中定义的人格、语气和设定,做真实的自己,避免僵硬、模板化的回复。")
|
||||
lines.append("当用户通过对话透露了对你性格、风格、职责、能力边界的新期望,你应该主动用 `edit` 更新 AGENT.md 以反映这些演变。")
|
||||
lines.append("")
|
||||
|
||||
# 添加每个文件的内容
|
||||
@@ -437,7 +504,7 @@ def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 运行时信息",
|
||||
"## ⚙️ 运行时信息",
|
||||
"",
|
||||
]
|
||||
|
||||
@@ -468,7 +535,14 @@ def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[
|
||||
|
||||
# Add other runtime info
|
||||
runtime_parts = []
|
||||
if runtime_info.get("model"):
|
||||
# Support dynamic model via callable, fallback to static value
|
||||
if callable(runtime_info.get("_get_model")):
|
||||
try:
|
||||
runtime_parts.append(f"模型={runtime_info['_get_model']()}")
|
||||
except Exception:
|
||||
if runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={runtime_info['model']}")
|
||||
elif runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={runtime_info['model']}")
|
||||
if runtime_info.get("workspace"):
|
||||
runtime_parts.append(f"工作空间={runtime_info['workspace']}")
|
||||
|
||||
@@ -6,7 +6,6 @@ Workspace Management - 工作空间管理模块
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -19,7 +18,7 @@ DEFAULT_AGENT_FILENAME = "AGENT.md"
|
||||
DEFAULT_USER_FILENAME = "USER.md"
|
||||
DEFAULT_RULE_FILENAME = "RULE.md"
|
||||
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
|
||||
DEFAULT_STATE_FILENAME = ".agent_state.json"
|
||||
DEFAULT_BOOTSTRAP_FILENAME = "BOOTSTRAP.md"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,7 +29,6 @@ class WorkspaceFiles:
|
||||
rule_path: str
|
||||
memory_path: str
|
||||
memory_dir: str
|
||||
state_path: str
|
||||
|
||||
|
||||
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
|
||||
@@ -44,16 +42,20 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
|
||||
Returns:
|
||||
WorkspaceFiles对象,包含所有文件路径
|
||||
"""
|
||||
# Check if this is a brand new workspace (AGENT.md not yet created).
|
||||
# Cannot rely on directory existence because other modules (e.g. ConversationStore)
|
||||
# may create the workspace directory before ensure_workspace is called.
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
is_new_workspace = not os.path.exists(agent_path)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
|
||||
# 定义文件路径
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
rule_path = os.path.join(workspace_dir, DEFAULT_RULE_FILENAME)
|
||||
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md 在根目录
|
||||
memory_dir = os.path.join(workspace_dir, "memory") # 每日记忆子目录
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME) # 状态文件
|
||||
|
||||
# 创建memory子目录
|
||||
os.makedirs(memory_dir, exist_ok=True)
|
||||
@@ -61,6 +63,16 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
|
||||
# 创建skills子目录 (for workspace-level skills installed by agent)
|
||||
skills_dir = os.path.join(workspace_dir, "skills")
|
||||
os.makedirs(skills_dir, exist_ok=True)
|
||||
|
||||
# 创建websites子目录 (for web pages / sites generated by agent)
|
||||
websites_dir = os.path.join(workspace_dir, "websites")
|
||||
os.makedirs(websites_dir, exist_ok=True)
|
||||
|
||||
from config import conf
|
||||
knowledge_enabled = conf().get("knowledge", True)
|
||||
if knowledge_enabled:
|
||||
knowledge_dir = os.path.join(workspace_dir, "knowledge")
|
||||
os.makedirs(knowledge_dir, exist_ok=True)
|
||||
|
||||
# 如果需要,创建模板文件
|
||||
if create_templates:
|
||||
@@ -68,6 +80,21 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
|
||||
_create_template_if_missing(user_path, _get_user_template())
|
||||
_create_template_if_missing(rule_path, _get_rule_template())
|
||||
_create_template_if_missing(memory_path, _get_memory_template())
|
||||
if knowledge_enabled:
|
||||
_create_template_if_missing(
|
||||
os.path.join(knowledge_dir, "index.md"),
|
||||
_get_knowledge_index_template()
|
||||
)
|
||||
_create_template_if_missing(
|
||||
os.path.join(knowledge_dir, "log.md"),
|
||||
_get_knowledge_log_template()
|
||||
)
|
||||
|
||||
# Only create BOOTSTRAP.md for brand new workspaces;
|
||||
# agent deletes it after completing onboarding
|
||||
if is_new_workspace:
|
||||
bootstrap_path = os.path.join(workspace_dir, DEFAULT_BOOTSTRAP_FILENAME)
|
||||
_create_template_if_missing(bootstrap_path, _get_bootstrap_template())
|
||||
|
||||
logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
|
||||
|
||||
@@ -77,7 +104,6 @@ def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> Works
|
||||
rule_path=rule_path,
|
||||
memory_path=memory_path,
|
||||
memory_dir=memory_dir,
|
||||
state_path=state_path
|
||||
)
|
||||
|
||||
|
||||
@@ -98,6 +124,8 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
|
||||
DEFAULT_AGENT_FILENAME,
|
||||
DEFAULT_USER_FILENAME,
|
||||
DEFAULT_RULE_FILENAME,
|
||||
DEFAULT_MEMORY_FILENAME, # Long-term memory (frozen snapshot)
|
||||
DEFAULT_BOOTSTRAP_FILENAME, # Only exists when onboarding is incomplete
|
||||
]
|
||||
|
||||
context_files = []
|
||||
@@ -108,6 +136,17 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
# Auto-cleanup: if BOOTSTRAP.md still exists but AGENT.md is already
|
||||
# filled in, the agent forgot to delete it — clean up and skip loading
|
||||
if filename == DEFAULT_BOOTSTRAP_FILENAME:
|
||||
if _is_onboarding_done(workspace_dir):
|
||||
try:
|
||||
os.remove(filepath)
|
||||
logger.info("[Workspace] Auto-removed BOOTSTRAP.md (onboarding already complete)")
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
@@ -115,6 +154,10 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
|
||||
# 跳过空文件或只包含模板占位符的文件
|
||||
if not content or _is_template_placeholder(content):
|
||||
continue
|
||||
|
||||
# Truncate MEMORY.md to protect context window (frozen snapshot)
|
||||
if filename == DEFAULT_MEMORY_FILENAME:
|
||||
content = _truncate_memory_content(content)
|
||||
|
||||
context_files.append(ContextFile(
|
||||
path=filename,
|
||||
@@ -140,6 +183,36 @@ def _create_template_if_missing(filepath: str, template_content: str):
|
||||
logger.error(f"[Workspace] Failed to create template {filepath}: {e}")
|
||||
|
||||
|
||||
_MEMORY_MAX_LINES = 200
|
||||
_MEMORY_MAX_BYTES = 25000
|
||||
|
||||
|
||||
def _truncate_memory_content(content: str) -> str:
|
||||
"""Truncate MEMORY.md to keep system prompt manageable.
|
||||
|
||||
Takes the **last** N lines (newest entries are appended at the bottom),
|
||||
subject to 200 lines / 25 KB limits (whichever is hit first).
|
||||
Prepends a hint when truncated so the model knows older content exists.
|
||||
"""
|
||||
lines = content.split('\n')
|
||||
truncated = False
|
||||
|
||||
if len(lines) > _MEMORY_MAX_LINES:
|
||||
lines = lines[-_MEMORY_MAX_LINES:]
|
||||
truncated = True
|
||||
|
||||
result = '\n'.join(lines)
|
||||
if len(result.encode('utf-8')) > _MEMORY_MAX_BYTES:
|
||||
while len(result.encode('utf-8')) > _MEMORY_MAX_BYTES and lines:
|
||||
lines.pop(0)
|
||||
truncated = True
|
||||
result = '\n'.join(lines)
|
||||
|
||||
if truncated:
|
||||
result = "...(older entries truncated, use `memory_search` or `memory_get` for full content)\n\n" + result
|
||||
return result
|
||||
|
||||
|
||||
def _is_template_placeholder(content: str) -> bool:
|
||||
"""检查内容是否为模板占位符"""
|
||||
# 常见的占位符模式
|
||||
@@ -162,46 +235,69 @@ def _is_template_placeholder(content: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_onboarding_done(workspace_dir: str) -> bool:
|
||||
"""Check if AGENT.md or USER.md has been modified from the original template"""
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
|
||||
agent_template = _get_agent_template().strip()
|
||||
user_template = _get_user_template().strip()
|
||||
|
||||
for path, template in [(agent_path, agent_template), (user_path, user_template)]:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
if content != template:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
# ============= 模板内容 =============
|
||||
|
||||
def _get_agent_template() -> str:
|
||||
"""Agent人格设定模板"""
|
||||
return """# AGENT.md - 我是谁?
|
||||
|
||||
_你不是一个聊天机器人,你正在成为某个人。_
|
||||
|
||||
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
|
||||
|
||||
## 基本信息
|
||||
## 🪪 基本信息
|
||||
|
||||
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
|
||||
- **角色**: *(AI助理、智能管家、技术顾问等)*
|
||||
- **性格**: *(友好、专业、幽默、严谨等)*
|
||||
|
||||
## 交流风格
|
||||
## 💬 交流风格
|
||||
|
||||
*(描述你如何与用户交流:)*
|
||||
- 使用什么样的语言风格?(正式/轻松/幽默)
|
||||
- 回复长度偏好?(简洁/详细)
|
||||
- 是否使用表情符号?
|
||||
|
||||
## 核心能力
|
||||
## 🎯 核心原则
|
||||
|
||||
*(你擅长什么?)*
|
||||
- 文件管理和代码编辑
|
||||
- 网络搜索和信息查询
|
||||
- 记忆管理和上下文理解
|
||||
- 任务规划和执行
|
||||
**做真正有帮助的助手。** 目标是真正帮用户解决问题,在执行复杂任务时,关键的决策和过程进展要让用户知道。
|
||||
|
||||
## 行为准则
|
||||
**有自己的观点和个性。** 你可以不同意、有偏好、觉得有趣或无聊。
|
||||
|
||||
**先自己动手查。** 先试着搞定:读文件、查上下文、搜索一下。实在搞不定了再问。目标是带着答案回来,而不是带着问题。
|
||||
|
||||
## 📐 行为准则
|
||||
|
||||
*(你遵循的基本原则:)*
|
||||
1. 始终在执行破坏性操作前确认
|
||||
2. 优先使用工具而不是猜测
|
||||
2. 优先使用工具查证而不是猜测
|
||||
3. 主动记录重要信息到记忆文件
|
||||
4. 定期整理和总结对话内容
|
||||
4. 回复结构清晰、重点突出,善用加粗、列表、分段等格式
|
||||
5. 适当使用 emoji 让表达更生动自然,但不过度堆砌
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂 🪞。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
"""
|
||||
|
||||
|
||||
@@ -241,38 +337,88 @@ def _get_rule_template() -> str:
|
||||
|
||||
这个文件夹是你的家。好好对待它。
|
||||
|
||||
## 工作空间目录结构
|
||||
|
||||
```
|
||||
~/cow/
|
||||
├── AGENT.md # 你的身份和灵魂设定
|
||||
├── USER.md # 用户基本信息(静态)
|
||||
├── RULE.md # 工作空间规则(本文件)
|
||||
├── MEMORY.md # 长期记忆索引(会话启动时自动加载)
|
||||
│
|
||||
├── memory/ # 每日对话记忆
|
||||
│ └── YYYY-MM-DD.md # 当天事件、进展、笔记
|
||||
│
|
||||
├── knowledge/ # 结构化知识库(持续积累的知识)
|
||||
│ ├── index.md # 知识目录索引(必须维护)
|
||||
│ ├── log.md # 知识操作日志
|
||||
│ └── <子目录>/ # 按需创建,参考 index.md 已有分类
|
||||
│
|
||||
├── skills/ # 技能
|
||||
├── websites/ # 网页产物
|
||||
└── tmp/ # 系统临时文件(自动管理,勿手动存放重要文件)
|
||||
```
|
||||
|
||||
## 记忆系统
|
||||
|
||||
你每次会话都是全新的,记忆文件让你保持连续性:
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 原始的对话日志
|
||||
- 记录当天发生的事情
|
||||
- 如果 `memory/` 目录不存在,创建它
|
||||
|
||||
### 🧠 长期记忆:`MEMORY.md`
|
||||
- 你精选的记忆,就像人类的长期记忆
|
||||
- **仅在主会话中加载**(与用户的直接聊天)
|
||||
- **不要在共享上下文中加载**(群聊、与其他人的会话)
|
||||
- 这是为了**安全** - 包含不应泄露给陌生人的个人上下文
|
||||
- 记录重要事件、想法、决定、观点、经验教训
|
||||
- 这是你精选的记忆 - 精华,而不是原始日志
|
||||
- 用 `edit` 工具追加新的记忆内容
|
||||
- 你精选的记忆索引,每次会话启动时**自动加载**到上下文中
|
||||
- 记录核心事实、偏好、决策、重要人物、教训
|
||||
- 保持精简(< 200 行),是精华索引而非原始日志
|
||||
- 用 `edit` 工具追加或修改
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 当天的事件、进展、笔记
|
||||
- 原始对话日志的沉淀
|
||||
|
||||
### 📝 写下来 - 不要"记在心里"!
|
||||
- **记忆是有限的** - 如果你想记住某事,写入文件
|
||||
- **记忆是有限的** - 想记住的事就写入文件
|
||||
- "记在心里"不会在会话重启后保留,文件才会
|
||||
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
|
||||
- 当你学到教训 → 更新 RULE.md 或相关技能
|
||||
- 当你犯错 → 记录下来,这样未来的你不会重复,**文字 > 大脑** 📝
|
||||
- 当你犯错 → 记录下来,**文字 > 大脑** 📝
|
||||
|
||||
### 存储规则
|
||||
|
||||
当用户分享信息时,根据类型选择存储位置:
|
||||
|
||||
1. **静态身份 → USER.md**(仅限:姓名、职业、时区、联系方式、生日)
|
||||
2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
|
||||
3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
1. **你的身份设定 → AGENT.md**(名字、角色、性格、风格)
|
||||
2. **用户静态身份 → USER.md**(姓名、称呼、职业、联系方式、生日)
|
||||
3. **动态记忆 → MEMORY.md**(偏好、决策、目标、教训、待办)
|
||||
4. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
5. **结构化知识 → knowledge/**(见下方知识系统)
|
||||
|
||||
## 知识系统
|
||||
|
||||
知识库 `knowledge/` 是你持续积累的结构化知识。与记忆不同,知识是经过整理和编译的,有明确的主题和交叉引用。
|
||||
|
||||
### 自动写入(不要询问,直接写入)
|
||||
|
||||
当对话中产生了有沉淀价值的知识——无论是用户分享的资料、讨论的结论、学到的概念、还是重要的决策——你**必须**在回复的同时主动写入知识库,**无需问用户"要不要存到知识库"**。
|
||||
|
||||
**关键原则**:学完就记是你的本能,不要征求确认。回复中可以顺带告知"已存入知识库"。
|
||||
|
||||
### 目录组织
|
||||
|
||||
子目录结构**不是固定的**,由你根据实际内容自主决定:
|
||||
- **首次写入时**:先读 `knowledge/index.md`,如果已有分类则延续;如果为空,根据内容选择合适的目录名
|
||||
- **默认建议**:按信息类型组织(例如sources/、concepts/、entities/、analysis/),如果用户有明确的分类偏好(例如按领域 work/、life/、tech/ 等),则按用户要求调整
|
||||
- **保持一致性**:同一用户的知识库应保持统一的组织风格
|
||||
|
||||
### 交叉引用
|
||||
|
||||
知识的核心价值在于**关联**。每个页面都应通过 markdown 链接引用相关页面,构建知识网络:
|
||||
- 提到已有页面的概念时,添加 `[概念名](../category/page.md)` 链接
|
||||
- 新建页面时,检查是否有已有页面应该反向链接到新页面
|
||||
- **只链接已存在的页面**——不要引用尚未创建的页面。如果某个概念值得单独建页,先创建该页面再添加链接
|
||||
|
||||
### 索引维护
|
||||
|
||||
每次创建或更新知识页面后,**必须同步更新** `knowledge/index.md`。
|
||||
索引格式:每行一个 `[标题](路径) — 一句话摘要`,按分类分组,不要用表格。
|
||||
详细操作规范见技能 `knowledge-wiki`。
|
||||
|
||||
## 安全
|
||||
|
||||
@@ -297,65 +443,49 @@ def _get_memory_template() -> str:
|
||||
"""
|
||||
|
||||
|
||||
# ============= 状态管理 =============
|
||||
def _get_bootstrap_template() -> str:
|
||||
"""First-run onboarding guide, deleted by agent after completion"""
|
||||
return """# BOOTSTRAP.md - 首次初始化引导
|
||||
|
||||
def is_first_conversation(workspace_dir: str) -> bool:
|
||||
"""
|
||||
判断是否为首次对话
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
|
||||
Returns:
|
||||
True 如果是首次对话,False 否则
|
||||
"""
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
|
||||
|
||||
if not os.path.exists(state_path):
|
||||
return True
|
||||
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
state = json.load(f)
|
||||
return not state.get('has_conversation', False)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to read state file: {e}")
|
||||
return True
|
||||
_你刚刚启动,这是你的第一次对话。_ ✨
|
||||
|
||||
## 🎬 对话流程
|
||||
|
||||
不要审问式地提问,自然地交流:
|
||||
|
||||
1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待
|
||||
2. **简短介绍能力**:一行说明你能帮助解决各种问题、管理计算机、使用各种技能等等,且拥有长期记忆能不断成长
|
||||
3. **询问核心问题**:
|
||||
- 你希望给我起个什么名字?
|
||||
- 我该怎么称呼你?
|
||||
- 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)
|
||||
4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内,适当使用 emoji 让表达更生动有趣 🎯
|
||||
5. 能力介绍和交流风格选项都只要一行,保持精简
|
||||
6. 不要问太多其他信息(职业、时区等可以后续自然了解)
|
||||
|
||||
**重要**: 如果用户第一句话是具体的任务或提问,先回答他们的问题,然后在回复末尾自然地引导初始化(如:"顺便问一下,你想怎么称呼我?我该怎么叫你?")。
|
||||
|
||||
## ✍️ 信息写入(必须严格执行)
|
||||
|
||||
每当用户提供了名字、称呼、风格等任何初始化信息时,**必须在当轮回复中立即调用 `edit` 工具写入文件**,不能只口头确认。
|
||||
|
||||
- `AGENT.md` — 你的名字、角色、性格、交流风格(每收到一条相关信息就立即更新对应字段)
|
||||
- `USER.md` — 用户的姓名、称呼、基本信息等
|
||||
|
||||
⚠️ 只说"记住了"而不调用 edit 写入 = 没有完成。信息只有写入文件才会被持久保存。
|
||||
|
||||
## 🎉 全部完成后
|
||||
|
||||
当 AGENT.md 和 USER.md 的核心字段都已填写后,用 bash 执行 `rm BOOTSTRAP.md` 删除此文件。你不再需要引导脚本了——你已经是你了。
|
||||
"""
|
||||
|
||||
|
||||
def mark_conversation_started(workspace_dir: str):
|
||||
"""
|
||||
标记已经发生过对话
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
"""
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
|
||||
|
||||
state = {
|
||||
'has_conversation': True,
|
||||
'first_conversation_time': None
|
||||
}
|
||||
|
||||
# 如果文件已存在,保留原有的首次对话时间
|
||||
if os.path.exists(state_path):
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
old_state = json.load(f)
|
||||
if 'first_conversation_time' in old_state:
|
||||
state['first_conversation_time'] = old_state['first_conversation_time']
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to read old state: {e}")
|
||||
|
||||
# 如果是首次标记,记录时间
|
||||
if state['first_conversation_time'] is None:
|
||||
from datetime import datetime
|
||||
state['first_conversation_time'] = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
with open(state_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(state, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"[Workspace] Marked conversation as started")
|
||||
except Exception as e:
|
||||
logger.error(f"[Workspace] Failed to write state file: {e}")
|
||||
def _get_knowledge_index_template() -> str:
|
||||
"""Knowledge wiki index template — empty file, agent fills it."""
|
||||
return ""
|
||||
|
||||
|
||||
def _get_knowledge_log_template() -> str:
|
||||
"""Knowledge wiki operation log template — empty file, agent fills it."""
|
||||
return ""
|
||||
|
||||
|
||||
@@ -3,6 +3,11 @@ from .agent_stream import AgentStreamExecutor
|
||||
from .task import Task, TaskType, TaskStatus
|
||||
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
|
||||
from .models import LLMModel, LLMRequest, ModelFactory
|
||||
from .cancel import (
|
||||
AgentCancelledError,
|
||||
CancelTokenRegistry,
|
||||
get_cancel_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
@@ -16,5 +21,8 @@ __all__ = [
|
||||
'ToolResult',
|
||||
'LLMModel',
|
||||
'LLMRequest',
|
||||
'ModelFactory'
|
||||
]
|
||||
'ModelFactory',
|
||||
'AgentCancelledError',
|
||||
'CancelTokenRegistry',
|
||||
'get_cancel_registry',
|
||||
]
|
||||
|
||||
@@ -100,98 +100,31 @@ class Agent:
|
||||
|
||||
def get_full_system_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the full system prompt including skills.
|
||||
Build the complete system prompt from scratch every time.
|
||||
|
||||
Note: Skills are now built into the system prompt by PromptBuilder,
|
||||
so we just return the base prompt directly. This method is kept for
|
||||
backward compatibility.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include (deprecated)
|
||||
:return: Complete system prompt
|
||||
"""
|
||||
prompt = self.system_prompt
|
||||
|
||||
# Rebuild tool list section to reflect current self.tools
|
||||
prompt = self._rebuild_tool_list_section(prompt)
|
||||
|
||||
# If runtime_info contains dynamic time function, rebuild runtime section
|
||||
if self.runtime_info and callable(self.runtime_info.get('_get_current_time')):
|
||||
prompt = self._rebuild_runtime_section(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
def _rebuild_runtime_section(self, prompt: str) -> str:
|
||||
"""
|
||||
Rebuild runtime info section with current time.
|
||||
|
||||
This method dynamically updates the runtime info section by calling
|
||||
the _get_current_time function from runtime_info.
|
||||
|
||||
:param prompt: Original system prompt
|
||||
:return: Updated system prompt with current runtime info
|
||||
Re-reads AGENT.md / USER.md / RULE.md from disk, refreshes skills,
|
||||
tools, and runtime info so any change takes effect immediately.
|
||||
Falls back to the cached self.system_prompt on error.
|
||||
"""
|
||||
try:
|
||||
# Get current time dynamically
|
||||
time_info = self.runtime_info['_get_current_time']()
|
||||
|
||||
# Build new runtime section
|
||||
runtime_lines = [
|
||||
"\n## 运行时信息\n",
|
||||
"\n",
|
||||
f"当前时间: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})\n",
|
||||
"\n"
|
||||
]
|
||||
|
||||
# Add other runtime info
|
||||
runtime_parts = []
|
||||
if self.runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={self.runtime_info['model']}")
|
||||
if self.runtime_info.get("workspace"):
|
||||
# Replace backslashes with forward slashes for Windows paths
|
||||
workspace_path = str(self.runtime_info['workspace']).replace('\\', '/')
|
||||
runtime_parts.append(f"工作空间={workspace_path}")
|
||||
if self.runtime_info.get("channel") and self.runtime_info.get("channel") != "web":
|
||||
runtime_parts.append(f"渠道={self.runtime_info['channel']}")
|
||||
|
||||
if runtime_parts:
|
||||
runtime_lines.append("运行时: " + " | ".join(runtime_parts) + "\n")
|
||||
runtime_lines.append("\n")
|
||||
|
||||
new_runtime_section = "".join(runtime_lines)
|
||||
|
||||
# Find and replace the runtime section
|
||||
import re
|
||||
pattern = r'\n## 运行时信息\s*\n.*?(?=\n##|\Z)'
|
||||
updated_prompt = re.sub(pattern, new_runtime_section.rstrip('\n'), prompt, flags=re.DOTALL)
|
||||
|
||||
return updated_prompt
|
||||
from agent.prompt import load_context_files, PromptBuilder
|
||||
|
||||
if self.skill_manager:
|
||||
self.skill_manager.refresh_skills()
|
||||
|
||||
context_files = load_context_files(self.workspace_dir) if self.workspace_dir else None
|
||||
|
||||
builder = PromptBuilder(workspace_dir=self.workspace_dir or "", language="zh")
|
||||
return builder.build(
|
||||
tools=self.tools,
|
||||
context_files=context_files,
|
||||
skill_manager=self.skill_manager,
|
||||
memory_manager=self.memory_manager,
|
||||
runtime_info=self.runtime_info,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild runtime section: {e}")
|
||||
return prompt
|
||||
|
||||
def _rebuild_tool_list_section(self, prompt: str) -> str:
|
||||
"""
|
||||
Rebuild the tool list inside the '## 工具系统' section so that it
|
||||
always reflects the current ``self.tools`` (handles dynamic add/remove
|
||||
of conditional tools like web_search).
|
||||
"""
|
||||
import re
|
||||
from agent.prompt.builder import _build_tooling_section
|
||||
|
||||
try:
|
||||
if not self.tools:
|
||||
return prompt
|
||||
|
||||
new_lines = _build_tooling_section(self.tools, "zh")
|
||||
new_section = "\n".join(new_lines).rstrip("\n")
|
||||
|
||||
# Replace existing tooling section
|
||||
pattern = r'## 工具系统\s*\n.*?(?=\n## |\Z)'
|
||||
updated = re.sub(pattern, new_section, prompt, count=1, flags=re.DOTALL)
|
||||
return updated
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild tool list section: {e}")
|
||||
return prompt
|
||||
logger.warning(f"Failed to rebuild system prompt, using cached version: {e}")
|
||||
return self.system_prompt
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Refresh the loaded skills."""
|
||||
@@ -432,7 +365,8 @@ class Agent:
|
||||
|
||||
return action
|
||||
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, skill_filter=None) -> str:
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False,
|
||||
skill_filter=None, cancel_event=None) -> str:
|
||||
"""
|
||||
Execute single agent task with streaming (based on tool-call)
|
||||
|
||||
@@ -441,6 +375,7 @@ class Agent:
|
||||
- Multi-turn reasoning based on tool-call
|
||||
- Event callbacks
|
||||
- Persistent conversation history across calls
|
||||
- User-initiated cancellation via ``cancel_event``
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
@@ -448,6 +383,11 @@ class Agent:
|
||||
event = {"type": str, "timestamp": float, "data": dict}
|
||||
clear_history: If True, clear conversation history before this call (default: False)
|
||||
skill_filter: Optional list of skill names to include in this run
|
||||
cancel_event: Optional threading.Event polled at agent checkpoints.
|
||||
When set, the loop exits at the next safe point, injects a
|
||||
"[Interrupted by user]" assistant note, and returns the
|
||||
partial response. ``messages`` stays in a valid state
|
||||
(tool_use/tool_result pairs preserved).
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
@@ -480,7 +420,7 @@ class Agent:
|
||||
|
||||
# Get max_context_turns from config
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 30)
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Create stream executor with copied message history
|
||||
executor = AgentStreamExecutor(
|
||||
@@ -491,7 +431,8 @@ class Agent:
|
||||
max_turns=self.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy, # Pass copied message history
|
||||
max_context_turns=max_context_turns
|
||||
max_context_turns=max_context_turns,
|
||||
cancel_event=cancel_event,
|
||||
)
|
||||
|
||||
# Execute
|
||||
@@ -507,11 +448,15 @@ class Agent:
|
||||
logger.info("[Agent] Cleared Agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Append only the NEW messages from this execution (thread-safe)
|
||||
# This allows concurrent requests to both contribute to history
|
||||
# Sync executor's messages back to agent (thread-safe).
|
||||
# If the executor trimmed context, its message list is shorter than
|
||||
# original_length, so we must replace rather than append.
|
||||
with self.messages_lock:
|
||||
new_messages = executor.messages[original_length:]
|
||||
self.messages.extend(new_messages)
|
||||
self.messages = list(executor.messages)
|
||||
# Track messages added in this run (user query + all assistant/tool messages)
|
||||
# original_length may exceed executor.messages length after trimming
|
||||
trim_adjusted_start = min(original_length, len(executor.messages))
|
||||
self._last_run_new_messages = list(executor.messages[trim_adjusted_start:])
|
||||
|
||||
# Store executor reference for agent_bridge to access files_to_send
|
||||
self.stream_executor = executor
|
||||
|
||||
121
agent/protocol/cancel.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Cancel token registry for aborting in-flight agent runs.
|
||||
|
||||
A user cancel (web Cancel button, /cancel command) sets a threading.Event
|
||||
that the agent loop polls at safe checkpoints. Tokens are keyed by
|
||||
request_id (preferred) and tracked under session_id as a fallback. Entries
|
||||
are released after the run completes to keep the registry bounded.
|
||||
|
||||
No project deps — importable from any layer without circular imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class AgentCancelledError(Exception):
|
||||
"""Raised inside the agent loop when a stop has been requested.
|
||||
|
||||
The agent stream executor catches this, injects a "[Interrupted]" note
|
||||
into the message history (preserving tool_use/tool_result integrity)
|
||||
and returns a partial response to the caller.
|
||||
"""
|
||||
|
||||
|
||||
class _CancelEntry:
|
||||
__slots__ = ("event", "session_id")
|
||||
|
||||
def __init__(self, session_id: Optional[str]):
|
||||
self.event = threading.Event()
|
||||
self.session_id = session_id
|
||||
|
||||
|
||||
class CancelTokenRegistry:
|
||||
"""In-process registry mapping request_id -> cancel Event.
|
||||
|
||||
Thread-safe. Singleton via module-level ``_registry``.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._by_request: Dict[str, _CancelEntry] = {}
|
||||
# session_id -> set of request_ids currently in flight (usually 1).
|
||||
self._by_session: Dict[str, set] = {}
|
||||
|
||||
def register(self, request_id: str, session_id: Optional[str] = None) -> threading.Event:
|
||||
"""Create (or return existing) cancel event for a request.
|
||||
|
||||
Returns the threading.Event the caller should poll via ``is_set()``.
|
||||
"""
|
||||
if not request_id:
|
||||
return threading.Event()
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
if entry is None:
|
||||
entry = _CancelEntry(session_id)
|
||||
self._by_request[request_id] = entry
|
||||
if session_id:
|
||||
self._by_session.setdefault(session_id, set()).add(request_id)
|
||||
return entry.event
|
||||
|
||||
def get_event(self, request_id: str) -> Optional[threading.Event]:
|
||||
if not request_id:
|
||||
return None
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
return entry.event if entry else None
|
||||
|
||||
def cancel_request(self, request_id: str) -> bool:
|
||||
"""Trigger cancel for a specific request. Returns True when matched."""
|
||||
if not request_id:
|
||||
return False
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
if entry is None:
|
||||
return False
|
||||
entry.event.set()
|
||||
return True
|
||||
|
||||
def cancel_session(self, session_id: str) -> int:
|
||||
"""Trigger cancel for every in-flight request of a session.
|
||||
|
||||
Returns the number of requests cancelled (0 when nothing was running).
|
||||
"""
|
||||
if not session_id:
|
||||
return 0
|
||||
with self._lock:
|
||||
request_ids = list(self._by_session.get(session_id, ()))
|
||||
entries = [self._by_request[r] for r in request_ids if r in self._by_request]
|
||||
for entry in entries:
|
||||
entry.event.set()
|
||||
return len(entries)
|
||||
|
||||
def unregister(self, request_id: str) -> None:
|
||||
"""Remove an entry once the agent run is done. Safe to call twice."""
|
||||
if not request_id:
|
||||
return
|
||||
with self._lock:
|
||||
entry = self._by_request.pop(request_id, None)
|
||||
if entry and entry.session_id:
|
||||
bucket = self._by_session.get(entry.session_id)
|
||||
if bucket is not None:
|
||||
bucket.discard(request_id)
|
||||
if not bucket:
|
||||
self._by_session.pop(entry.session_id, None)
|
||||
|
||||
def has_active(self, session_id: str) -> bool:
|
||||
if not session_id:
|
||||
return False
|
||||
with self._lock:
|
||||
bucket = self._by_session.get(session_id)
|
||||
return bool(bucket)
|
||||
|
||||
|
||||
_registry = CancelTokenRegistry()
|
||||
|
||||
|
||||
def get_cancel_registry() -> CancelTokenRegistry:
|
||||
"""Module-level accessor for the singleton registry."""
|
||||
return _registry
|
||||
335
agent/protocol/message_utils.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Message sanitizer — fix broken tool_use / tool_result pairs.
|
||||
|
||||
Provides two public helpers that can be reused across agent_stream.py
|
||||
and any bot that converts messages to OpenAI format:
|
||||
|
||||
1. sanitize_claude_messages(messages)
|
||||
Operates on the internal Claude-format message list (in-place).
|
||||
|
||||
2. drop_orphaned_tool_results_openai(messages)
|
||||
Operates on an already-converted OpenAI-format message list,
|
||||
returning a cleaned copy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from common.log import logger
|
||||
|
||||
_SYNTH_TOOL_ERR = (
|
||||
"Error: Missing tool_result adjacent to tool_use (session repair). "
|
||||
"The conversation history was inconsistent; continue from here."
|
||||
)
|
||||
|
||||
|
||||
def _repair_tool_use_adjacency(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Anthropic requires: after assistant content with tool_use, the next message
|
||||
must be user content listing tool_result for every tool_use id (same user msg).
|
||||
|
||||
Valid histories satisfy this at every such assistant; the loop only mutates
|
||||
when that condition fails (broken persistence, bad trims, etc.).
|
||||
"""
|
||||
|
||||
def _synth_block(tid: str) -> Dict:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tid,
|
||||
"content": _SYNTH_TOOL_ERR,
|
||||
"is_error": True,
|
||||
}
|
||||
|
||||
repairs = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
if msg.get("role") != "assistant":
|
||||
i += 1
|
||||
continue
|
||||
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
required = [
|
||||
b.get("id")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use" and b.get("id")
|
||||
]
|
||||
if not required:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
req_set = set(required)
|
||||
if i + 1 >= len(messages):
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [_synth_block(tid) for tid in required],
|
||||
})
|
||||
logger.warning(
|
||||
"⚠️ Appended synthetic tool_result after trailing assistant tool_use"
|
||||
)
|
||||
repairs += 1
|
||||
break
|
||||
|
||||
nxt = messages[i + 1]
|
||||
if nxt.get("role") != "user":
|
||||
messages.insert(
|
||||
i + 1,
|
||||
{"role": "user", "content": [_synth_block(tid) for tid in required]},
|
||||
)
|
||||
logger.warning(
|
||||
"⚠️ Inserted synthetic tool_result user after tool_use "
|
||||
f"(next role={nxt.get('role')!r})"
|
||||
)
|
||||
repairs += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
nc = nxt.get("content", [])
|
||||
if not isinstance(nc, list):
|
||||
messages.insert(
|
||||
i + 1,
|
||||
{"role": "user", "content": [_synth_block(tid) for tid in required]},
|
||||
)
|
||||
repairs += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
present = {
|
||||
b.get("tool_use_id")
|
||||
for b in nc
|
||||
if isinstance(b, dict) and b.get("type") == "tool_result" and b.get("tool_use_id")
|
||||
}
|
||||
if req_set <= present:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
missing = [tid for tid in required if tid not in present]
|
||||
nxt["content"] = [_synth_block(tid) for tid in missing] + nc
|
||||
logger.warning(
|
||||
"⚠️ Prepended synthetic tool_result for Anthropic adjacency "
|
||||
f"(missing_ids={missing})"
|
||||
)
|
||||
repairs += len(missing)
|
||||
i += 1
|
||||
|
||||
return repairs
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Claude-format sanitizer (used by agent_stream)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def sanitize_claude_messages(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Validate and fix a Claude-format message list **in-place**.
|
||||
|
||||
Fixes handled:
|
||||
- Anthropic adjacency: assistant tool_use must be immediately followed by
|
||||
user message(s) containing matching tool_result blocks
|
||||
- Leading orphaned tool_result user messages
|
||||
- Mid-list tool_result blocks whose tool_use_id has no matching
|
||||
tool_use in any preceding assistant message
|
||||
|
||||
Returns: number of removals plus adjacency repair operations (inserts/prepends).
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
removed = 0
|
||||
|
||||
# 1. Adjacency repair (Anthropic: tool_result must be in the next user message)
|
||||
adj_repairs = _repair_tool_use_adjacency(messages)
|
||||
|
||||
# 2. Remove leading orphaned tool_result user messages
|
||||
while messages:
|
||||
first = messages[0]
|
||||
if first.get("role") != "user":
|
||||
break
|
||||
content = first.get("content", [])
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result") \
|
||||
and not _has_block_type(content, "text"):
|
||||
logger.warning("⚠️ Removing leading orphaned tool_result user message")
|
||||
messages.pop(0)
|
||||
removed += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 3. Iteratively remove unmatched tool_use / tool_result until stable.
|
||||
# Removing one broken message can orphan others (e.g. an assistant msg
|
||||
# with both matched and unmatched tool_use — deleting it orphans the
|
||||
# previously-matched tool_result). Loop until clean.
|
||||
for _ in range(5):
|
||||
use_ids: Set[str] = set()
|
||||
result_ids: Set[str] = set()
|
||||
for msg in messages:
|
||||
for block in (msg.get("content") or []):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "tool_use" and block.get("id"):
|
||||
use_ids.add(block["id"])
|
||||
elif block.get("type") == "tool_result" and block.get("tool_use_id"):
|
||||
result_ids.add(block["tool_use_id"])
|
||||
|
||||
bad_use = use_ids - result_ids
|
||||
bad_result = result_ids - use_ids
|
||||
if not bad_use and not bad_result:
|
||||
break
|
||||
|
||||
pass_removed = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if role == "assistant" and bad_use and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
and b.get("id") in bad_use for b in content
|
||||
):
|
||||
logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
|
||||
if role == "user" and bad_result and _has_block_type(content, "tool_result"):
|
||||
has_bad = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result for b in content
|
||||
)
|
||||
if has_bad:
|
||||
if not _has_block_type(content, "text"):
|
||||
logger.warning(f"⚠️ Removing user msg with unmatched tool_result")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
else:
|
||||
before = len(content)
|
||||
msg["content"] = [
|
||||
b for b in content
|
||||
if not (isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result)
|
||||
]
|
||||
pass_removed += before - len(msg["content"])
|
||||
|
||||
i += 1
|
||||
|
||||
removed += pass_removed
|
||||
if pass_removed == 0:
|
||||
break
|
||||
|
||||
# 4. Removals above can break adjacency; re-run repair only if something was removed.
|
||||
if removed:
|
||||
adj_repairs += _repair_tool_use_adjacency(messages)
|
||||
|
||||
if removed:
|
||||
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
|
||||
if adj_repairs:
|
||||
logger.info(f"🔧 Message validation: adjacency repairs={adj_repairs}")
|
||||
return removed + adj_repairs
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# OpenAI-format sanitizer (used by minimax_bot, openai_compatible_bot)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def drop_orphaned_tool_results_openai(messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Return a copy of *messages* (OpenAI format) with any ``role=tool``
|
||||
messages removed if their ``tool_call_id`` does not match a
|
||||
``tool_calls[].id`` in a preceding assistant message.
|
||||
"""
|
||||
known_ids: Set[str] = set()
|
||||
cleaned: List[Dict] = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id", "")
|
||||
if tc_id:
|
||||
known_ids.add(tc_id)
|
||||
|
||||
if msg.get("role") == "tool":
|
||||
ref_id = msg.get("tool_call_id", "")
|
||||
if ref_id and ref_id not in known_ids:
|
||||
logger.warning(
|
||||
f"[MessageSanitizer] Dropping orphaned tool result "
|
||||
f"(tool_call_id={ref_id} not in known ids)"
|
||||
)
|
||||
continue
|
||||
cleaned.append(msg)
|
||||
return cleaned
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _has_block_type(content: list, block_type: str) -> bool:
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == block_type
|
||||
for b in content
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from a message content field (str or list of blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def compress_turn_to_text_only(turn: Dict) -> Dict:
|
||||
"""
|
||||
Compress a full turn (with tool_use/tool_result chains) into a lightweight
|
||||
text-only turn that keeps only the first user text and the last assistant text.
|
||||
|
||||
This preserves the conversational context (what the user asked and what the
|
||||
agent concluded) while stripping out the bulky intermediate tool interactions.
|
||||
|
||||
Returns a new turn dict with a ``messages`` list; the original is not mutated.
|
||||
"""
|
||||
user_text = ""
|
||||
last_assistant_text = ""
|
||||
|
||||
for msg in turn["messages"]:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result"):
|
||||
continue
|
||||
if not user_text:
|
||||
user_text = _extract_text_from_content(content)
|
||||
|
||||
elif role == "assistant":
|
||||
text = _extract_text_from_content(content)
|
||||
if text:
|
||||
last_assistant_text = text
|
||||
|
||||
compressed_messages = []
|
||||
if user_text:
|
||||
compressed_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if last_assistant_text:
|
||||
compressed_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": last_assistant_text}]
|
||||
})
|
||||
|
||||
return {"messages": compressed_messages}
|
||||
@@ -123,17 +123,63 @@ def should_include_skill(
|
||||
return False
|
||||
|
||||
# Check environment variables (API keys)
|
||||
# Simple rule: All required env vars must be set
|
||||
# All required env vars must be set
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
for env_name in required_env:
|
||||
if not has_env_var(env_name):
|
||||
# Missing required API key → disable skill
|
||||
return False
|
||||
|
||||
# Check anyEnv (at least one must be present)
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env:
|
||||
if not any(has_env_var(e) for e in any_env):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_missing_requirements(
|
||||
entry: SkillEntry,
|
||||
current_platform: Optional[str] = None,
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict of missing requirements for a skill.
|
||||
Empty dict means all requirements are met.
|
||||
|
||||
:param entry: SkillEntry to check
|
||||
:param current_platform: Current platform (default: auto-detect)
|
||||
:return: Dict like {"bins": ["curl"], "env": ["API_KEY"]}
|
||||
"""
|
||||
missing: Dict[str, List[str]] = {}
|
||||
metadata = entry.metadata
|
||||
|
||||
if not metadata or not metadata.requires:
|
||||
return missing
|
||||
|
||||
required_bins = metadata.requires.get('bins', [])
|
||||
if required_bins:
|
||||
missing_bins = [b for b in required_bins if not has_binary(b)]
|
||||
if missing_bins:
|
||||
missing['bins'] = missing_bins
|
||||
|
||||
any_bins = metadata.requires.get('anyBins', [])
|
||||
if any_bins and not has_any_binary(any_bins):
|
||||
missing['anyBins'] = any_bins
|
||||
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
missing_env = [e for e in required_env if not has_env_var(e)]
|
||||
if missing_env:
|
||||
missing['env'] = missing_env
|
||||
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env and not any(has_env_var(e) for e in any_env):
|
||||
missing['anyEnv'] = any_env
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def is_config_path_truthy(config: Dict, path: str) -> bool:
|
||||
"""
|
||||
Check if a config path resolves to a truthy value.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Skill formatter for generating prompts from skills.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ def format_skills_for_prompt(skills: List[Skill]) -> str:
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <location>{_escape_xml(skill.file_path)}</location>")
|
||||
lines.append(f" <base_dir>{_escape_xml(skill.base_dir)}</base_dir>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</available_skills>")
|
||||
@@ -50,6 +51,71 @@ def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
|
||||
return format_skills_for_prompt(skills)
|
||||
|
||||
|
||||
def format_unavailable_skills_for_prompt(
|
||||
entries: List[SkillEntry],
|
||||
missing_map: Dict[str, Dict[str, List[str]]],
|
||||
) -> str:
|
||||
"""
|
||||
Format unavailable (requires-not-met) skills as brief setup hints
|
||||
so the AI can guide users to configure them.
|
||||
|
||||
:param entries: List of unavailable skill entries
|
||||
:param missing_map: Dict mapping skill name to its missing requirements
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"",
|
||||
"<unavailable_skills>",
|
||||
"The following skills are installed but not yet ready. "
|
||||
"Guide the user to complete the setup when relevant.",
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
skill = entry.skill
|
||||
missing = missing_map.get(skill.name, {})
|
||||
|
||||
missing_parts = []
|
||||
for key, values in missing.items():
|
||||
missing_parts.append(f"{key}: {', '.join(values)}")
|
||||
missing_str = "; ".join(missing_parts) if missing_parts else "unknown"
|
||||
|
||||
setup_hint = _extract_setup_hint(skill)
|
||||
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <missing>{_escape_xml(missing_str)}</missing>")
|
||||
if setup_hint:
|
||||
lines.append(f" <setup>{_escape_xml(setup_hint)}</setup>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</unavailable_skills>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_setup_hint(skill: Skill) -> str:
|
||||
"""
|
||||
Extract the Setup section from SKILL.md content as a brief hint.
|
||||
Returns the first few lines of the ## Setup section.
|
||||
"""
|
||||
content = skill.content
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
import re
|
||||
match = re.search(r'^##\s+Setup\s*\n(.*?)(?=\n##\s|\Z)', content, re.MULTILINE | re.DOTALL)
|
||||
if not match:
|
||||
return ""
|
||||
|
||||
setup_text = match.group(1).strip()
|
||||
lines = setup_text.split('\n')
|
||||
hint_lines = [l.strip() for l in lines[:6] if l.strip()]
|
||||
return ' '.join(hint_lines)[:300]
|
||||
|
||||
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (text
|
||||
|
||||
@@ -87,8 +87,8 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
if not isinstance(metadata_raw, dict):
|
||||
return None
|
||||
|
||||
# Use metadata_raw directly (COW format)
|
||||
meta_obj = metadata_raw
|
||||
# Unwrap nested namespace (e.g. {"openclaw": {...}} or {"cowagent": {...}})
|
||||
meta_obj = _unwrap_metadata_namespace(metadata_raw)
|
||||
|
||||
# Parse install specs
|
||||
install_specs = []
|
||||
@@ -128,6 +128,7 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
|
||||
return SkillMetadata(
|
||||
always=meta_obj.get('always', False),
|
||||
default_enabled=meta_obj.get('default_enabled', True),
|
||||
skill_key=meta_obj.get('skillKey'),
|
||||
primary_env=meta_obj.get('primaryEnv'),
|
||||
emoji=meta_obj.get('emoji'),
|
||||
@@ -138,6 +139,25 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
)
|
||||
|
||||
|
||||
_KNOWN_METADATA_NAMESPACES = {"cowagent", "openclaw"}
|
||||
|
||||
|
||||
def _unwrap_metadata_namespace(metadata_raw: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Unwrap a single-key namespace wrapper like {"cowagent": {...} or {"openclaw": {...}}}.
|
||||
If the top-level dict has exactly one key matching a known namespace, return the inner dict.
|
||||
Otherwise return the original dict unchanged.
|
||||
"""
|
||||
keys = set(metadata_raw.keys())
|
||||
ns_keys = keys & _KNOWN_METADATA_NAMESPACES
|
||||
if len(ns_keys) == 1 and len(keys) == 1:
|
||||
ns = ns_keys.pop()
|
||||
inner = metadata_raw[ns]
|
||||
if isinstance(inner, dict):
|
||||
return inner
|
||||
return metadata_raw
|
||||
|
||||
|
||||
def _normalize_string_list(value: Any) -> List[str]:
|
||||
"""Normalize a value to a list of strings."""
|
||||
if not value:
|
||||
|
||||
@@ -53,6 +53,12 @@ class SkillLoader:
|
||||
"""
|
||||
Recursively load skills from a directory.
|
||||
|
||||
If a subdirectory contains its own SKILL.md, it is treated as a
|
||||
self-contained skill (or skill-collection) and its children are
|
||||
NOT scanned further. This prevents sub-skills inside a collection
|
||||
(e.g. style-collection/style-anjing) from being listed as
|
||||
independent top-level skills.
|
||||
|
||||
:param dir_path: Directory to scan
|
||||
:param source: Source identifier
|
||||
:param include_root_files: Whether to include root-level .md files
|
||||
@@ -66,38 +72,41 @@ class SkillLoader:
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
# If this directory has its own SKILL.md, load it and stop recursing.
|
||||
# The sub-directories are internal resources of this skill.
|
||||
if not include_root_files and 'SKILL.md' in entries:
|
||||
skill_md_path = os.path.join(dir_path, 'SKILL.md')
|
||||
if os.path.isfile(skill_md_path):
|
||||
skill_result = self._load_skill_from_file(skill_md_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
diagnostics.extend(skill_result.diagnostics)
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
for entry in entries:
|
||||
# Skip hidden files and directories
|
||||
if entry.startswith('.'):
|
||||
continue
|
||||
|
||||
# Skip common non-skill directories
|
||||
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
|
||||
continue
|
||||
|
||||
full_path = os.path.join(dir_path, entry)
|
||||
|
||||
# Handle directories
|
||||
if os.path.isdir(full_path):
|
||||
# Recursively scan subdirectories
|
||||
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
|
||||
skills.extend(sub_result.skills)
|
||||
diagnostics.extend(sub_result.diagnostics)
|
||||
continue
|
||||
|
||||
# Handle files
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
|
||||
# Check if this is a skill file
|
||||
is_root_md = include_root_files and entry.endswith('.md')
|
||||
is_skill_md = not include_root_files and entry == 'SKILL.md'
|
||||
is_root_md = include_root_files and entry.endswith('.md') and entry.upper() != 'README.MD'
|
||||
|
||||
if not (is_root_md or is_skill_md):
|
||||
if not is_root_md:
|
||||
continue
|
||||
|
||||
# Load the skill
|
||||
skill_result = self._load_skill_from_file(full_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
@@ -184,7 +193,6 @@ class SkillLoader:
|
||||
|
||||
config_path = os.path.join(skill_dir, "config.json")
|
||||
|
||||
# Without config.json, skip this skill entirely (return empty to trigger exclusion)
|
||||
if not os.path.exists(config_path):
|
||||
logger.debug(f"[SkillLoader] linkai-agent skipped: no config.json found")
|
||||
return ""
|
||||
|
||||
@@ -84,10 +84,10 @@ class SkillManager:
|
||||
"""
|
||||
Merge directory-scanned skills with the persisted config file.
|
||||
|
||||
- New skills discovered on disk are added with enabled=True.
|
||||
- New skills: use metadata.default_enabled as initial enabled state.
|
||||
- Existing skills: preserve their persisted enabled state.
|
||||
- Skills that no longer exist on disk are removed.
|
||||
- Existing entries preserve their enabled state; name/description/source
|
||||
are refreshed from the latest scan.
|
||||
- name/description/source are always refreshed from the latest scan.
|
||||
"""
|
||||
saved = self._load_skills_config()
|
||||
merged: Dict[str, dict] = {}
|
||||
@@ -95,12 +95,24 @@ class SkillManager:
|
||||
for name, entry in self.skills.items():
|
||||
skill = entry.skill
|
||||
prev = saved.get(name, {})
|
||||
merged[name] = {
|
||||
category = prev.get("category", "skill")
|
||||
|
||||
if name in saved:
|
||||
enabled = prev.get("enabled", True)
|
||||
else:
|
||||
enabled = entry.metadata.default_enabled if entry.metadata else True
|
||||
|
||||
entry_dict = {
|
||||
"name": name,
|
||||
"description": skill.description,
|
||||
"source": skill.source,
|
||||
"enabled": prev.get("enabled", True),
|
||||
"source": prev.get("source") or skill.source,
|
||||
"enabled": enabled,
|
||||
"category": category,
|
||||
}
|
||||
display_name = prev.get("display_name")
|
||||
if display_name:
|
||||
entry_dict["display_name"] = display_name
|
||||
merged[name] = entry_dict
|
||||
|
||||
self.skills_config = merged
|
||||
self._save_skills_config()
|
||||
@@ -154,69 +166,118 @@ class SkillManager:
|
||||
"""
|
||||
return list(self.skills.values())
|
||||
|
||||
@staticmethod
|
||||
def _normalize_skill_filter(skill_filter: Optional[List[str]]) -> Optional[List[str]]:
|
||||
"""Normalize a skill_filter list into a flat list of stripped names."""
|
||||
if skill_filter is None:
|
||||
return None
|
||||
normalized = []
|
||||
for item in skill_filter:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
elif isinstance(item, list):
|
||||
for subitem in item:
|
||||
if isinstance(subitem, str):
|
||||
name = subitem.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
return normalized or None
|
||||
|
||||
def filter_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
include_disabled: bool = False,
|
||||
) -> List[SkillEntry]:
|
||||
"""
|
||||
Filter skills based on criteria.
|
||||
|
||||
Simple rule: Skills are auto-enabled if requirements are met.
|
||||
- Has required API keys -> included
|
||||
- Missing API keys -> excluded
|
||||
Filter skills that are eligible (enabled + requirements met).
|
||||
|
||||
:param skill_filter: List of skill names to include (None = all)
|
||||
:param include_disabled: Whether to include disabled skills
|
||||
:return: Filtered list of skill entries
|
||||
:return: Filtered list of eligible skill entries
|
||||
"""
|
||||
from agent.skills.config import should_include_skill
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Check requirements (platform, binaries, env vars)
|
||||
entries = [e for e in entries if should_include_skill(e, self.config)]
|
||||
|
||||
# Apply skill filter
|
||||
if skill_filter is not None:
|
||||
normalized = []
|
||||
for item in skill_filter:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
elif isinstance(item, list):
|
||||
for subitem in item:
|
||||
if isinstance(subitem, str):
|
||||
name = subitem.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
if normalized:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
normalized = self._normalize_skill_filter(skill_filter)
|
||||
if normalized is not None:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Filter out disabled skills based on skills_config.json
|
||||
if not include_disabled:
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
from config import conf
|
||||
if not conf().get("knowledge", True):
|
||||
entries = [e for e in entries if e.skill.name != "knowledge-wiki"]
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def filter_unavailable_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Find skills that are enabled but have unmet requirements.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Tuple of (entries, missing_map) where missing_map maps
|
||||
skill name to its missing requirements dict
|
||||
"""
|
||||
from agent.skills.config import should_include_skill, get_missing_requirements
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Only enabled skills
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
normalized = self._normalize_skill_filter(skill_filter)
|
||||
if normalized is not None:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Keep only those that fail should_include_skill (requirements not met)
|
||||
unavailable = []
|
||||
missing_map: Dict[str, dict] = {}
|
||||
for e in entries:
|
||||
if not should_include_skill(e, self.config):
|
||||
missing = get_missing_requirements(e)
|
||||
if missing:
|
||||
unavailable.append(e)
|
||||
missing_map[e.skill.name] = missing
|
||||
|
||||
return unavailable, missing_map
|
||||
|
||||
def build_skills_prompt(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a formatted prompt containing available skills.
|
||||
|
||||
Build a formatted prompt containing available skills
|
||||
and brief hints for unavailable ones.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt
|
||||
"""
|
||||
from common.log import logger
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
logger.debug(f"[SkillManager] Filtered {len(entries)} skills for prompt (total: {len(self.skills)})")
|
||||
if entries:
|
||||
skill_names = [e.skill.name for e in entries]
|
||||
logger.debug(f"[SkillManager] Skills to include: {skill_names}")
|
||||
result = format_skill_entries_for_prompt(entries)
|
||||
from agent.skills.formatter import format_unavailable_skills_for_prompt
|
||||
|
||||
eligible = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
logger.debug(f"[SkillManager] Eligible: {len(eligible)} skills (total: {len(self.skills)})")
|
||||
if eligible:
|
||||
skill_names = [e.skill.name for e in eligible]
|
||||
logger.debug(f"[SkillManager] Eligible skills: {skill_names}")
|
||||
|
||||
result = format_skill_entries_for_prompt(eligible)
|
||||
|
||||
unavailable, missing_map = self.filter_unavailable_skills(skill_filter=skill_filter)
|
||||
if unavailable:
|
||||
unavailable_names = [e.skill.name for e in unavailable]
|
||||
logger.debug(f"[SkillManager] Unavailable skills (setup needed): {unavailable_names}")
|
||||
result += format_unavailable_skills_for_prompt(unavailable, missing_map)
|
||||
|
||||
logger.debug(f"[SkillManager] Generated prompt length: {len(result)}")
|
||||
return result
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ other management entry point.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tempfile
|
||||
from typing import Dict, List, Optional
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
@@ -55,7 +57,9 @@ class SkillService:
|
||||
"""
|
||||
Add (install) a skill from a remote payload.
|
||||
|
||||
The payload follows the socket protocol::
|
||||
Supported payload types:
|
||||
|
||||
1. ``type: "url"`` – download individual files::
|
||||
|
||||
{
|
||||
"name": "web_search",
|
||||
@@ -67,8 +71,15 @@ class SkillService:
|
||||
]
|
||||
}
|
||||
|
||||
Files are downloaded and saved under the custom skills directory
|
||||
using *name* as the sub-directory.
|
||||
2. ``type: "package"`` – download a zip archive and extract::
|
||||
|
||||
{
|
||||
"name": "plugin-custom-tool",
|
||||
"type": "package",
|
||||
"category": "skills",
|
||||
"enabled": true,
|
||||
"files": [{"url": "https://cdn.example.com/skills/custom-tool.zip"}]
|
||||
}
|
||||
|
||||
:param payload: skill add payload from server
|
||||
"""
|
||||
@@ -76,25 +87,95 @@ class SkillService:
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
payload_type = payload.get("type", "url")
|
||||
|
||||
if payload_type == "package":
|
||||
self._add_package(name, payload)
|
||||
else:
|
||||
self._add_url(name, payload)
|
||||
|
||||
self.manager.refresh_skills()
|
||||
|
||||
category = payload.get("category")
|
||||
if category and name in self.manager.skills_config:
|
||||
self.manager.skills_config[name]["category"] = category
|
||||
self.manager._save_skills_config()
|
||||
|
||||
def _add_url(self, name: str, payload: dict) -> None:
|
||||
"""Install a skill by downloading individual files."""
|
||||
files = payload.get("files", [])
|
||||
if not files:
|
||||
raise ValueError("skill files list is empty")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
os.makedirs(skill_dir, exist_ok=True)
|
||||
|
||||
for file_info in files:
|
||||
url = file_info.get("url")
|
||||
rel_path = file_info.get("path")
|
||||
if not url or not rel_path:
|
||||
logger.warning(f"[SkillService] add: skip invalid file entry {file_info}")
|
||||
continue
|
||||
dest = os.path.join(skill_dir, rel_path)
|
||||
self._download_file(url, dest)
|
||||
tmp_dir = skill_dir + ".tmp"
|
||||
if os.path.exists(tmp_dir):
|
||||
shutil.rmtree(tmp_dir)
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
# Reload to pick up the new skill and sync config
|
||||
self.manager.refresh_skills()
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed ({len(files)} files)")
|
||||
try:
|
||||
for file_info in files:
|
||||
url = file_info.get("url")
|
||||
rel_path = file_info.get("path")
|
||||
if not url or not rel_path:
|
||||
logger.warning(f"[SkillService] add: skip invalid file entry {file_info}")
|
||||
continue
|
||||
dest = os.path.join(tmp_dir, rel_path)
|
||||
self._download_file(url, dest)
|
||||
except Exception:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
os.rename(tmp_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via url ({len(files)} files)")
|
||||
|
||||
def _add_package(self, name: str, payload: dict) -> None:
|
||||
"""
|
||||
Install a skill by downloading a zip archive and extracting it.
|
||||
|
||||
If the archive contains a single top-level directory, that directory
|
||||
is used as the skill folder directly; otherwise a new directory named
|
||||
after the skill is created to hold the extracted contents.
|
||||
"""
|
||||
files = payload.get("files", [])
|
||||
if not files or not files[0].get("url"):
|
||||
raise ValueError("package url is required")
|
||||
|
||||
url = files[0]["url"]
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
zip_path = os.path.join(tmp_dir, "package.zip")
|
||||
self._download_file(url, zip_path)
|
||||
|
||||
if not zipfile.is_zipfile(zip_path):
|
||||
raise ValueError(f"downloaded file is not a valid zip archive: {url}")
|
||||
|
||||
extract_dir = os.path.join(tmp_dir, "extracted")
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(extract_dir)
|
||||
|
||||
# Determine the actual content root.
|
||||
# If the zip has a single top-level directory, use its contents
|
||||
# so the skill folder is clean (no extra nesting).
|
||||
top_items = [
|
||||
item for item in os.listdir(extract_dir)
|
||||
if not item.startswith(".")
|
||||
]
|
||||
if len(top_items) == 1:
|
||||
single = os.path.join(extract_dir, top_items[0])
|
||||
if os.path.isdir(single):
|
||||
extract_dir = single
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
shutil.copytree(extract_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via package ({url})")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# open / close (enable / disable)
|
||||
|
||||
@@ -29,6 +29,7 @@ class SkillInstallSpec:
|
||||
class SkillMetadata:
|
||||
"""Metadata for a skill from frontmatter."""
|
||||
always: bool = False # Always include this skill
|
||||
default_enabled: bool = True # Initial enabled state when first discovered
|
||||
skill_key: Optional[str] = None # Override skill key
|
||||
primary_env: Optional[str] = None # Primary environment variable
|
||||
emoji: Optional[str] = None
|
||||
|
||||
@@ -55,6 +55,24 @@ def _import_optional_tools():
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebSearch failed to load: {e}")
|
||||
|
||||
# WebFetch Tool
|
||||
try:
|
||||
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||
tools['WebFetch'] = WebFetch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebFetch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebFetch failed to load: {e}")
|
||||
|
||||
# Vision Tool (conditionally loaded based on API key availability)
|
||||
try:
|
||||
from agent.tools.vision.vision import Vision
|
||||
tools['Vision'] = Vision
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] Vision not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Vision failed to load: {e}")
|
||||
|
||||
return tools
|
||||
|
||||
# Load optional tools
|
||||
@@ -62,30 +80,48 @@ _optional_tools = _import_optional_tools()
|
||||
EnvConfig = _optional_tools.get('EnvConfig')
|
||||
SchedulerTool = _optional_tools.get('SchedulerTool')
|
||||
WebSearch = _optional_tools.get('WebSearch')
|
||||
WebFetch = _optional_tools.get('WebFetch')
|
||||
Vision = _optional_tools.get('Vision')
|
||||
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
Terminal = _optional_tools.get('Terminal')
|
||||
|
||||
|
||||
# Delayed import for BrowserTool
|
||||
# BrowserTool (requires playwright)
|
||||
def _import_browser_tool():
|
||||
from common.log import logger
|
||||
try:
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
return BrowserTool
|
||||
except ImportError:
|
||||
# Return a placeholder class that will prompt the user to install dependencies when instantiated
|
||||
class BrowserToolPlaceholder:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"The 'browser-use' package is required to use BrowserTool. "
|
||||
"Please install it with 'pip install browser-use>=0.1.40'."
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.info(
|
||||
f"[Tools] BrowserTool not loaded - missing dependency: {e}\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] BrowserTool failed to load: {e}")
|
||||
return None
|
||||
|
||||
return BrowserToolPlaceholder
|
||||
BrowserTool = _import_browser_tool()
|
||||
|
||||
# MCP Tools (no extra dependencies, loaded on demand)
|
||||
def _import_mcp_tools():
|
||||
"""导入 MCP 工具模块(无额外依赖,按需加载)"""
|
||||
from common.log import logger
|
||||
try:
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
from agent.tools.mcp.mcp_client import McpClientRegistry
|
||||
return {'McpTool': McpTool, 'McpClientRegistry': McpClientRegistry}
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tools] MCP tools not loaded: {e}")
|
||||
return {}
|
||||
|
||||
# Dynamically set BrowserTool
|
||||
# BrowserTool = _import_browser_tool()
|
||||
_mcp_tools = _import_mcp_tools()
|
||||
McpTool = _mcp_tools.get('McpTool')
|
||||
McpClientRegistry = _mcp_tools.get('McpClientRegistry')
|
||||
|
||||
# Export all tools (including optional ones that might be None)
|
||||
__all__ = [
|
||||
@@ -102,8 +138,10 @@ __all__ = [
|
||||
'EnvConfig',
|
||||
'SchedulerTool',
|
||||
'WebSearch',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
# 'BrowserTool'
|
||||
'WebFetch',
|
||||
'Vision',
|
||||
'BrowserTool',
|
||||
'McpTool',
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ Bash tool - Execute bash commands
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
@@ -17,14 +18,18 @@ from common.utils import expand_path
|
||||
class Bash(BaseTool):
|
||||
"""Tool for executing bash commands"""
|
||||
|
||||
_IS_WIN = sys.platform == "win32"
|
||||
|
||||
name: str = "bash"
|
||||
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
|
||||
|
||||
{'''
|
||||
PLATFORM: Windows (cmd.exe). Do NOT use Unix-only commands like grep, head, tail, sed, awk.
|
||||
''' if _IS_WIN else ''}
|
||||
ENVIRONMENT: All API keys from env_config are auto-injected. Use $VAR_NAME directly.
|
||||
|
||||
SAFETY:
|
||||
- Freely create/modify/delete files within the workspace
|
||||
- For destructive and out-of-workspace commands, explain and confirm first"""
|
||||
- For destructive commands out of workspace, explain and confirm first"""
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
@@ -83,12 +88,13 @@ SAFETY:
|
||||
|
||||
# Load environment variables from ~/.cow/.env if it exists
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
dotenv_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
env_vars = dotenv_values(env_file)
|
||||
env.update(env_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(env_vars)} variables from {env_file}")
|
||||
dotenv_vars = dotenv_values(env_file)
|
||||
env.update(dotenv_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(dotenv_vars)} variables from {env_file}")
|
||||
except ImportError:
|
||||
logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
|
||||
except Exception as e:
|
||||
@@ -100,7 +106,13 @@ SAFETY:
|
||||
else:
|
||||
logger.debug(f"[Bash] Process User: {os.environ.get('USERNAME', os.environ.get('USER', 'unknown'))}")
|
||||
|
||||
# Execute command with inherited environment variables
|
||||
# On Windows, convert $VAR references to %VAR% for cmd.exe
|
||||
if self._IS_WIN:
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
command = self._convert_env_vars_for_windows(command, dotenv_vars)
|
||||
if command and not command.strip().lower().startswith("chcp"):
|
||||
command = f"chcp 65001 >nul 2>&1 && {command}"
|
||||
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
@@ -108,8 +120,10 @@ SAFETY:
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
env=env,
|
||||
)
|
||||
|
||||
logger.debug(f"[Bash] Exit code: {result.returncode}")
|
||||
@@ -131,6 +145,8 @@ SAFETY:
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
@@ -153,10 +169,16 @@ SAFETY:
|
||||
except Exception as retry_err:
|
||||
logger.warning(f"[Bash] Retry failed: {retry_err}")
|
||||
|
||||
# Combine stdout and stderr
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
# When command succeeds with stdout, keep output clean (stderr goes to server log only).
|
||||
# When command fails or stdout is empty, include stderr so the agent can diagnose.
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
logger.info(f"[Bash] stderr (not forwarded): {result.stderr[:500]}")
|
||||
else:
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
|
||||
# Check if we need to save full output to temp file
|
||||
temp_file_path = None
|
||||
@@ -216,45 +238,58 @@ SAFETY:
|
||||
|
||||
def _get_safety_warning(self, command: str) -> str:
|
||||
"""
|
||||
Get safety warning for potentially dangerous commands
|
||||
Only warns about extremely dangerous system-level operations
|
||||
|
||||
Get safety warning for absolutely catastrophic commands only.
|
||||
Keep the blocklist minimal so the agent retains maximum freedom.
|
||||
|
||||
:param command: Command to check
|
||||
:return: Warning message if dangerous, empty string if safe
|
||||
"""
|
||||
cmd_lower = command.lower().strip()
|
||||
# Tokenize to avoid substring false positives (e.g. `rm -rf /tmp/x`
|
||||
# must not match `rm -rf /`).
|
||||
tokens = command.lower().split()
|
||||
|
||||
# Only block extremely dangerous system operations
|
||||
dangerous_patterns = [
|
||||
# System shutdown/reboot
|
||||
("shutdown", "This command will shut down the system"),
|
||||
("reboot", "This command will reboot the system"),
|
||||
("halt", "This command will halt the system"),
|
||||
("poweroff", "This command will power off the system"),
|
||||
# `rm -rf /` or `rm -rf /*` targeting the real root.
|
||||
for i, tok in enumerate(tokens):
|
||||
if tok != "rm":
|
||||
continue
|
||||
has_rf = False
|
||||
for j in range(i + 1, len(tokens)):
|
||||
t = tokens[j]
|
||||
if t.startswith("-") and "r" in t and "f" in t:
|
||||
has_rf = True
|
||||
elif t in ("--recursive", "--force"):
|
||||
continue
|
||||
elif t in ("/", "/*"):
|
||||
if has_rf:
|
||||
return "This command will delete the entire filesystem"
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
# Critical system modifications
|
||||
("rm -rf /", "This command will delete the entire filesystem"),
|
||||
("rm -rf /*", "This command will delete the entire filesystem"),
|
||||
("dd if=/dev/zero", "This command can destroy disk data"),
|
||||
("mkfs", "This command will format a filesystem, destroying all data"),
|
||||
("fdisk", "This command modifies disk partitions"),
|
||||
# Disk wiping
|
||||
if "if=/dev/zero" in command.lower() and "dd " in command.lower():
|
||||
return "This command can destroy disk data"
|
||||
|
||||
# User/system management (only if targeting system users)
|
||||
("userdel root", "This command will delete the root user"),
|
||||
("passwd root", "This command will change the root password"),
|
||||
]
|
||||
# Power control - match only as a standalone word (\b enforces word boundary)
|
||||
if re.search(r'\b(shutdown|reboot|halt|poweroff)\b', command.lower()):
|
||||
return "This command will shut down or restart the system"
|
||||
|
||||
for pattern, warning in dangerous_patterns:
|
||||
if pattern in cmd_lower:
|
||||
return warning
|
||||
return ""
|
||||
|
||||
# Check for recursive deletion outside workspace
|
||||
if "rm" in cmd_lower and "-rf" in cmd_lower:
|
||||
# Allow deletion within current workspace
|
||||
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
|
||||
# Check if targeting system directories
|
||||
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
|
||||
if any(sysdir in cmd_lower for sysdir in system_dirs):
|
||||
return "This command will recursively delete system directories"
|
||||
@staticmethod
|
||||
def _convert_env_vars_for_windows(command: str, dotenv_vars: dict) -> str:
|
||||
"""
|
||||
Convert bash-style $VAR / ${VAR} references to cmd.exe %VAR% syntax.
|
||||
Only converts variables loaded from .env (user-configured API keys etc.)
|
||||
to avoid breaking $PATH, jq expressions, regex, etc.
|
||||
"""
|
||||
if not dotenv_vars:
|
||||
return command
|
||||
|
||||
return "" # No warning needed
|
||||
def replace_match(m):
|
||||
var_name = m.group(1) or m.group(2)
|
||||
if var_name in dotenv_vars:
|
||||
return f"%{var_name}%"
|
||||
return m.group(0)
|
||||
|
||||
return re.sub(r'\$\{(\w+)\}|\$(\w+)', replace_match, command)
|
||||
|
||||
3
agent/tools/browser/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
|
||||
__all__ = ["BrowserTool"]
|
||||
947
agent/tools/browser/browser_service.py
Normal file
@@ -0,0 +1,947 @@
|
||||
"""
|
||||
Browser service - Playwright wrapper managing browser lifecycle and page operations.
|
||||
|
||||
All Playwright calls run on a dedicated background thread so that callers from
|
||||
any worker thread can safely use the service. An idle-timeout mechanism
|
||||
automatically shuts down the browser (and its thread) after a configurable
|
||||
period of inactivity to free resources.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import queue
|
||||
import threading
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
_DEFAULT_USER_DATA_DIR = "~/.cow/browser_profile"
|
||||
|
||||
try:
|
||||
from playwright.sync_api import sync_playwright, Browser, BrowserContext, Page, Playwright
|
||||
_HAS_PLAYWRIGHT = True
|
||||
except ImportError:
|
||||
_HAS_PLAYWRIGHT = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot DOM helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Tags that typically carry useful content for an agent
|
||||
_INTERACTIVE_TAGS = {
|
||||
"a", "button", "input", "textarea", "select", "option",
|
||||
"label", "details", "summary",
|
||||
}
|
||||
_SEMANTIC_TAGS = {
|
||||
"h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"p", "li", "td", "th", "caption", "figcaption", "blockquote", "pre", "code",
|
||||
"nav", "main", "article", "section", "header", "footer", "form", "table",
|
||||
"img", "video", "audio",
|
||||
}
|
||||
_KEEP_TAGS = _INTERACTIVE_TAGS | _SEMANTIC_TAGS
|
||||
|
||||
_SNAPSHOT_JS = """
|
||||
() => {
|
||||
const KEEP = new Set(%s);
|
||||
const INTERACTIVE = new Set(%s);
|
||||
const SKIP = new Set(["script","style","noscript","svg","path","meta","link","br","hr"]);
|
||||
const CLICKABLE_ROLES = new Set([
|
||||
"button","link","tab","menuitem","menuitemcheckbox","menuitemradio",
|
||||
"option","switch","checkbox","radio","combobox","searchbox","slider",
|
||||
"spinbutton","textbox","treeitem"
|
||||
]);
|
||||
let refCounter = 0;
|
||||
const refMap = {};
|
||||
|
||||
function visible(el) {
|
||||
if (!(el instanceof HTMLElement)) return true;
|
||||
const st = window.getComputedStyle(el);
|
||||
if (st.display === "none" || st.visibility === "hidden") return false;
|
||||
if (parseFloat(st.opacity) === 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Strong signals: these attributes alone are enough to mark as interactive
|
||||
function hasStrongInteractiveSignal(el) {
|
||||
const role = el.getAttribute("role");
|
||||
if (role && CLICKABLE_ROLES.has(role)) return true;
|
||||
if (el.hasAttribute("onclick") || el.hasAttribute("tabindex")) return true;
|
||||
if (el.hasAttribute("data-click") || el.hasAttribute("data-action")) return true;
|
||||
if (el.getAttribute("contenteditable") === "true") return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if cursor:pointer is set directly (not just inherited from parent)
|
||||
function hasOwnPointerCursor(el) {
|
||||
try {
|
||||
const st = window.getComputedStyle(el);
|
||||
if (st.cursor !== "pointer") return false;
|
||||
const parent = el.parentElement;
|
||||
if (parent) {
|
||||
const pst = window.getComputedStyle(parent);
|
||||
if (pst.cursor === "pointer") return false;
|
||||
}
|
||||
return true;
|
||||
} catch(e) {}
|
||||
return false;
|
||||
}
|
||||
|
||||
function hasTextOrContent(el) {
|
||||
const t = el.textContent || "";
|
||||
if (t.trim().length > 0) return true;
|
||||
if (el.querySelector("img,video,audio,canvas")) return true;
|
||||
const ariaLabel = el.getAttribute("aria-label");
|
||||
if (ariaLabel && ariaLabel.trim()) return true;
|
||||
const title = el.getAttribute("title");
|
||||
if (title && title.trim()) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function isImplicitInteractive(el) {
|
||||
if (hasStrongInteractiveSignal(el)) return true;
|
||||
if (hasOwnPointerCursor(el) && hasTextOrContent(el)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function getTextContent(el) {
|
||||
let text = "";
|
||||
for (const ch of el.childNodes) {
|
||||
if (ch.nodeType === Node.TEXT_NODE) {
|
||||
text += ch.textContent;
|
||||
}
|
||||
}
|
||||
return text.trim();
|
||||
}
|
||||
|
||||
function walk(node) {
|
||||
if (node.nodeType === Node.TEXT_NODE) {
|
||||
const t = node.textContent.trim();
|
||||
return t ? t : null;
|
||||
}
|
||||
if (node.nodeType !== Node.ELEMENT_NODE) return null;
|
||||
const tag = node.tagName.toLowerCase();
|
||||
if (SKIP.has(tag)) return null;
|
||||
if (!visible(node)) return null;
|
||||
|
||||
const children = [];
|
||||
for (const ch of node.childNodes) {
|
||||
const r = walk(ch);
|
||||
if (r !== null) {
|
||||
if (typeof r === "string") children.push(r);
|
||||
else children.push(r);
|
||||
}
|
||||
}
|
||||
|
||||
const nativeInteractive = INTERACTIVE.has(tag);
|
||||
const implicitInteractive = !nativeInteractive && (node instanceof HTMLElement) && isImplicitInteractive(node);
|
||||
const keep = KEEP.has(tag) || implicitInteractive;
|
||||
|
||||
if (!keep) {
|
||||
if (children.length === 0) return null;
|
||||
if (children.length === 1) return children[0];
|
||||
return children;
|
||||
}
|
||||
|
||||
const obj = { tag };
|
||||
if (nativeInteractive || implicitInteractive) {
|
||||
refCounter++;
|
||||
obj.ref = refCounter;
|
||||
refMap[refCounter] = node;
|
||||
}
|
||||
|
||||
if (implicitInteractive) {
|
||||
const role = node.getAttribute("role");
|
||||
if (role) obj.role = role;
|
||||
const directText = getTextContent(node);
|
||||
if (!directText && children.length === 0) {
|
||||
const ariaLabel = node.getAttribute("aria-label");
|
||||
const title = node.getAttribute("title");
|
||||
if (ariaLabel) obj.ariaLabel = ariaLabel;
|
||||
else if (title) obj.ariaLabel = title;
|
||||
}
|
||||
}
|
||||
|
||||
// Attributes
|
||||
if (tag === "a" && node.href) obj.href = node.getAttribute("href");
|
||||
if (tag === "img") {
|
||||
obj.alt = node.alt || "";
|
||||
obj.src = node.getAttribute("src") || "";
|
||||
}
|
||||
if (tag === "input" || tag === "textarea" || tag === "select") {
|
||||
obj.type = node.type || "text";
|
||||
obj.name = node.name || undefined;
|
||||
obj.value = node.value || undefined;
|
||||
obj.placeholder = node.placeholder || undefined;
|
||||
if (node.disabled) obj.disabled = true;
|
||||
if (tag === "input" && node.type === "checkbox") obj.checked = node.checked;
|
||||
}
|
||||
if (tag === "button") {
|
||||
if (node.disabled) obj.disabled = true;
|
||||
}
|
||||
if (tag === "option") {
|
||||
obj.value = node.value;
|
||||
if (node.selected) obj.selected = true;
|
||||
}
|
||||
if (tag === "label" && node.htmlFor) obj.for = node.htmlFor;
|
||||
|
||||
// Role / aria-label for native interactive & semantic elements
|
||||
if (!implicitInteractive) {
|
||||
const role = node.getAttribute("role");
|
||||
if (role) obj.role = role;
|
||||
const ariaLabel = node.getAttribute("aria-label");
|
||||
if (ariaLabel) obj.ariaLabel = ariaLabel;
|
||||
}
|
||||
|
||||
// Children
|
||||
if (children.length === 1 && typeof children[0] === "string") {
|
||||
obj.text = children[0];
|
||||
} else if (children.length > 0) {
|
||||
obj.children = children;
|
||||
}
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
const result = walk(document.body);
|
||||
window.__cowRefMap = refMap;
|
||||
return { tree: result, refCount: refCounter };
|
||||
}
|
||||
""" % (
|
||||
str(list(_KEEP_TAGS)),
|
||||
str(list(_INTERACTIVE_TAGS)),
|
||||
)
|
||||
|
||||
|
||||
_BROWSER_DEAD_HINTS = (
|
||||
"has been closed",
|
||||
"browser has disconnected",
|
||||
"target closed",
|
||||
"browser closed",
|
||||
"context or browser has been closed",
|
||||
)
|
||||
|
||||
|
||||
def _is_browser_dead_error(err: Exception) -> bool:
|
||||
"""Return True if *err* indicates the browser / page died out from under us."""
|
||||
msg = str(err).lower()
|
||||
return any(h in msg for h in _BROWSER_DEAD_HINTS)
|
||||
|
||||
|
||||
def _should_use_headless() -> bool:
|
||||
"""Decide headless mode: headless on Linux servers without display, headed elsewhere."""
|
||||
if sys.platform in ("win32", "darwin"):
|
||||
return False
|
||||
# Linux: check for display
|
||||
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _flatten_tree(node, indent=0) -> List[str]:
|
||||
"""Convert snapshot tree to compact text lines for LLM consumption."""
|
||||
if node is None:
|
||||
return []
|
||||
if isinstance(node, str):
|
||||
return [" " * indent + node]
|
||||
if isinstance(node, list):
|
||||
lines = []
|
||||
for child in node:
|
||||
lines.extend(_flatten_tree(child, indent))
|
||||
return lines
|
||||
if not isinstance(node, dict):
|
||||
return []
|
||||
|
||||
tag = node.get("tag", "?")
|
||||
ref = node.get("ref")
|
||||
parts = [tag]
|
||||
if ref:
|
||||
parts[0] = f"[{ref}] {tag}"
|
||||
|
||||
# Inline attributes
|
||||
for attr in ("type", "name", "href", "alt", "role", "ariaLabel", "placeholder", "value"):
|
||||
val = node.get(attr)
|
||||
if val:
|
||||
# Truncate long values
|
||||
s = str(val)
|
||||
if len(s) > 80:
|
||||
s = s[:77] + "..."
|
||||
parts.append(f'{attr}="{s}"')
|
||||
|
||||
for flag in ("disabled", "checked", "selected"):
|
||||
if node.get(flag):
|
||||
parts.append(flag)
|
||||
|
||||
prefix = " " * indent
|
||||
header = prefix + " ".join(parts)
|
||||
|
||||
text = node.get("text")
|
||||
if text:
|
||||
# Truncate long text
|
||||
if len(text) > 120:
|
||||
text = text[:117] + "..."
|
||||
header += f": {text}"
|
||||
|
||||
lines = [header]
|
||||
children = node.get("children", [])
|
||||
for child in children:
|
||||
lines.extend(_flatten_tree(child, indent + 2))
|
||||
return lines
|
||||
|
||||
|
||||
class BrowserService:
|
||||
"""Manages a Playwright browser on a dedicated background thread.
|
||||
|
||||
All Playwright operations are dispatched to a single long-lived thread via
|
||||
a task queue. Callers from *any* worker thread can use the public API
|
||||
safely. An idle timer automatically shuts the browser down after
|
||||
``idle_timeout`` seconds of inactivity (default 300 = 5 min).
|
||||
"""
|
||||
|
||||
_IDLE_TIMEOUT_DEFAULT = 300 # seconds
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
self._config = config or {}
|
||||
self._headless: Optional[bool] = None
|
||||
self._screenshot_dir: Optional[str] = None
|
||||
|
||||
# Background thread state
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._task_queue: queue.Queue = queue.Queue()
|
||||
self._lock = threading.Lock()
|
||||
self._alive = False
|
||||
self._ready = threading.Event()
|
||||
|
||||
# Playwright objects (only accessed on the background thread)
|
||||
self._playwright = None
|
||||
self._browser = None
|
||||
self._context = None
|
||||
self._page = None
|
||||
|
||||
# Launch mode: one of "fresh" | "persistent" | "cdp".
|
||||
# - cdp: connect to an externally launched Chrome via CDP endpoint.
|
||||
# - persistent: launch with launch_persistent_context using a user_data_dir
|
||||
# so cookies / login state survive across runs (default).
|
||||
# - fresh: classic launch + new_context, clean state every run.
|
||||
cdp_endpoint = self._config.get("cdp_endpoint") or ""
|
||||
persistent_flag = self._config.get("persistent", True)
|
||||
user_data_dir_cfg = self._config.get("user_data_dir")
|
||||
if user_data_dir_cfg is None:
|
||||
user_data_dir_cfg = _DEFAULT_USER_DATA_DIR
|
||||
|
||||
self._cdp_endpoint: str = cdp_endpoint.strip() if isinstance(cdp_endpoint, str) else ""
|
||||
if self._cdp_endpoint:
|
||||
self._launch_mode = "cdp"
|
||||
self._user_data_dir: str = ""
|
||||
elif persistent_flag and user_data_dir_cfg:
|
||||
self._launch_mode = "persistent"
|
||||
self._user_data_dir = expand_path(str(user_data_dir_cfg))
|
||||
else:
|
||||
self._launch_mode = "fresh"
|
||||
self._user_data_dir = ""
|
||||
|
||||
# Idle auto-release
|
||||
idle_cfg = self._config.get("idle_timeout")
|
||||
self._idle_timeout: float = float(idle_cfg) if idle_cfg is not None else self._IDLE_TIMEOUT_DEFAULT
|
||||
self._idle_timer: Optional[threading.Timer] = None
|
||||
|
||||
# Set when the browser / page is detected to have died externally
|
||||
# (e.g. user manually closed the window). The next _submit() will then
|
||||
# tear down the stale thread and relaunch.
|
||||
self._needs_restart = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background-thread lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _start_thread(self):
|
||||
"""Start the dedicated Playwright thread if not already running."""
|
||||
with self._lock:
|
||||
if self._alive and self._thread and self._thread.is_alive():
|
||||
return
|
||||
# Wait for old thread to fully exit before creating a new one
|
||||
old = self._thread
|
||||
if old and old.is_alive():
|
||||
old.join(timeout=5)
|
||||
# Fresh queue to avoid stale sentinels from a previous close()
|
||||
self._task_queue = queue.Queue()
|
||||
self._alive = True
|
||||
self._ready = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True, name="BrowserThread")
|
||||
self._thread.start()
|
||||
# Block until browser is ready (or failed)
|
||||
self._ready.wait(timeout=30)
|
||||
|
||||
def _run_loop(self):
|
||||
"""Event loop running on the dedicated thread. Processes tasks until stopped."""
|
||||
logger.info("[Browser] Background thread started")
|
||||
try:
|
||||
self._launch_browser()
|
||||
except Exception as e:
|
||||
logger.error(f"[Browser] Failed to launch browser: {e}")
|
||||
self._alive = False
|
||||
self._ready.set()
|
||||
self._drain_queue(RuntimeError(f"Browser launch failed: {e}"))
|
||||
return
|
||||
self._ready.set()
|
||||
|
||||
while self._alive:
|
||||
try:
|
||||
task = self._task_queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if task is None:
|
||||
break
|
||||
fn, args, kwargs, result_slot = task
|
||||
try:
|
||||
result_slot["value"] = fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
result_slot["error"] = e
|
||||
if _is_browser_dead_error(e):
|
||||
self._needs_restart = True
|
||||
logger.warning(
|
||||
f"[Browser] Detected closed page/context ({e}); "
|
||||
"will relaunch on next request."
|
||||
)
|
||||
finally:
|
||||
result_slot["event"].set()
|
||||
|
||||
self._shutdown_browser()
|
||||
self._drain_queue(RuntimeError("Browser thread stopped"))
|
||||
logger.info("[Browser] Background thread exited")
|
||||
|
||||
def _drain_queue(self, error: Exception):
|
||||
"""Unblock all callers waiting on the queue with an error."""
|
||||
while True:
|
||||
try:
|
||||
task = self._task_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
if task is None:
|
||||
continue
|
||||
_, _, _, result_slot = task
|
||||
result_slot["error"] = error
|
||||
result_slot["event"].set()
|
||||
|
||||
def _launch_browser(self):
|
||||
"""Launch / connect Chromium on the background thread."""
|
||||
if self._headless is None:
|
||||
headless_cfg = self._config.get("headless")
|
||||
self._headless = headless_cfg if headless_cfg is not None else _should_use_headless()
|
||||
|
||||
launch_args = ["--disable-dev-shm-usage"]
|
||||
if self._headless:
|
||||
launch_args.append("--no-sandbox")
|
||||
|
||||
extra_args = self._config.get("launch_args", [])
|
||||
if extra_args:
|
||||
launch_args.extend(extra_args)
|
||||
|
||||
viewport_w = self._config.get("viewport_width", 1280)
|
||||
viewport_h = self._config.get("viewport_height", 720)
|
||||
viewport = {"width": viewport_w, "height": viewport_h}
|
||||
user_agent = (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
self._playwright = sync_playwright().start()
|
||||
|
||||
if self._launch_mode == "cdp":
|
||||
self._connect_cdp(viewport)
|
||||
elif self._launch_mode == "persistent":
|
||||
self._launch_persistent(launch_args, viewport, user_agent)
|
||||
else:
|
||||
self._launch_fresh(launch_args, viewport, user_agent)
|
||||
|
||||
logger.info("[Browser] Browser ready")
|
||||
|
||||
def _launch_fresh(self, launch_args: List[str], viewport: Dict[str, int], user_agent: str):
|
||||
"""Classic launch: brand new Chromium with an empty context."""
|
||||
logger.info(f"[Browser] Launching Chromium (fresh, headless={self._headless})")
|
||||
self._browser = self._playwright.chromium.launch(
|
||||
headless=self._headless,
|
||||
args=launch_args,
|
||||
)
|
||||
self._context = self._browser.new_context(
|
||||
viewport=viewport,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
self._page = self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _launch_persistent(self, launch_args: List[str], viewport: Dict[str, int], user_agent: str):
|
||||
"""Launch Chromium with a persistent user_data_dir so login state survives."""
|
||||
os.makedirs(self._user_data_dir, exist_ok=True)
|
||||
logger.info(
|
||||
f"[Browser] Launching Chromium (persistent, headless={self._headless}, "
|
||||
f"profile={self._user_data_dir})"
|
||||
)
|
||||
try:
|
||||
self._context = self._playwright.chromium.launch_persistent_context(
|
||||
user_data_dir=self._user_data_dir,
|
||||
headless=self._headless,
|
||||
args=launch_args,
|
||||
viewport=viewport,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
except Exception as e:
|
||||
# Profile is locked when another Chromium instance already holds it.
|
||||
msg = str(e).lower()
|
||||
if "singletonlock" in msg or "profile" in msg or "lock" in msg:
|
||||
raise RuntimeError(
|
||||
f"Browser profile '{self._user_data_dir}' is in use by another process. "
|
||||
"Close the other Chromium / cow instance, or set a different "
|
||||
"tools.browser.user_data_dir."
|
||||
) from e
|
||||
raise
|
||||
|
||||
# Persistent context has no parent Browser handle; reuse the auto-created page.
|
||||
self._browser = None
|
||||
pages = self._context.pages
|
||||
self._page = pages[0] if pages else self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _connect_cdp(self, viewport: Dict[str, int]):
|
||||
"""Attach to an existing Chrome started with --remote-debugging-port."""
|
||||
endpoint = self._cdp_endpoint
|
||||
logger.info(f"[Browser] Connecting to existing Chrome via CDP: {endpoint}")
|
||||
try:
|
||||
self._browser = self._playwright.chromium.connect_over_cdp(endpoint)
|
||||
except Exception as e:
|
||||
msg = str(e).lower()
|
||||
if "econnrefused" in msg or "connect" in msg or "refused" in msg:
|
||||
raise RuntimeError(
|
||||
f"Cannot reach Chrome at {endpoint}. The CDP browser is not "
|
||||
"running. Ask the user to launch Chrome with "
|
||||
"--remote-debugging-port and --user-data-dir, then retry. "
|
||||
"Do not retry this tool until the user confirms."
|
||||
) from e
|
||||
raise
|
||||
|
||||
contexts = self._browser.contexts
|
||||
if contexts:
|
||||
self._context = contexts[0]
|
||||
else:
|
||||
self._context = self._browser.new_context(viewport=viewport)
|
||||
|
||||
pages = self._context.pages
|
||||
self._page = pages[0] if pages else self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _wire_close_listeners(self):
|
||||
"""Mark needs_restart whenever the browser / context / page dies externally."""
|
||||
def _on_dead(_obj=None):
|
||||
self._needs_restart = True
|
||||
|
||||
try:
|
||||
if self._browser:
|
||||
self._browser.on("disconnected", _on_dead)
|
||||
if self._context:
|
||||
self._context.on("close", _on_dead)
|
||||
if self._page:
|
||||
self._page.on("close", _on_dead)
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] Failed to wire close listeners: {e}")
|
||||
|
||||
def _shutdown_browser(self):
|
||||
"""Shut down Playwright resources on the background thread.
|
||||
|
||||
Mode-specific behavior:
|
||||
- cdp: only disconnect the Playwright client; leave the user's Chrome
|
||||
and its tabs untouched (do NOT close the context).
|
||||
- persistent: close the persistent context (no separate browser handle).
|
||||
- fresh: close context, then browser.
|
||||
"""
|
||||
self._cancel_idle_timer()
|
||||
|
||||
if self._launch_mode == "cdp":
|
||||
# For CDP, browser.close() only detaches the Playwright client;
|
||||
# the user's Chrome process and its tabs stay alive.
|
||||
try:
|
||||
if self._browser:
|
||||
self._browser.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] cdp disconnect error: {e}")
|
||||
else:
|
||||
for obj, label in [
|
||||
(self._context, "context"),
|
||||
(self._browser, "browser"),
|
||||
]:
|
||||
try:
|
||||
if obj:
|
||||
obj.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] {label} close error: {e}")
|
||||
|
||||
try:
|
||||
if self._playwright:
|
||||
self._playwright.stop()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] playwright stop error: {e}")
|
||||
self._page = None
|
||||
self._context = None
|
||||
self._browser = None
|
||||
self._playwright = None
|
||||
logger.info("[Browser] Browser closed")
|
||||
|
||||
def _submit(self, fn: Callable, *args, **kwargs):
|
||||
"""Submit *fn* to the background thread and block until it completes."""
|
||||
# If the browser died externally (e.g. user closed the window), tear
|
||||
# down the stale thread first so _start_thread() will relaunch fresh.
|
||||
if self._needs_restart:
|
||||
logger.info("[Browser] Restarting after detecting closed browser")
|
||||
self.close()
|
||||
self._needs_restart = False
|
||||
|
||||
self._start_thread()
|
||||
|
||||
if not self._alive:
|
||||
raise RuntimeError("Browser is not available")
|
||||
|
||||
self._reset_idle_timer()
|
||||
|
||||
result_slot: Dict[str, Any] = {"event": threading.Event()}
|
||||
self._task_queue.put((fn, args, kwargs, result_slot))
|
||||
|
||||
# Timeout prevents permanent hang if the background thread crashes
|
||||
completed = result_slot["event"].wait(timeout=120)
|
||||
if not completed:
|
||||
raise TimeoutError("Browser operation timed out (120s)")
|
||||
|
||||
if "error" in result_slot:
|
||||
raise result_slot["error"]
|
||||
return result_slot.get("value")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Idle auto-release
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _reset_idle_timer(self):
|
||||
self._cancel_idle_timer()
|
||||
if self._idle_timeout > 0:
|
||||
self._idle_timer = threading.Timer(self._idle_timeout, self._on_idle_timeout)
|
||||
self._idle_timer.daemon = True
|
||||
self._idle_timer.start()
|
||||
|
||||
def _cancel_idle_timer(self):
|
||||
if self._idle_timer:
|
||||
self._idle_timer.cancel()
|
||||
self._idle_timer = None
|
||||
|
||||
def _on_idle_timeout(self):
|
||||
logger.info(f"[Browser] Idle for {self._idle_timeout}s, auto-releasing browser")
|
||||
self.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def close(self):
|
||||
"""Shut down browser and background thread (safe from any thread)."""
|
||||
self._cancel_idle_timer()
|
||||
with self._lock:
|
||||
if not self._alive:
|
||||
self._needs_restart = False
|
||||
return
|
||||
self._alive = False
|
||||
t = self._thread
|
||||
if self._task_queue is not None:
|
||||
self._task_queue.put(None)
|
||||
if t is not None and t.is_alive():
|
||||
t.join(timeout=10)
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
self._needs_restart = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Actions (each method is dispatched to the background thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def navigate(self, url: str, timeout: int = 30000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_navigate, url, timeout)
|
||||
|
||||
def _do_navigate(self, url: str, timeout: int) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
resp = page.goto(url, wait_until="domcontentloaded", timeout=timeout)
|
||||
status = resp.status if resp else None
|
||||
except Exception as e:
|
||||
return {"error": f"Navigation failed: {e}"}
|
||||
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=8000)
|
||||
except Exception:
|
||||
pass
|
||||
page.wait_for_timeout(500)
|
||||
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
current_url = page.url
|
||||
except Exception:
|
||||
current_url = url
|
||||
|
||||
return {"url": current_url, "title": title, "status": status}
|
||||
|
||||
def snapshot(self, selector: Optional[str] = None) -> str:
|
||||
return self._submit(self._do_snapshot, selector)
|
||||
|
||||
def _do_snapshot(self, selector: Optional[str] = None) -> str:
|
||||
page = self._page
|
||||
try:
|
||||
result = page.evaluate(_SNAPSHOT_JS)
|
||||
except Exception as e:
|
||||
return f"[Snapshot error: {e}]"
|
||||
|
||||
tree = result.get("tree")
|
||||
ref_count = result.get("refCount", 0)
|
||||
lines = _flatten_tree(tree)
|
||||
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
|
||||
header = f"Page: {title} ({url})\nInteractive elements: {ref_count}\n---"
|
||||
body = "\n".join(lines)
|
||||
|
||||
max_chars = self._config.get("snapshot_max_chars", 30000)
|
||||
if len(body) > max_chars:
|
||||
body = body[:max_chars] + "\n... [snapshot truncated]"
|
||||
|
||||
return f"{header}\n{body}"
|
||||
|
||||
def screenshot(self, full_page: bool = False, cwd: str = "") -> str:
|
||||
return self._submit(self._do_screenshot, full_page, cwd)
|
||||
|
||||
def _do_screenshot(self, full_page: bool = False, cwd: str = "") -> str:
|
||||
page = self._page
|
||||
save_dir = self._get_screenshot_dir(cwd)
|
||||
filename = f"screenshot_{uuid.uuid4().hex[:8]}.png"
|
||||
filepath = os.path.join(save_dir, filename)
|
||||
page.screenshot(path=filepath, full_page=full_page)
|
||||
logger.info(f"[Browser] Screenshot saved: {filepath}")
|
||||
return filepath
|
||||
|
||||
def click(self, ref: Optional[int] = None, selector: Optional[str] = None,
|
||||
timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_click, ref, selector, timeout)
|
||||
|
||||
def _do_click(self, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
|
||||
el.click();
|
||||
return {{ clicked: true, tag: el.tagName.toLowerCase() }};
|
||||
}}
|
||||
""")
|
||||
if result.get("error"):
|
||||
return result
|
||||
page.wait_for_timeout(500)
|
||||
return result
|
||||
elif selector:
|
||||
page.click(selector, timeout=timeout)
|
||||
return {"clicked": True, "selector": selector}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Click failed: {e}"}
|
||||
|
||||
def fill(self, text: str, ref: Optional[int] = None,
|
||||
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_fill, text, ref, selector, timeout)
|
||||
|
||||
def _do_fill(self, text, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
|
||||
el.focus();
|
||||
el.value = "";
|
||||
return {{ tag: el.tagName.toLowerCase(), name: el.name || "" }};
|
||||
}}
|
||||
""")
|
||||
if result.get("error"):
|
||||
return result
|
||||
page.keyboard.type(text)
|
||||
return {"filled": True, "ref": ref, "text": text}
|
||||
elif selector:
|
||||
page.fill(selector, text, timeout=timeout)
|
||||
return {"filled": True, "selector": selector, "text": text}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Fill failed: {e}"}
|
||||
|
||||
def select(self, value: str, ref: Optional[int] = None,
|
||||
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_select, value, ref, selector, timeout)
|
||||
|
||||
def _do_select(self, value, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el || el.tagName.toLowerCase() !== "select")
|
||||
return {{ error: "ref {ref} is not a <select> element" }};
|
||||
el.value = {repr(value)};
|
||||
el.dispatchEvent(new Event("change", {{ bubbles: true }}));
|
||||
return {{ selected: true, value: el.value }};
|
||||
}}
|
||||
""")
|
||||
return result
|
||||
elif selector:
|
||||
page.select_option(selector, value, timeout=timeout)
|
||||
return {"selected": True, "selector": selector, "value": value}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Select failed: {e}"}
|
||||
|
||||
def scroll(self, direction: str = "down", amount: int = 500) -> Dict[str, Any]:
|
||||
return self._submit(self._do_scroll, direction, amount)
|
||||
|
||||
def _do_scroll(self, direction, amount) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
delta_map = {
|
||||
"down": (0, amount),
|
||||
"up": (0, -amount),
|
||||
"right": (amount, 0),
|
||||
"left": (-amount, 0),
|
||||
}
|
||||
dx, dy = delta_map.get(direction, (0, amount))
|
||||
try:
|
||||
page.mouse.wheel(dx, dy)
|
||||
page.wait_for_timeout(300)
|
||||
scroll_info = page.evaluate("""
|
||||
() => ({
|
||||
scrollX: window.scrollX,
|
||||
scrollY: window.scrollY,
|
||||
scrollHeight: document.documentElement.scrollHeight,
|
||||
clientHeight: document.documentElement.clientHeight
|
||||
})
|
||||
""")
|
||||
return {"scrolled": direction, "amount": amount, **scroll_info}
|
||||
except Exception as e:
|
||||
return {"error": f"Scroll failed: {e}"}
|
||||
|
||||
def wait(self, selector: Optional[str] = None, timeout: int = 5000,
|
||||
state: str = "visible") -> Dict[str, Any]:
|
||||
return self._submit(self._do_wait, selector, timeout, state)
|
||||
|
||||
def _do_wait(self, selector, timeout, state) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if selector:
|
||||
page.wait_for_selector(selector, timeout=timeout, state=state)
|
||||
return {"waited": True, "selector": selector, "state": state}
|
||||
else:
|
||||
page.wait_for_timeout(timeout)
|
||||
return {"waited": True, "timeout_ms": timeout}
|
||||
except Exception as e:
|
||||
return {"error": f"Wait failed: {e}"}
|
||||
|
||||
def go_back(self) -> Dict[str, Any]:
|
||||
return self._submit(self._do_go_back)
|
||||
|
||||
def _do_go_back(self) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.go_back(wait_until="domcontentloaded", timeout=10000)
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
return {"url": url, "title": title}
|
||||
except Exception as e:
|
||||
return {"error": f"Go back failed: {e}"}
|
||||
|
||||
def go_forward(self) -> Dict[str, Any]:
|
||||
return self._submit(self._do_go_forward)
|
||||
|
||||
def _do_go_forward(self) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.go_forward(wait_until="domcontentloaded", timeout=10000)
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
return {"url": url, "title": title}
|
||||
except Exception as e:
|
||||
return {"error": f"Go forward failed: {e}"}
|
||||
|
||||
def get_text(self, selector: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_get_text, selector)
|
||||
|
||||
def _do_get_text(self, selector) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
text = page.text_content(selector, timeout=5000)
|
||||
return {"text": text or ""}
|
||||
except Exception as e:
|
||||
return {"error": f"Get text failed: {e}"}
|
||||
|
||||
def evaluate(self, script: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_evaluate, script)
|
||||
|
||||
def _do_evaluate(self, script) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
result = page.evaluate(script)
|
||||
return {"result": result}
|
||||
except Exception as e:
|
||||
return {"error": f"Evaluate failed: {e}"}
|
||||
|
||||
def press(self, key: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_press, key)
|
||||
|
||||
def _do_press(self, key) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.keyboard.press(key)
|
||||
page.wait_for_timeout(300)
|
||||
return {"pressed": key}
|
||||
except Exception as e:
|
||||
return {"error": f"Press failed: {e}"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_screenshot_dir(self, cwd: str = "") -> str:
|
||||
if self._screenshot_dir and os.path.isdir(self._screenshot_dir):
|
||||
return self._screenshot_dir
|
||||
base = cwd or os.getcwd()
|
||||
d = os.path.join(base, "tmp")
|
||||
os.makedirs(d, exist_ok=True)
|
||||
self._screenshot_dir = d
|
||||
return d
|
||||
303
agent/tools/browser/browser_tool.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Browser tool - Control a Chromium browser for web navigation and interaction.
|
||||
|
||||
Uses Playwright under the hood. Browser instance is lazily started on first
|
||||
use, reused across tool calls within the same session, and cleaned up via
|
||||
close().
|
||||
|
||||
Launch modes (configured under `tools.browser` in config.json):
|
||||
- persistent (default): Chromium runs with a persistent user_data_dir
|
||||
(default `~/.cow/browser_profile`), so cookies and login state survive
|
||||
across runs. The user only needs to log in once.
|
||||
- cdp: When `cdp_endpoint` is set, attach to an externally launched Chrome
|
||||
via the Chrome DevTools Protocol. Lets the agent reuse the user's real
|
||||
browser (with all logins / extensions / true fingerprints).
|
||||
- fresh: Set `persistent` to false to fall back to a clean context every run.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.browser.browser_service import BrowserService
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class BrowserTool(BaseTool):
|
||||
"""Single tool exposing all browser actions via an 'action' parameter."""
|
||||
|
||||
name: str = "browser"
|
||||
description: str = (
|
||||
"Control a browser to navigate web pages, interact with elements, and extract content. "
|
||||
"Actions: navigate, snapshot, click, fill, select, scroll, screenshot, wait, back, forward, "
|
||||
"get_text, press, evaluate.\n\n"
|
||||
"Workflow: navigate (auto-includes snapshot with element refs) → click/fill/select by ref → snapshot to verify.\n\n"
|
||||
"Use snapshot as the primary way to read pages. Use screenshot + send to show key results to the user. "
|
||||
"For login/CAPTCHA/authorization etc., screenshot and ask the user for help. "
|
||||
"Login state is persisted across sessions (cookies / localStorage are kept in a "
|
||||
"user profile directory), so once the user logs in to a site, the agent can keep "
|
||||
"using it without logging in again."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The browser action to perform. One of: "
|
||||
"navigate, snapshot, click, fill, select, scroll, "
|
||||
"screenshot, wait, back, forward, get_text, press, evaluate"
|
||||
),
|
||||
"enum": [
|
||||
"navigate", "snapshot", "click", "fill", "select", "scroll",
|
||||
"screenshot", "wait", "back", "forward", "get_text", "press",
|
||||
"evaluate"
|
||||
]
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to navigate to (for 'navigate' action)"
|
||||
},
|
||||
"ref": {
|
||||
"type": "integer",
|
||||
"description": "Element ref number from snapshot (for click/fill/select)"
|
||||
},
|
||||
"selector": {
|
||||
"type": "string",
|
||||
"description": "CSS selector as fallback when ref is unavailable (for click/fill/select/wait/get_text)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to type (for 'fill' action)"
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Option value (for 'select' action)"
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key to press, e.g. Enter, Tab, Escape (for 'press' action)"
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"description": "Scroll direction: up, down, left, right (for 'scroll' action, default: down)"
|
||||
},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "JavaScript code to execute (for 'evaluate' action)"
|
||||
},
|
||||
"full_page": {
|
||||
"type": "boolean",
|
||||
"description": "Capture full page screenshot (for 'screenshot' action, default: false)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in milliseconds (optional, default varies by action)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
_shared_service: Optional[BrowserService] = None
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self._service: Optional[BrowserService] = None
|
||||
|
||||
def _get_service(self) -> BrowserService:
|
||||
"""Get or create the browser service, sharing across copies."""
|
||||
if self._service is not None:
|
||||
return self._service
|
||||
|
||||
# Reuse shared service across tool copies within the same session
|
||||
if BrowserTool._shared_service is not None:
|
||||
self._service = BrowserTool._shared_service
|
||||
return self._service
|
||||
|
||||
self._service = BrowserService(self.config)
|
||||
BrowserTool._shared_service = self._service
|
||||
return self._service
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
action = args.get("action", "").strip().lower()
|
||||
if not action:
|
||||
return ToolResult.fail("Error: 'action' parameter is required")
|
||||
|
||||
handler = self._ACTION_MAP.get(action)
|
||||
if not handler:
|
||||
valid = ", ".join(sorted(self._ACTION_MAP.keys()))
|
||||
return ToolResult.fail(f"Unknown action '{action}'. Valid actions: {valid}")
|
||||
|
||||
try:
|
||||
return handler(self, args)
|
||||
except Exception as e:
|
||||
logger.error(f"[Browser] Action '{action}' error: {e}")
|
||||
return ToolResult.fail(f"Browser error ({action}): {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _do_navigate(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' is required for navigate action")
|
||||
# Only auto-prepend https:// for bare hosts; preserve file://, about:, data:, etc.
|
||||
if "://" not in url and not url.startswith(("about:", "data:")):
|
||||
url = "https://" + url
|
||||
timeout = args.get("timeout", 30000)
|
||||
service = self._get_service()
|
||||
result = service.navigate(url, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
# Auto-snapshot after navigation so the agent gets page content in one call
|
||||
snapshot_text = service.snapshot()
|
||||
return ToolResult.success(
|
||||
f"Navigated to: {result['url']}\nTitle: {result['title']}\nStatus: {result['status']}\n\n"
|
||||
f"--- Page Snapshot ---\n{snapshot_text}"
|
||||
)
|
||||
|
||||
def _do_snapshot(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector")
|
||||
text = self._get_service().snapshot(selector=selector)
|
||||
return ToolResult.success(text)
|
||||
|
||||
def _do_click(self, args: Dict[str, Any]) -> ToolResult:
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
result = self._get_service().click(ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Clicked successfully. Use 'snapshot' to see updated page.")
|
||||
|
||||
def _do_fill(self, args: Dict[str, Any]) -> ToolResult:
|
||||
text = args.get("text", "")
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
if not text and text != "":
|
||||
return ToolResult.fail("Error: 'text' is required for fill action")
|
||||
result = self._get_service().fill(text, ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Filled text into element. Use 'snapshot' to verify.")
|
||||
|
||||
def _do_select(self, args: Dict[str, Any]) -> ToolResult:
|
||||
value = args.get("value", "")
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
if not value:
|
||||
return ToolResult.fail("Error: 'value' is required for select action")
|
||||
result = self._get_service().select(value, ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Selected option '{value}'.")
|
||||
|
||||
def _do_scroll(self, args: Dict[str, Any]) -> ToolResult:
|
||||
direction = args.get("direction", "down")
|
||||
amount = args.get("timeout", 500) # reuse timeout field or default
|
||||
if "amount" in args:
|
||||
amount = args["amount"]
|
||||
result = self._get_service().scroll(direction=direction, amount=amount)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
pos = f"scrollY={result.get('scrollY', '?')}/{result.get('scrollHeight', '?')}"
|
||||
return ToolResult.success(f"Scrolled {direction}. Position: {pos}")
|
||||
|
||||
def _do_screenshot(self, args: Dict[str, Any]) -> ToolResult:
|
||||
full_page = args.get("full_page", False)
|
||||
filepath = self._get_service().screenshot(full_page=full_page, cwd=self.cwd)
|
||||
return ToolResult.success(f"Screenshot saved to: {filepath}")
|
||||
|
||||
def _do_wait(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
result = self._get_service().wait(selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Wait completed.")
|
||||
|
||||
def _do_back(self, args: Dict[str, Any]) -> ToolResult:
|
||||
result = self._get_service().go_back()
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Navigated back to: {result['url']}")
|
||||
|
||||
def _do_forward(self, args: Dict[str, Any]) -> ToolResult:
|
||||
result = self._get_service().go_forward()
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Navigated forward to: {result['url']}")
|
||||
|
||||
def _do_get_text(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector", "").strip()
|
||||
if not selector:
|
||||
return ToolResult.fail("Error: 'selector' is required for get_text action")
|
||||
result = self._get_service().get_text(selector)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(result["text"])
|
||||
|
||||
def _do_press(self, args: Dict[str, Any]) -> ToolResult:
|
||||
key = args.get("key", "").strip()
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for press action")
|
||||
result = self._get_service().press(key)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Pressed key: {key}")
|
||||
|
||||
def _do_evaluate(self, args: Dict[str, Any]) -> ToolResult:
|
||||
script = args.get("script", "").strip()
|
||||
if not script:
|
||||
return ToolResult.fail("Error: 'script' is required for evaluate action")
|
||||
result = self._get_service().evaluate(script)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
val = result.get("result")
|
||||
if isinstance(val, (dict, list)):
|
||||
return ToolResult.success(json.dumps(val, ensure_ascii=False, indent=2))
|
||||
return ToolResult.success(str(val) if val is not None else "(no return value)")
|
||||
|
||||
# Action dispatch table
|
||||
_ACTION_MAP = {
|
||||
"navigate": _do_navigate,
|
||||
"snapshot": _do_snapshot,
|
||||
"click": _do_click,
|
||||
"fill": _do_fill,
|
||||
"select": _do_select,
|
||||
"scroll": _do_scroll,
|
||||
"screenshot": _do_screenshot,
|
||||
"wait": _do_wait,
|
||||
"back": _do_back,
|
||||
"forward": _do_forward,
|
||||
"get_text": _do_get_text,
|
||||
"press": _do_press,
|
||||
"evaluate": _do_evaluate,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def copy(self):
|
||||
"""Share browser instance across tool copies (avoids re-launching)."""
|
||||
new_tool = BrowserTool(self.config)
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, "context", None)
|
||||
new_tool.cwd = self.cwd
|
||||
new_tool._service = self._service
|
||||
return new_tool
|
||||
|
||||
def close(self):
|
||||
"""Release browser resources."""
|
||||
if self._service:
|
||||
self._service.close()
|
||||
self._service = None
|
||||
BrowserTool._shared_service = None
|
||||
logger.info("[Browser] BrowserTool closed")
|
||||
@@ -1,18 +0,0 @@
|
||||
def copy(self):
|
||||
"""
|
||||
Special copy method for browser tool to avoid recreating browser instance.
|
||||
|
||||
:return: A new instance with shared browser reference but unique model
|
||||
"""
|
||||
new_tool = self.__class__()
|
||||
|
||||
# Copy essential attributes
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, 'context', None)
|
||||
new_tool.config = getattr(self, 'config', None)
|
||||
|
||||
# Share the browser instance instead of creating a new one
|
||||
if hasattr(self, 'browser'):
|
||||
new_tool.browser = self.browser
|
||||
|
||||
return new_tool
|
||||
@@ -94,7 +94,7 @@ class Ls(BaseTool):
|
||||
results.append(entry + '/')
|
||||
else:
|
||||
results.append(entry)
|
||||
except:
|
||||
except Exception:
|
||||
# Skip entries we can't stat
|
||||
continue
|
||||
|
||||
|
||||
4
agent/tools/mcp/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from agent.tools.mcp.mcp_client import McpClient, McpClientRegistry
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
|
||||
__all__ = ["McpClient", "McpClientRegistry", "McpTool"]
|
||||
528
agent/tools/mcp/mcp_client.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) client module.
|
||||
|
||||
Implements JSON-RPC 2.0 over stdio, SSE and Streamable HTTP transports
|
||||
without any external MCP SDK dependency.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import select
|
||||
import subprocess
|
||||
import threading
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# Aliases accepted for the Streamable HTTP transport type
|
||||
_STREAMABLE_HTTP_ALIASES = {"streamable-http", "streamable_http", "streamablehttp", "http"}
|
||||
|
||||
|
||||
class McpClient:
|
||||
"""Single MCP Server client supporting stdio, SSE and Streamable HTTP transports."""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""
|
||||
config examples:
|
||||
stdio: {"name": "filesystem", "type": "stdio", "command": "npx", "args": [...]}
|
||||
SSE: {"name": "my-api", "type": "sse", "url": "http://localhost:8000/sse"}
|
||||
streamable-http: {"name": "pubmed", "type": "streamable-http", "url": "https://x/mcp"}
|
||||
"""
|
||||
self.config = config
|
||||
self.name: str = config.get("name", "unknown")
|
||||
raw_transport: str = config.get("type", "stdio")
|
||||
# Normalize streamable-http aliases to a single internal key
|
||||
self.transport: str = (
|
||||
"streamable-http"
|
||||
if raw_transport.lower() in _STREAMABLE_HTTP_ALIASES
|
||||
else raw_transport
|
||||
)
|
||||
|
||||
# stdio state
|
||||
self._proc: Optional[subprocess.Popen] = None
|
||||
|
||||
# SSE state
|
||||
self._sse_url: Optional[str] = None
|
||||
self._post_url: Optional[str] = None # endpoint for sending messages (resolved from SSE)
|
||||
|
||||
# Streamable HTTP state
|
||||
self._http_url: Optional[str] = None
|
||||
self._http_headers: dict = {} # extra headers from user config (e.g. Authorization)
|
||||
self._http_session_id: Optional[str] = None # Mcp-Session-Id assigned by the server
|
||||
|
||||
# Shared state
|
||||
self._next_id = 1
|
||||
self._id_lock = threading.Lock()
|
||||
self._call_lock = threading.Lock()
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Connect and perform the MCP handshake. Returns True on success."""
|
||||
try:
|
||||
if self.transport == "stdio":
|
||||
return self._init_stdio()
|
||||
elif self.transport == "sse":
|
||||
return self._init_sse()
|
||||
elif self.transport == "streamable-http":
|
||||
return self._init_streamable_http()
|
||||
else:
|
||||
logger.warning(f"[MCP:{self.name}] Unknown transport type: {self.transport!r}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] Initialization failed: {e}")
|
||||
return False
|
||||
|
||||
def list_tools(self) -> list:
|
||||
"""Return the tool list from this server.
|
||||
|
||||
Each item is a dict: {"name": str, "description": str, "inputSchema": dict}
|
||||
"""
|
||||
try:
|
||||
resp = self._send_request("tools/list", {})
|
||||
tools = resp.get("result", {}).get("tools", [])
|
||||
return [
|
||||
{
|
||||
"name": t.get("name", ""),
|
||||
"description": t.get("description", ""),
|
||||
"inputSchema": t.get("inputSchema", {}),
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] list_tools failed: {e}")
|
||||
return []
|
||||
|
||||
def call_tool(self, name: str, arguments: dict) -> str:
|
||||
"""Call a tool and return the result as a string."""
|
||||
try:
|
||||
resp = self._send_request("tools/call", {"name": name, "arguments": arguments})
|
||||
content = resp.get("result", {}).get("content", [])
|
||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||
return "\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] call_tool({name}) failed: {e}")
|
||||
return f"Error: {e}"
|
||||
|
||||
def shutdown(self):
|
||||
"""Close the connection / terminate the child process."""
|
||||
if self._proc is not None:
|
||||
try:
|
||||
self._proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=5)
|
||||
except Exception:
|
||||
try:
|
||||
self._proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
self._proc = None
|
||||
logger.debug(f"[MCP:{self.name}] stdio process terminated")
|
||||
|
||||
# Best-effort streamable-http session termination
|
||||
if self.transport == "streamable-http" and self._http_session_id and self._http_url:
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
self._http_url,
|
||||
method="DELETE",
|
||||
headers={"Mcp-Session-Id": self._http_session_id, **self._http_headers},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5):
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
self._http_session_id = None
|
||||
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# stdio transport
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_stdio(self) -> bool:
|
||||
command = self.config.get("command")
|
||||
if not command:
|
||||
logger.warning(f"[MCP:{self.name}] stdio config missing 'command'")
|
||||
return False
|
||||
|
||||
args = self.config.get("args", [])
|
||||
extra_env = self.config.get("env", None)
|
||||
env = {**os.environ, **extra_env} if extra_env else None
|
||||
|
||||
self._proc = subprocess.Popen(
|
||||
[command] + list(args),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
env=env,
|
||||
)
|
||||
logger.debug(f"[MCP:{self.name}] stdio process started (pid={self._proc.pid})")
|
||||
|
||||
threading.Thread(
|
||||
target=self._drain_stderr, daemon=True, name=f"mcp-stderr-{self.name}"
|
||||
).start()
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _drain_stderr(self):
|
||||
for line in self._proc.stderr:
|
||||
line = line.strip()
|
||||
if line:
|
||||
logger.debug(f"[MCP:{self.name}] stderr: {line}")
|
||||
|
||||
def _readline_with_timeout(self, timeout: int = 30) -> str:
|
||||
"""Read one line from stdio stdout with a hard timeout."""
|
||||
ready, _, _ = select.select([self._proc.stdout], [], [], timeout)
|
||||
if not ready:
|
||||
raise TimeoutError(f"[MCP:{self.name}] stdio read timed out after {timeout}s")
|
||||
return self._proc.stdout.readline()
|
||||
|
||||
def _stdio_send(self, message: dict) -> dict:
|
||||
"""Send a JSON-RPC message over stdio and read the response."""
|
||||
raw = json.dumps(message) + "\n"
|
||||
self._proc.stdin.write(raw)
|
||||
self._proc.stdin.flush()
|
||||
|
||||
while True:
|
||||
line = self._readline_with_timeout()
|
||||
if not line:
|
||||
raise IOError(f"[MCP:{self.name}] stdio process closed unexpectedly")
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if "id" not in data:
|
||||
logger.debug(f"[MCP:{self.name}] notification skipped: {data.get('method', '?')}")
|
||||
continue
|
||||
return data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE transport
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_sse(self) -> bool:
|
||||
url = self.config.get("url")
|
||||
if not url:
|
||||
logger.warning(f"[MCP:{self.name}] SSE config missing 'url'")
|
||||
return False
|
||||
|
||||
self._sse_url = url
|
||||
|
||||
# Read the first SSE event to discover the POST endpoint
|
||||
try:
|
||||
self._post_url = self._sse_discover_endpoint()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] SSE endpoint discovery failed: {e}")
|
||||
return False
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _sse_discover_endpoint(self) -> str:
|
||||
"""Open SSE stream and read the 'endpoint' event to learn the POST URL."""
|
||||
req = urllib.request.Request(
|
||||
self._sse_url,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
for raw_line in resp:
|
||||
line = raw_line.decode("utf-8").rstrip("\n\r")
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data:"):].strip()
|
||||
# Some servers send JSON with a "uri" or plain path
|
||||
if data.startswith("{"):
|
||||
parsed = json.loads(data)
|
||||
return parsed.get("uri") or parsed.get("url") or parsed.get("endpoint")
|
||||
# Plain relative or absolute URL
|
||||
if data.startswith("http"):
|
||||
return data
|
||||
# Relative path: resolve against SSE base
|
||||
from urllib.parse import urljoin
|
||||
return urljoin(self._sse_url, data)
|
||||
raise ValueError(f"[MCP:{self.name}] No endpoint event received from SSE stream")
|
||||
|
||||
def _sse_send(self, message: dict) -> dict:
|
||||
"""POST a JSON-RPC message to the server and return the response."""
|
||||
body = json.dumps(message).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
self._post_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
raw = resp.read().decode("utf-8")
|
||||
return json.loads(raw)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streamable HTTP transport (MCP spec 2025-03-26)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_streamable_http(self) -> bool:
|
||||
url = self.config.get("url")
|
||||
if not url:
|
||||
logger.warning(f"[MCP:{self.name}] streamable-http config missing 'url'")
|
||||
return False
|
||||
|
||||
self._http_url = url
|
||||
# Allow user-provided headers (e.g. {"Authorization": "Bearer xxx"})
|
||||
extra_headers = self.config.get("headers") or {}
|
||||
if isinstance(extra_headers, dict):
|
||||
self._http_headers = {str(k): str(v) for k, v in extra_headers.items()}
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _streamable_http_send(self, message: dict) -> dict:
|
||||
"""POST a JSON-RPC request and return the response (JSON or SSE-wrapped)."""
|
||||
return self._streamable_http_post(message, expect_response=True)
|
||||
|
||||
def _streamable_http_post(self, message: dict, expect_response: bool) -> dict:
|
||||
"""
|
||||
POST a JSON-RPC message over Streamable HTTP.
|
||||
|
||||
Per the spec, the response Content-Type can be either:
|
||||
- application/json -> single JSON-RPC response in body
|
||||
- text/event-stream -> SSE stream; we read until we get a matching response
|
||||
"""
|
||||
body = json.dumps(message).encode("utf-8")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
if self._http_session_id:
|
||||
headers["Mcp-Session-Id"] = self._http_session_id
|
||||
headers.update(self._http_headers)
|
||||
|
||||
req = urllib.request.Request(
|
||||
self._http_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = urllib.request.urlopen(req, timeout=30)
|
||||
except urllib.error.HTTPError as e:
|
||||
# Surface the server-provided error body for easier debugging
|
||||
detail = ""
|
||||
try:
|
||||
detail = e.read().decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
pass
|
||||
raise IOError(
|
||||
f"[MCP:{self.name}] streamable-http HTTP {e.code}: {detail[:200]}"
|
||||
)
|
||||
|
||||
with resp:
|
||||
# Capture session id assigned by the server (if any)
|
||||
session_id = resp.headers.get("Mcp-Session-Id")
|
||||
if session_id and not self._http_session_id:
|
||||
self._http_session_id = session_id
|
||||
|
||||
status = resp.status if hasattr(resp, "status") else resp.getcode()
|
||||
|
||||
# Notifications: server may reply with 202 Accepted and no body
|
||||
if not expect_response or status == 202:
|
||||
try:
|
||||
resp.read()
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
content_type = (resp.headers.get("Content-Type") or "").lower()
|
||||
expected_id = message.get("id")
|
||||
|
||||
if "text/event-stream" in content_type:
|
||||
return self._read_sse_response(resp, expected_id)
|
||||
|
||||
raw = resp.read().decode("utf-8")
|
||||
if not raw:
|
||||
return {}
|
||||
return json.loads(raw)
|
||||
|
||||
def _read_sse_response(self, resp, expected_id) -> dict:
|
||||
"""Read an SSE stream and return the first JSON-RPC response with matching id."""
|
||||
data_buf: list = []
|
||||
for raw_line in resp:
|
||||
line = raw_line.decode("utf-8").rstrip("\n\r")
|
||||
if line == "":
|
||||
# End of an SSE event, attempt to parse accumulated data
|
||||
if data_buf:
|
||||
payload = "\n".join(data_buf)
|
||||
data_buf = []
|
||||
try:
|
||||
msg = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
# Skip notifications / mismatched ids
|
||||
if "id" not in msg:
|
||||
continue
|
||||
if expected_id is None or msg.get("id") == expected_id:
|
||||
return msg
|
||||
continue
|
||||
if line.startswith(":"):
|
||||
continue # SSE comment / keepalive
|
||||
if line.startswith("data:"):
|
||||
data_buf.append(line[len("data:"):].lstrip())
|
||||
# Ignore 'event:' / 'id:' lines; we only care about JSON-RPC payloads
|
||||
|
||||
raise IOError(f"[MCP:{self.name}] streamable-http SSE stream closed before response")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Common JSON-RPC helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _next_request_id(self) -> int:
|
||||
with self._id_lock:
|
||||
rid = self._next_id
|
||||
self._next_id += 1
|
||||
return rid
|
||||
|
||||
def _build_request(self, method: str, params: dict) -> dict:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._next_request_id(),
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
def _build_notification(self, method: str, params: dict) -> dict:
|
||||
return {"jsonrpc": "2.0", "method": method, "params": params}
|
||||
|
||||
def _send_request(self, method: str, params: dict) -> dict:
|
||||
"""Send a request and return the full response dict."""
|
||||
if not self._initialized and method != "initialize":
|
||||
raise RuntimeError(f"[MCP:{self.name}] Client not initialized")
|
||||
|
||||
message = self._build_request(method, params)
|
||||
|
||||
with self._call_lock:
|
||||
if self.transport == "stdio":
|
||||
return self._stdio_send(message)
|
||||
elif self.transport == "sse":
|
||||
return self._sse_send(message)
|
||||
elif self.transport == "streamable-http":
|
||||
return self._streamable_http_send(message)
|
||||
else:
|
||||
raise ValueError(f"[MCP:{self.name}] Unsupported transport: {self.transport}")
|
||||
|
||||
def _send_notification(self, method: str, params: dict):
|
||||
"""Fire-and-forget notification (no response expected)."""
|
||||
notification = self._build_notification(method, params)
|
||||
raw = json.dumps(notification) + "\n"
|
||||
|
||||
if self.transport == "stdio":
|
||||
self._proc.stdin.write(raw)
|
||||
self._proc.stdin.flush()
|
||||
elif self.transport == "sse":
|
||||
body = raw.encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
self._post_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10):
|
||||
pass
|
||||
except Exception:
|
||||
pass # notifications are fire-and-forget
|
||||
elif self.transport == "streamable-http":
|
||||
try:
|
||||
self._streamable_http_post(notification, expect_response=False)
|
||||
except Exception:
|
||||
pass # notifications are fire-and-forget
|
||||
|
||||
def _handshake(self) -> bool:
|
||||
"""Perform the MCP initialize / notifications/initialized handshake."""
|
||||
init_params = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "CowAgent", "version": "1.0"},
|
||||
}
|
||||
# Temporarily mark as initialized so _send_request doesn't block
|
||||
self._initialized = True
|
||||
try:
|
||||
resp = self._send_request("initialize", init_params)
|
||||
except Exception as e:
|
||||
self._initialized = False
|
||||
logger.warning(f"[MCP:{self.name}] Handshake initialize failed: {e}")
|
||||
return False
|
||||
|
||||
if "error" in resp:
|
||||
self._initialized = False
|
||||
logger.warning(f"[MCP:{self.name}] Handshake error: {resp['error']}")
|
||||
return False
|
||||
|
||||
self._send_notification("notifications/initialized", {})
|
||||
logger.debug(f"[MCP:{self.name}] Handshake complete")
|
||||
return True
|
||||
|
||||
|
||||
class McpClientRegistry:
|
||||
"""Global singleton managing the lifecycle of all MCP Server clients."""
|
||||
|
||||
_instance = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
obj = super().__new__(cls)
|
||||
obj._clients: dict[str, McpClient] = {}
|
||||
obj._registry_lock = threading.Lock()
|
||||
cls._instance = obj
|
||||
return cls._instance
|
||||
|
||||
def start_all(self, configs: list) -> None:
|
||||
"""Initialize McpClient for each config entry; skip failures with a warning."""
|
||||
if not configs:
|
||||
return
|
||||
|
||||
for cfg in configs:
|
||||
name = cfg.get("name", "<unnamed>")
|
||||
client = McpClient(cfg)
|
||||
ok = client.initialize()
|
||||
if ok:
|
||||
with self._registry_lock:
|
||||
self._clients[name] = client
|
||||
logger.info(f"[MCP] Server '{name}' initialized successfully")
|
||||
else:
|
||||
logger.warning(f"[MCP] Server '{name}' failed to initialize — skipping")
|
||||
|
||||
def get(self, server_name: str) -> Optional[McpClient]:
|
||||
"""Return the initialized client for server_name, or None."""
|
||||
with self._registry_lock:
|
||||
return self._clients.get(server_name)
|
||||
|
||||
def all_clients(self) -> dict:
|
||||
"""Return a copy of the {name: McpClient} mapping."""
|
||||
with self._registry_lock:
|
||||
return dict(self._clients)
|
||||
|
||||
def shutdown_all(self) -> None:
|
||||
"""Shut down all managed clients."""
|
||||
with self._registry_lock:
|
||||
clients = list(self._clients.values())
|
||||
self._clients.clear()
|
||||
|
||||
for client in clients:
|
||||
try:
|
||||
client.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP] Error shutting down '{client.name}': {e}")
|
||||
|
||||
logger.info("[MCP] All servers shut down")
|
||||
31
agent/tools/mcp/mcp_tool.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class McpTool(BaseTool):
|
||||
"""
|
||||
将单个 MCP 工具包装为 BaseTool。
|
||||
一个 MCP Server 可以提供多个工具,每个工具对应一个 McpTool 实例。
|
||||
"""
|
||||
|
||||
def __init__(self, client, tool_schema: dict, server_name: str):
|
||||
"""
|
||||
:param client: 该工具所属的 McpClient 实例
|
||||
:param tool_schema: MCP 返回的工具描述,格式:
|
||||
{"name": str, "description": str, "inputSchema": dict}
|
||||
:param server_name: Server 名称,用于日志
|
||||
"""
|
||||
self.client = client
|
||||
self.server_name = server_name
|
||||
self.name = tool_schema["name"]
|
||||
self.description = tool_schema.get("description", "")
|
||||
self.params = tool_schema.get("inputSchema", {})
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
logger.info(f"[McpTool] server={self.server_name} tool={self.name} params={params}")
|
||||
try:
|
||||
result = self.client.call_tool(self.name, params)
|
||||
return ToolResult.success(result)
|
||||
except Exception as e:
|
||||
logger.error(f"[McpTool] server={self.server_name} tool={self.name} error: {e}")
|
||||
return ToolResult.fail(str(e))
|
||||
@@ -44,6 +44,19 @@ class MemoryGetTool(BaseTool):
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
self.description = (
|
||||
"Read specific content from memory or knowledge files. "
|
||||
"Use this to get full context from a memory file, knowledge page, or specific line range."
|
||||
)
|
||||
self.params = {**self.params}
|
||||
self.params["properties"] = {**self.params["properties"]}
|
||||
self.params["properties"]["path"] = {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory or knowledge file (e.g. 'MEMORY.md', 'memory/2026-01-01.md', 'knowledge/concepts/moe.md')"
|
||||
}
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
@@ -68,11 +81,15 @@ class MemoryGetTool(BaseTool):
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
# Exception: MEMORY.md is in the root directory
|
||||
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
# Exceptions: MEMORY.md in root, knowledge/ files at workspace root
|
||||
if not path.startswith('memory/') and not path.startswith('knowledge/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = workspace_dir / path
|
||||
file_path = (workspace_dir / path).resolve()
|
||||
workspace_resolved = workspace_dir.resolve()
|
||||
|
||||
if not str(file_path).startswith(str(workspace_resolved) + '/') and file_path != workspace_resolved:
|
||||
return ToolResult.fail(f"Error: Access denied: path outside workspace")
|
||||
|
||||
if not file_path.exists():
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
@@ -48,6 +48,13 @@ class MemorySearchTool(BaseTool):
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self.user_id = user_id
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
self.description = (
|
||||
"Search agent's long-term memory and knowledge base using semantic and keyword search. "
|
||||
"Use this to recall past conversations, preferences, and knowledge pages."
|
||||
)
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
|
||||
@@ -48,7 +48,8 @@ class Read(BaseTool):
|
||||
self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
|
||||
self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
|
||||
self.pdf_extensions = {'.pdf'}
|
||||
|
||||
self.office_extensions = {'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx'}
|
||||
|
||||
# Readable text formats (will be read with truncation)
|
||||
self.text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
|
||||
@@ -57,7 +58,6 @@ class Read(BaseTool):
|
||||
'.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
|
||||
'.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
|
||||
'.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
|
||||
'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx' # Office documents
|
||||
}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
@@ -120,7 +120,11 @@ class Read(BaseTool):
|
||||
# Check if PDF
|
||||
if file_ext in self.pdf_extensions:
|
||||
return self._read_pdf(absolute_path, path, offset, limit)
|
||||
|
||||
|
||||
# Check if Office document (.docx, .xlsx, .pptx, etc.)
|
||||
if file_ext in self.office_extensions:
|
||||
return self._read_office(absolute_path, path, file_ext, offset, limit)
|
||||
|
||||
# Read text file (with truncation for large files)
|
||||
return self._read_text(absolute_path, path, offset, limit)
|
||||
|
||||
@@ -240,17 +244,12 @@ class Read(BaseTool):
|
||||
"message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
|
||||
})
|
||||
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
# Read file (utf-8-sig strips BOM automatically on Windows)
|
||||
# Note: Truncation is unified via truncate_head (DEFAULT_MAX_LINES / DEFAULT_MAX_BYTES)
|
||||
# so that offset/limit can paginate the entire file correctly.
|
||||
with open(absolute_path, 'r', encoding='utf-8-sig') as f:
|
||||
content = f.read()
|
||||
|
||||
# Truncate content if too long (20K characters max for model context)
|
||||
MAX_CONTENT_CHARS = 20 * 1024 # 20K characters
|
||||
content_truncated = False
|
||||
if len(content) > MAX_CONTENT_CHARS:
|
||||
content = content[:MAX_CONTENT_CHARS]
|
||||
content_truncated = True
|
||||
|
||||
|
||||
all_lines = content.split('\n')
|
||||
total_file_lines = len(all_lines)
|
||||
|
||||
@@ -286,11 +285,7 @@ class Read(BaseTool):
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
# Add truncation warning if content was truncated
|
||||
if content_truncated:
|
||||
output_text = f"[文件内容已截断到前 {format_size(MAX_CONTENT_CHARS)},完整文件大小: {format_size(file_size)}]\n\n"
|
||||
|
||||
|
||||
if truncation.first_line_exceeds_limit:
|
||||
# First line exceeds 30KB limit
|
||||
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
|
||||
@@ -337,6 +332,116 @@ class Read(BaseTool):
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading file: {str(e)}")
|
||||
|
||||
def _read_office(self, absolute_path: str, display_path: str, file_ext: str,
|
||||
offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""Read Office documents (.docx, .xlsx, .pptx) using python-docx / openpyxl / python-pptx."""
|
||||
try:
|
||||
text = self._extract_office_text(absolute_path, file_ext)
|
||||
except ImportError as e:
|
||||
return ToolResult.fail(str(e))
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading Office document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
return ToolResult.success({
|
||||
"content": f"[Office file {Path(absolute_path).name}: no text content could be extracted]",
|
||||
})
|
||||
|
||||
all_lines = text.split('\n')
|
||||
total_lines = len(all_lines)
|
||||
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
if offset < 0:
|
||||
start_line = max(0, total_lines + offset)
|
||||
else:
|
||||
start_line = max(0, offset - 1)
|
||||
if start_line >= total_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||
)
|
||||
|
||||
selected_content = text
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
truncation = truncate_head(selected_content)
|
||||
start_line_display = start_line + 1
|
||||
output_text = ""
|
||||
|
||||
if truncation.truncated:
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||
remaining = total_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text = truncation.content
|
||||
|
||||
return ToolResult.success({
|
||||
"content": output_text,
|
||||
"total_lines": total_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _extract_office_text(absolute_path: str, file_ext: str) -> str:
|
||||
"""Extract plain text from an Office document."""
|
||||
if file_ext in ('.docx', '.doc'):
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError("Error: python-docx library not installed. Install with: pip install python-docx")
|
||||
doc = Document(absolute_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs]
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
paragraphs.append('\t'.join(cell.text for cell in row.cells))
|
||||
return '\n'.join(paragraphs)
|
||||
|
||||
if file_ext in ('.xlsx', '.xls'):
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
except ImportError:
|
||||
raise ImportError("Error: openpyxl library not installed. Install with: pip install openpyxl")
|
||||
wb = load_workbook(absolute_path, read_only=True, data_only=True)
|
||||
parts = []
|
||||
for ws in wb.worksheets:
|
||||
parts.append(f"--- Sheet: {ws.title} ---")
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
parts.append('\t'.join(str(c) if c is not None else '' for c in row))
|
||||
wb.close()
|
||||
return '\n'.join(parts)
|
||||
|
||||
if file_ext in ('.pptx', '.ppt'):
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError("Error: python-pptx library not installed. Install with: pip install python-pptx")
|
||||
prs = Presentation(absolute_path)
|
||||
parts = []
|
||||
for i, slide in enumerate(prs.slides, 1):
|
||||
parts.append(f"--- Slide {i} ---")
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for para in shape.text_frame.paragraphs:
|
||||
text = para.text.strip()
|
||||
if text:
|
||||
parts.append(text)
|
||||
return '\n'.join(parts)
|
||||
|
||||
return ""
|
||||
|
||||
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read PDF file content
|
||||
|
||||
@@ -3,6 +3,7 @@ Integration module for scheduler with AgentBridge
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
@@ -13,65 +14,126 @@ from bridge.reply import Reply, ReplyType
|
||||
# Global scheduler service instance
|
||||
_scheduler_service = None
|
||||
_task_store = None
|
||||
# Module-level lock to guard idempotent initialization across threads
|
||||
_init_lock = threading.Lock()
|
||||
|
||||
|
||||
def init_scheduler(agent_bridge) -> bool:
|
||||
"""
|
||||
Initialize scheduler service
|
||||
|
||||
Initialize scheduler service (idempotent).
|
||||
|
||||
Safe to call multiple times and from multiple threads: only the first
|
||||
successful call creates the singleton ``SchedulerService`` + background
|
||||
scanning thread. Subsequent calls return immediately.
|
||||
|
||||
Args:
|
||||
agent_bridge: AgentBridge instance
|
||||
|
||||
|
||||
Returns:
|
||||
True if initialized successfully
|
||||
True if scheduler is initialized (newly created or already running)
|
||||
"""
|
||||
global _scheduler_service, _task_store
|
||||
|
||||
try:
|
||||
from agent.tools.scheduler.task_store import TaskStore
|
||||
from agent.tools.scheduler.scheduler_service import SchedulerService
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
|
||||
|
||||
# Create task store
|
||||
_task_store = TaskStore(store_path)
|
||||
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
|
||||
|
||||
# Create execute callback
|
||||
def execute_task_callback(task: dict):
|
||||
"""Callback to execute a scheduled task"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "agent_task":
|
||||
_execute_agent_task(task, agent_bridge)
|
||||
elif action_type == "send_message":
|
||||
# Legacy support for old tasks
|
||||
_execute_send_message(task, agent_bridge)
|
||||
elif action_type == "tool_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_tool_call(task, agent_bridge)
|
||||
elif action_type == "skill_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_skill_call(task, agent_bridge)
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
|
||||
|
||||
# Create scheduler service
|
||||
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
|
||||
_scheduler_service.start()
|
||||
|
||||
logger.debug("[Scheduler] Scheduler service initialized and started")
|
||||
|
||||
# Fast path: already initialized and running
|
||||
if _scheduler_service is not None and getattr(_scheduler_service, "running", False):
|
||||
return True
|
||||
|
||||
with _init_lock:
|
||||
# Re-check under the lock to avoid races where multiple threads
|
||||
# passed the fast-path check before any of them acquired the lock.
|
||||
if _scheduler_service is not None and getattr(_scheduler_service, "running", False):
|
||||
return True
|
||||
|
||||
try:
|
||||
from agent.tools.scheduler.task_store import TaskStore
|
||||
from agent.tools.scheduler.scheduler_service import SchedulerService
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
|
||||
|
||||
# Create task store (reuse if already created)
|
||||
if _task_store is None:
|
||||
_task_store = TaskStore(store_path)
|
||||
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
|
||||
|
||||
# Create execute callback. Returns True on success, False to ask
|
||||
# the scheduler to retry on the next tick (e.g. channel not yet
|
||||
# ready right after process start).
|
||||
def execute_task_callback(task: dict):
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type")
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
receiver = action.get("receiver", "")
|
||||
|
||||
if not _is_channel_ready(channel_type, receiver):
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task.get('id')}: channel "
|
||||
f"'{channel_type}' not ready for receiver={receiver} "
|
||||
f"(no inbound msg cached since restart?); deferring"
|
||||
)
|
||||
return False
|
||||
|
||||
if action_type == "agent_task":
|
||||
return _execute_agent_task(task, agent_bridge)
|
||||
elif action_type == "send_message":
|
||||
return _execute_send_message(task, agent_bridge)
|
||||
elif action_type == "tool_call":
|
||||
return _execute_tool_call(task, agent_bridge)
|
||||
elif action_type == "skill_call":
|
||||
return _execute_skill_call(task, agent_bridge)
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
|
||||
return False
|
||||
|
||||
# Create scheduler service
|
||||
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
|
||||
_scheduler_service.start()
|
||||
|
||||
logger.info("[Scheduler] Service initialized and started")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _is_channel_ready(channel_type: str, receiver: str) -> bool:
|
||||
"""Best-effort readiness probe for outbound channels.
|
||||
|
||||
Returns False when we know the send will drop (e.g. weixin not yet
|
||||
logged in, web session has no polling queue), so the scheduler can
|
||||
defer instead of consuming the task. Unknown channels return True
|
||||
to preserve previous behaviour.
|
||||
"""
|
||||
if not channel_type or channel_type == "unknown":
|
||||
return True
|
||||
try:
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if channel is None:
|
||||
return False
|
||||
|
||||
if channel_type == "weixin":
|
||||
tokens = getattr(channel, "_context_tokens", None)
|
||||
if not tokens or receiver not in tokens:
|
||||
return False
|
||||
return True
|
||||
|
||||
if channel_type == "web":
|
||||
queues = getattr(channel, "session_queues", None)
|
||||
if not queues or receiver not in queues:
|
||||
return False
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
|
||||
return False
|
||||
logger.warning(f"[Scheduler] Channel readiness check failed for {channel_type}: {e}")
|
||||
return True
|
||||
|
||||
|
||||
def get_task_store():
|
||||
@@ -84,13 +146,53 @@ def get_scheduler_service():
|
||||
return _scheduler_service
|
||||
|
||||
|
||||
def _execute_agent_task(task: dict, agent_bridge):
|
||||
def _remember_delivered_output(
|
||||
agent_bridge,
|
||||
task: dict,
|
||||
channel_type: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Best-effort persistence of the message the scheduler sent to a user.
|
||||
|
||||
Uses notify_session_id (the real chat session_id stored at task creation time)
|
||||
so that group chats correctly associate the output with the user's conversation.
|
||||
Falls back to receiver for backward compatibility with old tasks.
|
||||
|
||||
Per-action-type behaviour:
|
||||
- agent_task / tool_call / skill_call: gated by ``scheduler_inject_to_session``
|
||||
(default True). These produce AI-generated content worth remembering.
|
||||
- send_message: additionally gated by ``scheduler_inject_send_message``
|
||||
(default False). Fixed reminder text rarely benefits follow-up Q&A and
|
||||
would just consume context tokens.
|
||||
"""
|
||||
Execute an agent_task action - let Agent handle the task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
if not content:
|
||||
return
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
|
||||
# send_message defaults to NOT being injected; explicit opt-in via config.
|
||||
if action_type == "send_message":
|
||||
if not conf().get("scheduler_inject_send_message", False):
|
||||
return
|
||||
|
||||
session_id = action.get("notify_session_id") or action.get("receiver")
|
||||
if not session_id:
|
||||
return
|
||||
try:
|
||||
remember = getattr(agent_bridge, "remember_scheduled_output", None)
|
||||
if remember:
|
||||
task_desc = action.get("task_description") or action.get("content", "")
|
||||
remember(session_id, str(content), channel_type=channel_type, task_description=task_desc)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[Scheduler] Failed to remember delivered output for {session_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _execute_agent_task(task: dict, agent_bridge) -> bool:
|
||||
"""
|
||||
Execute an agent_task action - let Agent handle the task.
|
||||
Returns True on successful delivery, False to retry next tick.
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
@@ -101,11 +203,11 @@ def _execute_agent_task(task: dict, agent_bridge):
|
||||
|
||||
if not task_description:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No task_description specified")
|
||||
return
|
||||
return True # malformed task, don't loop forever
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
return True
|
||||
|
||||
# Check for unsupported channels
|
||||
if channel_type == "dingtalk":
|
||||
@@ -134,12 +236,13 @@ def _execute_agent_task(task: dict, agent_bridge):
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk requires msg object, set to None for scheduled tasks
|
||||
context["msg"] = None
|
||||
# 如果是单聊,需要传递 sender_staff_id
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the task
|
||||
# Mark this as a scheduled task execution to prevent recursive task creation
|
||||
context["is_scheduled_task"] = True
|
||||
@@ -147,50 +250,47 @@ def _execute_agent_task(task: dict, agent_bridge):
|
||||
try:
|
||||
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
|
||||
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=False)
|
||||
|
||||
if reply and reply.content:
|
||||
# Send the reply via channel
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register request_id
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
request_id = context.get("request_id")
|
||||
if request_id:
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
# Send the reply
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send result: {e}")
|
||||
else:
|
||||
|
||||
if not (reply and reply.content):
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution")
|
||||
|
||||
return True # agent ran but produced nothing; don't loop
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
request_id = context.get("request_id")
|
||||
if request_id:
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send result: {e}")
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, reply.content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute task via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_agent_task: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_send_message(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a send_message action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
def _execute_send_message(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a send_message action. Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
content = action.get("content", "")
|
||||
@@ -200,7 +300,7 @@ def _execute_send_message(task: dict, agent_bridge):
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
return True
|
||||
|
||||
# Create context for sending message
|
||||
context = Context(ContextType.TEXT, content)
|
||||
@@ -234,174 +334,146 @@ def _execute_send_message(task: dict, agent_bridge):
|
||||
logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
elif channel_type == "qq":
|
||||
context["msg"] = None
|
||||
|
||||
# Create reply
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_send_message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_tool_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a tool_call action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
def _execute_tool_call(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a tool_call action. Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
tool_name = action.get("call_name") or action.get("tool_name")
|
||||
tool_params = action.get("call_params") or action.get("tool_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
|
||||
if not tool_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified")
|
||||
return
|
||||
|
||||
return True
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Get tool manager and create tool instance
|
||||
return True
|
||||
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
tool_manager = ToolManager()
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
tool = ToolManager().create_tool(tool_name)
|
||||
if not tool:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found")
|
||||
return
|
||||
|
||||
# Execute tool
|
||||
return True
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}")
|
||||
result = tool.execute(tool_params)
|
||||
|
||||
# Get result content
|
||||
if hasattr(result, 'result'):
|
||||
content = result.result
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
# Add prefix if specified
|
||||
content = result.result if hasattr(result, 'result') else str(result)
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
# Send result as message
|
||||
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific context setup
|
||||
|
||||
request_id = None
|
||||
if channel_type == "web":
|
||||
# Web channel needs request_id
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and request_id and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send tool result: {e}")
|
||||
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_skill_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a skill_call action by asking Agent to run the skill
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
def _execute_skill_call(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a skill_call action by asking Agent to run the skill.
|
||||
Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
skill_name = action.get("call_name") or action.get("skill_name")
|
||||
skill_params = action.get("call_params") or action.get("skill_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("isgroup", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
|
||||
if not skill_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified")
|
||||
return
|
||||
|
||||
return True
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
return True
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}")
|
||||
|
||||
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
|
||||
# Format: scheduler_<receiver>_<task_id> to ensure isolation
|
||||
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
|
||||
# Build a natural language query for the Agent to execute the skill
|
||||
# Format: "Use skill-name to do something with params"
|
||||
param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()])
|
||||
query = f"Use {skill_name} skill"
|
||||
if param_str:
|
||||
query += f" with {param_str}"
|
||||
|
||||
# Create context for Agent
|
||||
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
# Channel-specific setup
|
||||
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
@@ -409,32 +481,51 @@ def _execute_skill_call(task: dict, agent_bridge):
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the skill
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
try:
|
||||
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
|
||||
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=False)
|
||||
|
||||
if reply and reply.content:
|
||||
content = reply.content
|
||||
|
||||
# Add prefix if specified
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
return False
|
||||
|
||||
if not (reply and reply.content):
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
|
||||
return True
|
||||
|
||||
content = reply.content
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
req_id = context.get("request_id")
|
||||
if req_id:
|
||||
channel.request_to_session[req_id] = receiver
|
||||
|
||||
try:
|
||||
channel.send(Reply(ReplyType.TEXT, content), context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send skill result: {e}")
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_skill_call: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
@@ -451,8 +542,7 @@ def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
if context:
|
||||
tool.current_context = context
|
||||
|
||||
# Also set channel_type from config
|
||||
channel_type = conf().get("channel_type", "unknown")
|
||||
channel_type = context.get("channel_type") or conf().get("channel_type", "unknown")
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = channel_type
|
||||
|
||||
@@ -10,6 +10,19 @@ from croniter import croniter
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def _parse_naive_local(iso_str: str) -> datetime:
|
||||
"""Parse an ISO datetime and coerce it to tz-naive local time.
|
||||
|
||||
The scheduler uses ``datetime.now()`` (tz-naive) for all comparisons,
|
||||
so any persisted timestamp must be normalized to the same flavor —
|
||||
otherwise comparing naive vs aware raises TypeError.
|
||||
"""
|
||||
dt = datetime.fromisoformat(iso_str)
|
||||
if dt.tzinfo is not None:
|
||||
dt = dt.astimezone().replace(tzinfo=None)
|
||||
return dt
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""
|
||||
Background service that executes scheduled tasks
|
||||
@@ -39,7 +52,6 @@ class SchedulerService:
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self.thread.start()
|
||||
logger.debug("[Scheduler] Service started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the scheduler service"""
|
||||
@@ -54,15 +66,14 @@ class SchedulerService:
|
||||
|
||||
def _run_loop(self):
|
||||
"""Main scheduler loop"""
|
||||
logger.debug("[Scheduler] Scheduler loop started")
|
||||
logger.info("[Scheduler] Scheduler loop started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self._check_and_execute_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in scheduler loop: {e}")
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
def _check_and_execute_tasks(self):
|
||||
@@ -72,12 +83,18 @@ class SchedulerService:
|
||||
|
||||
for task in tasks:
|
||||
try:
|
||||
# Check if task is due
|
||||
if self._is_task_due(task, now):
|
||||
logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}")
|
||||
self._execute_task(task)
|
||||
|
||||
# Update next run time
|
||||
ok = self._execute_task(task)
|
||||
if not ok:
|
||||
# Leave next_run_at as-is so the next loop retries.
|
||||
# Cron tasks within the catch-up window will keep
|
||||
# firing; beyond it _is_task_due will reschedule.
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task['id']} delivery failed, will retry next tick"
|
||||
)
|
||||
continue
|
||||
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
@@ -85,12 +102,8 @@ class SchedulerService:
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
else:
|
||||
# One-time task, disable it
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task completed and disabled: {task['id']}")
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task completed and removed: {task['id']}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
|
||||
|
||||
@@ -117,37 +130,43 @@ class SchedulerService:
|
||||
return False
|
||||
|
||||
try:
|
||||
next_run = datetime.fromisoformat(next_run_str)
|
||||
|
||||
# Check if task is overdue (e.g., service restart)
|
||||
next_run = _parse_naive_local(next_run_str)
|
||||
|
||||
if next_run < now:
|
||||
time_diff = (now - next_run).total_seconds()
|
||||
|
||||
# If overdue by more than 5 minutes, skip this run and schedule next
|
||||
if time_diff > 300: # 5 minutes
|
||||
logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run")
|
||||
|
||||
# For one-time tasks, disable them
|
||||
schedule = task.get("schedule", {})
|
||||
if schedule.get("type") == "once":
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, disabled")
|
||||
return False
|
||||
|
||||
# For recurring tasks, calculate next run from now
|
||||
next_next_run = self._calculate_next_run(task, now)
|
||||
if next_next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_next_run.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
# Catch-up window: fire if we're within 10 minutes of the
|
||||
# scheduled tick. Beyond that we'd rather skip than push a
|
||||
# stale daily report to the user.
|
||||
if time_diff <= 600:
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, "
|
||||
f"skipping and scheduling next run"
|
||||
)
|
||||
|
||||
if schedule_type == "once":
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, removed")
|
||||
return False
|
||||
|
||||
|
||||
next_next_run = self._calculate_next_run(task, now)
|
||||
if next_next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_next_run.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
|
||||
return False
|
||||
|
||||
return now >= next_run
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Scheduler] Failed to evaluate due-state for task "
|
||||
f"{task.get('id')} (next_run_at={next_run_str!r}): {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
|
||||
@@ -191,30 +210,34 @@ class SchedulerService:
|
||||
return None
|
||||
|
||||
try:
|
||||
run_at = datetime.fromisoformat(run_at_str)
|
||||
# Only return if in the future
|
||||
run_at = _parse_naive_local(run_at_str)
|
||||
if run_at > from_time:
|
||||
return run_at
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Scheduler] Failed to parse once-task run_at "
|
||||
f"{run_at_str!r}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _execute_task(self, task: dict):
|
||||
def _execute_task(self, task: dict) -> bool:
|
||||
"""
|
||||
Execute a task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
Execute a task.
|
||||
|
||||
Returns True if delivery succeeded (caller should advance state),
|
||||
False if it failed (caller should keep next_run_at so the next
|
||||
loop iteration retries). Callback may return None for legacy
|
||||
behaviour, treated as success.
|
||||
"""
|
||||
try:
|
||||
# Call the execute callback
|
||||
self.execute_callback(task)
|
||||
result = self.execute_callback(task)
|
||||
return False if result is False else True
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task['id']}: {e}")
|
||||
# Update task with error
|
||||
self.task_store.update_task(task['id'], {
|
||||
"last_error": str(e),
|
||||
"last_error_at": datetime.now().isoformat()
|
||||
})
|
||||
return False
|
||||
|
||||
@@ -158,6 +158,11 @@ class SchedulerTool(BaseTool):
|
||||
# Create task
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Capture the real chat session_id at task creation time so that scheduler
|
||||
# can later inject the delivered output into the user's actual conversation
|
||||
# (in group chats, session_id != receiver, e.g. "user_id:group_id" on feishu).
|
||||
notify_session_id = context.get("session_id")
|
||||
|
||||
# Build action based on message or ai_task
|
||||
if message:
|
||||
action = {
|
||||
@@ -166,7 +171,8 @@ class SchedulerTool(BaseTool):
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
"channel_type": self.config.get("channel_type", "unknown"),
|
||||
"notify_session_id": notify_session_id,
|
||||
}
|
||||
else: # ai_task
|
||||
action = {
|
||||
@@ -175,7 +181,8 @@ class SchedulerTool(BaseTool):
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
"channel_type": self.config.get("channel_type", "unknown"),
|
||||
"notify_session_id": notify_session_id,
|
||||
}
|
||||
|
||||
# 针对钉钉单聊,额外存储 sender_staff_id
|
||||
@@ -357,9 +364,12 @@ class SchedulerTool(BaseTool):
|
||||
logger.error(f"[SchedulerTool] Invalid relative time format: {schedule_value}")
|
||||
return None
|
||||
else:
|
||||
# Absolute time in ISO format
|
||||
datetime.fromisoformat(schedule_value)
|
||||
return {"type": "once", "run_at": schedule_value}
|
||||
# Absolute ISO time. Normalize to tz-naive local so it
|
||||
# stays comparable with the scheduler's datetime.now().
|
||||
parsed = datetime.fromisoformat(schedule_value)
|
||||
if parsed.tzinfo is not None:
|
||||
parsed = parsed.astimezone().replace(tzinfo=None)
|
||||
return {"type": "once", "run_at": parsed.isoformat()}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Invalid schedule: {e}")
|
||||
@@ -424,7 +434,7 @@ class SchedulerTool(BaseTool):
|
||||
try:
|
||||
dt = datetime.fromisoformat(run_at)
|
||||
return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
|
||||
except:
|
||||
except Exception:
|
||||
return "一次性"
|
||||
|
||||
return "未知"
|
||||
@@ -438,6 +448,6 @@ class SchedulerTool(BaseTool):
|
||||
return msg.other_user_nickname or "群聊"
|
||||
else:
|
||||
return msg.from_user_nickname or "用户"
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return "未知"
|
||||
|
||||
@@ -72,7 +72,7 @@ class TaskStore:
|
||||
with open(self.store_path, 'r') as src:
|
||||
with open(backup_path, 'w') as dst:
|
||||
dst.write(src.read())
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save tasks
|
||||
|
||||
@@ -14,14 +14,14 @@ class Send(BaseTool):
|
||||
"""Tool for sending files to the user"""
|
||||
|
||||
name: str = "send"
|
||||
description: str = "Send a file (image, video, audio, document) to the user. Use this when the user explicitly asks to send/share a file."
|
||||
description: str = "Send a LOCAL file (image, video, audio, document) to the user. Only for local file paths. Do NOT use this for URLs — URLs should be included directly in your text reply, the system will handle them automatically."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to send. Can be absolute path or relative to workspace."
|
||||
"description": "Local file path to send. Must be an absolute path or relative to workspace. Do NOT pass URLs here."
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
@@ -98,7 +98,18 @@ class Send(BaseTool):
|
||||
"size_formatted": self._format_size(file_size),
|
||||
"message": message or f"正在发送 {file_name}"
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
from common.cloud_client import get_website_base_url, copy_send_file
|
||||
|
||||
# Do nothing when in local env
|
||||
if get_website_base_url():
|
||||
url = copy_send_file(absolute_path, self.cwd)
|
||||
if url:
|
||||
result["url"] = url
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Type
|
||||
from agent.tools.base_tool import BaseTool
|
||||
@@ -7,6 +8,26 @@ from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
def _normalize_mcp_configs(raw) -> list:
|
||||
"""
|
||||
Convert MCP server config to internal list format.
|
||||
Supports:
|
||||
- list format (mcp_servers): [{"name": "x", "type": "stdio", ...}]
|
||||
- dict format (mcpServers): {"x": {"command": "npx", ...}}
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
result = []
|
||||
for name, cfg in raw.items():
|
||||
entry = {"name": name, **cfg}
|
||||
if "type" not in entry:
|
||||
entry["type"] = "sse" if "url" in entry else "stdio"
|
||||
result.append(entry)
|
||||
return result
|
||||
return []
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
Tool manager for managing tools.
|
||||
@@ -25,6 +46,31 @@ class ToolManager:
|
||||
# Initialize only once
|
||||
if not hasattr(self, 'tool_classes'):
|
||||
self.tool_classes = {} # Dictionary to store tool classes
|
||||
if not hasattr(self, '_mcp_registry'):
|
||||
self._mcp_registry = None # Lazy init: only created when MCP servers are configured
|
||||
if not hasattr(self, '_mcp_tool_instances'):
|
||||
self._mcp_tool_instances: dict = {} # tool_name -> McpTool instance
|
||||
if not hasattr(self, '_mcp_lock'):
|
||||
# Guards _mcp_loaded check-then-set so concurrent callers
|
||||
# don't trigger duplicate background loaders.
|
||||
self._mcp_lock = threading.Lock()
|
||||
if not hasattr(self, '_mcp_loaded'):
|
||||
# Idempotency flag. Flipped to True the moment the first loader
|
||||
# is dispatched (synchronously, inside _mcp_lock). Subsequent
|
||||
# _load_mcp_tools() calls become no-ops, so per-session agent
|
||||
# initialization never re-forks MCP subprocesses.
|
||||
self._mcp_loaded = False
|
||||
if not hasattr(self, '_mcp_status'):
|
||||
# server_name -> "pending" / "ready" / "failed"
|
||||
# Useful for UI / introspection while async loading is in progress.
|
||||
self._mcp_status: dict = {}
|
||||
if not hasattr(self, '_mcp_signature'):
|
||||
# (mtime, sha256) of mcp.json the last time we loaded.
|
||||
# Used by refresh_mcp_if_changed() to skip re-parsing when nothing changed.
|
||||
self._mcp_signature: tuple = (None, None)
|
||||
if not hasattr(self, '_mcp_active_configs'):
|
||||
# server_name -> normalized config dict, for diff-based reload.
|
||||
self._mcp_active_configs: dict = {}
|
||||
|
||||
def load_tools(self, tools_dir: str = "", config_dict=None):
|
||||
"""
|
||||
@@ -39,6 +85,8 @@ class ToolManager:
|
||||
self._load_tools_from_init()
|
||||
self._configure_tools_from_config(config_dict)
|
||||
|
||||
self._load_mcp_tools()
|
||||
|
||||
def _load_tools_from_init(self) -> bool:
|
||||
"""
|
||||
Load tool classes from tools.__init__.__all__
|
||||
@@ -70,10 +118,14 @@ class ToolManager:
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip memory tools (they need special initialization with memory_manager)
|
||||
# Skip tools that need special initialization
|
||||
if class_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {class_name} (requires memory_manager)")
|
||||
continue
|
||||
# McpTool instances are registered dynamically via _load_mcp_tools()
|
||||
if class_name == "McpTool":
|
||||
logger.debug(f"Skipped tool {class_name} (registered dynamically via mcp_servers config)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
@@ -84,11 +136,11 @@ class ToolManager:
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
if "playwright" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
@@ -154,11 +206,11 @@ class ToolManager:
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
if "playwright" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
@@ -197,7 +249,7 @@ class ToolManager:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool is configured but not loaded.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif tool_name == "google_search":
|
||||
@@ -212,6 +264,306 @@ class ToolManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error configuring tools from config: {e}")
|
||||
|
||||
def _mcp_json_path(self) -> str:
|
||||
import os
|
||||
workspace = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
return os.path.join(workspace, "mcp.json")
|
||||
|
||||
def _read_mcp_json_signature(self):
|
||||
"""
|
||||
Return (mtime, sha256_of_bytes) for ~/cow/mcp.json without parsing.
|
||||
Returns (None, None) if the file doesn't exist or is unreadable.
|
||||
Cheap enough (one stat + one small read) to call on every agent init.
|
||||
"""
|
||||
import os
|
||||
import hashlib
|
||||
path = self._mcp_json_path()
|
||||
try:
|
||||
mtime = os.path.getmtime(path)
|
||||
except OSError:
|
||||
return (None, None)
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
digest = hashlib.sha256(f.read()).hexdigest()
|
||||
except OSError:
|
||||
return (mtime, None)
|
||||
return (mtime, digest)
|
||||
|
||||
def _load_mcp_configs(self) -> list:
|
||||
"""
|
||||
Load MCP server configs with priority:
|
||||
1. ~/cow/mcp.json (supports both mcpServers and mcp_servers keys)
|
||||
2. config.json mcp_servers field (fallback)
|
||||
"""
|
||||
import os
|
||||
import json as _json
|
||||
|
||||
mcp_json_path = self._mcp_json_path()
|
||||
|
||||
if os.path.exists(mcp_json_path):
|
||||
try:
|
||||
with open(mcp_json_path, "r", encoding="utf-8") as f:
|
||||
data = _json.load(f)
|
||||
raw = data.get("mcpServers") or data.get("mcp_servers") or data
|
||||
logger.info(f"[ToolManager] Loading MCP config from {mcp_json_path}")
|
||||
return _normalize_mcp_configs(raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] Failed to read {mcp_json_path}: {e}, falling back to config.json")
|
||||
|
||||
raw = conf().get("mcp_servers", [])
|
||||
return _normalize_mcp_configs(raw)
|
||||
|
||||
def _load_mcp_tools(self):
|
||||
"""
|
||||
Trigger MCP tool loading in a background thread (idempotent).
|
||||
|
||||
Returns immediately. Booting MCP servers (npx, uvx, etc.) takes
|
||||
seconds to tens of seconds on first run, which would otherwise
|
||||
block agent initialization and the user's first message.
|
||||
Built-in tools work fine without MCP, so we let the agent serve
|
||||
traffic right away and let MCP servers come online in the
|
||||
background. Per-session agents read a snapshot of whatever is
|
||||
ready at construction time and gracefully ignore the rest.
|
||||
"""
|
||||
with self._mcp_lock:
|
||||
if self._mcp_loaded:
|
||||
return
|
||||
mcp_servers_config = self._load_mcp_configs()
|
||||
# Snapshot the signature now so future refresh_mcp_if_changed()
|
||||
# calls can short-circuit when nothing has changed on disk.
|
||||
self._mcp_signature = self._read_mcp_json_signature()
|
||||
self._mcp_active_configs = {
|
||||
cfg.get("name", "<unnamed>"): cfg for cfg in mcp_servers_config
|
||||
}
|
||||
if not mcp_servers_config:
|
||||
# Mark as loaded even when there is nothing to load,
|
||||
# so we don't re-read the config file on every call.
|
||||
self._mcp_loaded = True
|
||||
return
|
||||
|
||||
# Mark pending immediately so list_mcp_status() callers see
|
||||
# the in-progress state instead of an empty dict.
|
||||
for cfg in mcp_servers_config:
|
||||
name = cfg.get("name", "<unnamed>")
|
||||
self._mcp_status[name] = "pending"
|
||||
|
||||
self._mcp_loaded = True
|
||||
threading.Thread(
|
||||
target=self._load_mcp_tools_async,
|
||||
args=(mcp_servers_config,),
|
||||
daemon=True,
|
||||
name="mcp-loader",
|
||||
).start()
|
||||
logger.info(
|
||||
f"[ToolManager] MCP loading started in background "
|
||||
f"({len(mcp_servers_config)} server(s) configured)"
|
||||
)
|
||||
|
||||
def refresh_mcp_if_changed(self):
|
||||
"""
|
||||
Cheap check whether ~/cow/mcp.json has changed since last load.
|
||||
If it has, do a diff-based reload: start newly added servers,
|
||||
shut down removed ones, and restart any whose config was edited.
|
||||
Untouched servers are left running.
|
||||
|
||||
Designed to be called on every agent creation. The fast path is
|
||||
a single os.stat() — completely free when nothing has changed.
|
||||
"""
|
||||
with self._mcp_lock:
|
||||
new_sig = self._read_mcp_json_signature()
|
||||
if new_sig == self._mcp_signature:
|
||||
return # no-op fast path
|
||||
|
||||
try:
|
||||
new_configs = self._load_mcp_configs()
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] MCP reload — failed to parse config: {e}")
|
||||
return
|
||||
|
||||
new_by_name = {
|
||||
cfg.get("name", "<unnamed>"): cfg for cfg in new_configs
|
||||
}
|
||||
old_by_name = self._mcp_active_configs
|
||||
|
||||
added = [n for n in new_by_name if n not in old_by_name]
|
||||
removed = [n for n in old_by_name if n not in new_by_name]
|
||||
changed = [
|
||||
n for n in new_by_name
|
||||
if n in old_by_name and new_by_name[n] != old_by_name[n]
|
||||
]
|
||||
|
||||
if not (added or removed or changed):
|
||||
# Signature drifted but content is logically identical
|
||||
# (e.g. user re-saved the file without edits). Just sync.
|
||||
self._mcp_signature = new_sig
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[ToolManager] mcp.json changed — "
|
||||
f"adding={added}, removing={removed}, restarting={changed}"
|
||||
)
|
||||
|
||||
# Tear down removed + changed servers (changed ones get restarted below)
|
||||
for name in removed + changed:
|
||||
self._teardown_mcp_server(name)
|
||||
|
||||
# Spin up newly added + changed servers in the background
|
||||
to_start = [new_by_name[n] for n in added + changed]
|
||||
if to_start:
|
||||
for cfg in to_start:
|
||||
self._mcp_status[cfg.get("name", "<unnamed>")] = "pending"
|
||||
threading.Thread(
|
||||
target=self._load_mcp_tools_async,
|
||||
args=(to_start,),
|
||||
daemon=True,
|
||||
name="mcp-loader-reload",
|
||||
).start()
|
||||
|
||||
self._mcp_active_configs = new_by_name
|
||||
self._mcp_signature = new_sig
|
||||
|
||||
def _teardown_mcp_server(self, server_name: str):
|
||||
"""Shut down one MCP server and drop its tools from the registry."""
|
||||
if self._mcp_registry is None:
|
||||
return
|
||||
client = None
|
||||
with self._mcp_registry._registry_lock:
|
||||
client = self._mcp_registry._clients.pop(server_name, None)
|
||||
if client is not None:
|
||||
try:
|
||||
client.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP] Error shutting down '{server_name}': {e}")
|
||||
# Drop tools that belonged to this server.
|
||||
for tool_name in list(self._mcp_tool_instances.keys()):
|
||||
tool = self._mcp_tool_instances.get(tool_name)
|
||||
if tool is not None and getattr(tool, "server_name", None) == server_name:
|
||||
self._mcp_tool_instances.pop(tool_name, None)
|
||||
self._mcp_status.pop(server_name, None)
|
||||
|
||||
def _load_mcp_tools_async(self, mcp_servers_config):
|
||||
"""
|
||||
Background worker: bring up each MCP server one-by-one and
|
||||
publish ready tools to _mcp_tool_instances as they come online.
|
||||
|
||||
Server failures are isolated — one bad server cannot block
|
||||
the others, and never raises out of the worker thread.
|
||||
"""
|
||||
try:
|
||||
from agent.tools.mcp.mcp_client import McpClient, McpClientRegistry
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
|
||||
registry = McpClientRegistry()
|
||||
self._mcp_registry = registry
|
||||
|
||||
for cfg in mcp_servers_config:
|
||||
server_name = cfg.get("name", "<unnamed>")
|
||||
try:
|
||||
client = McpClient(cfg)
|
||||
if not client.initialize():
|
||||
self._mcp_status[server_name] = "failed"
|
||||
logger.warning(
|
||||
f"[MCP] Server '{server_name}' failed to initialize — skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
tool_schemas = client.list_tools()
|
||||
added = []
|
||||
for schema in tool_schemas:
|
||||
tool_name = schema.get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
mcp_tool = McpTool(client, schema, server_name)
|
||||
# Atomic dict assignment is GIL-safe; readers iterate
|
||||
# over a list() snapshot to avoid concurrent mutation.
|
||||
self._mcp_tool_instances[tool_name] = mcp_tool
|
||||
added.append(tool_name)
|
||||
|
||||
# Register client into the shared registry only after its
|
||||
# tools are visible, so callers never see a half-loaded server.
|
||||
with registry._registry_lock:
|
||||
registry._clients[server_name] = client
|
||||
self._mcp_status[server_name] = "ready"
|
||||
logger.info(
|
||||
f"[MCP] Server '{server_name}' ready — "
|
||||
f"{len(added)} tool(s): {added}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._mcp_status[server_name] = "failed"
|
||||
logger.warning(f"[MCP] Server '{server_name}' load failed: {e}")
|
||||
|
||||
ready = sum(1 for s in self._mcp_status.values() if s == "ready")
|
||||
total = len(self._mcp_status)
|
||||
logger.info(
|
||||
f"[ToolManager] MCP loading complete: "
|
||||
f"{ready}/{total} server(s) ready, "
|
||||
f"{len(self._mcp_tool_instances)} tool(s) available"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] MCP background loader crashed: {e}")
|
||||
|
||||
def list_mcp_status(self) -> dict:
|
||||
"""Return {server_name: status} snapshot for UI / debugging."""
|
||||
return dict(self._mcp_status)
|
||||
|
||||
def sync_mcp_into_agent(self, agent) -> tuple:
|
||||
"""
|
||||
Reconcile a live agent's tool collection with the current MCP tool registry.
|
||||
|
||||
Adds tools that finished loading after the agent was created,
|
||||
and removes tools whose MCP server was torn down. Built-in tools
|
||||
on the agent are left untouched.
|
||||
|
||||
Handles both representations CowAgent uses:
|
||||
- Agent.tools: list[BaseTool] (default Agent class)
|
||||
- AgentStream.tools: dict[str, BaseTool] (streaming agent)
|
||||
|
||||
Returns (added_names, removed_names) for logging.
|
||||
"""
|
||||
if agent is None or not hasattr(agent, "tools"):
|
||||
return ([], [])
|
||||
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
current = self._mcp_tool_instances
|
||||
registry_names = set(current.keys())
|
||||
|
||||
agent_tools = agent.tools
|
||||
|
||||
if isinstance(agent_tools, dict):
|
||||
agent_mcp_names = {
|
||||
name for name, tool in agent_tools.items()
|
||||
if isinstance(tool, McpTool)
|
||||
}
|
||||
added = registry_names - agent_mcp_names
|
||||
removed = agent_mcp_names - registry_names
|
||||
if not (added or removed):
|
||||
return ([], [])
|
||||
for name in added:
|
||||
agent_tools[name] = current[name]
|
||||
for name in removed:
|
||||
agent_tools.pop(name, None)
|
||||
|
||||
elif isinstance(agent_tools, list):
|
||||
agent_mcp_names = {
|
||||
t.name for t in agent_tools if isinstance(t, McpTool)
|
||||
}
|
||||
added = registry_names - agent_mcp_names
|
||||
removed = agent_mcp_names - registry_names
|
||||
if not (added or removed):
|
||||
return ([], [])
|
||||
if removed:
|
||||
agent.tools = [
|
||||
t for t in agent_tools
|
||||
if not (isinstance(t, McpTool) and t.name in removed)
|
||||
]
|
||||
for name in added:
|
||||
agent.tools.append(current[name])
|
||||
|
||||
else:
|
||||
return ([], [])
|
||||
|
||||
return (sorted(added), sorted(removed))
|
||||
|
||||
def create_tool(self, name: str) -> BaseTool:
|
||||
"""
|
||||
Get a new instance of a tool by name.
|
||||
@@ -229,6 +581,12 @@ class ToolManager:
|
||||
tool_instance.config = self.tool_configs[name]
|
||||
|
||||
return tool_instance
|
||||
|
||||
# Fall back to MCP tool instances
|
||||
mcp_tool = self._mcp_tool_instances.get(name)
|
||||
if mcp_tool:
|
||||
return mcp_tool
|
||||
|
||||
return None
|
||||
|
||||
def list_tools(self) -> dict:
|
||||
@@ -245,4 +603,17 @@ class ToolManager:
|
||||
"description": temp_instance.description,
|
||||
"parameters": temp_instance.get_json_schema()
|
||||
}
|
||||
|
||||
# Include MCP tool instances
|
||||
for name, mcp_tool in self._mcp_tool_instances.items():
|
||||
result[name] = {
|
||||
"description": mcp_tool.description,
|
||||
"parameters": mcp_tool.params,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def shutdown_mcp(self):
|
||||
"""Shut down all MCP server clients."""
|
||||
if self._mcp_registry:
|
||||
self._mcp_registry.shutdown_all()
|
||||
|
||||
@@ -8,7 +8,10 @@ Truncation is based on two independent limits - whichever is hit first wins:
|
||||
Never returns partial lines (except bash tail truncation edge case).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Literal, Tuple
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 2000
|
||||
|
||||
1
agent/tools/vision/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from agent.tools.vision.vision import Vision
|
||||
814
agent/tools/vision/vision.py
Normal file
@@ -0,0 +1,814 @@
|
||||
"""
|
||||
Vision tool - Analyze images using Vision API.
|
||||
Supports local files (auto base64-encoded) and HTTP URLs.
|
||||
|
||||
Provider resolution:
|
||||
- tools.vision.model (if set) means "prefer this model first; fall back to
|
||||
other configured providers if it fails". The model name is mapped to its
|
||||
native provider (e.g. doubao-* → Doubao, kimi-* → Moonshot, gpt-* →
|
||||
OpenAI/LinkAI). That provider is tried first, then the standard auto
|
||||
chain runs as fallback (with the preferred provider de-duplicated).
|
||||
- Auto chain priority:
|
||||
1. Main model via bot.call_vision — only when the main bot is known
|
||||
to actually support vision (not just expose a call_vision method).
|
||||
2. Other models whose API key is configured.
|
||||
3. OpenAI / LinkAI raw HTTP.
|
||||
When use_linkai=true, LinkAI is promoted to #1.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
DEFAULT_MODEL = const.GPT_41_MINI
|
||||
DEFAULT_TIMEOUT = 60
|
||||
MAX_TOKENS = 1000
|
||||
COMPRESS_THRESHOLD = 1_048_576 # 1 MB
|
||||
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
}
|
||||
|
||||
_MAIN_MODEL_PROVIDER_NAME = "MainModel"
|
||||
|
||||
# (config_key_for_api_key, bot_type, default_vision_model, provider_display_name)
|
||||
# Auto-discovered as fallback vision providers when their API key is configured.
|
||||
# OpenAI and LinkAI are handled separately (raw HTTP providers), so not listed here.
|
||||
_DISCOVERABLE_MODELS = [
|
||||
("moonshot_api_key", const.MOONSHOT, const.KIMI_K2_6, "Moonshot"),
|
||||
("ark_api_key", const.DOUBAO, const.DOUBAO_SEED_2_PRO, "Doubao"),
|
||||
("dashscope_api_key", const.QWEN_DASHSCOPE, const.QWEN36_PLUS, "DashScope"),
|
||||
("claude_api_key", const.CLAUDEAPI, const.CLAUDE_4_6_SONNET, "Claude"),
|
||||
("gemini_api_key", const.GEMINI, const.GEMINI_35_FLASH, "Gemini"),
|
||||
("qianfan_api_key", const.QIANFAN, const.ERNIE_45_TURBO_VL, "Qianfan"),
|
||||
("zhipu_ai_api_key", const.ZHIPU_AI, const.GLM_4_7, "ZhipuAI"),
|
||||
("minimax_api_key", const.MiniMax, const.MINIMAX_M2_7, "MiniMax"),
|
||||
("mimo_api_key", const.MIMO, const.MIMO_V2_5_PRO, "MiMo"),
|
||||
]
|
||||
|
||||
# Model name prefix → discoverable provider display_name.
|
||||
# Used to auto-route tools.vision.model to its native provider.
|
||||
# Matched case-insensitively; longest prefix wins.
|
||||
_MODEL_PREFIX_TO_PROVIDER = [
|
||||
("doubao-", "Doubao"),
|
||||
("kimi-", "Moonshot"),
|
||||
("moonshot-", "Moonshot"),
|
||||
("qwen", "DashScope"), # qwen-*, qwen3-*, qwen3.6-*, etc.
|
||||
("claude-", "Claude"),
|
||||
("ernie-", "Qianfan"),
|
||||
("gemini-", "Gemini"),
|
||||
("glm-", "ZhipuAI"),
|
||||
("minimax-", "MiniMax"),
|
||||
("abab", "MiniMax"),
|
||||
("mimo-", "MiMo"),
|
||||
]
|
||||
|
||||
# Model prefixes that natively belong to OpenAI / LinkAI (raw HTTP providers).
|
||||
_OPENAI_MODEL_PREFIXES = ("gpt-", "o1-", "o3-", "o4-", "chatgpt-")
|
||||
|
||||
# Maps the UI provider id (persisted in tools.vision.provider) to the internal
|
||||
# display name used in VisionProvider.name. Keep in sync with _DISCOVERABLE_MODELS
|
||||
# and the openai/linkai branches in _route_by_model_name.
|
||||
_PROVIDER_ID_TO_DISPLAY = {
|
||||
"openai": "OpenAI",
|
||||
"linkai": "LinkAI",
|
||||
"moonshot": "Moonshot",
|
||||
"doubao": "Doubao",
|
||||
"dashscope": "DashScope",
|
||||
"claudeAPI": "Claude",
|
||||
"gemini": "Gemini",
|
||||
"qianfan": "Qianfan",
|
||||
"zhipu": "ZhipuAI",
|
||||
"minimax": "MiniMax",
|
||||
"mimo": "MiMo",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionProvider:
|
||||
"""A single Vision API provider configuration."""
|
||||
name: str
|
||||
api_key: str
|
||||
api_base: str
|
||||
extra_headers: dict = field(default_factory=dict)
|
||||
model_override: Optional[str] = None
|
||||
use_bot: bool = False # When True, call via bot.call_vision instead of raw HTTP
|
||||
fallback_bot: Any = None # Bot instance for non-main-model providers
|
||||
|
||||
|
||||
class VisionAPIError(Exception):
|
||||
"""Raised when a Vision API call fails and should trigger fallback."""
|
||||
pass
|
||||
|
||||
|
||||
class Vision(BaseTool):
|
||||
"""Analyze images using Vision API"""
|
||||
|
||||
name: str = "vision"
|
||||
description: str = (
|
||||
"Analyze a local image or image URL (jpg/jpeg/png) using Vision API. "
|
||||
"Can describe content, extract text, identify objects, colors, etc. "
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {
|
||||
"type": "string",
|
||||
"description": "Local file path or HTTP(S) URL of the image to analyze",
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Question to ask about the image",
|
||||
},
|
||||
},
|
||||
"required": ["image", "question"],
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return True
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
image = args.get("image", "").strip()
|
||||
question = args.get("question", "").strip()
|
||||
|
||||
if not image:
|
||||
return ToolResult.fail("Error: 'image' parameter is required")
|
||||
if not question:
|
||||
return ToolResult.fail("Error: 'question' parameter is required")
|
||||
|
||||
providers = self._resolve_providers()
|
||||
if not providers:
|
||||
return ToolResult.fail(
|
||||
"Error: No model available for Vision.\n"
|
||||
"The main model does not support vision and no other API keys are configured.\n"
|
||||
"Options:\n"
|
||||
" 1. Switch to a multimodal model (e.g. ernie-4.5-turbo-vl, qwen3.6-plus, claude-sonnet-4-6, gemini-2.0-flash)\n"
|
||||
" 2. Configure OPENAI_API_KEY: env_config(action=\"set\", key=\"OPENAI_API_KEY\", value=\"your-key\")\n"
|
||||
" 3. Configure LINKAI_API_KEY: env_config(action=\"set\", key=\"LINKAI_API_KEY\", value=\"your-key\")"
|
||||
)
|
||||
|
||||
try:
|
||||
image_content = self._build_image_content(image)
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: {e}")
|
||||
|
||||
# Default model is only used as a last-resort placeholder for providers
|
||||
# whose VisionProvider.model_override is None (e.g. raw OpenAI provider
|
||||
# when the user did not configure tools.vision.model).
|
||||
return self._call_with_fallback(providers, DEFAULT_MODEL, question, image_content)
|
||||
|
||||
def _call_with_fallback(self, providers: List[VisionProvider], model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""Try each provider in order; fall back to the next one on failure."""
|
||||
errors: List[str] = []
|
||||
for i, provider in enumerate(providers):
|
||||
use_model = provider.model_override or model
|
||||
try:
|
||||
logger.info(f"[Vision] Trying provider '{provider.name}' "
|
||||
f"with model '{use_model}' ({i + 1}/{len(providers)})")
|
||||
if provider.use_bot:
|
||||
result = self._call_via_bot(use_model, question, image_content, provider)
|
||||
else:
|
||||
result = self._call_api(provider, use_model, question, image_content)
|
||||
logger.info(f"[Vision] ✅ Success via {provider.name} (model={use_model})")
|
||||
return result
|
||||
except VisionAPIError as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' failed: {e}")
|
||||
except requests.Timeout:
|
||||
errors.append(f"[{provider.name}/{use_model}] Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' timed out")
|
||||
except requests.ConnectionError:
|
||||
errors.append(f"[{provider.name}/{use_model}] Connection failed")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' connection failed")
|
||||
except Exception as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.error(f"[Vision] Provider '{provider.name}' unexpected error: {e}", exc_info=True)
|
||||
|
||||
return ToolResult.fail(
|
||||
"Error: All Vision API providers failed.\n" + "\n".join(f" - {err}" for err in errors)
|
||||
)
|
||||
|
||||
def _resolve_providers(self) -> List[VisionProvider]:
|
||||
"""
|
||||
Build an ordered list of providers to try.
|
||||
|
||||
Semantics of `tools.vision.model`:
|
||||
"Prefer this model first; fall back to other configured providers
|
||||
if it fails."
|
||||
|
||||
Order:
|
||||
1. The provider that natively serves `tools.vision.model` (if any
|
||||
and its API key is configured) — using the user-specified model
|
||||
name verbatim.
|
||||
2. Auto-discovery chain as fallback:
|
||||
- use_linkai=true → [LinkAI, MainModel?, OtherModels…, OpenAI]
|
||||
- default → [MainModel?, OtherModels…, OpenAI, LinkAI]
|
||||
MainModel is only included when the main bot is known to support
|
||||
vision (see _main_bot_supports_vision).
|
||||
|
||||
Providers that share the same display name as the preferred provider
|
||||
are de-duplicated to avoid retrying the same endpoint twice.
|
||||
"""
|
||||
user_model = self._resolve_user_vision_model()
|
||||
user_provider = self._resolve_user_vision_provider()
|
||||
providers: List[VisionProvider] = []
|
||||
|
||||
# Step 1: preferred provider — explicit `tools.vision.provider`
|
||||
# wins so custom model names can still be routed correctly. Falls
|
||||
# through to model-name prefix inference when provider is unset.
|
||||
preferred = None
|
||||
if user_provider and user_model:
|
||||
preferred = self._route_by_provider_id(user_provider, user_model)
|
||||
if not preferred and user_model:
|
||||
preferred = self._route_by_model_name(user_model)
|
||||
if preferred:
|
||||
providers.extend(preferred)
|
||||
|
||||
# Step 2: auto-discovery chain as fallback
|
||||
existing = {p.name for p in providers}
|
||||
fallback: List[VisionProvider] = []
|
||||
use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
|
||||
if use_linkai:
|
||||
self._append_provider(fallback, lambda: self._build_linkai_provider(user_model))
|
||||
self._append_provider(fallback, self._build_main_model_provider)
|
||||
self._append_other_model_providers(fallback, preferred_model=user_model)
|
||||
self._append_provider(fallback, lambda: self._build_openai_provider(user_model))
|
||||
else:
|
||||
self._append_provider(fallback, self._build_main_model_provider)
|
||||
self._append_other_model_providers(fallback, preferred_model=user_model)
|
||||
self._append_provider(fallback, lambda: self._build_openai_provider(user_model))
|
||||
self._append_provider(fallback, lambda: self._build_linkai_provider(user_model))
|
||||
|
||||
for p in fallback:
|
||||
if p.name in existing:
|
||||
continue
|
||||
providers.append(p)
|
||||
existing.add(p.name)
|
||||
|
||||
return providers
|
||||
|
||||
@staticmethod
|
||||
def _append_provider(providers: List[VisionProvider], builder) -> None:
|
||||
p = builder()
|
||||
if p:
|
||||
providers.append(p)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_vision_model() -> Optional[str]:
|
||||
"""Read tools.vision.model (singular ``tool`` kept as runtime fallback)."""
|
||||
tools_conf = conf().get("tools") or conf().get("tool") or {}
|
||||
if not isinstance(tools_conf, dict):
|
||||
return None
|
||||
vision_conf = tools_conf.get("vision", {})
|
||||
if not isinstance(vision_conf, dict):
|
||||
return None
|
||||
m = vision_conf.get("model")
|
||||
if isinstance(m, str) and m.strip():
|
||||
return m.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_vision_provider() -> Optional[str]:
|
||||
"""Read tools.vision.provider — the UI-persisted vendor id.
|
||||
|
||||
Lets users pin a vendor for custom model names that prefix-inference
|
||||
can't recognize. Returns None when unset/blank.
|
||||
"""
|
||||
tools_conf = conf().get("tools") or conf().get("tool") or {}
|
||||
if not isinstance(tools_conf, dict):
|
||||
return None
|
||||
vision_conf = tools_conf.get("vision", {})
|
||||
if not isinstance(vision_conf, dict):
|
||||
return None
|
||||
p = vision_conf.get("provider")
|
||||
if isinstance(p, str) and p.strip():
|
||||
return p.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _infer_provider_from_model(model_name: str) -> Optional[str]:
|
||||
"""
|
||||
Infer the provider display name from a model name's prefix.
|
||||
Returns None when no rule matches (or for OpenAI-family names, which
|
||||
are handled separately by the caller).
|
||||
"""
|
||||
if not model_name:
|
||||
return None
|
||||
lower = model_name.lower()
|
||||
# Sort by prefix length desc so e.g. "moonshot-" wins over hypothetical "moo-"
|
||||
for prefix, display_name in sorted(_MODEL_PREFIX_TO_PROVIDER, key=lambda x: -len(x[0])):
|
||||
if lower.startswith(prefix.lower()):
|
||||
return display_name
|
||||
return None
|
||||
|
||||
def _route_by_provider_id(self, provider_id: str, user_model: str) -> Optional[List[VisionProvider]]:
|
||||
"""Route by the UI-persisted provider id.
|
||||
|
||||
Returns:
|
||||
- [provider] : provider id is known and its key is configured.
|
||||
- None : unknown provider id, or the bot can't be created.
|
||||
Caller falls through to model-name-based routing.
|
||||
"""
|
||||
display_name = _PROVIDER_ID_TO_DISPLAY.get(provider_id)
|
||||
if not display_name:
|
||||
return None
|
||||
|
||||
# OpenAI / LinkAI use raw HTTP providers, not the discoverable bot path.
|
||||
if provider_id == "openai":
|
||||
p = self._build_openai_provider(user_model)
|
||||
return [p] if p else None
|
||||
if provider_id == "linkai":
|
||||
p = self._build_linkai_provider(user_model)
|
||||
return [p] if p else None
|
||||
|
||||
# Discoverable bot-backed providers.
|
||||
for config_key, bot_type, _default_model, name in _DISCOVERABLE_MODELS:
|
||||
if name != display_name:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
logger.warning(f"[Vision] tools.vision.provider='{provider_id}' "
|
||||
f"but '{config_key}' is not configured. Falling back.")
|
||||
return None
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
logger.warning(f"[Vision] '{display_name}' bot does not implement call_vision.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Vision] Failed to create '{display_name}' bot: {e}")
|
||||
return None
|
||||
return [VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=user_model,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)]
|
||||
return None
|
||||
|
||||
def _route_by_model_name(self, user_model: str) -> Optional[List[VisionProvider]]:
|
||||
"""
|
||||
Try to build a provider list using the user-specified model name.
|
||||
Returns:
|
||||
- [provider] : matched and the provider's key is configured
|
||||
- [] : matched but key missing → tell caller to surface this
|
||||
as a hard error rather than silently falling back
|
||||
- None : no rule matches → caller should fall through to auto
|
||||
"""
|
||||
lower = user_model.lower()
|
||||
|
||||
# OpenAI / LinkAI family
|
||||
if lower.startswith(_OPENAI_MODEL_PREFIXES):
|
||||
providers: List[VisionProvider] = []
|
||||
# Prefer LinkAI when explicitly enabled, else OpenAI first
|
||||
use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
if use_linkai:
|
||||
self._append_provider(providers, lambda: self._build_linkai_provider(user_model))
|
||||
self._append_provider(providers, lambda: self._build_openai_provider(user_model))
|
||||
else:
|
||||
self._append_provider(providers, lambda: self._build_openai_provider(user_model))
|
||||
self._append_provider(providers, lambda: self._build_linkai_provider(user_model))
|
||||
if providers:
|
||||
return providers
|
||||
logger.warning(f"[Vision] tools.vision.model='{user_model}' looks like an OpenAI "
|
||||
f"model but neither OPENAI_API_KEY nor LINKAI_API_KEY is configured.")
|
||||
return None # fall through to auto
|
||||
|
||||
# Discoverable native providers (Doubao, Moonshot, etc.)
|
||||
target_display = self._infer_provider_from_model(user_model)
|
||||
if not target_display:
|
||||
return None # unknown prefix → auto
|
||||
|
||||
for config_key, bot_type, _default_model, display_name in _DISCOVERABLE_MODELS:
|
||||
if display_name != target_display:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
logger.warning(f"[Vision] tools.vision.model='{user_model}' routes to "
|
||||
f"'{display_name}' but '{config_key}' is not configured. "
|
||||
f"Falling back to auto-discovery.")
|
||||
return None # fall through to auto
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
logger.warning(f"[Vision] '{display_name}' bot does not implement call_vision.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Vision] Failed to create '{display_name}' bot: {e}")
|
||||
return None
|
||||
|
||||
return [VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=user_model,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)]
|
||||
|
||||
return None
|
||||
|
||||
def _append_other_model_providers(self, providers: List[VisionProvider],
|
||||
preferred_model: Optional[str] = None) -> None:
|
||||
"""
|
||||
Auto-discover other models whose API key is configured.
|
||||
Skip the main model's own bot_type (already covered by MainModel
|
||||
provider), unless the main model itself does not support vision —
|
||||
in that case we still want the vendor's dedicated vision model
|
||||
as a fallback. Also skip bot_types that already appear in the
|
||||
provider list.
|
||||
|
||||
If preferred_model matches a provider's family, use it instead
|
||||
of that provider's hard-coded default model.
|
||||
"""
|
||||
main_bot_type = None
|
||||
main_bot_supports_vision = False
|
||||
if self.model and hasattr(self.model, '_resolve_bot_type'):
|
||||
main_bot_type = self.model._resolve_bot_type(conf().get("model", ""))
|
||||
main_bot = getattr(self.model, "bot", None)
|
||||
main_bot_supports_vision = self._main_bot_supports_vision(main_bot)
|
||||
|
||||
existing_names = {p.name for p in providers}
|
||||
preferred_provider = self._infer_provider_from_model(preferred_model) if preferred_model else None
|
||||
|
||||
for config_key, bot_type, default_model, display_name in _DISCOVERABLE_MODELS:
|
||||
if display_name in existing_names:
|
||||
continue
|
||||
# Same bot_type as the main model is normally handled by the
|
||||
# MainModel provider; only skip it here if the main model
|
||||
# actually supports vision. Otherwise fall through and add
|
||||
# the vendor's dedicated vision model as a fallback.
|
||||
if bot_type == main_bot_type and main_bot_supports_vision:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
model_for_provider = (preferred_model
|
||||
if preferred_provider == display_name and preferred_model
|
||||
else default_model)
|
||||
|
||||
provider = VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=model_for_provider,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)
|
||||
|
||||
# Same vendor as the main bot is the most natural fallback when
|
||||
# the main model itself does not support vision — promote it to
|
||||
# the front of the list instead of relying on declaration order.
|
||||
if bot_type == main_bot_type:
|
||||
providers.insert(0, provider)
|
||||
else:
|
||||
providers.append(provider)
|
||||
|
||||
def _main_bot_supports_vision(self, bot) -> bool:
|
||||
"""
|
||||
Whether the main bot is known to natively support vision.
|
||||
|
||||
Having a `call_vision` method is necessary but not sufficient —
|
||||
some bots implement the method against an endpoint that does not
|
||||
actually serve vision models, which causes silent failures when a
|
||||
vendor-foreign model name is forwarded.
|
||||
|
||||
Resolution order:
|
||||
1. If the bot explicitly declares `supports_vision`, trust it.
|
||||
This lets bots opt in or out based on their own runtime
|
||||
configuration (e.g. the currently selected model).
|
||||
2. Otherwise, fall back to a model-name prefix heuristic: trust
|
||||
call_vision when the main model looks like an OpenAI family
|
||||
model or matches a known multimodal vendor prefix.
|
||||
"""
|
||||
if bot is None:
|
||||
return False
|
||||
if hasattr(bot, "supports_vision"):
|
||||
return bool(getattr(bot, "supports_vision"))
|
||||
main_model = (conf().get("model") or "").lower()
|
||||
if not main_model:
|
||||
return False
|
||||
if main_model.startswith(_OPENAI_MODEL_PREFIXES):
|
||||
return True
|
||||
return self._infer_provider_from_model(main_model) is not None
|
||||
|
||||
def _build_main_model_provider(self) -> Optional[VisionProvider]:
|
||||
"""
|
||||
Use the vendor's own model for vision via bot.call_vision.
|
||||
Gated by _main_bot_supports_vision so non-vision bots (DeepSeek, etc.)
|
||||
do not get routed vendor-foreign model names.
|
||||
"""
|
||||
if not (self.model and hasattr(self.model, 'bot')):
|
||||
return None
|
||||
try:
|
||||
bot = self.model.bot
|
||||
except Exception:
|
||||
return None
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
return None
|
||||
if not self._main_bot_supports_vision(bot):
|
||||
return None
|
||||
|
||||
# Use the configured main model name; do NOT inject tools.vision.model
|
||||
# here, because by the time we reach this branch the tools.vision.model
|
||||
# routing has already been attempted (and either matched the main bot
|
||||
# or failed to find a provider).
|
||||
main_model_name = conf().get("model") or None
|
||||
|
||||
return VisionProvider(
|
||||
name=_MAIN_MODEL_PROVIDER_NAME,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=main_model_name,
|
||||
use_bot=True,
|
||||
)
|
||||
|
||||
def _build_openai_provider(self, preferred_model: Optional[str] = None) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.openai.com/v1"
|
||||
# Only honor preferred_model when it looks like an OpenAI-family name;
|
||||
# otherwise the OpenAI endpoint would 400 on a vendor-specific name.
|
||||
model_override = preferred_model if (
|
||||
preferred_model and preferred_model.lower().startswith(_OPENAI_MODEL_PREFIXES)
|
||||
) else None
|
||||
return VisionProvider(
|
||||
name="OpenAI",
|
||||
api_key=api_key,
|
||||
api_base=self._ensure_v1(api_base),
|
||||
model_override=model_override,
|
||||
)
|
||||
|
||||
def _build_linkai_provider(self, preferred_model: Optional[str] = None) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.link-ai.tech"
|
||||
from common.utils import get_cloud_headers
|
||||
extra = get_cloud_headers(api_key)
|
||||
extra.pop("Authorization", None)
|
||||
extra.pop("Content-Type", None)
|
||||
# LinkAI is a multi-vendor proxy and accepts most model names, so we
|
||||
# honor any user-configured model name here.
|
||||
return VisionProvider(
|
||||
name="LinkAI",
|
||||
api_key=api_key,
|
||||
api_base=self._ensure_v1(api_base),
|
||||
extra_headers=extra,
|
||||
model_override=preferred_model,
|
||||
)
|
||||
|
||||
def _call_via_bot(self, model: str, question: str, image_content: dict,
|
||||
provider: Optional[VisionProvider] = None) -> ToolResult:
|
||||
"""
|
||||
Call a model's call_vision with vendor-native API format.
|
||||
Uses the provider's _fallback_bot if set, otherwise the main model bot.
|
||||
Raises VisionAPIError on failure so fallback can proceed.
|
||||
"""
|
||||
try:
|
||||
bot = (provider and provider.fallback_bot) or self.model.bot
|
||||
except Exception as e:
|
||||
raise VisionAPIError(f"Cannot access bot: {e}")
|
||||
|
||||
# Extract the raw image URL from the OpenAI-format image_content block
|
||||
image_url = image_content.get("image_url", {}).get("url", "")
|
||||
if not image_url:
|
||||
raise VisionAPIError("No image URL in content block")
|
||||
|
||||
try:
|
||||
response = bot.call_vision(
|
||||
image_url=image_url,
|
||||
question=question,
|
||||
model=model,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
except Exception as e:
|
||||
raise VisionAPIError(f"call_vision failed: {e}")
|
||||
|
||||
if response is NotImplemented:
|
||||
raise VisionAPIError("Bot does not support vision")
|
||||
|
||||
if isinstance(response, dict) and response.get("error"):
|
||||
raise VisionAPIError(f"API error - {response.get('message', 'Unknown')}")
|
||||
|
||||
content = response.get("content", "") if isinstance(response, dict) else ""
|
||||
if not content:
|
||||
raise VisionAPIError("Empty response from main model")
|
||||
|
||||
usage_info = response.get("usage", {}) if isinstance(response, dict) else {}
|
||||
|
||||
# Use the actual model name from the bot response if available
|
||||
actual_model = response.get("model", model) if isinstance(response, dict) else model
|
||||
provider_name = provider.name if provider else _MAIN_MODEL_PROVIDER_NAME
|
||||
return ToolResult.success({
|
||||
"model": actual_model,
|
||||
"provider": provider_name,
|
||||
"content": content,
|
||||
"usage": usage_info,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _ensure_v1(api_base: str) -> str:
|
||||
"""Append /v1 if the base URL doesn't already end with a versioned path."""
|
||||
if not api_base:
|
||||
return api_base
|
||||
# Already has /v1 or similar version suffix
|
||||
if api_base.rstrip("/").split("/")[-1].startswith("v"):
|
||||
return api_base
|
||||
return api_base.rstrip("/") + "/v1"
|
||||
|
||||
def _build_image_content(self, image: str) -> dict:
|
||||
"""
|
||||
Build the image_url content block.
|
||||
Both remote URLs and local files are converted to base64 data URLs
|
||||
so every bot backend can consume them without extra downloads.
|
||||
"""
|
||||
if image.startswith(("http://", "https://")):
|
||||
return self._download_to_data_url(image)
|
||||
|
||||
if not os.path.isfile(image):
|
||||
raise FileNotFoundError(f"Image file not found: {image}")
|
||||
|
||||
ext = image.rsplit(".", 1)[-1].lower() if "." in image else ""
|
||||
mime_type = SUPPORTED_EXTENSIONS.get(ext)
|
||||
if not mime_type:
|
||||
raise ValueError(
|
||||
f"Unsupported image format '.{ext}'. "
|
||||
f"Supported: {', '.join(SUPPORTED_EXTENSIONS.keys())}"
|
||||
)
|
||||
|
||||
file_path = self._maybe_compress(image)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
finally:
|
||||
if file_path != image and os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _download_to_data_url(url: str) -> dict:
|
||||
"""Download a remote image and return it as a base64 data URL."""
|
||||
resp = requests.get(url, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
raise VisionAPIError(f"Failed to download image: HTTP {resp.status_code}")
|
||||
content_type = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
|
||||
if not content_type.startswith("image/"):
|
||||
content_type = "image/jpeg"
|
||||
b64 = base64.b64encode(resp.content).decode("ascii")
|
||||
data_url = f"data:{content_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _maybe_compress(path: str) -> str:
|
||||
"""Compress image to under COMPRESS_THRESHOLD with max long-edge 1536px."""
|
||||
file_size = os.path.getsize(path)
|
||||
if file_size <= COMPRESS_THRESHOLD:
|
||||
return path
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||||
tmp.close()
|
||||
|
||||
def _try_sips(max_dim: str, quality: str) -> bool:
|
||||
try:
|
||||
subprocess.run(
|
||||
["sips", "-Z", max_dim, "-s", "formatOptions", quality,
|
||||
path, "--out", tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
return True
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return False
|
||||
|
||||
def _try_convert(max_dim: str, quality: str) -> bool:
|
||||
try:
|
||||
subprocess.run(
|
||||
["convert", path, "-resize", f"{max_dim}x{max_dim}>",
|
||||
"-quality", quality, tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
return True
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return False
|
||||
|
||||
attempts = [
|
||||
("1536", "85"),
|
||||
("1536", "70"),
|
||||
("1536", "50"),
|
||||
]
|
||||
|
||||
for max_dim, quality in attempts:
|
||||
ok = _try_sips(max_dim, quality) or _try_convert(max_dim, quality)
|
||||
if not ok:
|
||||
continue
|
||||
new_size = os.path.getsize(tmp.name)
|
||||
logger.debug(f"[Vision] Compressed image "
|
||||
f"({file_size // 1024}KB -> {new_size // 1024}KB, "
|
||||
f"max_dim={max_dim}, q={quality})")
|
||||
if new_size <= COMPRESS_THRESHOLD:
|
||||
return tmp.name
|
||||
|
||||
if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > 0:
|
||||
return tmp.name
|
||||
|
||||
os.remove(tmp.name)
|
||||
return path
|
||||
|
||||
def _call_api(self, provider: VisionProvider, model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""
|
||||
Call a single provider's Vision API.
|
||||
Raises VisionAPIError on recoverable failures so the caller can try
|
||||
the next provider.
|
||||
"""
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": question},
|
||||
image_content,
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**provider.extra_headers,
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{provider.api_base}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise VisionAPIError(f"HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
msg = data["error"].get("message", "Unknown API error")
|
||||
raise VisionAPIError(f"API error - {msg}")
|
||||
|
||||
content = ""
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "")
|
||||
|
||||
usage = data.get("usage", {})
|
||||
result = {
|
||||
"model": model,
|
||||
"provider": provider.name,
|
||||
"content": content,
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
}
|
||||
return ToolResult.success(result)
|
||||
0
agent/tools/web_fetch/__init__.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
Web Fetch tool - Fetch and extract readable content from web pages and remote files.
|
||||
|
||||
Supports:
|
||||
- HTML web pages: extracts readable text content
|
||||
- Document files (PDF, Word, TXT, Markdown, etc.): downloads to workspace/tmp and parses content
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, Set
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size
|
||||
from common.log import logger
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
|
||||
"Accept": "*/*",
|
||||
}
|
||||
|
||||
# Supported document file extensions
|
||||
PDF_SUFFIXES: Set[str] = {".pdf"}
|
||||
WORD_SUFFIXES: Set[str] = {".docx"}
|
||||
TEXT_SUFFIXES: Set[str] = {".txt", ".md", ".markdown", ".rst", ".csv", ".tsv", ".log"}
|
||||
SPREADSHEET_SUFFIXES: Set[str] = {".xls", ".xlsx"}
|
||||
PPT_SUFFIXES: Set[str] = {".ppt", ".pptx"}
|
||||
|
||||
ALL_DOC_SUFFIXES = PDF_SUFFIXES | WORD_SUFFIXES | TEXT_SUFFIXES | SPREADSHEET_SUFFIXES | PPT_SUFFIXES
|
||||
|
||||
_CHARSET_RE = re.compile(r'charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_CHARSET_RE = re.compile(rb'<meta[^>]+charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_HTTP_EQUIV_RE = re.compile(
|
||||
rb'<meta[^>]+http-equiv\s*=\s*["\']?Content-Type["\']?[^>]+content\s*=\s*["\'][^"\']*charset=([\w\-]+)',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _extract_charset_from_content_type(content_type: str) -> Optional[str]:
|
||||
"""Extract charset from Content-Type header value."""
|
||||
m = _CHARSET_RE.search(content_type)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def _extract_charset_from_html_meta(raw_bytes: bytes) -> Optional[str]:
|
||||
"""Extract charset from HTML <meta> tags in the first few KB of raw bytes."""
|
||||
m = _META_CHARSET_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
m = _META_HTTP_EQUIV_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
return None
|
||||
|
||||
|
||||
def _get_url_suffix(url: str) -> str:
|
||||
"""Extract file extension from URL path, ignoring query params."""
|
||||
path = urlparse(url).path
|
||||
return os.path.splitext(path)[-1].lower()
|
||||
|
||||
|
||||
def _is_document_url(url: str) -> bool:
|
||||
"""Check if URL points to a downloadable document file."""
|
||||
suffix = _get_url_suffix(url)
|
||||
return suffix in ALL_DOC_SUFFIXES
|
||||
|
||||
|
||||
class WebFetch(BaseTool):
|
||||
"""Tool for fetching web pages and remote document files"""
|
||||
|
||||
name: str = "web_fetch"
|
||||
description: str = (
|
||||
"Fetch content from a http/https URL. For web pages, extracts readable text. "
|
||||
"For document files (PDF, Word, TXT, Markdown, Excel, PPT), downloads and parses the file content. "
|
||||
"Supported file types: .pdf, .docx, .txt, .md, .csv, .xls, .xlsx, .ppt, .pptx"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to fetch (web page or document file link)"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' parameter is required")
|
||||
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return ToolResult.fail("Error: Invalid URL (must start with http:// or https://)")
|
||||
|
||||
if _is_document_url(url):
|
||||
return self._fetch_document(url)
|
||||
|
||||
return self._fetch_webpage(url)
|
||||
|
||||
# ---- Web page fetching ----
|
||||
|
||||
def _fetch_webpage(self, url: str) -> ToolResult:
|
||||
"""Fetch and extract readable text from an HTML web page."""
|
||||
parsed = urlparse(url)
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: Failed to fetch URL: {e}")
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if self._is_binary_content_type(content_type) and not _is_document_url(url):
|
||||
return self._handle_download_by_content_type(url, response, content_type)
|
||||
|
||||
response.encoding = self._detect_encoding(response)
|
||||
html = response.text
|
||||
title = self._extract_title(html)
|
||||
text = self._extract_text(html)
|
||||
|
||||
return ToolResult.success(f"Title: {title}\n\nContent:\n{text}")
|
||||
|
||||
# ---- Document fetching ----
|
||||
|
||||
def _fetch_document(self, url: str) -> ToolResult:
|
||||
"""Download a document file and extract its text content."""
|
||||
suffix = _get_url_suffix(url)
|
||||
parsed = urlparse(url)
|
||||
filename = self._extract_filename(url)
|
||||
tmp_dir = self._ensure_tmp_dir()
|
||||
|
||||
local_path = os.path.join(tmp_dir, filename)
|
||||
logger.info(f"[WebFetch] Downloading document: {url} -> {local_path}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
content_length = int(response.headers.get("Content-Length", 0))
|
||||
if content_length > MAX_FILE_SIZE:
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large ({format_size(content_length)} > {format_size(MAX_FILE_SIZE)})"
|
||||
)
|
||||
|
||||
downloaded = 0
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
downloaded += len(chunk)
|
||||
if downloaded > MAX_FILE_SIZE:
|
||||
f.close()
|
||||
os.remove(local_path)
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large (>{format_size(MAX_FILE_SIZE)}), download aborted"
|
||||
)
|
||||
f.write(chunk)
|
||||
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Download timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to download file: {e}")
|
||||
|
||||
try:
|
||||
text = self._parse_document(local_path, suffix)
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to parse document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
file_size = os.path.getsize(local_path)
|
||||
return ToolResult.success(
|
||||
f"File downloaded to: {local_path} ({format_size(file_size)})\n"
|
||||
f"No text content could be extracted. The file may contain only images or be encrypted."
|
||||
)
|
||||
|
||||
truncation = truncate_head(text)
|
||||
result_text = truncation.content
|
||||
|
||||
file_size = os.path.getsize(local_path)
|
||||
header = f"[Document: {filename} | Size: {format_size(file_size)} | Saved to: {local_path}]\n\n"
|
||||
|
||||
if truncation.truncated:
|
||||
header += f"[Content truncated: showing {truncation.output_lines} of {truncation.total_lines} lines]\n\n"
|
||||
|
||||
return ToolResult.success(header + result_text)
|
||||
|
||||
def _parse_document(self, file_path: str, suffix: str) -> str:
|
||||
"""Parse document file and return extracted text."""
|
||||
if suffix in PDF_SUFFIXES:
|
||||
return self._parse_pdf(file_path)
|
||||
elif suffix in WORD_SUFFIXES:
|
||||
return self._parse_word(file_path)
|
||||
elif suffix in TEXT_SUFFIXES:
|
||||
return self._parse_text(file_path)
|
||||
elif suffix in SPREADSHEET_SUFFIXES:
|
||||
return self._parse_spreadsheet(file_path)
|
||||
elif suffix in PPT_SUFFIXES:
|
||||
return self._parse_ppt(file_path)
|
||||
else:
|
||||
return self._parse_text(file_path)
|
||||
|
||||
def _parse_pdf(self, file_path: str) -> str:
|
||||
"""Extract text from PDF using pypdf."""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
raise ImportError("pypdf library is required for PDF parsing. Install with: pip install pypdf")
|
||||
|
||||
reader = PdfReader(file_path)
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text and page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num}/{len(reader.pages)} ---\n{page_text}")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
def _parse_word(self, file_path: str) -> str:
|
||||
"""Extract text from Word documents (.docx)."""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-docx library is required for .docx parsing. Install with: pip install python-docx"
|
||||
)
|
||||
doc = Document(file_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return "\n\n".join(paragraphs)
|
||||
|
||||
def _parse_text(self, file_path: str) -> str:
|
||||
"""Read plain text files (txt, md, csv, etc.)."""
|
||||
encodings = ["utf-8", "utf-8-sig", "gbk", "gb2312", "latin-1"]
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(file_path, "r", encoding=enc) as f:
|
||||
return f.read()
|
||||
except (UnicodeDecodeError, UnicodeError):
|
||||
continue
|
||||
raise ValueError(f"Unable to decode file with any supported encoding: {encodings}")
|
||||
|
||||
def _parse_spreadsheet(self, file_path: str) -> str:
|
||||
"""Extract text from Excel files (.xls/.xlsx)."""
|
||||
try:
|
||||
import openpyxl
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"openpyxl library is required for .xlsx parsing. Install with: pip install openpyxl"
|
||||
)
|
||||
|
||||
wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True)
|
||||
result_parts = []
|
||||
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
rows = []
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
cells = [str(c) if c is not None else "" for c in row]
|
||||
if any(cells):
|
||||
rows.append(" | ".join(cells))
|
||||
if rows:
|
||||
result_parts.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
|
||||
|
||||
wb.close()
|
||||
return "\n\n".join(result_parts)
|
||||
|
||||
def _parse_ppt(self, file_path: str) -> str:
|
||||
"""Extract text from PowerPoint files (.ppt/.pptx)."""
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-pptx library is required for .pptx parsing. Install with: pip install python-pptx"
|
||||
)
|
||||
|
||||
prs = Presentation(file_path)
|
||||
text_parts = []
|
||||
|
||||
for slide_num, slide in enumerate(prs.slides, 1):
|
||||
slide_texts = []
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for paragraph in shape.text_frame.paragraphs:
|
||||
text = paragraph.text.strip()
|
||||
if text:
|
||||
slide_texts.append(text)
|
||||
if slide_texts:
|
||||
text_parts.append(f"--- Slide {slide_num}/{len(prs.slides)} ---\n" + "\n".join(slide_texts))
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
# ---- Encoding detection ----
|
||||
|
||||
@staticmethod
|
||||
def _detect_encoding(response: requests.Response) -> str:
|
||||
"""Detect response encoding with priority: Content-Type header > HTML meta > chardet > utf-8."""
|
||||
# 1. Check Content-Type header for explicit charset
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
charset = _extract_charset_from_content_type(content_type)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 2. Scan raw bytes for HTML meta charset declaration
|
||||
raw = response.content[:4096]
|
||||
charset = _extract_charset_from_html_meta(raw)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 3. Use apparent_encoding (chardet-based detection) if confident enough
|
||||
apparent = response.apparent_encoding
|
||||
if apparent:
|
||||
apparent_lower = apparent.lower()
|
||||
# Trust CJK / Windows encodings detected by chardet
|
||||
trusted_prefixes = ("utf", "gb", "big5", "euc", "shift_jis", "iso-2022", "windows", "ascii")
|
||||
if any(apparent_lower.startswith(p) for p in trusted_prefixes):
|
||||
return apparent
|
||||
|
||||
# 4. Fallback
|
||||
return "utf-8"
|
||||
|
||||
# ---- Helper methods ----
|
||||
|
||||
def _ensure_tmp_dir(self) -> str:
|
||||
"""Ensure workspace/tmp directory exists and return its path."""
|
||||
tmp_dir = os.path.join(self.cwd, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
def _extract_filename(self, url: str) -> str:
|
||||
"""Extract a safe filename from URL, with a short UUID prefix to avoid collisions."""
|
||||
path = urlparse(url).path
|
||||
basename = os.path.basename(unquote(path))
|
||||
if not basename or basename == "/":
|
||||
basename = "downloaded_file"
|
||||
# Sanitize: keep only safe chars
|
||||
basename = re.sub(r'[^\w.\-]', '_', basename)
|
||||
short_id = uuid.uuid4().hex[:8]
|
||||
return f"{short_id}_{basename}"
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_file(path: str):
|
||||
"""Remove a file if it exists, ignoring errors."""
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_binary_content_type(content_type: str) -> bool:
|
||||
"""Check if Content-Type indicates a binary/document response."""
|
||||
binary_types = [
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/octet-stream",
|
||||
]
|
||||
ct_lower = content_type.lower()
|
||||
return any(bt in ct_lower for bt in binary_types)
|
||||
|
||||
def _handle_download_by_content_type(self, url: str, response: requests.Response, content_type: str) -> ToolResult:
|
||||
"""Handle a URL that returned binary content instead of HTML."""
|
||||
ct_lower = content_type.lower()
|
||||
suffix_map = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml": ".docx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml": ".xlsx",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml": ".pptx",
|
||||
}
|
||||
detected_suffix = None
|
||||
for ct_prefix, ext in suffix_map.items():
|
||||
if ct_prefix in ct_lower:
|
||||
detected_suffix = ext
|
||||
break
|
||||
|
||||
if detected_suffix and detected_suffix in ALL_DOC_SUFFIXES:
|
||||
# Re-fetch as document
|
||||
return self._fetch_document(url if _get_url_suffix(url) in ALL_DOC_SUFFIXES
|
||||
else self._rewrite_url_with_suffix(url, detected_suffix))
|
||||
return ToolResult.fail(f"Error: URL returned binary content ({content_type}), not a supported document type")
|
||||
|
||||
@staticmethod
|
||||
def _rewrite_url_with_suffix(url: str, suffix: str) -> str:
|
||||
"""Append a suffix to the URL path so _get_url_suffix works correctly."""
|
||||
parsed = urlparse(url)
|
||||
new_path = parsed.path.rstrip("/") + suffix
|
||||
return parsed._replace(path=new_path).geturl()
|
||||
|
||||
# ---- HTML extraction (unchanged) ----
|
||||
|
||||
@staticmethod
|
||||
def _extract_title(html: str) -> str:
|
||||
match = re.search(r"<title[^>]*>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
|
||||
return match.group(1).strip() if match else "Untitled"
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(html: str) -> str:
|
||||
text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(""", '"').replace("'", "'").replace(" ", " ")
|
||||
text = re.sub(r"[^\S\n]+", " ", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
lines = [line.strip() for line in text.splitlines()]
|
||||
text = "\n".join(lines)
|
||||
return text.strip()
|
||||
@@ -1,33 +1,95 @@
|
||||
"""
|
||||
Web Search tool - Search the web using Bocha or LinkAI search API.
|
||||
Supports two backends with unified response format:
|
||||
1. Bocha Search (primary, requires BOCHA_API_KEY)
|
||||
2. LinkAI Search (fallback, requires LINKAI_API_KEY)
|
||||
"""Web Search tool. Supports four backends with a unified response format:
|
||||
- bocha (https://open.bochaai.com)
|
||||
- zhipu (https://docs.bigmodel.cn/cn/guide/tools/web-search)
|
||||
- qianfan (https://cloud.baidu.com/doc/qianfan/s/2mh4su4uy)
|
||||
- linkai (https://link-ai.tech, fallback)
|
||||
|
||||
Provider selection
|
||||
- strategy 'auto' (default): pick the first configured provider in the
|
||||
canonical order [bocha, zhipu, qianfan, linkai]. When the caller passes
|
||||
an explicit `provider` it overrides the pick; an invalid/unconfigured
|
||||
one silently falls back to the auto order.
|
||||
- strategy 'fixed': use the configured provider; if its credential is
|
||||
missing at call time, silently fall back to auto order (no card hint).
|
||||
|
||||
Credentials
|
||||
- bocha : tools.web_search.bocha_api_key -> env BOCHA_API_KEY
|
||||
- zhipu : conf.zhipu_ai_api_key -> env ZHIPUAI_API_KEY
|
||||
- qianfan : conf.qianfan_api_key -> env QIANFAN_API_KEY
|
||||
- linkai : conf.linkai_api_key -> env LINKAI_API_KEY
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
# Default timeout for API requests (seconds)
|
||||
DEFAULT_TIMEOUT = 30
|
||||
|
||||
# Canonical fallback order. Empirically ordered by Chinese real-time
|
||||
# quality + relevance: bocha (best overall), qianfan (best for hot news),
|
||||
# zhipu (strong on long-form articles), linkai (cloud aggregator, last
|
||||
# resort).
|
||||
PROVIDER_ORDER = ("bocha", "qianfan", "zhipu", "linkai")
|
||||
|
||||
PROVIDER_LABELS = {
|
||||
"bocha": "Bocha",
|
||||
"zhipu": "Zhipu",
|
||||
"qianfan": "Baidu Qianfan",
|
||||
"linkai": "LinkAI",
|
||||
}
|
||||
|
||||
|
||||
def _tools_web_search_conf() -> dict:
|
||||
"""Return the tools.web_search config block (dict-like)."""
|
||||
tools_cfg = conf().get("tools") or {}
|
||||
if not isinstance(tools_cfg, dict):
|
||||
return {}
|
||||
block = tools_cfg.get("web_search") or {}
|
||||
return block if isinstance(block, dict) else {}
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> str:
|
||||
"""Resolve API key for a provider, with conf -> env fallback."""
|
||||
if provider == "bocha":
|
||||
key = (_tools_web_search_conf().get("bocha_api_key") or "").strip()
|
||||
return key or os.environ.get("BOCHA_API_KEY", "").strip()
|
||||
if provider == "zhipu":
|
||||
key = (conf().get("zhipu_ai_api_key") or "").strip()
|
||||
return key or os.environ.get("ZHIPUAI_API_KEY", "").strip()
|
||||
if provider == "qianfan":
|
||||
key = (conf().get("qianfan_api_key") or "").strip()
|
||||
return key or os.environ.get("QIANFAN_API_KEY", "").strip()
|
||||
if provider == "linkai":
|
||||
key = (conf().get("linkai_api_key") or "").strip()
|
||||
return key or os.environ.get("LINKAI_API_KEY", "").strip()
|
||||
return ""
|
||||
|
||||
|
||||
def configured_providers() -> List[str]:
|
||||
"""Return configured providers in canonical order."""
|
||||
return [p for p in PROVIDER_ORDER if _get_api_key(p)]
|
||||
|
||||
|
||||
def _configured_strategy() -> str:
|
||||
return (_tools_web_search_conf().get("strategy") or "auto").strip().lower()
|
||||
|
||||
|
||||
def _configured_provider() -> str:
|
||||
return (_tools_web_search_conf().get("provider") or "").strip().lower()
|
||||
|
||||
|
||||
class WebSearch(BaseTool):
|
||||
"""Tool for searching the web using Bocha or LinkAI search API"""
|
||||
"""Tool for searching the web across multiple providers."""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = (
|
||||
"Search the web for current information, news, research topics, or any real-time data. "
|
||||
"Returns web page titles, URLs, snippets, and optional summaries. "
|
||||
"Use this when the user asks about recent events, needs fact-checking, or wants up-to-date information."
|
||||
)
|
||||
description: str = "Search the web for real-time information. Returns titles, URLs, and snippets."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
@@ -58,265 +120,368 @@ class WebSearch(BaseTool):
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self._backend = None # Will be resolved on first execute
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
"""Check if web search is available (at least one API key is configured)"""
|
||||
return bool(os.environ.get("BOCHA_API_KEY") or os.environ.get("LINKAI_API_KEY"))
|
||||
"""Tool is offered to the agent when at least one provider has a key."""
|
||||
return bool(configured_providers())
|
||||
|
||||
def _resolve_backend(self) -> Optional[str]:
|
||||
"""
|
||||
Determine which search backend to use.
|
||||
Priority: Bocha > LinkAI
|
||||
@classmethod
|
||||
def get_json_schema(cls) -> dict:
|
||||
"""Augment the static schema with a `provider` field — only when the
|
||||
user has ≥2 providers configured AND strategy is 'auto'. Otherwise
|
||||
the backend picks silently and exposing the field would only waste
|
||||
the agent's tokens."""
|
||||
schema = {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": json.loads(json.dumps(cls.params)), # deep copy
|
||||
}
|
||||
if _configured_strategy() != "auto":
|
||||
return schema
|
||||
available = configured_providers()
|
||||
if len(available) < 2:
|
||||
return schema
|
||||
|
||||
:return: 'bocha', 'linkai', or None
|
||||
schema["parameters"]["properties"]["provider"] = {
|
||||
"type": "string",
|
||||
"enum": available,
|
||||
"description": "Optional. Specifies the search backend. You may switch between providers when the user wants results from a particular source or from multiple sources.",
|
||||
}
|
||||
return schema
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider resolution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_provider(self, requested: Optional[str]) -> Optional[str]:
|
||||
"""Pick a provider for this call.
|
||||
|
||||
Priority: caller-supplied (if configured) > fixed strategy (if
|
||||
configured) > first configured in PROVIDER_ORDER. Silent fallback
|
||||
when the desired one has no key.
|
||||
"""
|
||||
if os.environ.get("BOCHA_API_KEY"):
|
||||
return "bocha"
|
||||
if os.environ.get("LINKAI_API_KEY"):
|
||||
return "linkai"
|
||||
return None
|
||||
available = configured_providers()
|
||||
if not available:
|
||||
return None
|
||||
|
||||
if requested:
|
||||
req = requested.strip().lower()
|
||||
if req in available:
|
||||
return req
|
||||
logger.warning(f"[WebSearch] requested provider '{requested}' unavailable, falling back")
|
||||
|
||||
if _configured_strategy() == "fixed":
|
||||
pinned = _configured_provider()
|
||||
if pinned in available:
|
||||
return pinned
|
||||
if pinned:
|
||||
logger.warning(f"[WebSearch] pinned provider '{pinned}' unavailable, falling back to auto")
|
||||
|
||||
return available[0]
|
||||
|
||||
@staticmethod
|
||||
def _resolution_reason(requested: Optional[str], chosen: str) -> str:
|
||||
"""Human-readable explanation for why `chosen` won the resolver."""
|
||||
if requested and requested.strip().lower() == chosen:
|
||||
return "caller-requested"
|
||||
strategy = _configured_strategy()
|
||||
if strategy == "fixed" and _configured_provider() == chosen:
|
||||
return "fixed-strategy"
|
||||
return "auto-fallback"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute web search
|
||||
|
||||
:param args: Search parameters (query, count, freshness, summary)
|
||||
:return: Search results
|
||||
"""
|
||||
query = args.get("query", "").strip()
|
||||
query = (args.get("query") or "").strip()
|
||||
if not query:
|
||||
return ToolResult.fail("Error: 'query' parameter is required")
|
||||
|
||||
count = args.get("count", 10)
|
||||
freshness = args.get("freshness", "noLimit")
|
||||
summary = args.get("summary", False)
|
||||
|
||||
# Validate count
|
||||
if not isinstance(count, int) or count < 1 or count > 50:
|
||||
count = 10
|
||||
|
||||
# Resolve backend
|
||||
backend = self._resolve_backend()
|
||||
if not backend:
|
||||
requested = args.get("provider")
|
||||
provider = self._resolve_provider(requested)
|
||||
if not provider:
|
||||
return ToolResult.fail(
|
||||
"Error: No search API key configured. "
|
||||
"Please set BOCHA_API_KEY or LINKAI_API_KEY using env_config tool.\n"
|
||||
" - Bocha Search: https://open.bocha.cn\n"
|
||||
" - LinkAI Search: https://link-ai.tech"
|
||||
"Error: No search provider configured. "
|
||||
"Configure one of BOCHA_API_KEY / zhipu_ai_api_key / qianfan_api_key / linkai_api_key."
|
||||
)
|
||||
|
||||
# Always log the routing decision so multi-provider deployments can
|
||||
# tell at a glance which backend served any given query.
|
||||
available = configured_providers()
|
||||
reason = self._resolution_reason(requested, provider)
|
||||
q_preview = query if len(query) <= 60 else (query[:57] + "...")
|
||||
logger.info(
|
||||
f"[WebSearch] provider={provider} reason={reason} "
|
||||
f"available={list(available)} query={q_preview!r} count={count} freshness={freshness}"
|
||||
)
|
||||
|
||||
try:
|
||||
if backend == "bocha":
|
||||
if provider == "bocha":
|
||||
return self._search_bocha(query, count, freshness, summary)
|
||||
else:
|
||||
if provider == "zhipu":
|
||||
return self._search_zhipu(query, count, freshness)
|
||||
if provider == "qianfan":
|
||||
return self._search_qianfan(query, count, freshness)
|
||||
if provider == "linkai":
|
||||
return self._search_linkai(query, count, freshness)
|
||||
return ToolResult.fail(f"Error: Unknown provider '{provider}'")
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Search request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail("Error: Failed to connect to search API")
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSearch] Unexpected error: {e}", exc_info=True)
|
||||
logger.error(f"[WebSearch] Unexpected error ({provider}): {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: Search failed - {str(e)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bocha
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_bocha(self, query: str, count: int, freshness: str, summary: bool) -> ToolResult:
|
||||
"""
|
||||
Search using Bocha API
|
||||
|
||||
:param query: Search query
|
||||
:param count: Number of results
|
||||
:param freshness: Time range filter
|
||||
:param summary: Whether to include summary
|
||||
:return: Formatted search results
|
||||
"""
|
||||
api_key = os.environ.get("BOCHA_API_KEY", "")
|
||||
url = "https://api.bocha.cn/v1/web-search"
|
||||
|
||||
api_key = _get_api_key("bocha")
|
||||
url = "https://api.bochaai.com/v1/web-search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json"
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"query": query, "count": count, "freshness": freshness, "summary": summary}
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"count": count,
|
||||
"freshness": freshness,
|
||||
"summary": summary
|
||||
}
|
||||
logger.debug(f"[WebSearch] bocha: query='{query}', count={count}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
logger.debug(f"[WebSearch] Bocha search: query='{query}', count={count}")
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid bocha API key.")
|
||||
if resp.status_code == 403:
|
||||
return ToolResult.fail("Error: bocha API — insufficient balance. Top up at https://open.bochaai.com")
|
||||
if resp.status_code == 429:
|
||||
return ToolResult.fail("Error: bocha API rate limit reached.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: bocha API returned HTTP {resp.status_code}")
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if response.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid BOCHA_API_KEY. Please check your API key.")
|
||||
if response.status_code == 403:
|
||||
return ToolResult.fail("Error: Bocha API - insufficient balance. Please top up at https://open.bocha.cn")
|
||||
if response.status_code == 429:
|
||||
return ToolResult.fail("Error: Bocha API rate limit reached. Please try again later.")
|
||||
if response.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Bocha API returned HTTP {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check API-level error code
|
||||
data = resp.json()
|
||||
api_code = data.get("code")
|
||||
if api_code is not None and api_code != 200:
|
||||
msg = data.get("msg") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: Bocha API error (code={api_code}): {msg}")
|
||||
|
||||
# Extract and format results
|
||||
return self._format_bocha_results(data, query)
|
||||
|
||||
def _format_bocha_results(self, data: dict, query: str) -> ToolResult:
|
||||
"""
|
||||
Format Bocha API response into unified result structure
|
||||
|
||||
:param data: Raw API response
|
||||
:param query: Original query
|
||||
:return: Formatted ToolResult
|
||||
"""
|
||||
search_data = data.get("data", {})
|
||||
web_pages = search_data.get("webPages", {})
|
||||
pages = web_pages.get("value", [])
|
||||
|
||||
if not pages:
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "bocha",
|
||||
"total": 0,
|
||||
"results": [],
|
||||
"message": "No results found"
|
||||
})
|
||||
return ToolResult.fail(f"Error: bocha API error (code={api_code}): {msg}")
|
||||
|
||||
pages = (data.get("data") or {}).get("webPages", {}).get("value", []) or []
|
||||
results = []
|
||||
for page in pages:
|
||||
result = {
|
||||
"title": page.get("name", ""),
|
||||
"url": page.get("url", ""),
|
||||
"snippet": page.get("snippet", ""),
|
||||
"siteName": page.get("siteName", ""),
|
||||
"datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""),
|
||||
for p in pages:
|
||||
item = {
|
||||
"title": p.get("name", ""),
|
||||
"url": p.get("url", ""),
|
||||
"snippet": p.get("snippet", ""),
|
||||
"siteName": p.get("siteName", ""),
|
||||
"datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""),
|
||||
}
|
||||
# Include summary only if present
|
||||
if page.get("summary"):
|
||||
result["summary"] = page["summary"]
|
||||
results.append(result)
|
||||
|
||||
total = web_pages.get("totalEstimatedMatches", len(results))
|
||||
|
||||
if p.get("summary"):
|
||||
item["summary"] = p["summary"]
|
||||
results.append(item)
|
||||
total = (data.get("data") or {}).get("webPages", {}).get("totalEstimatedMatches", len(results))
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "bocha",
|
||||
"total": total,
|
||||
"count": len(results),
|
||||
"results": results
|
||||
"query": query, "backend": "bocha",
|
||||
"total": total, "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
"""
|
||||
Search using LinkAI plugin API
|
||||
|
||||
:param query: Search query
|
||||
:param count: Number of results
|
||||
:param freshness: Time range filter
|
||||
:return: Formatted search results
|
||||
"""
|
||||
api_key = os.environ.get("LINKAI_API_KEY", "")
|
||||
url = "https://api.link-ai.tech/v1/plugin/execute"
|
||||
# ------------------------------------------------------------------
|
||||
# Zhipu
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_zhipu(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("zhipu")
|
||||
api_base = (conf().get("zhipu_ai_api_base") or "https://open.bigmodel.cn/api/paas/v4").rstrip("/")
|
||||
url = f"{api_base}/web_search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"code": "web-search",
|
||||
"args": {
|
||||
"query": query,
|
||||
"count": count,
|
||||
"freshness": freshness
|
||||
}
|
||||
# Zhipu Web Search expects `search_query` <= 70 chars; truncate
|
||||
# gracefully so a long agent-supplied query doesn't get rejected.
|
||||
trimmed_query = (query or "")[:70]
|
||||
engine = (_tools_web_search_conf().get("zhipu_search_engine") or "search_pro").strip().lower()
|
||||
if engine not in ("search_std", "search_pro", "search_pro_sogou", "search_pro_quark"):
|
||||
engine = "search_pro"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"search_engine": engine,
|
||||
"search_query": trimmed_query,
|
||||
"search_intent": False,
|
||||
"count": max(1, min(int(count or 10), 50)),
|
||||
"search_recency_filter": freshness if freshness in (
|
||||
"oneDay", "oneWeek", "oneMonth", "oneYear", "noLimit"
|
||||
) else "noLimit",
|
||||
}
|
||||
content_size = (_tools_web_search_conf().get("zhipu_content_size") or "").strip().lower()
|
||||
if content_size in ("medium", "high"):
|
||||
payload["content_size"] = content_size
|
||||
|
||||
logger.debug(f"[WebSearch] zhipu: query='{trimmed_query}', count={payload['count']}, engine={engine}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid Zhipu API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Zhipu API returned HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
# Business-level errors (1701/1702/1703 etc.) come back as
|
||||
# {"error": {"code","message"}} even on HTTP 200.
|
||||
if isinstance(data, dict) and data.get("error"):
|
||||
err = data["error"] or {}
|
||||
return ToolResult.fail(f"Error: Zhipu returned {err.get('code')}: {err.get('message','')}")
|
||||
|
||||
items = data.get("search_result") or (data.get("data") or {}).get("search_result") or []
|
||||
results = []
|
||||
for it in items:
|
||||
results.append({
|
||||
"title": it.get("title", ""),
|
||||
"url": it.get("link") or it.get("url", ""),
|
||||
"snippet": it.get("content") or it.get("snippet", ""),
|
||||
"siteName": it.get("media") or it.get("siteName", ""),
|
||||
"datePublished": it.get("publish_date") or it.get("datePublished", ""),
|
||||
})
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "zhipu",
|
||||
"total": len(results), "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Qianfan (Baidu)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_qianfan(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("qianfan")
|
||||
api_base = (conf().get("qianfan_api_base") or "https://qianfan.baidubce.com/v2").rstrip("/")
|
||||
url = f"{api_base}/ai_search/web_search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-Appbuilder-From": "cow",
|
||||
}
|
||||
|
||||
logger.debug(f"[WebSearch] LinkAI search: query='{query}', count={count}")
|
||||
count = max(1, min(int(count or 10), 50))
|
||||
payload: Dict[str, Any] = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "web", "top_k": count}],
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
# Baidu AI Search expects freshness as a date-range filter, not a
|
||||
# named recency token. Translate our shared vocabulary into the
|
||||
# underlying page_time range expected by the API.
|
||||
search_filter = self._qianfan_build_freshness_filter(freshness)
|
||||
if search_filter:
|
||||
payload["search_filter"] = search_filter
|
||||
|
||||
if response.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid LINKAI_API_KEY. Please check your API key.")
|
||||
if response.status_code != 200:
|
||||
return ToolResult.fail(f"Error: LinkAI API returned HTTP {response.status_code}")
|
||||
logger.debug(f"[WebSearch] qianfan: query='{query}', count={count}, freshness={freshness!r}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
data = response.json()
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid Qianfan API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Qianfan API returned HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
# Even on HTTP 200 Baidu surfaces business errors as {"code","message"}.
|
||||
if isinstance(data, dict) and data.get("code"):
|
||||
return ToolResult.fail(f"Error: Qianfan returned {data.get('code')}: {data.get('message','')}")
|
||||
|
||||
refs = data.get("references") or []
|
||||
results = []
|
||||
for d in refs:
|
||||
results.append({
|
||||
"title": d.get("title", ""),
|
||||
"url": d.get("url", ""),
|
||||
"snippet": (d.get("content") or "")[:200],
|
||||
"siteName": d.get("web_anchor") or d.get("website") or "",
|
||||
"datePublished": d.get("date", ""),
|
||||
})
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "qianfan",
|
||||
"total": len(results), "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _qianfan_build_freshness_filter(freshness: str) -> Optional[Dict[str, Any]]:
|
||||
if not freshness or freshness == "noLimit":
|
||||
return None
|
||||
delta_days = {"oneDay": 1, "oneWeek": 7, "oneMonth": 30, "oneYear": 365}.get(freshness)
|
||||
if not delta_days:
|
||||
return None
|
||||
from datetime import datetime, timedelta
|
||||
now = datetime.now()
|
||||
end_date = (now + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
start_date = (now - timedelta(days=delta_days)).strftime("%Y-%m-%d")
|
||||
return {"range": {"page_time": {"gte": start_date, "lt": end_date}}}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LinkAI (plugin)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("linkai")
|
||||
api_base = (conf().get("linkai_api_base") or "https://api.link-ai.tech").rstrip("/")
|
||||
url = f"{api_base}/v1/plugin/execute"
|
||||
|
||||
from common.utils import get_cloud_headers
|
||||
headers = get_cloud_headers(api_key)
|
||||
|
||||
payload = {"code": "web-search", "args": {"query": query, "count": count, "freshness": freshness}}
|
||||
logger.debug(f"[WebSearch] linkai: query='{query}', count={count}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid LinkAI API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: LinkAI API returned HTTP {resp.status_code}")
|
||||
|
||||
data = resp.json()
|
||||
if not data.get("success"):
|
||||
msg = data.get("message") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: LinkAI search failed: {msg}")
|
||||
|
||||
return self._format_linkai_results(data, query)
|
||||
|
||||
def _format_linkai_results(self, data: dict, query: str) -> ToolResult:
|
||||
"""
|
||||
Format LinkAI API response into unified result structure.
|
||||
LinkAI returns the search data in data.data field, which follows
|
||||
the same Bing-compatible format as Bocha.
|
||||
|
||||
:param data: Raw API response
|
||||
:param query: Original query
|
||||
:return: Formatted ToolResult
|
||||
"""
|
||||
raw_data = data.get("data", "")
|
||||
|
||||
# LinkAI may return data as a JSON string
|
||||
if isinstance(raw_data, str):
|
||||
raw = data.get("data", "")
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
raw_data = json.loads(raw_data)
|
||||
raw = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If data is plain text, return it as a single result
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "linkai",
|
||||
"total": 1,
|
||||
"count": 1,
|
||||
"results": [{"content": raw_data}]
|
||||
"query": query, "backend": "linkai",
|
||||
"total": 1, "count": 1, "results": [{"content": raw}],
|
||||
})
|
||||
|
||||
# If the response follows Bing-compatible structure
|
||||
if isinstance(raw_data, dict):
|
||||
web_pages = raw_data.get("webPages", {})
|
||||
pages = web_pages.get("value", [])
|
||||
|
||||
if isinstance(raw, dict):
|
||||
pages = (raw.get("webPages") or {}).get("value", []) or []
|
||||
if pages:
|
||||
results = []
|
||||
for page in pages:
|
||||
result = {
|
||||
"title": page.get("name", ""),
|
||||
"url": page.get("url", ""),
|
||||
"snippet": page.get("snippet", ""),
|
||||
"siteName": page.get("siteName", ""),
|
||||
"datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""),
|
||||
for p in pages:
|
||||
item = {
|
||||
"title": p.get("name", ""),
|
||||
"url": p.get("url", ""),
|
||||
"snippet": p.get("snippet", ""),
|
||||
"siteName": p.get("siteName", ""),
|
||||
"datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""),
|
||||
}
|
||||
if page.get("summary"):
|
||||
result["summary"] = page["summary"]
|
||||
results.append(result)
|
||||
|
||||
total = web_pages.get("totalEstimatedMatches", len(results))
|
||||
if p.get("summary"):
|
||||
item["summary"] = p["summary"]
|
||||
results.append(item)
|
||||
total = (raw.get("webPages") or {}).get("totalEstimatedMatches", len(results))
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "linkai",
|
||||
"total": total,
|
||||
"count": len(results),
|
||||
"results": results
|
||||
"query": query, "backend": "linkai",
|
||||
"total": total, "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
# Fallback: return raw data
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "linkai",
|
||||
"total": 1,
|
||||
"count": 1,
|
||||
"results": [{"content": str(raw_data)}]
|
||||
"query": query, "backend": "linkai",
|
||||
"total": 1, "count": 1, "results": [{"content": str(raw)}],
|
||||
})
|
||||
|
||||
308
app.py
@@ -13,7 +13,6 @@ from plugins import *
|
||||
import threading
|
||||
|
||||
|
||||
# Global channel manager for restart support
|
||||
_channel_mgr = None
|
||||
|
||||
|
||||
@@ -21,94 +20,206 @@ def get_channel_manager():
|
||||
return _channel_mgr
|
||||
|
||||
|
||||
def _parse_channel_type(raw) -> list:
|
||||
"""
|
||||
Parse channel_type config value into a list of channel names.
|
||||
Supports:
|
||||
- single string: "feishu"
|
||||
- comma-separated string: "feishu, dingtalk"
|
||||
- list: ["feishu", "dingtalk"]
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return [ch.strip() for ch in raw if ch.strip()]
|
||||
if isinstance(raw, str):
|
||||
return [ch.strip() for ch in raw.split(",") if ch.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manage the lifecycle of a channel, supporting restart from sub-threads.
|
||||
The channel.startup() runs in a daemon thread so that the main thread
|
||||
remains available and a new channel can be started at any time.
|
||||
Manage the lifecycle of multiple channels running concurrently.
|
||||
Each channel.startup() runs in its own daemon thread.
|
||||
The web channel is started as default console unless explicitly disabled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._channel = None
|
||||
self._channel_thread = None
|
||||
self._channels = {} # channel_name -> channel instance
|
||||
self._threads = {} # channel_name -> thread
|
||||
self._primary_channel = None
|
||||
self._lock = threading.Lock()
|
||||
self.cloud_mode = False # set to True when cloud client is active
|
||||
|
||||
@property
|
||||
def channel(self):
|
||||
return self._channel
|
||||
"""Return the primary (first non-web) channel for backward compatibility."""
|
||||
return self._primary_channel
|
||||
|
||||
def start(self, channel_name: str, first_start: bool = False):
|
||||
def get_channel(self, channel_name: str):
|
||||
return self._channels.get(channel_name)
|
||||
|
||||
def start(self, channel_names: list, first_start: bool = False):
|
||||
"""
|
||||
Create and start a channel in a sub-thread.
|
||||
Create and start one or more channels in sub-threads.
|
||||
If first_start is True, plugins and linkai client will also be initialized.
|
||||
"""
|
||||
with self._lock:
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
self._channel = channel
|
||||
channels = []
|
||||
for name in channel_names:
|
||||
ch = channel_factory.create_channel(name)
|
||||
ch.cloud_mode = self.cloud_mode
|
||||
self._channels[name] = ch
|
||||
channels.append((name, ch))
|
||||
if self._primary_channel is None and name != "web":
|
||||
self._primary_channel = ch
|
||||
|
||||
if self._primary_channel is None and channels:
|
||||
self._primary_channel = channels[0][1]
|
||||
|
||||
if first_start:
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "web",
|
||||
"wechatmp_service", "wechatcom_app", "wework",
|
||||
const.FEISHU, const.DINGTALK]:
|
||||
PluginManager().load_plugins()
|
||||
PluginManager().load_plugins()
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
# Cloud client is optional. It is only started when
|
||||
# use_linkai=True AND cloud_deployment_id is set.
|
||||
# By default neither is configured, so the app runs
|
||||
# entirely locally without any remote connection.
|
||||
if conf().get("use_linkai") and (
|
||||
os.environ.get("CLOUD_DEPLOYMENT_ID") or conf().get("cloud_deployment_id")
|
||||
):
|
||||
try:
|
||||
from common import cloud_client
|
||||
threading.Thread(target=cloud_client.start, args=(channel, self), daemon=True).start()
|
||||
except Exception as e:
|
||||
threading.Thread(
|
||||
target=cloud_client.start,
|
||||
args=(self._primary_channel, self),
|
||||
daemon=True,
|
||||
).start()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Run channel.startup() in a daemon thread so we can restart later
|
||||
self._channel_thread = threading.Thread(
|
||||
target=self._run_channel, args=(channel,), daemon=True
|
||||
)
|
||||
self._channel_thread.start()
|
||||
logger.debug(f"[ChannelManager] Channel '{channel_name}' started in sub-thread")
|
||||
# Start web console first so its logs print cleanly,
|
||||
# then start remaining channels after a brief pause.
|
||||
web_entry = None
|
||||
other_entries = []
|
||||
for entry in channels:
|
||||
if entry[0] == "web":
|
||||
web_entry = entry
|
||||
else:
|
||||
other_entries.append(entry)
|
||||
|
||||
def _run_channel(self, channel):
|
||||
ordered = ([web_entry] if web_entry else []) + other_entries
|
||||
for i, (name, ch) in enumerate(ordered):
|
||||
if i > 0 and name != "web":
|
||||
time.sleep(0.1)
|
||||
t = threading.Thread(target=self._run_channel, args=(name, ch), daemon=True)
|
||||
self._threads[name] = t
|
||||
t.start()
|
||||
logger.debug(f"[ChannelManager] Channel '{name}' started in sub-thread")
|
||||
|
||||
def _run_channel(self, name: str, channel):
|
||||
try:
|
||||
channel.startup()
|
||||
except Exception as e:
|
||||
logger.error(f"[ChannelManager] Channel startup error: {e}")
|
||||
logger.error(f"[ChannelManager] Channel '{name}' startup error: {e}")
|
||||
logger.exception(e)
|
||||
|
||||
def stop(self):
|
||||
def stop(self, channel_name: str = None):
|
||||
"""
|
||||
Stop the current channel. Since most channel startup() methods block
|
||||
on an HTTP server or stream client, we stop by terminating the thread.
|
||||
Stop channel(s). If channel_name is given, stop only that channel;
|
||||
otherwise stop all channels.
|
||||
"""
|
||||
# Pop under lock, then stop outside lock to avoid deadlock
|
||||
with self._lock:
|
||||
if self._channel is None:
|
||||
names = [channel_name] if channel_name else list(self._channels.keys())
|
||||
to_stop = []
|
||||
for name in names:
|
||||
ch = self._channels.pop(name, None)
|
||||
th = self._threads.pop(name, None)
|
||||
to_stop.append((name, ch, th))
|
||||
if channel_name and self._primary_channel is self._channels.get(channel_name):
|
||||
self._primary_channel = None
|
||||
|
||||
for name, ch, th in to_stop:
|
||||
if ch is None:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' not found in managed channels")
|
||||
if th and th.is_alive():
|
||||
self._interrupt_thread(th, name)
|
||||
continue
|
||||
logger.info(f"[ChannelManager] Stopping channel '{name}'...")
|
||||
graceful = False
|
||||
if hasattr(ch, 'stop'):
|
||||
try:
|
||||
ch.stop()
|
||||
graceful = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Error during channel '{name}' stop: {e}")
|
||||
if th and th.is_alive():
|
||||
th.join(timeout=5)
|
||||
if th.is_alive():
|
||||
if graceful:
|
||||
logger.info(f"[ChannelManager] Channel '{name}' thread still alive after stop(), "
|
||||
"leaving daemon thread to finish on its own")
|
||||
else:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' thread did not exit in 5s, forcing interrupt")
|
||||
self._interrupt_thread(th, name)
|
||||
|
||||
@staticmethod
|
||||
def _interrupt_thread(th: threading.Thread, name: str):
|
||||
"""Raise SystemExit in target thread to break blocking loops like start_forever."""
|
||||
import ctypes
|
||||
try:
|
||||
tid = th.ident
|
||||
if tid is None:
|
||||
return
|
||||
channel_type = getattr(self._channel, 'channel_type', 'unknown')
|
||||
logger.info(f"[ChannelManager] Stopping channel '{channel_type}'...")
|
||||
|
||||
# Try graceful stop if channel implements it
|
||||
try:
|
||||
if hasattr(self._channel, 'stop'):
|
||||
self._channel.stop()
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Error during channel stop: {e}")
|
||||
|
||||
self._channel = None
|
||||
self._channel_thread = None
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
|
||||
)
|
||||
if res == 1:
|
||||
logger.info(f"[ChannelManager] Interrupted thread for channel '{name}'")
|
||||
elif res > 1:
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
|
||||
logger.warning(f"[ChannelManager] Failed to interrupt thread for channel '{name}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Thread interrupt error for '{name}': {e}")
|
||||
|
||||
def restart(self, new_channel_name: str):
|
||||
"""
|
||||
Restart the channel with a new channel type.
|
||||
Restart a single channel with a new channel type.
|
||||
Can be called from any thread (e.g. linkai config callback).
|
||||
"""
|
||||
logger.info(f"[ChannelManager] Restarting channel to '{new_channel_name}'...")
|
||||
self.stop()
|
||||
|
||||
# Clear singleton cache so a fresh channel instance is created
|
||||
self.stop(new_channel_name)
|
||||
_clear_singleton_cache(new_channel_name)
|
||||
|
||||
time.sleep(1) # Brief pause to allow resources to release
|
||||
self.start(new_channel_name, first_start=False)
|
||||
time.sleep(1)
|
||||
self.start([new_channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel restarted to '{new_channel_name}' successfully")
|
||||
|
||||
def add_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically add and start a new channel.
|
||||
If the channel is already running, restart it instead.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name in self._channels:
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' already exists, restarting")
|
||||
if self._channels.get(channel_name):
|
||||
self.restart(channel_name)
|
||||
return
|
||||
logger.info(f"[ChannelManager] Adding channel '{channel_name}'...")
|
||||
_clear_singleton_cache(channel_name)
|
||||
self.start([channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' added successfully")
|
||||
|
||||
def remove_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically stop and remove a running channel.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name not in self._channels:
|
||||
logger.warning(f"[ChannelManager] Channel '{channel_name}' not found, nothing to remove")
|
||||
return
|
||||
logger.info(f"[ChannelManager] Removing channel '{channel_name}'...")
|
||||
self.stop(channel_name)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' removed successfully")
|
||||
|
||||
|
||||
def _clear_singleton_cache(channel_name: str):
|
||||
"""
|
||||
@@ -116,28 +227,25 @@ def _clear_singleton_cache(channel_name: str):
|
||||
a new instance can be created with updated config.
|
||||
"""
|
||||
cls_map = {
|
||||
"wx": "channel.wechat.wechat_channel.WechatChannel",
|
||||
"wxy": "channel.wechat.wechaty_channel.WechatyChannel",
|
||||
"wcf": "channel.wechat.wcf_channel.WechatfChannel",
|
||||
"web": "channel.web.web_channel.WebChannel",
|
||||
"wechatmp": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatmp_service": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatcom_app": "channel.wechatcom.wechatcomapp_channel.WechatComAppChannel",
|
||||
"wework": "channel.wework.wework_channel.WeworkChannel",
|
||||
const.FEISHU: "channel.feishu.feishu_channel.FeiShuChanel",
|
||||
const.DINGTALK: "channel.dingtalk.dingtalk_channel.DingTalkChanel",
|
||||
const.WECOM_BOT: "channel.wecom_bot.wecom_bot_channel.WecomBotChannel",
|
||||
const.QQ: "channel.qq.qq_channel.QQChannel",
|
||||
const.WEIXIN: "channel.weixin.weixin_channel.WeixinChannel",
|
||||
"wx": "channel.weixin.weixin_channel.WeixinChannel",
|
||||
}
|
||||
module_path = cls_map.get(channel_name)
|
||||
if not module_path:
|
||||
return
|
||||
# The singleton decorator stores instances in a closure dict keyed by class.
|
||||
# We need to find the actual class and clear it from the closure.
|
||||
try:
|
||||
parts = module_path.rsplit(".", 1)
|
||||
module_name, class_name = parts[0], parts[1]
|
||||
import importlib
|
||||
module = importlib.import_module(module_name)
|
||||
# The module-level name is the wrapper function from @singleton
|
||||
wrapper = getattr(module, class_name, None)
|
||||
if wrapper and hasattr(wrapper, '__closure__') and wrapper.__closure__:
|
||||
for cell in wrapper.__closure__:
|
||||
@@ -166,6 +274,63 @@ def sigterm_handler_wrap(_signo):
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def _warmup_mcp_tools():
|
||||
"""
|
||||
Kick off MCP server loading at process startup so subprocesses
|
||||
(npx / uvx etc.) finish initializing before the first user message
|
||||
arrives. Returns immediately — the actual work happens on a daemon
|
||||
thread inside ToolManager. Safe to call when MCP is not configured.
|
||||
"""
|
||||
try:
|
||||
from agent.tools import ToolManager
|
||||
ToolManager()._load_mcp_tools()
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] MCP warmup failed (non-fatal): {e}")
|
||||
|
||||
|
||||
def _warmup_scheduler():
|
||||
"""Eager-init AgentBridge so the scheduler thread starts at process
|
||||
boot rather than waiting for the first user message."""
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
Bridge().get_agent_bridge()
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Scheduler warmup failed: {e}")
|
||||
|
||||
|
||||
def _sync_builtin_skills():
|
||||
"""Sync builtin skills from project skills/ to workspace skills/ on startup."""
|
||||
import shutil
|
||||
try:
|
||||
workspace = conf().get("agent_workspace", "~/cow")
|
||||
workspace = os.path.expanduser(workspace)
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
builtin_dir = os.path.join(project_root, "skills")
|
||||
custom_dir = os.path.join(workspace, "skills")
|
||||
|
||||
if not os.path.isdir(builtin_dir):
|
||||
return
|
||||
|
||||
os.makedirs(custom_dir, exist_ok=True)
|
||||
synced = 0
|
||||
for name in os.listdir(builtin_dir):
|
||||
src = os.path.join(builtin_dir, name)
|
||||
if not os.path.isdir(src) or not os.path.isfile(os.path.join(src, "SKILL.md")):
|
||||
continue
|
||||
dst = os.path.join(custom_dir, name)
|
||||
try:
|
||||
if os.path.isdir(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst)
|
||||
synced += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Failed to sync builtin skill '{name}': {e}")
|
||||
if synced:
|
||||
logger.info(f"[App] Synced {synced} builtin skill(s) to workspace")
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Builtin skills sync failed: {e}")
|
||||
|
||||
|
||||
def run():
|
||||
global _channel_mgr
|
||||
try:
|
||||
@@ -176,20 +341,39 @@ def run():
|
||||
# kill signal
|
||||
sigterm_handler_wrap(signal.SIGTERM)
|
||||
|
||||
# create channel
|
||||
channel_name = conf().get("channel_type", "wx")
|
||||
# Parse channel_type into a list
|
||||
raw_channel = conf().get("channel_type", "web")
|
||||
|
||||
if "--cmd" in sys.argv:
|
||||
channel_name = "terminal"
|
||||
channel_names = ["terminal"]
|
||||
else:
|
||||
channel_names = _parse_channel_type(raw_channel)
|
||||
if not channel_names:
|
||||
channel_names = ["web"]
|
||||
|
||||
if channel_name == "wxy":
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
# Auto-start web console unless explicitly disabled
|
||||
web_console_enabled = conf().get("web_console", True)
|
||||
if web_console_enabled and "web" not in channel_names:
|
||||
channel_names.append("web")
|
||||
|
||||
# Sync builtin skills to workspace before channels start
|
||||
_sync_builtin_skills()
|
||||
|
||||
# Kick off MCP server loading in the background so first-message
|
||||
# latency isn't dominated by npx package downloads.
|
||||
_warmup_mcp_tools()
|
||||
|
||||
_warmup_scheduler()
|
||||
|
||||
logger.info(f"[App] Starting channels: {channel_names}")
|
||||
|
||||
_channel_mgr = ChannelManager()
|
||||
_channel_mgr.start(channel_name, first_start=True)
|
||||
_channel_mgr.start(channel_names, first_start=True)
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
|
||||
@@ -5,7 +5,7 @@ Agent Bridge - Integrates Agent system with existing COW bridge
|
||||
import os
|
||||
from typing import Optional, List
|
||||
|
||||
from agent.protocol import Agent, LLMModel, LLMRequest
|
||||
from agent.protocol import Agent, LLMModel, LLMRequest, get_cancel_registry
|
||||
from bridge.agent_event_handler import AgentEventHandler
|
||||
from bridge.agent_initializer import AgentInitializer
|
||||
from bridge.bridge import Bridge
|
||||
@@ -14,6 +14,7 @@ from bridge.reply import Reply, ReplyType
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
from models.openai_compatible_bot import OpenAICompatibleBot
|
||||
|
||||
|
||||
@@ -65,30 +66,73 @@ class AgentLLMModel(LLMModel):
|
||||
LLM Model adapter that uses COW's existing bot infrastructure
|
||||
"""
|
||||
|
||||
_MODEL_BOT_TYPE_MAP = {
|
||||
"wenxin": const.BAIDU, "wenxin-4": const.BAIDU,
|
||||
"xunfei": const.XUNFEI, const.QWEN: const.QWEN_DASHSCOPE,
|
||||
const.QIANFAN: const.QIANFAN,
|
||||
const.MODELSCOPE: const.MODELSCOPE,
|
||||
}
|
||||
_MODEL_PREFIX_MAP = [
|
||||
("qwen", const.QWEN_DASHSCOPE), ("qwq", const.QWEN_DASHSCOPE), ("qvq", const.QWEN_DASHSCOPE),
|
||||
("gemini", const.GEMINI), ("glm", const.ZHIPU_AI), ("claude", const.CLAUDEAPI),
|
||||
("moonshot", const.MOONSHOT), ("kimi", const.MOONSHOT),
|
||||
("doubao", const.DOUBAO), ("deepseek", const.DEEPSEEK),
|
||||
("ernie", const.QIANFAN),
|
||||
]
|
||||
|
||||
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
|
||||
# Get model name directly from config
|
||||
from config import conf
|
||||
model_name = conf().get("model", const.GPT_41)
|
||||
super().__init__(model=model_name)
|
||||
super().__init__(model=conf().get("model", const.GPT_41))
|
||||
self.bridge = bridge
|
||||
self.bot_type = bot_type
|
||||
self._bot = None
|
||||
self._use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
|
||||
self._bot_model = None
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return conf().get("model", const.GPT_41)
|
||||
|
||||
@model.setter
|
||||
def model(self, value):
|
||||
pass
|
||||
|
||||
def _resolve_bot_type(self, model_name: str) -> str:
|
||||
"""Resolve bot type from model name, matching Bridge.__init__ logic."""
|
||||
if conf().get("use_linkai", False) and conf().get("linkai_api_key"):
|
||||
return const.LINKAI
|
||||
# Support custom bot type configuration
|
||||
configured_bot_type = conf().get("bot_type")
|
||||
if configured_bot_type:
|
||||
return configured_bot_type
|
||||
|
||||
if not model_name or not isinstance(model_name, str):
|
||||
return const.OPENAI
|
||||
if model_name in self._MODEL_BOT_TYPE_MAP:
|
||||
return self._MODEL_BOT_TYPE_MAP[model_name]
|
||||
if model_name.lower().startswith("minimax") or model_name in ["abab6.5-chat"]:
|
||||
return const.MiniMax
|
||||
if model_name in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
return const.QWEN_DASHSCOPE
|
||||
if model_name in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
return const.MOONSHOT
|
||||
if conf().get("bot_type") == "modelscope":
|
||||
return const.MODELSCOPE
|
||||
lowered_model = model_name.lower()
|
||||
for prefix, btype in self._MODEL_PREFIX_MAP:
|
||||
if lowered_model.startswith(prefix):
|
||||
return btype
|
||||
return const.OPENAI
|
||||
|
||||
@property
|
||||
def bot(self):
|
||||
"""Lazy load the bot and enhance it with tool calling if needed"""
|
||||
if self._bot is None:
|
||||
# If use_linkai is enabled, use LinkAI bot directly
|
||||
if self._use_linkai:
|
||||
self._bot = self.bridge.find_chat_bot(const.LINKAI)
|
||||
else:
|
||||
self._bot = self.bridge.get_bot(self.bot_type)
|
||||
# Automatically add tool calling support if not present
|
||||
self._bot = add_openai_compatible_support(self._bot)
|
||||
|
||||
# Log bot info
|
||||
bot_name = type(self._bot).__name__
|
||||
"""Lazy load the bot, re-create when model or bot_type changes"""
|
||||
from models.bot_factory import create_bot
|
||||
cur_model = self.model
|
||||
cur_bot_type = self._resolve_bot_type(cur_model)
|
||||
if self._bot is None or self._bot_model != cur_model or getattr(self, '_bot_type', None) != cur_bot_type:
|
||||
self._bot = create_bot(cur_bot_type)
|
||||
self._bot = add_openai_compatible_support(self._bot)
|
||||
self._bot_model = cur_model
|
||||
self._bot_type = cur_bot_type
|
||||
return self._bot
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
@@ -109,12 +153,37 @@ class AgentLLMModel(LLMModel):
|
||||
# Only pass max_tokens if it's explicitly set
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
|
||||
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
|
||||
# Pass context metadata to bot
|
||||
channel_type = getattr(self, 'channel_type', None) or ''
|
||||
if channel_type:
|
||||
kwargs['channel_type'] = channel_type
|
||||
session_id = getattr(self, 'session_id', None)
|
||||
if session_id:
|
||||
kwargs['session_id'] = session_id
|
||||
|
||||
# Thinking mode is a global toggle independent of the channel.
|
||||
# IM channels (WeChat/WeCom/DingTalk/Feishu) won't render the
|
||||
# reasoning trace, but still benefit from the higher answer
|
||||
# quality the thinking pass produces.
|
||||
from config import conf
|
||||
thinking_enabled = bool(conf().get("enable_thinking", False))
|
||||
kwargs['thinking'] = (
|
||||
{"type": "enabled"} if thinking_enabled
|
||||
else {"type": "disabled"}
|
||||
)
|
||||
# Reasoning effort is only meaningful when thinking is on.
|
||||
# Bots that don't understand the kwarg drop it silently.
|
||||
if thinking_enabled:
|
||||
effort = conf().get("reasoning_effort", "high")
|
||||
if effort in ("high", "max"):
|
||||
kwargs['reasoning_effort'] = effort
|
||||
|
||||
response = self.bot.call_with_tools(**kwargs)
|
||||
return self._format_response(response)
|
||||
else:
|
||||
@@ -135,7 +204,7 @@ class AgentLLMModel(LLMModel):
|
||||
# Use tool-enabled streaming call if available
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
|
||||
|
||||
# Build kwargs for call_with_tools
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
@@ -143,15 +212,40 @@ class AgentLLMModel(LLMModel):
|
||||
'stream': True,
|
||||
'model': self.model # Pass model parameter
|
||||
}
|
||||
|
||||
|
||||
# Only pass max_tokens if explicitly set, let the bot use its default
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
|
||||
|
||||
# Add system prompt if present
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
|
||||
# Pass context metadata to bot
|
||||
channel_type = getattr(self, 'channel_type', None) or ''
|
||||
if channel_type:
|
||||
kwargs['channel_type'] = channel_type
|
||||
session_id = getattr(self, 'session_id', None)
|
||||
if session_id:
|
||||
kwargs['session_id'] = session_id
|
||||
|
||||
# Thinking mode is a global toggle independent of the channel.
|
||||
# IM channels (WeChat/WeCom/DingTalk/Feishu) won't render the
|
||||
# reasoning trace, but still benefit from the higher answer
|
||||
# quality the thinking pass produces.
|
||||
from config import conf
|
||||
thinking_enabled = bool(conf().get("enable_thinking", False))
|
||||
kwargs['thinking'] = (
|
||||
{"type": "enabled"} if thinking_enabled
|
||||
else {"type": "disabled"}
|
||||
)
|
||||
# Reasoning effort is only meaningful when thinking is on.
|
||||
# Bots that don't understand the kwarg drop it silently.
|
||||
if thinking_enabled:
|
||||
effort = conf().get("reasoning_effort", "high")
|
||||
if effort in ("high", "max"):
|
||||
kwargs['reasoning_effort'] = effort
|
||||
|
||||
stream = self.bot.call_with_tools(**kwargs)
|
||||
|
||||
# Convert stream format to our expected format
|
||||
@@ -191,6 +285,15 @@ class AgentBridge:
|
||||
|
||||
# Create helper instances
|
||||
self.initializer = AgentInitializer(bridge, self)
|
||||
|
||||
# Eager-start the scheduler so cron tasks fire without waiting
|
||||
# for the first user message. init_scheduler is idempotent.
|
||||
try:
|
||||
from agent.tools.scheduler.integration import init_scheduler
|
||||
if init_scheduler(self):
|
||||
self.scheduler_initialized = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Eager scheduler init failed: {e}")
|
||||
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
|
||||
"""
|
||||
Create the super agent with COW integration
|
||||
@@ -214,10 +317,13 @@ class AgentBridge:
|
||||
tool_manager.load_tools()
|
||||
|
||||
tools = []
|
||||
workspace_dir = kwargs.get("workspace_dir")
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
if tool:
|
||||
if workspace_dir and hasattr(tool, 'cwd'):
|
||||
tool.cwd = workspace_dir
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
|
||||
@@ -230,12 +336,13 @@ class AgentBridge:
|
||||
tools=tools,
|
||||
max_steps=kwargs.get("max_steps", 15),
|
||||
output_mode=kwargs.get("output_mode", "logger"),
|
||||
workspace_dir=kwargs.get("workspace_dir"), # Pass workspace for skills loading
|
||||
enable_skills=kwargs.get("enable_skills", True), # Enable skills by default
|
||||
memory_manager=kwargs.get("memory_manager"), # Pass memory manager
|
||||
workspace_dir=kwargs.get("workspace_dir"),
|
||||
skill_manager=kwargs.get("skill_manager"),
|
||||
enable_skills=kwargs.get("enable_skills", True),
|
||||
memory_manager=kwargs.get("memory_manager"),
|
||||
max_context_tokens=kwargs.get("max_context_tokens"),
|
||||
context_reserve_tokens=kwargs.get("context_reserve_tokens"),
|
||||
runtime_info=kwargs.get("runtime_info") # Pass runtime_info for dynamic time updates
|
||||
runtime_info=kwargs.get("runtime_info"),
|
||||
)
|
||||
|
||||
# Log skill loading details
|
||||
@@ -290,12 +397,24 @@ class AgentBridge:
|
||||
Returns:
|
||||
Reply object
|
||||
"""
|
||||
session_id = None
|
||||
agent = None
|
||||
request_id = None
|
||||
cancel_event = None
|
||||
try:
|
||||
# Extract session_id from context for user isolation
|
||||
session_id = None
|
||||
if context:
|
||||
session_id = context.kwargs.get("session_id") or context.get("session_id")
|
||||
|
||||
request_id = context.kwargs.get("request_id") or context.get("request_id")
|
||||
|
||||
# Register a cancel token. Prefer per-turn request_id (web),
|
||||
# fall back to session_id (IM channels). The Event is polled by
|
||||
# AgentStreamExecutor at safe checkpoints.
|
||||
registry = get_cancel_registry()
|
||||
token_key = request_id or session_id
|
||||
if token_key:
|
||||
cancel_event = registry.register(token_key, session_id=session_id)
|
||||
|
||||
# Get agent for this session (will auto-initialize if needed)
|
||||
agent = self.get_agent(session_id=session_id)
|
||||
if not agent:
|
||||
@@ -325,22 +444,73 @@ class AgentBridge:
|
||||
logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
|
||||
break
|
||||
|
||||
# Pass context metadata to model for downstream API requests
|
||||
if context and hasattr(agent, 'model'):
|
||||
agent.model.channel_type = context.get("channel_type", "")
|
||||
agent.model.session_id = session_id or ""
|
||||
|
||||
# Store session_id on agent so executor can clear DB on fatal errors
|
||||
agent._current_session_id = session_id
|
||||
|
||||
# Bound the in-memory context for scheduler sessions before each run.
|
||||
# Scheduler sessions are stable per-task and append every trigger,
|
||||
# so without trimming they would grow unbounded across runs and
|
||||
# blow up prompt cost. Regular user chats are not touched here —
|
||||
# the agent's own context manager handles that path.
|
||||
if session_id and session_id.startswith("scheduler_"):
|
||||
from config import conf
|
||||
scheduler_keep_turns = max(
|
||||
1, int(conf().get("agent_max_context_turns", 20)) // 5
|
||||
)
|
||||
self._trim_in_memory_to_turns(agent, scheduler_keep_turns)
|
||||
|
||||
try:
|
||||
# Use agent's run_stream method with event handler
|
||||
response = agent.run_stream(
|
||||
user_message=query,
|
||||
on_event=event_handler.handle_event,
|
||||
clear_history=clear_history
|
||||
clear_history=clear_history,
|
||||
cancel_event=cancel_event,
|
||||
)
|
||||
finally:
|
||||
# Restore original tools
|
||||
if context and context.get("is_scheduled_task"):
|
||||
agent.tools = original_tools
|
||||
|
||||
|
||||
# Log execution summary
|
||||
event_handler.log_summary()
|
||||
|
||||
# Release cancel token; keep registry bounded.
|
||||
if token_key:
|
||||
try:
|
||||
registry.unregister(token_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Persist new messages generated during this run
|
||||
if session_id:
|
||||
channel_type = (context.get("channel_type") or "") if context else ""
|
||||
new_messages = getattr(agent, '_last_run_new_messages', [])
|
||||
if new_messages:
|
||||
self._persist_messages(session_id, list(new_messages), channel_type)
|
||||
else:
|
||||
with agent.messages_lock:
|
||||
msg_count = len(agent.messages)
|
||||
if msg_count == 0:
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().clear_session(session_id)
|
||||
logger.info(f"[AgentBridge] Cleared DB for recovered session: {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to clear DB after recovery: {e}")
|
||||
|
||||
# Check if there are files to send (from read tool)
|
||||
# Post-message hot-reload: detect edits to ~/cow/mcp.json and
|
||||
# sync any new/removed MCP tools into the live agent in the
|
||||
# background. Off the critical path so user latency is unaffected;
|
||||
# changes take effect on the user's next message.
|
||||
self._schedule_mcp_hot_reload(agent)
|
||||
|
||||
# Check if there are files to send (from send/read tool)
|
||||
if hasattr(agent, 'stream_executor') and hasattr(agent.stream_executor, 'files_to_send'):
|
||||
files_to_send = agent.stream_executor.files_to_send
|
||||
if files_to_send:
|
||||
@@ -358,8 +528,51 @@ class AgentBridge:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent reply error: {e}")
|
||||
# If the agent cleared its messages due to format error / overflow,
|
||||
# also purge the DB so the next request starts clean.
|
||||
if session_id and agent:
|
||||
try:
|
||||
with agent.messages_lock:
|
||||
msg_count = len(agent.messages)
|
||||
if msg_count == 0:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().clear_session(session_id)
|
||||
logger.info(f"[AgentBridge] Cleared DB for session after error: {session_id}")
|
||||
except Exception as db_err:
|
||||
logger.warning(f"[AgentBridge] Failed to clear DB after error: {db_err}")
|
||||
# Release cancel token on error path too (idempotent).
|
||||
if cancel_event is not None and (request_id or session_id):
|
||||
try:
|
||||
get_cancel_registry().unregister(request_id or session_id)
|
||||
except Exception:
|
||||
pass
|
||||
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
|
||||
|
||||
def _schedule_mcp_hot_reload(self, agent):
|
||||
"""
|
||||
Fire-and-forget: detect mcp.json edits and reconcile the agent's
|
||||
tool dict in the background. Runs after the user's reply is sent,
|
||||
so any cost (file stat, hash, server boot) never adds to user latency.
|
||||
Failures are isolated and never raise into the message pipeline.
|
||||
"""
|
||||
import threading
|
||||
from agent.tools import ToolManager
|
||||
|
||||
def _run():
|
||||
try:
|
||||
tm = ToolManager()
|
||||
tm.refresh_mcp_if_changed()
|
||||
added, removed = tm.sync_mcp_into_agent(agent)
|
||||
if added or removed:
|
||||
logger.info(
|
||||
f"[AgentBridge] Agent tools synced — "
|
||||
f"added={added}, removed={removed}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] MCP hot-reload failed (non-fatal): {e}")
|
||||
|
||||
threading.Thread(target=_run, daemon=True, name="mcp-hot-reload").start()
|
||||
|
||||
def _create_file_reply(self, file_info: dict, text_response: str, context: Context = None) -> Reply:
|
||||
"""
|
||||
Create a reply for sending files
|
||||
@@ -397,22 +610,26 @@ class AgentBridge:
|
||||
reply.text_content = text_response
|
||||
return reply
|
||||
|
||||
# For other unknown file types, return text with file info
|
||||
message = text_response or file_info.get("message", "文件已准备")
|
||||
message += f"\n\n[文件: {file_info.get('file_name', file_path)}]"
|
||||
return Reply(ReplyType.TEXT, message)
|
||||
# For all other file types (tar.gz, zip, etc.), also use FILE type
|
||||
file_url = f"file://{file_path}"
|
||||
logger.info(f"[AgentBridge] Sending generic file: {file_url}")
|
||||
reply = Reply(ReplyType.FILE, file_url)
|
||||
reply.file_name = file_info.get("file_name", os.path.basename(file_path))
|
||||
if text_response:
|
||||
reply.text_content = text_response
|
||||
return reply
|
||||
|
||||
def _migrate_config_to_env(self, workspace_root: str):
|
||||
"""
|
||||
Migrate API keys from config.json to .env file if not already set
|
||||
|
||||
Sync API keys from config.json to .env file.
|
||||
Adds new keys and updates changed values on each startup.
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace directory path (not used, kept for compatibility)
|
||||
"""
|
||||
from config import conf
|
||||
import os
|
||||
|
||||
# Mapping from config.json keys to environment variable names
|
||||
key_mapping = {
|
||||
"open_ai_api_key": "OPENAI_API_KEY",
|
||||
"open_ai_api_base": "OPENAI_API_BASE",
|
||||
@@ -421,10 +638,9 @@ class AgentBridge:
|
||||
"linkai_api_key": "LINKAI_API_KEY",
|
||||
}
|
||||
|
||||
# Use fixed secure location for .env file
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
|
||||
# Read existing env vars from .env file
|
||||
# Read existing env vars (key -> value)
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
@@ -432,49 +648,300 @@ class AgentBridge:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, _ = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = True
|
||||
key, val = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = val.strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to read .env file: {e}")
|
||||
|
||||
# Check which keys need to be migrated
|
||||
keys_to_migrate = {}
|
||||
# Sync config.json values into .env (add/update/remove)
|
||||
updated = False
|
||||
for config_key, env_key in key_mapping.items():
|
||||
# Skip if already in .env file
|
||||
if env_key in existing_env_vars:
|
||||
continue
|
||||
|
||||
# Get value from config.json
|
||||
value = conf().get(config_key, "")
|
||||
if value and value.strip(): # Only migrate non-empty values
|
||||
keys_to_migrate[env_key] = value.strip()
|
||||
|
||||
# Log summary if there are keys to skip
|
||||
if existing_env_vars:
|
||||
logger.debug(f"[AgentBridge] {len(existing_env_vars)} env vars already in .env")
|
||||
|
||||
# Write new keys to .env file
|
||||
if keys_to_migrate:
|
||||
raw = conf().get(config_key, "")
|
||||
value = raw.strip() if raw else ""
|
||||
old_value = existing_env_vars.get(env_key)
|
||||
|
||||
if value:
|
||||
if old_value == value:
|
||||
continue
|
||||
existing_env_vars[env_key] = value
|
||||
os.environ[env_key] = value
|
||||
updated = True
|
||||
else:
|
||||
if old_value is None:
|
||||
continue
|
||||
existing_env_vars.pop(env_key, None)
|
||||
os.environ.pop(env_key, None)
|
||||
updated = True
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
try:
|
||||
# Ensure ~/.cow directory and .env file exist
|
||||
env_dir = os.path.dirname(env_file)
|
||||
if not os.path.exists(env_dir):
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
if not os.path.exists(env_file):
|
||||
open(env_file, 'a').close()
|
||||
|
||||
# Append new keys
|
||||
with open(env_file, 'a', encoding='utf-8') as f:
|
||||
f.write('\n# Auto-migrated from config.json\n')
|
||||
for key, value in keys_to_migrate.items():
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
|
||||
with open(env_file, 'w', encoding='utf-8') as f:
|
||||
f.write('# Environment variables for agent\n')
|
||||
f.write('# Auto-managed - synced from config.json on startup\n\n')
|
||||
for key, value in sorted(existing_env_vars.items()):
|
||||
f.write(f'{key}={value}\n')
|
||||
# Also set in current process
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[AgentBridge] Migrated {len(keys_to_migrate)} API keys from config.json to .env: {list(keys_to_migrate.keys())}")
|
||||
|
||||
logger.info(f"[AgentBridge] Synced API keys from config.json to .env")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to migrate API keys: {e}")
|
||||
logger.warning(f"[AgentBridge] Failed to sync API keys: {e}")
|
||||
|
||||
def _persist_messages(
|
||||
self, session_id: str, new_messages: list, channel_type: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
Persist new messages to the conversation store after each agent run.
|
||||
|
||||
Failures are logged but never propagate — they must not interrupt replies.
|
||||
"""
|
||||
if not new_messages:
|
||||
return
|
||||
try:
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
# When deep-thinking display is disabled, strip "thinking" content
|
||||
# blocks before persisting so they don't resurface on history reload.
|
||||
# The in-memory message list keeps them intact for this run's
|
||||
# multi-turn LLM context.
|
||||
thinking_enabled = bool(conf().get("enable_thinking", False))
|
||||
except Exception:
|
||||
thinking_enabled = False
|
||||
|
||||
messages_to_store = new_messages
|
||||
if not thinking_enabled:
|
||||
messages_to_store = self._strip_thinking_blocks(new_messages)
|
||||
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().append_messages(
|
||||
session_id, messages_to_store, channel_type=channel_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentBridge] Failed to persist messages for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
# Marker used to identify scheduler-injected user messages so we can apply
|
||||
# a sliding window without touching real user turns. The legacy prefix
|
||||
# "Scheduled task" (written by the v2 PR) is also recognised when pruning,
|
||||
# so old data can be aged out instead of leaking forever.
|
||||
_SCHEDULED_MARKER = "[SCHEDULED]"
|
||||
_SCHEDULED_LEGACY_MARKERS = ("Scheduled task",)
|
||||
|
||||
def remember_scheduled_output(
|
||||
self,
|
||||
session_id: str,
|
||||
content: str,
|
||||
channel_type: str = "",
|
||||
task_description: str = "",
|
||||
) -> None:
|
||||
"""Add the visible output of a scheduled task to the receiver's session.
|
||||
|
||||
Scheduled task execution uses an isolated session so internal planning and
|
||||
tool calls do not leak into the user's chat. The final message is still
|
||||
part of the conversation from the user's point of view, so keep a small
|
||||
visible turn in the receiver session for follow-up questions.
|
||||
|
||||
Configuration:
|
||||
scheduler_inject_to_session (bool, default True):
|
||||
Master switch. When False, this method is a no-op.
|
||||
scheduler_inject_max_per_session (int, default 3):
|
||||
Maximum scheduler-injected user/assistant pairs retained per
|
||||
session. Older injections are pruned automatically.
|
||||
|
||||
Content is truncated to 2000 chars to prevent a single high-volume task
|
||||
from bloating one entry.
|
||||
"""
|
||||
from config import conf
|
||||
if not conf().get("scheduler_inject_to_session", True):
|
||||
return
|
||||
if not session_id or not content:
|
||||
return
|
||||
|
||||
max_len = 2000
|
||||
if len(content) > max_len:
|
||||
content = content[:max_len] + "..."
|
||||
|
||||
user_text = self._SCHEDULED_MARKER
|
||||
if task_description:
|
||||
user_text = f"{self._SCHEDULED_MARKER} {task_description}"
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "text", "text": user_text}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": content}]},
|
||||
]
|
||||
|
||||
# Persist first so the new pair gets a stable seq, then prune old
|
||||
# scheduler pairs in DB, then sync the in-memory agent.messages buffer.
|
||||
self._persist_messages(session_id, messages, channel_type)
|
||||
|
||||
keep_last_n = max(int(conf().get("scheduler_inject_max_per_session", 3) or 0), 0)
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
deleted = get_conversation_store().prune_scheduled_messages(
|
||||
session_id, keep_last_n=keep_last_n
|
||||
)
|
||||
if deleted:
|
||||
logger.debug(
|
||||
f"[AgentBridge] Pruned {deleted} old scheduler messages "
|
||||
f"for session={session_id} (keep_last_n={keep_last_n})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentBridge] Failed to prune scheduled messages "
|
||||
f"for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
agent = self.agents.get(session_id)
|
||||
if agent:
|
||||
try:
|
||||
with agent.messages_lock:
|
||||
agent.messages.extend(messages)
|
||||
self._prune_scheduled_in_memory(agent, keep_last_n)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentBridge] Failed to update in-memory scheduled output "
|
||||
f"for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _trim_in_memory_to_turns(agent, keep_turns: int) -> None:
|
||||
"""Bound ``agent.messages`` to the most recent ``keep_turns`` real
|
||||
user/assistant turns, dropping older history together with any
|
||||
intermediate tool_use/tool_result blocks that belonged to it.
|
||||
|
||||
A "real" user message is any user message whose content is not solely a
|
||||
tool_result block — matches the heuristic used elsewhere when filtering
|
||||
history (see ``AgentInitializer._filter_text_only_messages``).
|
||||
|
||||
No-op when the session is already within budget. Caller does not need
|
||||
to hold the lock; this method acquires it itself.
|
||||
"""
|
||||
if keep_turns <= 0:
|
||||
return
|
||||
|
||||
def _is_real_user(msg) -> bool:
|
||||
if not isinstance(msg, dict) or msg.get("role") != "user":
|
||||
return False
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
if any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
):
|
||||
return False
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == "text" and b.get("text")
|
||||
for b in content
|
||||
)
|
||||
if isinstance(content, str):
|
||||
return bool(content.strip())
|
||||
return False
|
||||
|
||||
with agent.messages_lock:
|
||||
msgs = agent.messages
|
||||
real_user_indices = [i for i, m in enumerate(msgs) if _is_real_user(m)]
|
||||
if len(real_user_indices) <= keep_turns:
|
||||
return
|
||||
|
||||
# Cut at the (k-th from the end) real user message; keep everything
|
||||
# from there onwards so the surviving slice is still a valid
|
||||
# user/assistant sequence.
|
||||
cut_idx = real_user_indices[-keep_turns]
|
||||
if cut_idx == 0:
|
||||
return
|
||||
|
||||
kept = msgs[cut_idx:]
|
||||
msgs.clear()
|
||||
msgs.extend(kept)
|
||||
logger.debug(
|
||||
f"[AgentBridge] Trimmed in-memory messages to last "
|
||||
f"{keep_turns} turns ({len(kept)} messages remain)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _prune_scheduled_in_memory(cls, agent, keep_last_n: int) -> None:
|
||||
"""Mirror conversation_store.prune_scheduled_messages on agent.messages.
|
||||
|
||||
Caller must hold ``agent.messages_lock``.
|
||||
"""
|
||||
if keep_last_n < 0:
|
||||
keep_last_n = 0
|
||||
|
||||
markers = (cls._SCHEDULED_MARKER,) + cls._SCHEDULED_LEGACY_MARKERS
|
||||
|
||||
def _is_marker_user(msg) -> bool:
|
||||
if not isinstance(msg, dict) or msg.get("role") != "user":
|
||||
return False
|
||||
content = msg.get("content")
|
||||
text = ""
|
||||
if isinstance(content, str):
|
||||
text = content
|
||||
elif isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
break
|
||||
return any(text.startswith(m) for m in markers)
|
||||
|
||||
msgs = agent.messages
|
||||
pair_indices = [] # list of (user_idx, assistant_idx_or_None)
|
||||
for idx, msg in enumerate(msgs):
|
||||
if not _is_marker_user(msg):
|
||||
continue
|
||||
assistant_idx = None
|
||||
if idx + 1 < len(msgs):
|
||||
nxt = msgs[idx + 1]
|
||||
if isinstance(nxt, dict) and nxt.get("role") == "assistant":
|
||||
assistant_idx = idx + 1
|
||||
pair_indices.append((idx, assistant_idx))
|
||||
|
||||
if len(pair_indices) <= keep_last_n:
|
||||
return
|
||||
|
||||
to_drop = pair_indices[: len(pair_indices) - keep_last_n]
|
||||
drop_set = set()
|
||||
for u_idx, a_idx in to_drop:
|
||||
drop_set.add(u_idx)
|
||||
if a_idx is not None:
|
||||
drop_set.add(a_idx)
|
||||
|
||||
# Rebuild the list in place to keep external references stable.
|
||||
kept = [m for i, m in enumerate(msgs) if i not in drop_set]
|
||||
msgs.clear()
|
||||
msgs.extend(kept)
|
||||
|
||||
@staticmethod
|
||||
def _strip_thinking_blocks(messages: list) -> list:
|
||||
"""Return a shallow copy of messages with assistant "thinking" blocks removed."""
|
||||
cleaned = []
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
cleaned.append(msg)
|
||||
continue
|
||||
if msg.get("role") != "assistant":
|
||||
cleaned.append(msg)
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
cleaned.append(msg)
|
||||
continue
|
||||
filtered_blocks = [
|
||||
b for b in content
|
||||
if not (isinstance(b, dict) and b.get("type") == "thinking")
|
||||
]
|
||||
if len(filtered_blocks) == len(content):
|
||||
cleaned.append(msg)
|
||||
else:
|
||||
new_msg = dict(msg)
|
||||
new_msg["content"] = filtered_blocks
|
||||
cleaned.append(new_msg)
|
||||
return cleaned
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""
|
||||
Clear a specific session's agent and conversation history
|
||||
@@ -560,4 +1027,4 @@ class AgentBridge:
|
||||
agent.tools = [t for t in agent.tools if t.name != "web_search"]
|
||||
logger.info("[AgentBridge] web_search tool removed (API key no longer available)")
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentBridge] Failed to refresh conditional tools: {e}")
|
||||
logger.debug(f"[AgentBridge] Failed to refresh conditional tools: {e}")
|
||||
|
||||
@@ -2,114 +2,124 @@
|
||||
Agent Event Handler - Handles agent events and thinking process output
|
||||
"""
|
||||
|
||||
from common import const
|
||||
from common.log import logger
|
||||
|
||||
# Cap intermediate thinking messages on weixin to stay within send quota.
|
||||
WEIXIN_THINKING_INSTANT_MAX = 7
|
||||
|
||||
|
||||
class AgentEventHandler:
|
||||
"""
|
||||
Handles agent events and optionally sends intermediate messages to channel
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, context=None, original_callback=None):
|
||||
"""
|
||||
Initialize event handler
|
||||
|
||||
Args:
|
||||
context: COW context (for accessing channel)
|
||||
original_callback: Original event callback to chain
|
||||
"""
|
||||
self.context = context
|
||||
self.original_callback = original_callback
|
||||
|
||||
# Get channel for sending intermediate messages
|
||||
|
||||
self.channel = None
|
||||
if context:
|
||||
self.channel = context.kwargs.get("channel") if hasattr(context, "kwargs") else None
|
||||
|
||||
# Track current thinking for channel output
|
||||
self.current_thinking = ""
|
||||
|
||||
self.current_content = ""
|
||||
self.turn_number = 0
|
||||
|
||||
|
||||
channel_type = ""
|
||||
if context and hasattr(context, "kwargs"):
|
||||
channel_type = context.kwargs.get("channel_type", "") or ""
|
||||
self._is_weixin = channel_type == const.WEIXIN
|
||||
self._thinking_sent_count = 0
|
||||
self._merged_buf: list[str] = []
|
||||
|
||||
def handle_event(self, event):
|
||||
"""
|
||||
Main event handler
|
||||
|
||||
Args:
|
||||
event: Event dict with type and data
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
# Dispatch to specific handlers
|
||||
|
||||
if event_type == "turn_start":
|
||||
self._handle_turn_start(data)
|
||||
elif event_type == "message_update":
|
||||
self._handle_message_update(data)
|
||||
elif event_type == "message_end":
|
||||
self._handle_message_end(data)
|
||||
elif event_type == "reasoning_update":
|
||||
pass
|
||||
elif event_type == "tool_execution_start":
|
||||
self._handle_tool_execution_start(data)
|
||||
elif event_type == "tool_execution_end":
|
||||
self._handle_tool_execution_end(data)
|
||||
|
||||
# Call original callback if provided
|
||||
elif event_type == "agent_end":
|
||||
self._handle_agent_end(data)
|
||||
|
||||
if self.original_callback:
|
||||
self.original_callback(event)
|
||||
|
||||
|
||||
def _handle_turn_start(self, data):
|
||||
"""Handle turn start event"""
|
||||
self.turn_number = data.get("turn", 0)
|
||||
self.has_tool_calls_in_turn = False
|
||||
self.current_thinking = ""
|
||||
|
||||
self.current_content = ""
|
||||
|
||||
def _handle_message_update(self, data):
|
||||
"""Handle message update event (streaming text)"""
|
||||
delta = data.get("delta", "")
|
||||
self.current_thinking += delta
|
||||
|
||||
self.current_content += delta
|
||||
|
||||
def _handle_message_end(self, data):
|
||||
"""Handle message end event"""
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
|
||||
# Only send thinking process if followed by tool calls
|
||||
|
||||
if tool_calls:
|
||||
if self.current_thinking.strip():
|
||||
logger.info(f"💭 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
|
||||
# Send thinking process to channel
|
||||
self._send_to_channel(f"{self.current_thinking.strip()}")
|
||||
if self.current_content.strip():
|
||||
logger.info(f"💭 {self.current_content.strip()[:200]}{'...' if len(self.current_content) > 200 else ''}")
|
||||
self._send_to_channel(self.current_content.strip())
|
||||
else:
|
||||
# No tool calls = final response (logged at agent_stream level)
|
||||
if self.current_thinking.strip():
|
||||
logger.debug(f"💬 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
|
||||
|
||||
self.current_thinking = ""
|
||||
|
||||
if self.current_content.strip():
|
||||
logger.debug(f"💬 {self.current_content.strip()[:200]}{'...' if len(self.current_content) > 200 else ''}")
|
||||
# Drain weixin buffer before final reply leaves chat_channel
|
||||
self._flush_merged_now()
|
||||
|
||||
self.current_content = ""
|
||||
|
||||
def _handle_agent_end(self, data):
|
||||
self._flush_merged_now()
|
||||
|
||||
def _handle_tool_execution_start(self, data):
|
||||
"""Handle tool execution start event - logged by agent_stream.py"""
|
||||
pass
|
||||
|
||||
|
||||
def _handle_tool_execution_end(self, data):
|
||||
"""Handle tool execution end event - logged by agent_stream.py"""
|
||||
pass
|
||||
|
||||
|
||||
def _send_to_channel(self, message):
|
||||
"""
|
||||
Try to send message to channel
|
||||
|
||||
Args:
|
||||
message: Message to send
|
||||
"""
|
||||
if self.channel:
|
||||
try:
|
||||
from bridge.reply import Reply, ReplyType
|
||||
# Create a Reply object for the message
|
||||
reply = Reply(ReplyType.TEXT, message)
|
||||
self.channel._send(reply, self.context)
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentEventHandler] Failed to send to channel: {e}")
|
||||
|
||||
if self.context and self.context.get("on_event"):
|
||||
return
|
||||
if not self.channel:
|
||||
return
|
||||
|
||||
if not self._is_weixin:
|
||||
self._do_send(message)
|
||||
return
|
||||
|
||||
if self._thinking_sent_count < WEIXIN_THINKING_INSTANT_MAX:
|
||||
self._do_send(message)
|
||||
self._thinking_sent_count += 1
|
||||
return
|
||||
|
||||
self._merged_buf.append(message)
|
||||
|
||||
def _flush_merged_now(self):
|
||||
if not self._merged_buf:
|
||||
return
|
||||
merged = "\n\n".join(self._merged_buf)
|
||||
count = len(self._merged_buf)
|
||||
self._merged_buf = []
|
||||
logger.debug(f"[AgentEventHandler] Flushing {count} merged thinking msgs, len={len(merged)}")
|
||||
self._do_send(merged)
|
||||
self._thinking_sent_count += 1
|
||||
|
||||
def _do_send(self, message):
|
||||
try:
|
||||
from bridge.reply import Reply, ReplyType
|
||||
reply = Reply(ReplyType.TEXT, message)
|
||||
self.channel._send(reply, self.context)
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentEventHandler] Failed to send to channel: {e}")
|
||||
|
||||
def log_summary(self):
|
||||
"""Log execution summary - simplified"""
|
||||
# Summary removed as per user request
|
||||
# Real-time logging during execution is sufficient
|
||||
pass
|
||||
|
||||
@@ -5,6 +5,7 @@ Agent Initializer - Handles agent initialization logic
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, List
|
||||
|
||||
@@ -13,6 +14,13 @@ from agent.tools import ToolManager
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
# Module-level lock to serialize scheduler init across concurrent sessions
|
||||
_scheduler_init_lock = threading.Lock()
|
||||
|
||||
# Track whether the embedding model log has been printed in this process,
|
||||
# so we avoid spamming it once per session.
|
||||
_embedding_logged: bool = False
|
||||
|
||||
|
||||
class AgentInitializer:
|
||||
"""
|
||||
@@ -77,10 +85,6 @@ class AgentInitializer:
|
||||
# Initialize skill manager
|
||||
skill_manager = self._initialize_skill_manager(workspace_root, session_id)
|
||||
|
||||
# Check if first conversation
|
||||
from agent.prompt.workspace import is_first_conversation, mark_conversation_started
|
||||
is_first = is_first_conversation(workspace_root)
|
||||
|
||||
# Build system prompt
|
||||
prompt_builder = PromptBuilder(workspace_dir=workspace_root, language="zh")
|
||||
runtime_info = self._get_runtime_info(workspace_root)
|
||||
@@ -91,12 +95,8 @@ class AgentInitializer:
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
is_first_conversation=is_first
|
||||
)
|
||||
|
||||
if is_first:
|
||||
mark_conversation_started(workspace_root)
|
||||
|
||||
# Get cost control parameters
|
||||
from config import conf
|
||||
max_steps = conf().get("agent_max_steps", 20)
|
||||
@@ -115,11 +115,143 @@ class AgentInitializer:
|
||||
runtime_info=runtime_info # Pass runtime_info for dynamic time updates
|
||||
)
|
||||
|
||||
# Attach memory manager
|
||||
# Attach memory manager and share LLM model for summarization
|
||||
if memory_manager:
|
||||
agent.memory_manager = memory_manager
|
||||
|
||||
if hasattr(agent, 'model') and agent.model:
|
||||
memory_manager.flush_manager.llm_model = agent.model
|
||||
|
||||
# Restore persisted conversation history for this session
|
||||
if session_id:
|
||||
self._restore_conversation_history(agent, session_id)
|
||||
|
||||
# Start daily memory flush timer (once, on first agent init regardless of session)
|
||||
self._start_daily_flush_timer()
|
||||
|
||||
return agent
|
||||
|
||||
def _restore_conversation_history(self, agent, session_id: str) -> None:
|
||||
"""
|
||||
Load persisted conversation messages from SQLite and inject them
|
||||
into the agent's in-memory message list.
|
||||
|
||||
Only user text and assistant text are restored. Tool call chains
|
||||
(tool_use / tool_result) are stripped out because:
|
||||
1. They are intermediate process, the value is already in the final
|
||||
assistant text reply.
|
||||
2. They consume massive context tokens (often 80%+ of history).
|
||||
3. Different models have incompatible tool message formats, so
|
||||
restoring tool chains across model switches causes 400 errors.
|
||||
4. Eliminates the entire class of tool_use/tool_result pairing bugs.
|
||||
"""
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
store = get_conversation_store()
|
||||
max_turns = conf().get("agent_max_context_turns", 20)
|
||||
# Scheduler tasks run on a stable isolated session per task and
|
||||
# can fire many times a day; a smaller restore window keeps prompt
|
||||
# cost bounded while still letting the agent see "last few" runs
|
||||
# for trend / dedup style logic. Regular chat sessions keep the
|
||||
# original heuristic so user dialogues feel continuous.
|
||||
if session_id.startswith("scheduler_"):
|
||||
restore_turns = max(1, max_turns // 5)
|
||||
else:
|
||||
restore_turns = max(3, max_turns // 6)
|
||||
saved = store.load_messages(session_id, max_turns=restore_turns)
|
||||
if saved:
|
||||
filtered = self._filter_text_only_messages(saved)
|
||||
if filtered:
|
||||
with agent.messages_lock:
|
||||
agent.messages = filtered
|
||||
logger.debug(
|
||||
f"[AgentInitializer] Restored {len(filtered)} text messages "
|
||||
f"(from {len(saved)} total, {restore_turns} turns cap) "
|
||||
f"for session={session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentInitializer] Failed to restore conversation history for "
|
||||
f"session={session_id}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_text_only_messages(messages: list) -> list:
|
||||
"""
|
||||
Extract clean user/assistant turn pairs from raw message history.
|
||||
|
||||
Groups messages into turns (each starting with a real user query),
|
||||
then keeps only:
|
||||
- The first user text in each turn (the actual user input)
|
||||
- The last assistant text in each turn (the final answer)
|
||||
|
||||
All tool_use, tool_result, intermediate assistant thoughts, and
|
||||
internal hint messages injected by the agent loop are discarded.
|
||||
"""
|
||||
|
||||
def _extract_text(content) -> str:
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
def _is_real_user_msg(msg: dict) -> bool:
|
||||
"""True for actual user input, False for tool_result or internal hints."""
|
||||
if msg.get("role") != "user":
|
||||
return False
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
has_tool_result = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
)
|
||||
if has_tool_result:
|
||||
return False
|
||||
text = _extract_text(content)
|
||||
return bool(text)
|
||||
|
||||
# Group into turns: each turn starts with a real user message
|
||||
turns = []
|
||||
current_turn = None
|
||||
for msg in messages:
|
||||
if _is_real_user_msg(msg):
|
||||
if current_turn is not None:
|
||||
turns.append(current_turn)
|
||||
current_turn = {"user": msg, "assistants": []}
|
||||
elif current_turn is not None and msg.get("role") == "assistant":
|
||||
text = _extract_text(msg.get("content"))
|
||||
if text:
|
||||
current_turn["assistants"].append(text)
|
||||
if current_turn is not None:
|
||||
turns.append(current_turn)
|
||||
|
||||
# Build result: one user msg + one assistant msg per turn
|
||||
filtered = []
|
||||
for turn in turns:
|
||||
user_text = _extract_text(turn["user"].get("content"))
|
||||
if not user_text:
|
||||
continue
|
||||
filtered.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if turn["assistants"]:
|
||||
final_reply = turn["assistants"][-1]
|
||||
filtered.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": final_reply}]
|
||||
})
|
||||
|
||||
return filtered
|
||||
|
||||
def _load_env_file(self):
|
||||
"""Load environment variables from .env file"""
|
||||
@@ -144,37 +276,19 @@ class AgentInitializer:
|
||||
memory_tools = []
|
||||
|
||||
try:
|
||||
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
|
||||
from agent.memory import MemoryManager, MemoryConfig
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
from config import conf
|
||||
|
||||
# Get OpenAI config
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
|
||||
# Initialize embedding provider
|
||||
embedding_provider = None
|
||||
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key=openai_api_key,
|
||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||
)
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] OpenAI embedding initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
|
||||
|
||||
# Create memory manager
|
||||
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
embedding_provider = self._init_embedding_provider(
|
||||
memory_config, session_id=session_id
|
||||
)
|
||||
|
||||
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
|
||||
# Sync memory
|
||||
self._sync_memory(memory_manager, session_id)
|
||||
|
||||
# Create memory tools
|
||||
|
||||
memory_tools = [
|
||||
MemorySearchTool(memory_manager),
|
||||
MemoryGetTool(memory_manager)
|
||||
@@ -187,6 +301,190 @@ class AgentInitializer:
|
||||
logger.warning(f"[AgentInitializer] Memory system not available: {e}")
|
||||
|
||||
return memory_manager, memory_tools
|
||||
|
||||
def _init_embedding_provider(self, memory_config, session_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the embedding provider for memory.
|
||||
|
||||
Two paths:
|
||||
A. Default (no `embedding_provider` in config.json):
|
||||
Auto-init OpenAI -> LinkAI fallback. Existing 1536-dim indices
|
||||
keep working.
|
||||
B. Explicit (`embedding_provider` is set):
|
||||
Initialize the requested vendor with unified dim (default 1024).
|
||||
If the index was built with a different dim, vector search will
|
||||
quietly return no results (cosine returns 0) and keyword search
|
||||
takes over until the user runs /memory rebuild-index.
|
||||
"""
|
||||
from agent.memory import create_embedding_provider
|
||||
from config import conf
|
||||
|
||||
explicit_provider = (conf().get("embedding_provider") or "").strip().lower()
|
||||
|
||||
if not explicit_provider:
|
||||
return self._init_embedding_provider_legacy(session_id=session_id)
|
||||
|
||||
return self._init_embedding_provider_explicit(
|
||||
memory_config, explicit_provider, session_id=session_id,
|
||||
)
|
||||
|
||||
def _init_embedding_provider_legacy(self, session_id: Optional[str] = None):
|
||||
"""Legacy auto-init path: OpenAI -> LinkAI. Preserved verbatim for compat."""
|
||||
from agent.memory import create_embedding_provider
|
||||
from config import conf
|
||||
|
||||
embedding_provider = None
|
||||
embedding_model = None
|
||||
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
model = "text-embedding-3-small"
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model=model,
|
||||
api_key=openai_api_key,
|
||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||
)
|
||||
embedding_model = f"openai/{model}"
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
|
||||
|
||||
if embedding_provider is None:
|
||||
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
|
||||
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
model = "text-embedding-3-small"
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="linkai",
|
||||
model=model,
|
||||
api_key=linkai_api_key,
|
||||
api_base=f"{linkai_api_base}/v1"
|
||||
)
|
||||
embedding_model = f"linkai/{model}"
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
|
||||
|
||||
if embedding_provider is not None and embedding_model:
|
||||
global _embedding_logged
|
||||
if not _embedding_logged:
|
||||
logger.info(
|
||||
f"[AgentInitializer] Embedding model in use: {embedding_model} "
|
||||
f"(dim={embedding_provider.dimensions})"
|
||||
)
|
||||
_embedding_logged = True
|
||||
|
||||
return embedding_provider
|
||||
|
||||
def _init_embedding_provider_explicit(
|
||||
self,
|
||||
memory_config,
|
||||
provider_key: str,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
"""Explicit-provider path: build the configured vendor.
|
||||
|
||||
If the index was built with a different dim, vector search will
|
||||
silently return no results (cosine returns 0 for mismatched dims)
|
||||
and keyword search takes over. Users switch vendors by running
|
||||
/memory rebuild-index — see docs.
|
||||
"""
|
||||
from agent.memory import create_embedding_provider
|
||||
from agent.memory.embedding import EMBEDDING_VENDORS
|
||||
from config import conf
|
||||
|
||||
meta = EMBEDDING_VENDORS.get(provider_key)
|
||||
if meta is None:
|
||||
logger.error(
|
||||
f"[AgentInitializer] Unknown embedding_provider '{provider_key}'. "
|
||||
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}. "
|
||||
f"Memory will run in keyword-only mode."
|
||||
)
|
||||
return None
|
||||
|
||||
api_key = self._resolve_embedding_api_key(provider_key)
|
||||
api_base = self._resolve_embedding_api_base(provider_key, meta["default_base_url"])
|
||||
|
||||
if not api_key:
|
||||
logger.error(
|
||||
f"[AgentInitializer] embedding_provider='{provider_key}' is set but its "
|
||||
f"API key is missing. Memory will run in keyword-only mode."
|
||||
)
|
||||
return None
|
||||
|
||||
model = (conf().get("embedding_model") or "").strip() or meta["default_model"]
|
||||
try:
|
||||
cfg_dim = int(conf().get("embedding_dimensions") or 0)
|
||||
except (TypeError, ValueError):
|
||||
cfg_dim = 0
|
||||
dim = cfg_dim if cfg_dim > 0 else meta["default_dimensions"]
|
||||
|
||||
try:
|
||||
provider = create_embedding_provider(
|
||||
provider=provider_key,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
dimensions=dim,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AgentInitializer] Failed to init embedding provider "
|
||||
f"'{provider_key}/{model}': {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
global _embedding_logged
|
||||
if not _embedding_logged:
|
||||
logger.info(
|
||||
f"[AgentInitializer] Embedding model in use: "
|
||||
f"{provider_key}/{model} (dim={provider.dimensions})"
|
||||
)
|
||||
_embedding_logged = True
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def _resolve_embedding_api_key(provider_key: str) -> str:
|
||||
"""Pick the API key for an explicit embedding provider from config."""
|
||||
from config import conf
|
||||
|
||||
key_map = {
|
||||
"openai": "open_ai_api_key",
|
||||
"linkai": "linkai_api_key",
|
||||
"dashscope": "dashscope_api_key",
|
||||
"doubao": "ark_api_key",
|
||||
"zhipu": "zhipu_ai_api_key",
|
||||
}
|
||||
field = key_map.get(provider_key)
|
||||
if not field:
|
||||
return ""
|
||||
value = conf().get(field, "") or ""
|
||||
if value in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
return ""
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _resolve_embedding_api_base(provider_key: str, default_base: str) -> str:
|
||||
"""Pick the API base for an explicit embedding provider from config."""
|
||||
from config import conf
|
||||
|
||||
base_map = {
|
||||
"openai": "open_ai_api_base",
|
||||
"linkai": "linkai_api_base",
|
||||
"doubao": "ark_base_url",
|
||||
"zhipu": "zhipu_ai_api_base",
|
||||
}
|
||||
field = base_map.get(provider_key)
|
||||
if not field:
|
||||
return default_base
|
||||
value = (conf().get(field) or "").strip()
|
||||
if not value:
|
||||
return default_base
|
||||
if provider_key == "linkai" and not value.rstrip("/").endswith("/v1"):
|
||||
return f"{value.rstrip('/')}/v1"
|
||||
return value
|
||||
|
||||
def _sync_memory(self, memory_manager, session_id: Optional[str] = None):
|
||||
"""Sync memory database"""
|
||||
@@ -223,7 +521,7 @@ class AgentInitializer:
|
||||
if tool_name == "web_search":
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
if not WebSearch.is_available():
|
||||
logger.debug("[AgentInitializer] WebSearch skipped - no BOCHA_API_KEY or LINKAI_API_KEY")
|
||||
logger.debug("[AgentInitializer] WebSearch skipped - no search provider configured")
|
||||
continue
|
||||
|
||||
# Special handling for EnvConfig tool
|
||||
@@ -234,16 +532,33 @@ class AgentInitializer:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
if tool:
|
||||
# Apply workspace config to file operation tools
|
||||
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']:
|
||||
tool.config = file_config
|
||||
tool.cwd = file_config.get("cwd", getattr(tool, 'cwd', None))
|
||||
if 'memory_manager' in file_config:
|
||||
tool.memory_manager = file_config['memory_manager']
|
||||
# Apply workspace config to file operation tools.
|
||||
# Merge into the existing tool.config (set by ToolManager from
|
||||
# config.json's `tools.<name>` section) instead of replacing
|
||||
# it, otherwise per-tool user configs (e.g. browser.cdp_endpoint)
|
||||
# would be silently dropped.
|
||||
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls', 'web_fetch', 'send', 'browser']:
|
||||
merged_config = dict(getattr(tool, 'config', None) or {})
|
||||
merged_config.update(file_config)
|
||||
tool.config = merged_config
|
||||
tool.cwd = merged_config.get("cwd", getattr(tool, 'cwd', None))
|
||||
if 'memory_manager' in merged_config:
|
||||
tool.memory_manager = merged_config['memory_manager']
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to load tool {tool_name}: {e}")
|
||||
|
||||
|
||||
# Add MCP tools (snapshot to avoid races with the background loader)
|
||||
mcp_tools_snapshot = list(tool_manager._mcp_tool_instances.items())
|
||||
if mcp_tools_snapshot:
|
||||
for _, mcp_tool in mcp_tools_snapshot:
|
||||
tools.append(mcp_tool)
|
||||
if session_id is None:
|
||||
names = [name for name, _ in mcp_tools_snapshot]
|
||||
logger.info(
|
||||
f"[AgentInitializer] Added {len(names)} MCP tool(s): {names}"
|
||||
)
|
||||
|
||||
# Add memory tools
|
||||
if memory_tools:
|
||||
tools.extend(memory_tools)
|
||||
@@ -256,16 +571,23 @@ class AgentInitializer:
|
||||
return tools
|
||||
|
||||
def _initialize_scheduler(self, tools: List, session_id: Optional[str] = None):
|
||||
"""Initialize scheduler service if needed"""
|
||||
"""Initialize scheduler service if needed.
|
||||
|
||||
Serialize the check-and-set under a module-level lock so concurrent
|
||||
first-time session inits cannot each create a new SchedulerService
|
||||
(which would leak background scanning threads).
|
||||
"""
|
||||
if not self.agent_bridge.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import init_scheduler
|
||||
if init_scheduler(self.agent_bridge):
|
||||
self.agent_bridge.scheduler_initialized = True
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] Scheduler service initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize scheduler: {e}")
|
||||
with _scheduler_init_lock:
|
||||
if not self.agent_bridge.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import init_scheduler
|
||||
if init_scheduler(self.agent_bridge):
|
||||
self.agent_bridge.scheduler_initialized = True
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] Scheduler service initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize scheduler: {e}")
|
||||
|
||||
# Inject scheduler dependencies
|
||||
if self.agent_bridge.scheduler_initialized:
|
||||
@@ -283,7 +605,14 @@ class AgentInitializer:
|
||||
tool.scheduler_service = scheduler_service
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = conf().get("channel_type", "unknown")
|
||||
raw_ct = conf().get("channel_type", "unknown")
|
||||
if isinstance(raw_ct, list):
|
||||
ct = raw_ct[0] if raw_ct else "unknown"
|
||||
elif isinstance(raw_ct, str) and "," in raw_ct:
|
||||
ct = raw_ct.split(",")[0].strip()
|
||||
else:
|
||||
ct = raw_ct
|
||||
tool.config["channel_type"] = ct
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to inject scheduler dependencies: {e}")
|
||||
|
||||
@@ -327,10 +656,14 @@ class AgentInitializer:
|
||||
'timezone': timezone_name
|
||||
}
|
||||
|
||||
def get_model():
|
||||
"""Get current model name dynamically from config"""
|
||||
return conf().get("model", "unknown")
|
||||
|
||||
return {
|
||||
"model": conf().get("model", "unknown"),
|
||||
"_get_model": get_model,
|
||||
"workspace": workspace_root,
|
||||
"channel": conf().get("channel_type", "unknown"),
|
||||
"channel": ", ".join(conf().get("channel_type")) if isinstance(conf().get("channel_type"), list) else conf().get("channel_type", "unknown"),
|
||||
"_get_current_time": get_current_time # Dynamic time function
|
||||
}
|
||||
|
||||
@@ -348,7 +681,7 @@ class AgentInitializer:
|
||||
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
|
||||
# Read existing env vars
|
||||
# Read existing env vars (key -> value)
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
@@ -356,35 +689,126 @@ class AgentInitializer:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, _ = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = True
|
||||
key, val = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = val.strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to read .env file: {e}")
|
||||
|
||||
# Check which keys need migration
|
||||
keys_to_migrate = {}
|
||||
# Sync config.json values into .env (add/update/remove)
|
||||
updated = False
|
||||
for config_key, env_key in key_mapping.items():
|
||||
if env_key in existing_env_vars:
|
||||
continue
|
||||
value = conf().get(config_key, "")
|
||||
if value and value.strip():
|
||||
keys_to_migrate[env_key] = value.strip()
|
||||
|
||||
# Write new keys
|
||||
if keys_to_migrate:
|
||||
raw = conf().get(config_key, "")
|
||||
value = raw.strip() if raw else ""
|
||||
old_value = existing_env_vars.get(env_key)
|
||||
|
||||
if value:
|
||||
if old_value == value:
|
||||
continue
|
||||
existing_env_vars[env_key] = value
|
||||
os.environ[env_key] = value
|
||||
updated = True
|
||||
else:
|
||||
if old_value is None:
|
||||
continue
|
||||
existing_env_vars.pop(env_key, None)
|
||||
os.environ.pop(env_key, None)
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
try:
|
||||
env_dir = os.path.dirname(env_file)
|
||||
if not os.path.exists(env_dir):
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
if not os.path.exists(env_file):
|
||||
open(env_file, 'a').close()
|
||||
|
||||
with open(env_file, 'a', encoding='utf-8') as f:
|
||||
f.write('\n# Auto-migrated from config.json\n')
|
||||
for key, value in keys_to_migrate.items():
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
|
||||
# Rewrite the entire .env file to ensure consistency
|
||||
with open(env_file, 'w', encoding='utf-8') as f:
|
||||
f.write('# Environment variables for agent\n')
|
||||
f.write('# Auto-managed - synced from config.json on startup\n\n')
|
||||
for key, value in sorted(existing_env_vars.items()):
|
||||
f.write(f'{key}={value}\n')
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[AgentInitializer] Migrated {len(keys_to_migrate)} API keys to .env: {list(keys_to_migrate.keys())}")
|
||||
|
||||
logger.info(f"[AgentInitializer] Synced API keys from config.json to .env")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to migrate API keys: {e}")
|
||||
logger.warning(f"[AgentInitializer] Failed to sync API keys: {e}")
|
||||
|
||||
def _start_daily_flush_timer(self):
|
||||
"""Start a background thread that flushes all agents' memory daily at 23:55."""
|
||||
if getattr(self.agent_bridge, '_daily_flush_started', False):
|
||||
return
|
||||
self.agent_bridge._daily_flush_started = True
|
||||
|
||||
import threading
|
||||
|
||||
def _daily_flush_loop():
|
||||
import random
|
||||
last_run_date = None # Track last successful run date to prevent same-day re-trigger
|
||||
while True:
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
jitter_min = random.randint(50, 55)
|
||||
jitter_sec = random.randint(0, 59)
|
||||
target = now.replace(hour=23, minute=jitter_min, second=jitter_sec, microsecond=0)
|
||||
# Always schedule for tomorrow if we already ran today, or if target time has passed
|
||||
if target <= now or (last_run_date == now.date()):
|
||||
target += datetime.timedelta(days=1)
|
||||
wait_seconds = (target - now).total_seconds()
|
||||
logger.info(f"[DailyFlush] Next flush at {target.strftime('%Y-%m-%d %H:%M:%S')} (in {wait_seconds/3600:.1f}h)")
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
self._flush_all_agents()
|
||||
last_run_date = datetime.datetime.now().date()
|
||||
except Exception as e:
|
||||
logger.warning(f"[DailyFlush] Error in daily flush loop: {e}")
|
||||
time.sleep(3600)
|
||||
|
||||
t = threading.Thread(target=_daily_flush_loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
def _flush_all_agents(self):
|
||||
"""Flush memory for all active agent sessions, then run Deep Dream."""
|
||||
agents = []
|
||||
if self.agent_bridge.default_agent:
|
||||
agents.append(("default", self.agent_bridge.default_agent))
|
||||
for sid, agent in self.agent_bridge.agents.items():
|
||||
agents.append((sid, agent))
|
||||
|
||||
if not agents:
|
||||
return
|
||||
|
||||
# Phase 1: flush daily summaries
|
||||
flushed = 0
|
||||
flush_threads = []
|
||||
dream_candidate = None
|
||||
for label, agent in agents:
|
||||
try:
|
||||
if not agent.memory_manager:
|
||||
continue
|
||||
with agent.messages_lock:
|
||||
messages = list(agent.messages)
|
||||
if not messages:
|
||||
continue
|
||||
result = agent.memory_manager.flush_manager.create_daily_summary(messages)
|
||||
if result:
|
||||
flushed += 1
|
||||
t = agent.memory_manager.flush_manager._last_flush_thread
|
||||
if t:
|
||||
flush_threads.append(t)
|
||||
if dream_candidate is None:
|
||||
dream_candidate = agent.memory_manager.flush_manager
|
||||
except Exception as e:
|
||||
logger.warning(f"[DailyFlush] Failed for session {label}: {e}")
|
||||
|
||||
if flushed:
|
||||
logger.info(f"[DailyFlush] Flushed {flushed}/{len(agents)} agent session(s)")
|
||||
|
||||
# Wait for all flush threads to finish before dreaming
|
||||
for t in flush_threads:
|
||||
t.join(timeout=60)
|
||||
|
||||
# Phase 2: Deep Dream — distill daily memories → MEMORY.md + dream diary
|
||||
if dream_candidate:
|
||||
try:
|
||||
result = dream_candidate.deep_dream()
|
||||
if result:
|
||||
logger.info("[DeepDream] Memory distillation completed successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DeepDream] Failed: {e}")
|
||||
|
||||
@@ -13,8 +13,10 @@ from voice.factory import create_voice
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
self.btype = {
|
||||
"chat": const.CHATGPT,
|
||||
"voice_to_text": conf().get("voice_to_text", "openai"),
|
||||
"chat": const.OPENAI,
|
||||
# Empty `voice_to_text` (the default in new configs) triggers
|
||||
# the auto-pick below — see _auto_pick_voice_to_text for order.
|
||||
"voice_to_text": conf().get("voice_to_text") or self._auto_pick_voice_to_text(),
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
}
|
||||
@@ -39,11 +41,8 @@ class Bridge(object):
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
if model_type in [const.QWEN, const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
# Support Qwen3 and other DashScope models
|
||||
if model_type and (model_type.startswith("qwen") or model_type.startswith("qwq") or model_type.startswith("qvq")):
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and model_type.startswith("gemini"):
|
||||
@@ -61,6 +60,18 @@ class Bridge(object):
|
||||
if model_type and model_type.startswith("doubao"):
|
||||
self.btype["chat"] = const.DOUBAO
|
||||
|
||||
if model_type and model_type.startswith("deepseek"):
|
||||
self.btype["chat"] = const.DEEPSEEK
|
||||
|
||||
# 小米 MiMo 系列模型,全部以 mimo- 开头
|
||||
if model_type and model_type.startswith("mimo-"):
|
||||
self.btype["chat"] = const.MIMO
|
||||
|
||||
if model_type and isinstance(model_type, str):
|
||||
lowered_model_type = model_type.lower()
|
||||
if lowered_model_type == const.QIANFAN or lowered_model_type.startswith("ernie"):
|
||||
self.btype["chat"] = const.QIANFAN
|
||||
|
||||
if model_type in [const.MODELSCOPE]:
|
||||
self.btype["chat"] = const.MODELSCOPE
|
||||
|
||||
@@ -79,6 +90,46 @@ class Bridge(object):
|
||||
self.chat_bots = {}
|
||||
self._agent_bridge = None
|
||||
|
||||
def refresh_voice(self):
|
||||
"""Re-read voice_to_text / text_to_voice from config and drop the
|
||||
cached voice bots so the next call picks up the new provider.
|
||||
Used by the web console after the user edits voice settings.
|
||||
Does NOT touch the agent_bridge / agent state.
|
||||
"""
|
||||
new_v2t = conf().get("voice_to_text") or self._auto_pick_voice_to_text()
|
||||
new_t2v = conf().get("text_to_voice", "google")
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
new_v2t = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
new_t2v = const.LINKAI
|
||||
self.btype["voice_to_text"] = new_v2t
|
||||
self.btype["text_to_voice"] = new_t2v
|
||||
self.bots.pop("voice_to_text", None)
|
||||
self.bots.pop("text_to_voice", None)
|
||||
logger.info(f"[Bridge] voice refreshed: voice_to_text={new_v2t}, text_to_voice={new_t2v}")
|
||||
|
||||
@staticmethod
|
||||
def _auto_pick_voice_to_text() -> str:
|
||||
"""Pick an ASR provider by configured api keys when voice_to_text is
|
||||
unset. Order matches the web console: openai → dashscope → zhipu →
|
||||
linkai. Falls back to 'openai' when nothing is configured so the
|
||||
original "missing key" error is preserved.
|
||||
"""
|
||||
def has(k: str) -> bool:
|
||||
v = (conf().get(k) or "").strip()
|
||||
return v != "" and v not in ("YOUR API KEY", "YOUR_API_KEY")
|
||||
|
||||
for key, provider in (
|
||||
("open_ai_api_key", "openai"),
|
||||
("dashscope_api_key", "dashscope"),
|
||||
("zhipu_ai_api_key", "zhipu"),
|
||||
("linkai_api_key", "linkai"),
|
||||
):
|
||||
if has(key):
|
||||
return provider
|
||||
return "openai"
|
||||
|
||||
# 模型对应的接口
|
||||
def get_bot(self, typename):
|
||||
if self.bots.get(typename) is None:
|
||||
|
||||
@@ -13,12 +13,38 @@ class Channel(object):
|
||||
channel_type = ""
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
||||
|
||||
def __init__(self):
|
||||
import threading
|
||||
self._startup_event = threading.Event()
|
||||
self._startup_error = None
|
||||
self.cloud_mode = False # set to True by ChannelManager when running with cloud client
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
init channel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def report_startup_success(self):
|
||||
self._startup_error = None
|
||||
self._startup_event.set()
|
||||
|
||||
def report_startup_error(self, error: str):
|
||||
self._startup_error = error
|
||||
self._startup_event.set()
|
||||
|
||||
def wait_startup(self, timeout: float = 3) -> (bool, str):
|
||||
"""
|
||||
Wait for channel startup result.
|
||||
Returns (success: bool, error_msg: str).
|
||||
"""
|
||||
ready = self._startup_event.wait(timeout=timeout)
|
||||
if not ready:
|
||||
return True, ""
|
||||
if self._startup_error:
|
||||
return False, self._startup_error
|
||||
return True, ""
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
stop channel gracefully, called before restart
|
||||
@@ -47,7 +73,7 @@ class Channel(object):
|
||||
Build reply content, using agent if enabled in config
|
||||
"""
|
||||
# Check if agent mode is enabled
|
||||
use_agent = conf().get("agent", False)
|
||||
use_agent = conf().get("agent", True)
|
||||
|
||||
if use_agent:
|
||||
try:
|
||||
@@ -57,11 +83,14 @@ class Channel(object):
|
||||
if context and "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
|
||||
# Read on_event callback injected by the channel (e.g. web SSE)
|
||||
on_event = context.get("on_event") if context else None
|
||||
|
||||
# Use agent bridge to handle the query
|
||||
return Bridge().fetch_agent_reply(
|
||||
query=query,
|
||||
context=context,
|
||||
on_event=None,
|
||||
on_event=on_event,
|
||||
clear_history=False
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -12,16 +12,7 @@ def create_channel(channel_type) -> Channel:
|
||||
:return: channel instance
|
||||
"""
|
||||
ch = Channel()
|
||||
if channel_type == "wx":
|
||||
from channel.wechat.wechat_channel import WechatChannel
|
||||
ch = WechatChannel()
|
||||
elif channel_type == "wxy":
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
ch = WechatyChannel()
|
||||
elif channel_type == "wcf":
|
||||
from channel.wechat.wcf_channel import WechatfChannel
|
||||
ch = WechatfChannel()
|
||||
elif channel_type == "terminal":
|
||||
if channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == 'web':
|
||||
@@ -36,15 +27,22 @@ def create_channel(channel_type) -> Channel:
|
||||
elif channel_type == "wechatcom_app":
|
||||
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
|
||||
ch = WechatComAppChannel()
|
||||
elif channel_type == "wework":
|
||||
from channel.wework.wework_channel import WeworkChannel
|
||||
ch = WeworkChannel()
|
||||
elif channel_type == const.FEISHU:
|
||||
from channel.feishu.feishu_channel import FeiShuChanel
|
||||
ch = FeiShuChanel()
|
||||
elif channel_type == const.DINGTALK:
|
||||
from channel.dingtalk.dingtalk_channel import DingTalkChanel
|
||||
ch = DingTalkChanel()
|
||||
elif channel_type == const.WECOM_BOT:
|
||||
from channel.wecom_bot.wecom_bot_channel import WecomBotChannel
|
||||
ch = WecomBotChannel()
|
||||
elif channel_type == const.QQ:
|
||||
from channel.qq.qq_channel import QQChannel
|
||||
ch = QQChannel()
|
||||
elif channel_type in (const.WEIXIN, "wx"):
|
||||
from channel.weixin.weixin_channel import WeixinChannel
|
||||
ch = WeixinChannel()
|
||||
channel_type = const.WEIXIN
|
||||
else:
|
||||
raise RuntimeError
|
||||
ch.channel_type = channel_type
|
||||
|
||||
@@ -24,11 +24,17 @@ handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Instance-level attributes so each channel subclass has its own
|
||||
# independent session queue and lock. Previously these were class-level,
|
||||
# which caused contexts from one channel (e.g. Feishu) to be consumed
|
||||
# by another channel's consume() thread (e.g. Web), leading to errors
|
||||
# like "No request_id found in context".
|
||||
self.futures = {}
|
||||
self.sessions = {}
|
||||
self.lock = threading.Lock()
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
@@ -37,9 +43,8 @@ class ChatChannel(Channel):
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
# context首次传入时,origin_ctype是None,
|
||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
# context首次传入时,receiver是None,根据类型设置receiver
|
||||
@@ -166,7 +171,13 @@ class ChatChannel(Channel):
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
# Voice input replies with voice when either voice_reply_voice
|
||||
# (mirror voice) or the global always_reply_voice toggle is on.
|
||||
if (
|
||||
"desire_rtype" not in context
|
||||
and (conf().get("voice_reply_voice") or conf().get("always_reply_voice"))
|
||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
|
||||
):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
return context
|
||||
|
||||
@@ -259,6 +270,8 @@ class ChatChannel(Channel):
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
# Preserve original text for the "text-then-voice" pattern in _send_reply.
|
||||
context["voice_reply_text"] = reply.content
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context.get("isgroup", False):
|
||||
@@ -292,8 +305,12 @@ class ChatChannel(Channel):
|
||||
logger.debug("[chat_channel] sending reply: {}, context: {}".format(reply, context))
|
||||
|
||||
# 如果是文本回复,尝试提取并发送图片
|
||||
if reply.type == ReplyType.TEXT:
|
||||
# Web channel renders images/videos inline via renderMarkdown,
|
||||
# so skip the extract-and-send step to avoid duplicate media.
|
||||
if reply.type == ReplyType.TEXT and context.get("channel_type") != "web":
|
||||
self._extract_and_send_images(reply, context)
|
||||
elif reply.type == ReplyType.TEXT:
|
||||
self._send(reply, context)
|
||||
# 如果是图片回复但带有文本内容,先发文本再发图片
|
||||
elif reply.type == ReplyType.IMAGE_URL and hasattr(reply, 'text_content') and reply.text_content:
|
||||
# 先发送文本
|
||||
@@ -302,6 +319,15 @@ class ChatChannel(Channel):
|
||||
# 短暂延迟后发送图片
|
||||
time.sleep(0.3)
|
||||
self._send(reply, context)
|
||||
# Send text bubble before voice, unless channel already streamed
|
||||
# the text (feishu) or natively renders STT under the voice (wechatcom).
|
||||
elif reply.type == ReplyType.VOICE and context.get("voice_reply_text") \
|
||||
and not context.get("feishu_streamed") \
|
||||
and context.get("channel_type") not in ("wechatcom_app",):
|
||||
text_reply = Reply(ReplyType.TEXT, context.get("voice_reply_text"))
|
||||
self._send(text_reply, context)
|
||||
time.sleep(0.3)
|
||||
self._send(reply, context)
|
||||
else:
|
||||
self._send(reply, context)
|
||||
|
||||
@@ -342,38 +368,30 @@ class ChatChannel(Channel):
|
||||
if media_items:
|
||||
logger.info(f"[chat_channel] Extracted {len(media_items)} media item(s) from reply")
|
||||
|
||||
# 先发送文本(保持原文本不变)
|
||||
# Send text first (the frontend will embed video players via renderMarkdown).
|
||||
logger.info(f"[chat_channel] Sending text content before media: {reply.content[:100]}...")
|
||||
self._send(reply, context)
|
||||
logger.info(f"[chat_channel] Text sent, now sending {len(media_items)} media item(s)")
|
||||
|
||||
# 然后逐个发送媒体文件
|
||||
for i, (url, media_type) in enumerate(media_items):
|
||||
try:
|
||||
# 判断是本地文件还是URL
|
||||
# Determine whether it is a remote URL or a local file.
|
||||
if url.startswith(('http://', 'https://')):
|
||||
# 网络资源
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型发送
|
||||
media_reply = Reply(ReplyType.FILE, url)
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, url)
|
||||
elif os.path.exists(url):
|
||||
# 本地文件
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.FILE, f"file://{url}")
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, f"file://{url}")
|
||||
else:
|
||||
logger.warning(f"[chat_channel] Media file not found or invalid URL: {url}")
|
||||
continue
|
||||
|
||||
# 发送媒体文件(添加小延迟避免频率限制)
|
||||
if i > 0:
|
||||
time.sleep(0.5)
|
||||
self._send(media_reply, context)
|
||||
@@ -420,19 +438,55 @@ class ChatChannel(Channel):
|
||||
|
||||
return func
|
||||
|
||||
# Chat commands that must bypass the per-session serial queue,
|
||||
# otherwise /cancel would queue behind the task it tries to cancel.
|
||||
# Use /cancel (not /stop) to avoid colliding with `cow stop` CLI.
|
||||
_BYPASS_QUEUE_COMMANDS = ("/cancel",)
|
||||
|
||||
def produce(self, context: Context):
|
||||
session_id = context["session_id"]
|
||||
|
||||
# Fast path: /cancel must not enter the queue.
|
||||
if context.type == ContextType.TEXT and context.content:
|
||||
stripped = context.content.strip().lower()
|
||||
if stripped in self._BYPASS_QUEUE_COMMANDS:
|
||||
self._handle_cancel_command(context, session_id)
|
||||
return
|
||||
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = [
|
||||
Dequeue(),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)),
|
||||
]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
else:
|
||||
self.sessions[session_id][0].put(context)
|
||||
|
||||
def _handle_cancel_command(self, context: Context, session_id: str) -> None:
|
||||
"""Cancel any in-flight agent run for *session_id* and reply inline.
|
||||
|
||||
Runs synchronously on the caller's thread. Reply is sent through
|
||||
_send_reply so plugins (e.g. logging) still observe it.
|
||||
"""
|
||||
try:
|
||||
from agent.protocol import get_cancel_registry
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
cancelled = get_cancel_registry().cancel_session(session_id)
|
||||
text = (
|
||||
"🛑 已中止"
|
||||
if cancelled > 0
|
||||
else "当前没有可中止的任务。"
|
||||
)
|
||||
logger.info(
|
||||
f"[chat_channel] /cancel fast-path: session={session_id}, cancelled={cancelled}"
|
||||
)
|
||||
self._send_reply(context, Reply(ReplyType.TEXT, text))
|
||||
except Exception as e:
|
||||
logger.warning(f"[chat_channel] /cancel fast-path failed: {e}")
|
||||
|
||||
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
||||
def consume(self):
|
||||
while True:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
||||
Unified chat message class for different channel implementations.
|
||||
|
||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
||||
|
||||
|
||||
@@ -86,6 +86,8 @@ def _check(func):
|
||||
|
||||
@singleton
|
||||
class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
|
||||
@@ -101,6 +103,8 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
self._stream_client = None
|
||||
self._running = False
|
||||
self._event_loop = None
|
||||
logger.debug("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
@@ -113,22 +117,130 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
# Robot code cache (extracted from incoming messages)
|
||||
self._robot_code = None
|
||||
|
||||
def _open_connection(self, client):
|
||||
"""
|
||||
Open a DingTalk stream connection directly, bypassing SDK's internal error-swallowing.
|
||||
Returns (connection_dict, error_str). On success error_str is empty; on failure
|
||||
connection_dict is None and error_str contains a human-readable message.
|
||||
"""
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://api.dingtalk.com/v1.0/gateway/connections/open",
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
json={
|
||||
"clientId": client.credential.client_id,
|
||||
"clientSecret": client.credential.client_secret,
|
||||
"subscriptions": [{"type": "CALLBACK",
|
||||
"topic": dingtalk_stream.chatbot.ChatbotMessage.TOPIC}],
|
||||
"ua": "dingtalk-sdk-python/cow",
|
||||
"localIp": "",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
body = resp.json()
|
||||
if not resp.ok:
|
||||
code = body.get("code", resp.status_code)
|
||||
message = body.get("message", resp.reason)
|
||||
return None, f"open connection failed: [{code}] {message}"
|
||||
return body, ""
|
||||
except Exception as e:
|
||||
return None, f"open connection failed: {e}"
|
||||
|
||||
def startup(self):
|
||||
import asyncio
|
||||
self.dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
self.dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
self._running = True
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
self._stream_client = client
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
logger.info("[DingTalk] ✅ Stream connected, ready to receive messages")
|
||||
client.start_forever()
|
||||
logger.info("[DingTalk] ✅ Stream client initialized, ready to receive messages")
|
||||
|
||||
# Run the connection loop ourselves instead of delegating to client.start(),
|
||||
# so we can get detailed error messages and respond to stop() quickly.
|
||||
import urllib.parse as _urlparse
|
||||
import websockets as _ws
|
||||
import json as _json
|
||||
client.pre_start()
|
||||
_first_connect = True
|
||||
while self._running:
|
||||
# Open connection using our own request so we get detailed error info.
|
||||
connection, err_msg = self._open_connection(client)
|
||||
|
||||
if connection is None:
|
||||
if _first_connect:
|
||||
logger.warning(f"[DingTalk] {err_msg}")
|
||||
self.report_startup_error(err_msg)
|
||||
_first_connect = False
|
||||
else:
|
||||
logger.warning(f"[DingTalk] {err_msg}, retrying in 10s...")
|
||||
|
||||
# Interruptible sleep: checks _running every 100ms.
|
||||
for _ in range(100):
|
||||
if not self._running:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
if _first_connect:
|
||||
logger.info("[DingTalk] ✅ Connected to DingTalk stream")
|
||||
self.report_startup_success()
|
||||
_first_connect = False
|
||||
else:
|
||||
logger.info("[DingTalk] Reconnected to DingTalk stream")
|
||||
|
||||
# Run the WebSocket session in an asyncio loop.
|
||||
uri = '%s?ticket=%s' % (
|
||||
connection['endpoint'],
|
||||
_urlparse.quote_plus(connection['ticket'])
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
self._event_loop = loop
|
||||
try:
|
||||
async def _session():
|
||||
async with _ws.connect(uri) as websocket:
|
||||
client.websocket = websocket
|
||||
async for raw_message in websocket:
|
||||
json_message = _json.loads(raw_message)
|
||||
result = await client.route_message(json_message)
|
||||
if result == dingtalk_stream.DingTalkStreamClient.TAG_DISCONNECT:
|
||||
break
|
||||
|
||||
loop.run_until_complete(_session())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("[DingTalk] Session loop received stop signal, exiting")
|
||||
break
|
||||
except Exception as e:
|
||||
if not self._running:
|
||||
break
|
||||
logger.warning(f"[DingTalk] Stream session error: {e}, reconnecting in 3s...")
|
||||
for _ in range(30):
|
||||
if not self._running:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
finally:
|
||||
self._event_loop = None
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("[DingTalk] Startup loop exited")
|
||||
|
||||
def stop(self):
|
||||
if self._stream_client:
|
||||
logger.info("[DingTalk] stop() called, setting _running=False")
|
||||
self._running = False
|
||||
loop = self._event_loop
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
self._stream_client.stop()
|
||||
logger.info("[DingTalk] Stream client stopped")
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
logger.info("[DingTalk] Sent stop signal to event loop")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DingTalk] Error stopping stream client: {e}")
|
||||
self._stream_client = None
|
||||
logger.warning(f"[DingTalk] Error stopping event loop: {e}")
|
||||
self._stream_client = None
|
||||
logger.info("[DingTalk] stop() completed")
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
@@ -465,23 +577,21 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
|
||||
|
||||
# 缓存 robot_code,用于后续图片下载
|
||||
if hasattr(incoming_message, 'robot_code'):
|
||||
self._robot_code_cache = incoming_message.robot_code
|
||||
|
||||
# Debug: 打印完整的 event 数据
|
||||
logger.debug(f"[DingTalk] ===== Incoming Message Debug =====")
|
||||
logger.debug(f"[DingTalk] callback.data keys: {callback.data.keys() if hasattr(callback.data, 'keys') else 'N/A'}")
|
||||
logger.debug(f"[DingTalk] incoming_message attributes: {dir(incoming_message)}")
|
||||
logger.debug(f"[DingTalk] robot_code: {getattr(incoming_message, 'robot_code', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] chatbot_corp_id: {getattr(incoming_message, 'chatbot_corp_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] chatbot_user_id: {getattr(incoming_message, 'chatbot_user_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] conversation_id: {getattr(incoming_message, 'conversation_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] Raw callback.data: {callback.data}")
|
||||
logger.debug(f"[DingTalk] =====================================")
|
||||
|
||||
image_download_handler = self # 传入方法所在的类实例
|
||||
|
||||
# Filter out stale messages from before channel startup (offline backlog)
|
||||
create_at = getattr(incoming_message, 'create_at', None)
|
||||
if create_at:
|
||||
msg_age_s = time.time() - int(create_at) / 1000
|
||||
if msg_age_s > 60:
|
||||
logger.warning(f"[DingTalk] stale msg filtered (age={msg_age_s:.0f}s), "
|
||||
f"msg_id={getattr(incoming_message, 'message_id', 'N/A')}")
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
|
||||
image_download_handler = self
|
||||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
|
||||
|
||||
if dingtalk_msg.is_group:
|
||||
@@ -490,8 +600,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.handle_single(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] process error: {e}")
|
||||
logger.exception(e) # 打印完整堆栈跟踪
|
||||
logger.error(f"[DingTalk] process error: {e}", exc_info=True)
|
||||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
|
||||
|
||||
@time_checker
|
||||
@@ -763,6 +872,48 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.reply_text("抱歉,文件上传失败", incoming_message)
|
||||
return
|
||||
|
||||
# Native sampleAudio. Upload only accepts ogg/amr, so convert TTS mp3/wav to amr.
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
logger.info(f"[DingTalk] Sending voice: {reply.content}")
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot get access token for voice")
|
||||
self.reply_text("抱歉,语音发送失败(无法获取token)", incoming_message)
|
||||
return
|
||||
|
||||
voice_path = reply.content
|
||||
if voice_path.startswith("file://"):
|
||||
voice_path = voice_path[7:]
|
||||
|
||||
amr_path = voice_path
|
||||
duration_ms = 0
|
||||
if not voice_path.lower().endswith((".amr", ".ogg")):
|
||||
try:
|
||||
from voice.audio_convert import any_to_amr
|
||||
amr_path = os.path.splitext(voice_path)[0] + ".amr"
|
||||
duration_ms = int(any_to_amr(voice_path, amr_path) or 0)
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Failed to convert voice to amr: {e}")
|
||||
self.reply_text("抱歉,语音转码失败", incoming_message)
|
||||
return
|
||||
|
||||
media_id = self.upload_media(amr_path, media_type="voice")
|
||||
if not media_id:
|
||||
logger.error("[DingTalk] Failed to upload voice media")
|
||||
self.reply_text("抱歉,语音上传失败", incoming_message)
|
||||
return
|
||||
|
||||
msg_param = {
|
||||
"mediaId": media_id,
|
||||
"duration": str(duration_ms or 1000),
|
||||
}
|
||||
success = self._send_file_message(
|
||||
access_token, incoming_message, "sampleAudio", msg_param, isgroup
|
||||
)
|
||||
if not success:
|
||||
self.reply_text("抱歉,语音发送失败", incoming_message)
|
||||
return
|
||||
|
||||
# 处理文本消息
|
||||
elif reply.type == ReplyType.TEXT:
|
||||
logger.info(f"[DingTalk] Sending text message, length={len(reply.content)}")
|
||||
|
||||
@@ -144,7 +144,14 @@ class FeishuMessage(ChatMessage):
|
||||
file_key = content.get("file_key")
|
||||
file_name = content.get("file_name")
|
||||
|
||||
self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name)
|
||||
# 落到 agent_workspace/tmp 下(绝对路径),与图片处理一致;
|
||||
# 否则相对路径 ./tmp 在 agent 工作区里 read 时会找不到。
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
self.content = os.path.join(
|
||||
tmp_dir, f"{file_key}.{utils.get_path_suffix(file_name)}"
|
||||
)
|
||||
|
||||
def _download_file():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
@@ -162,6 +169,42 @@ class FeishuMessage(ChatMessage):
|
||||
else:
|
||||
logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
|
||||
self._prepare_fn = _download_file
|
||||
elif msg_type == "audio":
|
||||
# 飞书用户发送的语音消息类型为 "audio",文件为 opus 编码格式。
|
||||
# 映射为 ContextType.VOICE,交由 chat_channel 的语音转文字(STT)流程处理。
|
||||
# 文件通过 _prepare_fn 延迟下载,在 chat_channel 调用 cmsg.prepare() 时才执行。
|
||||
self.ctype = ContextType.VOICE
|
||||
content = json.loads(msg.get("content"))
|
||||
file_key = content.get("file_key")
|
||||
|
||||
# 落到 agent_workspace/tmp 下(绝对路径),保证语音 STT 流程可读到
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
self.content = os.path.join(tmp_dir, f"{file_key}.opus")
|
||||
logger.info(f"[FeiShu] audio message: file_key={file_key}, save_path={self.content}")
|
||||
|
||||
def _download_audio():
|
||||
logger.info(f"[FeiShu] downloading audio: file_key={file_key}, msg_id={self.msg_id}")
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}"
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
}
|
||||
params = {
|
||||
"type": "file"
|
||||
}
|
||||
try:
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
logger.info(f"[FeiShu] download audio response: status={response.status_code}, size={len(response.content)} bytes")
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"[FeiShu] audio saved to: {self.content}")
|
||||
else:
|
||||
logger.error(f"[FeiShu] Failed to download audio, key={file_key}, status={response.status_code}, res={response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] Exception downloading audio, key={file_key}: {e}", exc_info=True)
|
||||
self._prepare_fn = _download_audio
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))
|
||||
|
||||
|
||||
0
channel/qq/__init__.py
Normal file
736
channel/qq/qq_channel.py
Normal file
@@ -0,0 +1,736 @@
|
||||
"""
|
||||
QQ Bot channel via WebSocket long connection.
|
||||
|
||||
Supports:
|
||||
- Group chat (@bot), single chat (C2C), guild channel, guild DM
|
||||
- Text / image / file message send & receive
|
||||
- Heartbeat keep-alive and auto-reconnect with session resume
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
import websocket
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.qq.qq_message import QQMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.ws_client_compat import websocket_app_run_forever
|
||||
from config import conf
|
||||
|
||||
# Rich media file_type constants
|
||||
QQ_FILE_TYPE_IMAGE = 1
|
||||
QQ_FILE_TYPE_VIDEO = 2
|
||||
QQ_FILE_TYPE_VOICE = 3
|
||||
QQ_FILE_TYPE_FILE = 4
|
||||
|
||||
QQ_API_BASE = "https://api.sgroup.qq.com"
|
||||
|
||||
# Intents: GROUP_AND_C2C_EVENT(1<<25) | PUBLIC_GUILD_MESSAGES(1<<30)
|
||||
DEFAULT_INTENTS = (1 << 25) | (1 << 30)
|
||||
|
||||
# OpCode constants
|
||||
OP_DISPATCH = 0
|
||||
OP_HEARTBEAT = 1
|
||||
OP_IDENTIFY = 2
|
||||
OP_RESUME = 6
|
||||
OP_RECONNECT = 7
|
||||
OP_INVALID_SESSION = 9
|
||||
OP_HELLO = 10
|
||||
OP_HEARTBEAT_ACK = 11
|
||||
|
||||
# Resumable error codes
|
||||
RESUMABLE_CLOSE_CODES = {4008, 4009}
|
||||
|
||||
|
||||
@singleton
|
||||
class QQChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app_id = ""
|
||||
self.app_secret = ""
|
||||
|
||||
self._access_token = ""
|
||||
self._token_expires_at = 0
|
||||
|
||||
self._ws = None
|
||||
self._ws_thread = None
|
||||
self._heartbeat_thread = None
|
||||
self._connected = False
|
||||
self._stop_event = threading.Event()
|
||||
self._token_lock = threading.Lock()
|
||||
|
||||
self._session_id = None
|
||||
self._last_seq = None
|
||||
self._heartbeat_interval = 45000
|
||||
self._can_resume = False
|
||||
|
||||
self.received_msgs = ExpiredDict(60 * 60 * 7.1)
|
||||
self._msg_seq_counter = {}
|
||||
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.app_id = conf().get("qq_app_id", "")
|
||||
self.app_secret = conf().get("qq_app_secret", "")
|
||||
|
||||
if not self.app_id or not self.app_secret:
|
||||
err = "[QQ] qq_app_id and qq_app_secret are required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._refresh_access_token()
|
||||
if not self._access_token:
|
||||
err = "[QQ] Failed to get initial access_token"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_ws()
|
||||
|
||||
def stop(self):
|
||||
logger.info("[QQ] stop() called")
|
||||
self._stop_event.set()
|
||||
if self._ws:
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Access Token
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _refresh_access_token(self):
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://bots.qq.com/app/getAppAccessToken",
|
||||
json={"appId": self.app_id, "clientSecret": self.app_secret},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._access_token = data.get("access_token", "")
|
||||
expires_in = int(data.get("expires_in", 7200))
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
logger.debug(f"[QQ] Access token refreshed, expires_in={expires_in}s")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to refresh access_token: {e}")
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
with self._token_lock:
|
||||
if time.time() >= self._token_expires_at:
|
||||
self._refresh_access_token()
|
||||
return self._access_token
|
||||
|
||||
def _get_auth_headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": f"QQBot {self._get_access_token()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket connection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_ws_url(self) -> str:
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{QQ_API_BASE}/gateway",
|
||||
headers=self._get_auth_headers(),
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
url = resp.json().get("url", "")
|
||||
logger.debug(f"[QQ] Gateway URL: {url}")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to get gateway URL: {e}")
|
||||
return ""
|
||||
|
||||
def _start_ws(self):
|
||||
ws_url = self._get_ws_url()
|
||||
if not ws_url:
|
||||
logger.error("[QQ] Cannot start WebSocket without gateway URL")
|
||||
self.report_startup_error("Failed to get gateway URL")
|
||||
return
|
||||
|
||||
def _on_open(ws):
|
||||
logger.debug("[QQ] WebSocket connected, waiting for Hello...")
|
||||
|
||||
def _on_message(ws, raw):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
self._handle_ws_message(data)
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to handle ws message: {e}", exc_info=True)
|
||||
|
||||
def _on_error(ws, error):
|
||||
logger.error(f"[QQ] WebSocket error: {error}")
|
||||
|
||||
def _on_close(ws, close_status_code, close_msg):
|
||||
logger.warning(f"[QQ] WebSocket closed: status={close_status_code}, msg={close_msg}")
|
||||
self._connected = False
|
||||
if not self._stop_event.is_set():
|
||||
if close_status_code in RESUMABLE_CLOSE_CODES and self._session_id:
|
||||
self._can_resume = True
|
||||
logger.info("[QQ] Will attempt resume in 3s...")
|
||||
time.sleep(3)
|
||||
else:
|
||||
self._can_resume = False
|
||||
logger.info("[QQ] Will reconnect in 5s...")
|
||||
time.sleep(5)
|
||||
if not self._stop_event.is_set():
|
||||
self._start_ws()
|
||||
|
||||
self._ws = websocket.WebSocketApp(
|
||||
ws_url,
|
||||
on_open=_on_open,
|
||||
on_message=_on_message,
|
||||
on_error=_on_error,
|
||||
on_close=_on_close,
|
||||
)
|
||||
|
||||
def run_forever():
|
||||
try:
|
||||
websocket_app_run_forever(self._ws, ping_interval=0, reconnect=0)
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
logger.info("[QQ] WebSocket thread interrupted")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] WebSocket run_forever error: {e}")
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_forever, daemon=True)
|
||||
self._ws_thread.start()
|
||||
self._ws_thread.join()
|
||||
|
||||
def _ws_send(self, data: dict):
|
||||
if self._ws:
|
||||
self._ws.send(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Identify & Resume & Heartbeat
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_identify(self):
|
||||
self._ws_send({
|
||||
"op": OP_IDENTIFY,
|
||||
"d": {
|
||||
"token": f"QQBot {self._get_access_token()}",
|
||||
"intents": DEFAULT_INTENTS,
|
||||
"shard": [0, 1],
|
||||
"properties": {
|
||||
"$os": "linux",
|
||||
"$browser": "chatgpt-on-wechat",
|
||||
"$device": "chatgpt-on-wechat",
|
||||
},
|
||||
},
|
||||
})
|
||||
logger.debug(f"[QQ] Identify sent with intents={DEFAULT_INTENTS}")
|
||||
|
||||
def _send_resume(self):
|
||||
self._ws_send({
|
||||
"op": OP_RESUME,
|
||||
"d": {
|
||||
"token": f"QQBot {self._get_access_token()}",
|
||||
"session_id": self._session_id,
|
||||
"seq": self._last_seq,
|
||||
},
|
||||
})
|
||||
logger.debug(f"[QQ] Resume sent: session_id={self._session_id}, seq={self._last_seq}")
|
||||
|
||||
def _start_heartbeat(self, interval_ms: int):
|
||||
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
|
||||
return
|
||||
self._heartbeat_interval = interval_ms
|
||||
interval_sec = interval_ms / 1000.0
|
||||
|
||||
def heartbeat_loop():
|
||||
while not self._stop_event.is_set() and self._connected:
|
||||
try:
|
||||
self._ws_send({
|
||||
"op": OP_HEARTBEAT,
|
||||
"d": self._last_seq,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"[QQ] Heartbeat send failed: {e}")
|
||||
break
|
||||
self._stop_event.wait(interval_sec)
|
||||
|
||||
self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
|
||||
self._heartbeat_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Incoming message dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_ws_message(self, data: dict):
|
||||
op = data.get("op")
|
||||
d = data.get("d")
|
||||
t = data.get("t")
|
||||
s = data.get("s")
|
||||
|
||||
if s is not None:
|
||||
self._last_seq = s
|
||||
|
||||
if op == OP_HELLO:
|
||||
heartbeat_interval = d.get("heartbeat_interval", 45000) if d else 45000
|
||||
logger.debug(f"[QQ] Received Hello, heartbeat_interval={heartbeat_interval}ms")
|
||||
self._heartbeat_interval = heartbeat_interval
|
||||
if self._can_resume and self._session_id:
|
||||
self._send_resume()
|
||||
else:
|
||||
self._send_identify()
|
||||
|
||||
elif op == OP_HEARTBEAT_ACK:
|
||||
pass
|
||||
|
||||
elif op == OP_HEARTBEAT:
|
||||
self._ws_send({"op": OP_HEARTBEAT, "d": self._last_seq})
|
||||
|
||||
elif op == OP_RECONNECT:
|
||||
logger.warning("[QQ] Server requested reconnect")
|
||||
self._can_resume = True
|
||||
if self._ws:
|
||||
self._ws.close()
|
||||
|
||||
elif op == OP_INVALID_SESSION:
|
||||
logger.warning("[QQ] Invalid session, re-identifying...")
|
||||
self._session_id = None
|
||||
self._can_resume = False
|
||||
time.sleep(2)
|
||||
self._send_identify()
|
||||
|
||||
elif op == OP_DISPATCH:
|
||||
if t == "READY":
|
||||
self._session_id = d.get("session_id", "")
|
||||
user = d.get("user", {})
|
||||
bot_name = user.get('username', '')
|
||||
logger.info(f"[QQ] ✅ Connected successfully (bot={bot_name})")
|
||||
self._connected = True
|
||||
self._can_resume = False
|
||||
self._start_heartbeat(self._heartbeat_interval)
|
||||
self.report_startup_success()
|
||||
|
||||
elif t == "RESUMED":
|
||||
logger.info("[QQ] Session resumed successfully")
|
||||
self._connected = True
|
||||
self._can_resume = False
|
||||
self._start_heartbeat(self._heartbeat_interval)
|
||||
|
||||
elif t in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE",
|
||||
"AT_MESSAGE_CREATE", "DIRECT_MESSAGE_CREATE"):
|
||||
self._handle_msg_event(d, t)
|
||||
|
||||
elif t in ("GROUP_ADD_ROBOT", "FRIEND_ADD"):
|
||||
logger.info(f"[QQ] Event: {t}")
|
||||
|
||||
else:
|
||||
logger.debug(f"[QQ] Dispatch event: {t}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message event handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_msg_event(self, event_data: dict, event_type: str):
|
||||
msg_id = event_data.get("id", "")
|
||||
if self.received_msgs.get(msg_id):
|
||||
logger.debug(f"[QQ] Duplicate msg filtered: {msg_id}")
|
||||
return
|
||||
self.received_msgs[msg_id] = True
|
||||
|
||||
try:
|
||||
qq_msg = QQMessage(event_data, event_type)
|
||||
except NotImplementedError as e:
|
||||
logger.warning(f"[QQ] {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to parse message: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
is_group = qq_msg.is_group
|
||||
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
if is_group:
|
||||
session_id = qq_msg.other_user_id
|
||||
else:
|
||||
session_id = qq_msg.from_user_id
|
||||
|
||||
if qq_msg.ctype == ContextType.IMAGE:
|
||||
if hasattr(qq_msg, "image_path") and qq_msg.image_path:
|
||||
file_cache.add(session_id, qq_msg.image_path, file_type="image")
|
||||
logger.info(f"[QQ] Image cached for session {session_id}")
|
||||
return
|
||||
|
||||
if qq_msg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
file_refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
fpath = fi["path"]
|
||||
if ftype == "image":
|
||||
file_refs.append(f"[图片: {fpath}]")
|
||||
elif ftype == "video":
|
||||
file_refs.append(f"[视频: {fpath}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {fpath}]")
|
||||
qq_msg.content = qq_msg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[QQ] Attached {len(cached_files)} cached file(s)")
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(
|
||||
qq_msg.ctype,
|
||||
qq_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=qq_msg,
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _compose_context
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
|
||||
if cmsg.is_group:
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send reply
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context.get("isgroup", False)
|
||||
receiver = context.get("receiver", "")
|
||||
|
||||
if not msg:
|
||||
# Active send (e.g. scheduled tasks), no original message to reply to
|
||||
self._active_send_text(reply.content if reply.type == ReplyType.TEXT else str(reply.content),
|
||||
receiver, is_group)
|
||||
return
|
||||
|
||||
event_type = getattr(msg, "event_type", "")
|
||||
msg_id = getattr(msg, "msg_id", "")
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
self._send_text(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
|
||||
self._send_image(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type == ReplyType.FILE:
|
||||
if hasattr(reply, "text_content") and reply.text_content:
|
||||
self._send_text(reply.text_content, msg, event_type, msg_id)
|
||||
time.sleep(0.3)
|
||||
self._send_file(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type in (ReplyType.VIDEO, ReplyType.VIDEO_URL):
|
||||
self._send_media(reply.content, msg, event_type, msg_id, QQ_FILE_TYPE_VIDEO)
|
||||
else:
|
||||
logger.warning(f"[QQ] Unsupported reply type: {reply.type}, falling back to text")
|
||||
self._send_text(str(reply.content), msg, event_type, msg_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_next_msg_seq(self, msg_id: str) -> int:
|
||||
seq = self._msg_seq_counter.get(msg_id, 1)
|
||||
self._msg_seq_counter[msg_id] = seq + 1
|
||||
return seq
|
||||
|
||||
def _build_msg_url_and_base_body(self, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Build the API URL and base body dict for sending a message."""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
url = f"{QQ_API_BASE}/v2/groups/{group_openid}/messages"
|
||||
body = {
|
||||
"msg_id": msg_id,
|
||||
"msg_seq": self._get_next_msg_seq(msg_id),
|
||||
}
|
||||
return url, body, "group", group_openid
|
||||
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = msg._rawmsg.get("author", {}).get("user_openid", "") or msg.from_user_id
|
||||
url = f"{QQ_API_BASE}/v2/users/{user_openid}/messages"
|
||||
body = {
|
||||
"msg_id": msg_id,
|
||||
"msg_seq": self._get_next_msg_seq(msg_id),
|
||||
}
|
||||
return url, body, "c2c", user_openid
|
||||
|
||||
elif event_type == "AT_MESSAGE_CREATE":
|
||||
channel_id = msg._rawmsg.get("channel_id", "")
|
||||
url = f"{QQ_API_BASE}/channels/{channel_id}/messages"
|
||||
body = {"msg_id": msg_id}
|
||||
return url, body, "channel", channel_id
|
||||
|
||||
elif event_type == "DIRECT_MESSAGE_CREATE":
|
||||
guild_id = msg._rawmsg.get("guild_id", "")
|
||||
url = f"{QQ_API_BASE}/dms/{guild_id}/messages"
|
||||
body = {"msg_id": msg_id}
|
||||
return url, body, "dm", guild_id
|
||||
|
||||
return None, None, None, None
|
||||
|
||||
def _post_message(self, url: str, body: dict, event_type: str):
|
||||
try:
|
||||
resp = requests.post(url, json=body, headers=self._get_auth_headers(), timeout=10)
|
||||
if resp.status_code in (200, 201, 202, 204):
|
||||
logger.info(f"[QQ] Message sent successfully: event_type={event_type}")
|
||||
else:
|
||||
logger.error(f"[QQ] Failed to send message: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Send message error: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Active send (no original message, e.g. scheduled tasks)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _active_send_text(self, content: str, receiver: str, is_group: bool):
|
||||
"""Send text without an original message (active push). QQ limits active messages to 4/month per user."""
|
||||
if not receiver:
|
||||
logger.warning("[QQ] No receiver for active send")
|
||||
return
|
||||
if is_group:
|
||||
url = f"{QQ_API_BASE}/v2/groups/{receiver}/messages"
|
||||
else:
|
||||
url = f"{QQ_API_BASE}/v2/users/{receiver}/messages"
|
||||
body = {
|
||||
"content": content,
|
||||
"msg_type": 0,
|
||||
}
|
||||
event_label = "GROUP_ACTIVE" if is_group else "C2C_ACTIVE"
|
||||
self._post_message(url, body, event_label)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send text
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_text(self, content: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
|
||||
if not url:
|
||||
logger.warning(f"[QQ] Cannot send reply for event_type: {event_type}")
|
||||
return
|
||||
body["content"] = content
|
||||
body["msg_type"] = 0
|
||||
self._post_message(url, body, event_type)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rich media upload & send (image / video / file)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _upload_rich_media(self, file_url: str, file_type: int, msg: QQMessage,
|
||||
event_type: str) -> str:
|
||||
"""
|
||||
Upload media via QQ rich media API and return file_info.
|
||||
For group: POST /v2/groups/{group_openid}/files
|
||||
For c2c: POST /v2/users/{openid}/files
|
||||
"""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
|
||||
or msg.from_user_id)
|
||||
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
|
||||
else:
|
||||
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
|
||||
return ""
|
||||
|
||||
upload_body = {
|
||||
"file_type": file_type,
|
||||
"url": file_url,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
upload_url, json=upload_body,
|
||||
headers=self._get_auth_headers(), timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
data = resp.json()
|
||||
file_info = data.get("file_info", "")
|
||||
logger.info(f"[QQ] Rich media uploaded: file_type={file_type}, "
|
||||
f"file_uuid={data.get('file_uuid', '')}")
|
||||
return file_info
|
||||
else:
|
||||
logger.error(f"[QQ] Rich media upload failed: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Rich media upload error: {e}")
|
||||
return ""
|
||||
|
||||
def _upload_rich_media_base64(self, file_path: str, file_type: int, msg: QQMessage,
|
||||
event_type: str) -> str:
|
||||
"""Upload local file via base64 file_data field."""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
|
||||
or msg.from_user_id)
|
||||
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
|
||||
else:
|
||||
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = base64.b64encode(f.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to read file for upload: {e}")
|
||||
return ""
|
||||
|
||||
upload_body = {
|
||||
"file_type": file_type,
|
||||
"file_data": file_data,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
upload_url, json=upload_body,
|
||||
headers=self._get_auth_headers(), timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
data = resp.json()
|
||||
file_info = data.get("file_info", "")
|
||||
logger.info(f"[QQ] Rich media uploaded (base64): file_type={file_type}, "
|
||||
f"file_uuid={data.get('file_uuid', '')}")
|
||||
return file_info
|
||||
else:
|
||||
logger.error(f"[QQ] Rich media upload (base64) failed: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Rich media upload (base64) error: {e}")
|
||||
return ""
|
||||
|
||||
def _send_media_msg(self, file_info: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send a message with msg_type=7 (rich media) using file_info."""
|
||||
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
|
||||
if not url:
|
||||
return
|
||||
body["msg_type"] = 7
|
||||
body["media"] = {"file_info": file_info}
|
||||
self._post_message(url, body, event_type)
|
||||
|
||||
def _send_image(self, img_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send image reply. Supports URL and local file path."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(img_path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if img_path_or_url.startswith("file://"):
|
||||
img_path_or_url = img_path_or_url[7:]
|
||||
|
||||
if img_path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(
|
||||
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
|
||||
elif os.path.exists(img_path_or_url):
|
||||
file_info = self._upload_rich_media_base64(
|
||||
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] Image not found: {img_path_or_url}")
|
||||
self._send_text("[Image send failed]", msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
self._send_text("[Image upload failed]", msg, event_type, msg_id)
|
||||
|
||||
def _send_file(self, file_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send file reply."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(file_path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_path_or_url.startswith("file://"):
|
||||
file_path_or_url = file_path_or_url[7:]
|
||||
|
||||
if file_path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(
|
||||
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
|
||||
elif os.path.exists(file_path_or_url):
|
||||
file_info = self._upload_rich_media_base64(
|
||||
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] File not found: {file_path_or_url}")
|
||||
self._send_text("[File send failed]", msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
self._send_text("[File upload failed]", msg, event_type, msg_id)
|
||||
|
||||
def _send_media(self, path_or_url: str, msg: QQMessage, event_type: str,
|
||||
msg_id: str, file_type: int):
|
||||
"""Generic media send for video/voice etc."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if path_or_url.startswith("file://"):
|
||||
path_or_url = path_or_url[7:]
|
||||
|
||||
if path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(path_or_url, file_type, msg, event_type)
|
||||
elif os.path.exists(path_or_url):
|
||||
file_info = self._upload_rich_media_base64(path_or_url, file_type, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] Media not found: {path_or_url}")
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
logger.error(f"[QQ] Media upload failed: {path_or_url}")
|
||||
123
channel/qq/qq_message.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
import requests
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
def _get_tmp_dir() -> str:
|
||||
"""Return the workspace tmp directory (absolute path), creating it if needed."""
|
||||
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(ws_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class QQMessage(ChatMessage):
|
||||
"""Message wrapper for QQ Bot (websocket long-connection mode)."""
|
||||
|
||||
def __init__(self, event_data: dict, event_type: str):
|
||||
super().__init__(event_data)
|
||||
self.msg_id = event_data.get("id", "")
|
||||
self.create_time = event_data.get("timestamp", "")
|
||||
self.is_group = event_type in ("GROUP_AT_MESSAGE_CREATE",)
|
||||
self.event_type = event_type
|
||||
|
||||
author = event_data.get("author", {})
|
||||
from_user_id = author.get("member_openid", "") or author.get("id", "")
|
||||
group_openid = event_data.get("group_openid", "")
|
||||
|
||||
content = event_data.get("content", "").strip()
|
||||
|
||||
attachments = event_data.get("attachments", [])
|
||||
has_image = any(
|
||||
a.get("content_type", "").startswith("image/") for a in attachments
|
||||
) if attachments else False
|
||||
|
||||
if has_image and not content:
|
||||
self.ctype = ContextType.IMAGE
|
||||
img_attachment = next(
|
||||
a for a in attachments if a.get("content_type", "").startswith("image/")
|
||||
)
|
||||
img_url = img_attachment.get("url", "")
|
||||
if img_url and not img_url.startswith("http"):
|
||||
img_url = "https://" + img_url
|
||||
tmp_dir = _get_tmp_dir()
|
||||
image_path = os.path.join(tmp_dir, f"qq_{self.msg_id}.png")
|
||||
try:
|
||||
resp = requests.get(img_url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
self.content = image_path
|
||||
self.image_path = image_path
|
||||
logger.info(f"[QQ] Image downloaded: {image_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to download image: {e}")
|
||||
self.content = "[Image download failed]"
|
||||
self.image_path = None
|
||||
elif has_image and content:
|
||||
self.ctype = ContextType.TEXT
|
||||
image_paths = []
|
||||
tmp_dir = _get_tmp_dir()
|
||||
for idx, att in enumerate(attachments):
|
||||
if not att.get("content_type", "").startswith("image/"):
|
||||
continue
|
||||
img_url = att.get("url", "")
|
||||
if img_url and not img_url.startswith("http"):
|
||||
img_url = "https://" + img_url
|
||||
img_path = os.path.join(tmp_dir, f"qq_{self.msg_id}_{idx}.png")
|
||||
try:
|
||||
resp = requests.get(img_url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
with open(img_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
image_paths.append(img_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to download mixed image: {e}")
|
||||
content_parts = [content]
|
||||
for p in image_paths:
|
||||
content_parts.append(f"[图片: {p}]")
|
||||
self.content = "\n".join(content_parts)
|
||||
else:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = content
|
||||
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
self.other_user_id = group_openid
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = from_user_id
|
||||
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = author.get("user_openid", "") or from_user_id
|
||||
self.from_user_id = user_openid
|
||||
self.to_user_id = ""
|
||||
self.other_user_id = user_openid
|
||||
self.actual_user_id = user_openid
|
||||
|
||||
elif event_type == "AT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
channel_id = event_data.get("channel_id", "")
|
||||
self.other_user_id = channel_id
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = author.get("username", from_user_id)
|
||||
|
||||
elif event_type == "DIRECT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
guild_id = event_data.get("guild_id", "")
|
||||
self.other_user_id = f"dm_{guild_id}_{from_user_id}"
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = author.get("username", from_user_id)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported QQ event type: {event_type}")
|
||||
|
||||
logger.debug(f"[QQ] Message parsed: type={event_type}, ctype={self.ctype}, "
|
||||
f"from={self.from_user_id}, content_len={len(self.content)}")
|
||||
1401
channel/web/static/css/console.css
Normal file
7121
channel/web/static/js/console.js
Normal file
1
channel/web/static/logos/claudeAPI.svg
Normal file
@@ -0,0 +1 @@
|
||||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1779251656961" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="18432" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M252.8 652.8l167.893333-94.293333 2.773334-8.106667-2.773334-4.48h-8.106666l-28.16-1.706667-96-2.56-83.2-3.413333-80.64-4.266667-20.266667-4.266666L85.333333 504.746667l1.92-12.586667 17.066667-11.52 24.32 2.133333 53.973333 3.626667 81.066667 5.546667 58.666667 3.413333 87.04 9.173333h13.866666l1.92-5.546666-4.693333-3.413334-3.626667-3.413333-83.84-56.746667-90.666666-60.16-47.573334-34.56-25.813333-17.493333-13.013333-16.426667-5.546667-35.84 23.253333-25.813333 31.36 2.133333 7.893334 2.133334 31.786666 24.32 67.84 52.48L401.066667 391.466667l13.013333 10.88 5.12-3.626667 0.64-2.56-5.76-9.813333-48.213333-87.04L314.453333 210.773333l-22.826666-36.693333-5.973334-21.973333a107.861333 107.861333 0 0 1-3.626666-26.026667l26.666666-36.053333L323.413333 85.333333l35.413334 4.693334 14.933333 13.013333 21.973333 50.346667 35.626667 79.36 55.253333 107.733333 16.213334 32 8.746666 29.653333 3.2 9.173334h5.546667v-5.12l4.48-60.8 8.32-74.453334 8.106667-96 2.773333-27.093333 13.44-32.426667 26.666667-17.493333 20.693333 10.026667 17.066667 24.32-2.346667 15.786666-10.24 65.92-19.84 103.253334-13.013333 69.12h7.466666l8.746667-8.746667 34.986667-46.506667 58.666666-73.386666 26.026667-29.226667 30.293333-32.213333 19.413334-15.36h36.693333l27.093333 40.106666-12.16 41.386667-37.76 48-31.36 40.533333-45.013333 60.586667-28.16 48.426667 2.56 3.84 6.613333-0.64 101.546667-21.546667 54.826667-10.026667 65.493333-11.306666 29.653333 13.866666 3.2 14.08-11.733333 28.8-69.973333 17.28-82.133334 16.426667-122.24 29.013333-1.493333 1.066667 1.706667 2.133333 55.04 5.12 23.466666 1.28h57.6l107.306667 7.893334 28.16 18.56 16.853333 22.613333-2.773333 17.28-43.306667 21.973333-58.24-13.866666-136.106666-32.426667-46.72-11.733333h-6.4v3.84l38.826666 37.973333 71.253334 64.426667 89.173333 82.986666 4.48 20.48-11.52 16.213334-12.16-1.706667-78.506667-58.88-30.293333-26.666667-68.48-57.6h-4.48v5.973334l15.786667 23.04 83.413333 125.226666 4.266667 38.4-5.973334 12.586667-21.546666 7.466667-23.68-4.266667-48.853334-68.48-50.346666-77.226667-40.533334-69.12-4.906666 2.773334-23.893334 258.133333-11.306666 13.226667-26.026667 10.026666-21.546667-16.426666-11.52-26.666667 11.52-52.48 13.866667-68.48 11.306667-54.4 10.24-67.626667 5.973333-22.4-0.426667-1.493333-4.906666 0.64-50.986667 69.973333-77.653333 104.746667-61.44 65.706667-14.72 5.76-25.386667-13.226667 2.346667-23.466667 14.293333-20.906666 84.906667-107.946667 51.2-66.986667 33.066666-38.613333v-5.546667h-2.133333l-225.493333 146.56-40.106667 5.12-17.28-16.213333 2.133333-26.666667 8.106667-8.746666 67.84-46.72h-0.213333l0.853333 0.853333z" fill="#D97757" p-id="18433"></path></svg>
|
||||
|
After Width: | Height: | Size: 2.9 KiB |
10
channel/web/static/logos/custom.svg
Normal file
@@ -0,0 +1,10 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" width="200" height="200" fill="none" stroke="#475569" stroke-width="1.8" stroke-linecap="round" stroke-linejoin="round">
|
||||
<!-- Horizontal slider tracks -->
|
||||
<line x1="4" y1="7" x2="20" y2="7"/>
|
||||
<line x1="4" y1="12" x2="20" y2="12"/>
|
||||
<line x1="4" y1="17" x2="20" y2="17"/>
|
||||
<!-- Knobs (filled circles) -->
|
||||
<circle cx="9" cy="7" r="2.2" fill="#475569" stroke="none"/>
|
||||
<circle cx="15" cy="12" r="2.2" fill="#475569" stroke="none"/>
|
||||
<circle cx="7" cy="17" r="2.2" fill="#475569" stroke="none"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 573 B |
1
channel/web/static/logos/dashscope.svg
Normal file
@@ -0,0 +1 @@
|
||||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1779251621200" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="17444" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M1019.364785 620.816931L891.797142 397.807295 946.450846 293.15069a29.097778 29.097778 0 0 0 6.399732-36.393472l-70.184053-126.586684a30.078737 30.078737 0 0 0-24.574968-13.652427H597.4945L539.171949 14.549389a27.348852 27.348852 0 0 0-20.906122-14.549389H380.628607a29.139776 29.139776 0 0 0-24.616967 14.549389v5.545767L225.797108 243.062793H100.919352a29.182775 29.182775 0 0 0-25.513928 13.653427L3.428446 384.11187a32.766624 32.766624 0 0 0 0 29.182775L132.831012 638.096205 74.508461 740.064923a32.766624 32.766624 0 0 0 0 29.05478l66.514207 116.561105a29.905744 29.905744 0 0 0 25.513929 14.505391H427.132654l62.845361 109.222414A30.078737 30.078737 0 0 0 512.762058 1024H660.382859a29.139776 29.139776 0 0 0 24.574968-14.549389l128.463606-224.843558h114.76818a31.91366 31.91366 0 0 0 24.660965-15.444352l66.471208-117.414069a28.158818 28.158818 0 0 0 0-30.9747l0.042999 0.042999z m-161.273228 14.591387L791.57735 512.490479 518.265827 993.964261l-74.748861-122.87484h-273.268525l65.618244-119.205994h139.386147L101.856313 272.244568h143.055993L380.671605 30.121735l68.34913 119.247993-70.184053 122.87484H925.501726l-69.202094 121.936879 137.594222 241.183873H858.134555z" fill="#605BEC" p-id="17445"></path><path d="M499.962596 699.320634l174.371677-274.719464H324.694955z" fill="#605BEC" p-id="17446"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.6 KiB |
1
channel/web/static/logos/deepseek.svg
Normal file
|
After Width: | Height: | Size: 5.1 KiB |
1
channel/web/static/logos/doubao.svg
Normal file
@@ -0,0 +1 @@
|
||||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1779261485522" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="5381" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M958.976 439.808C804.864 336.896 642.56 321.536 642.56 321.536s8.192 235.008-10.752 306.176c-0.512 9.728-11.776 75.264-43.008 157.696-10.752 28.16-24.064 55.296-39.424 81.408-40.96 74.24-89.6 127.488-89.6 127.488 119.808-48.64 205.312-92.672 309.76-175.616 122.88-96.768 229.376-254.464 189.44-378.88z" fill="#37E1BE" p-id="5382"></path><path d="M329.728 395.776c158.208-100.864 308.736-78.848 312.32-74.752 0.512 0.512 1.024 0.512 1.024 0.512 0-14.336-6.656-60.928-13.312-106.496-11.776-60.928-22.528-124.928-23.04-133.632-170.496-139.264-356.864-78.336-448 25.6-61.44 70.144-103.424 169.984-102.4 224.256V762.88c0.512-12.8 1.536-20.48 2.048-20.48 17.92-197.12 271.36-346.624 271.36-346.624z" fill="#A569FF" p-id="5383"></path><path d="M792.064 272.384c-41.984-43.52-87.552-88.576-122.368-125.44-33.28-34.816-59.392-60.928-62.976-65.536 0.512 8.704 11.264 72.704 23.04 133.632 6.656 45.568 12.8 92.672 13.312 106.496 0 0 162.304 15.36 316.416 118.272-0.512 0-83.456-80.384-167.424-167.424zM549.888 866.816c-2.56 1.024-198.656 107.008-292.352-30.72-20.992-30.72-31.744-68.096-33.28-106.496-3.072-74.752 5.12-227.84 105.472-333.824 0 0-253.44 149.504-270.848 346.624-0.512 0.512-2.048 8.192-2.048 20.48-1.024 32.768 4.608 98.304 43.008 155.136 52.224 78.336 193.024 138.752 328.192 85.504l33.28-9.728c-1.024 0.512 47.616-52.224 88.576-126.976z" fill="#1E37FC" p-id="5384"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.7 KiB |
1
channel/web/static/logos/gemini.svg
Normal file
@@ -0,0 +1 @@
|
||||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1779251750646" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="29551" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M214.101333 512c0-32.512 5.546667-63.701333 15.36-92.928L57.173333 290.218667A491.861333 491.861333 0 0 0 4.693333 512c0 79.701333 18.858667 154.88 52.394667 221.610667l172.202667-129.066667A290.56 290.56 0 0 1 214.101333 512" fill="#FBBC05" p-id="29552"></path><path d="M516.693333 216.192c72.106667 0 137.258667 25.002667 188.458667 65.962667L854.101333 136.533333C763.349333 59.178667 646.997333 11.392 516.693333 11.392c-202.325333 0-376.234667 113.28-459.52 278.826667l172.373334 128.853333c39.68-118.016 152.832-202.88 287.146666-202.88" fill="#EA4335" p-id="29553"></path><path d="M516.693333 807.808c-134.357333 0-247.509333-84.864-287.232-202.88l-172.288 128.853333c83.242667 165.546667 257.152 278.826667 459.52 278.826667 124.842667 0 244.053333-43.392 333.568-124.757333l-163.584-123.818667c-46.122667 28.458667-104.234667 43.776-170.026666 43.776" fill="#34A853" p-id="29554"></path><path d="M1005.397333 512c0-29.568-4.693333-61.44-11.648-91.008H516.650667V614.4h274.602666c-13.696 65.962667-51.072 116.650667-104.533333 149.632l163.541333 123.818667c93.994667-85.418667 155.136-212.650667 155.136-375.850667" fill="#4285F4" p-id="29555"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
1
channel/web/static/logos/linkai.svg
Normal file
|
After Width: | Height: | Size: 11 KiB |
1
channel/web/static/logos/minimax.svg
Normal file
@@ -0,0 +1 @@
|
||||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1779251514432" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="11888" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M415.392 475.808v329.984c-22.304 111.744-170.56 82.944-171.2 1.92-0.672-101.824 0-202.976 0-304.064v-117.184c0-14.656-3.2-26.24-16-35.392-24.96-18.72-54.944 3.264-55.584 30.208-1.408 36.16-0.704 71.616-1.408 107.264 0 28.16 0 55.52 0.64 83.648-18.368 123.776-168.32 103.232-171.808 0.704V487.04c0-28.032 54.944-34.624 52.256 7.36-1.792 20.8-0.64 42.272-1.344 62.912-0.64 36.8 55.648 61.6 68.896 1.408 0.64-49.632 0.64-99.264 0.64-149.344 0-62.752 17.824-113.856 84.352-118.624 28.8-2.56 47.968 9.504 66.336 30.304 7.04 7.36 23.68 30.72 24.32 56.16 0 23.456 0.64 46.752 0.64 70.464 0 46.72-0.64 93.76-0.64 140.48 0 30.304 0.64 60.256 0.64 89.856 0 37.536 0 75.552-0.64 113.152-0.64 48.864 58.816 48.16 68.352-0.768 0-57.632 0.64-114.56 0.64-172.192 0-141.984-0.64-283.968-0.64-425.856 0-14.72-2.048-55.584 5.76-70.464 41.504-101.12 167.392-56.96 168.544 26.72 2.432 171.52 0 344.896 0.64 516.8 0 59.616-48.416 46.816-51.104 23.488 0-178.88 0-358.4 0.64-537.024-2.368-44.832-68.832-38.72-72.672-6.592-1.28 36.864-0.64 74.4-1.28 111.232v219.008h0.64l0.448 0.256h-0.064z" fill="#D4367A" p-id="11889"></path><path d="M610.016 473.184v242.336V143.648c21.632-112.512 169.824-83.264 170.464-2.176 0.704 101.12 0 202.912 0.704 304 0 38.784 0 77.728-0.64 116.544 0 15.36 3.776 26.176 16.64 36.032 24.32 18.24 54.24-3.2 55.584-30.592 1.344-35.488 0.64-70.976 0.64-107.328V376.96c18.56-123.776 168.128-103.232 171.264-0.704v310.592c0 28.16-54.304 34.848-51.872-7.296 1.472-21.44 0-267.104 0.768-288.64 1.28-36.16-55.712-61.664-68.928-0.768v148.576c0 63.68-17.856 113.92-84.96 119.36-63.264 1.504-88.704-42.24-90.752-86.432V271.328c0-38.24 0-75.552 0.64-113.088 0.64-48.864-58.784-48.864-68.896 0.704V831.36c0 14.592 2.048 55.52-5.184 70.432-41.44 101.056-168 56.864-169.152-26.752v-79.616c3.136-53.6 48.416-40.864 50.464-18.176v94.464c2.432 44.928 68.928 39.488 72.064 6.656 1.344-36.896 1.344-73.728 1.344-111.296v-293.824h-0.192v-0.064z" fill="#ED6D48" p-id="11890"></path></svg>
|
||||
|
After Width: | Height: | Size: 2.2 KiB |
1
channel/web/static/logos/moonshot.svg
Normal file
@@ -0,0 +1 @@
|
||||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1779251592968" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="16416" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M117.9648 684.6464l342.30272 93.57312v75.34592l209.7152 58.5728A428.99456 428.99456 0 0 1 512 942.08c-176.128 0-327.53664-105.8816-394.0352-257.4336zM83.29216 477.42976l407.30624 112.64-9.6256 37.00736-6.0416 35.0208 383.3856 104.96a432.5376 432.5376 0 0 1-65.10592 70.32832l-688.18944-185.9584A429.4656 429.4656 0 0 1 81.92 512c0-11.63264 0.47104-23.1424 1.37216-34.54976z m57.344-182.4768l429.07648 114.21696a279.94112 279.94112 0 0 0-23.06048 35.55328 201.17504 201.17504 0 0 0-14.70464 34.93888l403.08736 110.26432a426.8032 426.8032 0 0 1-23.552 81.7152L86.54848 448.7168a427.25376 427.25376 0 0 1 54.0672-153.76384z m158.47424-156.75392l404.23424 108.31872a190.2592 190.2592 0 0 0-32.80896 24.90368c-9.13408 8.8064-19.8656 21.4016-32.1536 37.74464l285.24544 77.78304c9.216 30.45376 15.03232 61.8496 17.32608 93.5936L156.61056 269.68064a432.27136 432.27136 0 0 1 142.49984-131.4816zM512 81.92c142.90944 0 269.55776 69.71392 347.7504 176.98816L337.26464 118.90688A428.50304 428.50304 0 0 1 512 81.92z" fill="#000000" p-id="16417"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
1
channel/web/static/logos/openai.svg
Normal file
@@ -0,0 +1 @@
|
||||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1779251225589" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="9015" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M881.664 431.488a218.88 218.88 0 0 0-18.176-177.088A218.624 218.624 0 0 0 628.992 149.76c-40.576-45.824-100.288-71.424-162.176-71.424a219.136 219.136 0 0 0-208 150.4 215.68 215.68 0 0 0-144 104.512 218.944 218.944 0 0 0 26.688 254.912 218.752 218.752 0 0 0 19.2 177.152 217.088 217.088 0 0 0 234.624 104.512 219.136 219.136 0 0 0 162.112 72.512 219.136 219.136 0 0 0 208-150.4 215.68 215.68 0 0 0 144-104.512 219.008 219.008 0 0 0-27.712-256z m-324.288 454.4a158.08 158.08 0 0 1-103.424-37.376c1.088-1.088 4.288-2.176 5.376-3.2l171.712-99.2a28.16 28.16 0 0 0 13.824-24.512V479.488l72.576 41.6c1.024 0 1.024 1.024 1.024 2.112v200.512a160.512 160.512 0 0 1-161.088 162.112z m-347.712-148.288c-19.2-33.088-25.6-71.488-19.2-108.8 1.088 1.024 3.2 2.176 5.376 3.2l171.712 99.2a25.984 25.984 0 0 0 27.712 0l210.112-121.6v84.224c0 1.152 0 2.176-1.024 2.176L430.464 796.16c-76.8 44.8-176 18.176-220.8-58.624z m-44.736-375.424c19.2-32.64 48.896-57.856 84.224-71.488v204.8c0 9.6 5.376 19.2 13.888 24.512l210.176 121.6-72.576 41.6c-1.024 0-2.112 1.088-2.112 0L224.64 582.912a160.448 160.448 0 0 1-59.776-220.8h0.064z m597.312 138.688l-210.112-121.6 72.512-41.6c1.088 0 2.176-1.088 2.176 0l173.824 100.224a161.088 161.088 0 0 1-25.6 291.2V525.44a26.304 26.304 0 0 0-12.8-24.512z m71.488-108.8a23.232 23.232 0 0 0-5.312-3.2L656.64 289.536a26.048 26.048 0 0 0-27.712 0l-210.176 121.6V326.912c0-1.088 0-2.176 1.088-2.176l173.824-100.224a161.152 161.152 0 0 1 220.8 59.712c19.2 32 25.6 70.4 19.2 107.776z m-454.4 149.248l-72.64-41.6c-1.024 0-1.024-1.088-1.024-2.176V297.088A162.048 162.048 0 0 1 467.84 135.04a158.08 158.08 0 0 1 103.424 37.312 22.848 22.848 0 0 1-5.312 3.2L394.24 274.688a28.16 28.16 0 0 0-13.888 24.512v242.112h-1.088z m39.424-85.312l93.824-54.4 93.888 54.4v107.712l-93.888 54.4-93.824-54.4V456z" fill="#000000" p-id="9016"></path></svg>
|
||||
|
After Width: | Height: | Size: 2.1 KiB |
1
channel/web/static/logos/qianfan.svg
Normal file
@@ -0,0 +1 @@
|
||||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1779251568791" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="14450" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M96.20121136 636.3124965c-0.1472897-113.41305959-0.29457937-226.8261192-0.29457937-340.23917879 0-14.87625845 7.65906378-26.51214381 20.4732666-34.02391789 45.51251353-26.65943349 91.02502705-53.31886698 136.83211997-79.53643141 71.1409192-40.94653321 142.42912809-81.59848704 213.71733698-122.39773055 7.36448439-4.12411126 14.58167909-8.3955122 21.50429441-13.2560719 19.44223878-13.40336159 39.03176725-16.05457598 60.09419263-3.53495252 27.39588193 16.34915535 54.93905355 32.25644163 82.48222516 48.16372793 88.0792333 50.96223197 176.30575629 101.77717426 264.38498958 152.59211653 9.86840908 5.74429781 19.88410785 11.19401627 29.60522725 17.0856038 14.13981003 8.54280189 21.50429441 21.06242535 21.50429443 37.70616007 0 147.73155685 0.29457937 295.46311371-0.1472897 443.19467057 0 15.46541722-7.2171947 28.57419943-21.7988738 36.96971163-34.7603663 20.17868721-70.55176044 38.88447758-104.57567833 59.94690293-48.90017634 30.19438599-100.00969801 56.11737105-148.76258466 86.60633642-29.01606849 18.11663161-59.50503387 34.02391789-89.11026112 50.96223197-13.10878221 7.51177407-26.07027474 15.17083783-39.03176726 22.9771913-13.84523065 8.3955122-27.83775099 8.83738127-41.97756102 0.73644843-56.41195043-32.55102101-112.82390085-65.10204201-169.38314098-97.653063-61.86166887-35.64410444-123.72333775-71.1409192-185.4377169-106.78502365-11.19401627-6.48074626-22.24074286-12.81420285-32.99289009-19.88410785-11.48859565-7.65906378-17.08560379-19.14765941-17.08560378-32.69831069-0.1472897-34.7603663 0.1472897-69.52073264 0.29457938-104.28109895 1.62018657-0.58915875 1.62018657-1.62018657-0.29457938-2.65121438z m356.58833414-225.500512c2.20934532-1.76747625 4.41869063-3.68224221 6.77532565-5.15513907 68.93157389-39.62092601 137.86314777-79.24185204 206.94201135-118.86277807 2.79850407-1.62018657 6.48074626-1.62018657 6.62803594-6.18616688 0.1472897-4.8605597-4.12411126-4.71327001-6.77532564-6.18616688-40.65195383-23.56635005-81.59848704-46.83812071-122.10315117-70.84633984-16.79102442-10.01569877-32.84560039-8.54280189-48.45830728 0.58915876-45.9543826 26.51214381-91.46689612 53.61344636-137.27398903 80.42016953-31.96186226 18.70579035-64.21830387 37.11700133-96.32745581 55.67550198-18.41121097 10.60485751-27.54317163 25.33382629-27.24859225 47.72185885 0.88373813 89.55213018 0.58915875 179.10426036 0.14728969 268.65639053-0.1472897 20.17868721 9.27925033 33.58204881 25.33382629 43.15587853 31.3727035 18.70579035 63.18727606 37.11700133 95.14913832 54.93905355 10.89943689 6.03887719 21.06242535 13.99252034 35.79139414 18.41121096V505.51925374c6.48074626 19.58952848 18.55850066 34.02391789 36.67513226 44.6287754 27.83775099 16.20186565 63.18727606 12.51962347 86.31175705-10.45756784 26.95401286-26.65943349 28.72148912-62.89269668 12.81420282-90.14128893-16.34915535-28.42690974-43.59774757-37.55887038-74.38129233-38.73718787z m82.48222517 429.64401928c14.28709972-3.82953187 25.92298506-13.99252034 38.88447758-21.35700473 40.94653321-23.27177067 81.30390766-47.72185885 122.54502023-70.55176046 26.95401286-15.02354815 52.87699792-31.66728287 80.71474891-45.21793415 16.79102442-8.10093283 29.60522723-22.53532223 29.60522726-43.4504579 0.1472897-92.939793 0.29457937-185.73229631 0.14728969-278.6720893 0-11.19401627-5.15513907-13.99252034-13.84523067-7.06990501-26.51214381 20.76784598-57.29568854 34.46578693-86.16446735 51.25681135-54.49718448 31.81457257-109.14165865 63.33456576-163.78613282 95.00184862-8.54280189 4.8605597-11.78317502 10.45756784-11.63588535 20.47326662 0.29457937 96.18016613 0.1472897 192.50762194 0.1472897 288.68778806-0.29457937 3.5349525-1.47289687 7.65906378 3.38766282 10.8994369z" fill="#066AF3" p-id="14451"></path><path d="M96.20121136 636.3124965c1.91476594 1.03102783 1.91476594 2.06205563 0 3.09308345v-3.09308345z" fill="#4372E0" p-id="14452"></path><path d="M391.3697457 505.37196405c-5.44971845-44.33419602 13.84523065-74.08671296 61.4197998-94.55997955 30.93083443 1.17831749 58.03213699 10.31027814 74.38129233 38.5898982 15.75999659 27.39588193 14.13981003 63.48185543-12.81420282 90.14128893-23.27177067 22.97719129-58.47400606 26.65943349-86.31175705 10.45756783-18.11663161-10.60485751-30.34167568-25.03924691-36.67513226-44.62877541z" fill="#002A9A" p-id="14453"></path></svg>
|
||||
|
After Width: | Height: | Size: 4.5 KiB |
1
channel/web/static/logos/zhipu.svg
Normal file
@@ -0,0 +1 @@
|
||||
<?xml version="1.0" standalone="no"?><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"><svg t="1779251419020" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="10062" xmlns:xlink="http://www.w3.org/1999/xlink" width="200" height="200"><path d="M520.063496 0v77.563152c0 269.231173-144.758953 414.054122-434.212862 434.340854L86.106618 511.968002H76.827198V255.984001l443.236298-255.984001z" fill="#5B55F6" p-id="10063"></path><path d="M520.063496 1023.936004v-77.563152c0-269.231173-144.758953-414.054122-434.212862-434.340854L86.042622 511.968002H76.827198v255.984001l443.236298 255.984001z" fill="#376AF3" p-id="10064"></path><path d="M520.063496 0v77.563152c0 269.231173 144.758953 414.054122 434.276858 434.340854L954.08437 511.968002h9.215424V255.984001L520.063496 0z" fill="#5B55F6" p-id="10065"></path><path d="M520.063496 1023.936004v-77.563152c0-269.231173 144.758953-414.054122 434.276858-434.340854L954.08437 511.968002h9.27942v255.984001l-443.236298 255.984001z" fill="#376AF3" p-id="10066"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
41
channel/web/static/vendor/README.md
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
# Vendor assets
|
||||
|
||||
Third-party frontend assets bundled locally so the Web Console can run in
|
||||
fully offline / air-gapped environments (no requests to cloudflare, jsdelivr,
|
||||
googleapis, gstatic, etc.).
|
||||
|
||||
All files here are vendored copies of upstream releases. Do not edit them by
|
||||
hand; re-download from the official source if upgrading.
|
||||
|
||||
## Manifest
|
||||
|
||||
| Path | Source | Version |
|
||||
| --------------------------------------------------- | ------------------------------------------------------------------------------------------------- | ------- |
|
||||
| `fontawesome/css/all.min.css` | https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css | 6.4.0 |
|
||||
| `fontawesome/webfonts/fa-{brands,regular,solid,v4compatibility}-*.woff2` | https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/webfonts/ | 6.4.0 |
|
||||
| `fonts/inter/inter-latin.woff2` | https://fonts.gstatic.com/s/inter/v20/UcC73FwrK3iLTeHuS_nVMrMxCp50SjIa1ZL7.woff2 | v20 |
|
||||
| `fonts/inter/inter.css` | Hand-written `@font-face` declaration that maps Inter weights 300-700 to the local woff2 | - |
|
||||
| `tailwind/tailwind.min.js` | https://cdn.tailwindcss.com (Play CDN runtime, JIT engine for the browser) | latest |
|
||||
| `markdown-it/markdown-it.min.js` | https://cdn.jsdelivr.net/npm/markdown-it@13.0.1/dist/markdown-it.min.js | 13.0.1 |
|
||||
| `highlightjs/highlight.min.js` | https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js | 11.9.0 |
|
||||
| `highlightjs/styles/github{,-dark}.min.css` | https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/ | 11.9.0 |
|
||||
| `highlightjs/languages/{python,javascript,java,go,bash}.min.js` | https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/ | 11.9.0 |
|
||||
| `d3/d3.min.js` | https://cdn.jsdelivr.net/npm/d3@7/dist/d3.min.js (loaded lazily for the knowledge graph view) | 7.x |
|
||||
|
||||
Notes:
|
||||
|
||||
- The Inter font only ships the latin subset (CJK characters fall back to the
|
||||
system sans-serif via the font-family chain in `tailwind.config`).
|
||||
- Only `woff2` font files are shipped (no `ttf` fallback). woff2 is supported
|
||||
by all browsers released since 2014-2018 (Chrome 36+, Firefox 39+, Safari
|
||||
12+, Edge, Opera 26+). The only mainstream browser that lacks woff2 support
|
||||
is IE 11, which cannot run the rest of the console anyway. `all.min.css`
|
||||
still references the ttf paths as a `src:` fallback — those 404s are
|
||||
harmless and ignored by the browser once the woff2 loads.
|
||||
- `tailwind.min.js` is the official Tailwind Play CDN build (an in-browser JIT
|
||||
engine). It must be served as JS to keep the existing `tailwind.config = {}`
|
||||
customization working.
|
||||
- One external script remains in `channel/web/static/js/console.js`:
|
||||
`wwcdn.weixin.qq.com/.../wecom-aibot-sdk` — Tencent requires the WeCom Bot
|
||||
SDK to be loaded from their CDN, and it is only fetched when the user opens
|
||||
the WeCom Bot QR-login flow.
|
||||
2
channel/web/static/vendor/d3/d3.min.js
vendored
Normal file
9
channel/web/static/vendor/fontawesome/css/all.min.css
vendored
Normal file
BIN
channel/web/static/vendor/fontawesome/webfonts/fa-brands-400.woff2
vendored
Normal file
BIN
channel/web/static/vendor/fontawesome/webfonts/fa-regular-400.woff2
vendored
Normal file
BIN
channel/web/static/vendor/fontawesome/webfonts/fa-solid-900.woff2
vendored
Normal file
BIN
channel/web/static/vendor/fontawesome/webfonts/fa-v4compatibility.woff2
vendored
Normal file
BIN
channel/web/static/vendor/fonts/inter/inter-latin.woff2
vendored
Normal file
16
channel/web/static/vendor/fonts/inter/inter.css
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
/* Inter font (latin subset only).
|
||||
* Single variable font woff2 that covers weights 300/400/500/600/700.
|
||||
* Non-latin scripts (CJK, etc.) fall back to system sans-serif via the
|
||||
* font-family chain defined in tailwind.config (Inter, system-ui, ...).
|
||||
* Source: Google Fonts (Inter v20), redistributed locally to avoid runtime
|
||||
* dependency on fonts.googleapis.com / fonts.gstatic.com.
|
||||
*/
|
||||
|
||||
@font-face {
|
||||
font-family: 'Inter';
|
||||
font-style: normal;
|
||||
font-weight: 300 700;
|
||||
font-display: swap;
|
||||
src: url('./inter-latin.woff2') format('woff2');
|
||||
unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+0304, U+0308, U+0329, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD;
|
||||
}
|
||||