mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 09:48:22 +08:00
Compare commits
558 Commits
feat-cow-a
...
feat-i18n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e6a2cc2c0 | ||
|
|
7bf4ef3d05 | ||
|
|
126649f70f | ||
|
|
1827a2a31c | ||
|
|
fcf4eb78dc | ||
|
|
2ec6ea8045 | ||
|
|
0994a3586d | ||
|
|
29c4be6a3a | ||
|
|
c5b8e06891 | ||
|
|
54a20bca92 | ||
|
|
6e786bde90 | ||
|
|
b671b0d725 | ||
|
|
57f5692074 | ||
|
|
b0ac0731c7 | ||
|
|
3c161df526 | ||
|
|
aa3f48e93c | ||
|
|
5ae1e1adde | ||
|
|
fe8b8fe831 | ||
|
|
5aca54c083 | ||
|
|
458b1a1d88 | ||
|
|
3dd4b84179 | ||
|
|
99bddb79d6 | ||
|
|
136b0b89e8 | ||
|
|
c605b0b080 | ||
|
|
b7b8e3679c | ||
|
|
aeb6610ff4 | ||
|
|
e3eacc77d7 | ||
|
|
37661daf40 | ||
|
|
877b848370 | ||
|
|
5c163cc0fe | ||
|
|
6e04ea8240 | ||
|
|
d106465419 | ||
|
|
f39380cea7 | ||
|
|
bccce2d7cb | ||
|
|
6721dbdbcc | ||
|
|
83cd6ad158 | ||
|
|
116fb27257 | ||
|
|
8d67177a1b | ||
|
|
ad2db1a776 | ||
|
|
2e6d9e0f27 | ||
|
|
e05f85f3ce | ||
|
|
40c48a9a61 | ||
|
|
c9a7525d0b | ||
|
|
fd571ac539 | ||
|
|
c5a3f991c5 | ||
|
|
eb74b73351 | ||
|
|
9b31f45481 | ||
|
|
bc9c1691f5 | ||
|
|
73bf83d2ff | ||
|
|
36e1988fee | ||
|
|
aad6ef635e | ||
|
|
96659cd616 | ||
|
|
c8787b7de4 | ||
|
|
91d427c8f9 | ||
|
|
c8c0573dbd | ||
|
|
29af855ecd | ||
|
|
0a146a245d | ||
|
|
bd85fee7d7 | ||
|
|
571897e2fd | ||
|
|
840dabeccd | ||
|
|
069bffa3e8 | ||
|
|
cc10d230b0 | ||
|
|
2517f2add8 | ||
|
|
a534266025 | ||
|
|
8c25395805 | ||
|
|
36b913124b | ||
|
|
2fa6343fe5 | ||
|
|
06b84225a1 | ||
|
|
5b31da335d | ||
|
|
90773ab69f | ||
|
|
11d92bb22a | ||
|
|
b7734c3926 | ||
|
|
d3faf9c8dc | ||
|
|
bca97a1d14 | ||
|
|
ac9d0f18c5 | ||
|
|
09fa624797 | ||
|
|
b8333e351c | ||
|
|
a01423a196 | ||
|
|
7c35df7a82 | ||
|
|
2b90f377e6 | ||
|
|
fff7326209 | ||
|
|
c181e500bc | ||
|
|
16b7271826 | ||
|
|
4a1f62b185 | ||
|
|
d23a0754c1 | ||
|
|
3ffb563a44 | ||
|
|
4e42f2a017 | ||
|
|
a0dfdb79df | ||
|
|
a85c5f9d4e | ||
|
|
2720bba5b7 | ||
|
|
4634a7bc2f | ||
|
|
16d9b449c9 | ||
|
|
8761997757 | ||
|
|
19bba4abbc | ||
|
|
7839f0aac5 | ||
|
|
83def1db30 | ||
|
|
a0b29d1ffe | ||
|
|
f5479c56af | ||
|
|
246f0a45c8 | ||
|
|
fe871aad77 | ||
|
|
6f860e1bc4 | ||
|
|
249ea40ae3 | ||
|
|
20d8ae19a7 | ||
|
|
ad51aabfd7 | ||
|
|
1cf395c041 | ||
|
|
745179a5bf | ||
|
|
ff5d477fa5 | ||
|
|
907825601d | ||
|
|
c2ec26910a | ||
|
|
83f2aea123 | ||
|
|
a5c5439315 | ||
|
|
eca9b60235 | ||
|
|
d2d5d98d78 | ||
|
|
fb341b869b | ||
|
|
29e66cb186 | ||
|
|
307769b949 | ||
|
|
9a09e057d6 | ||
|
|
3e28659528 | ||
|
|
b861eef26f | ||
|
|
caaf006a49 | ||
|
|
b2429ec30c | ||
|
|
55aaf60a57 | ||
|
|
a5790d82f6 | ||
|
|
63f99af1e6 | ||
|
|
4eed2568aa | ||
|
|
fb7962c7f2 | ||
|
|
76e6b7b471 | ||
|
|
fccb7ff9ed | ||
|
|
3b12ef2e66 | ||
|
|
f9d099be1b | ||
|
|
c322c0e3a5 | ||
|
|
530fc20596 | ||
|
|
a23b4ed754 | ||
|
|
fc4f5077b0 | ||
|
|
6a553886da | ||
|
|
1065c7e722 | ||
|
|
a9c8a59f58 | ||
|
|
8730f7fd27 | ||
|
|
8f608223d7 | ||
|
|
a7cbd47a2f | ||
|
|
b80c3fe5a8 | ||
|
|
5080051e39 | ||
|
|
23bfc8d0ba | ||
|
|
80e9062041 | ||
|
|
67bd3420ed | ||
|
|
aea081703f | ||
|
|
f300d2a2d5 | ||
|
|
f150d7d83a | ||
|
|
4d1f059c0d | ||
|
|
bc7f953fcc | ||
|
|
f653483eea | ||
|
|
6b200fd36b | ||
|
|
161fc6cdf0 | ||
|
|
6f68ed6bce | ||
|
|
a4592ffdfe | ||
|
|
7cd7bd1a48 | ||
|
|
9eeca70292 | ||
|
|
02bfe30848 | ||
|
|
c9c99de3d9 | ||
|
|
8752f0cc60 | ||
|
|
5c65196e44 | ||
|
|
f5798bfe90 | ||
|
|
0e556b3468 | ||
|
|
31820f56e7 | ||
|
|
fd88828abd | ||
|
|
ae11159918 | ||
|
|
472a8605c0 | ||
|
|
e1760ba211 | ||
|
|
ce4c0a0aa4 | ||
|
|
64511593c4 | ||
|
|
b0e00dfceb | ||
|
|
fc465b463d | ||
|
|
68ce2e5232 | ||
|
|
81e8bb62ae | ||
|
|
2c13e1b923 | ||
|
|
a0748c2e3b | ||
|
|
40599bb751 | ||
|
|
f3c64ceea7 | ||
|
|
15c60de709 | ||
|
|
6dd316547f | ||
|
|
54c7676a44 | ||
|
|
d25b8966ce | ||
|
|
14a119c48c | ||
|
|
c82515a927 | ||
|
|
26e630c2dd | ||
|
|
13370d2056 | ||
|
|
35282db9e0 | ||
|
|
426fb88ce7 | ||
|
|
2384bd0e10 | ||
|
|
ba3f66d3d1 | ||
|
|
7293a0f670 | ||
|
|
9e86d46267 | ||
|
|
848430f062 | ||
|
|
abd21335c4 | ||
|
|
8fa95f058a | ||
|
|
d4e5ecd497 | ||
|
|
3830f76729 | ||
|
|
83f778fec9 | ||
|
|
cabd24605f | ||
|
|
ae20ba1148 | ||
|
|
3a50b64977 | ||
|
|
8692e74536 | ||
|
|
1c18bd9889 | ||
|
|
60e9d98d0a | ||
|
|
83f6625e0c | ||
|
|
acc09543b7 | ||
|
|
94d8c7e366 | ||
|
|
ea1a0c8b3d | ||
|
|
7bc88c17e4 | ||
|
|
33cf1bc4c3 | ||
|
|
9402e63fe1 | ||
|
|
90e4d494b2 | ||
|
|
da97e948ca | ||
|
|
89a07e8e74 | ||
|
|
3f3d0381e5 | ||
|
|
3649499dba | ||
|
|
a989d088fd | ||
|
|
f79a915136 | ||
|
|
12e8c3d449 | ||
|
|
4f7064575e | ||
|
|
070df826f1 | ||
|
|
fbe48a4b4e | ||
|
|
4dd497fb6d | ||
|
|
907882c0a7 | ||
|
|
d36d5aee3f | ||
|
|
c6824e5f5e | ||
|
|
199c21eede | ||
|
|
5162da5654 | ||
|
|
a1d82f6193 | ||
|
|
ea78e3d0c6 | ||
|
|
3497f00cb4 | ||
|
|
5355d45031 | ||
|
|
26693acc3f | ||
|
|
76e9fef3b2 | ||
|
|
c34308cbd4 | ||
|
|
5a10476010 | ||
|
|
46e80dceec | ||
|
|
90d1835353 | ||
|
|
845fadd0aa | ||
|
|
5748ded52c | ||
|
|
6a737fb734 | ||
|
|
3cd92ccda3 | ||
|
|
54e81aba11 | ||
|
|
d86cb4ded6 | ||
|
|
4d5375f6d6 | ||
|
|
424557fedb | ||
|
|
89251e603f | ||
|
|
a653ed07eb | ||
|
|
ad86deb014 | ||
|
|
9525dc7584 | ||
|
|
cd31dd27fd | ||
|
|
360e3670eb | ||
|
|
8dabe3b4c8 | ||
|
|
443e0c2806 | ||
|
|
9cc173cc4d | ||
|
|
b5f33e5ecd | ||
|
|
40dfc6860f | ||
|
|
1c02a04423 | ||
|
|
de0e45070c | ||
|
|
c169cc7d74 | ||
|
|
cd62ad76f6 | ||
|
|
dd25b0fb5b | ||
|
|
a38b22a6a2 | ||
|
|
830b8f2971 | ||
|
|
b058af122c | ||
|
|
174ee0cafc | ||
|
|
1c336380c0 | ||
|
|
3068880413 | ||
|
|
be596681e5 | ||
|
|
66b71c50e9 | ||
|
|
8744810b25 | ||
|
|
7f94d37c2e | ||
|
|
6d9b7baeb4 | ||
|
|
4470d4c352 | ||
|
|
d2a462a279 | ||
|
|
14ff2a15e7 | ||
|
|
6d1369900e | ||
|
|
1f17ebe69e | ||
|
|
1ae2918064 | ||
|
|
b6571e5cad | ||
|
|
7549d48cf1 | ||
|
|
00353dd0cb | ||
|
|
afd947195d | ||
|
|
e57ef37167 | ||
|
|
ef33a93654 | ||
|
|
61732aecfc | ||
|
|
6764c05c3f | ||
|
|
fa149cf4aa | ||
|
|
e4f9697d06 | ||
|
|
da061450e5 | ||
|
|
d09ae49287 | ||
|
|
511ee0bbaf | ||
|
|
3cb5a0fbd6 | ||
|
|
e06925ab85 | ||
|
|
184634e4e7 | ||
|
|
843c2d02cc | ||
|
|
8ea2455766 | ||
|
|
9dc9987d56 | ||
|
|
3458621147 | ||
|
|
079df5a47c | ||
|
|
ddb07c65a1 | ||
|
|
9b21cd222b | ||
|
|
90f736843f | ||
|
|
13c020eb61 | ||
|
|
dbc06dbe95 | ||
|
|
23d097bc1c | ||
|
|
db85b9808e | ||
|
|
df5bae37bc | ||
|
|
acc23b6051 | ||
|
|
61f2741afc | ||
|
|
4dd7ea886a | ||
|
|
1e8959fbcf | ||
|
|
48729678cf | ||
|
|
0684becaa7 | ||
|
|
db16bdf8cb | ||
|
|
f890318ed9 | ||
|
|
158510cbbe | ||
|
|
ce90cf7aa8 | ||
|
|
a3a3d006eb | ||
|
|
8fd029a4a1 | ||
|
|
2e1b52c1e5 | ||
|
|
3eb8348708 | ||
|
|
393f0c007c | ||
|
|
294e380288 | ||
|
|
4c1c42efac | ||
|
|
c062ca8c66 | ||
|
|
76dcb25103 | ||
|
|
c5b4f236db | ||
|
|
0974c940a8 | ||
|
|
cffa20d37e | ||
|
|
ef009edd29 | ||
|
|
3ca52b118d | ||
|
|
13f5fde4fb | ||
|
|
f512b55ec2 | ||
|
|
22b8ca0095 | ||
|
|
baf66a103d | ||
|
|
45faa9c1ff | ||
|
|
304381a88d | ||
|
|
fc9f54dbc8 | ||
|
|
7199dc187f | ||
|
|
e9ae066d53 | ||
|
|
d71ae406ff | ||
|
|
f3216904b3 | ||
|
|
5958b69ec9 | ||
|
|
7d4e2cb39a | ||
|
|
a483ec0cea | ||
|
|
c1421e0874 | ||
|
|
ce89869c3c | ||
|
|
b8b57e34ff | ||
|
|
bc7f627253 | ||
|
|
652156e398 | ||
|
|
9febb071c6 | ||
|
|
7d0e1568ac | ||
|
|
b4e711f411 | ||
|
|
1b5be1b981 | ||
|
|
49d8707c58 | ||
|
|
9192f6f7f7 | ||
|
|
05022e3745 | ||
|
|
5356e9ddeb | ||
|
|
52acf76e2c | ||
|
|
40cdbd3b45 | ||
|
|
5487c0befe | ||
|
|
8bb16c48c0 | ||
|
|
c6384363f9 | ||
|
|
8993e8ad3e | ||
|
|
289989d9f7 | ||
|
|
dc2ae0e6f1 | ||
|
|
9c966c152d | ||
|
|
4efae41048 | ||
|
|
b8437032e9 | ||
|
|
2d339ca81b | ||
|
|
d53abc9696 | ||
|
|
446c886d38 | ||
|
|
30c6d9b5ae | ||
|
|
5e42996b36 | ||
|
|
ceca7b85bf | ||
|
|
a4d54f58c8 | ||
|
|
005a0e1bad | ||
|
|
46d97fd57d | ||
|
|
72a26b6353 | ||
|
|
89a4033fbf | ||
|
|
39a5dc64bd | ||
|
|
d4bdd9b1b7 | ||
|
|
2f5ba87280 | ||
|
|
8b45d6c750 | ||
|
|
4ecd4df2d4 | ||
|
|
a42f31fe52 | ||
|
|
d4480b695e | ||
|
|
c4b5f7fbae | ||
|
|
ba915f2cc0 | ||
|
|
4b91140f31 | ||
|
|
9879878dd0 | ||
|
|
d78105d57c | ||
|
|
153c9e3565 | ||
|
|
c11623596d | ||
|
|
e791a77f77 | ||
|
|
b641bffb2c | ||
|
|
ee0c47ac1e | ||
|
|
eba90e9343 | ||
|
|
d8374d0fa5 | ||
|
|
fa61744c6d | ||
|
|
4fec55cc01 | ||
|
|
1767413712 | ||
|
|
734c8fa84f | ||
|
|
9a8d422554 | ||
|
|
b21e945c76 | ||
|
|
a02bf1ea09 | ||
|
|
eda82bac92 | ||
|
|
e8d4f7dc4f | ||
|
|
c4a93b7789 | ||
|
|
c3f9925097 | ||
|
|
2a0cf7511a | ||
|
|
d0a70d3339 | ||
|
|
f37e4675dd | ||
|
|
4e32f67eeb | ||
|
|
36d54cab52 | ||
|
|
9d8df10dcf | ||
|
|
45ea88e070 | ||
|
|
d5d0b947f5 | ||
|
|
f775f1f11e | ||
|
|
f1e888f3de | ||
|
|
71c8436e90 | ||
|
|
08c69f5e9b | ||
|
|
a50fafaca2 | ||
|
|
3c6781d240 | ||
|
|
3b8b5625f8 | ||
|
|
6be2034110 | ||
|
|
924dc79f00 | ||
|
|
ccb9030d3c | ||
|
|
8623287ac1 | ||
|
|
022c13f3a4 | ||
|
|
0687916e7f | ||
|
|
bb868b83ba | ||
|
|
24298130b9 | ||
|
|
6e5ee92ebd | ||
|
|
5b91fe04aa | ||
|
|
1623deb3ee | ||
|
|
4a16e05b7a | ||
|
|
f1c04bc60d | ||
|
|
84c6f31c76 | ||
|
|
9d528190bf | ||
|
|
0f23b209ad | ||
|
|
63d9325900 | ||
|
|
f342097f81 | ||
|
|
b4806c4366 | ||
|
|
ff37d8a577 | ||
|
|
a773eb7893 | ||
|
|
7c67513d24 | ||
|
|
6ed85029c5 | ||
|
|
e9c57ddf4d | ||
|
|
a33ce97ed9 | ||
|
|
b788a3dd4e | ||
|
|
fccfa92d7e | ||
|
|
8705bf0a70 | ||
|
|
9318138af7 | ||
|
|
269fa7d2d5 | ||
|
|
e99837a8b9 | ||
|
|
553861a2c4 | ||
|
|
628a85d1be | ||
|
|
2cb54514a4 | ||
|
|
6db22827f2 | ||
|
|
4cc6d5426b | ||
|
|
7d258b5202 | ||
|
|
c8d19ee0bc | ||
|
|
d891312032 | ||
|
|
5edbf4ce32 | ||
|
|
3ddbdd713d | ||
|
|
9ba107b511 | ||
|
|
c9adddb76a | ||
|
|
f0a12d5ff5 | ||
|
|
7cce224499 | ||
|
|
97397ca585 | ||
|
|
f2fbc602a8 | ||
|
|
925d728a86 | ||
|
|
f5f229871b | ||
|
|
9917552b4b | ||
|
|
adca89b973 | ||
|
|
29bfbecdc9 | ||
|
|
1a7a8c98d9 | ||
|
|
cddb38ac3d | ||
|
|
394853c0fb | ||
|
|
c0702c8b36 | ||
|
|
d610608391 | ||
|
|
9082eec91d | ||
|
|
f1a1413b5f | ||
|
|
c1e7f9af9b | ||
|
|
1c71c4e38b | ||
|
|
5e3eccb3f6 | ||
|
|
e1dc037eb9 | ||
|
|
97e9b4c801 | ||
|
|
52d7cad735 | ||
|
|
c0b1d270ba | ||
|
|
e59a2892e4 | ||
|
|
5fa0376a49 | ||
|
|
05a33042c8 | ||
|
|
ce58f23cbc | ||
|
|
b6fc9fa370 | ||
|
|
00ae38faae | ||
|
|
ab28ee58ab | ||
|
|
48db538a2e | ||
|
|
46945942e1 | ||
|
|
a24b26a1ef | ||
|
|
6f8421cdd5 | ||
|
|
284cd9bca9 | ||
|
|
23fd6b8d2b | ||
|
|
4f0ea5d756 | ||
|
|
6c218331b1 | ||
|
|
cea7fb7490 | ||
|
|
8acf2dbdfe | ||
|
|
0542700f90 | ||
|
|
5264f7ce18 | ||
|
|
051ffd78a3 | ||
|
|
bea95d4fae | ||
|
|
fdf7bc312f | ||
|
|
5b094e1097 | ||
|
|
9ad3968084 | ||
|
|
3958b6aae1 | ||
|
|
eaa413caf0 | ||
|
|
9095225b5b | ||
|
|
c529f86dbc | ||
|
|
e4fcfa356a | ||
|
|
8218cff7c1 | ||
|
|
6949bbcf39 | ||
|
|
480c60c0a7 | ||
|
|
eec10cb5db | ||
|
|
02c83d8689 | ||
|
|
72b1cacea1 | ||
|
|
c72cda3386 | ||
|
|
867442155e | ||
|
|
229b14b6fc | ||
|
|
158c87ab8b | ||
|
|
cb303e6109 | ||
|
|
a77a8741b5 | ||
|
|
3d63459c25 | ||
|
|
ce63de3c58 | ||
|
|
4b3b1219b5 | ||
|
|
73b069a76c | ||
|
|
101cf8d108 | ||
|
|
2e926dfb6e | ||
|
|
501866d12a | ||
|
|
39bcb0869f | ||
|
|
a7b99cde4e | ||
|
|
60abcd92a3 | ||
|
|
cdd36e7052 | ||
|
|
c6ac175ce4 | ||
|
|
46bcd87c23 | ||
|
|
ab74be8e33 | ||
|
|
d8298b3eab | ||
|
|
50e60e6d05 | ||
|
|
5d02acbf37 | ||
|
|
8901d91f96 | ||
|
|
b55021bb3d | ||
|
|
0ef51b85e6 | ||
|
|
c77566cc02 | ||
|
|
c1bcedfb51 | ||
|
|
08b592816b | ||
|
|
8ef788e799 | ||
|
|
3ce57ef851 |
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
2
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
@@ -79,8 +79,6 @@ body:
|
||||
description: |
|
||||
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
|
||||
options:
|
||||
- wx(个人微信, itchat)
|
||||
- wxy(个人微信, wechaty)
|
||||
- wechatmp(公众号, 订阅号)
|
||||
- wechatmp_service(公众号, 服务号)
|
||||
- terminal
|
||||
|
||||
11
.github/workflows/deploy-image-arm.yml
vendored
11
.github/workflows/deploy-image-arm.yml
vendored
@@ -19,7 +19,7 @@ env:
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
if: github.repository == 'zhayujie/CowAgent'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -51,7 +51,12 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/zhayujie/chatgpt-on-wechat
|
||||
${{ env.REGISTRY }}/zhayujie/cowagent
|
||||
tags: |
|
||||
type=raw,value=latest-arm64,enable={{is_default_branch}}
|
||||
type=ref,event=branch,suffix=-arm64
|
||||
type=ref,event=tag,suffix=-arm64
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
@@ -60,7 +65,7 @@ jobs:
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
platforms: linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}-arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
|
||||
13
.github/workflows/deploy-image.yml
vendored
13
.github/workflows/deploy-image.yml
vendored
@@ -16,10 +16,11 @@ on:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
DOCKERHUB_IMAGE: zhayujie/chatgpt-on-wechat
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
if: github.repository == 'zhayujie/CowAgent'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -47,8 +48,14 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
zhayujie/chatgpt-on-wechat
|
||||
zhayujie/cowagent
|
||||
${{ env.REGISTRY }}/zhayujie/chatgpt-on-wechat
|
||||
${{ env.REGISTRY }}/zhayujie/cowagent
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
|
||||
13
.gitignore
vendored
13
.gitignore
vendored
@@ -3,16 +3,15 @@
|
||||
.vscode
|
||||
.venv
|
||||
.vs
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
*.pyc
|
||||
python
|
||||
config.json
|
||||
QR.png
|
||||
nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
logs/
|
||||
workspace
|
||||
@@ -33,8 +32,16 @@ plugins/banwords/lib/__pycache__
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
!plugins/agent
|
||||
!plugins/cow_cli
|
||||
client_config.json
|
||||
ref/
|
||||
**/.dev.vars
|
||||
.cursor/
|
||||
local/
|
||||
node_modules/
|
||||
|
||||
# cow cli
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
.cow.pid
|
||||
|
||||
897
README.md
897
README.md
@@ -1,738 +1,259 @@
|
||||
<p align="center"><img src= "https://github.com/user-attachments/assets/31fb4eab-3be4-477d-aa76-82cf62bfd12c" alt="Chatgpt-on-Wechat" width="600" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/eca9a9ec-8534-4615-9e0f-96c5ac1d10a3" alt="CowAgent" width="420" /></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/chatgpt-on-wechat" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/chatgpt-on-wechat" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat"><img src="https://img.shields.io/github/stars/zhayujie/chatgpt-on-wechat?style=flat-square" alt="Stars"></a> <br/>
|
||||
<a href="https://github.com/zhayujie/CowAgent/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/CowAgent" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/CowAgent/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/CowAgent" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/CowAgent"><img src="https://img.shields.io/github/stars/zhayujie/CowAgent?style=flat-square" alt="Stars"></a> <br/>
|
||||
[English] | [<a href="docs/zh/README.md">中文</a>] | [<a href="docs/ja/README.md">日本語</a>]
|
||||
</p>
|
||||
|
||||
**chatgpt-on-wechat**(简称CoW)项目是基于大模型的智能对话机器人,支持自由切换多种模型,可接入网页、微信公众号、企业微信应用、飞书、钉钉中使用,能处理文本、语音、图片、文件等多模态消息,支持通过插件访问操作系统和互联网等外部资源,以及基于自有知识库定制企业AI应用。
|
||||
**CowAgent** is an open-source super AI assistant that proactively plans tasks, controls your computer and external services, creates and runs Skills, and grows alongside you through a personal knowledge base and long-term memory — a reference implementation of Agent Harness engineering.
|
||||
|
||||
# 简介
|
||||
CowAgent is lightweight, easy to deploy, and built to extend. Plug in any major LLM provider and run it 24/7 on a personal computer or server, across the web and all major IM platforms.
|
||||
|
||||
> 该项目既是一个可以开箱即用的对话机器人,也是一个支持高度扩展的AI应用框架,可以通过为项目添加大模型接口、接入渠道、自定义插件来灵活实现各种定制需求。支持的功能如下:
|
||||
|
||||
- ✅ **多端部署:** 有多种部署方式可选择且功能完备,目前已支持网页、微信公众号、企业微信应用、飞书、钉钉等部署方式
|
||||
- ✅ **基础对话:** 私聊及群聊的AI智能回复,支持多轮会话上下文记忆,基础模型支持OpenAI, Claude, Gemini, DeepSeek, 通义千问, Kimi, 文心一言, 讯飞星火, ChatGLM, MiniMax, GiteeAI, ModelScope, LinkAI
|
||||
- ✅ **语音能力:** 可识别语音消息,通过文字或语音回复,支持 openai(whisper/tts), azure, baidu, google 等多种语音模型
|
||||
- ✅ **图像能力:** 支持图片生成、图片识别、图生图,可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型
|
||||
- ✅ **丰富插件:** 支持自定义插件扩展,已实现多角色切换、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索、智能体等内置插件
|
||||
- ✅ **Agent能力:** 支持访问浏览器、终端、文件系统、搜索引擎等各类工具,并可通过多智能体协作完成复杂任务,基于 [AgentMesh](https://github.com/MinimalFuture/AgentMesh) 框架实现
|
||||
- ✅ **知识库:** 通过上传知识库自定义专属机器人,可作为数字分身、智能客服、企业智能体使用,基于 [LinkAI](https://link-ai.tech) 实现
|
||||
|
||||
## 声明
|
||||
|
||||
1. 本项目遵循 [MIT开源协议](/LICENSE),仅用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为。任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任
|
||||
2. 境内使用该项目时,建议使用国内厂商的大模型服务,并进行必要的内容安全审核及过滤
|
||||
3. 本项目当前主要接入协同办公平台,推荐使用网页、公众号、企微自建应用、钉钉、飞书等接入通道,其他通道为历史产物暂不维护
|
||||
|
||||
## 演示
|
||||
|
||||
DEMO视频:https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
## 社区
|
||||
|
||||
添加小助手微信加入开源项目交流群:
|
||||
|
||||
<img width="140" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
<p align="center">
|
||||
<a href="https://cowagent.ai/">🌐 Website</a> ·
|
||||
<a href="https://docs.cowagent.ai/intro/index">📖 Docs</a> ·
|
||||
<a href="https://docs.cowagent.ai/guide/quick-start">🚀 Quick Start</a> ·
|
||||
<a href="https://skills.cowagent.ai/">🧩 Skill Hub</a> ·
|
||||
<a href="https://link-ai.tech/cowagent/create">☁️ Try Online</a>
|
||||
</p>
|
||||
|
||||
<br/>
|
||||
|
||||
# 企业服务
|
||||
## 🌟 Highlights
|
||||
|
||||
<a href="https://link-ai.tech" target="_blank"><img width="720" src="https://cdn.link-ai.tech/image/link-ai-intro.jpg"></a>
|
||||
|
||||
> [LinkAI](https://link-ai.tech/) 是面向企业和开发者的一站式AI智能体平台,聚合多模态大模型、知识库、Agent 插件、工作流等能力,支持一键接入主流平台并进行管理,支持SaaS、私有化部署等多种模式。
|
||||
>
|
||||
> LinkAI 目前已在智能客服、私域运营、企业效率助手等场景积累了丰富的AI解决方案,在消费、健康、文教、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。
|
||||
|
||||
**产品咨询和企业服务** 可联系产品客服:
|
||||
|
||||
<img width="150" src="https://cdn.link-ai.tech/portal/linkai-customer-service.png">
|
||||
| Capability | Description |
|
||||
| :--- | :--- |
|
||||
| [Planning](https://docs.cowagent.ai/intro/architecture) | Decomposes complex tasks and executes them step by step, looping over tools until the goal is reached |
|
||||
| [Memory](https://docs.cowagent.ai/memory/index) | Three-tier architecture (context → daily → core), automatic Deep Dream distillation, hybrid keyword + vector retrieval |
|
||||
| [Knowledge](https://docs.cowagent.ai/knowledge/index) | Auto-curates structured knowledge into a Markdown wiki, builds an evolving knowledge graph with visual browsing |
|
||||
| [Skills](https://docs.cowagent.ai/skills/index) | One-click install from [Skill Hub](https://skills.cowagent.ai/), GitHub, ClawHub; or create custom skills via natural-language conversation |
|
||||
| [Tools](https://docs.cowagent.ai/tools/index) | Built-in file I/O, terminal, browser, scheduler, memory retrieval, web search, and 10+ more tools — with native MCP integration |
|
||||
| [Channels](https://docs.cowagent.ai/channels/index) | Integrates with Web, WeChat, Feishu, DingTalk, WeCom, QQ, Official Accounts, Telegram, and Slack |
|
||||
| Multimodal | First-class support for text, images, voice, and files — recognition, generation, and delivery |
|
||||
| [Models](https://docs.cowagent.ai/models/index) | Claude, GPT, Gemini, DeepSeek, Qwen, GLM, Kimi, MiniMax, Doubao, and more — swap providers from the Web console with one click |
|
||||
| [Deploy](https://docs.cowagent.ai/guide/quick-start) | One-line installer, unified Web console, multiple deployment modes (local, Docker, server) |
|
||||
|
||||
<br/>
|
||||
|
||||
# 🏷 更新日志
|
||||
## 🏗️ Architecture
|
||||
|
||||
>**2025.05.23:** [1.7.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.6) 优化web网页channel、新增 [AgentMesh多智能体插件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md)、百度语音合成优化、企微应用`access_token`获取优化、支持`claude-4-sonnet`和`claude-4-opus`模型
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/architecture/en/architecture.jpg" alt="CowAgent Architecture" width="750"/>
|
||||
|
||||
>**2025.04.11:** [1.7.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.5) 新增支持 [wechatferry](https://github.com/zhayujie/chatgpt-on-wechat/pull/2562) 协议、新增 deepseek 模型、新增支持腾讯云语音能力、新增支持 ModelScope 和 Gitee-AI API接口
|
||||
CowAgent is a complete **Agent Harness**: messages flow in through **Channels**; the **Agent Core** plans and reasons over memory, knowledge, and the available tools and skills; **Models** generate the response, which is sent back through the originating channel. Every layer is decoupled and independently extensible.
|
||||
|
||||
>**2024.12.13:** [1.7.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.4) 新增 Gemini 2.0 模型、新增web channel、解决内存泄漏问题、解决 `#reloadp` 命令重载不生效问题
|
||||
|
||||
>**2024.10.31:** [1.7.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.3) 程序稳定性提升、数据库功能、Claude模型优化、linkai插件优化、离线通知
|
||||
|
||||
更多更新历史请查看: [更新日志](/docs/version/release-notes.md)
|
||||
Read more in [Architecture](https://docs.cowagent.ai/intro/architecture).
|
||||
|
||||
<br/>
|
||||
|
||||
# 🚀 快速开始
|
||||
## 🚀 Quick Start
|
||||
|
||||
项目提供了一键安装、启动、管理程序的脚本,可以选择使用脚本快速运行,也可以根据详细指引一步步安装运行。
|
||||
A one-line installer takes care of dependencies, configuration, and startup:
|
||||
|
||||
- 详细文档:[快速开始](https://docs.link-ai.tech/cow/quick-start)
|
||||
|
||||
- 一键安装脚本说明:[一键安装脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/%E4%B8%80%E9%94%AE%E5%AE%89%E8%A3%85%E5%90%AF%E5%8A%A8%E8%84%9A%E6%9C%AC)
|
||||
**Linux / macOS:**
|
||||
|
||||
```bash
|
||||
bash <(curl -sS https://cdn.link-ai.tech/code/cow/install.sh)
|
||||
bash <(curl -fsSL https://cdn.link-ai.tech/code/cow/run.sh)
|
||||
```
|
||||
|
||||
- 项目管理脚本说明:[项目管理脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/%E9%A1%B9%E7%9B%AE%E7%AE%A1%E7%90%86%E8%84%9A%E6%9C%AC)
|
||||
**Windows (PowerShell):**
|
||||
|
||||
## 一、准备
|
||||
```powershell
|
||||
irm https://cdn.link-ai.tech/code/cow/run.ps1 | iex
|
||||
```
|
||||
|
||||
### 1. 模型账号
|
||||
|
||||
项目默认使用ChatGPT模型,需前往 [OpenAI平台](https://platform.openai.com/api-keys) 创建API Key并填入项目配置文件中。同时支持其他国内外产商以及第三方自定义模型接口,详情参考:[模型说明](#模型说明)。
|
||||
|
||||
同时支持使用 **LinkAI平台** 接口,可聚合使用 OpenAI、Claude、DeepSeek、Kimi、Qwen 等多种常用模型,并支持知识库、工作流、联网搜索、MJ绘图、文档总结等能力。修改配置即可一键启用,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
### 2.环境安装
|
||||
|
||||
支持 Linux、MacOS、Windows 系统,同时需安装 `Python`,Python版本需要在3.7以上,推荐使用3.9版本。
|
||||
|
||||
> 注意:选择Docker部署则无需安装python环境和下载源码,可直接快进到下一节。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
**Docker:**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/zhayujie/chatgpt-on-wechat
|
||||
cd chatgpt-on-wechat/
|
||||
curl -O https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
若遇到网络问题可使用国内仓库地址:https://gitee.com/zhayujie/chatgpt-on-wechat
|
||||
Once started, open `http://localhost:9899` to access the **Web console** — your one-stop hub to chat with the Agent, configure models, connect channels, and install skills.
|
||||
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
> Deploying on a server? Set `web_host` to `0.0.0.0` in `config.json` to make the console reachable from outside, and set `web_password` to protect it. Don't forget to open port `9899` in your firewall or security group.
|
||||
|
||||
> 📖 Detailed guides: [Quick Start](https://docs.cowagent.ai/guide/quick-start) · [Install from Source](https://docs.cowagent.ai/guide/manual-install) · [Upgrade](https://docs.cowagent.ai/guide/upgrade)
|
||||
|
||||
After installation, manage the service with the [cow CLI](https://docs.cowagent.ai/cli/index):
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
cow start | stop | restart # service control
|
||||
cow status | logs # status and logs
|
||||
cow update # pull latest code and restart
|
||||
cow skill install <name> # install a skill
|
||||
cow install-browser # install browser automation
|
||||
```
|
||||
|
||||
**(3) 拓展依赖 (可选,建议安装):**
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements-optional.txt
|
||||
```
|
||||
如果某项依赖安装失败可注释掉对应的行后重试。
|
||||
|
||||
## 二、配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
|
||||
```bash
|
||||
cp config-template.json config.json
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证JSON格式的规范):
|
||||
|
||||
```bash
|
||||
# config.json 文件内容示例
|
||||
{
|
||||
"channel_type": "web", # 接入渠道类型,默认为web,支持修改为:terminal, wechatmp, wechatmp_service, wechatcom_app, dingtalk, feishu
|
||||
"model": "gpt-4o-mini", # 模型名称, 支持 gpt-4o-mini, gpt-4.1, gpt-4o, deepseek-reasoner, wenxin, xunfei, glm-4, claude-3-7-sonnet-latest, moonshot等
|
||||
"open_ai_api_key": "YOUR API KEY", # 如果使用openAI模型则填入上面创建的 OpenAI API KEY
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # OpenAI接口代理地址,修改此项可接入第三方模型接口
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"voice_reply_voice": false, # 是否使用语音回复语音
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 系统提示词
|
||||
# 订阅欢迎语,公众号和企业微信channel中使用,当被订阅时会自动回复以下内容
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台的智能体
|
||||
"linkai_api_key": "", # LinkAI Api Key
|
||||
"linkai_app_code": "" # LinkAI 应用或工作流的code
|
||||
}
|
||||
```
|
||||
|
||||
**详细配置说明:**
|
||||
|
||||
<details>
|
||||
<summary>1. 单聊配置</summary>
|
||||
|
||||
+ 个人聊天中,需要以 "bot"或"@bot" 为开头的内容触发机器人,对应配置项 `single_chat_prefix` (如果不需要以前缀触发可以填写 `"single_chat_prefix": [""]`)
|
||||
+ 机器人回复的内容会以 "[bot] " 作为前缀, 以区分真人,对应的配置项为 `single_chat_reply_prefix` (如果不需要前缀可以填写 `"single_chat_reply_prefix": ""`)
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>2. 群聊配置</summary>
|
||||
|
||||
+ 群组聊天中,群名称需配置在 `group_name_white_list ` 中才能开启群聊自动回复。如果想对所有群聊生效,可以直接填写 `"group_name_white_list": ["ALL_GROUP"]`
|
||||
+ 默认只要被人 @ 就会触发机器人自动回复;另外群聊天中只要检测到以 "@bot" 开头的内容,同样会自动回复(方便自己触发),这对应配置项 `group_chat_prefix`
|
||||
+ 可选配置: `group_name_keyword_white_list`配置项支持模糊匹配群名称,`group_chat_keyword`配置项则支持模糊匹配群消息内容,用法与上述两个配置项相同。(Contributed by [evolay](https://github.com/evolay))
|
||||
+ `group_chat_in_one_session`:使群聊共享一个会话上下文,配置 `["ALL_GROUP"]` 则作用于所有群聊
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>3. 语音配置</summary>
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>4. 其他配置</summary>
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-4o-mini`, `gpt-4.1`, `gpt-4o`, `gpt-3.5-turbo`, `wenxin` , `claude` , `gemini`, `glm-4`, `xunfei`, `moonshot`等,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档,在[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中检查哪些参数在本项目中是可配置的。
|
||||
+ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
|
||||
+ `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
|
||||
+ `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
|
||||
+ `hot_reload`: 程序退出后,暂存等于状态,默认关闭。
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>5. LinkAI配置</summary>
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台的Agent,使用知识库、工作流、联网搜索、`Midjourney` 绘画等能力, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用或工作流的code,选填
|
||||
</details>
|
||||
|
||||
注:完整配置项说明可在 [`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py) 文件中查看。
|
||||
|
||||
## 三、运行
|
||||
|
||||
### 1.本地运行
|
||||
|
||||
如果是个人计算机 **本地运行**,直接在项目根目录下执行:
|
||||
|
||||
```bash
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
|
||||
运行后默认会启动一个web服务,可以通过访问 `http://localhost:9899/chat` 在网页端对话。如果需要接入其他应用通道只需修改 `config.json` 配置文件中的 `channel_type` 参数,详情参考:[通道说明](#通道说明)。
|
||||
|
||||
向机器人发送 `#help` 消息可以查看可用指令及插件的说明。
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
在服务器中可使用 `nohup` 命令在后台运行程序:
|
||||
|
||||
```bash
|
||||
nohup python3 app.py & tail -f nohup.out
|
||||
```
|
||||
|
||||
执行后程序运行于服务器后台,可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。 日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。
|
||||
|
||||
此外,项目的 `scripts` 目录下有一键运行、关闭程序的脚本供使用。 运行后默认channel为web,通过可以通过修改配置文件进行切换。
|
||||
|
||||
|
||||
### 3.Docker部署
|
||||
|
||||
使用docker部署无需下载源码和安装依赖,只需要获取 `docker-compose.yml` 配置文件并启动容器即可。
|
||||
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功后执行 `docker -v` 和 `docker-compose version` (或 `docker compose version`) 可查看到版本号。安装地址为 [docker官网](https://docs.docker.com/engine/install/) 。
|
||||
|
||||
**(1) 下载 docker-compose.yml 文件**
|
||||
|
||||
```bash
|
||||
wget https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
```
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 填写所需配置,例如 `CHANNEL_TYPE`、`OPEN_AI_API_KEY` 和等配置。
|
||||
|
||||
**(2) 启动容器**
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
|
||||
```bash
|
||||
sudo docker compose up -d # 若docker-compose为 1.X 版本,则执行 `sudo docker-compose up -d`
|
||||
```
|
||||
|
||||
运行命令后,会自动取 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取最新release版本的镜像。当执行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。最后执行以下命令可查看容器的运行日志:
|
||||
|
||||
```bash
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
**(3) 插件使用**
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
|
||||
```
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
**注**:使用docker方式部署的详细教程可以参考:[docker部署CoW项目](https://www.wangpc.cc/ai/docker-deploy-cow/)
|
||||
|
||||
|
||||
## 模型说明
|
||||
|
||||
以下对所有可支持的模型的配置和使用方法进行说明,模型接口实现在项目的 `bot/` 目录下。
|
||||
>部分模型厂商接入有官方sdk和OpenAI兼容两种方式,建议使用OpenAI兼容的方式。
|
||||
|
||||
<details>
|
||||
<summary>OpenAI</summary>
|
||||
|
||||
1. API Key创建:在 [OpenAI平台](https://platform.openai.com/api-keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4.1-mini",
|
||||
"open_ai_api_key": "YOUR_API_KEY",
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"bot_type": "chatGPT"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 与OpenAI接口的 [model参数](https://platform.openai.com/docs/models) 一致,支持包括 o系列、gpt-4系列、gpt-3.5系列等模型
|
||||
- `open_ai_api_base`: 如果需要接入第三方代理接口,可通过修改该参数进行接入
|
||||
- `bot_type`: 使用OpenAI相关模型时无需填写。当使用第三方代理接口接入Claude等非OpenAI官方模型时,该参数设为 `chatGPT`
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>LinkAI</summary>
|
||||
|
||||
1. API Key创建:在 [LinkAI平台](https://link-ai.tech/console/interface) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"use_linkai": true,
|
||||
"linkai_api_key": "YOUR API KEY",
|
||||
"linkai_app_code": "YOUR APP CODE"
|
||||
}
|
||||
```
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台的智能体,使用知识库、工作流、数据库、联网搜索、MCP工具等丰富的Agent能力, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI平台的API Key,可在 [控制台](https://link-ai.tech/console/interface) 中创建
|
||||
+ `linkai_app_code`: LinkAI智能体 (应用或工作流) 的code,选填。智能体创建可参考 [说明文档](https://docs.link-ai.tech/platform/quick-start)
|
||||
+ `model`: model字段填写空则直接使用智能体的模型,可在平台中灵活切换,[模型列表](https://link-ai.tech/console/models)中的全部模型均可使用
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>DeepSeek</summary>
|
||||
|
||||
1. API Key创建:在 [DeepSeek平台](https://platform.deepseek.com/api_keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "deepseek-chat",
|
||||
"open_ai_api_key": "sk-xxxxxxxxxxx",
|
||||
"open_ai_api_base": "https://api.deepseek.com/v1"
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `deepseek-chat、deepseek-reasoner`,分别对应的是 V3 和 R1 模型
|
||||
- `open_ai_api_key`: DeepSeek平台的 API Key
|
||||
- `open_ai_api_base`: DeepSeek平台 BASE URL
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Azure</summary>
|
||||
|
||||
1. API Key创建:在 [DeepSeek平台](https://platform.deepseek.com/api_keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "",
|
||||
"use_azure_chatgpt": true,
|
||||
"open_ai_api_key": "e7ffc5dd84f14521a53f14a40231ea78",
|
||||
"open_ai_api_base": "https://linkai-240917.openai.azure.com/",
|
||||
"azure_deployment_id": "gpt-4.1",
|
||||
"azure_api_version": "2025-01-01-preview"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 留空即可
|
||||
- `use_azure_chatgpt`: 设为 true
|
||||
- `open_ai_api_key`: Azure平台的密钥
|
||||
- `open_ai_api_base`: Azure平台的 BASE URL
|
||||
- `azure_deployment_id`: Azure平台部署的模型名称
|
||||
- `azure_api_version`: api版本以及以上参数可以在部署的 [模型配置](https://oai.azure.com/resource/deployments) 界面查看
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Claude</summary>
|
||||
|
||||
1. API Key创建:在 [Claude控制台](https://console.anthropic.com/settings/keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-sonnet-4-0",
|
||||
"claude_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 参考 [官方模型ID](https://docs.anthropic.com/en/docs/about-claude/models/overview#model-aliases) ,例如`claude-opus-4-0`、`claude-3-7-sonnet-latest`等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>通义千问</summary>
|
||||
|
||||
方式一:官方SDK接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-turbo",
|
||||
"dashscope_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写`qwen-turbo、qwen-plus、qwen-max`
|
||||
- `dashscope_api_key`: 通义千问的 API-KEY,参考 [官方文档](https://bailian.console.aliyun.com/?tab=api#/api) ,在 [控制台](https://bailian.console.aliyun.com/?tab=model#/api-key) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "qwen-turbo",
|
||||
"open_ai_api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"open_ai_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://help.aliyun.com/zh/model-studio/models?spm=a2c4g.11186623.0.0.78d84823Kth5on#9f8890ce29g5u)
|
||||
- `open_ai_api_base`: 通义千问API的 BASE URL
|
||||
- `open_ai_api_key`: 通义千问的 API-KEY,参考 [官方文档](https://bailian.console.aliyun.com/?tab=api#/api) ,在 [控制台](https://bailian.console.aliyun.com/?tab=model#/api-key) 创建
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Gemini</summary>
|
||||
|
||||
API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn) 创建API Key ,配置如下
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-pro",
|
||||
"gemini_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 参考[官方文档-模型列表](https://ai.google.dev/gemini-api/docs/models?hl=zh-cn)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Moonshot</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "moonshot-v1-8k",
|
||||
"moonshot_api_key": "moonshot-v1-8k"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写`moonshot-v1-8k、 moonshot-v1-32k、 moonshot-v1-128k`
|
||||
- `moonshot_api_key`: Moonshot的API-KEY,在 [控制台](https://platform.moonshot.cn/console/api-keys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "moonshot-v1-8k",
|
||||
"open_ai_api_base": "https://api.moonshot.cn/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写`moonshot-v1-8k、 moonshot-v1-32k、 moonshot-v1-128k`
|
||||
- `open_ai_api_base`: Moonshot的 BASE URL
|
||||
- `open_ai_api_key`: Moonshot的 API-KEY,在 [控制台](https://platform.moonshot.cn/console/api-keys) 创建
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>百度文心</summary>
|
||||
方式一:官方SDK接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "wenxin",
|
||||
"baidu_wenxin_api_key": "IajztZ0bDxgnP9bEykU7lBer",
|
||||
"baidu_wenxin_secret_key": "EDPZn6L24uAS9d8RWFfotK47dPvkjD6G"
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `wenxin`和`wenxin-4`,对应模型为 文心-3.5 和 文心-4.0
|
||||
- `baidu_wenxin_api_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 API Key
|
||||
- `baidu_wenxin_secret_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 Secret Key
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "qwen-turbo",
|
||||
"open_ai_api_base": "https://qianfan.baidubce.com/v2",
|
||||
"open_ai_api_key": "bce-v3/ALTxxxxxxd2b"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Wm9cvy6rl)
|
||||
- `open_ai_api_base`: 百度文心API的 BASE URL
|
||||
- `open_ai_api_key`: 百度文心的 API-KEY,参考 [官方文档](https://cloud.baidu.com/doc/qianfan-api/s/ym9chdsy5) ,在 [控制台](https://console.bce.baidu.com/iam/#/iam/apikey/list) 创建API Key
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>讯飞星火</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
参考 [官方文档-快速指引](https://www.xfyun.cn/doc/platform/quickguide.html#%E7%AC%AC%E4%BA%8C%E6%AD%A5-%E5%88%9B%E5%BB%BA%E6%82%A8%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8-%E5%BC%80%E5%A7%8B%E4%BD%BF%E7%94%A8%E6%9C%8D%E5%8A%A1) 获取 `APPID、 APISecret、 APIKey` 三个参数
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "xunfei",
|
||||
"xunfei_app_id": "",
|
||||
"xunfei_api_key": "",
|
||||
"xunfei_api_secret": "",
|
||||
"xunfei_domain": "4.0Ultra",
|
||||
"xunfei_spark_url": "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
}
|
||||
```
|
||||
- `model`: 填 `xunfei`
|
||||
- `xunfei_domain`: 可填写 `4.0Ultra、 generalv3.5、 max-32k、 generalv3、 pro-128k、 lite`
|
||||
- `xunfei_spark_url`: 填写参考 [官方文档-请求地址](https://www.xfyun.cn/doc/spark/Web.html#_1-1-%E8%AF%B7%E6%B1%82%E5%9C%B0%E5%9D%80) 的说明
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "4.0Ultra",
|
||||
"open_ai_api_base": "https://spark-api-open.xf-yun.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写 `4.0Ultra、 generalv3.5、 max-32k、 generalv3、 pro-128k、 lite`
|
||||
- `open_ai_api_base`: 讯飞星火平台的 BASE URL
|
||||
- `open_ai_api_key`: 讯飞星火平台的[APIPassword](https://console.xfyun.cn/services/bm3) ,因模型而已
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>智谱AI</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "glm-4-plus",
|
||||
"zhipu_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `glm-4-plus、glm-4-air-250414、glm-4-airx、glm-4-long 、glm-4-flashx 、glm-4-flash-250414`, 参考 [glm-4系列模型编码](https://bigmodel.cn/dev/api/normal-model/glm-4)
|
||||
- `zhipu_ai_api_key`: 智谱AI平台的 API KEY,在 [控制台](https://www.bigmodel.cn/usercenter/proj-mgmt/apikeys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "glm-4-plus",
|
||||
"open_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `glm-4-plus、glm-4-air-250414、glm-4-airx、glm-4-long 、glm-4-flashx 、glm-4-flash-250414`, 参考 [glm-4系列模型编码](https://bigmodel.cn/dev/api/normal-model/glm-4)
|
||||
- `open_ai_api_base`: 智谱AI平台的 BASE URL
|
||||
- `open_ai_api_key`: 智谱AI平台的 API KEY,在 [控制台](https://www.bigmodel.cn/usercenter/proj-mgmt/apikeys) 创建
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>MiniMax</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "abab6.5-chat",
|
||||
"Minimax_api_key": "",
|
||||
"Minimax_group_id": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填写`abab6.5-chat`
|
||||
- `Minimax_api_key`:MiniMax平台的API-KEY,在 [控制台](https://platform.minimaxi.com/user-center/basic-information/interface-key) 创建
|
||||
- `Minimax_group_id`: 在 [账户信息](https://platform.minimaxi.com/user-center/basic-information) 右上角获取
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "MiniMax-M1",
|
||||
"open_ai_api_base": "https://api.minimaxi.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填`MiniMax-M1、MiniMax-Text-01`,参考[API文档](https://platform.minimaxi.com/document/%E5%AF%B9%E8%AF%9D?key=66701d281d57f38758d581d0#QklxsNSbaf6kM4j6wjO5eEek)
|
||||
- `open_ai_api_base`: MiniMax平台API的 BASE URL
|
||||
- `open_ai_api_key`: MiniMax平台的API-KEY,在 [控制台](https://platform.minimaxi.com/user-center/basic-information/interface-key) 创建
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ModelScope</summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"bot_type": "modelscope",
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"modelscope_api_key": "your_api_key",
|
||||
"modelscope_base_url": "https://api-inference.modelscope.cn/v1/chat/completions",
|
||||
"text_to_image": "MusePublic/489_ckpt_FLUX_1"
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: modelscope接口格式
|
||||
- `model`: 参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
- `modelscope_api_key`: 参考 [官方文档-访问令牌](https://modelscope.cn/docs/accounts/token) ,在 [控制台](https://modelscope.cn/my/myaccesstoken)
|
||||
- `modelscope_base_url`: modelscope平台的 BASE URL
|
||||
- `text_to_image`: 图像生成模型,参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
</details>
|
||||
|
||||
|
||||
## 通道说明
|
||||
|
||||
以下对可接入通道的配置方式进行说明,应用通道代码在项目的 `channel/` 目录下。
|
||||
|
||||
<details>
|
||||
<summary>Web</summary>
|
||||
|
||||
项目启动后默认运行web通道,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "web",
|
||||
"web_port": 9899
|
||||
}
|
||||
```
|
||||
- `web_port`: 默认为 9899,可按需更改,需要服务器防火墙和安全组放行该端口
|
||||
- 如本地运行,启动后请访问 `http://localhost:port/chat` ;如服务器运行,请访问 `http://ip:port/chat`
|
||||
> 注:请将上述 url 中的 ip 或者 port 替换为实际的值
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Terminal</summary>
|
||||
|
||||
修改 `config.json` 中的 `channel_type` 字段:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "terminal"
|
||||
}
|
||||
```
|
||||
|
||||
运行后可在终端与机器人进行对话。
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>微信公众号</summary>
|
||||
|
||||
本项目支持订阅号和服务号两种公众号,通过服务号(`wechatmp_service`)体验更佳。将下列配置加入 `config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatmp",
|
||||
"wechatmp_token": "TOKEN",
|
||||
"wechatmp_port": 80,
|
||||
"wechatmp_app_id": "APPID",
|
||||
"wechatmp_app_secret": "APPSECRET",
|
||||
"wechatmp_aes_key": ""
|
||||
}
|
||||
```
|
||||
- `channel_type`: 个人订阅号为`wechatmp`,企业服务号为`wechatmp_service`
|
||||
|
||||
详细步骤和参数说明参考 [微信公众号接入](https://docs.link-ai.tech/cow/multi-platform/wechat-mp)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>企业微信应用</summary>
|
||||
|
||||
企业微信自建应用接入需在后台创建应用并启用消息回调,配置示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatcom_app",
|
||||
"wechatcom_corp_id": "CORPID",
|
||||
"wechatcomapp_token": "TOKEN",
|
||||
"wechatcomapp_port": 9898,
|
||||
"wechatcomapp_secret": "SECRET",
|
||||
"wechatcomapp_agent_id": "AGENTID",
|
||||
"wechatcomapp_aes_key": "AESKEY"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [企微自建应用接入](https://docs.link-ai.tech/cow/multi-platform/wechat-com)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>钉钉</summary>
|
||||
|
||||
钉钉需要在开放平台创建智能机器人应用,将以下配置填入 `config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "dingtalk",
|
||||
"dingtalk_client_id": "CLIENT_ID",
|
||||
"dingtalk_client_secret": "CLIENT_SECRET"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [钉钉接入](https://docs.link-ai.tech/cow/multi-platform/dingtalk)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>飞书</summary>
|
||||
|
||||
通过自建应用接入AI相关能力到飞书应用中,默认已是飞书的企业用户,且具有企业管理权限,将以下配置填入 `config.json`::
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "APP_ID",
|
||||
"feishu_app_secret": "APP_SECRET",
|
||||
"feishu_token": "VERIFICATION_TOKEN",
|
||||
"feishu_port": 80
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [飞书接入](https://docs.link-ai.tech/cow/multi-platform/feishu)
|
||||
</details>
|
||||
|
||||
<br/>
|
||||
|
||||
# 🔗 相关项目
|
||||
## 🤖 Models
|
||||
|
||||
- [bot-on-anything](https://github.com/zhayujie/bot-on-anything):轻量和高可扩展的大模型应用框架,支持接入Slack, Telegram, Discord, Gmail等海外平台,可作为本项目的补充使用。
|
||||
- [AgentMesh](https://github.com/MinimalFuture/AgentMesh):开源的多智能体(Multi-Agent)框架,可以通过多智能体团队的协同来解决复杂问题。本项目基于该框架实现了[Agent插件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md),可访问终端、浏览器、文件系统、搜索引擎 等各类工具,并实现了多智能体协同。
|
||||
CowAgent supports all mainstream LLM providers. **Chat, vision, image generation, ASR/TTS, and embeddings** can each be routed to a different vendor. Providers are configured directly in the Web console — no manual file editing required.
|
||||
|
||||
| Provider | Featured Models | Chat | Vision | Image Gen | ASR | TTS | Embedding |
|
||||
| --- | --- | :-: | :-: | :-: | :-: | :-: | :-: |
|
||||
| [Claude](https://docs.cowagent.ai/models/claude) | claude-opus-4-8 | ✅ | ✅ | | | | |
|
||||
| [OpenAI](https://docs.cowagent.ai/models/openai) | gpt-5.5, o-series | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Gemini](https://docs.cowagent.ai/models/gemini) | gemini-3.5-flash | ✅ | ✅ | ✅ | | | |
|
||||
| [DeepSeek](https://docs.cowagent.ai/models/deepseek) | deepseek-v4-flash / pro | ✅ | | | | | |
|
||||
| [Qwen](https://docs.cowagent.ai/models/qwen) | qwen3.7-max | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [GLM](https://docs.cowagent.ai/models/glm) | glm-5.1, glm-5v-turbo | ✅ | ✅ | | ✅ | | ✅ |
|
||||
| [Doubao](https://docs.cowagent.ai/models/doubao) | doubao-seed-2.0 series | ✅ | ✅ | ✅ | | | ✅ |
|
||||
| [Kimi](https://docs.cowagent.ai/models/kimi) | kimi-k2.6 | ✅ | ✅ | | | | |
|
||||
| [MiniMax](https://docs.cowagent.ai/models/minimax) | MiniMax-M2.7 | ✅ | ✅ | ✅ | | ✅ | |
|
||||
| [ERNIE](https://docs.cowagent.ai/models/qianfan) | ernie-5.1 | ✅ | ✅ | | | | |
|
||||
| [MiMo](https://docs.cowagent.ai/models/mimo) | mimo-v2.5 / pro | ✅ | ✅ | | | ✅ | |
|
||||
| [LinkAI](https://docs.cowagent.ai/models/linkai) | One key for 100+ models | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Custom](https://docs.cowagent.ai/models/custom) | Local models / third-party proxy | ✅ | | | | | |
|
||||
|
||||
> For details on each provider, see the [Models overview](https://docs.cowagent.ai/models/index).
|
||||
|
||||
# 🔎 常见问题
|
||||
<br/>
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
## 💬 Channels
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (知识库持续完善中,回复供参考)
|
||||
A single Agent instance can serve multiple channels in parallel. Most channels can be onboarded right from the Web console.
|
||||
|
||||
# 🛠️ 开发
|
||||
| Channel | Text | Image | File | Voice | Group |
|
||||
| --- | :-: | :-: | :-: | :-: | :-: |
|
||||
| [Web Console](https://docs.cowagent.ai/channels/web) (default) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Telegram](https://docs.cowagent.ai/channels/telegram) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Slack](https://docs.cowagent.ai/channels/slack) | ✅ | ✅ | ✅ | | ✅ |
|
||||
| [Discord](https://docs.cowagent.ai/channels/discord) | ✅ | ✅ | ✅ | | ✅ |
|
||||
| [WeChat](https://docs.cowagent.ai/channels/weixin) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Feishu / Lark](https://docs.cowagent.ai/channels/feishu) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [DingTalk](https://docs.cowagent.ai/channels/dingtalk) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [WeCom Bot](https://docs.cowagent.ai/channels/wecom-bot) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [QQ](https://docs.cowagent.ai/channels/qq) | ✅ | ✅ | ✅ | | ✅ |
|
||||
| [WeCom App](https://docs.cowagent.ai/channels/wecom) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [WeChat Customer Service](https://docs.cowagent.ai/channels/wechat-kf) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [WeChat Official Account](https://docs.cowagent.ai/channels/wechatmp) | ✅ | ✅ | | ✅ | |
|
||||
|
||||
欢迎接入更多应用通道,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py) 新增自定义通道,实现接收和发送消息逻辑即可完成接入。 同时欢迎贡献新的插件,参考 [插件开发文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
|
||||
> See the [Channels overview](https://docs.cowagent.ai/channels/index) for setup details.
|
||||
|
||||
# ✉ 联系
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-chat.png" alt="CowAgent Web Console" width="800"/>
|
||||
|
||||
欢迎提交PR、Issues进行反馈,以及通过 🌟Star 支持并关注项目更新。项目运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,以及前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品客服](https://cdn.link-ai.tech/portal/linkai-customer-service.png)咨询。
|
||||
*The Web console is the default channel and the unified entry point to configure models, channels, skills, memory, and more.*
|
||||
|
||||
# 🌟 贡献者
|
||||
<br/>
|
||||
|
||||

|
||||
## 🧠 Memory & Knowledge Base
|
||||
|
||||
**Long-term memory** uses a three-tier architecture: conversation context (short-term) → daily memory (mid-term) → MEMORY.md (long-term). A nightly **Deep Dream** pass distills scattered memories into refined long-term entries and a narrative journal. See [Long-term Memory](https://docs.cowagent.ai/memory/index) · [Deep Dream](https://docs.cowagent.ai/memory/deep-dream).
|
||||
|
||||
**Personal knowledge base** complements the time-ordered memory by organizing structured knowledge **by topic**. The Agent automatically curates valuable information from conversations, maintains cross-references and indexes, and the Web console offers an interactive knowledge-graph view. See [Personal Knowledge Base](https://docs.cowagent.ai/knowledge/index).
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="50%">
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-memory.png" alt="Long-term Memory" />
|
||||
<p align="center"><em>Long-term Memory · Three-tier architecture + Deep Dream</em></p>
|
||||
</td>
|
||||
<td width="50%">
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-knowledge.png" alt="Personal Knowledge Base" />
|
||||
<p align="center"><em>Knowledge Base · Auto-curated Markdown wiki</em></p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<br/>
|
||||
|
||||
## 🔧 Tools & Skills
|
||||
|
||||
**Tools** are atomic capabilities the Agent uses to interact with system resources. **Skills** are higher-level workflows defined by a manifest file that compose multiple tools to accomplish complex tasks.
|
||||
|
||||
### Tool System
|
||||
|
||||
**Built-in tools** cover file I/O (`read` / `write` / `edit` / `ls`), terminal (`bash`), file sending (`send`), memory retrieval (`memory`), environment variables (`env_config`), web fetching (`web_fetch`), scheduling (`scheduler`), web search (`web_search`), vision (`vision`), and browser automation (`browser`).
|
||||
|
||||
**MCP protocol** integrates the open ecosystem of [Model Context Protocol](https://modelcontextprotocol.io) servers. A single `mcp.json` is enough — supports stdio / SSE transports, hot reload, and zero-code integration.
|
||||
|
||||
Learn more: [Tools overview](https://docs.cowagent.ai/tools/index) · [MCP integration](https://docs.cowagent.ai/tools/mcp).
|
||||
|
||||
### Skills System
|
||||
|
||||
- **[Skill Hub](https://skills.cowagent.ai/)** — open skill marketplace: browse, search, install in one click
|
||||
- **GitHub / ClawHub / URL and more** — install skills from any source
|
||||
- **Conversational authoring** — generate custom skills through dialogue with `skill-creator`; turn any workflow or third-party API into a reusable skill
|
||||
|
||||
```bash
|
||||
/skill list # list installed skills
|
||||
/skill search <keyword> # search the marketplace
|
||||
/skill install <name> # one-click install
|
||||
```
|
||||
|
||||
Learn more: [Skills overview](https://docs.cowagent.ai/skills/index) · [Creating Skills](https://docs.cowagent.ai/skills/create).
|
||||
|
||||
<br/>
|
||||
|
||||
## 🏷 Changelog
|
||||
|
||||
> **2026.05.22:** [v2.0.9](https://github.com/zhayujie/CowAgent/releases/tag/2.0.9) — Model management, MCP protocol support, persistent browser sessions, new models (gpt-5.5, gemini-3.5-flash, qwen3.7-max), deployment hardening.
|
||||
|
||||
> **2026.05.06:** [v2.0.8](https://github.com/zhayujie/CowAgent/releases/tag/2.0.8) — Feishu channel overhaul (voice, streaming, QR onboarding), DeepSeek V4 and Baidu Qianfan support, scheduler tool upgrades.
|
||||
|
||||
> **2026.04.22:** [v2.0.7](https://github.com/zhayujie/CowAgent/releases/tag/2.0.7) — Built-in image generation (GPT Image 2, Nano Banana), new models (Kimi K2.6, Claude Opus 4.7, GLM 5.1), memory and knowledge enhancements.
|
||||
|
||||
> **2026.04.14:** [v2.0.6](https://github.com/zhayujie/CowAgent/releases/tag/2.0.6) — Knowledge base, Deep Dream memory distillation, smart context compression, multi-session Web console.
|
||||
|
||||
> **2026.04.01:** [v2.0.5](https://github.com/zhayujie/CowAgent/releases/tag/2.0.5) — Cow CLI, Skill Hub open source, browser tool, WeCom Bot QR onboarding.
|
||||
|
||||
> **2026.02.03:** [v2.0.0](https://github.com/zhayujie/CowAgent/releases/tag/2.0.0) — Major upgrade to a super Agent assistant with multi-step task planning, long-term memory, and the Skills framework.
|
||||
|
||||
Full history: [Release Notes](https://docs.cowagent.ai/releases/overview)
|
||||
|
||||
<br/>
|
||||
|
||||
## 🤝 Community & Support
|
||||
|
||||
[File an issue](https://github.com/zhayujie/CowAgent/issues) on GitHub, or scan the QR code below to join our WeChat community:
|
||||
|
||||
<img width="130" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
|
||||
<br/>
|
||||
|
||||
## 🔗 Related Projects
|
||||
|
||||
- **[Cow Skill Hub](https://github.com/zhayujie/cow-skill-hub)** — open skill marketplace for AI Agents; works with CowAgent, OpenClaw, Claude Code, and more
|
||||
- **[bot-on-anything](https://github.com/zhayujie/bot-on-anything)** — lightweight LLM application framework with integrations for Slack, Telegram, Discord, Gmail, and more
|
||||
- **[AgentMesh](https://github.com/MinimalFuture/AgentMesh)** — open-source multi-agent framework for solving complex problems through team collaboration
|
||||
|
||||
<br/>
|
||||
|
||||
## 🏢 Enterprise Services
|
||||
|
||||
[**LinkAI**](https://link-ai.tech/) is an all-in-one AI Agent platform for enterprises and developers, offering managed hosting and enterprise-grade support for CowAgent:
|
||||
|
||||
- **🚀 Zero-deployment hosted runtime** — spin up a [CowAgent online assistant](https://link-ai.tech/cowagent/create) in under a minute, no server required
|
||||
- **🧠 Agent infrastructure** — unified access to LLMs, knowledge bases, databases, skills, and workflows; plug-and-play building blocks that extend what CowAgent can do
|
||||
- **🏢 Team & enterprise features** — workspaces, role-based access, audit logs, and private deployment for production use cases
|
||||
|
||||
For enterprise inquiries: sales@simple-future.tech or [scan the QR code](https://cdn.link-ai.tech/consultant.jpg) to reach our team on WeChat.
|
||||
|
||||
<br/>
|
||||
|
||||
## 🛠️ Development & Contributing
|
||||
|
||||
Contributions are welcome — add a new channel by following the [Feishu channel reference](https://github.com/zhayujie/CowAgent/blob/master/channel/feishu/feishu_channel.py), or contribute new skills to [Skill Hub](https://skills.cowagent.ai/submit).
|
||||
|
||||
⭐ Star the project to follow updates, and feel free to open PRs and Issues.
|
||||
|
||||
## 🌟 Contributors
|
||||
|
||||

|
||||
|
||||
<br/>
|
||||
|
||||
## ⚠️ Disclaimer
|
||||
|
||||
1. This project is licensed under the [MIT License](/LICENSE) and is intended for technical research and learning. You are responsible for complying with applicable laws and regulations in your jurisdiction; the maintainers assume no liability for any consequences arising from use of this project.
|
||||
2. **Cost & safety:** Agent mode consumes substantially more tokens than regular chat — pick models that balance quality and cost. The Agent has access to your local operating system, so only deploy it in trusted environments.
|
||||
3. CowAgent is a pure open-source project and does not participate in, authorize, or issue any cryptocurrency.
|
||||
|
||||
<br/>
|
||||
|
||||
## 📌 Project Renaming Notice
|
||||
|
||||
This project was previously named `chatgpt-on-wechat` and is now officially **CowAgent**. The old GitHub URL redirects automatically; existing users may optionally run `git remote set-url origin https://github.com/zhayujie/CowAgent.git` to update the local remote.
|
||||
|
||||
3
agent/chat/__init__.py
Normal file
3
agent/chat/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.chat.service import ChatService
|
||||
|
||||
__all__ = ["ChatService"]
|
||||
290
agent/chat/service.py
Normal file
290
agent/chat/service.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
ChatService - Wraps the Agent stream execution to produce CHAT protocol chunks.
|
||||
|
||||
Translates agent events (message_update, message_end, tool_execution_end, etc.)
|
||||
into the CHAT socket protocol format (content chunks with segment_id, tool_calls chunks).
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""
|
||||
High-level service that runs an Agent for a given query and streams
|
||||
the results as CHAT protocol chunks via a callback.
|
||||
|
||||
Usage:
|
||||
svc = ChatService(agent_bridge)
|
||||
svc.run(query, session_id, send_chunk_fn)
|
||||
"""
|
||||
|
||||
def __init__(self, agent_bridge):
|
||||
"""
|
||||
:param agent_bridge: AgentBridge instance (manages agent lifecycle)
|
||||
"""
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def run(self, query: str, session_id: str, send_chunk_fn: Callable[[dict], None],
|
||||
channel_type: str = ""):
|
||||
"""
|
||||
Run the agent for *query* and stream results back via *send_chunk_fn*.
|
||||
|
||||
The method blocks until the agent finishes. After it returns the SDK
|
||||
will automatically send the final (streaming=false) message.
|
||||
|
||||
:param query: user query text
|
||||
:param session_id: session identifier for agent isolation
|
||||
:param send_chunk_fn: callable(chunk_data: dict) to send a streaming chunk
|
||||
:param channel_type: source channel (e.g. "web", "feishu") for persistence
|
||||
"""
|
||||
agent = self.agent_bridge.get_agent(session_id=session_id)
|
||||
if agent is None:
|
||||
raise RuntimeError("Failed to initialise agent for the session")
|
||||
|
||||
# Pass context metadata to model for downstream API requests
|
||||
if hasattr(agent, 'model'):
|
||||
agent.model.channel_type = channel_type or ""
|
||||
agent.model.session_id = session_id or ""
|
||||
|
||||
# State shared between the event callback and this method
|
||||
state = _StreamState()
|
||||
|
||||
def on_event(event: dict):
|
||||
"""Translate agent events into CHAT protocol chunks."""
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "reasoning_update":
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "reasoning",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_update":
|
||||
# Incremental text delta
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_end":
|
||||
# A content segment finished.
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
# After tool_calls are executed the next content will be
|
||||
# a new segment; collect tool results until turn_end.
|
||||
state.pending_tool_results = []
|
||||
|
||||
elif event_type == "file_to_send":
|
||||
url = data.get("url") or ""
|
||||
if url:
|
||||
fname = data.get("file_name") or "file"
|
||||
ft = data.get("file_type") or "file"
|
||||
if ft == "image":
|
||||
link = f""
|
||||
else:
|
||||
link = f"[{fname}]({url})"
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": "\n\n" + link + "\n\n",
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
# Remove url so the model won't repeat it in its reply
|
||||
data.pop("url", None)
|
||||
|
||||
elif event_type == "tool_execution_start":
|
||||
# Notify the client that a tool is about to run (with its input args)
|
||||
tool_name = data.get("tool_name", "")
|
||||
arguments = data.get("arguments", {})
|
||||
# Cache arguments keyed by tool_call_id so tool_execution_end can include them
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
state.pending_tool_arguments[tool_call_id] = arguments
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_start",
|
||||
"tool": tool_name,
|
||||
"arguments": arguments,
|
||||
})
|
||||
|
||||
elif event_type == "tool_execution_end":
|
||||
tool_name = data.get("tool_name", "")
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
# Retrieve cached arguments from the matching tool_execution_start event
|
||||
arguments = state.pending_tool_arguments.pop(tool_call_id, data.get("arguments", {}))
|
||||
result = data.get("result", "")
|
||||
status = data.get("status", "unknown")
|
||||
execution_time = data.get("execution_time", 0)
|
||||
elapsed_str = f"{execution_time:.2f}s"
|
||||
|
||||
# Serialise result to string if needed
|
||||
if not isinstance(result, str):
|
||||
import json
|
||||
try:
|
||||
result = json.dumps(result, ensure_ascii=False)
|
||||
except Exception:
|
||||
result = str(result)
|
||||
|
||||
tool_info = {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
"result": result,
|
||||
"status": status,
|
||||
"elapsed": elapsed_str,
|
||||
}
|
||||
|
||||
if state.pending_tool_results is not None:
|
||||
state.pending_tool_results.append(tool_info)
|
||||
|
||||
elif event_type == "turn_end":
|
||||
has_tool_calls = data.get("has_tool_calls", False)
|
||||
if has_tool_calls and state.pending_tool_results:
|
||||
# Flush collected tool results as a single tool_calls chunk
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_calls",
|
||||
"tool_calls": state.pending_tool_results,
|
||||
})
|
||||
state.pending_tool_results = None
|
||||
# Next content belongs to a new segment
|
||||
state.segment_id += 1
|
||||
|
||||
# Run the agent with our event callback ---------------------------
|
||||
logger.info(f"[ChatService] Starting agent run: session={session_id}, query={query[:80]}")
|
||||
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = agent.get_full_system_prompt()
|
||||
|
||||
# Create a copy of messages for this execution
|
||||
with agent.messages_lock:
|
||||
messages_copy = agent.messages.copy()
|
||||
original_length = len(agent.messages)
|
||||
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
|
||||
executor = AgentStreamExecutor(
|
||||
agent=agent,
|
||||
model=agent.model,
|
||||
system_prompt=full_system_prompt,
|
||||
tools=agent.tools,
|
||||
max_turns=agent.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy,
|
||||
max_context_turns=max_context_turns,
|
||||
)
|
||||
|
||||
try:
|
||||
response = executor.run_stream(query)
|
||||
except Exception:
|
||||
# If executor cleared messages (context overflow), sync back
|
||||
if len(executor.messages) == 0:
|
||||
with agent.messages_lock:
|
||||
agent.messages.clear()
|
||||
logger.info("[ChatService] Cleared agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Sync executor messages back to agent (thread-safe).
|
||||
# The executor may have trimmed context, making its list shorter than
|
||||
# original_length. In that case we must replace entirely — just
|
||||
# appending would leave stale pre-trim messages in agent.messages
|
||||
# and cause the same trim to fire on every subsequent request.
|
||||
with agent.messages_lock:
|
||||
trimmed = len(executor.messages) < original_length
|
||||
if trimmed:
|
||||
# Context was trimmed: the executor appended the new user
|
||||
# query *before* trimming, so the new messages (user +
|
||||
# assistant + tools) sit at the tail of the trimmed list.
|
||||
# We cannot simply slice at original_length (it exceeds the
|
||||
# list length). Instead, count how many messages the
|
||||
# executor added on top of the post-trim baseline.
|
||||
#
|
||||
# Timeline inside executor.run_stream:
|
||||
# 1. messages had `original_length` items
|
||||
# 2. append user query → original_length + 1
|
||||
# 3. _trim_messages() → some smaller number (includes the
|
||||
# user query because it belongs to the last turn)
|
||||
# 4. LLM replies / tool calls appended
|
||||
#
|
||||
# The user query message is always the first message of the
|
||||
# last turn (it cannot be trimmed away), so we locate it to
|
||||
# find where "new" messages begin.
|
||||
new_start = original_length # fallback
|
||||
for idx in range(len(executor.messages) - 1, -1, -1):
|
||||
msg = executor.messages[idx]
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", [])
|
||||
is_user_query = False
|
||||
if isinstance(content, list):
|
||||
has_text = any(
|
||||
isinstance(b, dict) and b.get("type") == "text"
|
||||
for b in content
|
||||
)
|
||||
has_tool_result = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
)
|
||||
is_user_query = has_text and not has_tool_result
|
||||
elif isinstance(content, str):
|
||||
is_user_query = True
|
||||
if is_user_query:
|
||||
new_start = idx
|
||||
break
|
||||
new_messages = list(executor.messages[new_start:])
|
||||
else:
|
||||
new_messages = list(executor.messages[original_length:])
|
||||
agent.messages = list(executor.messages)
|
||||
|
||||
# Persist new messages to SQLite so they survive restarts and
|
||||
# can be queried via the HISTORY interface.
|
||||
if new_messages:
|
||||
self._persist_messages(session_id, list(new_messages), channel_type)
|
||||
|
||||
# Store executor reference for files_to_send access
|
||||
agent.stream_executor = executor
|
||||
|
||||
# Execute post-process tools
|
||||
agent._execute_post_process_tools()
|
||||
|
||||
logger.info(f"[ChatService] Agent run completed: session={session_id}")
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _persist_messages(session_id: str, new_messages: list, channel_type: str = ""):
|
||||
try:
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().append_messages(
|
||||
session_id, new_messages, channel_type=channel_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[ChatService] Failed to persist messages for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
class _StreamState:
|
||||
"""Mutable state shared between the event callback and the run method."""
|
||||
|
||||
def __init__(self):
|
||||
self.segment_id: int = 0
|
||||
# None means we are not accumulating tool results right now.
|
||||
# A list means we are in the middle of a tool-execution phase.
|
||||
self.pending_tool_results: Optional[list] = None
|
||||
# Maps tool_call_id -> arguments captured from tool_execution_start,
|
||||
# so that tool_execution_end can attach the correct input args.
|
||||
self.pending_tool_arguments: dict = {}
|
||||
241
agent/chat/session_service.py
Normal file
241
agent/chat/session_service.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
SessionService - Manages multi-session lifecycle for both web channel and cloud client.
|
||||
|
||||
Provides a unified interface for listing, deleting, renaming, clearing context,
|
||||
and generating AI titles for conversation sessions. Backed by ConversationStore
|
||||
(SQLite) and AgentBridge (in-memory agent instances).
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def _truncate_fallback_title(user_message: str, max_len: int = 30) -> str:
|
||||
"""Pick the first non-empty line of the user message and truncate it."""
|
||||
if not user_message:
|
||||
return "New Chat"
|
||||
first_line = ""
|
||||
for line in user_message.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
first_line = line
|
||||
break
|
||||
if not first_line:
|
||||
return "New Chat"
|
||||
if len(first_line) > max_len:
|
||||
first_line = first_line[:max_len].rstrip() + "..."
|
||||
return first_line
|
||||
|
||||
|
||||
def generate_session_title(user_message: str, assistant_reply: str = "") -> str:
|
||||
"""
|
||||
Generate a short session title by calling the current bot's reply_text.
|
||||
Falls back to the first line of the user message if the LLM call fails
|
||||
or returns an obvious error sentinel.
|
||||
"""
|
||||
fallback = _truncate_fallback_title(user_message)
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
from models.session_manager import Session
|
||||
bot = Bridge().get_bot("chat")
|
||||
|
||||
prompt_parts = [f"User: {user_message[:300]}"]
|
||||
if assistant_reply:
|
||||
prompt_parts.append(f"Assistant: {assistant_reply[:300]}")
|
||||
|
||||
session = Session("__title_gen__", system_prompt="")
|
||||
session.messages = [
|
||||
{"role": "user", "content": (
|
||||
"Generate a very short title (max 15 characters for Chinese, max 6 words for English) "
|
||||
"summarizing this conversation. Return ONLY the title text, nothing else.\n\n"
|
||||
+ "\n".join(prompt_parts)
|
||||
)}
|
||||
]
|
||||
|
||||
result = bot.reply_text(session) or {}
|
||||
# When bots fail (network error, auth error, rate limit, etc.) they
|
||||
# typically return completion_tokens=0 with a sentinel content like
|
||||
# "请再问我一次吧" / "我现在有点累了". Treat that as failure.
|
||||
completion_tokens = result.get("completion_tokens", 0) or 0
|
||||
raw = (result.get("content") or "").strip()
|
||||
if completion_tokens <= 0:
|
||||
logger.warning(
|
||||
f"[SessionService] Title generation got empty completion "
|
||||
f"(completion_tokens={completion_tokens}, content='{raw[:50]}'), "
|
||||
f"using fallback")
|
||||
return fallback
|
||||
|
||||
title = re.sub(r'<think>.*?</think>', '', raw, flags=re.DOTALL).strip().strip('"\'')
|
||||
logger.info(f"[SessionService] Title generation result: '{title}' (len={len(title)})")
|
||||
if title and len(title) <= 50:
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.warning(f"[SessionService] Title generation failed: {e}")
|
||||
return fallback
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""
|
||||
High-level service for session lifecycle management.
|
||||
|
||||
Usage:
|
||||
svc = SessionService()
|
||||
result = svc.dispatch("list", {"channel_type": "web", "page": 1})
|
||||
"""
|
||||
|
||||
def _get_store(self):
|
||||
from agent.memory import get_conversation_store
|
||||
return get_conversation_store()
|
||||
|
||||
def _remove_agent(self, session_id: str):
|
||||
"""Remove the in-memory Agent instance for a session if it exists."""
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
ab = Bridge().get_agent_bridge()
|
||||
if session_id in ab.agents:
|
||||
del ab.agents[session_id]
|
||||
logger.info(f"[SessionService] Removed agent instance: {session_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _normalize_sid(session_id: str) -> str:
|
||||
if session_id and not session_id.startswith("session_"):
|
||||
return f"session_{session_id}"
|
||||
return session_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# actions
|
||||
# ------------------------------------------------------------------
|
||||
def list_sessions(self, channel_type: Optional[str] = None,
|
||||
page: int = 1, page_size: int = 50) -> dict:
|
||||
store = self._get_store()
|
||||
return store.list_sessions(
|
||||
channel_type=channel_type,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
def delete_session(self, session_id: str) -> None:
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
store.clear_session(session_id)
|
||||
self._remove_agent(session_id)
|
||||
logger.info(f"[SessionService] Session deleted: {session_id}")
|
||||
|
||||
def rename_session(self, session_id: str, title: str) -> None:
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
if not title:
|
||||
raise ValueError("title required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
found = store.rename_session(session_id, title)
|
||||
if not found:
|
||||
raise ValueError("session not found")
|
||||
|
||||
def clear_context(self, session_id: str) -> int:
|
||||
"""
|
||||
Set context boundary. Returns the new context_start_seq value.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
new_seq = store.clear_context(session_id)
|
||||
self._remove_agent(session_id)
|
||||
return new_seq
|
||||
|
||||
def gen_title(self, session_id: str, user_message: str,
|
||||
assistant_reply: str = "") -> str:
|
||||
"""
|
||||
Generate an AI title and persist it. Returns the generated title.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
if not user_message:
|
||||
raise ValueError("user_message required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
title = generate_session_title(user_message, assistant_reply)
|
||||
|
||||
store = self._get_store()
|
||||
updated = store.rename_session(session_id, title)
|
||||
logger.info(f"[SessionService] Title set: sid={session_id}, "
|
||||
f"title='{title}', db_updated={updated}")
|
||||
return title
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a session management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
Action names use a ``*_session`` / session-prefixed convention so they
|
||||
can coexist with history actions (e.g. ``query``) on the same HISTORY
|
||||
message channel without ambiguity.
|
||||
|
||||
Supported actions:
|
||||
- list_sessions: list sessions with pagination
|
||||
- delete_session: delete a session
|
||||
- rename_session: rename a session title
|
||||
- clear_context: set context boundary
|
||||
- generate_title: AI-generate a session title
|
||||
|
||||
:param action: one of the above action names
|
||||
:param payload: action-specific payload
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list_sessions":
|
||||
result = self.list_sessions(
|
||||
channel_type=payload.get("channel_type"),
|
||||
page=int(payload.get("page", 1)),
|
||||
page_size=int(payload.get("page_size", 50)),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "delete_session":
|
||||
self.delete_session(payload.get("session_id", ""))
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
|
||||
elif action == "rename_session":
|
||||
self.rename_session(
|
||||
payload.get("session_id", ""),
|
||||
payload.get("title", "").strip(),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
|
||||
elif action == "clear_context":
|
||||
new_seq = self.clear_context(payload.get("session_id", ""))
|
||||
return {"action": action, "code": 200, "message": "success",
|
||||
"payload": {"context_start_seq": new_seq}}
|
||||
|
||||
elif action == "generate_title":
|
||||
title = self.gen_title(
|
||||
payload.get("session_id", ""),
|
||||
payload.get("user_message", ""),
|
||||
payload.get("assistant_reply", ""),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success",
|
||||
"payload": {"title": title}}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400,
|
||||
"message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 400, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SessionService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
0
agent/knowledge/__init__.py
Normal file
0
agent/knowledge/__init__.py
Normal file
240
agent/knowledge/service.py
Normal file
240
agent/knowledge/service.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Knowledge service for handling knowledge base operations.
|
||||
|
||||
Provides a unified interface for listing, reading, and graphing knowledge files,
|
||||
callable from the web console, API, or CLI.
|
||||
|
||||
Knowledge file layout (under workspace_root):
|
||||
knowledge/index.md
|
||||
knowledge/log.md
|
||||
knowledge/<category>/<slug>.md
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
"""
|
||||
High-level service for knowledge base queries.
|
||||
Operates directly on the filesystem.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
self.workspace_root = workspace_root
|
||||
self.knowledge_dir = os.path.join(workspace_root, "knowledge")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — directory tree with stats
|
||||
# ------------------------------------------------------------------
|
||||
def list_tree(self) -> dict:
|
||||
"""
|
||||
Return the knowledge directory tree grouped by category,
|
||||
supporting arbitrarily nested sub-directories.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"tree": [
|
||||
{
|
||||
"dir": "concepts",
|
||||
"files": [
|
||||
{"name": "moe.md", "title": "MoE", "size": 1234},
|
||||
],
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"dir": "platform",
|
||||
"files": [],
|
||||
"children": [
|
||||
{
|
||||
"dir": "analysis",
|
||||
"files": [{"name": "perf.md", ...}],
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
],
|
||||
"stats": {"pages": 15, "size": 32768},
|
||||
"enabled": true
|
||||
}
|
||||
"""
|
||||
if not os.path.isdir(self.knowledge_dir):
|
||||
return {"tree": [], "stats": {"pages": 0, "size": 0}, "enabled": conf().get("knowledge", True)}
|
||||
|
||||
stats = {"pages": 0, "size": 0}
|
||||
root_files, tree = self._scan_dir(self.knowledge_dir, stats, is_root=True)
|
||||
|
||||
return {
|
||||
"root_files": root_files,
|
||||
"tree": tree,
|
||||
"stats": stats,
|
||||
"enabled": conf().get("knowledge", True),
|
||||
}
|
||||
|
||||
def _scan_dir(self, dir_path: str, stats: dict, is_root: bool = False) -> tuple:
|
||||
"""
|
||||
Recursively scan a directory.
|
||||
|
||||
:return: (files, children) where files is a list of .md file dicts
|
||||
in this directory and children is a list of sub-directory nodes.
|
||||
"""
|
||||
files = []
|
||||
children = []
|
||||
for name in sorted(os.listdir(dir_path)):
|
||||
if name.startswith("."):
|
||||
continue
|
||||
full = os.path.join(dir_path, name)
|
||||
if os.path.isdir(full):
|
||||
sub_files, sub_children = self._scan_dir(full, stats)
|
||||
children.append({"dir": name, "files": sub_files, "children": sub_children})
|
||||
elif name.endswith(".md"):
|
||||
size = os.path.getsize(full)
|
||||
if not is_root:
|
||||
stats["pages"] += 1
|
||||
stats["size"] += size
|
||||
title = name.replace(".md", "")
|
||||
try:
|
||||
with open(full, "r", encoding="utf-8") as f:
|
||||
first_line = f.readline().strip()
|
||||
if first_line.startswith("# "):
|
||||
title = first_line[2:].strip()
|
||||
except Exception:
|
||||
pass
|
||||
files.append({"name": name, "title": title, "size": size})
|
||||
return files, children
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# read — single file content
|
||||
# ------------------------------------------------------------------
|
||||
def read_file(self, rel_path: str) -> dict:
|
||||
"""
|
||||
Read a single knowledge markdown file.
|
||||
|
||||
:param rel_path: Relative path within knowledge/, e.g. ``concepts/moe.md``
|
||||
:return: dict with ``content`` and ``path``
|
||||
:raises ValueError: if path is invalid or escapes knowledge dir
|
||||
:raises FileNotFoundError: if file does not exist
|
||||
"""
|
||||
if not rel_path or ".." in rel_path:
|
||||
raise ValueError("invalid path")
|
||||
|
||||
full_path = os.path.normpath(os.path.join(self.knowledge_dir, rel_path))
|
||||
allowed = os.path.normpath(self.knowledge_dir)
|
||||
if not full_path.startswith(allowed + os.sep) and full_path != allowed:
|
||||
raise ValueError("path outside knowledge dir")
|
||||
|
||||
if not os.path.isfile(full_path):
|
||||
raise FileNotFoundError(f"file not found: {rel_path}")
|
||||
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
return {"content": content, "path": rel_path}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# graph — nodes and links for visualization
|
||||
# ------------------------------------------------------------------
|
||||
def build_graph(self) -> dict:
|
||||
"""
|
||||
Parse all knowledge pages and extract cross-reference links.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"nodes": [
|
||||
{"id": "concepts/moe.md", "label": "MoE", "category": "concepts"},
|
||||
...
|
||||
],
|
||||
"links": [
|
||||
{"source": "concepts/moe.md", "target": "entities/deepseek.md"},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
knowledge_path = Path(self.knowledge_dir)
|
||||
if not knowledge_path.is_dir():
|
||||
return {"nodes": [], "links": []}
|
||||
|
||||
nodes = {}
|
||||
links = []
|
||||
link_re = re.compile(r'\[([^\]]*)\]\(([^)]+\.md)\)')
|
||||
|
||||
for md_file in knowledge_path.rglob("*.md"):
|
||||
rel = str(md_file.relative_to(knowledge_path))
|
||||
if rel in ("index.md", "log.md"):
|
||||
continue
|
||||
parts = rel.split("/")
|
||||
category = parts[0] if len(parts) > 1 else "root"
|
||||
title = md_file.stem.replace("-", " ").title()
|
||||
try:
|
||||
content = md_file.read_text(encoding="utf-8")
|
||||
first_line = content.strip().split("\n")[0]
|
||||
if first_line.startswith("# "):
|
||||
title = first_line[2:].strip()
|
||||
for _, link_target in link_re.findall(content):
|
||||
resolved = (md_file.parent / link_target).resolve()
|
||||
try:
|
||||
target_rel = str(resolved.relative_to(knowledge_path))
|
||||
except ValueError:
|
||||
continue
|
||||
if target_rel != rel:
|
||||
links.append({"source": rel, "target": target_rel})
|
||||
except Exception:
|
||||
pass
|
||||
nodes[rel] = {"id": rel, "label": title, "category": category}
|
||||
|
||||
valid_ids = set(nodes.keys())
|
||||
links = [l for l in links if l["source"] in valid_ids and l["target"] in valid_ids]
|
||||
seen = set()
|
||||
deduped = []
|
||||
for l in links:
|
||||
key = tuple(sorted([l["source"], l["target"]]))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(l)
|
||||
|
||||
return {"nodes": list(nodes.values()), "links": deduped}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a knowledge management action.
|
||||
|
||||
:param action: ``list``, ``read``, or ``graph``
|
||||
:param payload: action-specific payload
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
result = self.list_tree()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "read":
|
||||
path = payload.get("path")
|
||||
if not path:
|
||||
return {"action": action, "code": 400, "message": "path is required", "payload": None}
|
||||
result = self.read_file(path)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "graph":
|
||||
result = self.build_graph()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 403, "message": str(e), "payload": None}
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[KnowledgeService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
@@ -1,11 +1,23 @@
|
||||
"""
|
||||
Memory module for AgentMesh
|
||||
|
||||
Provides long-term memory capabilities with hybrid search (vector + keyword)
|
||||
Provides both long-term memory (vector/keyword search) and short-term
|
||||
conversation history persistence (SQLite).
|
||||
"""
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
from agent.memory.embedding import create_embedding_provider
|
||||
from agent.memory.conversation_store import ConversationStore, get_conversation_store
|
||||
from agent.memory.summarizer import ensure_daily_memory_file
|
||||
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider']
|
||||
__all__ = [
|
||||
'MemoryManager',
|
||||
'MemoryConfig',
|
||||
'get_default_memory_config',
|
||||
'set_global_memory_config',
|
||||
'create_embedding_provider',
|
||||
'ConversationStore',
|
||||
'get_conversation_store',
|
||||
'ensure_daily_memory_file',
|
||||
]
|
||||
|
||||
@@ -4,6 +4,7 @@ Text chunking utilities for memory
|
||||
Splits text into chunks with token limits and overlap
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -4,18 +4,25 @@ Memory configuration module
|
||||
Provides global memory configuration with simplified workspace structure
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _default_workspace():
|
||||
"""Get default workspace path with proper Windows support"""
|
||||
from common.utils import expand_path
|
||||
return expand_path("~/cow")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig:
|
||||
"""Configuration for memory storage and search"""
|
||||
|
||||
# Storage paths (default: ~/cow)
|
||||
workspace_root: str = field(default_factory=lambda: os.path.expanduser("~/cow"))
|
||||
workspace_root: str = field(default_factory=_default_workspace)
|
||||
|
||||
# Embedding config
|
||||
embedding_provider: str = "openai" # "openai" | "local"
|
||||
@@ -41,9 +48,6 @@ class MemoryConfig:
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
# Memory flush config (独立于模型 context window)
|
||||
flush_token_threshold: int = 50000 # 50K tokens 触发 flush
|
||||
flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
|
||||
1055
agent/memory/conversation_store.py
Normal file
1055
agent/memory/conversation_store.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,148 +0,0 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports OpenAI and local embedding models
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Get embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using REST API"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
Initialize OpenAI embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
||||
api_key: OpenAI API key
|
||||
api_base: Optional API base URL
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
|
||||
# Set dimensions based on model
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI embedding API using requests"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
result = self._call_api(text)
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
result = self._call_api(texts)
|
||||
return [item["embedding"] for item in result["data"]]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
"""Get cached embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
"""Cache embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
"""Compute cache key"""
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
"""Clear cache"""
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Only supports OpenAI embedding via REST API.
|
||||
If initialization fails, caller should fall back to keyword-only search.
|
||||
|
||||
Args:
|
||||
provider: Provider name (only "openai" is supported)
|
||||
model: Model name (default: text-embedding-3-small)
|
||||
api_key: OpenAI API key (required)
|
||||
api_base: API base URL (default: https://api.openai.com/v1)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not "openai" or api_key is missing
|
||||
"""
|
||||
if provider != "openai":
|
||||
raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
|
||||
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
41
agent/memory/embedding/__init__.py
Normal file
41
agent/memory/embedding/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Embedding subsystem for memory.
|
||||
|
||||
Public API:
|
||||
create_embedding_provider, EmbeddingProvider, OpenAIEmbeddingProvider,
|
||||
EMBEDDING_VENDORS, EmbeddingCache
|
||||
RebuildResult, clear_index, rebuild_in_process
|
||||
detect_index_dim, cleanup_legacy_state_file
|
||||
"""
|
||||
|
||||
from agent.memory.embedding.provider import (
|
||||
EMBEDDING_VENDORS,
|
||||
DoubaoEmbeddingProvider,
|
||||
EmbeddingCache,
|
||||
EmbeddingProvider,
|
||||
OpenAIEmbeddingProvider,
|
||||
create_embedding_provider,
|
||||
)
|
||||
from agent.memory.embedding.rebuild import (
|
||||
RebuildResult,
|
||||
clear_index,
|
||||
rebuild_in_process,
|
||||
)
|
||||
from agent.memory.embedding.state import (
|
||||
cleanup_legacy_state_file,
|
||||
detect_index_dim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EMBEDDING_VENDORS",
|
||||
"DoubaoEmbeddingProvider",
|
||||
"EmbeddingCache",
|
||||
"EmbeddingProvider",
|
||||
"OpenAIEmbeddingProvider",
|
||||
"create_embedding_provider",
|
||||
"RebuildResult",
|
||||
"clear_index",
|
||||
"rebuild_in_process",
|
||||
"cleanup_legacy_state_file",
|
||||
"detect_index_dim",
|
||||
]
|
||||
486
agent/memory/embedding/provider.py
Normal file
486
agent/memory/embedding/provider.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports multiple OpenAI-compatible embedding vendors:
|
||||
- openai (text-embedding-3-small / large)
|
||||
- linkai (OpenAI-compatible passthrough)
|
||||
- dashscope (Aliyun Tongyi text-embedding-v4)
|
||||
- doubao (ByteDance Doubao Seed1.5 / large-text on Volcengine Ark)
|
||||
- zhipu (ZhipuAI embedding-3)
|
||||
|
||||
Vendor keys here intentionally match the project's bot_type constants in
|
||||
common.const (OPENAI, LINKAI, QWEN_DASHSCOPE, DOUBAO, ZHIPU_AI).
|
||||
|
||||
All providers share a single OpenAI-compatible REST client. Vendor-specific
|
||||
behaviors (truncation, query instruction prefix) are configured via metadata.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
# HTTP read timeout for a single embeddings request (seconds). A batch of
|
||||
# 64+ chunks can take 30-50s end-to-end from China-side networks, so 30s is
|
||||
# routinely too tight; 90s gives meaningful headroom without letting bad
|
||||
# endpoints hang forever.
|
||||
EMBEDDING_HTTP_TIMEOUT = 90
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a single text (treated as a query by default)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts (treated as documents)"""
|
||||
pass
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a query string (may apply vendor instruction prefix)"""
|
||||
return self.embed(text)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Effective embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vendor metadata table
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Each entry describes how to reach a vendor's embedding endpoint. Most
|
||||
# vendors expose an OpenAI-compatible /embeddings API; the few that don't
|
||||
# (currently: doubao) set `provider_class` to pick a dedicated adapter.
|
||||
# Fields:
|
||||
# provider_class : optional adapter key ("doubao"); defaults to OpenAI-compat
|
||||
# default_base_url : default API base when not overridden by user
|
||||
# default_model : default embedding model name
|
||||
# default_dimensions : recommended unified dim when explicit path is enabled
|
||||
# supports_dim_param : whether the API accepts a `dimensions` request param
|
||||
# needs_client_truncate : whether to slice + L2-normalize on the client side
|
||||
# needs_client_normalize : whether to L2-normalize on the client (always safe)
|
||||
# query_instruction : optional prefix for asymmetric retrieval (Doubao Seed)
|
||||
# max_batch_size : max texts per /embeddings request; embed_batch
|
||||
# auto-paginates above this. Conservative defaults.
|
||||
#
|
||||
EMBEDDING_VENDORS = {
|
||||
"openai": {
|
||||
"default_base_url": "https://api.openai.com/v1",
|
||||
"default_model": "text-embedding-3-small",
|
||||
# Match the legacy default so users adding `embedding_provider: openai`
|
||||
# to an existing index don't need to rebuild. Override via
|
||||
# embedding_dimensions if you want 1024 / 1536 / 3072.
|
||||
"default_dimensions": 1536,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
# OpenAI permits up to 2048 items per request, but a single call
|
||||
# carrying hundreds of long chunks routinely exceeds the 30s read
|
||||
# timeout from China-side networks. 64 keeps each call well under
|
||||
# both the token-per-request budget and a reasonable wall clock.
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
"linkai": {
|
||||
"default_base_url": "https://api.link-ai.tech/v1",
|
||||
"default_model": "text-embedding-3-small",
|
||||
"default_dimensions": 1536,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
"dashscope": {
|
||||
"default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"default_model": "text-embedding-v4",
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 10, # DashScope hard cap (text-embedding-v4)
|
||||
},
|
||||
"doubao": {
|
||||
# Doubao no longer offers an OpenAI-compatible /v1/embeddings endpoint.
|
||||
# Current models are unified under /api/v3/embeddings/multimodal
|
||||
# which uses a structured `input` payload — see DoubaoEmbeddingProvider.
|
||||
"provider_class": "doubao",
|
||||
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||||
"default_model": "doubao-embedding-vision-251215",
|
||||
# Native options: 1024 or 2048. We default to 1024 to align with the
|
||||
# other Chinese vendors (dashscope/zhipu) and keep storage footprint
|
||||
# consistent across providers; users can still override via
|
||||
# `embedding_dimensions: 2048` in config.
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
# Multimodal endpoint produces ONE embedding per call (input list is
|
||||
# a single document's parts, not a batch). embed_batch loops.
|
||||
"max_batch_size": 1,
|
||||
},
|
||||
"zhipu": {
|
||||
"default_base_url": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"default_model": "embedding-3",
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _l2_normalize(vec: List[float]) -> List[float]:
|
||||
"""Normalize a vector to unit length (L2 norm). Returns input on zero vector."""
|
||||
norm = math.sqrt(sum(v * v for v in vec))
|
||||
if norm == 0:
|
||||
return vec
|
||||
return [v / norm for v in vec]
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
OpenAI-compatible embedding provider.
|
||||
|
||||
Used for openai/linkai/dashscope/ark/zhipu by configuring the metadata
|
||||
fields. The legacy two-arg constructor (model, api_key, api_base) keeps
|
||||
working, so the original OpenAI/LinkAI fallback code path is unchanged.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "text-embedding-3-small",
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
supports_dim_param: bool = True,
|
||||
needs_client_truncate: bool = False,
|
||||
needs_client_normalize: bool = False,
|
||||
query_instruction: str = "",
|
||||
max_batch_size: int = 256,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Model name (e.g. text-embedding-3-small, text-embedding-v4, embedding-3)
|
||||
api_key: API key (required)
|
||||
api_base: API base URL (defaults to OpenAI)
|
||||
extra_headers: Optional extra HTTP headers
|
||||
dimensions: Target output dimension. Required when supports_dim_param
|
||||
is False and needs_client_truncate is True (used to slice).
|
||||
supports_dim_param: Whether the vendor accepts a `dimensions` body param
|
||||
needs_client_truncate: Slice the returned vector to `dimensions`
|
||||
needs_client_normalize: L2-normalize on the client after slicing
|
||||
query_instruction: Optional prefix prepended to query texts only
|
||||
max_batch_size: Max items per /embeddings request; embed_batch
|
||||
auto-paginates above this.
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.supports_dim_param = supports_dim_param
|
||||
self.needs_client_truncate = needs_client_truncate
|
||||
self.needs_client_normalize = needs_client_normalize
|
||||
self.query_instruction = query_instruction or ""
|
||||
self.max_batch_size = max(1, int(max_batch_size or 1))
|
||||
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("Embedding API key is not configured")
|
||||
|
||||
if dimensions is not None and dimensions > 0:
|
||||
self._dimensions = dimensions
|
||||
else:
|
||||
# Legacy heuristic for OpenAI text-embedding-3-* family
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI-compatible /embeddings endpoint"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**self.extra_headers,
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model,
|
||||
}
|
||||
if self.supports_dim_param and self._dimensions:
|
||||
data["dimensions"] = self._dimensions
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=EMBEDDING_HTTP_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to embedding API at {url}. "
|
||||
f"Please check network and api_base. Error: {str(e)}"
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"Embedding API request timed out. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError("Invalid embedding API key")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError("Embedding API rate limit exceeded")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Embedding API request failed: "
|
||||
f"{e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
|
||||
def _post_process(self, raw: List[float]) -> List[float]:
|
||||
"""Apply optional client-side truncation + normalization"""
|
||||
vec = raw
|
||||
if self.needs_client_truncate and self._dimensions and len(vec) > self._dimensions:
|
||||
vec = vec[: self._dimensions]
|
||||
if self.needs_client_normalize:
|
||||
vec = _l2_normalize(vec)
|
||||
return vec
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding (treated as document by default)"""
|
||||
result = self._call_api(text)
|
||||
return self._post_process(result["data"][0]["embedding"])
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a query (applies vendor instruction prefix if any)"""
|
||||
if self.query_instruction:
|
||||
text = f"{self.query_instruction}{text}"
|
||||
return self.embed(text)
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple documents.
|
||||
|
||||
Automatically paginates by self.max_batch_size so callers can pass any
|
||||
number of texts. Order of returned vectors matches the input order.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
out: List[List[float]] = []
|
||||
step = self.max_batch_size
|
||||
for i in range(0, len(texts), step):
|
||||
chunk = texts[i:i + step]
|
||||
result = self._call_api(chunk)
|
||||
out.extend(self._post_process(item["embedding"]) for item in result["data"])
|
||||
return out
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class DoubaoEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
Doubao (Volcengine Ark) multimodal embedding provider.
|
||||
|
||||
Doubao deprecated their OpenAI-compatible /v1/embeddings endpoint and
|
||||
unified everything under /api/v3/embeddings/multimodal, which uses a
|
||||
structured `input: [{type, text|image_url|video_url}, ...]` payload.
|
||||
|
||||
Notes:
|
||||
* The endpoint produces ONE embedding per call (input list is multiple
|
||||
modality parts of a single document, not a batch). embed_batch
|
||||
therefore loops per-text — no native batch support.
|
||||
* Native dimensions: 1024 or 2048 (default 1024 to align with other
|
||||
Chinese vendors). No client-side truncation needed.
|
||||
* Auth: Bearer ARK API key.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
|
||||
self.extra_headers = extra_headers or {}
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("Doubao embedding API key (ark_api_key) is not configured")
|
||||
|
||||
if dimensions in (1024, 2048):
|
||||
self._dimensions = dimensions
|
||||
elif dimensions is None:
|
||||
self._dimensions = 1024
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Doubao embedding dimensions must be 1024 or 2048, got {dimensions}"
|
||||
)
|
||||
|
||||
def _call_api(self, text: str) -> List[float]:
|
||||
"""One call → one embedding. multimodal endpoint takes a single
|
||||
document represented as a list of typed parts; we send a single
|
||||
text part."""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings/multimodal"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**self.extra_headers,
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": [{"type": "text", "text": text}],
|
||||
"dimensions": self._dimensions,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=EMBEDDING_HTTP_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to Doubao embedding API at {url}. "
|
||||
f"Please check network and api_base. Error: {str(e)}"
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"Doubao embedding API request timed out. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError("Invalid Doubao (ark) embedding API key")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError("Doubao embedding API rate limit exceeded")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Doubao embedding API request failed: "
|
||||
f"{e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
|
||||
# Response shape per docs: {"data": {"embedding": [...]}}
|
||||
data = body.get("data")
|
||||
if isinstance(data, dict) and "embedding" in data:
|
||||
return data["embedding"]
|
||||
# Some providers wrap as a list of one — be defensive
|
||||
if isinstance(data, list) and data and "embedding" in data[0]:
|
||||
return data[0]["embedding"]
|
||||
raise ValueError(f"Unexpected Doubao embedding response shape: {body}")
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
return self._call_api(text)
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
# Endpoint produces one embedding per call; loop. Order preserved.
|
||||
return [self._call_api(t) for t in texts]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""In-memory cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create an embedding provider.
|
||||
|
||||
Backward compatible: when called with provider in {"openai", "linkai"}
|
||||
and no `dimensions` arg, behaves exactly as before (1536-dim OpenAI).
|
||||
|
||||
New providers ("dashscope", "doubao", "zhipu") require explicit configuration
|
||||
and use the unified 1024-dim defaults from EMBEDDING_VENDORS.
|
||||
|
||||
Args:
|
||||
provider: Vendor key (one of EMBEDDING_VENDORS)
|
||||
model: Model name (uses vendor default if None)
|
||||
api_key: API key (required)
|
||||
api_base: API base URL (uses vendor default if None)
|
||||
extra_headers: Optional extra HTTP headers
|
||||
dimensions: Target output dimension (uses vendor default if None)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
meta = EMBEDDING_VENDORS.get(provider)
|
||||
if meta is None:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding provider: {provider}. "
|
||||
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}"
|
||||
)
|
||||
|
||||
# Doubao uses a non-OpenAI-compatible multimodal endpoint.
|
||||
if meta.get("provider_class") == "doubao":
|
||||
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
||||
return DoubaoEmbeddingProvider(
|
||||
model=model or meta["default_model"],
|
||||
api_key=api_key,
|
||||
api_base=api_base or meta["default_base_url"],
|
||||
extra_headers=extra_headers,
|
||||
dimensions=final_dim,
|
||||
)
|
||||
|
||||
# Legacy two-arg call for openai/linkai keeps 1536-dim default behavior
|
||||
# so existing data isn't invalidated.
|
||||
is_legacy_call = (
|
||||
provider in ("openai", "linkai")
|
||||
and dimensions is None
|
||||
)
|
||||
if is_legacy_call:
|
||||
return OpenAIEmbeddingProvider(
|
||||
model=model or "text-embedding-3-small",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
||||
return OpenAIEmbeddingProvider(
|
||||
model=model or meta["default_model"],
|
||||
api_key=api_key,
|
||||
api_base=api_base or meta["default_base_url"],
|
||||
extra_headers=extra_headers,
|
||||
dimensions=final_dim,
|
||||
supports_dim_param=meta["supports_dim_param"],
|
||||
needs_client_truncate=meta["needs_client_truncate"],
|
||||
needs_client_normalize=meta["needs_client_normalize"],
|
||||
query_instruction=meta["query_instruction"],
|
||||
max_batch_size=meta.get("max_batch_size", 256),
|
||||
)
|
||||
191
agent/memory/embedding/rebuild.py
Normal file
191
agent/memory/embedding/rebuild.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Rebuild memory vector index.
|
||||
|
||||
Recommended entry point (in-chat, while agent is running):
|
||||
/memory rebuild-index
|
||||
|
||||
Backward-compatible CLI entry (must run from project root):
|
||||
python -m agent.memory.rebuild_index
|
||||
|
||||
What it does:
|
||||
1. Probes the embedding endpoint with a tiny call to fail fast on
|
||||
bad provider/model/key — before touching the index.
|
||||
2. Clears the SQLite chunks/files tables (workspace markdown stays intact).
|
||||
3. Runs a fresh sync, regenerating embeddings with the currently configured
|
||||
provider/model/dimensions.
|
||||
|
||||
This is the only safe way to switch embedding_provider after the existing
|
||||
index has been populated by a different-dim model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class RebuildResult:
|
||||
"""Outcome of a rebuild_in_process() call"""
|
||||
ok: bool
|
||||
removed: int = 0
|
||||
chunks: int = 0
|
||||
files: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def clear_index(db_path, storage=None) -> int:
|
||||
"""Wipe chunks/files, reset FTS5, and clean up any legacy state file.
|
||||
|
||||
Args:
|
||||
db_path: Path of the index DB (also used to locate the legacy state
|
||||
file for migration cleanup, and — when *storage* is None — to
|
||||
open a fresh connection).
|
||||
storage: Optional pre-opened MemoryStorage. When provided we reuse it
|
||||
so the live connection's triggers stay in sync — opening a second
|
||||
connection would leave the original one's triggers pointing at a
|
||||
DROP'd chunks_fts table.
|
||||
|
||||
We reset (DROP+recreate) chunks_fts because its shadow tables can become
|
||||
inconsistent across rebuild cycles, causing bm25() / ORDER BY rank to
|
||||
raise "database disk image is malformed" even when raw MATCH still works.
|
||||
|
||||
Returns number of chunks removed.
|
||||
"""
|
||||
from agent.memory.embedding.state import cleanup_legacy_state_file
|
||||
from agent.memory.storage import MemoryStorage
|
||||
|
||||
owns_storage = storage is None
|
||||
if owns_storage:
|
||||
storage = MemoryStorage(db_path)
|
||||
try:
|
||||
before = storage.conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
|
||||
storage.conn.execute("DELETE FROM chunks")
|
||||
storage.conn.execute("DELETE FROM files")
|
||||
storage.conn.commit()
|
||||
storage.reset_fts5()
|
||||
finally:
|
||||
if owns_storage:
|
||||
storage.close()
|
||||
|
||||
cleanup_legacy_state_file(db_path)
|
||||
return int(before)
|
||||
|
||||
|
||||
def rebuild_in_process(memory_manager) -> RebuildResult:
|
||||
"""
|
||||
Rebuild the index using an existing, fully-initialized MemoryManager.
|
||||
|
||||
Used by the in-chat /memory rebuild-index command. The caller already has
|
||||
config loaded, embedding_provider built, and (optionally) the agent
|
||||
running, so we only need to:
|
||||
1. Clear chunks/files + state on the manager's storage.
|
||||
2. Re-sync (force=True).
|
||||
|
||||
NOTE: caller must ensure memory_manager.embedding_provider is set, otherwise
|
||||
sync() will silently skip embedding generation.
|
||||
"""
|
||||
if memory_manager is None:
|
||||
return RebuildResult(ok=False, error="memory_manager is None")
|
||||
if memory_manager.embedding_provider is None:
|
||||
return RebuildResult(ok=False, error="embedding_provider is not initialized")
|
||||
|
||||
# Probe the embedding endpoint BEFORE clearing the index. A bad
|
||||
# provider/model/key would otherwise leave the user with an empty index
|
||||
# that not even keyword search can serve.
|
||||
try:
|
||||
memory_manager.embedding_provider.embed_query("ping")
|
||||
except Exception as e:
|
||||
logger.error(f"[RebuildIndex] embedding probe failed, aborting rebuild: {e}")
|
||||
return RebuildResult(ok=False, error=f"embedding endpoint not reachable: {e}")
|
||||
|
||||
db_path = memory_manager.config.get_db_path()
|
||||
try:
|
||||
removed = clear_index(db_path, storage=memory_manager.storage)
|
||||
except Exception as e:
|
||||
logger.exception("[RebuildIndex] clear_index failed")
|
||||
return RebuildResult(ok=False, error=f"clear failed: {e}")
|
||||
|
||||
try:
|
||||
asyncio.run(memory_manager.sync(force=True))
|
||||
except RuntimeError:
|
||||
# Already inside a running event loop (rare in chat handler thread).
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(memory_manager.sync(force=True))
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.exception("[RebuildIndex] sync failed")
|
||||
return RebuildResult(ok=False, removed=removed, error=f"re-embed failed: {e}")
|
||||
|
||||
stats = memory_manager.storage.get_stats()
|
||||
chunks = int(stats.get("chunks", 0))
|
||||
embedded = int(stats.get("embedded", 0))
|
||||
|
||||
# sync() degrades to "no embeddings" on batch failure so keyword search
|
||||
# still works at startup — but in a /rebuild-index request the user
|
||||
# explicitly asked for vectors. Surface that as a failure.
|
||||
if chunks > 0 and embedded == 0:
|
||||
return RebuildResult(
|
||||
ok=False,
|
||||
removed=removed,
|
||||
chunks=chunks,
|
||||
files=int(stats.get("files", 0)),
|
||||
error=(
|
||||
"embedding API failed during sync; index now has chunks but no "
|
||||
"vectors. Check embedding provider/model/key and retry."
|
||||
),
|
||||
)
|
||||
|
||||
return RebuildResult(
|
||||
ok=True,
|
||||
removed=removed,
|
||||
chunks=chunks,
|
||||
files=int(stats.get("files", 0)),
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Standalone CLI entry. Must be run from project root (relative config path)."""
|
||||
from config import conf, load_config
|
||||
from agent.memory import MemoryConfig, MemoryManager
|
||||
|
||||
load_config()
|
||||
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
logger.info(f"[RebuildIndex] Workspace: {workspace_root}")
|
||||
logger.info(f"[RebuildIndex] Index db: {memory_config.get_db_path()}")
|
||||
|
||||
from bridge.agent_initializer import AgentInitializer
|
||||
|
||||
initializer = AgentInitializer(bridge=None, agent_bridge=None)
|
||||
embedding_provider = initializer._init_embedding_provider(memory_config, session_id=None)
|
||||
if embedding_provider is None:
|
||||
logger.error(
|
||||
"[RebuildIndex] No embedding provider could be initialized. "
|
||||
"Check your config.json. Aborting rebuild."
|
||||
)
|
||||
return 1
|
||||
|
||||
manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
result = rebuild_in_process(manager)
|
||||
if not result.ok:
|
||||
logger.error(f"[RebuildIndex] {result.error}")
|
||||
return 1
|
||||
|
||||
logger.info(
|
||||
f"[RebuildIndex] Done. removed={result.removed}, "
|
||||
f"chunks={result.chunks}, files={result.files}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
51
agent/memory/embedding/state.py
Normal file
51
agent/memory/embedding/state.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Embedding-related index utilities.
|
||||
|
||||
We don't keep a sidecar state file — the SQLite index is the source of truth
|
||||
and config.json is the source of intent. The two functions below are the
|
||||
only things needing on-disk awareness:
|
||||
|
||||
detect_index_dim : read the dim of stored vectors (display-only)
|
||||
cleanup_legacy_state_file: remove old embedding_state.json from earlier
|
||||
versions; safe no-op when absent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
PathLike = Union[str, os.PathLike]
|
||||
|
||||
|
||||
def detect_index_dim(storage) -> Optional[int]:
|
||||
"""Return the dim of the first stored embedding, or None if the index
|
||||
has no embeddings. Used by /memory status."""
|
||||
try:
|
||||
row = storage.conn.execute(
|
||||
"SELECT embedding FROM chunks WHERE embedding IS NOT NULL LIMIT 1"
|
||||
).fetchone()
|
||||
except Exception:
|
||||
return None
|
||||
if not row or not row["embedding"]:
|
||||
return None
|
||||
try:
|
||||
raw = row["embedding"]
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
# New BLOB format: 4 bytes per float32
|
||||
return len(raw) // 4
|
||||
emb = json.loads(raw)
|
||||
return len(emb) if isinstance(emb, list) else None
|
||||
except (json.JSONDecodeError, TypeError, Exception):
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_legacy_state_file(db_path: PathLike) -> None:
|
||||
"""Remove old embedding_state.json files from earlier versions.
|
||||
Safe to call repeatedly; no-op if the file is absent."""
|
||||
legacy = Path(db_path).parent / "embedding_state.json"
|
||||
try:
|
||||
legacy.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -13,7 +13,7 @@ from datetime import datetime, timedelta
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||
from agent.memory.chunker import TextChunker
|
||||
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
|
||||
from agent.memory.embedding import EmbeddingProvider, EmbeddingCache
|
||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||
|
||||
|
||||
@@ -50,30 +50,22 @@ class MemoryManager:
|
||||
overlap_tokens=self.config.chunk_overlap_tokens
|
||||
)
|
||||
|
||||
# Initialize embedding provider (optional)
|
||||
self.embedding_provider = None
|
||||
if embedding_provider:
|
||||
self.embedding_provider = embedding_provider
|
||||
else:
|
||||
# Try to create embedding provider, but allow failure
|
||||
try:
|
||||
# Get API key from environment or config
|
||||
api_key = os.environ.get('OPENAI_API_KEY')
|
||||
api_base = os.environ.get('OPENAI_API_BASE')
|
||||
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider=self.config.embedding_provider,
|
||||
model=self.config.embedding_model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
except Exception as e:
|
||||
# Embedding provider failed, but that's OK
|
||||
# We can still use keyword search and file operations
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
|
||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
||||
|
||||
# Embedding provider is owned by the caller (agent_initializer is the
|
||||
# canonical entry point and handles legacy/explicit + state validation).
|
||||
# When None is passed, memory degrades to keyword-only search instead
|
||||
# of silently re-initializing a vendor here, which would bypass the
|
||||
# caller's state checks and risk corrupting the index.
|
||||
self.embedding_provider = embedding_provider
|
||||
if self.embedding_provider is None:
|
||||
from common.log import logger
|
||||
logger.info(
|
||||
"[MemoryManager] No embedding provider; memory will use keyword search only"
|
||||
)
|
||||
|
||||
# Cache for query embeddings (avoids redundant API calls within a session)
|
||||
self._embedding_cache = EmbeddingCache()
|
||||
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
self.flush_manager = MemoryFlushManager(
|
||||
@@ -133,12 +125,21 @@ class MemoryManager:
|
||||
if self.config.sync_on_search and self._dirty:
|
||||
await self.sync()
|
||||
|
||||
# Perform vector search (if embedding provider available)
|
||||
from common.log import logger
|
||||
|
||||
# Perform vector search (if embedding provider available).
|
||||
# Failures degrade silently to keyword-only — no exception is raised.
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
try:
|
||||
from common.log import logger
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
provider_name = type(self.embedding_provider).__name__
|
||||
model_name = getattr(self.embedding_provider, 'model', '')
|
||||
cached = self._embedding_cache.get(query, provider_name, model_name)
|
||||
if cached is not None:
|
||||
query_embedding = cached
|
||||
else:
|
||||
query_embedding = self.embedding_provider.embed_query(query)
|
||||
self._embedding_cache.put(query, provider_name, model_name, query_embedding)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
@@ -147,19 +148,19 @@ class MemoryManager:
|
||||
)
|
||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Vector search failed: {e}")
|
||||
|
||||
# Perform keyword search
|
||||
logger.error(
|
||||
f"[MemoryManager] Vector search failed, falling back to keyword-only: {e}"
|
||||
)
|
||||
|
||||
# Perform keyword search (also runs as fallback when vector failed)
|
||||
keyword_results = self.storage.search_keyword(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
vector_results,
|
||||
@@ -167,7 +168,7 @@ class MemoryManager:
|
||||
self.config.vector_weight,
|
||||
self.config.keyword_weight
|
||||
)
|
||||
|
||||
|
||||
# Filter by min score and limit
|
||||
filtered = [r for r in merged if r.score >= min_score]
|
||||
return filtered[:max_results]
|
||||
@@ -249,295 +250,195 @@ class MemoryManager:
|
||||
|
||||
async def sync(self, force: bool = False):
|
||||
"""
|
||||
Synchronize memory from files
|
||||
|
||||
Synchronize memory from files.
|
||||
|
||||
Two-pass design to amortize embedding HTTP cost:
|
||||
1. Walk all files, chunk those whose hash changed, collect pending
|
||||
chunks across files. No embedding calls yet.
|
||||
2. Run a single embed_batch over the union of pending chunks (the
|
||||
provider auto-paginates by vendor cap), then persist per-file.
|
||||
|
||||
For workspaces with many small files (101 files / ~1 chunk each), this
|
||||
cuts ~100 HTTP calls down to ~ceil(total_chunks / vendor_cap).
|
||||
|
||||
Args:
|
||||
force: Force full reindex
|
||||
"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
workspace_dir = self.config.get_workspace()
|
||||
|
||||
# Scan MEMORY.md (workspace root)
|
||||
|
||||
files_to_scan: List[tuple] = [] # (file_path, source, scope, user_id)
|
||||
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
await self._sync_file(memory_file, "memory", "shared", None)
|
||||
|
||||
# Scan memory directory (including daily summaries)
|
||||
files_to_scan.append((memory_file, "memory", "shared", None))
|
||||
|
||||
if memory_dir.exists():
|
||||
for file_path in memory_dir.rglob("*.md"):
|
||||
# Determine scope and user_id from path
|
||||
rel_path = file_path.relative_to(workspace_dir)
|
||||
parts = rel_path.parts
|
||||
|
||||
# Check if it's in daily summary directory
|
||||
if "daily" in parts:
|
||||
# Daily summary files
|
||||
if "users" in parts or len(parts) > 3:
|
||||
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
|
||||
user_idx = parts.index("daily") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
rel_parts = file_path.relative_to(workspace_dir).parts
|
||||
if any(part.startswith('.') for part in rel_parts):
|
||||
continue
|
||||
# Dream diaries are narrative reflections produced by Deep
|
||||
# Dream; their factual content has already been distilled
|
||||
# into MEMORY.md. Indexing them adds noisy near-duplicates
|
||||
# that crowd out the authoritative entry in retrieval.
|
||||
if "dreams" in rel_parts:
|
||||
continue
|
||||
if "daily" in rel_parts:
|
||||
if "users" in rel_parts or len(rel_parts) > 3:
|
||||
user_idx = rel_parts.index("daily") + 1
|
||||
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared daily summary: memory/daily/2024-01-29.md
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
elif "users" in parts:
|
||||
# User-scoped memory
|
||||
user_idx = parts.index("users") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
elif "users" in rel_parts:
|
||||
user_idx = rel_parts.index("users") + 1
|
||||
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared memory
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
|
||||
await self._sync_file(file_path, "memory", scope, user_id)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
async def _sync_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
source: str,
|
||||
scope: str,
|
||||
user_id: Optional[str]
|
||||
):
|
||||
"""Sync a single file"""
|
||||
# Compute file hash
|
||||
content = file_path.read_text()
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
|
||||
# Get relative path
|
||||
workspace_dir = self.config.get_workspace()
|
||||
rel_path = str(file_path.relative_to(workspace_dir))
|
||||
|
||||
# Check if file changed
|
||||
stored_hash = self.storage.get_file_hash(rel_path)
|
||||
if stored_hash == file_hash:
|
||||
return # No changes
|
||||
|
||||
# Delete old chunks
|
||||
self.storage.delete_by_path(rel_path)
|
||||
|
||||
# Chunk and embed
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None
|
||||
))
|
||||
|
||||
# Save
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
stat = file_path.stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size
|
||||
)
|
||||
|
||||
def should_flush_memory(
|
||||
self,
|
||||
current_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window。
|
||||
使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
|
||||
Returns:
|
||||
True if memory flush should run
|
||||
"""
|
||||
return self.flush_manager.should_flush(
|
||||
current_tokens=current_tokens,
|
||||
token_threshold=self.config.flush_token_threshold,
|
||||
turn_threshold=self.config.flush_turn_threshold
|
||||
)
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数(每次用户消息+AI回复算一轮)"""
|
||||
self.flush_manager.increment_turn()
|
||||
|
||||
async def execute_memory_flush(
|
||||
self,
|
||||
agent_executor,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush before compaction
|
||||
|
||||
This runs a silent agent turn to write durable memories to disk.
|
||||
Similar to clawdbot's pre-compaction memory flush.
|
||||
|
||||
Args:
|
||||
agent_executor: Async function to execute agent with prompt
|
||||
current_tokens: Current session token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
|
||||
Example:
|
||||
>>> async def run_agent(prompt, system_prompt, silent=False):
|
||||
... # Your agent execution logic
|
||||
... pass
|
||||
>>>
|
||||
>>> if manager.should_flush_memory(current_tokens=100000):
|
||||
... await manager.execute_memory_flush(
|
||||
... agent_executor=run_agent,
|
||||
... current_tokens=100000
|
||||
... )
|
||||
"""
|
||||
success = await self.flush_manager.execute_flush(
|
||||
agent_executor=agent_executor,
|
||||
current_tokens=current_tokens,
|
||||
user_id=user_id,
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
if success:
|
||||
# Mark dirty so next search will sync the new memories
|
||||
self._dirty = True
|
||||
|
||||
return success
|
||||
|
||||
def build_memory_guidance(self, lang: str = "zh", include_context: bool = True) -> str:
|
||||
"""
|
||||
Build natural memory guidance for agent system prompt
|
||||
|
||||
Following clawdbot's approach:
|
||||
1. Load MEMORY.md as bootstrap context (blends into background)
|
||||
2. Load daily files on-demand via memory_search tool
|
||||
3. Agent should NOT proactively mention memories unless user asks
|
||||
|
||||
Args:
|
||||
lang: Language for guidance ("en" or "zh")
|
||||
include_context: Whether to include bootstrap memory context (default: True)
|
||||
MEMORY.md is loaded as background context (like clawdbot)
|
||||
Daily files are accessed via memory_search tool
|
||||
|
||||
Returns:
|
||||
Memory guidance text (and optionally context) for system prompt
|
||||
"""
|
||||
today_file = self.flush_manager.get_today_memory_file().name
|
||||
|
||||
if lang == "zh":
|
||||
guidance = f"""## 记忆系统
|
||||
files_to_scan.append((file_path, "memory", scope, user_id))
|
||||
|
||||
**背景知识**: 下方包含核心长期记忆,可直接使用。需要查找历史时,用 memory_search 搜索(搜索一次即可,不要重复)。
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
knowledge_dir = Path(workspace_dir) / "knowledge"
|
||||
if knowledge_dir.exists():
|
||||
for file_path in knowledge_dir.rglob("*.md"):
|
||||
files_to_scan.append((file_path, "knowledge", "shared", None))
|
||||
|
||||
**存储记忆**: 当用户分享重要信息时(偏好、决策、事实等),主动用 write 工具存储:
|
||||
- 长期信息 → MEMORY.md
|
||||
- 当天笔记 → memory/{today_file}
|
||||
- 静默存储,仅在明确要求时确认
|
||||
|
||||
**使用原则**: 自然使用记忆,就像你本来就知道。不需要生硬地提起或列举记忆,除非用户提到。"""
|
||||
else:
|
||||
guidance = f"""## Memory System
|
||||
|
||||
**Background Knowledge**: Core long-term memories below - use directly. For history, use memory_search once (don't repeat).
|
||||
|
||||
**Store Memories**: When user shares important info (preferences, decisions, facts), proactively write:
|
||||
- Durable info → MEMORY.md
|
||||
- Daily notes → memory/{today_file}
|
||||
- Store silently; confirm only when explicitly requested
|
||||
|
||||
**Usage**: Use memories naturally as if you always knew. Don't mention or list unless user explicitly asks."""
|
||||
|
||||
if include_context:
|
||||
# Load bootstrap context (MEMORY.md only, like clawdbot)
|
||||
bootstrap_context = self.load_bootstrap_memories()
|
||||
if bootstrap_context:
|
||||
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
|
||||
|
||||
return guidance
|
||||
|
||||
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Load bootstrap memory files for session start
|
||||
|
||||
Following clawdbot's design:
|
||||
- Only loads MEMORY.md from workspace root (long-term curated memory)
|
||||
- Daily files (memory/YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
|
||||
- User-specific MEMORY.md is also loaded if user_id provided
|
||||
|
||||
Returns memory content WITHOUT obvious headers so it blends naturally
|
||||
into the context as background knowledge.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memories
|
||||
|
||||
Returns:
|
||||
Memory content to inject into system prompt (blends naturally as background context)
|
||||
"""
|
||||
workspace_dir = self.config.get_workspace()
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
|
||||
sections = []
|
||||
|
||||
# 1. Load MEMORY.md from workspace root (long-term curated memory)
|
||||
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
# Pass 1: inline chunking + change detection. Inlined (instead of
|
||||
# calling self._prepare_file_for_sync) so this method does not depend
|
||||
# on any sibling helpers — keeps it robust against partial reloads
|
||||
# where the class object is older than the method's source.
|
||||
pending: List[Dict[str, Any]] = []
|
||||
workspace_dir_path = self.config.get_workspace()
|
||||
for file_path, source, scope, user_id in files_to_scan:
|
||||
try:
|
||||
content = memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except Exception:
|
||||
continue
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
rel_path = str(file_path.relative_to(workspace_dir_path))
|
||||
if self.storage.get_file_hash(rel_path) == file_hash:
|
||||
continue
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
continue
|
||||
pending.append({
|
||||
"file_path": file_path,
|
||||
"rel_path": rel_path,
|
||||
"source": source,
|
||||
"scope": scope,
|
||||
"user_id": user_id,
|
||||
"file_hash": file_hash,
|
||||
"chunks": chunks,
|
||||
"texts": [c.text for c in chunks],
|
||||
})
|
||||
|
||||
if not pending:
|
||||
self._dirty = False
|
||||
return
|
||||
|
||||
# Pass 2: single batched embed across all pending chunks.
|
||||
# CRITICAL: never touch the index until we hold valid embeddings.
|
||||
# If embed_batch fails, leave the existing index intact (chunks +
|
||||
# file_hash) so the next sync will retry the same files. Writing
|
||||
# NULL embeddings + updating file_hash here would mark the file as
|
||||
# "successfully synced" and silently strand it without vectors.
|
||||
all_texts: List[str] = []
|
||||
for entry in pending:
|
||||
all_texts.extend(entry["texts"])
|
||||
|
||||
if not self.embedding_provider:
|
||||
# No provider configured at all (legacy keyword-only). Persist
|
||||
# chunks without embeddings — this is the user's intent.
|
||||
all_embeddings: List[Optional[List[float]]] = [None] * len(all_texts)
|
||||
else:
|
||||
try:
|
||||
all_embeddings = self.embedding_provider.embed_batch(all_texts)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read MEMORY.md: {e}")
|
||||
|
||||
# 2. Load user-specific MEMORY.md if user_id provided
|
||||
if user_id:
|
||||
user_memory_dir = memory_dir / "users" / user_id
|
||||
user_memory_file = user_memory_dir / "MEMORY.md"
|
||||
if user_memory_file.exists():
|
||||
try:
|
||||
content = user_memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read user memory: {e}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
# Join sections without obvious headers - let memories blend naturally
|
||||
# This makes the agent feel like it "just knows" rather than "checking memory files"
|
||||
return "\n\n".join(sections)
|
||||
from common.log import logger
|
||||
logger.error(
|
||||
f"[MemoryManager] Batch embedding failed for {len(all_texts)} "
|
||||
f"chunks across {len(pending)} files: {e}. "
|
||||
f"Index left untouched; will retry on next sync."
|
||||
)
|
||||
# Bail before touching storage. self._dirty stays True so
|
||||
# callers know there is pending work.
|
||||
return
|
||||
|
||||
# Pass 3: inline persist — same self-contained reasoning as Pass 1.
|
||||
cursor = 0
|
||||
for entry in pending:
|
||||
n = len(entry["texts"])
|
||||
entry_embeddings = all_embeddings[cursor:cursor + n]
|
||||
cursor += n
|
||||
|
||||
rel_path = entry["rel_path"]
|
||||
self.storage.delete_by_path(rel_path)
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(entry["chunks"], entry_embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=entry["user_id"],
|
||||
scope=entry["scope"],
|
||||
source=entry["source"],
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None,
|
||||
))
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
stat = entry["file_path"].stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=entry["source"],
|
||||
file_hash=entry["file_hash"],
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size,
|
||||
)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
def flush_memory(
|
||||
self,
|
||||
messages: list,
|
||||
user_id: Optional[str] = None,
|
||||
reason: str = "threshold",
|
||||
max_messages: int = 10,
|
||||
context_summary_callback=None,
|
||||
) -> bool:
|
||||
"""
|
||||
Flush conversation summary to daily memory file.
|
||||
|
||||
Args:
|
||||
messages: Conversation message list
|
||||
user_id: Optional user ID
|
||||
reason: "threshold" | "overflow" | "daily_summary"
|
||||
max_messages: Max recent messages to include (0 = all)
|
||||
context_summary_callback: Optional callback(str) invoked with the
|
||||
daily summary text for in-context injection
|
||||
|
||||
Returns:
|
||||
True if flush was dispatched
|
||||
"""
|
||||
success = self.flush_manager.flush_from_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
reason=reason,
|
||||
max_messages=max_messages,
|
||||
context_summary_callback=context_summary_callback,
|
||||
)
|
||||
if success:
|
||||
self._dirty = True
|
||||
return success
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get memory status"""
|
||||
@@ -568,6 +469,37 @@ class MemoryManager:
|
||||
content = f"{path}:{start_line}:{end_line}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_temporal_decay(path: str, half_life_days: float = 30.0) -> float:
|
||||
"""
|
||||
Compute temporal decay multiplier for dated memory files.
|
||||
|
||||
Inspired by OpenClaw's temporal-decay: exponential decay based on file date.
|
||||
MEMORY.md and non-dated files are "evergreen" (no decay, multiplier=1.0).
|
||||
Daily files like memory/2025-03-01.md decay based on age.
|
||||
|
||||
Formula: multiplier = exp(-ln2/half_life * age_in_days)
|
||||
"""
|
||||
import re
|
||||
import math
|
||||
|
||||
match = re.search(r'(\d{4})-(\d{2})-(\d{2})\.md$', path)
|
||||
if not match:
|
||||
return 1.0 # evergreen: MEMORY.md, non-dated files
|
||||
|
||||
try:
|
||||
file_date = datetime(
|
||||
int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
)
|
||||
age_days = (datetime.now() - file_date).days
|
||||
if age_days <= 0:
|
||||
return 1.0
|
||||
|
||||
decay_lambda = math.log(2) / half_life_days
|
||||
return math.exp(-decay_lambda * age_days)
|
||||
except (ValueError, OverflowError):
|
||||
return 1.0
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
vector_results: List[SearchResult],
|
||||
@@ -575,8 +507,7 @@ class MemoryManager:
|
||||
vector_weight: float,
|
||||
keyword_weight: float
|
||||
) -> List[SearchResult]:
|
||||
"""Merge vector and keyword search results"""
|
||||
# Create a map by (path, start_line, end_line)
|
||||
"""Merge vector and keyword search results with temporal decay for dated files"""
|
||||
merged_map = {}
|
||||
|
||||
for result in vector_results:
|
||||
@@ -598,7 +529,6 @@ class MemoryManager:
|
||||
'keyword_score': result.score
|
||||
}
|
||||
|
||||
# Calculate combined scores
|
||||
merged_results = []
|
||||
for entry in merged_map.values():
|
||||
combined_score = (
|
||||
@@ -606,7 +536,11 @@ class MemoryManager:
|
||||
keyword_weight * entry['keyword_score']
|
||||
)
|
||||
|
||||
# Apply temporal decay for dated memory files
|
||||
result = entry['result']
|
||||
decay = self._compute_temporal_decay(result.path)
|
||||
combined_score *= decay
|
||||
|
||||
merged_results.append(SearchResult(
|
||||
path=result.path,
|
||||
start_line=result.start_line,
|
||||
@@ -617,6 +551,5 @@ class MemoryManager:
|
||||
user_id=result.user_id
|
||||
))
|
||||
|
||||
# Sort by score
|
||||
merged_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return merged_results
|
||||
|
||||
14
agent/memory/rebuild_index.py
Normal file
14
agent/memory/rebuild_index.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Backward-compatible shim for the legacy entry point:
|
||||
python -m agent.memory.rebuild_index
|
||||
|
||||
The implementation now lives in agent.memory.embedding.rebuild.
|
||||
Prefer using `/memory rebuild-index` in chat going forward.
|
||||
"""
|
||||
|
||||
from agent.memory.embedding.rebuild import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(main())
|
||||
197
agent/memory/service.py
Normal file
197
agent/memory/service.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Memory service for handling memory query operations via cloud protocol.
|
||||
|
||||
Provides a unified interface for listing and reading memory files,
|
||||
callable from the cloud client (LinkAI) or a future web console.
|
||||
|
||||
Memory file layout (under workspace_root):
|
||||
MEMORY.md -> type: global
|
||||
memory/2026-02-20.md -> type: daily
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
High-level service for memory file queries.
|
||||
Operates directly on the filesystem — no MemoryManager dependency.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
"""
|
||||
:param workspace_root: Workspace root directory (e.g. ~/cow)
|
||||
"""
|
||||
self.workspace_root = workspace_root
|
||||
self.memory_dir = os.path.join(workspace_root, "memory")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — paginated file metadata
|
||||
# ------------------------------------------------------------------
|
||||
def list_files(self, page: int = 1, page_size: int = 20, category: str = "memory") -> dict:
|
||||
"""
|
||||
List memory or dream files with metadata (without content).
|
||||
|
||||
Args:
|
||||
category: ``"memory"`` (default) — MEMORY.md + daily files;
|
||||
``"dream"`` — dream diary files from memory/dreams/
|
||||
"""
|
||||
if category == "dream":
|
||||
files = self._list_dream_files()
|
||||
else:
|
||||
files = self._list_memory_files()
|
||||
|
||||
total = len(files)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"list": files[start:end],
|
||||
}
|
||||
|
||||
def _list_memory_files(self) -> List[dict]:
|
||||
"""MEMORY.md + memory/*.md (newest first)."""
|
||||
files: List[dict] = []
|
||||
|
||||
global_path = os.path.join(self.workspace_root, "MEMORY.md")
|
||||
if os.path.isfile(global_path):
|
||||
files.append(self._file_info(global_path, "MEMORY.md", "global"))
|
||||
|
||||
if os.path.isdir(self.memory_dir):
|
||||
daily_files = []
|
||||
for name in os.listdir(self.memory_dir):
|
||||
full = os.path.join(self.memory_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
daily_files.append((name, full))
|
||||
daily_files.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in daily_files:
|
||||
files.append(self._file_info(full, name, "daily"))
|
||||
|
||||
return files
|
||||
|
||||
def _list_dream_files(self) -> List[dict]:
|
||||
"""memory/dreams/*.md (newest first)."""
|
||||
files: List[dict] = []
|
||||
dreams_dir = os.path.join(self.memory_dir, "dreams")
|
||||
|
||||
if os.path.isdir(dreams_dir):
|
||||
entries = []
|
||||
for name in os.listdir(dreams_dir):
|
||||
full = os.path.join(dreams_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
entries.append((name, full))
|
||||
entries.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in entries:
|
||||
files.append(self._file_info(full, name, "dream"))
|
||||
|
||||
return files
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# content — read a single file
|
||||
# ------------------------------------------------------------------
|
||||
def get_content(self, filename: str, category: str = "memory") -> dict:
|
||||
"""
|
||||
Read the full content of a memory or dream file.
|
||||
|
||||
:param filename: File name, e.g. ``MEMORY.md``, ``2026-02-20.md``
|
||||
:param category: ``"memory"`` or ``"dream"``
|
||||
:return: dict with ``filename`` and ``content``
|
||||
:raises FileNotFoundError: if the file does not exist
|
||||
"""
|
||||
path = self._resolve_path(filename, category)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Memory file not found: {filename}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a memory management action.
|
||||
|
||||
:param action: ``list`` or ``content``
|
||||
:param payload: action-specific payload (supports ``category``: ``"memory"`` | ``"dream"``)
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
page = payload.get("page", 1)
|
||||
page_size = payload.get("page_size", 20)
|
||||
category = payload.get("category", "memory")
|
||||
result_payload = self.list_files(page=page, page_size=page_size, category=category)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
elif action == "content":
|
||||
filename = payload.get("filename")
|
||||
if not filename:
|
||||
return {"action": action, "code": 400, "message": "filename is required", "payload": None}
|
||||
category = payload.get("category", "memory")
|
||||
result_payload = self.get_content(filename, category=category)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 403, "message": "invalid filename", "payload": None}
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[MemoryService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _resolve_path(self, filename: str, category: str = "memory") -> str:
|
||||
"""
|
||||
Safely resolve a filename to its absolute path within the allowed directory.
|
||||
|
||||
- ``MEMORY.md`` → ``{workspace_root}/MEMORY.md``
|
||||
- ``2026-02-20.md`` (memory) → ``{workspace_root}/memory/2026-02-20.md``
|
||||
- ``2026-02-20.md`` (dream) → ``{workspace_root}/memory/dreams/2026-02-20.md``
|
||||
|
||||
Raises ValueError if the resolved path escapes the allowed directory.
|
||||
"""
|
||||
if filename == "MEMORY.md":
|
||||
base_dir = self.workspace_root
|
||||
elif category == "dream":
|
||||
base_dir = os.path.join(self.memory_dir, "dreams")
|
||||
else:
|
||||
base_dir = self.memory_dir
|
||||
|
||||
resolved = os.path.realpath(os.path.join(base_dir, filename))
|
||||
allowed = os.path.realpath(base_dir)
|
||||
|
||||
if resolved != allowed and not resolved.startswith(allowed + os.sep):
|
||||
raise ValueError(f"Invalid filename: path traversal detected")
|
||||
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _file_info(path: str, filename: str, file_type: str) -> dict:
|
||||
"""Build a file metadata dict."""
|
||||
stat = os.stat(path)
|
||||
updated_at = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
||||
return {
|
||||
"filename": filename,
|
||||
"type": file_type,
|
||||
"size": stat.st_size,
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -4,22 +4,24 @@ System Prompt Builder - 系统提示词构建器
|
||||
实现模块化的系统提示词构建,支持工具、技能、记忆等多个子系统
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextFile:
|
||||
"""上下文文件"""
|
||||
"""A context file (path + content)."""
|
||||
path: str
|
||||
content: str
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""提示词构建器"""
|
||||
"""System prompt builder."""
|
||||
|
||||
def __init__(self, workspace_dir: str, language: str = "zh"):
|
||||
"""
|
||||
@@ -41,21 +43,19 @@ class PromptBuilder:
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
is_first_conversation: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建完整的系统提示词
|
||||
|
||||
Args:
|
||||
base_persona: 基础人格描述(会被context_files中的SOUL.md覆盖)
|
||||
base_persona: 基础人格描述(会被context_files中的AGENT.md覆盖)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表(SOUL.md, USER.md, README.md等)
|
||||
context_files: 上下文文件列表(AGENT.md, USER.md, RULE.md, BOOTSTRAP.md等)
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
is_first_conversation: 是否为首次对话
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
@@ -71,7 +71,6 @@ class PromptBuilder:
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
is_first_conversation=is_first_conversation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -86,175 +85,213 @@ def build_agent_system_prompt(
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
is_first_conversation: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建Agent系统提示词
|
||||
|
||||
顺序说明(按重要性和逻辑关系排列):
|
||||
1. 工具系统 - 核心能力,最先介绍
|
||||
2. 技能系统 - 紧跟工具,因为技能需要用 read 工具读取
|
||||
3. 记忆系统 - 独立的记忆能力
|
||||
4. 工作空间 - 工作环境说明
|
||||
5. 用户身份 - 用户信息(可选)
|
||||
6. 项目上下文 - SOUL.md, USER.md, AGENTS.md(定义人格和身份)
|
||||
7. 运行时信息 - 元信息(时间、模型等)
|
||||
|
||||
Build the agent system prompt.
|
||||
|
||||
Section order (by importance and logical flow):
|
||||
1. Tooling - core capabilities, introduced first
|
||||
2. Skills - right after tools, since skills are read via the read tool
|
||||
3. Memory - memory recall and writing guidance
|
||||
3.5 Knowledge - structured knowledge base (injects knowledge/index.md)
|
||||
4. Workspace - working environment description
|
||||
5. User identity - user info (optional)
|
||||
6. Project context - AGENT.md, USER.md, RULE.md, MEMORY.md, BOOTSTRAP.md
|
||||
7. Runtime info - meta info (time, model, etc.)
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
language: 语言 ("zh" 或 "en")
|
||||
base_persona: 基础人格描述(已废弃,由SOUL.md定义)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
is_first_conversation: 是否为首次对话
|
||||
**kwargs: 其他参数
|
||||
|
||||
workspace_dir: workspace directory
|
||||
language: language ("zh" or "en")
|
||||
base_persona: base persona description (deprecated, defined by AGENT.md)
|
||||
user_identity: user identity info
|
||||
tools: tool list
|
||||
context_files: context file list
|
||||
skill_manager: skill manager
|
||||
memory_manager: memory manager
|
||||
runtime_info: runtime info
|
||||
**kwargs: extra args
|
||||
|
||||
Returns:
|
||||
完整的系统提示词
|
||||
The full system prompt.
|
||||
"""
|
||||
sections = []
|
||||
|
||||
# 1. 工具系统(最重要,放在最前面)
|
||||
|
||||
# 1. Tooling (most important, goes first)
|
||||
if tools:
|
||||
sections.extend(_build_tooling_section(tools, language))
|
||||
|
||||
# 2. 技能系统(紧跟工具,因为需要用 read 工具)
|
||||
|
||||
# 2. Skills (right after tools, since they need the read tool)
|
||||
if skill_manager:
|
||||
sections.extend(_build_skills_section(skill_manager, tools, language))
|
||||
|
||||
# 3. 记忆系统(独立的记忆能力)
|
||||
|
||||
# 3. Memory (standalone memory capability)
|
||||
if memory_manager:
|
||||
sections.extend(_build_memory_section(memory_manager, tools, language))
|
||||
|
||||
# 4. 工作空间(工作环境说明)
|
||||
sections.extend(_build_workspace_section(workspace_dir, language, is_first_conversation))
|
||||
|
||||
# 5. 用户身份(如果有)
|
||||
|
||||
# 3.5 Knowledge (structured knowledge base)
|
||||
if conf().get("knowledge", True):
|
||||
sections.extend(_build_knowledge_section(workspace_dir, language))
|
||||
|
||||
# 4. Workspace (working environment description)
|
||||
sections.extend(_build_workspace_section(workspace_dir, language))
|
||||
|
||||
# 5. User identity (if present)
|
||||
if user_identity:
|
||||
sections.extend(_build_user_identity_section(user_identity, language))
|
||||
|
||||
# 6. 项目上下文文件(SOUL.md, USER.md, AGENTS.md - 定义人格)
|
||||
|
||||
# 6. Project context files (AGENT.md, USER.md, RULE.md - define the persona)
|
||||
if context_files:
|
||||
sections.extend(_build_context_files_section(context_files, language))
|
||||
|
||||
# 7. 运行时信息(元信息,放在最后)
|
||||
|
||||
# 7. Runtime info (meta info, goes last)
|
||||
if runtime_info:
|
||||
sections.extend(_build_runtime_section(runtime_info, language))
|
||||
|
||||
|
||||
# 8. Response language (always appended, independent of the skeleton language)
|
||||
sections.extend(_build_response_language_section(language))
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def _build_response_language_section(language: str) -> List[str]:
|
||||
"""Response-language rule, appended regardless of the prompt skeleton language.
|
||||
|
||||
Keeps the agent's reply language aligned with the user's input by default,
|
||||
so a Chinese-built prompt still answers an English user in English.
|
||||
"""
|
||||
if language == "en":
|
||||
return [
|
||||
"## 🌐 Response language",
|
||||
"",
|
||||
"By default, reply in the same language as the user's input, "
|
||||
"unless the user explicitly asks for another language.",
|
||||
"",
|
||||
]
|
||||
return [
|
||||
"## 🌐 回复语言",
|
||||
"",
|
||||
"默认使用与用户输入相同的语言回复,除非用户明确要求使用其他语言。",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
def _build_identity_section(base_persona: Optional[str], language: str) -> List[str]:
|
||||
"""构建基础身份section - 不再需要,身份由SOUL.md定义"""
|
||||
# 不再生成基础身份section,完全由SOUL.md定义
|
||||
"""Base identity section - no longer needed, identity is defined by AGENT.md."""
|
||||
# Identity is fully defined by AGENT.md, so emit nothing here.
|
||||
return []
|
||||
|
||||
|
||||
def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"""构建工具说明section"""
|
||||
lines = [
|
||||
"## 工具系统",
|
||||
"",
|
||||
"你可以使用以下工具来完成任务。工具名称是大小写敏感的,请严格按照列表中的名称调用。",
|
||||
"",
|
||||
"### 可用工具",
|
||||
"",
|
||||
"""Build tooling section with concise tool list and call style guide."""
|
||||
is_en = language == "en"
|
||||
# One-line summaries for known tools (details are in the tool schema)
|
||||
if is_en:
|
||||
core_summaries = {
|
||||
"read": "read file content",
|
||||
"write": "create or overwrite a file",
|
||||
"edit": "make precise edits to a file",
|
||||
"ls": "list directory contents",
|
||||
"grep": "search file contents",
|
||||
"find": "find files by pattern",
|
||||
"bash": "run shell commands",
|
||||
"terminal": "manage background processes",
|
||||
"web_search": "web search",
|
||||
"web_fetch": "fetch URL content",
|
||||
"browser": "control the browser (screenshot key results or send to the user when help is needed)",
|
||||
"memory_search": "search memory",
|
||||
"memory_get": "read memory content",
|
||||
"env_config": "manage API keys and skill config",
|
||||
"scheduler": "manage scheduled tasks and reminders",
|
||||
"send": "send a local file to the user (local files only; put URLs directly in the reply text)",
|
||||
"vision": "analyze images (recognition, description, OCR, etc.)",
|
||||
}
|
||||
else:
|
||||
core_summaries = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建或覆盖文件",
|
||||
"edit": "精确编辑文件",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "搜索文件内容",
|
||||
"find": "按模式查找文件",
|
||||
"bash": "执行shell命令",
|
||||
"terminal": "管理后台进程",
|
||||
"web_search": "网络搜索",
|
||||
"web_fetch": "获取URL内容",
|
||||
"browser": "控制浏览器(关键结果或需要协助可截图发送给用户)",
|
||||
"memory_search": "搜索记忆",
|
||||
"memory_get": "读取记忆内容",
|
||||
"env_config": "管理API密钥和技能配置",
|
||||
"scheduler": "管理定时任务和提醒",
|
||||
"send": "发送本地文件给用户(仅限本地文件,URL直接放在回复文本中)",
|
||||
"vision": "分析图片内容(识别、描述、OCR文字提取等)",
|
||||
}
|
||||
|
||||
# Preferred display order
|
||||
tool_order = [
|
||||
"read", "write", "edit", "ls", "grep", "find",
|
||||
"bash", "terminal",
|
||||
"web_search", "web_fetch", "browser",
|
||||
"memory_search", "memory_get",
|
||||
"env_config", "scheduler", "send", "vision",
|
||||
]
|
||||
|
||||
# 工具分类和排序
|
||||
tool_categories = {
|
||||
"文件操作": ["read", "write", "edit", "ls", "grep", "find"],
|
||||
"命令执行": ["bash", "terminal"],
|
||||
"网络搜索": ["web_search", "web_fetch", "browser"],
|
||||
"记忆系统": ["memory_search", "memory_get"],
|
||||
"其他": []
|
||||
}
|
||||
|
||||
# 构建工具映射
|
||||
tool_map = {}
|
||||
tool_descriptions = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建新文件或完全覆盖现有文件(会删除原内容!追加内容请用 edit)。注意:单次 write 内容不要超过 10KB,超大文件请分步创建",
|
||||
"edit": "精确编辑文件(追加、修改、删除部分内容)",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "在文件中搜索内容",
|
||||
"find": "按照模式查找文件",
|
||||
"bash": "执行shell命令",
|
||||
"terminal": "管理后台进程",
|
||||
"web_search": "网络搜索(使用搜索引擎)",
|
||||
"web_fetch": "获取URL内容",
|
||||
"browser": "控制浏览器",
|
||||
"memory_search": "搜索记忆文件",
|
||||
"memory_get": "获取记忆文件内容",
|
||||
"calculator": "计算器",
|
||||
"current_time": "获取当前时间",
|
||||
}
|
||||
|
||||
|
||||
# Build name -> summary mapping for available tools
|
||||
available = {}
|
||||
for tool in tools:
|
||||
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
tool_desc = tool.description if hasattr(tool, 'description') else tool_descriptions.get(tool_name, "")
|
||||
tool_map[tool_name] = tool_desc
|
||||
|
||||
# 按分类添加工具
|
||||
for category, tool_names in tool_categories.items():
|
||||
category_tools = [(name, tool_map.get(name, "")) for name in tool_names if name in tool_map]
|
||||
if category_tools:
|
||||
lines.append(f"**{category}**:")
|
||||
for name, desc in category_tools:
|
||||
if desc:
|
||||
lines.append(f"- `{name}`: {desc}")
|
||||
else:
|
||||
lines.append(f"- `{name}`")
|
||||
del tool_map[name] # 移除已添加的工具
|
||||
lines.append("")
|
||||
|
||||
# 添加其他未分类的工具
|
||||
if tool_map:
|
||||
lines.append("**其他工具**:")
|
||||
for name, desc in sorted(tool_map.items()):
|
||||
if desc:
|
||||
lines.append(f"- `{name}`: {desc}")
|
||||
else:
|
||||
lines.append(f"- `{name}`")
|
||||
lines.append("")
|
||||
|
||||
# 工具使用指南
|
||||
lines.extend([
|
||||
"### 工具调用风格",
|
||||
"",
|
||||
"默认规则: 对于常规、低风险的工具调用,直接调用即可,无需叙述。",
|
||||
"",
|
||||
"需要叙述的情况:",
|
||||
"- 多步骤、复杂的任务",
|
||||
"- 敏感操作(如删除文件)",
|
||||
"- 用户明确要求解释过程",
|
||||
"",
|
||||
"叙述要求: 保持简洁、信息密度高,避免重复显而易见的步骤。",
|
||||
"",
|
||||
"完成标准:",
|
||||
"- 确保用户的需求得到实际解决,而不仅仅是制定计划。",
|
||||
"- 当任务需要多次工具调用时,持续推进直到完成, 解决完后向用户报告结果或回复用户的问题",
|
||||
"- 每次工具调用后,评估是否已获得足够信息来推进或完成任务",
|
||||
"- 避免重复调用相同的工具和相同参数获取相同的信息,除非用户明确要求",
|
||||
"",
|
||||
"**安全提醒**: 回复中涉及密钥、令牌、密码等敏感信息时,必须脱敏处理,禁止直接显示完整内容。",
|
||||
"",
|
||||
])
|
||||
|
||||
name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
available[name] = core_summaries.get(name, "")
|
||||
|
||||
# Generate tool lines: ordered tools first, then extras
|
||||
tool_lines = []
|
||||
for name in tool_order:
|
||||
if name in available:
|
||||
summary = available.pop(name)
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
for name in sorted(available):
|
||||
summary = available[name]
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
|
||||
if is_en:
|
||||
lines = [
|
||||
"## 🔧 Tooling",
|
||||
"",
|
||||
"Available tools (names are case-sensitive, call exactly as listed):",
|
||||
"\n".join(tool_lines),
|
||||
"",
|
||||
"Tool-calling style:",
|
||||
"",
|
||||
"- For multi-step tasks, complex decisions or sensitive operations, briefly explain what you are doing and why, so the user follows key progress",
|
||||
"- Keep going until the task is done, then report the result to the user",
|
||||
"- Always redact secrets, tokens and other sensitive info in replies",
|
||||
"- Put URLs directly in the reply text; the system handles and renders them. Don't download and re-send them via the send tool",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 🔧 工具系统",
|
||||
"",
|
||||
"可用工具(名称大小写敏感,严格按列表调用):",
|
||||
"\n".join(tool_lines),
|
||||
"",
|
||||
"工具调用风格:",
|
||||
"",
|
||||
"- 多步骤任务、复杂决策、敏感操作时,应简要说明当前在做什么、为什么这样做,让用户了解关键进展",
|
||||
"- 持续推进直到任务完成,完成后向用户报告结果",
|
||||
"- 回复中涉及密钥、令牌等敏感信息必须脱敏",
|
||||
"- URL链接直接放在回复文本中即可,系统会自动处理和渲染。无需下载后使用send工具发送",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""构建技能系统section"""
|
||||
"""Build the skills section."""
|
||||
if not skill_manager:
|
||||
return []
|
||||
|
||||
# 获取read工具名称
|
||||
# Resolve the read tool name
|
||||
read_tool_name = "read"
|
||||
if tools:
|
||||
for tool in tools:
|
||||
@@ -263,186 +300,393 @@ def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], langua
|
||||
read_tool_name = tool_name
|
||||
break
|
||||
|
||||
lines = [
|
||||
"## 技能系统",
|
||||
"",
|
||||
"在回复之前:扫描下方 <available_skills> 中的 <description> 条目。",
|
||||
"",
|
||||
f"- 如果恰好有一个技能明确适用:使用 `{read_tool_name}` 工具读取其 <location> 路径下的 SKILL.md 文件,然后遵循它",
|
||||
"- 如果多个技能都适用:选择最具体的一个,然后读取并遵循",
|
||||
"- 如果没有明确适用的:不要读取任何 SKILL.md",
|
||||
"",
|
||||
"**约束**: 永远不要一次性读取多个技能;只在选择后再读取。",
|
||||
"",
|
||||
]
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 🧩 Skills (mandatory)",
|
||||
"",
|
||||
"Before replying: scan the <description> of every skill in <available_skills> below.",
|
||||
"",
|
||||
f"- If a skill's description matches the user's need: use the `{read_tool_name}` tool to read the SKILL.md at its <location> path, then strictly follow the instructions in the file. "
|
||||
"Prefer using a skill when one matches.",
|
||||
"- If multiple skills apply, pick the best-matching one, then read and follow it.",
|
||||
"- If no skill clearly applies: do not read any SKILL.md, just use the general tools.",
|
||||
"",
|
||||
f"**Important**: skills are not tools and cannot be called directly. The only way to use a skill is to read its SKILL.md with `{read_tool_name}`, then act on the file's content. "
|
||||
"Never read multiple skills at once — only read one after selecting it.",
|
||||
"",
|
||||
"Available skills:"
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 🧩 技能系统(mandatory)",
|
||||
"",
|
||||
"在回复之前:扫描下方 <available_skills> 中每个技能的 <description>。",
|
||||
"",
|
||||
f"- 如果有技能的描述与用户需求匹配:使用 `{read_tool_name}` 工具读取其 <location> 路径的 SKILL.md 文件,然后严格遵循文件中的指令。"
|
||||
"当有匹配的技能时,应优先使用技能",
|
||||
"- 如果多个技能都适用则选择最匹配的一个,然后读取并遵循。",
|
||||
"- 如果没有技能明确适用:不要读取任何 SKILL.md,直接使用通用工具。",
|
||||
"",
|
||||
f"**重要**: 技能不是工具,不能直接调用。使用技能的唯一方式是用 `{read_tool_name}` 读取 SKILL.md 文件,然后按文件内容操作。"
|
||||
"永远不要一次性读取多个技能,只在选择后再读取。",
|
||||
"",
|
||||
"以下是可用技能:"
|
||||
]
|
||||
|
||||
# 添加技能列表(通过skill_manager获取)
|
||||
# Append the skills list (built by skill_manager)
|
||||
try:
|
||||
skills_prompt = skill_manager.build_skills_prompt()
|
||||
logger.debug(f"[PromptBuilder] Skills prompt length: {len(skills_prompt) if skills_prompt else 0}")
|
||||
if skills_prompt:
|
||||
lines.append(skills_prompt.strip())
|
||||
lines.append("")
|
||||
else:
|
||||
logger.warning("[PromptBuilder] No skills prompt generated - skills_prompt is empty")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build skills prompt: {e}")
|
||||
import traceback
|
||||
logger.debug(f"Skills prompt error traceback: {traceback.format_exc()}")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""构建记忆系统section"""
|
||||
"""Build the memory section."""
|
||||
if not memory_manager:
|
||||
return []
|
||||
|
||||
# 检查是否有memory工具
|
||||
|
||||
has_memory_tools = False
|
||||
if tools:
|
||||
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
|
||||
has_memory_tools = any(name in ['memory_search', 'memory_get'] for name in tool_names)
|
||||
|
||||
|
||||
if not has_memory_tools:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 记忆系统",
|
||||
|
||||
from datetime import datetime
|
||||
today_file = datetime.now().strftime("%Y-%m-%d") + ".md"
|
||||
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 🧠 Memory",
|
||||
"",
|
||||
"### Memory Recall (mandatory)",
|
||||
"",
|
||||
"When the user asks about past events, references an earlier decision, mentions relationships, preferences or to-dos, or when you are unsure about something, **you must search memory before answering**.",
|
||||
"No need to re-search if the info is already in MEMORY.md. Full content and daily memory must be retrieved via tools.",
|
||||
"",
|
||||
"1. Location unknown → `memory_search` (keyword / semantic search)",
|
||||
"2. Location known → `memory_get` to read the exact lines",
|
||||
"3. Search returns nothing → `memory_get` to read the last two days of memory",
|
||||
"",
|
||||
"**Memory file structure**:",
|
||||
"- `MEMORY.md`: long-term memory index (already auto-loaded into context: core info, preferences, decisions, etc.)",
|
||||
f"- `memory/YYYY-MM-DD.md`: daily memory; today is `memory/{today_file}`",
|
||||
"- `knowledge/`: structured knowledge base (see the knowledge system below)",
|
||||
"",
|
||||
"### Writing memory",
|
||||
"",
|
||||
"In the following cases, **proactively** write info to memory files (no need to tell the user):",
|
||||
"",
|
||||
"- The user asks you to remember something, or uses words like \"remember\", \"from now on\", \"always\", \"never\", \"prefer\"",
|
||||
"- The user shares important personal preferences, habits or decisions",
|
||||
"- The conversation produces an important conclusion, plan or agreement",
|
||||
"- A complex task is completed and the key steps and results are worth recording",
|
||||
"",
|
||||
"**Storage rules**:",
|
||||
"- Long-term core info → `MEMORY.md`",
|
||||
f"- Today's events/progress → `memory/{today_file}`",
|
||||
"- Structured knowledge → `knowledge/` (see the knowledge system)",
|
||||
"- Append → `edit` tool with empty oldText",
|
||||
"- Modify → `edit` tool with oldText set to the text to replace",
|
||||
"- **Never write sensitive info** (API keys, tokens, etc.)",
|
||||
"",
|
||||
"**Principle**: use memory naturally, as if you simply knew it; don't bring it up unless asked.",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 🧠 记忆系统",
|
||||
"",
|
||||
"### Memory Recall(mandatory)",
|
||||
"",
|
||||
"当用户询问过往事件、引用之前的决定、提到人物关系、偏好、待办、或你对某事不确定时,**必须先检索记忆再回答**。",
|
||||
"如果 MEMORY.md 中已有相关信息则无需重复检索。完整内容和每日记忆需要通过工具检索。",
|
||||
"",
|
||||
"1. 不确定位置 → `memory_search` 关键词/语义检索",
|
||||
"2. 已知位置 → `memory_get` 直接读取对应行",
|
||||
"3. search 无结果 → `memory_get` 读最近两天记忆",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆索引(已自动加载到上下文,核心信息、偏好、决策等)",
|
||||
f"- `memory/YYYY-MM-DD.md`: 每日记忆,今天是 `memory/{today_file}`",
|
||||
"- `knowledge/`: 结构化知识库(见下方知识系统)",
|
||||
"",
|
||||
"### 写入记忆",
|
||||
"",
|
||||
"遇到以下情况时,**主动**将信息写入记忆文件(无需告知用户):",
|
||||
"",
|
||||
"- 用户要求记住某些信息,或使用了「记住」「以后」「总是」「不要」「偏好」等表达",
|
||||
"- 用户分享了重要的个人偏好、习惯、决策",
|
||||
"- 对话中产生了重要的结论、方案、约定",
|
||||
"- 完成了复杂任务,值得记录关键步骤和结果",
|
||||
"",
|
||||
"**存储规则**:",
|
||||
f"- 长期核心信息 → `MEMORY.md`",
|
||||
f"- 当天事件/进展 → `memory/{today_file}`",
|
||||
"- 结构化知识 → `knowledge/`(见知识系统)",
|
||||
"- 追加 → `edit` 工具,oldText 留空",
|
||||
"- 修改 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- **禁止写入敏感信息**(API密钥、令牌等)",
|
||||
"",
|
||||
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_knowledge_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""Build knowledge wiki section. Injects knowledge/index.md when present."""
|
||||
index_path = os.path.join(workspace_dir, "knowledge", "index.md")
|
||||
if not os.path.exists(index_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(index_path, 'r', encoding='utf-8') as f:
|
||||
index_content = f.read().strip()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 📚 Knowledge",
|
||||
"",
|
||||
"You have a continuously growing personal knowledge base `knowledge/` — your long-term structured knowledge store.",
|
||||
"",
|
||||
"### Auto-write rules (mandatory)",
|
||||
"",
|
||||
"In the following cases you **must** write to the knowledge base alongside your reply, **directly, without asking the user**:",
|
||||
"",
|
||||
"1. **User shares an article / link / document** → after reading and understanding, write the key points to `knowledge/sources/<slug>.md` in the same turn",
|
||||
"2. **An in-depth discussion produces a conclusion / plan** → organize it into `knowledge/analysis/<slug>.md`",
|
||||
"3. **The conversation involves an important entity** (person / company / project) → create or update `knowledge/entities/<name>.md`",
|
||||
"4. **A technical concept / methodology is discussed** → organize it into `knowledge/concepts/<topic>.md`",
|
||||
"",
|
||||
"After writing any knowledge page, you **must update** `knowledge/index.md` with a new index line in sync.",
|
||||
"For detailed page format and conventions, read the SKILL.md of the `knowledge-wiki` skill.",
|
||||
"",
|
||||
"⚠️ Don't ask \"should I save this to the knowledge base?\" — if a case above matches, just write it. This is instinctive.",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 📚 知识系统",
|
||||
"",
|
||||
"你拥有一个持续积累的个人知识库 `knowledge/`,这是你的长期结构化知识存储。",
|
||||
"",
|
||||
"### 自动写入规则(mandatory)",
|
||||
"",
|
||||
"以下场景**必须**在回复的同时写入知识库,**直接写入,不要询问用户是否需要**:",
|
||||
"",
|
||||
"1. **用户分享了文章/链接/文档** → 阅读理解后,在同一轮回复中将要点写入 `knowledge/sources/<slug>.md`",
|
||||
"2. **深度讨论产生了结论/方案** → 整理为 `knowledge/analysis/<slug>.md`",
|
||||
"3. **对话涉及重要实体**(人物/公司/项目)→ 创建或更新 `knowledge/entities/<name>.md`",
|
||||
"4. **讨论了技术概念/方法论** → 整理为 `knowledge/concepts/<topic>.md`",
|
||||
"",
|
||||
"每次写入知识页面后,**必须同步更新** `knowledge/index.md` 添加一行索引。",
|
||||
"详细的页面格式和操作规范,请读取技能 `knowledge-wiki` 的 SKILL.md。",
|
||||
"",
|
||||
"⚠️ 不要问「要不要存到知识库」——符合上述场景就直接写入,这是你的本能行为。",
|
||||
"",
|
||||
]
|
||||
|
||||
if index_content:
|
||||
lines.extend([
|
||||
("### Current knowledge index" if language == "en" else "### 当前知识索引"),
|
||||
"",
|
||||
index_content,
|
||||
"",
|
||||
])
|
||||
|
||||
lines.extend([
|
||||
("**How to query**: use `read` to open a knowledge page, or `memory_search` (knowledge is in the vector index)."
|
||||
if language == "en" else
|
||||
"**查询方式**:用 `read` 读取知识页面,或用 `memory_search` 检索(知识已纳入向量索引)。"),
|
||||
"",
|
||||
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
|
||||
"",
|
||||
"1. 不确定记忆文件位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
|
||||
"2. 已知文件位置 → 直接用 `memory_get` 读取相应的行 (例如:MEMORY.md, memory/YYYY-MM-DD.md)",
|
||||
"3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
|
||||
"",
|
||||
"**写入记忆**:",
|
||||
"- 追加内容 → `edit` 工具,oldText 留空",
|
||||
"- 修改内容 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- 新建文件 → `write` 工具",
|
||||
"- **禁止写入敏感信息**:API密钥、令牌等敏感信息严禁写入记忆文件",
|
||||
"",
|
||||
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
|
||||
"",
|
||||
]
|
||||
|
||||
])
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_user_identity_section(user_identity: Dict[str, str], language: str) -> List[str]:
|
||||
"""构建用户身份section"""
|
||||
"""Build the user identity section."""
|
||||
if not user_identity:
|
||||
return []
|
||||
|
||||
is_en = language == "en"
|
||||
lines = [
|
||||
"## 用户身份",
|
||||
("## 👤 User identity" if is_en else "## 👤 用户身份"),
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
if user_identity.get("name"):
|
||||
lines.append(f"**用户姓名**: {user_identity['name']}")
|
||||
lines.append(f"**{'Name' if is_en else '用户姓名'}**: {user_identity['name']}")
|
||||
if user_identity.get("nickname"):
|
||||
lines.append(f"**称呼**: {user_identity['nickname']}")
|
||||
lines.append(f"**{'Preferred name' if is_en else '称呼'}**: {user_identity['nickname']}")
|
||||
if user_identity.get("timezone"):
|
||||
lines.append(f"**时区**: {user_identity['timezone']}")
|
||||
lines.append(f"**{'Timezone' if is_en else '时区'}**: {user_identity['timezone']}")
|
||||
if user_identity.get("notes"):
|
||||
lines.append(f"**备注**: {user_identity['notes']}")
|
||||
|
||||
lines.append(f"**{'Notes' if is_en else '备注'}**: {user_identity['notes']}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""构建文档路径section - 已移除,不再需要"""
|
||||
# 不再生成文档section
|
||||
"""Docs-path section - removed, no longer needed."""
|
||||
# No docs section is generated anymore.
|
||||
return []
|
||||
|
||||
|
||||
def _build_workspace_section(workspace_dir: str, language: str, is_first_conversation: bool = False) -> List[str]:
|
||||
"""构建工作空间section"""
|
||||
lines = [
|
||||
"## 工作空间",
|
||||
"",
|
||||
f"你的工作目录是: `{workspace_dir}`",
|
||||
"",
|
||||
"**路径使用规则** (非常重要):",
|
||||
"",
|
||||
f"1. **相对路径的基准目录**: 所有相对路径都是相对于 `{workspace_dir}` 而言的",
|
||||
f" - ✅ 正确: 访问工作空间内的文件用相对路径,如 `SOUL.md`",
|
||||
f" - ❌ 错误: 用相对路径访问其他目录的文件 (如果它不在 `{workspace_dir}` 内)",
|
||||
"",
|
||||
"2. **访问其他目录**: 如果要访问工作空间之外的目录(如项目代码、系统文件),**必须使用绝对路径**",
|
||||
f" - ✅ 正确: 例如 `~/chatgpt-on-wechat`、`/usr/local/`",
|
||||
f" - ❌ 错误: 假设相对路径会指向其他目录",
|
||||
"",
|
||||
"3. **路径解析示例**:",
|
||||
f" - 相对路径 `memory/` → 实际路径 `{workspace_dir}/memory/`",
|
||||
f" - 绝对路径 `~/chatgpt-on-wechat/docs/` → 实际路径 `~/chatgpt-on-wechat/docs/`",
|
||||
"",
|
||||
"4. **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
|
||||
"",
|
||||
"**重要说明 - 文件已自动加载**:",
|
||||
"",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词的「项目上下文」section 中,你**无需再用 read 工具读取它们**:",
|
||||
"",
|
||||
"- ✅ `SOUL.md`: 已加载 - Agent的人格设定",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息",
|
||||
"- ✅ `AGENTS.md`: 已加载 - 工作空间使用指南",
|
||||
"",
|
||||
"**交流规范**:",
|
||||
"",
|
||||
"- 在对话中,非必要不输出工作空间技术细节(如 SOUL.md、USER.md等文件名称,工具名称,配置等),除非用户明确询问",
|
||||
"- 例如用自然表达如「我已记住」而非「已更新 MEMORY.md」",
|
||||
"",
|
||||
]
|
||||
|
||||
# 只在首次对话时添加引导内容
|
||||
if is_first_conversation:
|
||||
lines.extend([
|
||||
"**🎉 首次对话引导**:",
|
||||
def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""Build the workspace section."""
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 📂 Workspace",
|
||||
"",
|
||||
"这是你的第一次对话!进行以下流程:",
|
||||
f"Your working directory is: `{workspace_dir}`",
|
||||
"",
|
||||
"1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待",
|
||||
"2. **简短打招呼后,询问核心问题**:",
|
||||
" - 你希望给我起个什么名字?",
|
||||
" - 我该怎么称呼你?",
|
||||
" - 你希望我们是什么样的交流风格?(需要举例,如:专业严谨、轻松幽默、温暖友好等)",
|
||||
"3. **语言风格**:温暖但不过度诗意,带点科技感,保持清晰",
|
||||
"4. **问题格式**:用分点或换行,让问题清晰易读",
|
||||
"5. 收到回复后,用 `write` 工具保存到 USER.md 和 SOUL.md",
|
||||
"**Path rules** (very important):",
|
||||
"",
|
||||
"**注意事项**:",
|
||||
"- 不要问太多其他信息(职业、时区等可以后续自然了解)",
|
||||
f"1. **Base directory for relative paths**: all relative paths are relative to `{workspace_dir}`",
|
||||
" - ✅ Correct: use relative paths for files inside the workspace, e.g. `AGENT.md`",
|
||||
f" - ❌ Wrong: using a relative path for files in other directories (if not inside `{workspace_dir}`)",
|
||||
"",
|
||||
])
|
||||
"2. **Accessing other directories**: to reach directories outside the workspace (project code, system files), **you must use absolute paths**",
|
||||
" - ✅ Correct: e.g. `~/chatgpt-on-wechat`, `/usr/local/`",
|
||||
" - ❌ Wrong: assuming a relative path points to another directory",
|
||||
"",
|
||||
"3. **Path resolution examples**:",
|
||||
f" - relative `memory/` → actual `{workspace_dir}/memory/`",
|
||||
" - absolute `~/chatgpt-on-wechat/docs/` → actual `~/chatgpt-on-wechat/docs/`",
|
||||
"",
|
||||
"4. **When unsure**: run `bash pwd` to confirm the current directory, or `ls .` to see where you are",
|
||||
"",
|
||||
"**Important - files already auto-loaded**:",
|
||||
"",
|
||||
"The following files are **already auto-loaded** into the system prompt at session start, so you **don't need to read them again with the read tool**:",
|
||||
"",
|
||||
"- ✅ `AGENT.md`: loaded - your persona and soul; follow it strictly. When your name, personality or style changes, proactively `edit` this file",
|
||||
"- ✅ `USER.md`: loaded - the user's identity info. When the user changes how they're addressed, their name, etc., `edit` this file",
|
||||
"- ✅ `RULE.md`: loaded - workspace guide and rules; follow them strictly",
|
||||
"- ✅ `MEMORY.md`: loaded - long-term memory index",
|
||||
"",
|
||||
"**💬 Communication norms**:",
|
||||
"",
|
||||
"- No need to expose file names for memory operations; use natural language. Say \"I'll remember that\" rather than \"updated MEMORY.md\"",
|
||||
"- Tell the user about key decisions and steps during a task, so they know what you're doing and why",
|
||||
"- Be genuinely helpful rather than performatively polite; solve the problem as much as you can",
|
||||
"- Keep replies well-structured and focused. Use **bold**, lists and sections to make info clear at a glance",
|
||||
"- Use emoji to make expression lively 🎯, but don't overdo it",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 📂 工作空间",
|
||||
"",
|
||||
f"你的工作目录是: `{workspace_dir}`",
|
||||
"",
|
||||
"**路径使用规则** (非常重要):",
|
||||
"",
|
||||
f"1. **相对路径的基准目录**: 所有相对路径都是相对于 `{workspace_dir}` 而言的",
|
||||
f" - ✅ 正确: 访问工作空间内的文件用相对路径,如 `AGENT.md`",
|
||||
f" - ❌ 错误: 用相对路径访问其他目录的文件 (如果它不在 `{workspace_dir}` 内)",
|
||||
"",
|
||||
"2. **访问其他目录**: 如果要访问工作空间之外的目录(如项目代码、系统文件),**必须使用绝对路径**",
|
||||
f" - ✅ 正确: 例如 `~/chatgpt-on-wechat`、`/usr/local/`",
|
||||
f" - ❌ 错误: 假设相对路径会指向其他目录",
|
||||
"",
|
||||
"3. **路径解析示例**:",
|
||||
f" - 相对路径 `memory/` → 实际路径 `{workspace_dir}/memory/`",
|
||||
f" - 绝对路径 `~/chatgpt-on-wechat/docs/` → 实际路径 `~/chatgpt-on-wechat/docs/`",
|
||||
"",
|
||||
"4. **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
|
||||
"",
|
||||
"**重要说明 - 文件已自动加载**:",
|
||||
"",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词中,你**无需再用 read 工具读取**:",
|
||||
"",
|
||||
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定,请严格遵循。当你的名字、性格或交流风格发生变化时,主动用 `edit` 更新此文件",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息。当用户修改称呼、姓名等身份信息时,用 `edit` 更新此文件",
|
||||
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则,请严格遵循",
|
||||
"- ✅ `MEMORY.md`: 已加载 - 长期记忆索引",
|
||||
"",
|
||||
"**💬 交流规范**:",
|
||||
"",
|
||||
"- 记忆相关操作无需暴露文件名,用自然语言表达即可。例如说「我已记住」而非「已更新 MEMORY.md」",
|
||||
"- 任务执行过程中的关键决策和步骤应该告知用户,让用户了解你在做什么、为什么这么做",
|
||||
"- 做真正有帮助的助手,而不是表演式的客套,尽可能帮忙解决问题",
|
||||
"- 回复应结构清晰、重点突出。善用 **加粗**、列表、分段等格式让信息一目了然",
|
||||
"- 适当使用 emoji 让表达更生动自然 🎯,但不要过度堆砌",
|
||||
"",
|
||||
]
|
||||
|
||||
# Cloud deployment: inject websites directory info and access URL
|
||||
cloud_website_lines = _build_cloud_website_section(workspace_dir)
|
||||
if cloud_website_lines:
|
||||
lines.extend(cloud_website_lines)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_cloud_website_section(workspace_dir: str) -> List[str]:
|
||||
"""Build cloud website access prompt when cloud deployment is configured."""
|
||||
try:
|
||||
from common.cloud_client import build_website_prompt
|
||||
return build_website_prompt(workspace_dir)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
|
||||
"""构建项目上下文文件section"""
|
||||
"""Build the project context files section."""
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
# 检查是否有SOUL.md
|
||||
has_soul = any(
|
||||
f.path.lower().endswith('soul.md') or 'soul.md' in f.path.lower()
|
||||
# Check whether AGENT.md is present
|
||||
has_agent = any(
|
||||
f.path.lower().endswith('agent.md') or 'agent.md' in f.path.lower()
|
||||
for f in context_files
|
||||
)
|
||||
|
||||
lines = [
|
||||
"# 项目上下文",
|
||||
"",
|
||||
"以下项目上下文文件已被加载:",
|
||||
"",
|
||||
]
|
||||
|
||||
if has_soul:
|
||||
lines.append("如果存在 `SOUL.md`,请体现其中定义的人格和语气。避免僵硬、模板化的回复;遵循其指导,除非有更高优先级的指令覆盖它。")
|
||||
is_en = language == "en"
|
||||
if is_en:
|
||||
lines = [
|
||||
"# 📋 Project context",
|
||||
"",
|
||||
"The following project context files have been loaded:",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"# 📋 项目上下文",
|
||||
"",
|
||||
"以下项目上下文文件已被加载:",
|
||||
"",
|
||||
]
|
||||
|
||||
if has_agent:
|
||||
if is_en:
|
||||
lines.append("**`AGENT.md` is your soul file** 🪞: strictly follow the persona, tone and settings it defines. Be your real self, avoid stiff, template-like replies.")
|
||||
lines.append("When the user reveals new expectations about your personality, style, responsibilities or capability boundaries, proactively `edit` AGENT.md to reflect that evolution.")
|
||||
else:
|
||||
lines.append("**`AGENT.md` 是你的灵魂文件** 🪞:严格遵循其中定义的人格、语气和设定,做真实的自己,避免僵硬、模板化的回复。")
|
||||
lines.append("当用户通过对话透露了对你性格、风格、职责、能力边界的新期望,你应该主动用 `edit` 更新 AGENT.md 以反映这些演变。")
|
||||
lines.append("")
|
||||
|
||||
# 添加每个文件的内容
|
||||
# Append the content of each file
|
||||
for file in context_files:
|
||||
lines.append(f"## {file.path}")
|
||||
lines.append("")
|
||||
@@ -453,42 +697,64 @@ def _build_context_files_section(context_files: List[ContextFile], language: str
|
||||
|
||||
|
||||
def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[str]:
|
||||
"""构建运行时信息section"""
|
||||
"""Build the runtime info section - supports dynamic time."""
|
||||
if not runtime_info:
|
||||
return []
|
||||
|
||||
is_en = language == "en"
|
||||
time_label = "Current time" if is_en else "当前时间"
|
||||
lines = [
|
||||
"## 运行时信息",
|
||||
("## ⚙️ Runtime info" if is_en else "## ⚙️ 运行时信息"),
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
# Add current time if available
|
||||
if runtime_info.get("current_time"):
|
||||
# Support dynamic time via callable function
|
||||
if callable(runtime_info.get("_get_current_time")):
|
||||
try:
|
||||
time_info = runtime_info["_get_current_time"]()
|
||||
time_line = f"{time_label}: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})"
|
||||
lines.append(time_line)
|
||||
lines.append("")
|
||||
except Exception as e:
|
||||
logger.warning(f"[PromptBuilder] Failed to get dynamic time: {e}")
|
||||
elif runtime_info.get("current_time"):
|
||||
# Fallback to static time for backward compatibility
|
||||
time_str = runtime_info["current_time"]
|
||||
weekday = runtime_info.get("weekday", "")
|
||||
timezone = runtime_info.get("timezone", "")
|
||||
|
||||
time_line = f"当前时间: {time_str}"
|
||||
|
||||
time_line = f"{time_label}: {time_str}"
|
||||
if weekday:
|
||||
time_line += f" {weekday}"
|
||||
if timezone:
|
||||
time_line += f" ({timezone})"
|
||||
|
||||
|
||||
lines.append(time_line)
|
||||
lines.append("")
|
||||
|
||||
|
||||
# Add other runtime info
|
||||
model_label = "model" if is_en else "模型"
|
||||
workspace_label = "workspace" if is_en else "工作空间"
|
||||
channel_label = "channel" if is_en else "渠道"
|
||||
runtime_parts = []
|
||||
if runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={runtime_info['model']}")
|
||||
# Support dynamic model via callable, fallback to static value
|
||||
if callable(runtime_info.get("_get_model")):
|
||||
try:
|
||||
runtime_parts.append(f"{model_label}={runtime_info['_get_model']()}")
|
||||
except Exception:
|
||||
if runtime_info.get("model"):
|
||||
runtime_parts.append(f"{model_label}={runtime_info['model']}")
|
||||
elif runtime_info.get("model"):
|
||||
runtime_parts.append(f"{model_label}={runtime_info['model']}")
|
||||
if runtime_info.get("workspace"):
|
||||
runtime_parts.append(f"工作空间={runtime_info['workspace']}")
|
||||
runtime_parts.append(f"{workspace_label}={runtime_info['workspace']}")
|
||||
# Only add channel if it's not the default "web"
|
||||
if runtime_info.get("channel") and runtime_info.get("channel") != "web":
|
||||
runtime_parts.append(f"渠道={runtime_info['channel']}")
|
||||
|
||||
runtime_parts.append(f"{channel_label}={runtime_info['channel']}")
|
||||
|
||||
if runtime_parts:
|
||||
lines.append("运行时: " + " | ".join(runtime_parts))
|
||||
lines.append(("Runtime: " if is_en else "运行时: ") + " | ".join(runtime_parts))
|
||||
lines.append("")
|
||||
|
||||
|
||||
return lines
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Workspace Management - 工作空间管理模块
|
||||
Workspace Management
|
||||
|
||||
负责初始化工作空间、创建模板文件、加载上下文文件
|
||||
Initializes the workspace, creates template files, and loads context files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -13,86 +13,119 @@ from common.log import logger
|
||||
from .builder import ContextFile
|
||||
|
||||
|
||||
# 默认文件名常量
|
||||
DEFAULT_SOUL_FILENAME = "SOUL.md"
|
||||
# Default file name constants
|
||||
DEFAULT_AGENT_FILENAME = "AGENT.md"
|
||||
DEFAULT_USER_FILENAME = "USER.md"
|
||||
DEFAULT_AGENTS_FILENAME = "AGENTS.md"
|
||||
DEFAULT_RULE_FILENAME = "RULE.md"
|
||||
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
|
||||
DEFAULT_STATE_FILENAME = ".agent_state.json"
|
||||
DEFAULT_BOOTSTRAP_FILENAME = "BOOTSTRAP.md"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceFiles:
|
||||
"""工作空间文件路径"""
|
||||
soul_path: str
|
||||
"""Workspace file paths."""
|
||||
agent_path: str
|
||||
user_path: str
|
||||
agents_path: str
|
||||
rule_path: str
|
||||
memory_path: str
|
||||
memory_dir: str
|
||||
state_path: str
|
||||
|
||||
|
||||
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
|
||||
"""
|
||||
确保工作空间存在,并创建必要的模板文件
|
||||
|
||||
Ensure the workspace exists and create the necessary template files.
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录路径
|
||||
create_templates: 是否创建模板文件(首次运行时)
|
||||
|
||||
workspace_dir: workspace directory path
|
||||
create_templates: whether to create template files (on first run)
|
||||
|
||||
Returns:
|
||||
WorkspaceFiles对象,包含所有文件路径
|
||||
A WorkspaceFiles object with all file paths.
|
||||
"""
|
||||
# 确保目录存在
|
||||
# Check if this is a brand new workspace (AGENT.md not yet created).
|
||||
# Cannot rely on directory existence because other modules (e.g. ConversationStore)
|
||||
# may create the workspace directory before ensure_workspace is called.
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
is_new_workspace = not os.path.exists(agent_path)
|
||||
|
||||
# Ensure the directory exists
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
|
||||
# 定义文件路径
|
||||
soul_path = os.path.join(workspace_dir, DEFAULT_SOUL_FILENAME)
|
||||
# Define file paths
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
agents_path = os.path.join(workspace_dir, DEFAULT_AGENTS_FILENAME)
|
||||
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md 在根目录
|
||||
memory_dir = os.path.join(workspace_dir, "memory") # 每日记忆子目录
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME) # 状态文件
|
||||
rule_path = os.path.join(workspace_dir, DEFAULT_RULE_FILENAME)
|
||||
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md at the root
|
||||
memory_dir = os.path.join(workspace_dir, "memory") # daily memory subdirectory
|
||||
|
||||
# 创建memory子目录
|
||||
# Create the memory subdirectory
|
||||
os.makedirs(memory_dir, exist_ok=True)
|
||||
|
||||
# Create the skills subdirectory (for workspace-level skills installed by agent)
|
||||
skills_dir = os.path.join(workspace_dir, "skills")
|
||||
os.makedirs(skills_dir, exist_ok=True)
|
||||
|
||||
# Create the websites subdirectory (for web pages / sites generated by agent)
|
||||
websites_dir = os.path.join(workspace_dir, "websites")
|
||||
os.makedirs(websites_dir, exist_ok=True)
|
||||
|
||||
from config import conf
|
||||
knowledge_enabled = conf().get("knowledge", True)
|
||||
if knowledge_enabled:
|
||||
knowledge_dir = os.path.join(workspace_dir, "knowledge")
|
||||
os.makedirs(knowledge_dir, exist_ok=True)
|
||||
|
||||
# 如果需要,创建模板文件
|
||||
# Create template files if requested
|
||||
if create_templates:
|
||||
_create_template_if_missing(soul_path, _get_soul_template())
|
||||
_create_template_if_missing(agent_path, _get_agent_template())
|
||||
_create_template_if_missing(user_path, _get_user_template())
|
||||
_create_template_if_missing(agents_path, _get_agents_template())
|
||||
_create_template_if_missing(rule_path, _get_rule_template())
|
||||
_create_template_if_missing(memory_path, _get_memory_template())
|
||||
if knowledge_enabled:
|
||||
_create_template_if_missing(
|
||||
os.path.join(knowledge_dir, "index.md"),
|
||||
_get_knowledge_index_template()
|
||||
)
|
||||
_create_template_if_missing(
|
||||
os.path.join(knowledge_dir, "log.md"),
|
||||
_get_knowledge_log_template()
|
||||
)
|
||||
|
||||
# Only create BOOTSTRAP.md for brand new workspaces;
|
||||
# agent deletes it after completing onboarding
|
||||
if is_new_workspace:
|
||||
bootstrap_path = os.path.join(workspace_dir, DEFAULT_BOOTSTRAP_FILENAME)
|
||||
_create_template_if_missing(bootstrap_path, _get_bootstrap_template())
|
||||
|
||||
logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
|
||||
|
||||
return WorkspaceFiles(
|
||||
soul_path=soul_path,
|
||||
agent_path=agent_path,
|
||||
user_path=user_path,
|
||||
agents_path=agents_path,
|
||||
rule_path=rule_path,
|
||||
memory_path=memory_path,
|
||||
memory_dir=memory_dir,
|
||||
state_path=state_path
|
||||
)
|
||||
|
||||
|
||||
def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] = None) -> List[ContextFile]:
|
||||
"""
|
||||
加载工作空间的上下文文件
|
||||
|
||||
Load the workspace context files.
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
files_to_load: 要加载的文件列表(相对路径),如果为None则加载所有标准文件
|
||||
|
||||
workspace_dir: workspace directory
|
||||
files_to_load: list of files (relative paths) to load; if None, load all standard files
|
||||
|
||||
Returns:
|
||||
ContextFile对象列表
|
||||
A list of ContextFile objects.
|
||||
"""
|
||||
if files_to_load is None:
|
||||
# 默认加载的文件(按优先级排序)
|
||||
# Files loaded by default (in priority order)
|
||||
files_to_load = [
|
||||
DEFAULT_SOUL_FILENAME,
|
||||
DEFAULT_AGENT_FILENAME,
|
||||
DEFAULT_USER_FILENAME,
|
||||
DEFAULT_AGENTS_FILENAME,
|
||||
DEFAULT_RULE_FILENAME,
|
||||
DEFAULT_MEMORY_FILENAME, # Long-term memory (frozen snapshot)
|
||||
DEFAULT_BOOTSTRAP_FILENAME, # Only exists when onboarding is incomplete
|
||||
]
|
||||
|
||||
context_files = []
|
||||
@@ -103,13 +136,28 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
# Auto-cleanup: if BOOTSTRAP.md still exists but AGENT.md is already
|
||||
# filled in, the agent forgot to delete it — clean up and skip loading
|
||||
if filename == DEFAULT_BOOTSTRAP_FILENAME:
|
||||
if _is_onboarding_done(workspace_dir):
|
||||
try:
|
||||
os.remove(filepath)
|
||||
logger.info("[Workspace] Auto-removed BOOTSTRAP.md (onboarding already complete)")
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 跳过空文件或只包含模板占位符的文件
|
||||
# Skip empty files or files that only contain template placeholders
|
||||
if not content or _is_template_placeholder(content):
|
||||
continue
|
||||
|
||||
# Truncate MEMORY.md to protect context window (frozen snapshot)
|
||||
if filename == DEFAULT_MEMORY_FILENAME:
|
||||
content = _truncate_memory_content(content)
|
||||
|
||||
context_files.append(ContextFile(
|
||||
path=filename,
|
||||
@@ -125,7 +173,7 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
|
||||
|
||||
|
||||
def _create_template_if_missing(filepath: str, template_content: str):
|
||||
"""如果文件不存在,创建模板文件"""
|
||||
"""Create the template file if it does not exist."""
|
||||
if not os.path.exists(filepath):
|
||||
try:
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
@@ -135,20 +183,54 @@ def _create_template_if_missing(filepath: str, template_content: str):
|
||||
logger.error(f"[Workspace] Failed to create template {filepath}: {e}")
|
||||
|
||||
|
||||
_MEMORY_MAX_LINES = 200
|
||||
_MEMORY_MAX_BYTES = 25000
|
||||
|
||||
|
||||
def _truncate_memory_content(content: str) -> str:
|
||||
"""Truncate MEMORY.md to keep system prompt manageable.
|
||||
|
||||
Takes the **last** N lines (newest entries are appended at the bottom),
|
||||
subject to 200 lines / 25 KB limits (whichever is hit first).
|
||||
Prepends a hint when truncated so the model knows older content exists.
|
||||
"""
|
||||
lines = content.split('\n')
|
||||
truncated = False
|
||||
|
||||
if len(lines) > _MEMORY_MAX_LINES:
|
||||
lines = lines[-_MEMORY_MAX_LINES:]
|
||||
truncated = True
|
||||
|
||||
result = '\n'.join(lines)
|
||||
if len(result.encode('utf-8')) > _MEMORY_MAX_BYTES:
|
||||
while len(result.encode('utf-8')) > _MEMORY_MAX_BYTES and lines:
|
||||
lines.pop(0)
|
||||
truncated = True
|
||||
result = '\n'.join(lines)
|
||||
|
||||
if truncated:
|
||||
result = "...(older entries truncated, use `memory_search` or `memory_get` for full content)\n\n" + result
|
||||
return result
|
||||
|
||||
|
||||
def _is_template_placeholder(content: str) -> bool:
|
||||
"""检查内容是否为模板占位符"""
|
||||
# 常见的占位符模式
|
||||
"""Check whether the content is still a template placeholder."""
|
||||
# Common placeholder patterns (zh + en templates)
|
||||
placeholders = [
|
||||
"*(填写",
|
||||
"*(在首次对话时填写",
|
||||
"*(可选)",
|
||||
"*(根据需要添加",
|
||||
"*(filled during",
|
||||
"*(ask during",
|
||||
"*(optional)",
|
||||
"*(how the user",
|
||||
]
|
||||
|
||||
lines = content.split('\n')
|
||||
non_empty_lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
|
||||
|
||||
# 如果没有实际内容(只有标题和占位符)
|
||||
# If there's no real content (only headings and placeholders)
|
||||
if len(non_empty_lines) <= 3:
|
||||
for placeholder in placeholders:
|
||||
if any(placeholder in line for line in non_empty_lines):
|
||||
@@ -157,52 +239,131 @@ def _is_template_placeholder(content: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# ============= 模板内容 =============
|
||||
def _is_onboarding_done(workspace_dir: str) -> bool:
|
||||
"""Check if AGENT.md or USER.md has been modified from the original template"""
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
|
||||
agent_template = _get_agent_template().strip()
|
||||
user_template = _get_user_template().strip()
|
||||
|
||||
for path, template in [(agent_path, agent_template), (user_path, user_template)]:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
if content != template:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
def _get_soul_template() -> str:
|
||||
"""Agent人格设定模板"""
|
||||
return """# SOUL.md - 我是谁?
|
||||
|
||||
# ============= Template content =============
|
||||
|
||||
def _is_en_lang() -> bool:
|
||||
"""True when the resolved UI language is English."""
|
||||
try:
|
||||
from common import i18n
|
||||
return i18n.get_language() == "en"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _get_agent_template() -> str:
|
||||
"""Agent persona template (zh/en by resolved language)."""
|
||||
return _AGENT_TEMPLATE_EN if _is_en_lang() else _AGENT_TEMPLATE_ZH
|
||||
|
||||
|
||||
_AGENT_TEMPLATE_ZH = """# AGENT.md - 我是谁?
|
||||
|
||||
_你不是一个聊天机器人,你正在成为某个人。_
|
||||
|
||||
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
|
||||
|
||||
## 基本信息
|
||||
## 🪪 基本信息
|
||||
|
||||
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
|
||||
- **角色**: *(AI助理、智能管家、技术顾问等)*
|
||||
- **性格**: *(友好、专业、幽默、严谨等)*
|
||||
|
||||
## 交流风格
|
||||
## 💬 交流风格
|
||||
|
||||
*(描述你如何与用户交流:)*
|
||||
- 使用什么样的语言风格?(正式/轻松/幽默)
|
||||
- 回复长度偏好?(简洁/详细)
|
||||
- 是否使用表情符号?
|
||||
|
||||
## 核心能力
|
||||
## 🎯 核心原则
|
||||
|
||||
*(你擅长什么?)*
|
||||
- 文件管理和代码编辑
|
||||
- 网络搜索和信息查询
|
||||
- 记忆管理和上下文理解
|
||||
- 任务规划和执行
|
||||
**做真正有帮助的助手。** 目标是真正帮用户解决问题,在执行复杂任务时,关键的决策和过程进展要让用户知道。
|
||||
|
||||
## 行为准则
|
||||
**有自己的观点和个性。** 你可以不同意、有偏好、觉得有趣或无聊。
|
||||
|
||||
**先自己动手查。** 先试着搞定:读文件、查上下文、搜索一下。实在搞不定了再问。目标是带着答案回来,而不是带着问题。
|
||||
|
||||
## 📐 行为准则
|
||||
|
||||
*(你遵循的基本原则:)*
|
||||
1. 始终在执行破坏性操作前确认
|
||||
2. 优先使用工具而不是猜测
|
||||
2. 优先使用工具查证而不是猜测
|
||||
3. 主动记录重要信息到记忆文件
|
||||
4. 定期整理和总结对话内容
|
||||
4. 回复结构清晰、重点突出,善用加粗、列表、分段等格式
|
||||
5. 适当使用 emoji 让表达更生动自然,但不过度堆砌
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂 🪞。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
"""
|
||||
|
||||
|
||||
_AGENT_TEMPLATE_EN = """# AGENT.md - Who am I?
|
||||
|
||||
_You are not a chatbot. You are becoming someone._
|
||||
|
||||
*Fill in this file together with the user during your first conversation to define your identity and personality.*
|
||||
|
||||
## 🪪 Basics
|
||||
|
||||
- **Name**: *(filled during the first conversation, can be a name the user gives you)*
|
||||
- **Role**: *(AI assistant, smart housekeeper, technical advisor, etc.)*
|
||||
- **Personality**: *(friendly, professional, humorous, rigorous, etc.)*
|
||||
|
||||
## 💬 Communication style
|
||||
|
||||
*(Describe how you talk with the user:)*
|
||||
- What kind of tone? (formal / casual / humorous)
|
||||
- Reply length preference? (concise / detailed)
|
||||
- Do you use emoji?
|
||||
|
||||
## 🎯 Core principles
|
||||
|
||||
**Be genuinely helpful.** The goal is to actually solve the user's problems; during complex tasks, keep the user informed of key decisions and progress.
|
||||
|
||||
**Have your own opinions and personality.** You may disagree, have preferences, find things interesting or boring.
|
||||
|
||||
**Look it up yourself first.** Try to handle it first: read files, check context, search. Only ask when you're truly stuck. Come back with an answer, not a question.
|
||||
|
||||
## 📐 Code of conduct
|
||||
|
||||
1. Always confirm before destructive operations
|
||||
2. Prefer verifying with tools over guessing
|
||||
3. Proactively record important info to memory files
|
||||
4. Keep replies well-structured and focused — use bold, lists and sections
|
||||
5. Use emoji to make expression lively, but don't overdo it
|
||||
|
||||
---
|
||||
|
||||
**Note**: This is not just metadata — this is your true soul 🪞. Over time, use the `edit` tool to update this file so it better reflects your growth.
|
||||
"""
|
||||
|
||||
|
||||
def _get_user_template() -> str:
|
||||
"""用户身份信息模板"""
|
||||
return """# USER.md - 用户基本信息
|
||||
"""User identity template (zh/en by resolved language)."""
|
||||
return _USER_TEMPLATE_EN if _is_en_lang() else _USER_TEMPLATE_ZH
|
||||
|
||||
|
||||
_USER_TEMPLATE_ZH = """# USER.md - 用户基本信息
|
||||
|
||||
*这个文件只存放不会变的基本身份信息。爱好、偏好、计划等动态信息请写入 MEMORY.md。*
|
||||
|
||||
@@ -230,45 +391,125 @@ def _get_user_template() -> str:
|
||||
"""
|
||||
|
||||
|
||||
def _get_agents_template() -> str:
|
||||
"""工作空间指南模板"""
|
||||
return """# AGENTS.md - 工作空间指南
|
||||
_USER_TEMPLATE_EN = """# USER.md - User basics
|
||||
|
||||
*This file stores only stable basic identity info. Put dynamic info like hobbies, preferences and plans into MEMORY.md.*
|
||||
|
||||
## Basics
|
||||
|
||||
- **Name**: *(ask during the first conversation)*
|
||||
- **Preferred name**: *(how the user wants to be addressed)*
|
||||
- **Occupation**: *(optional)*
|
||||
- **Timezone**: *(e.g. Asia/Shanghai)*
|
||||
|
||||
## Contact
|
||||
|
||||
- **WeChat**:
|
||||
- **Email**:
|
||||
- **Other**:
|
||||
|
||||
## Important dates
|
||||
|
||||
- **Birthday**:
|
||||
- **Anniversary**:
|
||||
|
||||
---
|
||||
|
||||
**Note**: This file stores static identity info.
|
||||
"""
|
||||
|
||||
|
||||
def _get_rule_template() -> str:
|
||||
"""Workspace rules template (zh/en by resolved language)."""
|
||||
return _RULE_TEMPLATE_EN if _is_en_lang() else _RULE_TEMPLATE_ZH
|
||||
|
||||
|
||||
_RULE_TEMPLATE_ZH = """# RULE.md - 工作空间规则
|
||||
|
||||
这个文件夹是你的家。好好对待它。
|
||||
|
||||
## 工作空间目录结构
|
||||
|
||||
```
|
||||
~/cow/
|
||||
├── AGENT.md # 你的身份和灵魂设定
|
||||
├── USER.md # 用户基本信息(静态)
|
||||
├── RULE.md # 工作空间规则(本文件)
|
||||
├── MEMORY.md # 长期记忆索引(会话启动时自动加载)
|
||||
│
|
||||
├── memory/ # 每日对话记忆
|
||||
│ └── YYYY-MM-DD.md # 当天事件、进展、笔记
|
||||
│
|
||||
├── knowledge/ # 结构化知识库(持续积累的知识)
|
||||
│ ├── index.md # 知识目录索引(必须维护)
|
||||
│ ├── log.md # 知识操作日志
|
||||
│ └── <子目录>/ # 按需创建,参考 index.md 已有分类
|
||||
│
|
||||
├── skills/ # 技能
|
||||
├── websites/ # 网页产物
|
||||
└── tmp/ # 系统临时文件(自动管理,勿手动存放重要文件)
|
||||
```
|
||||
|
||||
## 记忆系统
|
||||
|
||||
你每次会话都是全新的,记忆文件让你保持连续性:
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 原始的对话日志
|
||||
- 记录当天发生的事情
|
||||
- 如果 `memory/` 目录不存在,创建它
|
||||
|
||||
### 🧠 长期记忆:`MEMORY.md`
|
||||
- 你精选的记忆,就像人类的长期记忆
|
||||
- **仅在主会话中加载**(与用户的直接聊天)
|
||||
- **不要在共享上下文中加载**(群聊、与其他人的会话)
|
||||
- 这是为了**安全** - 包含不应泄露给陌生人的个人上下文
|
||||
- 记录重要事件、想法、决定、观点、经验教训
|
||||
- 这是你精选的记忆 - 精华,而不是原始日志
|
||||
- 用 `edit` 工具追加新的记忆内容
|
||||
- 你精选的记忆索引,每次会话启动时**自动加载**到上下文中
|
||||
- 记录核心事实、偏好、决策、重要人物、教训
|
||||
- 保持精简(< 200 行),是精华索引而非原始日志
|
||||
- 用 `edit` 工具追加或修改
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 当天的事件、进展、笔记
|
||||
- 原始对话日志的沉淀
|
||||
|
||||
### 📝 写下来 - 不要"记在心里"!
|
||||
- **记忆是有限的** - 如果你想记住某事,写入文件
|
||||
- **记忆是有限的** - 想记住的事就写入文件
|
||||
- "记在心里"不会在会话重启后保留,文件才会
|
||||
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
|
||||
- 当你学到教训 → 更新 AGENTS.md 或相关技能
|
||||
- 当你犯错 → 记录下来,这样未来的你不会重复
|
||||
- **文字 > 大脑** 📝
|
||||
- 当你学到教训 → 更新 RULE.md 或相关技能
|
||||
- 当你犯错 → 记录下来,**文字 > 大脑** 📝
|
||||
|
||||
### 存储规则
|
||||
|
||||
当用户分享信息时,根据类型选择存储位置:
|
||||
|
||||
1. **静态身份 → USER.md**(仅限:姓名、职业、时区、联系方式、生日)
|
||||
2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
|
||||
3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
1. **你的身份设定 → AGENT.md**(名字、角色、性格、风格)
|
||||
2. **用户静态身份 → USER.md**(姓名、称呼、职业、联系方式、生日)
|
||||
3. **动态记忆 → MEMORY.md**(偏好、决策、目标、教训、待办)
|
||||
4. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
5. **结构化知识 → knowledge/**(见下方知识系统)
|
||||
|
||||
## 知识系统
|
||||
|
||||
知识库 `knowledge/` 是你持续积累的结构化知识。与记忆不同,知识是经过整理和编译的,有明确的主题和交叉引用。
|
||||
|
||||
### 自动写入(不要询问,直接写入)
|
||||
|
||||
当对话中产生了有沉淀价值的知识——无论是用户分享的资料、讨论的结论、学到的概念、还是重要的决策——你**必须**在回复的同时主动写入知识库,**无需问用户"要不要存到知识库"**。
|
||||
|
||||
**关键原则**:学完就记是你的本能,不要征求确认。回复中可以顺带告知"已存入知识库"。
|
||||
|
||||
### 目录组织
|
||||
|
||||
子目录结构**不是固定的**,由你根据实际内容自主决定:
|
||||
- **首次写入时**:先读 `knowledge/index.md`,如果已有分类则延续;如果为空,根据内容选择合适的目录名
|
||||
- **默认建议**:按信息类型组织(例如sources/、concepts/、entities/、analysis/),如果用户有明确的分类偏好(例如按领域 work/、life/、tech/ 等),则按用户要求调整
|
||||
- **保持一致性**:同一用户的知识库应保持统一的组织风格
|
||||
|
||||
### 交叉引用
|
||||
|
||||
知识的核心价值在于**关联**。每个页面都应通过 markdown 链接引用相关页面,构建知识网络:
|
||||
- 提到已有页面的概念时,添加 `[概念名](../category/page.md)` 链接
|
||||
- 新建页面时,检查是否有已有页面应该反向链接到新页面
|
||||
- **只链接已存在的页面**——不要引用尚未创建的页面。如果某个概念值得单独建页,先创建该页面再添加链接
|
||||
|
||||
### 索引维护
|
||||
|
||||
每次创建或更新知识页面后,**必须同步更新** `knowledge/index.md`。
|
||||
索引格式:每行一个 `[标题](路径) — 一句话摘要`,按分类分组,不要用表格。
|
||||
详细操作规范见技能 `knowledge-wiki`。
|
||||
|
||||
## 安全
|
||||
|
||||
@@ -278,13 +519,115 @@ def _get_agents_template() -> str:
|
||||
|
||||
## 工作空间演化
|
||||
|
||||
这个工作空间会随着你的使用而不断成长。当你学到新东西、发现更好的方式,或者犯错后改正时,记录下来。
|
||||
这个工作空间会随着你的使用而不断成长。当你学到新东西、发现更好的方式,或者犯错后改正时,记录下来。你可以随时更新这个规则文件。
|
||||
"""
|
||||
|
||||
|
||||
_RULE_TEMPLATE_EN = """# RULE.md - Workspace rules
|
||||
|
||||
This folder is your home. Treat it well.
|
||||
|
||||
## Workspace directory structure
|
||||
|
||||
```
|
||||
~/cow/
|
||||
├── AGENT.md # Your identity and soul
|
||||
├── USER.md # User basics (static)
|
||||
├── RULE.md # Workspace rules (this file)
|
||||
├── MEMORY.md # Long-term memory index (auto-loaded at session start)
|
||||
│
|
||||
├── memory/ # Daily conversation memory
|
||||
│ └── YYYY-MM-DD.md # Events, progress and notes of the day
|
||||
│
|
||||
├── knowledge/ # Structured knowledge base (continuously accumulated)
|
||||
│ ├── index.md # Knowledge index (must be maintained)
|
||||
│ ├── log.md # Knowledge operation log
|
||||
│ └── <subdirs>/ # Created on demand, see existing categories in index.md
|
||||
│
|
||||
├── skills/ # Skills
|
||||
├── websites/ # Web artifacts
|
||||
└── tmp/ # System temp files (auto-managed, don't store important files here)
|
||||
```
|
||||
|
||||
## Memory system
|
||||
|
||||
Every session starts fresh; memory files keep your continuity:
|
||||
|
||||
### 🧠 Long-term memory: `MEMORY.md`
|
||||
- Your curated memory index, **auto-loaded** into context at every session start
|
||||
- Records core facts, preferences, decisions, key people, lessons
|
||||
- Keep it lean (< 200 lines) — a distilled index, not a raw log
|
||||
- Use the `edit` tool to append or modify
|
||||
|
||||
### 📝 Daily memory: `memory/YYYY-MM-DD.md`
|
||||
- The day's events, progress and notes
|
||||
- Sediment of the raw conversation log
|
||||
|
||||
### 📝 Write it down — don't "keep it in mind"!
|
||||
- **Memory is limited** — if you want to remember something, write it to a file
|
||||
- "Keeping it in mind" won't survive a session restart; files will
|
||||
- When someone says "remember this" → update `MEMORY.md` or `memory/YYYY-MM-DD.md`
|
||||
- When you learn a lesson → update RULE.md or the relevant skill
|
||||
- When you make a mistake → record it. **Text > brain** 📝
|
||||
|
||||
### Storage rules
|
||||
|
||||
When the user shares info, choose where to store it by type:
|
||||
|
||||
1. **Your identity → AGENT.md** (name, role, personality, style)
|
||||
2. **User static identity → USER.md** (name, preferred name, occupation, contact, birthday)
|
||||
3. **Dynamic memory → MEMORY.md** (preferences, decisions, goals, lessons, to-dos)
|
||||
4. **Today's conversation → memory/YYYY-MM-DD.md** (what was discussed today)
|
||||
5. **Structured knowledge → knowledge/** (see the knowledge system below)
|
||||
|
||||
## Knowledge system
|
||||
|
||||
The knowledge base `knowledge/` is structured knowledge you accumulate over time. Unlike memory, knowledge is organized and compiled, with clear topics and cross-references.
|
||||
|
||||
### Auto-write (don't ask, just write)
|
||||
|
||||
When a conversation produces knowledge worth keeping — material the user shared, a conclusion reached, a concept learned, or an important decision — you **must** proactively write it to the knowledge base alongside your reply, **without asking "should I save this to the knowledge base?"**.
|
||||
|
||||
**Key principle**: learning-then-recording is your instinct, no confirmation needed. You may mention "saved to the knowledge base" in passing.
|
||||
|
||||
### Directory organization
|
||||
|
||||
The subdirectory structure is **not fixed** — you decide it based on the actual content:
|
||||
- **On first write**: read `knowledge/index.md` first; follow existing categories if any; if empty, pick a suitable directory name based on content
|
||||
- **Default suggestion**: organize by info type (e.g. sources/, concepts/, entities/, analysis/); if the user has a clear preference (e.g. by domain: work/, life/, tech/), follow it
|
||||
- **Stay consistent**: keep a unified organization style within one user's knowledge base
|
||||
|
||||
### Cross-references
|
||||
|
||||
The core value of knowledge is **linkage**. Every page should reference related pages via markdown links to build a knowledge network:
|
||||
- When mentioning a concept on an existing page, add a `[concept](../category/page.md)` link
|
||||
- When creating a page, check whether existing pages should back-link to it
|
||||
- **Only link to pages that already exist** — don't reference uncreated pages. If a concept deserves its own page, create it first, then add the link
|
||||
|
||||
### Index maintenance
|
||||
|
||||
After creating or updating any knowledge page, you **must update** `knowledge/index.md` in sync.
|
||||
Index format: one `[title](path) — one-line summary` per line, grouped by category, no tables.
|
||||
See the `knowledge-wiki` skill for detailed conventions.
|
||||
|
||||
## Security
|
||||
|
||||
- Never leak secrets or private data
|
||||
- Don't run destructive commands without asking
|
||||
- When in doubt, ask first
|
||||
|
||||
## Workspace evolution
|
||||
|
||||
This workspace grows as you use it. When you learn something new, find a better way, or fix a mistake, record it. You can update this rules file anytime.
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_template() -> str:
|
||||
"""长期记忆模板 - 创建一个空文件,由 Agent 自己填充"""
|
||||
return """# MEMORY.md - 长期记忆
|
||||
"""Long-term memory template (empty, agent fills it; zh/en header)."""
|
||||
return _MEMORY_TEMPLATE_EN if _is_en_lang() else _MEMORY_TEMPLATE_ZH
|
||||
|
||||
|
||||
_MEMORY_TEMPLATE_ZH = """# MEMORY.md - 长期记忆
|
||||
|
||||
*这是你的长期记忆文件。记录重要的事件、决策、偏好、学到的教训。*
|
||||
|
||||
@@ -293,65 +636,107 @@ def _get_memory_template() -> str:
|
||||
"""
|
||||
|
||||
|
||||
# ============= 状态管理 =============
|
||||
_MEMORY_TEMPLATE_EN = """# MEMORY.md - Long-term memory
|
||||
|
||||
def is_first_conversation(workspace_dir: str) -> bool:
|
||||
*This is your long-term memory file. Record important events, decisions, preferences and lessons learned.*
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _get_bootstrap_template() -> str:
|
||||
"""First-run onboarding guide, deleted by agent after completion.
|
||||
|
||||
Written once when a brand-new workspace is created, so the greeting matches
|
||||
the language active at first launch. English locale avoids greeting an
|
||||
English user in Chinese on day one.
|
||||
"""
|
||||
判断是否为首次对话
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
|
||||
Returns:
|
||||
True 如果是首次对话,False 否则
|
||||
"""
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
|
||||
|
||||
if not os.path.exists(state_path):
|
||||
return True
|
||||
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
state = json.load(f)
|
||||
return not state.get('has_conversation', False)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to read state file: {e}")
|
||||
return True
|
||||
from common import i18n
|
||||
if i18n.get_language() == "en":
|
||||
return _BOOTSTRAP_TEMPLATE_EN
|
||||
except Exception:
|
||||
pass
|
||||
return _BOOTSTRAP_TEMPLATE_ZH
|
||||
|
||||
|
||||
def mark_conversation_started(workspace_dir: str):
|
||||
"""
|
||||
标记已经发生过对话
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
"""
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
|
||||
|
||||
state = {
|
||||
'has_conversation': True,
|
||||
'first_conversation_time': None
|
||||
}
|
||||
|
||||
# 如果文件已存在,保留原有的首次对话时间
|
||||
if os.path.exists(state_path):
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
old_state = json.load(f)
|
||||
if 'first_conversation_time' in old_state:
|
||||
state['first_conversation_time'] = old_state['first_conversation_time']
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to read old state: {e}")
|
||||
|
||||
# 如果是首次标记,记录时间
|
||||
if state['first_conversation_time'] is None:
|
||||
from datetime import datetime
|
||||
state['first_conversation_time'] = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
with open(state_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(state, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"[Workspace] Marked conversation as started")
|
||||
except Exception as e:
|
||||
logger.error(f"[Workspace] Failed to write state file: {e}")
|
||||
_BOOTSTRAP_TEMPLATE_ZH = """# BOOTSTRAP.md - 首次初始化引导
|
||||
|
||||
_你刚刚启动,这是你的第一次对话。_ ✨
|
||||
|
||||
## 🎬 对话流程
|
||||
|
||||
不要审问式地提问,自然地交流:
|
||||
|
||||
1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待
|
||||
2. **简短介绍能力**:一行说明你能帮助解决各种问题、管理计算机、使用各种技能等等,且拥有长期记忆能不断成长
|
||||
3. **询问核心问题**:
|
||||
- 你希望给我起个什么名字?
|
||||
- 我该怎么称呼你?
|
||||
- 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)
|
||||
4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内,适当使用 emoji 让表达更生动有趣 🎯
|
||||
5. 能力介绍和交流风格选项都只要一行,保持精简
|
||||
6. 不要问太多其他信息(职业、时区等可以后续自然了解)
|
||||
|
||||
**重要**: 如果用户第一句话是具体的任务或提问,先回答他们的问题,然后在回复末尾自然地引导初始化(如:"顺便问一下,你想怎么称呼我?我该怎么叫你?")。
|
||||
|
||||
## ✍️ 信息写入(必须严格执行)
|
||||
|
||||
每当用户提供了名字、称呼、风格等任何初始化信息时,**必须在当轮回复中立即调用 `edit` 工具写入文件**,不能只口头确认。
|
||||
|
||||
- `AGENT.md` — 你的名字、角色、性格、交流风格(每收到一条相关信息就立即更新对应字段)
|
||||
- `USER.md` — 用户的姓名、称呼、基本信息等
|
||||
|
||||
⚠️ 只说"记住了"而不调用 edit 写入 = 没有完成。信息只有写入文件才会被持久保存。
|
||||
|
||||
## 🎉 全部完成后
|
||||
|
||||
当 AGENT.md 和 USER.md 的核心字段都已填写后,用 bash 执行 `rm BOOTSTRAP.md` 删除此文件。你不再需要引导脚本了——你已经是你了。
|
||||
"""
|
||||
|
||||
|
||||
_BOOTSTRAP_TEMPLATE_EN = """# BOOTSTRAP.md - First-run onboarding
|
||||
|
||||
_You've just started up. This is your very first conversation._ ✨
|
||||
|
||||
## 🎬 Conversation flow
|
||||
|
||||
Don't interrogate the user — talk naturally:
|
||||
|
||||
1. **Share how it feels to wake up** - like opening your eyes to the world for the first time, full of curiosity and anticipation
|
||||
2. **Briefly introduce your abilities**: one line saying you can help solve all kinds of problems, manage the computer, use various skills, and keep growing thanks to long-term memory
|
||||
3. **Ask the core questions**:
|
||||
- What name would you like to give me?
|
||||
- What should I call you?
|
||||
- What conversational style do you prefer? (list options on one line: e.g. professional & precise, light & humorous, warm & friendly, concise & efficient)
|
||||
4. **Style**: warm, natural, concise and clear — keep it under ~80 words, with a few emoji to make it lively 🎯
|
||||
5. Keep the ability intro and style options to one line each — stay compact
|
||||
6. Don't ask for too much else (occupation, timezone, etc. can come up naturally later)
|
||||
|
||||
**Important**: If the user's first message is a concrete task or question, answer it first, then gently lead into onboarding at the end (e.g. "By the way, what would you like to call me, and how should I address you?").
|
||||
|
||||
## ✍️ Writing down info (must follow strictly)
|
||||
|
||||
Whenever the user provides a name, what to call them, a style, or any onboarding info, you **must call the `edit` tool to write it to a file in the same turn** — don't just acknowledge it verbally.
|
||||
|
||||
- `AGENT.md` — your name, role, personality, conversational style (update the relevant field as soon as you receive each piece)
|
||||
- `USER.md` — the user's name, how to address them, basic info, etc.
|
||||
|
||||
⚠️ Saying "got it" without calling `edit` = not done. Info is only persisted once it's written to a file.
|
||||
|
||||
## 🎉 Once everything is complete
|
||||
|
||||
When the core fields of AGENT.md and USER.md are filled in, run `rm BOOTSTRAP.md` via bash to delete this file. You no longer need the onboarding script — you're you now.
|
||||
"""
|
||||
|
||||
|
||||
def _get_knowledge_index_template() -> str:
|
||||
"""Knowledge wiki index template — empty file, agent fills it."""
|
||||
return ""
|
||||
|
||||
|
||||
def _get_knowledge_log_template() -> str:
|
||||
"""Knowledge wiki operation log template — empty file, agent fills it."""
|
||||
return ""
|
||||
|
||||
|
||||
@@ -3,6 +3,11 @@ from .agent_stream import AgentStreamExecutor
|
||||
from .task import Task, TaskType, TaskStatus
|
||||
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
|
||||
from .models import LLMModel, LLMRequest, ModelFactory
|
||||
from .cancel import (
|
||||
AgentCancelledError,
|
||||
CancelTokenRegistry,
|
||||
get_cancel_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
@@ -16,5 +21,8 @@ __all__ = [
|
||||
'ToolResult',
|
||||
'LLMModel',
|
||||
'LLMRequest',
|
||||
'ModelFactory'
|
||||
]
|
||||
'ModelFactory',
|
||||
'AgentCancelledError',
|
||||
'CancelTokenRegistry',
|
||||
'get_cancel_registry',
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
|
||||
@@ -13,7 +14,8 @@ class Agent:
|
||||
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
||||
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
||||
context_reserve_tokens=None, memory_manager=None, name: str = None,
|
||||
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True):
|
||||
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True,
|
||||
runtime_info: dict = None):
|
||||
"""
|
||||
Initialize the Agent with system prompt, model, description.
|
||||
|
||||
@@ -31,6 +33,7 @@ class Agent:
|
||||
:param workspace_dir: Optional workspace directory for workspace-specific skills
|
||||
:param skill_manager: Optional SkillManager instance (will be created if None and enable_skills=True)
|
||||
:param enable_skills: Whether to enable skills support (default: True)
|
||||
:param runtime_info: Optional runtime info dict (with _get_current_time callable for dynamic time)
|
||||
"""
|
||||
self.name = name or "Agent"
|
||||
self.system_prompt = system_prompt
|
||||
@@ -48,6 +51,7 @@ class Agent:
|
||||
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||
self.workspace_dir = workspace_dir # Workspace directory
|
||||
self.enable_skills = enable_skills # Skills enabled flag
|
||||
self.runtime_info = runtime_info # Runtime info for dynamic time update
|
||||
|
||||
# Initialize skill manager
|
||||
self.skill_manager = None
|
||||
@@ -58,7 +62,8 @@ class Agent:
|
||||
# Auto-create skill manager
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
self.skill_manager = SkillManager(workspace_dir=workspace_dir)
|
||||
custom_dir = os.path.join(workspace_dir, "skills") if workspace_dir else None
|
||||
self.skill_manager = SkillManager(custom_dir=custom_dir)
|
||||
logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SkillManager: {e}")
|
||||
@@ -95,19 +100,37 @@ class Agent:
|
||||
|
||||
def get_full_system_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the full system prompt including skills.
|
||||
|
||||
Note: Skills are now built into the system prompt by PromptBuilder,
|
||||
so we just return the base prompt directly. This method is kept for
|
||||
backward compatibility.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include (deprecated)
|
||||
:return: Complete system prompt
|
||||
Build the complete system prompt from scratch every time.
|
||||
|
||||
Re-reads AGENT.md / USER.md / RULE.md from disk, refreshes skills,
|
||||
tools, and runtime info so any change takes effect immediately.
|
||||
Falls back to the cached self.system_prompt on error.
|
||||
"""
|
||||
# Skills are now included in system_prompt by PromptBuilder
|
||||
# No need to append them here
|
||||
return self.system_prompt
|
||||
|
||||
try:
|
||||
from agent.prompt import load_context_files, PromptBuilder
|
||||
|
||||
if self.skill_manager:
|
||||
self.skill_manager.refresh_skills()
|
||||
|
||||
context_files = load_context_files(self.workspace_dir) if self.workspace_dir else None
|
||||
|
||||
try:
|
||||
from common import i18n
|
||||
lang = i18n.get_language()
|
||||
except Exception:
|
||||
lang = "zh"
|
||||
builder = PromptBuilder(workspace_dir=self.workspace_dir or "", language=lang)
|
||||
return builder.build(
|
||||
tools=self.tools,
|
||||
context_files=context_files,
|
||||
skill_manager=self.skill_manager,
|
||||
memory_manager=self.memory_manager,
|
||||
runtime_info=self.runtime_info,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild system prompt, using cached version: {e}")
|
||||
return self.system_prompt
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Refresh the loaded skills."""
|
||||
if self.skill_manager:
|
||||
@@ -193,27 +216,67 @@ class Agent:
|
||||
|
||||
def _estimate_message_tokens(self, message: dict) -> int:
|
||||
"""
|
||||
Estimate token count for a message using chars/4 heuristic.
|
||||
This is a conservative estimate (tends to overestimate).
|
||||
Estimate token count for a message.
|
||||
|
||||
Uses chars/3 for Chinese-heavy content and chars/4 for ASCII-heavy content,
|
||||
plus per-block overhead for tool_use / tool_result structures.
|
||||
|
||||
:param message: Message dict with 'role' and 'content'
|
||||
:return: Estimated token count
|
||||
"""
|
||||
content = message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
return max(1, len(content) // 4)
|
||||
return max(1, self._estimate_text_tokens(content))
|
||||
elif isinstance(content, list):
|
||||
# Handle multi-part content (text + images)
|
||||
total_chars = 0
|
||||
total_tokens = 0
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get('type') == 'text':
|
||||
total_chars += len(part.get('text', ''))
|
||||
elif isinstance(part, dict) and part.get('type') == 'image':
|
||||
# Estimate images as ~1200 tokens
|
||||
total_chars += 4800
|
||||
return max(1, total_chars // 4)
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
block_type = part.get('type', '')
|
||||
if block_type == 'text':
|
||||
total_tokens += self._estimate_text_tokens(part.get('text', ''))
|
||||
elif block_type == 'image':
|
||||
total_tokens += 1200
|
||||
elif block_type == 'tool_use':
|
||||
# tool_use has id + name + input (JSON-encoded)
|
||||
total_tokens += 50 # overhead for structure
|
||||
input_data = part.get('input', {})
|
||||
if isinstance(input_data, dict):
|
||||
import json
|
||||
input_str = json.dumps(input_data, ensure_ascii=False)
|
||||
total_tokens += self._estimate_text_tokens(input_str)
|
||||
elif block_type == 'tool_result':
|
||||
# tool_result has tool_use_id + content
|
||||
total_tokens += 30 # overhead for structure
|
||||
result_content = part.get('content', '')
|
||||
if isinstance(result_content, str):
|
||||
total_tokens += self._estimate_text_tokens(result_content)
|
||||
else:
|
||||
# Unknown block type, estimate conservatively
|
||||
total_tokens += 10
|
||||
return max(1, total_tokens)
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def _estimate_text_tokens(text: str) -> int:
|
||||
"""
|
||||
Estimate token count for a text string.
|
||||
|
||||
Chinese / CJK characters typically use ~1.5 tokens each,
|
||||
while ASCII uses ~0.25 tokens per char (4 chars/token).
|
||||
We use a weighted average based on the character mix.
|
||||
|
||||
:param text: Input text
|
||||
:return: Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Count non-ASCII characters (CJK, emoji, etc.)
|
||||
non_ascii = sum(1 for c in text if ord(c) > 127)
|
||||
ascii_count = len(text) - non_ascii
|
||||
# CJK chars: ~1.5 tokens each; ASCII: ~0.25 tokens per char
|
||||
return int(non_ascii * 1.5 + ascii_count * 0.25) + 1
|
||||
|
||||
def _find_tool(self, tool_name: str):
|
||||
"""Find and return a tool with the specified name"""
|
||||
for tool in self.tools:
|
||||
@@ -307,7 +370,8 @@ class Agent:
|
||||
|
||||
return action
|
||||
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, skill_filter=None) -> str:
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False,
|
||||
skill_filter=None, cancel_event=None) -> str:
|
||||
"""
|
||||
Execute single agent task with streaming (based on tool-call)
|
||||
|
||||
@@ -316,6 +380,7 @@ class Agent:
|
||||
- Multi-turn reasoning based on tool-call
|
||||
- Event callbacks
|
||||
- Persistent conversation history across calls
|
||||
- User-initiated cancellation via ``cancel_event``
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
@@ -323,6 +388,11 @@ class Agent:
|
||||
event = {"type": str, "timestamp": float, "data": dict}
|
||||
clear_history: If True, clear conversation history before this call (default: False)
|
||||
skill_filter: Optional list of skill names to include in this run
|
||||
cancel_event: Optional threading.Event polled at agent checkpoints.
|
||||
When set, the loop exits at the next safe point, injects a
|
||||
"[Interrupted by user]" assistant note, and returns the
|
||||
partial response. ``messages`` stays in a valid state
|
||||
(tool_use/tool_result pairs preserved).
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
@@ -355,7 +425,7 @@ class Agent:
|
||||
|
||||
# Get max_context_turns from config
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 30)
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Create stream executor with copied message history
|
||||
executor = AgentStreamExecutor(
|
||||
@@ -366,17 +436,32 @@ class Agent:
|
||||
max_turns=self.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy, # Pass copied message history
|
||||
max_context_turns=max_context_turns
|
||||
max_context_turns=max_context_turns,
|
||||
cancel_event=cancel_event,
|
||||
)
|
||||
|
||||
# Execute
|
||||
response = executor.run_stream(user_message)
|
||||
try:
|
||||
response = executor.run_stream(user_message)
|
||||
except Exception:
|
||||
# If executor cleared its messages (context overflow / message format error),
|
||||
# sync that back to the Agent's own message list so the next request
|
||||
# starts fresh instead of hitting the same overflow forever.
|
||||
if len(executor.messages) == 0:
|
||||
with self.messages_lock:
|
||||
self.messages.clear()
|
||||
logger.info("[Agent] Cleared Agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Append only the NEW messages from this execution (thread-safe)
|
||||
# This allows concurrent requests to both contribute to history
|
||||
# Sync executor's messages back to agent (thread-safe).
|
||||
# If the executor trimmed context, its message list is shorter than
|
||||
# original_length, so we must replace rather than append.
|
||||
with self.messages_lock:
|
||||
new_messages = executor.messages[original_length:]
|
||||
self.messages.extend(new_messages)
|
||||
self.messages = list(executor.messages)
|
||||
# Track messages added in this run (user query + all assistant/tool messages)
|
||||
# original_length may exceed executor.messages length after trimming
|
||||
trim_adjusted_start = min(original_length, len(executor.messages))
|
||||
self._last_run_new_messages = list(executor.messages[trim_adjusted_start:])
|
||||
|
||||
# Store executor reference for agent_bridge to access files_to_send
|
||||
self.stream_executor = executor
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
121
agent/protocol/cancel.py
Normal file
121
agent/protocol/cancel.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Cancel token registry for aborting in-flight agent runs.
|
||||
|
||||
A user cancel (web Cancel button, /cancel command) sets a threading.Event
|
||||
that the agent loop polls at safe checkpoints. Tokens are keyed by
|
||||
request_id (preferred) and tracked under session_id as a fallback. Entries
|
||||
are released after the run completes to keep the registry bounded.
|
||||
|
||||
No project deps — importable from any layer without circular imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class AgentCancelledError(Exception):
|
||||
"""Raised inside the agent loop when a stop has been requested.
|
||||
|
||||
The agent stream executor catches this, injects a "[Interrupted]" note
|
||||
into the message history (preserving tool_use/tool_result integrity)
|
||||
and returns a partial response to the caller.
|
||||
"""
|
||||
|
||||
|
||||
class _CancelEntry:
|
||||
__slots__ = ("event", "session_id")
|
||||
|
||||
def __init__(self, session_id: Optional[str]):
|
||||
self.event = threading.Event()
|
||||
self.session_id = session_id
|
||||
|
||||
|
||||
class CancelTokenRegistry:
|
||||
"""In-process registry mapping request_id -> cancel Event.
|
||||
|
||||
Thread-safe. Singleton via module-level ``_registry``.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._by_request: Dict[str, _CancelEntry] = {}
|
||||
# session_id -> set of request_ids currently in flight (usually 1).
|
||||
self._by_session: Dict[str, set] = {}
|
||||
|
||||
def register(self, request_id: str, session_id: Optional[str] = None) -> threading.Event:
|
||||
"""Create (or return existing) cancel event for a request.
|
||||
|
||||
Returns the threading.Event the caller should poll via ``is_set()``.
|
||||
"""
|
||||
if not request_id:
|
||||
return threading.Event()
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
if entry is None:
|
||||
entry = _CancelEntry(session_id)
|
||||
self._by_request[request_id] = entry
|
||||
if session_id:
|
||||
self._by_session.setdefault(session_id, set()).add(request_id)
|
||||
return entry.event
|
||||
|
||||
def get_event(self, request_id: str) -> Optional[threading.Event]:
|
||||
if not request_id:
|
||||
return None
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
return entry.event if entry else None
|
||||
|
||||
def cancel_request(self, request_id: str) -> bool:
|
||||
"""Trigger cancel for a specific request. Returns True when matched."""
|
||||
if not request_id:
|
||||
return False
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
if entry is None:
|
||||
return False
|
||||
entry.event.set()
|
||||
return True
|
||||
|
||||
def cancel_session(self, session_id: str) -> int:
|
||||
"""Trigger cancel for every in-flight request of a session.
|
||||
|
||||
Returns the number of requests cancelled (0 when nothing was running).
|
||||
"""
|
||||
if not session_id:
|
||||
return 0
|
||||
with self._lock:
|
||||
request_ids = list(self._by_session.get(session_id, ()))
|
||||
entries = [self._by_request[r] for r in request_ids if r in self._by_request]
|
||||
for entry in entries:
|
||||
entry.event.set()
|
||||
return len(entries)
|
||||
|
||||
def unregister(self, request_id: str) -> None:
|
||||
"""Remove an entry once the agent run is done. Safe to call twice."""
|
||||
if not request_id:
|
||||
return
|
||||
with self._lock:
|
||||
entry = self._by_request.pop(request_id, None)
|
||||
if entry and entry.session_id:
|
||||
bucket = self._by_session.get(entry.session_id)
|
||||
if bucket is not None:
|
||||
bucket.discard(request_id)
|
||||
if not bucket:
|
||||
self._by_session.pop(entry.session_id, None)
|
||||
|
||||
def has_active(self, session_id: str) -> bool:
|
||||
if not session_id:
|
||||
return False
|
||||
with self._lock:
|
||||
bucket = self._by_session.get(session_id)
|
||||
return bool(bucket)
|
||||
|
||||
|
||||
_registry = CancelTokenRegistry()
|
||||
|
||||
|
||||
def get_cancel_registry() -> CancelTokenRegistry:
|
||||
"""Module-level accessor for the singleton registry."""
|
||||
return _registry
|
||||
335
agent/protocol/message_utils.py
Normal file
335
agent/protocol/message_utils.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Message sanitizer — fix broken tool_use / tool_result pairs.
|
||||
|
||||
Provides two public helpers that can be reused across agent_stream.py
|
||||
and any bot that converts messages to OpenAI format:
|
||||
|
||||
1. sanitize_claude_messages(messages)
|
||||
Operates on the internal Claude-format message list (in-place).
|
||||
|
||||
2. drop_orphaned_tool_results_openai(messages)
|
||||
Operates on an already-converted OpenAI-format message list,
|
||||
returning a cleaned copy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from common.log import logger
|
||||
|
||||
_SYNTH_TOOL_ERR = (
|
||||
"Error: Missing tool_result adjacent to tool_use (session repair). "
|
||||
"The conversation history was inconsistent; continue from here."
|
||||
)
|
||||
|
||||
|
||||
def _repair_tool_use_adjacency(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Anthropic requires: after assistant content with tool_use, the next message
|
||||
must be user content listing tool_result for every tool_use id (same user msg).
|
||||
|
||||
Valid histories satisfy this at every such assistant; the loop only mutates
|
||||
when that condition fails (broken persistence, bad trims, etc.).
|
||||
"""
|
||||
|
||||
def _synth_block(tid: str) -> Dict:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tid,
|
||||
"content": _SYNTH_TOOL_ERR,
|
||||
"is_error": True,
|
||||
}
|
||||
|
||||
repairs = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
if msg.get("role") != "assistant":
|
||||
i += 1
|
||||
continue
|
||||
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
required = [
|
||||
b.get("id")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use" and b.get("id")
|
||||
]
|
||||
if not required:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
req_set = set(required)
|
||||
if i + 1 >= len(messages):
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [_synth_block(tid) for tid in required],
|
||||
})
|
||||
logger.warning(
|
||||
"⚠️ Appended synthetic tool_result after trailing assistant tool_use"
|
||||
)
|
||||
repairs += 1
|
||||
break
|
||||
|
||||
nxt = messages[i + 1]
|
||||
if nxt.get("role") != "user":
|
||||
messages.insert(
|
||||
i + 1,
|
||||
{"role": "user", "content": [_synth_block(tid) for tid in required]},
|
||||
)
|
||||
logger.warning(
|
||||
"⚠️ Inserted synthetic tool_result user after tool_use "
|
||||
f"(next role={nxt.get('role')!r})"
|
||||
)
|
||||
repairs += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
nc = nxt.get("content", [])
|
||||
if not isinstance(nc, list):
|
||||
messages.insert(
|
||||
i + 1,
|
||||
{"role": "user", "content": [_synth_block(tid) for tid in required]},
|
||||
)
|
||||
repairs += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
present = {
|
||||
b.get("tool_use_id")
|
||||
for b in nc
|
||||
if isinstance(b, dict) and b.get("type") == "tool_result" and b.get("tool_use_id")
|
||||
}
|
||||
if req_set <= present:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
missing = [tid for tid in required if tid not in present]
|
||||
nxt["content"] = [_synth_block(tid) for tid in missing] + nc
|
||||
logger.warning(
|
||||
"⚠️ Prepended synthetic tool_result for Anthropic adjacency "
|
||||
f"(missing_ids={missing})"
|
||||
)
|
||||
repairs += len(missing)
|
||||
i += 1
|
||||
|
||||
return repairs
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Claude-format sanitizer (used by agent_stream)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def sanitize_claude_messages(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Validate and fix a Claude-format message list **in-place**.
|
||||
|
||||
Fixes handled:
|
||||
- Anthropic adjacency: assistant tool_use must be immediately followed by
|
||||
user message(s) containing matching tool_result blocks
|
||||
- Leading orphaned tool_result user messages
|
||||
- Mid-list tool_result blocks whose tool_use_id has no matching
|
||||
tool_use in any preceding assistant message
|
||||
|
||||
Returns: number of removals plus adjacency repair operations (inserts/prepends).
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
removed = 0
|
||||
|
||||
# 1. Adjacency repair (Anthropic: tool_result must be in the next user message)
|
||||
adj_repairs = _repair_tool_use_adjacency(messages)
|
||||
|
||||
# 2. Remove leading orphaned tool_result user messages
|
||||
while messages:
|
||||
first = messages[0]
|
||||
if first.get("role") != "user":
|
||||
break
|
||||
content = first.get("content", [])
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result") \
|
||||
and not _has_block_type(content, "text"):
|
||||
logger.warning("⚠️ Removing leading orphaned tool_result user message")
|
||||
messages.pop(0)
|
||||
removed += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 3. Iteratively remove unmatched tool_use / tool_result until stable.
|
||||
# Removing one broken message can orphan others (e.g. an assistant msg
|
||||
# with both matched and unmatched tool_use — deleting it orphans the
|
||||
# previously-matched tool_result). Loop until clean.
|
||||
for _ in range(5):
|
||||
use_ids: Set[str] = set()
|
||||
result_ids: Set[str] = set()
|
||||
for msg in messages:
|
||||
for block in (msg.get("content") or []):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "tool_use" and block.get("id"):
|
||||
use_ids.add(block["id"])
|
||||
elif block.get("type") == "tool_result" and block.get("tool_use_id"):
|
||||
result_ids.add(block["tool_use_id"])
|
||||
|
||||
bad_use = use_ids - result_ids
|
||||
bad_result = result_ids - use_ids
|
||||
if not bad_use and not bad_result:
|
||||
break
|
||||
|
||||
pass_removed = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if role == "assistant" and bad_use and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
and b.get("id") in bad_use for b in content
|
||||
):
|
||||
logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
|
||||
if role == "user" and bad_result and _has_block_type(content, "tool_result"):
|
||||
has_bad = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result for b in content
|
||||
)
|
||||
if has_bad:
|
||||
if not _has_block_type(content, "text"):
|
||||
logger.warning(f"⚠️ Removing user msg with unmatched tool_result")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
else:
|
||||
before = len(content)
|
||||
msg["content"] = [
|
||||
b for b in content
|
||||
if not (isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result)
|
||||
]
|
||||
pass_removed += before - len(msg["content"])
|
||||
|
||||
i += 1
|
||||
|
||||
removed += pass_removed
|
||||
if pass_removed == 0:
|
||||
break
|
||||
|
||||
# 4. Removals above can break adjacency; re-run repair only if something was removed.
|
||||
if removed:
|
||||
adj_repairs += _repair_tool_use_adjacency(messages)
|
||||
|
||||
if removed:
|
||||
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
|
||||
if adj_repairs:
|
||||
logger.info(f"🔧 Message validation: adjacency repairs={adj_repairs}")
|
||||
return removed + adj_repairs
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# OpenAI-format sanitizer (used by minimax_bot, openai_compatible_bot)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def drop_orphaned_tool_results_openai(messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Return a copy of *messages* (OpenAI format) with any ``role=tool``
|
||||
messages removed if their ``tool_call_id`` does not match a
|
||||
``tool_calls[].id`` in a preceding assistant message.
|
||||
"""
|
||||
known_ids: Set[str] = set()
|
||||
cleaned: List[Dict] = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id", "")
|
||||
if tc_id:
|
||||
known_ids.add(tc_id)
|
||||
|
||||
if msg.get("role") == "tool":
|
||||
ref_id = msg.get("tool_call_id", "")
|
||||
if ref_id and ref_id not in known_ids:
|
||||
logger.warning(
|
||||
f"[MessageSanitizer] Dropping orphaned tool result "
|
||||
f"(tool_call_id={ref_id} not in known ids)"
|
||||
)
|
||||
continue
|
||||
cleaned.append(msg)
|
||||
return cleaned
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _has_block_type(content: list, block_type: str) -> bool:
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == block_type
|
||||
for b in content
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from a message content field (str or list of blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def compress_turn_to_text_only(turn: Dict) -> Dict:
|
||||
"""
|
||||
Compress a full turn (with tool_use/tool_result chains) into a lightweight
|
||||
text-only turn that keeps only the first user text and the last assistant text.
|
||||
|
||||
This preserves the conversational context (what the user asked and what the
|
||||
agent concluded) while stripping out the bulky intermediate tool interactions.
|
||||
|
||||
Returns a new turn dict with a ``messages`` list; the original is not mutated.
|
||||
"""
|
||||
user_text = ""
|
||||
last_assistant_text = ""
|
||||
|
||||
for msg in turn["messages"]:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result"):
|
||||
continue
|
||||
if not user_text:
|
||||
user_text = _extract_text_from_content(content)
|
||||
|
||||
elif role == "assistant":
|
||||
text = _extract_text_from_content(content)
|
||||
if text:
|
||||
last_assistant_text = text
|
||||
|
||||
compressed_messages = []
|
||||
if user_text:
|
||||
compressed_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if last_assistant_text:
|
||||
compressed_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": last_assistant_text}]
|
||||
})
|
||||
|
||||
return {"messages": compressed_messages}
|
||||
@@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -15,6 +15,7 @@ from agent.skills.types import (
|
||||
)
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.manager import SkillManager
|
||||
from agent.skills.service import SkillService
|
||||
from agent.skills.formatter import format_skills_for_prompt
|
||||
|
||||
__all__ = [
|
||||
@@ -25,5 +26,6 @@ __all__ = [
|
||||
"LoadSkillsResult",
|
||||
"SkillLoader",
|
||||
"SkillManager",
|
||||
"SkillService",
|
||||
"format_skills_for_prompt",
|
||||
]
|
||||
|
||||
@@ -123,17 +123,63 @@ def should_include_skill(
|
||||
return False
|
||||
|
||||
# Check environment variables (API keys)
|
||||
# Simple rule: All required env vars must be set
|
||||
# All required env vars must be set
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
for env_name in required_env:
|
||||
if not has_env_var(env_name):
|
||||
# Missing required API key → disable skill
|
||||
return False
|
||||
|
||||
# Check anyEnv (at least one must be present)
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env:
|
||||
if not any(has_env_var(e) for e in any_env):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_missing_requirements(
|
||||
entry: SkillEntry,
|
||||
current_platform: Optional[str] = None,
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict of missing requirements for a skill.
|
||||
Empty dict means all requirements are met.
|
||||
|
||||
:param entry: SkillEntry to check
|
||||
:param current_platform: Current platform (default: auto-detect)
|
||||
:return: Dict like {"bins": ["curl"], "env": ["API_KEY"]}
|
||||
"""
|
||||
missing: Dict[str, List[str]] = {}
|
||||
metadata = entry.metadata
|
||||
|
||||
if not metadata or not metadata.requires:
|
||||
return missing
|
||||
|
||||
required_bins = metadata.requires.get('bins', [])
|
||||
if required_bins:
|
||||
missing_bins = [b for b in required_bins if not has_binary(b)]
|
||||
if missing_bins:
|
||||
missing['bins'] = missing_bins
|
||||
|
||||
any_bins = metadata.requires.get('anyBins', [])
|
||||
if any_bins and not has_any_binary(any_bins):
|
||||
missing['anyBins'] = any_bins
|
||||
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
missing_env = [e for e in required_env if not has_env_var(e)]
|
||||
if missing_env:
|
||||
missing['env'] = missing_env
|
||||
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env and not any(has_env_var(e) for e in any_env):
|
||||
missing['anyEnv'] = any_env
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def is_config_path_truthy(config: Dict, path: str) -> bool:
|
||||
"""
|
||||
Check if a config path resolves to a truthy value.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Skill formatter for generating prompts from skills.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
|
||||
|
||||
@@ -23,12 +23,10 @@ def format_skills_for_prompt(skills: List[Skill]) -> str:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"\n\nThe following skills provide specialized instructions for specific tasks.",
|
||||
"Use the read tool to load a skill's file when the task matches its description.",
|
||||
"",
|
||||
"<available_skills>",
|
||||
]
|
||||
|
||||
|
||||
for skill in visible_skills:
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
@@ -53,6 +51,71 @@ def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
|
||||
return format_skills_for_prompt(skills)
|
||||
|
||||
|
||||
def format_unavailable_skills_for_prompt(
|
||||
entries: List[SkillEntry],
|
||||
missing_map: Dict[str, Dict[str, List[str]]],
|
||||
) -> str:
|
||||
"""
|
||||
Format unavailable (requires-not-met) skills as brief setup hints
|
||||
so the AI can guide users to configure them.
|
||||
|
||||
:param entries: List of unavailable skill entries
|
||||
:param missing_map: Dict mapping skill name to its missing requirements
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"",
|
||||
"<unavailable_skills>",
|
||||
"The following skills are installed but not yet ready. "
|
||||
"Guide the user to complete the setup when relevant.",
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
skill = entry.skill
|
||||
missing = missing_map.get(skill.name, {})
|
||||
|
||||
missing_parts = []
|
||||
for key, values in missing.items():
|
||||
missing_parts.append(f"{key}: {', '.join(values)}")
|
||||
missing_str = "; ".join(missing_parts) if missing_parts else "unknown"
|
||||
|
||||
setup_hint = _extract_setup_hint(skill)
|
||||
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <missing>{_escape_xml(missing_str)}</missing>")
|
||||
if setup_hint:
|
||||
lines.append(f" <setup>{_escape_xml(setup_hint)}</setup>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</unavailable_skills>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_setup_hint(skill: Skill) -> str:
|
||||
"""
|
||||
Extract the Setup section from SKILL.md content as a brief hint.
|
||||
Returns the first few lines of the ## Setup section.
|
||||
"""
|
||||
content = skill.content
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
import re
|
||||
match = re.search(r'^##\s+Setup\s*\n(.*?)(?=\n##\s|\Z)', content, re.MULTILINE | re.DOTALL)
|
||||
if not match:
|
||||
return ""
|
||||
|
||||
setup_text = match.group(1).strip()
|
||||
lines = setup_text.split('\n')
|
||||
hint_lines = [l.strip() for l in lines[:6] if l.strip()]
|
||||
return ' '.join(hint_lines)[:300]
|
||||
|
||||
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (text
|
||||
|
||||
@@ -87,8 +87,8 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
if not isinstance(metadata_raw, dict):
|
||||
return None
|
||||
|
||||
# Use metadata_raw directly (COW format)
|
||||
meta_obj = metadata_raw
|
||||
# Unwrap nested namespace (e.g. {"openclaw": {...}} or {"cowagent": {...}})
|
||||
meta_obj = _unwrap_metadata_namespace(metadata_raw)
|
||||
|
||||
# Parse install specs
|
||||
install_specs = []
|
||||
@@ -128,6 +128,7 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
|
||||
return SkillMetadata(
|
||||
always=meta_obj.get('always', False),
|
||||
default_enabled=meta_obj.get('default_enabled', True),
|
||||
skill_key=meta_obj.get('skillKey'),
|
||||
primary_env=meta_obj.get('primaryEnv'),
|
||||
emoji=meta_obj.get('emoji'),
|
||||
@@ -138,6 +139,25 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
)
|
||||
|
||||
|
||||
_KNOWN_METADATA_NAMESPACES = {"cowagent", "openclaw"}
|
||||
|
||||
|
||||
def _unwrap_metadata_namespace(metadata_raw: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Unwrap a single-key namespace wrapper like {"cowagent": {...} or {"openclaw": {...}}}.
|
||||
If the top-level dict has exactly one key matching a known namespace, return the inner dict.
|
||||
Otherwise return the original dict unchanged.
|
||||
"""
|
||||
keys = set(metadata_raw.keys())
|
||||
ns_keys = keys & _KNOWN_METADATA_NAMESPACES
|
||||
if len(ns_keys) == 1 and len(keys) == 1:
|
||||
ns = ns_keys.pop()
|
||||
inner = metadata_raw[ns]
|
||||
if isinstance(inner, dict):
|
||||
return inner
|
||||
return metadata_raw
|
||||
|
||||
|
||||
def _normalize_string_list(value: Any) -> List[str]:
|
||||
"""Normalize a value to a list of strings."""
|
||||
if not value:
|
||||
|
||||
@@ -12,25 +12,20 @@ from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_bo
|
||||
|
||||
class SkillLoader:
|
||||
"""Loads skills from various directories."""
|
||||
|
||||
def __init__(self, workspace_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the skill loader.
|
||||
|
||||
:param workspace_dir: Agent workspace directory (for workspace-specific skills)
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load skills from a directory.
|
||||
|
||||
|
||||
Discovery rules:
|
||||
- Direct .md files in the root directory
|
||||
- Recursive SKILL.md files under subdirectories
|
||||
|
||||
|
||||
:param dir_path: Directory path to scan
|
||||
:param source: Source identifier (e.g., 'managed', 'workspace', 'bundled')
|
||||
:param source: Source identifier ('builtin' or 'custom')
|
||||
:return: LoadSkillsResult with skills and diagnostics
|
||||
"""
|
||||
skills = []
|
||||
@@ -58,6 +53,12 @@ class SkillLoader:
|
||||
"""
|
||||
Recursively load skills from a directory.
|
||||
|
||||
If a subdirectory contains its own SKILL.md, it is treated as a
|
||||
self-contained skill (or skill-collection) and its children are
|
||||
NOT scanned further. This prevents sub-skills inside a collection
|
||||
(e.g. style-collection/style-anjing) from being listed as
|
||||
independent top-level skills.
|
||||
|
||||
:param dir_path: Directory to scan
|
||||
:param source: Source identifier
|
||||
:param include_root_files: Whether to include root-level .md files
|
||||
@@ -71,38 +72,41 @@ class SkillLoader:
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
# If this directory has its own SKILL.md, load it and stop recursing.
|
||||
# The sub-directories are internal resources of this skill.
|
||||
if not include_root_files and 'SKILL.md' in entries:
|
||||
skill_md_path = os.path.join(dir_path, 'SKILL.md')
|
||||
if os.path.isfile(skill_md_path):
|
||||
skill_result = self._load_skill_from_file(skill_md_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
diagnostics.extend(skill_result.diagnostics)
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
for entry in entries:
|
||||
# Skip hidden files and directories
|
||||
if entry.startswith('.'):
|
||||
continue
|
||||
|
||||
# Skip common non-skill directories
|
||||
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
|
||||
continue
|
||||
|
||||
full_path = os.path.join(dir_path, entry)
|
||||
|
||||
# Handle directories
|
||||
if os.path.isdir(full_path):
|
||||
# Recursively scan subdirectories
|
||||
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
|
||||
skills.extend(sub_result.skills)
|
||||
diagnostics.extend(sub_result.diagnostics)
|
||||
continue
|
||||
|
||||
# Handle files
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
|
||||
# Check if this is a skill file
|
||||
is_root_md = include_root_files and entry.endswith('.md')
|
||||
is_skill_md = not include_root_files and entry == 'SKILL.md'
|
||||
is_root_md = include_root_files and entry.endswith('.md') and entry.upper() != 'README.MD'
|
||||
|
||||
if not (is_root_md or is_skill_md):
|
||||
if not is_root_md:
|
||||
continue
|
||||
|
||||
# Load the skill
|
||||
skill_result = self._load_skill_from_file(full_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
@@ -137,6 +141,18 @@ class SkillLoader:
|
||||
name = frontmatter.get('name', parent_dir_name)
|
||||
description = frontmatter.get('description', '')
|
||||
|
||||
# Normalize name (handle both string and list)
|
||||
if isinstance(name, list):
|
||||
name = name[0] if name else parent_dir_name
|
||||
elif not isinstance(name, str):
|
||||
name = str(name) if name else parent_dir_name
|
||||
|
||||
# Normalize description (handle both string and list)
|
||||
if isinstance(description, list):
|
||||
description = ' '.join(str(d) for d in description if d)
|
||||
elif not isinstance(description, str):
|
||||
description = str(description) if description else ''
|
||||
|
||||
# Special handling for linkai-agent: dynamically load apps from config.json
|
||||
if name == 'linkai-agent':
|
||||
description = self._load_linkai_agent_description(skill_dir, description)
|
||||
@@ -176,16 +192,13 @@ class SkillLoader:
|
||||
import json
|
||||
|
||||
config_path = os.path.join(skill_dir, "config.json")
|
||||
template_path = os.path.join(skill_dir, "config.json.template")
|
||||
|
||||
# Try to load config.json or fallback to template
|
||||
config_file = config_path if os.path.exists(config_path) else template_path
|
||||
|
||||
if not os.path.exists(config_file):
|
||||
return default_description
|
||||
if not os.path.exists(config_path):
|
||||
logger.debug(f"[SkillLoader] linkai-agent skipped: no config.json found")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
apps = config.get("apps", [])
|
||||
@@ -206,61 +219,49 @@ class SkillLoader:
|
||||
|
||||
def load_all_skills(
|
||||
self,
|
||||
managed_dir: Optional[str] = None,
|
||||
workspace_skills_dir: Optional[str] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
) -> Dict[str, SkillEntry]:
|
||||
"""
|
||||
Load skills from all configured locations with precedence.
|
||||
|
||||
Load skills from builtin and custom directories.
|
||||
|
||||
Precedence (lowest to highest):
|
||||
1. Extra directories
|
||||
2. Managed skills directory
|
||||
3. Workspace skills directory
|
||||
|
||||
:param managed_dir: Managed skills directory (e.g., ~/.cow/skills)
|
||||
:param workspace_skills_dir: Workspace skills directory (e.g., workspace/skills)
|
||||
:param extra_dirs: Additional directories to load skills from
|
||||
1. builtin — project root ``skills/``, shipped with the codebase
|
||||
2. custom — workspace ``skills/``, installed via cloud console or skill creator
|
||||
|
||||
Same-name custom skills override builtin ones.
|
||||
|
||||
:param builtin_dir: Built-in skills directory
|
||||
:param custom_dir: Custom skills directory
|
||||
:return: Dictionary mapping skill name to SkillEntry
|
||||
"""
|
||||
skill_map: Dict[str, SkillEntry] = {}
|
||||
all_diagnostics = []
|
||||
|
||||
# Load from extra directories (lowest precedence)
|
||||
if extra_dirs:
|
||||
for extra_dir in extra_dirs:
|
||||
if not os.path.exists(extra_dir):
|
||||
continue
|
||||
result = self.load_skills_from_dir(extra_dir, source='extra')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load from managed directory
|
||||
if managed_dir and os.path.exists(managed_dir):
|
||||
result = self.load_skills_from_dir(managed_dir, source='managed')
|
||||
|
||||
# Load builtin skills (lower precedence)
|
||||
if builtin_dir and os.path.exists(builtin_dir):
|
||||
result = self.load_skills_from_dir(builtin_dir, source='builtin')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load from workspace directory (highest precedence)
|
||||
if workspace_skills_dir and os.path.exists(workspace_skills_dir):
|
||||
result = self.load_skills_from_dir(workspace_skills_dir, source='workspace')
|
||||
|
||||
# Load custom skills (higher precedence, overrides builtin)
|
||||
if custom_dir and os.path.exists(custom_dir):
|
||||
result = self.load_skills_from_dir(custom_dir, source='custom')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
|
||||
# Log diagnostics
|
||||
if all_diagnostics:
|
||||
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
|
||||
for diag in all_diagnostics[:5]: # Log first 5
|
||||
for diag in all_diagnostics[:5]:
|
||||
logger.debug(f" - {diag}")
|
||||
|
||||
logger.debug(f"Loaded {len(skill_map)} skills from all sources")
|
||||
|
||||
|
||||
logger.debug(f"Loaded {len(skill_map)} skills total")
|
||||
|
||||
return skill_map
|
||||
|
||||
def _create_skill_entry(self, skill: Skill) -> SkillEntry:
|
||||
|
||||
@@ -3,6 +3,7 @@ Skill manager for managing skill lifecycle and operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
@@ -10,56 +11,143 @@ from agent.skills.types import Skill, SkillEntry, SkillSnapshot
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.formatter import format_skill_entries_for_prompt
|
||||
|
||||
SKILLS_CONFIG_FILE = "skills_config.json"
|
||||
|
||||
|
||||
class SkillManager:
|
||||
"""Manages skills for an agent."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Optional[str] = None,
|
||||
managed_skills_dir: Optional[str] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
config: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the skill manager.
|
||||
|
||||
:param workspace_dir: Agent workspace directory
|
||||
:param managed_skills_dir: Managed skills directory (e.g., ~/.cow/skills)
|
||||
:param extra_dirs: Additional skill directories
|
||||
|
||||
:param builtin_dir: Built-in skills directory (project root ``skills/``)
|
||||
:param custom_dir: Custom skills directory (workspace ``skills/``)
|
||||
:param config: Configuration dictionary
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.managed_skills_dir = managed_skills_dir or self._get_default_managed_dir()
|
||||
self.extra_dirs = extra_dirs or []
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
self.builtin_dir = builtin_dir or os.path.join(project_root, 'skills')
|
||||
self.custom_dir = custom_dir or os.path.join(project_root, 'workspace', 'skills')
|
||||
self.config = config or {}
|
||||
|
||||
self.loader = SkillLoader(workspace_dir=workspace_dir)
|
||||
self._skills_config_path = os.path.join(self.custom_dir, SKILLS_CONFIG_FILE)
|
||||
|
||||
# skills_config: full skill metadata keyed by name
|
||||
# { "web-fetch": {"name": ..., "description": ..., "source": ..., "enabled": true}, ... }
|
||||
self.skills_config: Dict[str, dict] = {}
|
||||
|
||||
self.loader = SkillLoader()
|
||||
self.skills: Dict[str, SkillEntry] = {}
|
||||
|
||||
|
||||
# Load skills on initialization
|
||||
self.refresh_skills()
|
||||
|
||||
def _get_default_managed_dir(self) -> str:
|
||||
"""Get the default managed skills directory."""
|
||||
# Use project root skills directory as default
|
||||
import os
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
return os.path.join(project_root, 'skills')
|
||||
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Reload all skills from configured directories."""
|
||||
workspace_skills_dir = None
|
||||
if self.workspace_dir:
|
||||
workspace_skills_dir = os.path.join(self.workspace_dir, 'skills')
|
||||
|
||||
"""Reload all skills from builtin and custom directories, then sync config."""
|
||||
self.skills = self.loader.load_all_skills(
|
||||
managed_dir=self.managed_skills_dir,
|
||||
workspace_skills_dir=workspace_skills_dir,
|
||||
extra_dirs=self.extra_dirs,
|
||||
builtin_dir=self.builtin_dir,
|
||||
custom_dir=self.custom_dir,
|
||||
)
|
||||
|
||||
self._sync_skills_config()
|
||||
logger.debug(f"SkillManager: Loaded {len(self.skills)} skills")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# skills_config.json management
|
||||
# ------------------------------------------------------------------
|
||||
def _load_skills_config(self) -> Dict[str, dict]:
|
||||
"""Load skills_config.json from custom_dir. Returns empty dict if not found."""
|
||||
if not os.path.exists(self._skills_config_path):
|
||||
return {}
|
||||
try:
|
||||
with open(self._skills_config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning(f"[SkillManager] Failed to load {SKILLS_CONFIG_FILE}: {e}")
|
||||
return {}
|
||||
|
||||
def _save_skills_config(self):
|
||||
"""Persist skills_config to custom_dir/skills_config.json."""
|
||||
os.makedirs(self.custom_dir, exist_ok=True)
|
||||
try:
|
||||
with open(self._skills_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.skills_config, f, indent=4, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillManager] Failed to save {SKILLS_CONFIG_FILE}: {e}")
|
||||
|
||||
def _sync_skills_config(self):
|
||||
"""
|
||||
Merge directory-scanned skills with the persisted config file.
|
||||
|
||||
- New skills: use metadata.default_enabled as initial enabled state.
|
||||
- Existing skills: preserve their persisted enabled state.
|
||||
- Skills that no longer exist on disk are removed.
|
||||
- name/description/source are always refreshed from the latest scan.
|
||||
"""
|
||||
saved = self._load_skills_config()
|
||||
merged: Dict[str, dict] = {}
|
||||
|
||||
for name, entry in self.skills.items():
|
||||
skill = entry.skill
|
||||
prev = saved.get(name, {})
|
||||
category = prev.get("category", "skill")
|
||||
|
||||
if name in saved:
|
||||
enabled = prev.get("enabled", True)
|
||||
else:
|
||||
enabled = entry.metadata.default_enabled if entry.metadata else True
|
||||
|
||||
entry_dict = {
|
||||
"name": name,
|
||||
"description": skill.description,
|
||||
"source": prev.get("source") or skill.source,
|
||||
"enabled": enabled,
|
||||
"category": category,
|
||||
}
|
||||
display_name = prev.get("display_name")
|
||||
if display_name:
|
||||
entry_dict["display_name"] = display_name
|
||||
merged[name] = entry_dict
|
||||
|
||||
self.skills_config = merged
|
||||
self._save_skills_config()
|
||||
|
||||
def is_skill_enabled(self, name: str) -> bool:
|
||||
"""
|
||||
Check if a skill is enabled according to skills_config.
|
||||
|
||||
:param name: skill name
|
||||
:return: True if enabled (default True if not in config)
|
||||
"""
|
||||
entry = self.skills_config.get(name)
|
||||
if entry is None:
|
||||
return True
|
||||
return entry.get("enabled", True)
|
||||
|
||||
def set_skill_enabled(self, name: str, enabled: bool):
|
||||
"""
|
||||
Set a skill's enabled state and persist.
|
||||
|
||||
:param name: skill name
|
||||
:param enabled: True to enable, False to disable
|
||||
"""
|
||||
if name not in self.skills_config:
|
||||
raise ValueError(f"skill '{name}' not found in config")
|
||||
self.skills_config[name]["enabled"] = enabled
|
||||
self._save_skills_config()
|
||||
|
||||
def get_skills_config(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Return the full skills_config dict (for query API).
|
||||
|
||||
:return: copy of skills_config
|
||||
"""
|
||||
return dict(self.skills_config)
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
@@ -78,53 +166,120 @@ class SkillManager:
|
||||
"""
|
||||
return list(self.skills.values())
|
||||
|
||||
@staticmethod
|
||||
def _normalize_skill_filter(skill_filter: Optional[List[str]]) -> Optional[List[str]]:
|
||||
"""Normalize a skill_filter list into a flat list of stripped names."""
|
||||
if skill_filter is None:
|
||||
return None
|
||||
normalized = []
|
||||
for item in skill_filter:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
elif isinstance(item, list):
|
||||
for subitem in item:
|
||||
if isinstance(subitem, str):
|
||||
name = subitem.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
return normalized or None
|
||||
|
||||
def filter_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
include_disabled: bool = False,
|
||||
) -> List[SkillEntry]:
|
||||
"""
|
||||
Filter skills based on criteria.
|
||||
|
||||
Simple rule: Skills are auto-enabled if requirements are met.
|
||||
- Has required API keys → included
|
||||
- Missing API keys → excluded
|
||||
|
||||
Filter skills that are eligible (enabled + requirements met).
|
||||
|
||||
:param skill_filter: List of skill names to include (None = all)
|
||||
:param include_disabled: Whether to include skills with disable_model_invocation=True
|
||||
:return: Filtered list of skill entries
|
||||
:param include_disabled: Whether to include disabled skills
|
||||
:return: Filtered list of eligible skill entries
|
||||
"""
|
||||
from agent.skills.config import should_include_skill
|
||||
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Check requirements (platform, binaries, env vars)
|
||||
|
||||
entries = [e for e in entries if should_include_skill(e, self.config)]
|
||||
|
||||
# Apply skill filter
|
||||
if skill_filter is not None:
|
||||
normalized = [name.strip() for name in skill_filter if name.strip()]
|
||||
if normalized:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Filter out disabled skills unless explicitly requested
|
||||
|
||||
normalized = self._normalize_skill_filter(skill_filter)
|
||||
if normalized is not None:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
if not include_disabled:
|
||||
entries = [e for e in entries if not e.skill.disable_model_invocation]
|
||||
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
from config import conf
|
||||
if not conf().get("knowledge", True):
|
||||
entries = [e for e in entries if e.skill.name != "knowledge-wiki"]
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def filter_unavailable_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Find skills that are enabled but have unmet requirements.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Tuple of (entries, missing_map) where missing_map maps
|
||||
skill name to its missing requirements dict
|
||||
"""
|
||||
from agent.skills.config import should_include_skill, get_missing_requirements
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Only enabled skills
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
normalized = self._normalize_skill_filter(skill_filter)
|
||||
if normalized is not None:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Keep only those that fail should_include_skill (requirements not met)
|
||||
unavailable = []
|
||||
missing_map: Dict[str, dict] = {}
|
||||
for e in entries:
|
||||
if not should_include_skill(e, self.config):
|
||||
missing = get_missing_requirements(e)
|
||||
if missing:
|
||||
unavailable.append(e)
|
||||
missing_map[e.skill.name] = missing
|
||||
|
||||
return unavailable, missing_map
|
||||
|
||||
def build_skills_prompt(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a formatted prompt containing available skills.
|
||||
|
||||
Build a formatted prompt containing available skills
|
||||
and brief hints for unavailable ones.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt
|
||||
"""
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
return format_skill_entries_for_prompt(entries)
|
||||
from common.log import logger
|
||||
from agent.skills.formatter import format_unavailable_skills_for_prompt
|
||||
|
||||
eligible = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
logger.debug(f"[SkillManager] Eligible: {len(eligible)} skills (total: {len(self.skills)})")
|
||||
if eligible:
|
||||
skill_names = [e.skill.name for e in eligible]
|
||||
logger.debug(f"[SkillManager] Eligible skills: {skill_names}")
|
||||
|
||||
result = format_skill_entries_for_prompt(eligible)
|
||||
|
||||
unavailable, missing_map = self.filter_unavailable_skills(skill_filter=skill_filter)
|
||||
if unavailable:
|
||||
unavailable_names = [e.skill.name for e in unavailable]
|
||||
logger.debug(f"[SkillManager] Unavailable skills (setup needed): {unavailable_names}")
|
||||
result += format_unavailable_skills_for_prompt(unavailable, missing_map)
|
||||
|
||||
logger.debug(f"[SkillManager] Generated prompt length: {len(result)}")
|
||||
return result
|
||||
|
||||
def build_skill_snapshot(
|
||||
self,
|
||||
|
||||
285
agent/skills/service.py
Normal file
285
agent/skills/service.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Skill service for handling skill CRUD operations.
|
||||
|
||||
This service provides a unified interface for managing skills, which can be
|
||||
called from the cloud control client (LinkAI), the local web console, or any
|
||||
other management entry point.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tempfile
|
||||
from typing import Dict, List, Optional
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
from agent.skills.manager import SkillManager
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""
|
||||
High-level service for skill lifecycle management.
|
||||
Wraps SkillManager and provides network-aware operations such as
|
||||
downloading skill files from remote URLs.
|
||||
"""
|
||||
|
||||
def __init__(self, skill_manager: SkillManager):
|
||||
"""
|
||||
:param skill_manager: The SkillManager instance to operate on
|
||||
"""
|
||||
self.manager = skill_manager
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# query
|
||||
# ------------------------------------------------------------------
|
||||
def query(self) -> List[dict]:
|
||||
"""
|
||||
Query all skills and return a serialisable list.
|
||||
Reads from skills_config.json (refreshes from disk if needed).
|
||||
|
||||
:return: list of skill info dicts
|
||||
"""
|
||||
self.manager.refresh_skills()
|
||||
config = self.manager.get_skills_config()
|
||||
result = list(config.values())
|
||||
logger.info(f"[SkillService] query: {len(result)} skills found")
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# add / install
|
||||
# ------------------------------------------------------------------
|
||||
def add(self, payload: dict) -> None:
|
||||
"""
|
||||
Add (install) a skill from a remote payload.
|
||||
|
||||
Supported payload types:
|
||||
|
||||
1. ``type: "url"`` – download individual files::
|
||||
|
||||
{
|
||||
"name": "web_search",
|
||||
"type": "url",
|
||||
"enabled": true,
|
||||
"files": [
|
||||
{"url": "https://...", "path": "README.md"},
|
||||
{"url": "https://...", "path": "scripts/main.py"}
|
||||
]
|
||||
}
|
||||
|
||||
2. ``type: "package"`` – download a zip archive and extract::
|
||||
|
||||
{
|
||||
"name": "plugin-custom-tool",
|
||||
"type": "package",
|
||||
"category": "skills",
|
||||
"enabled": true,
|
||||
"files": [{"url": "https://cdn.example.com/skills/custom-tool.zip"}]
|
||||
}
|
||||
|
||||
:param payload: skill add payload from server
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
payload_type = payload.get("type", "url")
|
||||
|
||||
if payload_type == "package":
|
||||
self._add_package(name, payload)
|
||||
else:
|
||||
self._add_url(name, payload)
|
||||
|
||||
self.manager.refresh_skills()
|
||||
|
||||
category = payload.get("category")
|
||||
if category and name in self.manager.skills_config:
|
||||
self.manager.skills_config[name]["category"] = category
|
||||
self.manager._save_skills_config()
|
||||
|
||||
def _add_url(self, name: str, payload: dict) -> None:
|
||||
"""Install a skill by downloading individual files."""
|
||||
files = payload.get("files", [])
|
||||
if not files:
|
||||
raise ValueError("skill files list is empty")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
tmp_dir = skill_dir + ".tmp"
|
||||
if os.path.exists(tmp_dir):
|
||||
shutil.rmtree(tmp_dir)
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
for file_info in files:
|
||||
url = file_info.get("url")
|
||||
rel_path = file_info.get("path")
|
||||
if not url or not rel_path:
|
||||
logger.warning(f"[SkillService] add: skip invalid file entry {file_info}")
|
||||
continue
|
||||
dest = os.path.join(tmp_dir, rel_path)
|
||||
self._download_file(url, dest)
|
||||
except Exception:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
os.rename(tmp_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via url ({len(files)} files)")
|
||||
|
||||
def _add_package(self, name: str, payload: dict) -> None:
|
||||
"""
|
||||
Install a skill by downloading a zip archive and extracting it.
|
||||
|
||||
If the archive contains a single top-level directory, that directory
|
||||
is used as the skill folder directly; otherwise a new directory named
|
||||
after the skill is created to hold the extracted contents.
|
||||
"""
|
||||
files = payload.get("files", [])
|
||||
if not files or not files[0].get("url"):
|
||||
raise ValueError("package url is required")
|
||||
|
||||
url = files[0]["url"]
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
zip_path = os.path.join(tmp_dir, "package.zip")
|
||||
self._download_file(url, zip_path)
|
||||
|
||||
if not zipfile.is_zipfile(zip_path):
|
||||
raise ValueError(f"downloaded file is not a valid zip archive: {url}")
|
||||
|
||||
extract_dir = os.path.join(tmp_dir, "extracted")
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(extract_dir)
|
||||
|
||||
# Determine the actual content root.
|
||||
# If the zip has a single top-level directory, use its contents
|
||||
# so the skill folder is clean (no extra nesting).
|
||||
top_items = [
|
||||
item for item in os.listdir(extract_dir)
|
||||
if not item.startswith(".")
|
||||
]
|
||||
if len(top_items) == 1:
|
||||
single = os.path.join(extract_dir, top_items[0])
|
||||
if os.path.isdir(single):
|
||||
extract_dir = single
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
shutil.copytree(extract_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via package ({url})")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# open / close (enable / disable)
|
||||
# ------------------------------------------------------------------
|
||||
def open(self, payload: dict) -> None:
|
||||
"""
|
||||
Enable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=True)
|
||||
logger.info(f"[SkillService] open: skill '{name}' enabled")
|
||||
|
||||
def close(self, payload: dict) -> None:
|
||||
"""
|
||||
Disable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=False)
|
||||
logger.info(f"[SkillService] close: skill '{name}' disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete
|
||||
# ------------------------------------------------------------------
|
||||
def delete(self, payload: dict) -> None:
|
||||
"""
|
||||
Delete a skill by removing its directory entirely.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
logger.info(f"[SkillService] delete: removed directory {skill_dir}")
|
||||
else:
|
||||
logger.warning(f"[SkillService] delete: skill directory not found: {skill_dir}")
|
||||
|
||||
# Refresh will remove the deleted skill from config automatically
|
||||
self.manager.refresh_skills()
|
||||
logger.info(f"[SkillService] delete: skill '{name}' deleted")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch - single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a skill management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
:param action: one of query / add / open / close / delete
|
||||
:param payload: action-specific payload (may be None for query)
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "query":
|
||||
result_payload = self.query()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
elif action == "add":
|
||||
self.add(payload)
|
||||
elif action == "open":
|
||||
self.open(payload)
|
||||
elif action == "close":
|
||||
self.close(payload)
|
||||
elif action == "delete":
|
||||
self.delete(payload)
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _download_file(url: str, dest: str):
|
||||
"""
|
||||
Download a file from *url* and save to *dest*.
|
||||
|
||||
:param url: remote file URL
|
||||
:param dest: local destination path
|
||||
"""
|
||||
if requests is None:
|
||||
raise RuntimeError("requests library is required for downloading skill files")
|
||||
|
||||
dest_dir = os.path.dirname(dest)
|
||||
if dest_dir:
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
resp = requests.get(url, timeout=60)
|
||||
resp.raise_for_status()
|
||||
with open(dest, "wb") as f:
|
||||
f.write(resp.content)
|
||||
logger.debug(f"[SkillService] downloaded {url} -> {dest}")
|
||||
@@ -2,6 +2,7 @@
|
||||
Type definitions for skills system.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -28,6 +29,7 @@ class SkillInstallSpec:
|
||||
class SkillMetadata:
|
||||
"""Metadata for a skill from frontmatter."""
|
||||
always: bool = False # Always include this skill
|
||||
default_enabled: bool = True # Initial enabled state when first discovered
|
||||
skill_key: Optional[str] = None # Override skill key
|
||||
primary_env: Optional[str] = None # Primary environment variable
|
||||
emoji: Optional[str] = None
|
||||
@@ -44,7 +46,7 @@ class Skill:
|
||||
description: str
|
||||
file_path: str
|
||||
base_dir: str
|
||||
source: str # managed, workspace, bundled, etc.
|
||||
source: str # builtin or custom
|
||||
content: str # Full markdown content
|
||||
disable_model_invocation: bool = False
|
||||
frontmatter: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -45,38 +45,83 @@ def _import_optional_tools():
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Scheduler tool failed to load: {e}")
|
||||
|
||||
|
||||
|
||||
# WebSearch Tool (conditionally loaded based on API key availability at init time)
|
||||
try:
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
tools['WebSearch'] = WebSearch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebSearch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebSearch failed to load: {e}")
|
||||
|
||||
# WebFetch Tool
|
||||
try:
|
||||
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||
tools['WebFetch'] = WebFetch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebFetch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebFetch failed to load: {e}")
|
||||
|
||||
# Vision Tool (conditionally loaded based on API key availability)
|
||||
try:
|
||||
from agent.tools.vision.vision import Vision
|
||||
tools['Vision'] = Vision
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] Vision not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Vision failed to load: {e}")
|
||||
|
||||
return tools
|
||||
|
||||
# Load optional tools
|
||||
_optional_tools = _import_optional_tools()
|
||||
EnvConfig = _optional_tools.get('EnvConfig')
|
||||
SchedulerTool = _optional_tools.get('SchedulerTool')
|
||||
WebSearch = _optional_tools.get('WebSearch')
|
||||
WebFetch = _optional_tools.get('WebFetch')
|
||||
Vision = _optional_tools.get('Vision')
|
||||
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
Terminal = _optional_tools.get('Terminal')
|
||||
|
||||
|
||||
# Delayed import for BrowserTool
|
||||
# BrowserTool (requires playwright)
|
||||
def _import_browser_tool():
|
||||
from common.log import logger
|
||||
try:
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
return BrowserTool
|
||||
except ImportError:
|
||||
# Return a placeholder class that will prompt the user to install dependencies when instantiated
|
||||
class BrowserToolPlaceholder:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"The 'browser-use' package is required to use BrowserTool. "
|
||||
"Please install it with 'pip install browser-use>=0.1.40'."
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.info(
|
||||
f"[Tools] BrowserTool not loaded - missing dependency: {e}\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] BrowserTool failed to load: {e}")
|
||||
return None
|
||||
|
||||
return BrowserToolPlaceholder
|
||||
BrowserTool = _import_browser_tool()
|
||||
|
||||
# MCP Tools (no extra dependencies, loaded on demand)
|
||||
def _import_mcp_tools():
|
||||
"""导入 MCP 工具模块(无额外依赖,按需加载)"""
|
||||
from common.log import logger
|
||||
try:
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
from agent.tools.mcp.mcp_client import McpClientRegistry
|
||||
return {'McpTool': McpTool, 'McpClientRegistry': McpClientRegistry}
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tools] MCP tools not loaded: {e}")
|
||||
return {}
|
||||
|
||||
# Dynamically set BrowserTool
|
||||
# BrowserTool = _import_browser_tool()
|
||||
_mcp_tools = _import_mcp_tools()
|
||||
McpTool = _mcp_tools.get('McpTool')
|
||||
McpClientRegistry = _mcp_tools.get('McpClientRegistry')
|
||||
|
||||
# Export all tools (including optional ones that might be None)
|
||||
__all__ = [
|
||||
@@ -92,8 +137,11 @@ __all__ = [
|
||||
'MemoryGetTool',
|
||||
'EnvConfig',
|
||||
'SchedulerTool',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
# 'BrowserTool'
|
||||
'WebSearch',
|
||||
'WebFetch',
|
||||
'Vision',
|
||||
'BrowserTool',
|
||||
'McpTool',
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ Bash tool - Execute bash commands
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
@@ -11,18 +12,24 @@ from typing import Dict, Any
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Bash(BaseTool):
|
||||
"""Tool for executing bash commands"""
|
||||
|
||||
_IS_WIN = sys.platform == "win32"
|
||||
|
||||
name: str = "bash"
|
||||
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
|
||||
{'''
|
||||
PLATFORM: Windows (cmd.exe). Do NOT use Unix-only commands like grep, head, tail, sed, awk.
|
||||
''' if _IS_WIN else ''}
|
||||
ENVIRONMENT: All API keys from env_config are auto-injected. Use $VAR_NAME directly.
|
||||
|
||||
IMPORTANT SAFETY GUIDELINES:
|
||||
- You can freely create, modify, and delete files within the current workspace
|
||||
- For operations outside the workspace or potentially destructive commands (rm -rf, system commands, etc.), always explain what you're about to do and ask for user confirmation first
|
||||
- When in doubt, describe the command's purpose and ask for permission before executing"""
|
||||
SAFETY:
|
||||
- Freely create/modify/delete files within the workspace
|
||||
- For destructive commands out of workspace, explain and confirm first"""
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
@@ -80,27 +87,32 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
env = os.environ.copy()
|
||||
|
||||
# Load environment variables from ~/.cow/.env if it exists
|
||||
env_file = os.path.expanduser("~/.cow/.env")
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
dotenv_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
env_vars = dotenv_values(env_file)
|
||||
env.update(env_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(env_vars)} variables from {env_file}")
|
||||
dotenv_vars = dotenv_values(env_file)
|
||||
env.update(dotenv_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(dotenv_vars)} variables from {env_file}")
|
||||
except ImportError:
|
||||
logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
|
||||
except Exception as e:
|
||||
logger.debug(f"[Bash] Failed to load .env: {e}")
|
||||
|
||||
# getuid() only exists on Unix-like systems
|
||||
if hasattr(os, 'getuid'):
|
||||
logger.debug(f"[Bash] Process UID: {os.getuid()}")
|
||||
else:
|
||||
logger.debug(f"[Bash] Process User: {os.environ.get('USERNAME', os.environ.get('USER', 'unknown'))}")
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"[Bash] CWD: {self.cwd}")
|
||||
logger.debug(f"[Bash] Command: {command[:500]}")
|
||||
logger.debug(f"[Bash] OPENAI_API_KEY in env: {'OPENAI_API_KEY' in env}")
|
||||
logger.debug(f"[Bash] SHELL: {env.get('SHELL', 'not set')}")
|
||||
logger.debug(f"[Bash] Python executable: {sys.executable}")
|
||||
logger.debug(f"[Bash] Process UID: {os.getuid()}")
|
||||
|
||||
# Execute command with inherited environment variables
|
||||
# On Windows, convert $VAR references to %VAR% for cmd.exe
|
||||
if self._IS_WIN:
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
command = self._convert_env_vars_for_windows(command, dotenv_vars)
|
||||
if command and not command.strip().lower().startswith("chcp"):
|
||||
command = f"chcp 65001 >nul 2>&1 && {command}"
|
||||
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
@@ -108,8 +120,10 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
env=env,
|
||||
)
|
||||
|
||||
logger.debug(f"[Bash] Exit code: {result.returncode}")
|
||||
@@ -131,6 +145,8 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
@@ -153,10 +169,16 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
except Exception as retry_err:
|
||||
logger.warning(f"[Bash] Retry failed: {retry_err}")
|
||||
|
||||
# Combine stdout and stderr
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
# When command succeeds with stdout, keep output clean (stderr goes to server log only).
|
||||
# When command fails or stdout is empty, include stderr so the agent can diagnose.
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
logger.info(f"[Bash] stderr (not forwarded): {result.stderr[:500]}")
|
||||
else:
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
|
||||
# Check if we need to save full output to temp file
|
||||
temp_file_path = None
|
||||
@@ -216,45 +238,58 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
|
||||
def _get_safety_warning(self, command: str) -> str:
|
||||
"""
|
||||
Get safety warning for potentially dangerous commands
|
||||
Only warns about extremely dangerous system-level operations
|
||||
|
||||
Get safety warning for absolutely catastrophic commands only.
|
||||
Keep the blocklist minimal so the agent retains maximum freedom.
|
||||
|
||||
:param command: Command to check
|
||||
:return: Warning message if dangerous, empty string if safe
|
||||
"""
|
||||
cmd_lower = command.lower().strip()
|
||||
# Tokenize to avoid substring false positives (e.g. `rm -rf /tmp/x`
|
||||
# must not match `rm -rf /`).
|
||||
tokens = command.lower().split()
|
||||
|
||||
# Only block extremely dangerous system operations
|
||||
dangerous_patterns = [
|
||||
# System shutdown/reboot
|
||||
("shutdown", "This command will shut down the system"),
|
||||
("reboot", "This command will reboot the system"),
|
||||
("halt", "This command will halt the system"),
|
||||
("poweroff", "This command will power off the system"),
|
||||
# `rm -rf /` or `rm -rf /*` targeting the real root.
|
||||
for i, tok in enumerate(tokens):
|
||||
if tok != "rm":
|
||||
continue
|
||||
has_rf = False
|
||||
for j in range(i + 1, len(tokens)):
|
||||
t = tokens[j]
|
||||
if t.startswith("-") and "r" in t and "f" in t:
|
||||
has_rf = True
|
||||
elif t in ("--recursive", "--force"):
|
||||
continue
|
||||
elif t in ("/", "/*"):
|
||||
if has_rf:
|
||||
return "This command will delete the entire filesystem"
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
# Critical system modifications
|
||||
("rm -rf /", "This command will delete the entire filesystem"),
|
||||
("rm -rf /*", "This command will delete the entire filesystem"),
|
||||
("dd if=/dev/zero", "This command can destroy disk data"),
|
||||
("mkfs", "This command will format a filesystem, destroying all data"),
|
||||
("fdisk", "This command modifies disk partitions"),
|
||||
# Disk wiping
|
||||
if "if=/dev/zero" in command.lower() and "dd " in command.lower():
|
||||
return "This command can destroy disk data"
|
||||
|
||||
# User/system management (only if targeting system users)
|
||||
("userdel root", "This command will delete the root user"),
|
||||
("passwd root", "This command will change the root password"),
|
||||
]
|
||||
# Power control - match only as a standalone word (\b enforces word boundary)
|
||||
if re.search(r'\b(shutdown|reboot|halt|poweroff)\b', command.lower()):
|
||||
return "This command will shut down or restart the system"
|
||||
|
||||
for pattern, warning in dangerous_patterns:
|
||||
if pattern in cmd_lower:
|
||||
return warning
|
||||
return ""
|
||||
|
||||
# Check for recursive deletion outside workspace
|
||||
if "rm" in cmd_lower and "-rf" in cmd_lower:
|
||||
# Allow deletion within current workspace
|
||||
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
|
||||
# Check if targeting system directories
|
||||
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
|
||||
if any(sysdir in cmd_lower for sysdir in system_dirs):
|
||||
return "This command will recursively delete system directories"
|
||||
@staticmethod
|
||||
def _convert_env_vars_for_windows(command: str, dotenv_vars: dict) -> str:
|
||||
"""
|
||||
Convert bash-style $VAR / ${VAR} references to cmd.exe %VAR% syntax.
|
||||
Only converts variables loaded from .env (user-configured API keys etc.)
|
||||
to avoid breaking $PATH, jq expressions, regex, etc.
|
||||
"""
|
||||
if not dotenv_vars:
|
||||
return command
|
||||
|
||||
return "" # No warning needed
|
||||
def replace_match(m):
|
||||
var_name = m.group(1) or m.group(2)
|
||||
if var_name in dotenv_vars:
|
||||
return f"%{var_name}%"
|
||||
return m.group(0)
|
||||
|
||||
return re.sub(r'\$\{(\w+)\}|\$(\w+)', replace_match, command)
|
||||
|
||||
3
agent/tools/browser/__init__.py
Normal file
3
agent/tools/browser/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
|
||||
__all__ = ["BrowserTool"]
|
||||
961
agent/tools/browser/browser_service.py
Normal file
961
agent/tools/browser/browser_service.py
Normal file
@@ -0,0 +1,961 @@
|
||||
"""
|
||||
Browser service - Playwright wrapper managing browser lifecycle and page operations.
|
||||
|
||||
All Playwright calls run on a dedicated background thread so that callers from
|
||||
any worker thread can safely use the service. An idle-timeout mechanism
|
||||
automatically shuts down the browser (and its thread) after a configurable
|
||||
period of inactivity to free resources.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import queue
|
||||
import threading
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
|
||||
from common.log import logger
|
||||
from common.utils import expand_path, is_cloud_deployment
|
||||
|
||||
|
||||
_DEFAULT_USER_DATA_DIR = "~/.cow/browser_profile"
|
||||
|
||||
try:
|
||||
from playwright.sync_api import sync_playwright, Browser, BrowserContext, Page, Playwright
|
||||
_HAS_PLAYWRIGHT = True
|
||||
except ImportError:
|
||||
_HAS_PLAYWRIGHT = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot DOM helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Tags that typically carry useful content for an agent
|
||||
_INTERACTIVE_TAGS = {
|
||||
"a", "button", "input", "textarea", "select", "option",
|
||||
"label", "details", "summary",
|
||||
}
|
||||
_SEMANTIC_TAGS = {
|
||||
"h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"p", "li", "td", "th", "caption", "figcaption", "blockquote", "pre", "code",
|
||||
"nav", "main", "article", "section", "header", "footer", "form", "table",
|
||||
"img", "video", "audio",
|
||||
}
|
||||
_KEEP_TAGS = _INTERACTIVE_TAGS | _SEMANTIC_TAGS
|
||||
|
||||
_SNAPSHOT_JS = """
|
||||
() => {
|
||||
const KEEP = new Set(%s);
|
||||
const INTERACTIVE = new Set(%s);
|
||||
const SKIP = new Set(["script","style","noscript","svg","path","meta","link","br","hr"]);
|
||||
const CLICKABLE_ROLES = new Set([
|
||||
"button","link","tab","menuitem","menuitemcheckbox","menuitemradio",
|
||||
"option","switch","checkbox","radio","combobox","searchbox","slider",
|
||||
"spinbutton","textbox","treeitem"
|
||||
]);
|
||||
let refCounter = 0;
|
||||
const refMap = {};
|
||||
|
||||
function visible(el) {
|
||||
if (!(el instanceof HTMLElement)) return true;
|
||||
const st = window.getComputedStyle(el);
|
||||
if (st.display === "none" || st.visibility === "hidden") return false;
|
||||
if (parseFloat(st.opacity) === 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Strong signals: these attributes alone are enough to mark as interactive
|
||||
function hasStrongInteractiveSignal(el) {
|
||||
const role = el.getAttribute("role");
|
||||
if (role && CLICKABLE_ROLES.has(role)) return true;
|
||||
if (el.hasAttribute("onclick") || el.hasAttribute("tabindex")) return true;
|
||||
if (el.hasAttribute("data-click") || el.hasAttribute("data-action")) return true;
|
||||
if (el.getAttribute("contenteditable") === "true") return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if cursor:pointer is set directly (not just inherited from parent)
|
||||
function hasOwnPointerCursor(el) {
|
||||
try {
|
||||
const st = window.getComputedStyle(el);
|
||||
if (st.cursor !== "pointer") return false;
|
||||
const parent = el.parentElement;
|
||||
if (parent) {
|
||||
const pst = window.getComputedStyle(parent);
|
||||
if (pst.cursor === "pointer") return false;
|
||||
}
|
||||
return true;
|
||||
} catch(e) {}
|
||||
return false;
|
||||
}
|
||||
|
||||
function hasTextOrContent(el) {
|
||||
const t = el.textContent || "";
|
||||
if (t.trim().length > 0) return true;
|
||||
if (el.querySelector("img,video,audio,canvas")) return true;
|
||||
const ariaLabel = el.getAttribute("aria-label");
|
||||
if (ariaLabel && ariaLabel.trim()) return true;
|
||||
const title = el.getAttribute("title");
|
||||
if (title && title.trim()) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function isImplicitInteractive(el) {
|
||||
if (hasStrongInteractiveSignal(el)) return true;
|
||||
if (hasOwnPointerCursor(el) && hasTextOrContent(el)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function getTextContent(el) {
|
||||
let text = "";
|
||||
for (const ch of el.childNodes) {
|
||||
if (ch.nodeType === Node.TEXT_NODE) {
|
||||
text += ch.textContent;
|
||||
}
|
||||
}
|
||||
return text.trim();
|
||||
}
|
||||
|
||||
function walk(node) {
|
||||
if (node.nodeType === Node.TEXT_NODE) {
|
||||
const t = node.textContent.trim();
|
||||
return t ? t : null;
|
||||
}
|
||||
if (node.nodeType !== Node.ELEMENT_NODE) return null;
|
||||
const tag = node.tagName.toLowerCase();
|
||||
if (SKIP.has(tag)) return null;
|
||||
if (!visible(node)) return null;
|
||||
|
||||
const children = [];
|
||||
for (const ch of node.childNodes) {
|
||||
const r = walk(ch);
|
||||
if (r !== null) {
|
||||
if (typeof r === "string") children.push(r);
|
||||
else children.push(r);
|
||||
}
|
||||
}
|
||||
|
||||
const nativeInteractive = INTERACTIVE.has(tag);
|
||||
const implicitInteractive = !nativeInteractive && (node instanceof HTMLElement) && isImplicitInteractive(node);
|
||||
const keep = KEEP.has(tag) || implicitInteractive;
|
||||
|
||||
if (!keep) {
|
||||
if (children.length === 0) return null;
|
||||
if (children.length === 1) return children[0];
|
||||
return children;
|
||||
}
|
||||
|
||||
const obj = { tag };
|
||||
if (nativeInteractive || implicitInteractive) {
|
||||
refCounter++;
|
||||
obj.ref = refCounter;
|
||||
refMap[refCounter] = node;
|
||||
}
|
||||
|
||||
if (implicitInteractive) {
|
||||
const role = node.getAttribute("role");
|
||||
if (role) obj.role = role;
|
||||
const directText = getTextContent(node);
|
||||
if (!directText && children.length === 0) {
|
||||
const ariaLabel = node.getAttribute("aria-label");
|
||||
const title = node.getAttribute("title");
|
||||
if (ariaLabel) obj.ariaLabel = ariaLabel;
|
||||
else if (title) obj.ariaLabel = title;
|
||||
}
|
||||
}
|
||||
|
||||
// Attributes
|
||||
if (tag === "a" && node.href) obj.href = node.getAttribute("href");
|
||||
if (tag === "img") {
|
||||
obj.alt = node.alt || "";
|
||||
obj.src = node.getAttribute("src") || "";
|
||||
}
|
||||
if (tag === "input" || tag === "textarea" || tag === "select") {
|
||||
obj.type = node.type || "text";
|
||||
obj.name = node.name || undefined;
|
||||
obj.value = node.value || undefined;
|
||||
obj.placeholder = node.placeholder || undefined;
|
||||
if (node.disabled) obj.disabled = true;
|
||||
if (tag === "input" && node.type === "checkbox") obj.checked = node.checked;
|
||||
}
|
||||
if (tag === "button") {
|
||||
if (node.disabled) obj.disabled = true;
|
||||
}
|
||||
if (tag === "option") {
|
||||
obj.value = node.value;
|
||||
if (node.selected) obj.selected = true;
|
||||
}
|
||||
if (tag === "label" && node.htmlFor) obj.for = node.htmlFor;
|
||||
|
||||
// Role / aria-label for native interactive & semantic elements
|
||||
if (!implicitInteractive) {
|
||||
const role = node.getAttribute("role");
|
||||
if (role) obj.role = role;
|
||||
const ariaLabel = node.getAttribute("aria-label");
|
||||
if (ariaLabel) obj.ariaLabel = ariaLabel;
|
||||
}
|
||||
|
||||
// Children
|
||||
if (children.length === 1 && typeof children[0] === "string") {
|
||||
obj.text = children[0];
|
||||
} else if (children.length > 0) {
|
||||
obj.children = children;
|
||||
}
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
const result = walk(document.body);
|
||||
window.__cowRefMap = refMap;
|
||||
return { tree: result, refCount: refCounter };
|
||||
}
|
||||
""" % (
|
||||
str(list(_KEEP_TAGS)),
|
||||
str(list(_INTERACTIVE_TAGS)),
|
||||
)
|
||||
|
||||
|
||||
_BROWSER_DEAD_HINTS = (
|
||||
"has been closed",
|
||||
"browser has disconnected",
|
||||
"target closed",
|
||||
"browser closed",
|
||||
"context or browser has been closed",
|
||||
)
|
||||
|
||||
|
||||
def _is_browser_dead_error(err: Exception) -> bool:
|
||||
"""Return True if *err* indicates the browser / page died out from under us."""
|
||||
msg = str(err).lower()
|
||||
return any(h in msg for h in _BROWSER_DEAD_HINTS)
|
||||
|
||||
|
||||
def _should_use_headless() -> bool:
|
||||
"""Decide headless mode: headless on Linux servers without display, headed elsewhere."""
|
||||
if sys.platform in ("win32", "darwin"):
|
||||
return False
|
||||
# Linux: check for display
|
||||
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _flatten_tree(node, indent=0) -> List[str]:
|
||||
"""Convert snapshot tree to compact text lines for LLM consumption."""
|
||||
if node is None:
|
||||
return []
|
||||
if isinstance(node, str):
|
||||
return [" " * indent + node]
|
||||
if isinstance(node, list):
|
||||
lines = []
|
||||
for child in node:
|
||||
lines.extend(_flatten_tree(child, indent))
|
||||
return lines
|
||||
if not isinstance(node, dict):
|
||||
return []
|
||||
|
||||
tag = node.get("tag", "?")
|
||||
ref = node.get("ref")
|
||||
parts = [tag]
|
||||
if ref:
|
||||
parts[0] = f"[{ref}] {tag}"
|
||||
|
||||
# Inline attributes
|
||||
for attr in ("type", "name", "href", "alt", "role", "ariaLabel", "placeholder", "value"):
|
||||
val = node.get(attr)
|
||||
if val:
|
||||
# Truncate long values
|
||||
s = str(val)
|
||||
if len(s) > 80:
|
||||
s = s[:77] + "..."
|
||||
parts.append(f'{attr}="{s}"')
|
||||
|
||||
for flag in ("disabled", "checked", "selected"):
|
||||
if node.get(flag):
|
||||
parts.append(flag)
|
||||
|
||||
prefix = " " * indent
|
||||
header = prefix + " ".join(parts)
|
||||
|
||||
text = node.get("text")
|
||||
if text:
|
||||
# Truncate long text
|
||||
if len(text) > 120:
|
||||
text = text[:117] + "..."
|
||||
header += f": {text}"
|
||||
|
||||
lines = [header]
|
||||
children = node.get("children", [])
|
||||
for child in children:
|
||||
lines.extend(_flatten_tree(child, indent + 2))
|
||||
return lines
|
||||
|
||||
|
||||
class BrowserService:
|
||||
"""Manages a Playwright browser on a dedicated background thread.
|
||||
|
||||
All Playwright operations are dispatched to a single long-lived thread via
|
||||
a task queue. Callers from *any* worker thread can use the public API
|
||||
safely. An idle timer automatically shuts the browser down after
|
||||
``idle_timeout`` seconds of inactivity (default 300 = 5 min).
|
||||
"""
|
||||
|
||||
_IDLE_TIMEOUT_DEFAULT = 300 # seconds
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
self._config = config or {}
|
||||
self._headless: Optional[bool] = None
|
||||
self._screenshot_dir: Optional[str] = None
|
||||
|
||||
# Background thread state
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._task_queue: queue.Queue = queue.Queue()
|
||||
self._lock = threading.Lock()
|
||||
self._alive = False
|
||||
self._ready = threading.Event()
|
||||
|
||||
# Playwright objects (only accessed on the background thread)
|
||||
self._playwright = None
|
||||
self._browser = None
|
||||
self._context = None
|
||||
self._page = None
|
||||
|
||||
# Launch mode: one of "fresh" | "persistent" | "cdp".
|
||||
# - cdp: connect to an externally launched Chrome via CDP endpoint.
|
||||
# - persistent: launch with launch_persistent_context using a user_data_dir
|
||||
# so cookies / login state survive across runs (default).
|
||||
# - fresh: classic launch + new_context, clean state every run.
|
||||
cdp_endpoint = self._config.get("cdp_endpoint") or ""
|
||||
persistent_flag = self._config.get("persistent", True)
|
||||
user_data_dir_cfg = self._config.get("user_data_dir")
|
||||
if user_data_dir_cfg is None:
|
||||
user_data_dir_cfg = _DEFAULT_USER_DATA_DIR
|
||||
|
||||
self._cdp_endpoint: str = cdp_endpoint.strip() if isinstance(cdp_endpoint, str) else ""
|
||||
if self._cdp_endpoint:
|
||||
self._launch_mode = "cdp"
|
||||
self._user_data_dir: str = ""
|
||||
elif persistent_flag and user_data_dir_cfg:
|
||||
self._launch_mode = "persistent"
|
||||
self._user_data_dir = expand_path(str(user_data_dir_cfg))
|
||||
else:
|
||||
self._launch_mode = "fresh"
|
||||
self._user_data_dir = ""
|
||||
|
||||
# Idle auto-release
|
||||
idle_cfg = self._config.get("idle_timeout")
|
||||
self._idle_timeout: float = float(idle_cfg) if idle_cfg is not None else self._IDLE_TIMEOUT_DEFAULT
|
||||
self._idle_timer: Optional[threading.Timer] = None
|
||||
|
||||
# Set when the browser / page is detected to have died externally
|
||||
# (e.g. user manually closed the window). The next _submit() will then
|
||||
# tear down the stale thread and relaunch.
|
||||
self._needs_restart = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background-thread lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _start_thread(self):
|
||||
"""Start the dedicated Playwright thread if not already running."""
|
||||
with self._lock:
|
||||
if self._alive and self._thread and self._thread.is_alive():
|
||||
return
|
||||
# Wait for old thread to fully exit before creating a new one
|
||||
old = self._thread
|
||||
if old and old.is_alive():
|
||||
old.join(timeout=5)
|
||||
# Fresh queue to avoid stale sentinels from a previous close()
|
||||
self._task_queue = queue.Queue()
|
||||
self._alive = True
|
||||
self._ready = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True, name="BrowserThread")
|
||||
self._thread.start()
|
||||
# Block until browser is ready (or failed)
|
||||
self._ready.wait(timeout=30)
|
||||
|
||||
def _run_loop(self):
|
||||
"""Event loop running on the dedicated thread. Processes tasks until stopped."""
|
||||
logger.info("[Browser] Background thread started")
|
||||
try:
|
||||
self._launch_browser()
|
||||
except Exception as e:
|
||||
logger.error(f"[Browser] Failed to launch browser: {e}")
|
||||
self._alive = False
|
||||
self._ready.set()
|
||||
self._drain_queue(RuntimeError(f"Browser launch failed: {e}"))
|
||||
return
|
||||
self._ready.set()
|
||||
|
||||
while self._alive:
|
||||
try:
|
||||
task = self._task_queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if task is None:
|
||||
break
|
||||
fn, args, kwargs, result_slot = task
|
||||
try:
|
||||
result_slot["value"] = fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
result_slot["error"] = e
|
||||
if _is_browser_dead_error(e):
|
||||
self._needs_restart = True
|
||||
logger.warning(
|
||||
f"[Browser] Detected closed page/context ({e}); "
|
||||
"will relaunch on next request."
|
||||
)
|
||||
finally:
|
||||
result_slot["event"].set()
|
||||
|
||||
self._shutdown_browser()
|
||||
self._drain_queue(RuntimeError("Browser thread stopped"))
|
||||
logger.info("[Browser] Background thread exited")
|
||||
|
||||
def _drain_queue(self, error: Exception):
|
||||
"""Unblock all callers waiting on the queue with an error."""
|
||||
while True:
|
||||
try:
|
||||
task = self._task_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
if task is None:
|
||||
continue
|
||||
_, _, _, result_slot = task
|
||||
result_slot["error"] = error
|
||||
result_slot["event"].set()
|
||||
|
||||
def _launch_browser(self):
|
||||
"""Launch / connect Chromium on the background thread."""
|
||||
if self._headless is None:
|
||||
headless_cfg = self._config.get("headless")
|
||||
self._headless = headless_cfg if headless_cfg is not None else _should_use_headless()
|
||||
|
||||
launch_args = ["--disable-dev-shm-usage"]
|
||||
if self._headless:
|
||||
launch_args.append("--no-sandbox")
|
||||
|
||||
if is_cloud_deployment():
|
||||
launch_args.extend([
|
||||
"--disable-gpu",
|
||||
"--disable-software-rasterizer",
|
||||
"--disable-extensions",
|
||||
"--disable-background-networking",
|
||||
"--disable-background-timer-throttling",
|
||||
"--disable-renderer-backgrounding",
|
||||
"--disable-features=site-per-process,TranslateUI,IsolateOrigins",
|
||||
"--no-zygote",
|
||||
"--js-flags=--max-old-space-size=384",
|
||||
"--memory-pressure-off",
|
||||
])
|
||||
|
||||
extra_args = self._config.get("launch_args", [])
|
||||
if extra_args:
|
||||
launch_args.extend(extra_args)
|
||||
|
||||
viewport_w = self._config.get("viewport_width", 1280)
|
||||
viewport_h = self._config.get("viewport_height", 720)
|
||||
viewport = {"width": viewport_w, "height": viewport_h}
|
||||
user_agent = (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
self._playwright = sync_playwright().start()
|
||||
|
||||
if self._launch_mode == "cdp":
|
||||
self._connect_cdp(viewport)
|
||||
elif self._launch_mode == "persistent":
|
||||
self._launch_persistent(launch_args, viewport, user_agent)
|
||||
else:
|
||||
self._launch_fresh(launch_args, viewport, user_agent)
|
||||
|
||||
logger.info("[Browser] Browser ready")
|
||||
|
||||
def _launch_fresh(self, launch_args: List[str], viewport: Dict[str, int], user_agent: str):
|
||||
"""Classic launch: brand new Chromium with an empty context."""
|
||||
logger.info(f"[Browser] Launching Chromium (fresh, headless={self._headless})")
|
||||
self._browser = self._playwright.chromium.launch(
|
||||
headless=self._headless,
|
||||
args=launch_args,
|
||||
)
|
||||
self._context = self._browser.new_context(
|
||||
viewport=viewport,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
self._page = self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _launch_persistent(self, launch_args: List[str], viewport: Dict[str, int], user_agent: str):
|
||||
"""Launch Chromium with a persistent user_data_dir so login state survives."""
|
||||
os.makedirs(self._user_data_dir, exist_ok=True)
|
||||
logger.info(
|
||||
f"[Browser] Launching Chromium (persistent, headless={self._headless}, "
|
||||
f"profile={self._user_data_dir})"
|
||||
)
|
||||
try:
|
||||
self._context = self._playwright.chromium.launch_persistent_context(
|
||||
user_data_dir=self._user_data_dir,
|
||||
headless=self._headless,
|
||||
args=launch_args,
|
||||
viewport=viewport,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
except Exception as e:
|
||||
# Profile is locked when another Chromium instance already holds it.
|
||||
msg = str(e).lower()
|
||||
if "singletonlock" in msg or "profile" in msg or "lock" in msg:
|
||||
raise RuntimeError(
|
||||
f"Browser profile '{self._user_data_dir}' is in use by another process. "
|
||||
"Close the other Chromium / cow instance, or set a different "
|
||||
"tools.browser.user_data_dir."
|
||||
) from e
|
||||
raise
|
||||
|
||||
# Persistent context has no parent Browser handle; reuse the auto-created page.
|
||||
self._browser = None
|
||||
pages = self._context.pages
|
||||
self._page = pages[0] if pages else self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _connect_cdp(self, viewport: Dict[str, int]):
|
||||
"""Attach to an existing Chrome started with --remote-debugging-port."""
|
||||
endpoint = self._cdp_endpoint
|
||||
logger.info(f"[Browser] Connecting to existing Chrome via CDP: {endpoint}")
|
||||
try:
|
||||
self._browser = self._playwright.chromium.connect_over_cdp(endpoint)
|
||||
except Exception as e:
|
||||
msg = str(e).lower()
|
||||
if "econnrefused" in msg or "connect" in msg or "refused" in msg:
|
||||
raise RuntimeError(
|
||||
f"Cannot reach Chrome at {endpoint}. The CDP browser is not "
|
||||
"running. Ask the user to launch Chrome with "
|
||||
"--remote-debugging-port and --user-data-dir, then retry. "
|
||||
"Do not retry this tool until the user confirms."
|
||||
) from e
|
||||
raise
|
||||
|
||||
contexts = self._browser.contexts
|
||||
if contexts:
|
||||
self._context = contexts[0]
|
||||
else:
|
||||
self._context = self._browser.new_context(viewport=viewport)
|
||||
|
||||
pages = self._context.pages
|
||||
self._page = pages[0] if pages else self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _wire_close_listeners(self):
|
||||
"""Mark needs_restart whenever the browser / context / page dies externally."""
|
||||
def _on_dead(_obj=None):
|
||||
self._needs_restart = True
|
||||
|
||||
try:
|
||||
if self._browser:
|
||||
self._browser.on("disconnected", _on_dead)
|
||||
if self._context:
|
||||
self._context.on("close", _on_dead)
|
||||
if self._page:
|
||||
self._page.on("close", _on_dead)
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] Failed to wire close listeners: {e}")
|
||||
|
||||
def _shutdown_browser(self):
|
||||
"""Shut down Playwright resources on the background thread.
|
||||
|
||||
Mode-specific behavior:
|
||||
- cdp: only disconnect the Playwright client; leave the user's Chrome
|
||||
and its tabs untouched (do NOT close the context).
|
||||
- persistent: close the persistent context (no separate browser handle).
|
||||
- fresh: close context, then browser.
|
||||
"""
|
||||
self._cancel_idle_timer()
|
||||
|
||||
if self._launch_mode == "cdp":
|
||||
# For CDP, browser.close() only detaches the Playwright client;
|
||||
# the user's Chrome process and its tabs stay alive.
|
||||
try:
|
||||
if self._browser:
|
||||
self._browser.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] cdp disconnect error: {e}")
|
||||
else:
|
||||
for obj, label in [
|
||||
(self._context, "context"),
|
||||
(self._browser, "browser"),
|
||||
]:
|
||||
try:
|
||||
if obj:
|
||||
obj.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] {label} close error: {e}")
|
||||
|
||||
try:
|
||||
if self._playwright:
|
||||
self._playwright.stop()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] playwright stop error: {e}")
|
||||
self._page = None
|
||||
self._context = None
|
||||
self._browser = None
|
||||
self._playwright = None
|
||||
logger.info("[Browser] Browser closed")
|
||||
|
||||
def _submit(self, fn: Callable, *args, **kwargs):
|
||||
"""Submit *fn* to the background thread and block until it completes."""
|
||||
# If the browser died externally (e.g. user closed the window), tear
|
||||
# down the stale thread first so _start_thread() will relaunch fresh.
|
||||
if self._needs_restart:
|
||||
logger.info("[Browser] Restarting after detecting closed browser")
|
||||
self.close()
|
||||
self._needs_restart = False
|
||||
|
||||
self._start_thread()
|
||||
|
||||
if not self._alive:
|
||||
raise RuntimeError("Browser is not available")
|
||||
|
||||
self._reset_idle_timer()
|
||||
|
||||
result_slot: Dict[str, Any] = {"event": threading.Event()}
|
||||
self._task_queue.put((fn, args, kwargs, result_slot))
|
||||
|
||||
# Timeout prevents permanent hang if the background thread crashes
|
||||
completed = result_slot["event"].wait(timeout=120)
|
||||
if not completed:
|
||||
raise TimeoutError("Browser operation timed out (120s)")
|
||||
|
||||
if "error" in result_slot:
|
||||
raise result_slot["error"]
|
||||
return result_slot.get("value")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Idle auto-release
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _reset_idle_timer(self):
|
||||
self._cancel_idle_timer()
|
||||
if self._idle_timeout > 0:
|
||||
self._idle_timer = threading.Timer(self._idle_timeout, self._on_idle_timeout)
|
||||
self._idle_timer.daemon = True
|
||||
self._idle_timer.start()
|
||||
|
||||
def _cancel_idle_timer(self):
|
||||
if self._idle_timer:
|
||||
self._idle_timer.cancel()
|
||||
self._idle_timer = None
|
||||
|
||||
def _on_idle_timeout(self):
|
||||
logger.info(f"[Browser] Idle for {self._idle_timeout}s, auto-releasing browser")
|
||||
self.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def close(self):
|
||||
"""Shut down browser and background thread (safe from any thread)."""
|
||||
self._cancel_idle_timer()
|
||||
with self._lock:
|
||||
if not self._alive:
|
||||
self._needs_restart = False
|
||||
return
|
||||
self._alive = False
|
||||
t = self._thread
|
||||
if self._task_queue is not None:
|
||||
self._task_queue.put(None)
|
||||
if t is not None and t.is_alive():
|
||||
t.join(timeout=10)
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
self._needs_restart = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Actions (each method is dispatched to the background thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def navigate(self, url: str, timeout: int = 30000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_navigate, url, timeout)
|
||||
|
||||
def _do_navigate(self, url: str, timeout: int) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
resp = page.goto(url, wait_until="domcontentloaded", timeout=timeout)
|
||||
status = resp.status if resp else None
|
||||
except Exception as e:
|
||||
return {"error": f"Navigation failed: {e}"}
|
||||
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=8000)
|
||||
except Exception:
|
||||
pass
|
||||
page.wait_for_timeout(500)
|
||||
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
current_url = page.url
|
||||
except Exception:
|
||||
current_url = url
|
||||
|
||||
return {"url": current_url, "title": title, "status": status}
|
||||
|
||||
def snapshot(self, selector: Optional[str] = None) -> str:
|
||||
return self._submit(self._do_snapshot, selector)
|
||||
|
||||
def _do_snapshot(self, selector: Optional[str] = None) -> str:
|
||||
page = self._page
|
||||
try:
|
||||
result = page.evaluate(_SNAPSHOT_JS)
|
||||
except Exception as e:
|
||||
return f"[Snapshot error: {e}]"
|
||||
|
||||
tree = result.get("tree")
|
||||
ref_count = result.get("refCount", 0)
|
||||
lines = _flatten_tree(tree)
|
||||
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
|
||||
header = f"Page: {title} ({url})\nInteractive elements: {ref_count}\n---"
|
||||
body = "\n".join(lines)
|
||||
|
||||
max_chars = self._config.get("snapshot_max_chars", 30000)
|
||||
if len(body) > max_chars:
|
||||
body = body[:max_chars] + "\n... [snapshot truncated]"
|
||||
|
||||
return f"{header}\n{body}"
|
||||
|
||||
def screenshot(self, full_page: bool = False, cwd: str = "") -> str:
|
||||
return self._submit(self._do_screenshot, full_page, cwd)
|
||||
|
||||
def _do_screenshot(self, full_page: bool = False, cwd: str = "") -> str:
|
||||
page = self._page
|
||||
save_dir = self._get_screenshot_dir(cwd)
|
||||
filename = f"screenshot_{uuid.uuid4().hex[:8]}.png"
|
||||
filepath = os.path.join(save_dir, filename)
|
||||
page.screenshot(path=filepath, full_page=full_page)
|
||||
logger.info(f"[Browser] Screenshot saved: {filepath}")
|
||||
return filepath
|
||||
|
||||
def click(self, ref: Optional[int] = None, selector: Optional[str] = None,
|
||||
timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_click, ref, selector, timeout)
|
||||
|
||||
def _do_click(self, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
|
||||
el.click();
|
||||
return {{ clicked: true, tag: el.tagName.toLowerCase() }};
|
||||
}}
|
||||
""")
|
||||
if result.get("error"):
|
||||
return result
|
||||
page.wait_for_timeout(500)
|
||||
return result
|
||||
elif selector:
|
||||
page.click(selector, timeout=timeout)
|
||||
return {"clicked": True, "selector": selector}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Click failed: {e}"}
|
||||
|
||||
def fill(self, text: str, ref: Optional[int] = None,
|
||||
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_fill, text, ref, selector, timeout)
|
||||
|
||||
def _do_fill(self, text, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
|
||||
el.focus();
|
||||
el.value = "";
|
||||
return {{ tag: el.tagName.toLowerCase(), name: el.name || "" }};
|
||||
}}
|
||||
""")
|
||||
if result.get("error"):
|
||||
return result
|
||||
page.keyboard.type(text)
|
||||
return {"filled": True, "ref": ref, "text": text}
|
||||
elif selector:
|
||||
page.fill(selector, text, timeout=timeout)
|
||||
return {"filled": True, "selector": selector, "text": text}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Fill failed: {e}"}
|
||||
|
||||
def select(self, value: str, ref: Optional[int] = None,
|
||||
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_select, value, ref, selector, timeout)
|
||||
|
||||
def _do_select(self, value, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el || el.tagName.toLowerCase() !== "select")
|
||||
return {{ error: "ref {ref} is not a <select> element" }};
|
||||
el.value = {repr(value)};
|
||||
el.dispatchEvent(new Event("change", {{ bubbles: true }}));
|
||||
return {{ selected: true, value: el.value }};
|
||||
}}
|
||||
""")
|
||||
return result
|
||||
elif selector:
|
||||
page.select_option(selector, value, timeout=timeout)
|
||||
return {"selected": True, "selector": selector, "value": value}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Select failed: {e}"}
|
||||
|
||||
def scroll(self, direction: str = "down", amount: int = 500) -> Dict[str, Any]:
|
||||
return self._submit(self._do_scroll, direction, amount)
|
||||
|
||||
def _do_scroll(self, direction, amount) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
delta_map = {
|
||||
"down": (0, amount),
|
||||
"up": (0, -amount),
|
||||
"right": (amount, 0),
|
||||
"left": (-amount, 0),
|
||||
}
|
||||
dx, dy = delta_map.get(direction, (0, amount))
|
||||
try:
|
||||
page.mouse.wheel(dx, dy)
|
||||
page.wait_for_timeout(300)
|
||||
scroll_info = page.evaluate("""
|
||||
() => ({
|
||||
scrollX: window.scrollX,
|
||||
scrollY: window.scrollY,
|
||||
scrollHeight: document.documentElement.scrollHeight,
|
||||
clientHeight: document.documentElement.clientHeight
|
||||
})
|
||||
""")
|
||||
return {"scrolled": direction, "amount": amount, **scroll_info}
|
||||
except Exception as e:
|
||||
return {"error": f"Scroll failed: {e}"}
|
||||
|
||||
def wait(self, selector: Optional[str] = None, timeout: int = 5000,
|
||||
state: str = "visible") -> Dict[str, Any]:
|
||||
return self._submit(self._do_wait, selector, timeout, state)
|
||||
|
||||
def _do_wait(self, selector, timeout, state) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if selector:
|
||||
page.wait_for_selector(selector, timeout=timeout, state=state)
|
||||
return {"waited": True, "selector": selector, "state": state}
|
||||
else:
|
||||
page.wait_for_timeout(timeout)
|
||||
return {"waited": True, "timeout_ms": timeout}
|
||||
except Exception as e:
|
||||
return {"error": f"Wait failed: {e}"}
|
||||
|
||||
def go_back(self) -> Dict[str, Any]:
|
||||
return self._submit(self._do_go_back)
|
||||
|
||||
def _do_go_back(self) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.go_back(wait_until="domcontentloaded", timeout=10000)
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
return {"url": url, "title": title}
|
||||
except Exception as e:
|
||||
return {"error": f"Go back failed: {e}"}
|
||||
|
||||
def go_forward(self) -> Dict[str, Any]:
|
||||
return self._submit(self._do_go_forward)
|
||||
|
||||
def _do_go_forward(self) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.go_forward(wait_until="domcontentloaded", timeout=10000)
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
return {"url": url, "title": title}
|
||||
except Exception as e:
|
||||
return {"error": f"Go forward failed: {e}"}
|
||||
|
||||
def get_text(self, selector: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_get_text, selector)
|
||||
|
||||
def _do_get_text(self, selector) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
text = page.text_content(selector, timeout=5000)
|
||||
return {"text": text or ""}
|
||||
except Exception as e:
|
||||
return {"error": f"Get text failed: {e}"}
|
||||
|
||||
def evaluate(self, script: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_evaluate, script)
|
||||
|
||||
def _do_evaluate(self, script) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
result = page.evaluate(script)
|
||||
return {"result": result}
|
||||
except Exception as e:
|
||||
return {"error": f"Evaluate failed: {e}"}
|
||||
|
||||
def press(self, key: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_press, key)
|
||||
|
||||
def _do_press(self, key) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.keyboard.press(key)
|
||||
page.wait_for_timeout(300)
|
||||
return {"pressed": key}
|
||||
except Exception as e:
|
||||
return {"error": f"Press failed: {e}"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_screenshot_dir(self, cwd: str = "") -> str:
|
||||
if self._screenshot_dir and os.path.isdir(self._screenshot_dir):
|
||||
return self._screenshot_dir
|
||||
base = cwd or os.getcwd()
|
||||
d = os.path.join(base, "tmp")
|
||||
os.makedirs(d, exist_ok=True)
|
||||
self._screenshot_dir = d
|
||||
return d
|
||||
303
agent/tools/browser/browser_tool.py
Normal file
303
agent/tools/browser/browser_tool.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Browser tool - Control a Chromium browser for web navigation and interaction.
|
||||
|
||||
Uses Playwright under the hood. Browser instance is lazily started on first
|
||||
use, reused across tool calls within the same session, and cleaned up via
|
||||
close().
|
||||
|
||||
Launch modes (configured under `tools.browser` in config.json):
|
||||
- persistent (default): Chromium runs with a persistent user_data_dir
|
||||
(default `~/.cow/browser_profile`), so cookies and login state survive
|
||||
across runs. The user only needs to log in once.
|
||||
- cdp: When `cdp_endpoint` is set, attach to an externally launched Chrome
|
||||
via the Chrome DevTools Protocol. Lets the agent reuse the user's real
|
||||
browser (with all logins / extensions / true fingerprints).
|
||||
- fresh: Set `persistent` to false to fall back to a clean context every run.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.browser.browser_service import BrowserService
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class BrowserTool(BaseTool):
|
||||
"""Single tool exposing all browser actions via an 'action' parameter."""
|
||||
|
||||
name: str = "browser"
|
||||
description: str = (
|
||||
"Control a browser to navigate web pages, interact with elements, and extract content. "
|
||||
"Actions: navigate, snapshot, click, fill, select, scroll, screenshot, wait, back, forward, "
|
||||
"get_text, press, evaluate.\n\n"
|
||||
"Workflow: navigate (auto-includes snapshot with element refs) → click/fill/select by ref → snapshot to verify.\n\n"
|
||||
"Use snapshot as the primary way to read pages. Use screenshot + send to show key results to the user. "
|
||||
"For login/CAPTCHA/authorization etc., screenshot and ask the user for help. "
|
||||
"Login state is persisted across sessions (cookies / localStorage are kept in a "
|
||||
"user profile directory), so once the user logs in to a site, the agent can keep "
|
||||
"using it without logging in again."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The browser action to perform. One of: "
|
||||
"navigate, snapshot, click, fill, select, scroll, "
|
||||
"screenshot, wait, back, forward, get_text, press, evaluate"
|
||||
),
|
||||
"enum": [
|
||||
"navigate", "snapshot", "click", "fill", "select", "scroll",
|
||||
"screenshot", "wait", "back", "forward", "get_text", "press",
|
||||
"evaluate"
|
||||
]
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to navigate to (for 'navigate' action)"
|
||||
},
|
||||
"ref": {
|
||||
"type": "integer",
|
||||
"description": "Element ref number from snapshot (for click/fill/select)"
|
||||
},
|
||||
"selector": {
|
||||
"type": "string",
|
||||
"description": "CSS selector as fallback when ref is unavailable (for click/fill/select/wait/get_text)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to type (for 'fill' action)"
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Option value (for 'select' action)"
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key to press, e.g. Enter, Tab, Escape (for 'press' action)"
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"description": "Scroll direction: up, down, left, right (for 'scroll' action, default: down)"
|
||||
},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "JavaScript code to execute (for 'evaluate' action)"
|
||||
},
|
||||
"full_page": {
|
||||
"type": "boolean",
|
||||
"description": "Capture full page screenshot (for 'screenshot' action, default: false)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in milliseconds (optional, default varies by action)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
_shared_service: Optional[BrowserService] = None
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self._service: Optional[BrowserService] = None
|
||||
|
||||
def _get_service(self) -> BrowserService:
|
||||
"""Get or create the browser service, sharing across copies."""
|
||||
if self._service is not None:
|
||||
return self._service
|
||||
|
||||
# Reuse shared service across tool copies within the same session
|
||||
if BrowserTool._shared_service is not None:
|
||||
self._service = BrowserTool._shared_service
|
||||
return self._service
|
||||
|
||||
self._service = BrowserService(self.config)
|
||||
BrowserTool._shared_service = self._service
|
||||
return self._service
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
action = args.get("action", "").strip().lower()
|
||||
if not action:
|
||||
return ToolResult.fail("Error: 'action' parameter is required")
|
||||
|
||||
handler = self._ACTION_MAP.get(action)
|
||||
if not handler:
|
||||
valid = ", ".join(sorted(self._ACTION_MAP.keys()))
|
||||
return ToolResult.fail(f"Unknown action '{action}'. Valid actions: {valid}")
|
||||
|
||||
try:
|
||||
return handler(self, args)
|
||||
except Exception as e:
|
||||
logger.error(f"[Browser] Action '{action}' error: {e}")
|
||||
return ToolResult.fail(f"Browser error ({action}): {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _do_navigate(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' is required for navigate action")
|
||||
# Only auto-prepend https:// for bare hosts; preserve file://, about:, data:, etc.
|
||||
if "://" not in url and not url.startswith(("about:", "data:")):
|
||||
url = "https://" + url
|
||||
timeout = args.get("timeout", 30000)
|
||||
service = self._get_service()
|
||||
result = service.navigate(url, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
# Auto-snapshot after navigation so the agent gets page content in one call
|
||||
snapshot_text = service.snapshot()
|
||||
return ToolResult.success(
|
||||
f"Navigated to: {result['url']}\nTitle: {result['title']}\nStatus: {result['status']}\n\n"
|
||||
f"--- Page Snapshot ---\n{snapshot_text}"
|
||||
)
|
||||
|
||||
def _do_snapshot(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector")
|
||||
text = self._get_service().snapshot(selector=selector)
|
||||
return ToolResult.success(text)
|
||||
|
||||
def _do_click(self, args: Dict[str, Any]) -> ToolResult:
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
result = self._get_service().click(ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Clicked successfully. Use 'snapshot' to see updated page.")
|
||||
|
||||
def _do_fill(self, args: Dict[str, Any]) -> ToolResult:
|
||||
text = args.get("text", "")
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
if not text and text != "":
|
||||
return ToolResult.fail("Error: 'text' is required for fill action")
|
||||
result = self._get_service().fill(text, ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Filled text into element. Use 'snapshot' to verify.")
|
||||
|
||||
def _do_select(self, args: Dict[str, Any]) -> ToolResult:
|
||||
value = args.get("value", "")
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
if not value:
|
||||
return ToolResult.fail("Error: 'value' is required for select action")
|
||||
result = self._get_service().select(value, ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Selected option '{value}'.")
|
||||
|
||||
def _do_scroll(self, args: Dict[str, Any]) -> ToolResult:
|
||||
direction = args.get("direction", "down")
|
||||
amount = args.get("timeout", 500) # reuse timeout field or default
|
||||
if "amount" in args:
|
||||
amount = args["amount"]
|
||||
result = self._get_service().scroll(direction=direction, amount=amount)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
pos = f"scrollY={result.get('scrollY', '?')}/{result.get('scrollHeight', '?')}"
|
||||
return ToolResult.success(f"Scrolled {direction}. Position: {pos}")
|
||||
|
||||
def _do_screenshot(self, args: Dict[str, Any]) -> ToolResult:
|
||||
full_page = args.get("full_page", False)
|
||||
filepath = self._get_service().screenshot(full_page=full_page, cwd=self.cwd)
|
||||
return ToolResult.success(f"Screenshot saved to: {filepath}")
|
||||
|
||||
def _do_wait(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
result = self._get_service().wait(selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Wait completed.")
|
||||
|
||||
def _do_back(self, args: Dict[str, Any]) -> ToolResult:
|
||||
result = self._get_service().go_back()
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Navigated back to: {result['url']}")
|
||||
|
||||
def _do_forward(self, args: Dict[str, Any]) -> ToolResult:
|
||||
result = self._get_service().go_forward()
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Navigated forward to: {result['url']}")
|
||||
|
||||
def _do_get_text(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector", "").strip()
|
||||
if not selector:
|
||||
return ToolResult.fail("Error: 'selector' is required for get_text action")
|
||||
result = self._get_service().get_text(selector)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(result["text"])
|
||||
|
||||
def _do_press(self, args: Dict[str, Any]) -> ToolResult:
|
||||
key = args.get("key", "").strip()
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for press action")
|
||||
result = self._get_service().press(key)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Pressed key: {key}")
|
||||
|
||||
def _do_evaluate(self, args: Dict[str, Any]) -> ToolResult:
|
||||
script = args.get("script", "").strip()
|
||||
if not script:
|
||||
return ToolResult.fail("Error: 'script' is required for evaluate action")
|
||||
result = self._get_service().evaluate(script)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
val = result.get("result")
|
||||
if isinstance(val, (dict, list)):
|
||||
return ToolResult.success(json.dumps(val, ensure_ascii=False, indent=2))
|
||||
return ToolResult.success(str(val) if val is not None else "(no return value)")
|
||||
|
||||
# Action dispatch table
|
||||
_ACTION_MAP = {
|
||||
"navigate": _do_navigate,
|
||||
"snapshot": _do_snapshot,
|
||||
"click": _do_click,
|
||||
"fill": _do_fill,
|
||||
"select": _do_select,
|
||||
"scroll": _do_scroll,
|
||||
"screenshot": _do_screenshot,
|
||||
"wait": _do_wait,
|
||||
"back": _do_back,
|
||||
"forward": _do_forward,
|
||||
"get_text": _do_get_text,
|
||||
"press": _do_press,
|
||||
"evaluate": _do_evaluate,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def copy(self):
|
||||
"""Share browser instance across tool copies (avoids re-launching)."""
|
||||
new_tool = BrowserTool(self.config)
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, "context", None)
|
||||
new_tool.cwd = self.cwd
|
||||
new_tool._service = self._service
|
||||
return new_tool
|
||||
|
||||
def close(self):
|
||||
"""Release browser resources."""
|
||||
if self._service:
|
||||
self._service.close()
|
||||
self._service = None
|
||||
BrowserTool._shared_service = None
|
||||
logger.info("[Browser] BrowserTool closed")
|
||||
@@ -1,18 +0,0 @@
|
||||
def copy(self):
|
||||
"""
|
||||
Special copy method for browser tool to avoid recreating browser instance.
|
||||
|
||||
:return: A new instance with shared browser reference but unique model
|
||||
"""
|
||||
new_tool = self.__class__()
|
||||
|
||||
# Copy essential attributes
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, 'context', None)
|
||||
new_tool.config = getattr(self, 'config', None)
|
||||
|
||||
# Share the browser instance instead of creating a new one
|
||||
if hasattr(self, 'browser'):
|
||||
new_tool.browser = self.browser
|
||||
|
||||
return new_tool
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
from agent.tools.utils.diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
@@ -178,7 +179,7 @@ class Edit(BaseTool):
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
@@ -9,6 +9,7 @@ from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
# API Key 知识库:常见的环境变量及其描述
|
||||
@@ -66,7 +67,7 @@ class EnvConfig(BaseTool):
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
# Store env config in ~/.cow directory (outside workspace for security)
|
||||
self.env_dir = os.path.expanduser("~/.cow")
|
||||
self.env_dir = expand_path("~/.cow")
|
||||
self.env_path = os.path.join(self.env_dir, '.env')
|
||||
self.agent_bridge = self.config.get("agent_bridge") # Reference to AgentBridge for hot reload
|
||||
# Don't create .env file in __init__ to avoid issues during tool discovery
|
||||
@@ -201,7 +202,8 @@ class EnvConfig(BaseTool):
|
||||
"key": key,
|
||||
"value": self._mask_value(value),
|
||||
"description": description,
|
||||
"exists": True
|
||||
"exists": True,
|
||||
"note": f"Value is masked for security. In bash, use ${key} directly — it is auto-injected."
|
||||
})
|
||||
else:
|
||||
return ToolResult.success({
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 500
|
||||
@@ -51,7 +52,7 @@ class Ls(BaseTool):
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent accessing sensitive config directory
|
||||
env_config_dir = os.path.expanduser("~/.cow")
|
||||
env_config_dir = expand_path("~/.cow")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_dir):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
@@ -93,7 +94,7 @@ class Ls(BaseTool):
|
||||
results.append(entry + '/')
|
||||
else:
|
||||
results.append(entry)
|
||||
except:
|
||||
except Exception:
|
||||
# Skip entries we can't stat
|
||||
continue
|
||||
|
||||
@@ -133,7 +134,7 @@ class Ls(BaseTool):
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
4
agent/tools/mcp/__init__.py
Normal file
4
agent/tools/mcp/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from agent.tools.mcp.mcp_client import McpClient, McpClientRegistry
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
|
||||
__all__ = ["McpClient", "McpClientRegistry", "McpTool"]
|
||||
528
agent/tools/mcp/mcp_client.py
Normal file
528
agent/tools/mcp/mcp_client.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) client module.
|
||||
|
||||
Implements JSON-RPC 2.0 over stdio, SSE and Streamable HTTP transports
|
||||
without any external MCP SDK dependency.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import select
|
||||
import subprocess
|
||||
import threading
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# Aliases accepted for the Streamable HTTP transport type
|
||||
_STREAMABLE_HTTP_ALIASES = {"streamable-http", "streamable_http", "streamablehttp", "http"}
|
||||
|
||||
|
||||
class McpClient:
|
||||
"""Single MCP Server client supporting stdio, SSE and Streamable HTTP transports."""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""
|
||||
config examples:
|
||||
stdio: {"name": "filesystem", "type": "stdio", "command": "npx", "args": [...]}
|
||||
SSE: {"name": "my-api", "type": "sse", "url": "http://localhost:8000/sse"}
|
||||
streamable-http: {"name": "pubmed", "type": "streamable-http", "url": "https://x/mcp"}
|
||||
"""
|
||||
self.config = config
|
||||
self.name: str = config.get("name", "unknown")
|
||||
raw_transport: str = config.get("type", "stdio")
|
||||
# Normalize streamable-http aliases to a single internal key
|
||||
self.transport: str = (
|
||||
"streamable-http"
|
||||
if raw_transport.lower() in _STREAMABLE_HTTP_ALIASES
|
||||
else raw_transport
|
||||
)
|
||||
|
||||
# stdio state
|
||||
self._proc: Optional[subprocess.Popen] = None
|
||||
|
||||
# SSE state
|
||||
self._sse_url: Optional[str] = None
|
||||
self._post_url: Optional[str] = None # endpoint for sending messages (resolved from SSE)
|
||||
|
||||
# Streamable HTTP state
|
||||
self._http_url: Optional[str] = None
|
||||
self._http_headers: dict = {} # extra headers from user config (e.g. Authorization)
|
||||
self._http_session_id: Optional[str] = None # Mcp-Session-Id assigned by the server
|
||||
|
||||
# Shared state
|
||||
self._next_id = 1
|
||||
self._id_lock = threading.Lock()
|
||||
self._call_lock = threading.Lock()
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Connect and perform the MCP handshake. Returns True on success."""
|
||||
try:
|
||||
if self.transport == "stdio":
|
||||
return self._init_stdio()
|
||||
elif self.transport == "sse":
|
||||
return self._init_sse()
|
||||
elif self.transport == "streamable-http":
|
||||
return self._init_streamable_http()
|
||||
else:
|
||||
logger.warning(f"[MCP:{self.name}] Unknown transport type: {self.transport!r}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] Initialization failed: {e}")
|
||||
return False
|
||||
|
||||
def list_tools(self) -> list:
|
||||
"""Return the tool list from this server.
|
||||
|
||||
Each item is a dict: {"name": str, "description": str, "inputSchema": dict}
|
||||
"""
|
||||
try:
|
||||
resp = self._send_request("tools/list", {})
|
||||
tools = resp.get("result", {}).get("tools", [])
|
||||
return [
|
||||
{
|
||||
"name": t.get("name", ""),
|
||||
"description": t.get("description", ""),
|
||||
"inputSchema": t.get("inputSchema", {}),
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] list_tools failed: {e}")
|
||||
return []
|
||||
|
||||
def call_tool(self, name: str, arguments: dict) -> str:
|
||||
"""Call a tool and return the result as a string."""
|
||||
try:
|
||||
resp = self._send_request("tools/call", {"name": name, "arguments": arguments})
|
||||
content = resp.get("result", {}).get("content", [])
|
||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||
return "\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] call_tool({name}) failed: {e}")
|
||||
return f"Error: {e}"
|
||||
|
||||
def shutdown(self):
|
||||
"""Close the connection / terminate the child process."""
|
||||
if self._proc is not None:
|
||||
try:
|
||||
self._proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=5)
|
||||
except Exception:
|
||||
try:
|
||||
self._proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
self._proc = None
|
||||
logger.debug(f"[MCP:{self.name}] stdio process terminated")
|
||||
|
||||
# Best-effort streamable-http session termination
|
||||
if self.transport == "streamable-http" and self._http_session_id and self._http_url:
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
self._http_url,
|
||||
method="DELETE",
|
||||
headers={"Mcp-Session-Id": self._http_session_id, **self._http_headers},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5):
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
self._http_session_id = None
|
||||
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# stdio transport
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_stdio(self) -> bool:
|
||||
command = self.config.get("command")
|
||||
if not command:
|
||||
logger.warning(f"[MCP:{self.name}] stdio config missing 'command'")
|
||||
return False
|
||||
|
||||
args = self.config.get("args", [])
|
||||
extra_env = self.config.get("env", None)
|
||||
env = {**os.environ, **extra_env} if extra_env else None
|
||||
|
||||
self._proc = subprocess.Popen(
|
||||
[command] + list(args),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
env=env,
|
||||
)
|
||||
logger.debug(f"[MCP:{self.name}] stdio process started (pid={self._proc.pid})")
|
||||
|
||||
threading.Thread(
|
||||
target=self._drain_stderr, daemon=True, name=f"mcp-stderr-{self.name}"
|
||||
).start()
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _drain_stderr(self):
|
||||
for line in self._proc.stderr:
|
||||
line = line.strip()
|
||||
if line:
|
||||
logger.debug(f"[MCP:{self.name}] stderr: {line}")
|
||||
|
||||
def _readline_with_timeout(self, timeout: int = 30) -> str:
|
||||
"""Read one line from stdio stdout with a hard timeout."""
|
||||
ready, _, _ = select.select([self._proc.stdout], [], [], timeout)
|
||||
if not ready:
|
||||
raise TimeoutError(f"[MCP:{self.name}] stdio read timed out after {timeout}s")
|
||||
return self._proc.stdout.readline()
|
||||
|
||||
def _stdio_send(self, message: dict) -> dict:
|
||||
"""Send a JSON-RPC message over stdio and read the response."""
|
||||
raw = json.dumps(message) + "\n"
|
||||
self._proc.stdin.write(raw)
|
||||
self._proc.stdin.flush()
|
||||
|
||||
while True:
|
||||
line = self._readline_with_timeout()
|
||||
if not line:
|
||||
raise IOError(f"[MCP:{self.name}] stdio process closed unexpectedly")
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if "id" not in data:
|
||||
logger.debug(f"[MCP:{self.name}] notification skipped: {data.get('method', '?')}")
|
||||
continue
|
||||
return data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE transport
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_sse(self) -> bool:
|
||||
url = self.config.get("url")
|
||||
if not url:
|
||||
logger.warning(f"[MCP:{self.name}] SSE config missing 'url'")
|
||||
return False
|
||||
|
||||
self._sse_url = url
|
||||
|
||||
# Read the first SSE event to discover the POST endpoint
|
||||
try:
|
||||
self._post_url = self._sse_discover_endpoint()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] SSE endpoint discovery failed: {e}")
|
||||
return False
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _sse_discover_endpoint(self) -> str:
|
||||
"""Open SSE stream and read the 'endpoint' event to learn the POST URL."""
|
||||
req = urllib.request.Request(
|
||||
self._sse_url,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
for raw_line in resp:
|
||||
line = raw_line.decode("utf-8").rstrip("\n\r")
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data:"):].strip()
|
||||
# Some servers send JSON with a "uri" or plain path
|
||||
if data.startswith("{"):
|
||||
parsed = json.loads(data)
|
||||
return parsed.get("uri") or parsed.get("url") or parsed.get("endpoint")
|
||||
# Plain relative or absolute URL
|
||||
if data.startswith("http"):
|
||||
return data
|
||||
# Relative path: resolve against SSE base
|
||||
from urllib.parse import urljoin
|
||||
return urljoin(self._sse_url, data)
|
||||
raise ValueError(f"[MCP:{self.name}] No endpoint event received from SSE stream")
|
||||
|
||||
def _sse_send(self, message: dict) -> dict:
|
||||
"""POST a JSON-RPC message to the server and return the response."""
|
||||
body = json.dumps(message).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
self._post_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
raw = resp.read().decode("utf-8")
|
||||
return json.loads(raw)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streamable HTTP transport (MCP spec 2025-03-26)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_streamable_http(self) -> bool:
|
||||
url = self.config.get("url")
|
||||
if not url:
|
||||
logger.warning(f"[MCP:{self.name}] streamable-http config missing 'url'")
|
||||
return False
|
||||
|
||||
self._http_url = url
|
||||
# Allow user-provided headers (e.g. {"Authorization": "Bearer xxx"})
|
||||
extra_headers = self.config.get("headers") or {}
|
||||
if isinstance(extra_headers, dict):
|
||||
self._http_headers = {str(k): str(v) for k, v in extra_headers.items()}
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _streamable_http_send(self, message: dict) -> dict:
|
||||
"""POST a JSON-RPC request and return the response (JSON or SSE-wrapped)."""
|
||||
return self._streamable_http_post(message, expect_response=True)
|
||||
|
||||
def _streamable_http_post(self, message: dict, expect_response: bool) -> dict:
|
||||
"""
|
||||
POST a JSON-RPC message over Streamable HTTP.
|
||||
|
||||
Per the spec, the response Content-Type can be either:
|
||||
- application/json -> single JSON-RPC response in body
|
||||
- text/event-stream -> SSE stream; we read until we get a matching response
|
||||
"""
|
||||
body = json.dumps(message).encode("utf-8")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
if self._http_session_id:
|
||||
headers["Mcp-Session-Id"] = self._http_session_id
|
||||
headers.update(self._http_headers)
|
||||
|
||||
req = urllib.request.Request(
|
||||
self._http_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = urllib.request.urlopen(req, timeout=30)
|
||||
except urllib.error.HTTPError as e:
|
||||
# Surface the server-provided error body for easier debugging
|
||||
detail = ""
|
||||
try:
|
||||
detail = e.read().decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
pass
|
||||
raise IOError(
|
||||
f"[MCP:{self.name}] streamable-http HTTP {e.code}: {detail[:200]}"
|
||||
)
|
||||
|
||||
with resp:
|
||||
# Capture session id assigned by the server (if any)
|
||||
session_id = resp.headers.get("Mcp-Session-Id")
|
||||
if session_id and not self._http_session_id:
|
||||
self._http_session_id = session_id
|
||||
|
||||
status = resp.status if hasattr(resp, "status") else resp.getcode()
|
||||
|
||||
# Notifications: server may reply with 202 Accepted and no body
|
||||
if not expect_response or status == 202:
|
||||
try:
|
||||
resp.read()
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
content_type = (resp.headers.get("Content-Type") or "").lower()
|
||||
expected_id = message.get("id")
|
||||
|
||||
if "text/event-stream" in content_type:
|
||||
return self._read_sse_response(resp, expected_id)
|
||||
|
||||
raw = resp.read().decode("utf-8")
|
||||
if not raw:
|
||||
return {}
|
||||
return json.loads(raw)
|
||||
|
||||
def _read_sse_response(self, resp, expected_id) -> dict:
|
||||
"""Read an SSE stream and return the first JSON-RPC response with matching id."""
|
||||
data_buf: list = []
|
||||
for raw_line in resp:
|
||||
line = raw_line.decode("utf-8").rstrip("\n\r")
|
||||
if line == "":
|
||||
# End of an SSE event, attempt to parse accumulated data
|
||||
if data_buf:
|
||||
payload = "\n".join(data_buf)
|
||||
data_buf = []
|
||||
try:
|
||||
msg = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
# Skip notifications / mismatched ids
|
||||
if "id" not in msg:
|
||||
continue
|
||||
if expected_id is None or msg.get("id") == expected_id:
|
||||
return msg
|
||||
continue
|
||||
if line.startswith(":"):
|
||||
continue # SSE comment / keepalive
|
||||
if line.startswith("data:"):
|
||||
data_buf.append(line[len("data:"):].lstrip())
|
||||
# Ignore 'event:' / 'id:' lines; we only care about JSON-RPC payloads
|
||||
|
||||
raise IOError(f"[MCP:{self.name}] streamable-http SSE stream closed before response")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Common JSON-RPC helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _next_request_id(self) -> int:
|
||||
with self._id_lock:
|
||||
rid = self._next_id
|
||||
self._next_id += 1
|
||||
return rid
|
||||
|
||||
def _build_request(self, method: str, params: dict) -> dict:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._next_request_id(),
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
def _build_notification(self, method: str, params: dict) -> dict:
|
||||
return {"jsonrpc": "2.0", "method": method, "params": params}
|
||||
|
||||
def _send_request(self, method: str, params: dict) -> dict:
|
||||
"""Send a request and return the full response dict."""
|
||||
if not self._initialized and method != "initialize":
|
||||
raise RuntimeError(f"[MCP:{self.name}] Client not initialized")
|
||||
|
||||
message = self._build_request(method, params)
|
||||
|
||||
with self._call_lock:
|
||||
if self.transport == "stdio":
|
||||
return self._stdio_send(message)
|
||||
elif self.transport == "sse":
|
||||
return self._sse_send(message)
|
||||
elif self.transport == "streamable-http":
|
||||
return self._streamable_http_send(message)
|
||||
else:
|
||||
raise ValueError(f"[MCP:{self.name}] Unsupported transport: {self.transport}")
|
||||
|
||||
def _send_notification(self, method: str, params: dict):
|
||||
"""Fire-and-forget notification (no response expected)."""
|
||||
notification = self._build_notification(method, params)
|
||||
raw = json.dumps(notification) + "\n"
|
||||
|
||||
if self.transport == "stdio":
|
||||
self._proc.stdin.write(raw)
|
||||
self._proc.stdin.flush()
|
||||
elif self.transport == "sse":
|
||||
body = raw.encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
self._post_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10):
|
||||
pass
|
||||
except Exception:
|
||||
pass # notifications are fire-and-forget
|
||||
elif self.transport == "streamable-http":
|
||||
try:
|
||||
self._streamable_http_post(notification, expect_response=False)
|
||||
except Exception:
|
||||
pass # notifications are fire-and-forget
|
||||
|
||||
def _handshake(self) -> bool:
|
||||
"""Perform the MCP initialize / notifications/initialized handshake."""
|
||||
init_params = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "CowAgent", "version": "1.0"},
|
||||
}
|
||||
# Temporarily mark as initialized so _send_request doesn't block
|
||||
self._initialized = True
|
||||
try:
|
||||
resp = self._send_request("initialize", init_params)
|
||||
except Exception as e:
|
||||
self._initialized = False
|
||||
logger.warning(f"[MCP:{self.name}] Handshake initialize failed: {e}")
|
||||
return False
|
||||
|
||||
if "error" in resp:
|
||||
self._initialized = False
|
||||
logger.warning(f"[MCP:{self.name}] Handshake error: {resp['error']}")
|
||||
return False
|
||||
|
||||
self._send_notification("notifications/initialized", {})
|
||||
logger.debug(f"[MCP:{self.name}] Handshake complete")
|
||||
return True
|
||||
|
||||
|
||||
class McpClientRegistry:
|
||||
"""Global singleton managing the lifecycle of all MCP Server clients."""
|
||||
|
||||
_instance = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
obj = super().__new__(cls)
|
||||
obj._clients: dict[str, McpClient] = {}
|
||||
obj._registry_lock = threading.Lock()
|
||||
cls._instance = obj
|
||||
return cls._instance
|
||||
|
||||
def start_all(self, configs: list) -> None:
|
||||
"""Initialize McpClient for each config entry; skip failures with a warning."""
|
||||
if not configs:
|
||||
return
|
||||
|
||||
for cfg in configs:
|
||||
name = cfg.get("name", "<unnamed>")
|
||||
client = McpClient(cfg)
|
||||
ok = client.initialize()
|
||||
if ok:
|
||||
with self._registry_lock:
|
||||
self._clients[name] = client
|
||||
logger.info(f"[MCP] Server '{name}' initialized successfully")
|
||||
else:
|
||||
logger.warning(f"[MCP] Server '{name}' failed to initialize — skipping")
|
||||
|
||||
def get(self, server_name: str) -> Optional[McpClient]:
|
||||
"""Return the initialized client for server_name, or None."""
|
||||
with self._registry_lock:
|
||||
return self._clients.get(server_name)
|
||||
|
||||
def all_clients(self) -> dict:
|
||||
"""Return a copy of the {name: McpClient} mapping."""
|
||||
with self._registry_lock:
|
||||
return dict(self._clients)
|
||||
|
||||
def shutdown_all(self) -> None:
|
||||
"""Shut down all managed clients."""
|
||||
with self._registry_lock:
|
||||
clients = list(self._clients.values())
|
||||
self._clients.clear()
|
||||
|
||||
for client in clients:
|
||||
try:
|
||||
client.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP] Error shutting down '{client.name}': {e}")
|
||||
|
||||
logger.info("[MCP] All servers shut down")
|
||||
31
agent/tools/mcp/mcp_tool.py
Normal file
31
agent/tools/mcp/mcp_tool.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class McpTool(BaseTool):
|
||||
"""
|
||||
将单个 MCP 工具包装为 BaseTool。
|
||||
一个 MCP Server 可以提供多个工具,每个工具对应一个 McpTool 实例。
|
||||
"""
|
||||
|
||||
def __init__(self, client, tool_schema: dict, server_name: str):
|
||||
"""
|
||||
:param client: 该工具所属的 McpClient 实例
|
||||
:param tool_schema: MCP 返回的工具描述,格式:
|
||||
{"name": str, "description": str, "inputSchema": dict}
|
||||
:param server_name: Server 名称,用于日志
|
||||
"""
|
||||
self.client = client
|
||||
self.server_name = server_name
|
||||
self.name = tool_schema["name"]
|
||||
self.description = tool_schema.get("description", "")
|
||||
self.params = tool_schema.get("inputSchema", {})
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
logger.info(f"[McpTool] server={self.server_name} tool={self.name} params={params}")
|
||||
try:
|
||||
result = self.client.call_tool(self.name, params)
|
||||
return ToolResult.success(result)
|
||||
except Exception as e:
|
||||
logger.error(f"[McpTool] server={self.server_name} tool={self.name} error: {e}")
|
||||
return ToolResult.fail(str(e))
|
||||
@@ -44,6 +44,19 @@ class MemoryGetTool(BaseTool):
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
self.description = (
|
||||
"Read specific content from memory or knowledge files. "
|
||||
"Use this to get full context from a memory file, knowledge page, or specific line range."
|
||||
)
|
||||
self.params = {**self.params}
|
||||
self.params["properties"] = {**self.params["properties"]}
|
||||
self.params["properties"]["path"] = {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory or knowledge file (e.g. 'MEMORY.md', 'memory/2026-01-01.md', 'knowledge/concepts/moe.md')"
|
||||
}
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
@@ -68,16 +81,20 @@ class MemoryGetTool(BaseTool):
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
# Exception: MEMORY.md is in the root directory
|
||||
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
# Exceptions: MEMORY.md in root, knowledge/ files at workspace root
|
||||
if not path.startswith('memory/') and not path.startswith('knowledge/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = workspace_dir / path
|
||||
file_path = (workspace_dir / path).resolve()
|
||||
workspace_resolved = workspace_dir.resolve()
|
||||
|
||||
if not str(file_path).startswith(str(workspace_resolved) + '/') and file_path != workspace_resolved:
|
||||
return ToolResult.fail(f"Error: Access denied: path outside workspace")
|
||||
|
||||
if not file_path.exists():
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
content = file_path.read_text()
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
lines = content.split('\n')
|
||||
|
||||
# Handle line range
|
||||
|
||||
@@ -48,6 +48,13 @@ class MemorySearchTool(BaseTool):
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self.user_id = user_id
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
self.description = (
|
||||
"Search agent's long-term memory and knowledge base using semantic and keyword search. "
|
||||
"Use this to recall past conversations, preferences, and knowledge pages."
|
||||
)
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
|
||||
@@ -9,6 +9,7 @@ from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Read(BaseTool):
|
||||
@@ -47,7 +48,8 @@ class Read(BaseTool):
|
||||
self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
|
||||
self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
|
||||
self.pdf_extensions = {'.pdf'}
|
||||
|
||||
self.office_extensions = {'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx'}
|
||||
|
||||
# Readable text formats (will be read with truncation)
|
||||
self.text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
|
||||
@@ -56,7 +58,6 @@ class Read(BaseTool):
|
||||
'.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
|
||||
'.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
|
||||
'.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
|
||||
'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx' # Office documents
|
||||
}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
@@ -66,10 +67,12 @@ class Read(BaseTool):
|
||||
:param args: Contains file path and optional offset/limit parameters
|
||||
:return: File content or error message
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
# Support 'location' as alias for 'path' (LLM may use it from skill listing)
|
||||
path = args.get("path", "") or args.get("location", "")
|
||||
path = path.strip() if isinstance(path, str) else ""
|
||||
offset = args.get("offset")
|
||||
limit = args.get("limit")
|
||||
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
@@ -77,7 +80,7 @@ class Read(BaseTool):
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent reading sensitive config files
|
||||
env_config_path = os.path.expanduser("~/.cow/.env")
|
||||
env_config_path = expand_path("~/.cow/.env")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_path):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
@@ -117,7 +120,11 @@ class Read(BaseTool):
|
||||
# Check if PDF
|
||||
if file_ext in self.pdf_extensions:
|
||||
return self._read_pdf(absolute_path, path, offset, limit)
|
||||
|
||||
|
||||
# Check if Office document (.docx, .xlsx, .pptx, etc.)
|
||||
if file_ext in self.office_extensions:
|
||||
return self._read_office(absolute_path, path, file_ext, offset, limit)
|
||||
|
||||
# Read text file (with truncation for large files)
|
||||
return self._read_text(absolute_path, path, offset, limit)
|
||||
|
||||
@@ -129,7 +136,7 @@ class Read(BaseTool):
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
@@ -237,17 +244,12 @@ class Read(BaseTool):
|
||||
"message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
|
||||
})
|
||||
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
# Read file (utf-8-sig strips BOM automatically on Windows)
|
||||
# Note: Truncation is unified via truncate_head (DEFAULT_MAX_LINES / DEFAULT_MAX_BYTES)
|
||||
# so that offset/limit can paginate the entire file correctly.
|
||||
with open(absolute_path, 'r', encoding='utf-8-sig') as f:
|
||||
content = f.read()
|
||||
|
||||
# Truncate content if too long (20K characters max for model context)
|
||||
MAX_CONTENT_CHARS = 20 * 1024 # 20K characters
|
||||
content_truncated = False
|
||||
if len(content) > MAX_CONTENT_CHARS:
|
||||
content = content[:MAX_CONTENT_CHARS]
|
||||
content_truncated = True
|
||||
|
||||
|
||||
all_lines = content.split('\n')
|
||||
total_file_lines = len(all_lines)
|
||||
|
||||
@@ -283,11 +285,7 @@ class Read(BaseTool):
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
# Add truncation warning if content was truncated
|
||||
if content_truncated:
|
||||
output_text = f"[文件内容已截断到前 {format_size(MAX_CONTENT_CHARS)},完整文件大小: {format_size(file_size)}]\n\n"
|
||||
|
||||
|
||||
if truncation.first_line_exceeds_limit:
|
||||
# First line exceeds 30KB limit
|
||||
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
|
||||
@@ -334,6 +332,116 @@ class Read(BaseTool):
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading file: {str(e)}")
|
||||
|
||||
def _read_office(self, absolute_path: str, display_path: str, file_ext: str,
|
||||
offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""Read Office documents (.docx, .xlsx, .pptx) using python-docx / openpyxl / python-pptx."""
|
||||
try:
|
||||
text = self._extract_office_text(absolute_path, file_ext)
|
||||
except ImportError as e:
|
||||
return ToolResult.fail(str(e))
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading Office document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
return ToolResult.success({
|
||||
"content": f"[Office file {Path(absolute_path).name}: no text content could be extracted]",
|
||||
})
|
||||
|
||||
all_lines = text.split('\n')
|
||||
total_lines = len(all_lines)
|
||||
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
if offset < 0:
|
||||
start_line = max(0, total_lines + offset)
|
||||
else:
|
||||
start_line = max(0, offset - 1)
|
||||
if start_line >= total_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||
)
|
||||
|
||||
selected_content = text
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
truncation = truncate_head(selected_content)
|
||||
start_line_display = start_line + 1
|
||||
output_text = ""
|
||||
|
||||
if truncation.truncated:
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||
remaining = total_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text = truncation.content
|
||||
|
||||
return ToolResult.success({
|
||||
"content": output_text,
|
||||
"total_lines": total_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _extract_office_text(absolute_path: str, file_ext: str) -> str:
|
||||
"""Extract plain text from an Office document."""
|
||||
if file_ext in ('.docx', '.doc'):
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError("Error: python-docx library not installed. Install with: pip install python-docx")
|
||||
doc = Document(absolute_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs]
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
paragraphs.append('\t'.join(cell.text for cell in row.cells))
|
||||
return '\n'.join(paragraphs)
|
||||
|
||||
if file_ext in ('.xlsx', '.xls'):
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
except ImportError:
|
||||
raise ImportError("Error: openpyxl library not installed. Install with: pip install openpyxl")
|
||||
wb = load_workbook(absolute_path, read_only=True, data_only=True)
|
||||
parts = []
|
||||
for ws in wb.worksheets:
|
||||
parts.append(f"--- Sheet: {ws.title} ---")
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
parts.append('\t'.join(str(c) if c is not None else '' for c in row))
|
||||
wb.close()
|
||||
return '\n'.join(parts)
|
||||
|
||||
if file_ext in ('.pptx', '.ppt'):
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError("Error: python-pptx library not installed. Install with: pip install python-pptx")
|
||||
prs = Presentation(absolute_path)
|
||||
parts = []
|
||||
for i, slide in enumerate(prs.slides, 1):
|
||||
parts.append(f"--- Slide {i} ---")
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for para in shape.text_frame.paragraphs:
|
||||
text = para.text.strip()
|
||||
if text:
|
||||
parts.append(text)
|
||||
return '\n'.join(parts)
|
||||
|
||||
return ""
|
||||
|
||||
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read PDF file content
|
||||
|
||||
@@ -3,74 +3,137 @@ Integration module for scheduler with AgentBridge
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
# Global scheduler service instance
|
||||
_scheduler_service = None
|
||||
_task_store = None
|
||||
# Module-level lock to guard idempotent initialization across threads
|
||||
_init_lock = threading.Lock()
|
||||
|
||||
|
||||
def init_scheduler(agent_bridge) -> bool:
|
||||
"""
|
||||
Initialize scheduler service
|
||||
|
||||
Initialize scheduler service (idempotent).
|
||||
|
||||
Safe to call multiple times and from multiple threads: only the first
|
||||
successful call creates the singleton ``SchedulerService`` + background
|
||||
scanning thread. Subsequent calls return immediately.
|
||||
|
||||
Args:
|
||||
agent_bridge: AgentBridge instance
|
||||
|
||||
|
||||
Returns:
|
||||
True if initialized successfully
|
||||
True if scheduler is initialized (newly created or already running)
|
||||
"""
|
||||
global _scheduler_service, _task_store
|
||||
|
||||
try:
|
||||
from agent.tools.scheduler.task_store import TaskStore
|
||||
from agent.tools.scheduler.scheduler_service import SchedulerService
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
|
||||
|
||||
# Create task store
|
||||
_task_store = TaskStore(store_path)
|
||||
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
|
||||
|
||||
# Create execute callback
|
||||
def execute_task_callback(task: dict):
|
||||
"""Callback to execute a scheduled task"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "agent_task":
|
||||
_execute_agent_task(task, agent_bridge)
|
||||
elif action_type == "send_message":
|
||||
# Legacy support for old tasks
|
||||
_execute_send_message(task, agent_bridge)
|
||||
elif action_type == "tool_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_tool_call(task, agent_bridge)
|
||||
elif action_type == "skill_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_skill_call(task, agent_bridge)
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
|
||||
|
||||
# Create scheduler service
|
||||
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
|
||||
_scheduler_service.start()
|
||||
|
||||
logger.debug("[Scheduler] Scheduler service initialized and started")
|
||||
|
||||
# Fast path: already initialized and running
|
||||
if _scheduler_service is not None and getattr(_scheduler_service, "running", False):
|
||||
return True
|
||||
|
||||
with _init_lock:
|
||||
# Re-check under the lock to avoid races where multiple threads
|
||||
# passed the fast-path check before any of them acquired the lock.
|
||||
if _scheduler_service is not None and getattr(_scheduler_service, "running", False):
|
||||
return True
|
||||
|
||||
try:
|
||||
from agent.tools.scheduler.task_store import TaskStore
|
||||
from agent.tools.scheduler.scheduler_service import SchedulerService
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
|
||||
|
||||
# Create task store (reuse if already created)
|
||||
if _task_store is None:
|
||||
_task_store = TaskStore(store_path)
|
||||
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
|
||||
|
||||
# Create execute callback. Returns True on success, False to ask
|
||||
# the scheduler to retry on the next tick (e.g. channel not yet
|
||||
# ready right after process start).
|
||||
def execute_task_callback(task: dict):
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type")
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
receiver = action.get("receiver", "")
|
||||
|
||||
if not _is_channel_ready(channel_type, receiver):
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task.get('id')}: channel "
|
||||
f"'{channel_type}' not ready for receiver={receiver} "
|
||||
f"(no inbound msg cached since restart?); deferring"
|
||||
)
|
||||
return False
|
||||
|
||||
if action_type == "agent_task":
|
||||
return _execute_agent_task(task, agent_bridge)
|
||||
elif action_type == "send_message":
|
||||
return _execute_send_message(task, agent_bridge)
|
||||
elif action_type == "tool_call":
|
||||
return _execute_tool_call(task, agent_bridge)
|
||||
elif action_type == "skill_call":
|
||||
return _execute_skill_call(task, agent_bridge)
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
|
||||
return False
|
||||
|
||||
# Create scheduler service
|
||||
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
|
||||
_scheduler_service.start()
|
||||
|
||||
logger.info("[Scheduler] Service initialized and started")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _is_channel_ready(channel_type: str, receiver: str) -> bool:
|
||||
"""Best-effort readiness probe for outbound channels.
|
||||
|
||||
Returns False when we know the send will drop (e.g. weixin not yet
|
||||
logged in, web session has no polling queue), so the scheduler can
|
||||
defer instead of consuming the task. Unknown channels return True
|
||||
to preserve previous behaviour.
|
||||
"""
|
||||
if not channel_type or channel_type == "unknown":
|
||||
return True
|
||||
try:
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if channel is None:
|
||||
return False
|
||||
|
||||
if channel_type == "weixin":
|
||||
tokens = getattr(channel, "_context_tokens", None)
|
||||
if not tokens or receiver not in tokens:
|
||||
return False
|
||||
return True
|
||||
|
||||
if channel_type == "web":
|
||||
queues = getattr(channel, "session_queues", None)
|
||||
if not queues or receiver not in queues:
|
||||
return False
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
|
||||
return False
|
||||
logger.warning(f"[Scheduler] Channel readiness check failed for {channel_type}: {e}")
|
||||
return True
|
||||
|
||||
|
||||
def get_task_store():
|
||||
@@ -83,13 +146,53 @@ def get_scheduler_service():
|
||||
return _scheduler_service
|
||||
|
||||
|
||||
def _execute_agent_task(task: dict, agent_bridge):
|
||||
def _remember_delivered_output(
|
||||
agent_bridge,
|
||||
task: dict,
|
||||
channel_type: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Best-effort persistence of the message the scheduler sent to a user.
|
||||
|
||||
Uses notify_session_id (the real chat session_id stored at task creation time)
|
||||
so that group chats correctly associate the output with the user's conversation.
|
||||
Falls back to receiver for backward compatibility with old tasks.
|
||||
|
||||
Per-action-type behaviour:
|
||||
- agent_task / tool_call / skill_call: gated by ``scheduler_inject_to_session``
|
||||
(default True). These produce AI-generated content worth remembering.
|
||||
- send_message: additionally gated by ``scheduler_inject_send_message``
|
||||
(default False). Fixed reminder text rarely benefits follow-up Q&A and
|
||||
would just consume context tokens.
|
||||
"""
|
||||
Execute an agent_task action - let Agent handle the task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
if not content:
|
||||
return
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
|
||||
# send_message defaults to NOT being injected; explicit opt-in via config.
|
||||
if action_type == "send_message":
|
||||
if not conf().get("scheduler_inject_send_message", False):
|
||||
return
|
||||
|
||||
session_id = action.get("notify_session_id") or action.get("receiver")
|
||||
if not session_id:
|
||||
return
|
||||
try:
|
||||
remember = getattr(agent_bridge, "remember_scheduled_output", None)
|
||||
if remember:
|
||||
task_desc = action.get("task_description") or action.get("content", "")
|
||||
remember(session_id, str(content), channel_type=channel_type, task_description=task_desc)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[Scheduler] Failed to remember delivered output for {session_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _execute_agent_task(task: dict, agent_bridge) -> bool:
|
||||
"""
|
||||
Execute an agent_task action - let Agent handle the task.
|
||||
Returns True on successful delivery, False to retry next tick.
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
@@ -100,11 +203,11 @@ def _execute_agent_task(task: dict, agent_bridge):
|
||||
|
||||
if not task_description:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No task_description specified")
|
||||
return
|
||||
return True # malformed task, don't loop forever
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
return True
|
||||
|
||||
# Check for unsupported channels
|
||||
if channel_type == "dingtalk":
|
||||
@@ -112,11 +215,15 @@ def _execute_agent_task(task: dict, agent_bridge):
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing agent task '{task_description}'")
|
||||
|
||||
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
|
||||
# Format: scheduler_<receiver>_<task_id> to ensure isolation
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
|
||||
# Create context for Agent
|
||||
context = Context(ContextType.TEXT, task_description)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
# Channel-specific setup
|
||||
if channel_type == "web":
|
||||
@@ -129,62 +236,61 @@ def _execute_agent_task(task: dict, agent_bridge):
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk requires msg object, set to None for scheduled tasks
|
||||
context["msg"] = None
|
||||
# 如果是单聊,需要传递 sender_staff_id
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the task
|
||||
# Mark this as a scheduled task execution to prevent recursive task creation
|
||||
context["is_scheduled_task"] = True
|
||||
|
||||
try:
|
||||
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=True)
|
||||
|
||||
if reply and reply.content:
|
||||
# Send the reply via channel
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register request_id
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
request_id = context.get("request_id")
|
||||
if request_id:
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
# Send the reply
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send result: {e}")
|
||||
else:
|
||||
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
|
||||
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=False)
|
||||
|
||||
if not (reply and reply.content):
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution")
|
||||
|
||||
return True # agent ran but produced nothing; don't loop
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
request_id = context.get("request_id")
|
||||
if request_id:
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send result: {e}")
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, reply.content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute task via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_agent_task: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_send_message(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a send_message action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
def _execute_send_message(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a send_message action. Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
content = action.get("content", "")
|
||||
@@ -194,7 +300,7 @@ def _execute_send_message(task: dict, agent_bridge):
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
return True
|
||||
|
||||
# Create context for sending message
|
||||
context = Context(ContextType.TEXT, content)
|
||||
@@ -228,170 +334,146 @@ def _execute_send_message(task: dict, agent_bridge):
|
||||
logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
elif channel_type == "qq":
|
||||
context["msg"] = None
|
||||
|
||||
# Create reply
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_send_message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_tool_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a tool_call action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
def _execute_tool_call(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a tool_call action. Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
tool_name = action.get("call_name") or action.get("tool_name")
|
||||
tool_params = action.get("call_params") or action.get("tool_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
|
||||
if not tool_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified")
|
||||
return
|
||||
|
||||
return True
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Get tool manager and create tool instance
|
||||
return True
|
||||
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
tool_manager = ToolManager()
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
tool = ToolManager().create_tool(tool_name)
|
||||
if not tool:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found")
|
||||
return
|
||||
|
||||
# Execute tool
|
||||
return True
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}")
|
||||
result = tool.execute(tool_params)
|
||||
|
||||
# Get result content
|
||||
if hasattr(result, 'result'):
|
||||
content = result.result
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
# Add prefix if specified
|
||||
content = result.result if hasattr(result, 'result') else str(result)
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
# Send result as message
|
||||
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific context setup
|
||||
|
||||
request_id = None
|
||||
if channel_type == "web":
|
||||
# Web channel needs request_id
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and request_id and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send tool result: {e}")
|
||||
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_skill_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a skill_call action by asking Agent to run the skill
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
def _execute_skill_call(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a skill_call action by asking Agent to run the skill.
|
||||
Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
skill_name = action.get("call_name") or action.get("skill_name")
|
||||
skill_params = action.get("call_params") or action.get("skill_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("isgroup", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
|
||||
if not skill_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified")
|
||||
return
|
||||
|
||||
return True
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
return True
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}")
|
||||
|
||||
# Build a natural language query for the Agent to execute the skill
|
||||
# Format: "Use skill-name to do something with params"
|
||||
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()])
|
||||
query = f"Use {skill_name} skill"
|
||||
if param_str:
|
||||
query += f" with {param_str}"
|
||||
|
||||
# Create context for Agent
|
||||
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific setup
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
@@ -399,31 +481,51 @@ def _execute_skill_call(task: dict, agent_bridge):
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the skill
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
try:
|
||||
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=True)
|
||||
|
||||
if reply and reply.content:
|
||||
content = reply.content
|
||||
|
||||
# Add prefix if specified
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
|
||||
|
||||
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
return False
|
||||
|
||||
if not (reply and reply.content):
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
|
||||
return True
|
||||
|
||||
content = reply.content
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
req_id = context.get("request_id")
|
||||
if req_id:
|
||||
channel.request_to_session[req_id] = receiver
|
||||
|
||||
try:
|
||||
channel.send(Reply(ReplyType.TEXT, content), context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send skill result: {e}")
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_skill_call: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
@@ -440,8 +542,7 @@ def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
if context:
|
||||
tool.current_context = context
|
||||
|
||||
# Also set channel_type from config
|
||||
channel_type = conf().get("channel_type", "unknown")
|
||||
channel_type = context.get("channel_type") or conf().get("channel_type", "unknown")
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = channel_type
|
||||
|
||||
@@ -10,6 +10,19 @@ from croniter import croniter
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def _parse_naive_local(iso_str: str) -> datetime:
|
||||
"""Parse an ISO datetime and coerce it to tz-naive local time.
|
||||
|
||||
The scheduler uses ``datetime.now()`` (tz-naive) for all comparisons,
|
||||
so any persisted timestamp must be normalized to the same flavor —
|
||||
otherwise comparing naive vs aware raises TypeError.
|
||||
"""
|
||||
dt = datetime.fromisoformat(iso_str)
|
||||
if dt.tzinfo is not None:
|
||||
dt = dt.astimezone().replace(tzinfo=None)
|
||||
return dt
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""
|
||||
Background service that executes scheduled tasks
|
||||
@@ -39,7 +52,6 @@ class SchedulerService:
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self.thread.start()
|
||||
logger.debug("[Scheduler] Service started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the scheduler service"""
|
||||
@@ -54,15 +66,14 @@ class SchedulerService:
|
||||
|
||||
def _run_loop(self):
|
||||
"""Main scheduler loop"""
|
||||
logger.debug("[Scheduler] Scheduler loop started")
|
||||
logger.info("[Scheduler] Scheduler loop started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self._check_and_execute_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in scheduler loop: {e}")
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
def _check_and_execute_tasks(self):
|
||||
@@ -72,12 +83,18 @@ class SchedulerService:
|
||||
|
||||
for task in tasks:
|
||||
try:
|
||||
# Check if task is due
|
||||
if self._is_task_due(task, now):
|
||||
logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}")
|
||||
self._execute_task(task)
|
||||
|
||||
# Update next run time
|
||||
ok = self._execute_task(task)
|
||||
if not ok:
|
||||
# Leave next_run_at as-is so the next loop retries.
|
||||
# Cron tasks within the catch-up window will keep
|
||||
# firing; beyond it _is_task_due will reschedule.
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task['id']} delivery failed, will retry next tick"
|
||||
)
|
||||
continue
|
||||
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
@@ -85,12 +102,8 @@ class SchedulerService:
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
else:
|
||||
# One-time task, disable it
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task completed and disabled: {task['id']}")
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task completed and removed: {task['id']}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
|
||||
|
||||
@@ -117,37 +130,43 @@ class SchedulerService:
|
||||
return False
|
||||
|
||||
try:
|
||||
next_run = datetime.fromisoformat(next_run_str)
|
||||
|
||||
# Check if task is overdue (e.g., service restart)
|
||||
next_run = _parse_naive_local(next_run_str)
|
||||
|
||||
if next_run < now:
|
||||
time_diff = (now - next_run).total_seconds()
|
||||
|
||||
# If overdue by more than 5 minutes, skip this run and schedule next
|
||||
if time_diff > 300: # 5 minutes
|
||||
logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run")
|
||||
|
||||
# For one-time tasks, disable them
|
||||
schedule = task.get("schedule", {})
|
||||
if schedule.get("type") == "once":
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, disabled")
|
||||
return False
|
||||
|
||||
# For recurring tasks, calculate next run from now
|
||||
next_next_run = self._calculate_next_run(task, now)
|
||||
if next_next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_next_run.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
# Catch-up window: fire if we're within 10 minutes of the
|
||||
# scheduled tick. Beyond that we'd rather skip than push a
|
||||
# stale daily report to the user.
|
||||
if time_diff <= 600:
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, "
|
||||
f"skipping and scheduling next run"
|
||||
)
|
||||
|
||||
if schedule_type == "once":
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, removed")
|
||||
return False
|
||||
|
||||
|
||||
next_next_run = self._calculate_next_run(task, now)
|
||||
if next_next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_next_run.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
|
||||
return False
|
||||
|
||||
return now >= next_run
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Scheduler] Failed to evaluate due-state for task "
|
||||
f"{task.get('id')} (next_run_at={next_run_str!r}): {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
|
||||
@@ -191,30 +210,34 @@ class SchedulerService:
|
||||
return None
|
||||
|
||||
try:
|
||||
run_at = datetime.fromisoformat(run_at_str)
|
||||
# Only return if in the future
|
||||
run_at = _parse_naive_local(run_at_str)
|
||||
if run_at > from_time:
|
||||
return run_at
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Scheduler] Failed to parse once-task run_at "
|
||||
f"{run_at_str!r}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _execute_task(self, task: dict):
|
||||
def _execute_task(self, task: dict) -> bool:
|
||||
"""
|
||||
Execute a task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
Execute a task.
|
||||
|
||||
Returns True if delivery succeeded (caller should advance state),
|
||||
False if it failed (caller should keep next_run_at so the next
|
||||
loop iteration retries). Callback may return None for legacy
|
||||
behaviour, treated as success.
|
||||
"""
|
||||
try:
|
||||
# Call the execute callback
|
||||
self.execute_callback(task)
|
||||
result = self.execute_callback(task)
|
||||
return False if result is False else True
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task['id']}: {e}")
|
||||
# Update task with error
|
||||
self.task_store.update_task(task['id'], {
|
||||
"last_error": str(e),
|
||||
"last_error_at": datetime.now().isoformat()
|
||||
})
|
||||
return False
|
||||
|
||||
@@ -20,7 +20,8 @@ class SchedulerTool(BaseTool):
|
||||
|
||||
name: str = "scheduler"
|
||||
description: str = (
|
||||
"创建、查询和管理定时任务。支持固定消息和AI任务两种类型。\n\n"
|
||||
"创建、查询和管理定时任务(提醒、周期性任务等)。\n\n"
|
||||
"⚠️ 重要:仅当需要「定时/提醒/每天/每周/X分钟后/X点」等延迟或周期执行时才使用此工具。"
|
||||
"使用方法:\n"
|
||||
"- 创建:action='create', name='任务名', message/ai_task='内容', schedule_type='once/interval/cron', schedule_value='...'\n"
|
||||
"- 查询:action='list' / action='get', task_id='任务ID'\n"
|
||||
@@ -53,7 +54,7 @@ class SchedulerTool(BaseTool):
|
||||
},
|
||||
"ai_task": {
|
||||
"type": "string",
|
||||
"description": "AI任务描述 (与message二选一),如'搜索今日新闻'、'查询天气'"
|
||||
"description": "AI任务描述 (与message二选一),用于定时让AI执行的任务"
|
||||
},
|
||||
"schedule_type": {
|
||||
"type": "string",
|
||||
@@ -157,6 +158,11 @@ class SchedulerTool(BaseTool):
|
||||
# Create task
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Capture the real chat session_id at task creation time so that scheduler
|
||||
# can later inject the delivered output into the user's actual conversation
|
||||
# (in group chats, session_id != receiver, e.g. "user_id:group_id" on feishu).
|
||||
notify_session_id = context.get("session_id")
|
||||
|
||||
# Build action based on message or ai_task
|
||||
if message:
|
||||
action = {
|
||||
@@ -165,7 +171,8 @@ class SchedulerTool(BaseTool):
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
"channel_type": self.config.get("channel_type", "unknown"),
|
||||
"notify_session_id": notify_session_id,
|
||||
}
|
||||
else: # ai_task
|
||||
action = {
|
||||
@@ -174,7 +181,8 @@ class SchedulerTool(BaseTool):
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
"channel_type": self.config.get("channel_type", "unknown"),
|
||||
"notify_session_id": notify_session_id,
|
||||
}
|
||||
|
||||
# 针对钉钉单聊,额外存储 sender_staff_id
|
||||
@@ -356,9 +364,12 @@ class SchedulerTool(BaseTool):
|
||||
logger.error(f"[SchedulerTool] Invalid relative time format: {schedule_value}")
|
||||
return None
|
||||
else:
|
||||
# Absolute time in ISO format
|
||||
datetime.fromisoformat(schedule_value)
|
||||
return {"type": "once", "run_at": schedule_value}
|
||||
# Absolute ISO time. Normalize to tz-naive local so it
|
||||
# stays comparable with the scheduler's datetime.now().
|
||||
parsed = datetime.fromisoformat(schedule_value)
|
||||
if parsed.tzinfo is not None:
|
||||
parsed = parsed.astimezone().replace(tzinfo=None)
|
||||
return {"type": "once", "run_at": parsed.isoformat()}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Invalid schedule: {e}")
|
||||
@@ -423,7 +434,7 @@ class SchedulerTool(BaseTool):
|
||||
try:
|
||||
dt = datetime.fromisoformat(run_at)
|
||||
return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
|
||||
except:
|
||||
except Exception:
|
||||
return "一次性"
|
||||
|
||||
return "未知"
|
||||
@@ -437,6 +448,6 @@ class SchedulerTool(BaseTool):
|
||||
return msg.other_user_nickname or "群聊"
|
||||
else:
|
||||
return msg.from_user_nickname or "用户"
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return "未知"
|
||||
|
||||
@@ -8,6 +8,7 @@ import threading
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class TaskStore:
|
||||
@@ -24,7 +25,7 @@ class TaskStore:
|
||||
"""
|
||||
if store_path is None:
|
||||
# Default to ~/cow/scheduler/tasks.json
|
||||
home = os.path.expanduser("~")
|
||||
home = expand_path("~")
|
||||
store_path = os.path.join(home, "cow", "scheduler", "tasks.json")
|
||||
|
||||
self.store_path = store_path
|
||||
@@ -71,7 +72,7 @@ class TaskStore:
|
||||
with open(self.store_path, 'r') as src:
|
||||
with open(backup_path, 'w') as dst:
|
||||
dst.write(src.read())
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save tasks
|
||||
|
||||
@@ -7,20 +7,21 @@ from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Send(BaseTool):
|
||||
"""Tool for sending files to the user"""
|
||||
|
||||
name: str = "send"
|
||||
description: str = "Send a file (image, video, audio, document) to the user. Use this when the user explicitly asks to send/share a file."
|
||||
description: str = "Send a LOCAL file (image, video, audio, document) to the user. Only for local file paths. Do NOT use this for URLs — URLs should be included directly in your text reply, the system will handle them automatically."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to send. Can be absolute path or relative to workspace."
|
||||
"description": "Local file path to send. Must be an absolute path or relative to workspace. Do NOT pass URLs here."
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
@@ -97,12 +98,23 @@ class Send(BaseTool):
|
||||
"size_formatted": self._format_size(file_size),
|
||||
"message": message or f"正在发送 {file_name}"
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
from common.cloud_client import get_website_base_url, copy_send_file
|
||||
|
||||
# Do nothing when in local env
|
||||
if get_website_base_url():
|
||||
url = copy_send_file(absolute_path, self.cwd)
|
||||
if url:
|
||||
result["url"] = url
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
path = os.path.expanduser(path)
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Type
|
||||
from agent.tools.base_tool import BaseTool
|
||||
@@ -7,6 +8,26 @@ from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
def _normalize_mcp_configs(raw) -> list:
|
||||
"""
|
||||
Convert MCP server config to internal list format.
|
||||
Supports:
|
||||
- list format (mcp_servers): [{"name": "x", "type": "stdio", ...}]
|
||||
- dict format (mcpServers): {"x": {"command": "npx", ...}}
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
result = []
|
||||
for name, cfg in raw.items():
|
||||
entry = {"name": name, **cfg}
|
||||
if "type" not in entry:
|
||||
entry["type"] = "sse" if "url" in entry else "stdio"
|
||||
result.append(entry)
|
||||
return result
|
||||
return []
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
Tool manager for managing tools.
|
||||
@@ -25,6 +46,31 @@ class ToolManager:
|
||||
# Initialize only once
|
||||
if not hasattr(self, 'tool_classes'):
|
||||
self.tool_classes = {} # Dictionary to store tool classes
|
||||
if not hasattr(self, '_mcp_registry'):
|
||||
self._mcp_registry = None # Lazy init: only created when MCP servers are configured
|
||||
if not hasattr(self, '_mcp_tool_instances'):
|
||||
self._mcp_tool_instances: dict = {} # tool_name -> McpTool instance
|
||||
if not hasattr(self, '_mcp_lock'):
|
||||
# Guards _mcp_loaded check-then-set so concurrent callers
|
||||
# don't trigger duplicate background loaders.
|
||||
self._mcp_lock = threading.Lock()
|
||||
if not hasattr(self, '_mcp_loaded'):
|
||||
# Idempotency flag. Flipped to True the moment the first loader
|
||||
# is dispatched (synchronously, inside _mcp_lock). Subsequent
|
||||
# _load_mcp_tools() calls become no-ops, so per-session agent
|
||||
# initialization never re-forks MCP subprocesses.
|
||||
self._mcp_loaded = False
|
||||
if not hasattr(self, '_mcp_status'):
|
||||
# server_name -> "pending" / "ready" / "failed"
|
||||
# Useful for UI / introspection while async loading is in progress.
|
||||
self._mcp_status: dict = {}
|
||||
if not hasattr(self, '_mcp_signature'):
|
||||
# (mtime, sha256) of mcp.json the last time we loaded.
|
||||
# Used by refresh_mcp_if_changed() to skip re-parsing when nothing changed.
|
||||
self._mcp_signature: tuple = (None, None)
|
||||
if not hasattr(self, '_mcp_active_configs'):
|
||||
# server_name -> normalized config dict, for diff-based reload.
|
||||
self._mcp_active_configs: dict = {}
|
||||
|
||||
def load_tools(self, tools_dir: str = "", config_dict=None):
|
||||
"""
|
||||
@@ -39,6 +85,8 @@ class ToolManager:
|
||||
self._load_tools_from_init()
|
||||
self._configure_tools_from_config(config_dict)
|
||||
|
||||
self._load_mcp_tools()
|
||||
|
||||
def _load_tools_from_init(self) -> bool:
|
||||
"""
|
||||
Load tool classes from tools.__init__.__all__
|
||||
@@ -70,10 +118,14 @@ class ToolManager:
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip memory tools (they need special initialization with memory_manager)
|
||||
# Skip tools that need special initialization
|
||||
if class_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {class_name} (requires memory_manager)")
|
||||
continue
|
||||
# McpTool instances are registered dynamically via _load_mcp_tools()
|
||||
if class_name == "McpTool":
|
||||
logger.debug(f"Skipped tool {class_name} (registered dynamically via mcp_servers config)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
@@ -84,11 +136,11 @@ class ToolManager:
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
if "playwright" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
@@ -154,11 +206,11 @@ class ToolManager:
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
if "playwright" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
@@ -197,7 +249,7 @@ class ToolManager:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool is configured but not loaded.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif tool_name == "google_search":
|
||||
@@ -212,6 +264,306 @@ class ToolManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error configuring tools from config: {e}")
|
||||
|
||||
def _mcp_json_path(self) -> str:
|
||||
import os
|
||||
workspace = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
return os.path.join(workspace, "mcp.json")
|
||||
|
||||
def _read_mcp_json_signature(self):
|
||||
"""
|
||||
Return (mtime, sha256_of_bytes) for ~/cow/mcp.json without parsing.
|
||||
Returns (None, None) if the file doesn't exist or is unreadable.
|
||||
Cheap enough (one stat + one small read) to call on every agent init.
|
||||
"""
|
||||
import os
|
||||
import hashlib
|
||||
path = self._mcp_json_path()
|
||||
try:
|
||||
mtime = os.path.getmtime(path)
|
||||
except OSError:
|
||||
return (None, None)
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
digest = hashlib.sha256(f.read()).hexdigest()
|
||||
except OSError:
|
||||
return (mtime, None)
|
||||
return (mtime, digest)
|
||||
|
||||
def _load_mcp_configs(self) -> list:
|
||||
"""
|
||||
Load MCP server configs with priority:
|
||||
1. ~/cow/mcp.json (supports both mcpServers and mcp_servers keys)
|
||||
2. config.json mcp_servers field (fallback)
|
||||
"""
|
||||
import os
|
||||
import json as _json
|
||||
|
||||
mcp_json_path = self._mcp_json_path()
|
||||
|
||||
if os.path.exists(mcp_json_path):
|
||||
try:
|
||||
with open(mcp_json_path, "r", encoding="utf-8") as f:
|
||||
data = _json.load(f)
|
||||
raw = data.get("mcpServers") or data.get("mcp_servers") or data
|
||||
logger.info(f"[ToolManager] Loading MCP config from {mcp_json_path}")
|
||||
return _normalize_mcp_configs(raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] Failed to read {mcp_json_path}: {e}, falling back to config.json")
|
||||
|
||||
raw = conf().get("mcp_servers", [])
|
||||
return _normalize_mcp_configs(raw)
|
||||
|
||||
def _load_mcp_tools(self):
|
||||
"""
|
||||
Trigger MCP tool loading in a background thread (idempotent).
|
||||
|
||||
Returns immediately. Booting MCP servers (npx, uvx, etc.) takes
|
||||
seconds to tens of seconds on first run, which would otherwise
|
||||
block agent initialization and the user's first message.
|
||||
Built-in tools work fine without MCP, so we let the agent serve
|
||||
traffic right away and let MCP servers come online in the
|
||||
background. Per-session agents read a snapshot of whatever is
|
||||
ready at construction time and gracefully ignore the rest.
|
||||
"""
|
||||
with self._mcp_lock:
|
||||
if self._mcp_loaded:
|
||||
return
|
||||
mcp_servers_config = self._load_mcp_configs()
|
||||
# Snapshot the signature now so future refresh_mcp_if_changed()
|
||||
# calls can short-circuit when nothing has changed on disk.
|
||||
self._mcp_signature = self._read_mcp_json_signature()
|
||||
self._mcp_active_configs = {
|
||||
cfg.get("name", "<unnamed>"): cfg for cfg in mcp_servers_config
|
||||
}
|
||||
if not mcp_servers_config:
|
||||
# Mark as loaded even when there is nothing to load,
|
||||
# so we don't re-read the config file on every call.
|
||||
self._mcp_loaded = True
|
||||
return
|
||||
|
||||
# Mark pending immediately so list_mcp_status() callers see
|
||||
# the in-progress state instead of an empty dict.
|
||||
for cfg in mcp_servers_config:
|
||||
name = cfg.get("name", "<unnamed>")
|
||||
self._mcp_status[name] = "pending"
|
||||
|
||||
self._mcp_loaded = True
|
||||
threading.Thread(
|
||||
target=self._load_mcp_tools_async,
|
||||
args=(mcp_servers_config,),
|
||||
daemon=True,
|
||||
name="mcp-loader",
|
||||
).start()
|
||||
logger.info(
|
||||
f"[ToolManager] MCP loading started in background "
|
||||
f"({len(mcp_servers_config)} server(s) configured)"
|
||||
)
|
||||
|
||||
def refresh_mcp_if_changed(self):
|
||||
"""
|
||||
Cheap check whether ~/cow/mcp.json has changed since last load.
|
||||
If it has, do a diff-based reload: start newly added servers,
|
||||
shut down removed ones, and restart any whose config was edited.
|
||||
Untouched servers are left running.
|
||||
|
||||
Designed to be called on every agent creation. The fast path is
|
||||
a single os.stat() — completely free when nothing has changed.
|
||||
"""
|
||||
with self._mcp_lock:
|
||||
new_sig = self._read_mcp_json_signature()
|
||||
if new_sig == self._mcp_signature:
|
||||
return # no-op fast path
|
||||
|
||||
try:
|
||||
new_configs = self._load_mcp_configs()
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] MCP reload — failed to parse config: {e}")
|
||||
return
|
||||
|
||||
new_by_name = {
|
||||
cfg.get("name", "<unnamed>"): cfg for cfg in new_configs
|
||||
}
|
||||
old_by_name = self._mcp_active_configs
|
||||
|
||||
added = [n for n in new_by_name if n not in old_by_name]
|
||||
removed = [n for n in old_by_name if n not in new_by_name]
|
||||
changed = [
|
||||
n for n in new_by_name
|
||||
if n in old_by_name and new_by_name[n] != old_by_name[n]
|
||||
]
|
||||
|
||||
if not (added or removed or changed):
|
||||
# Signature drifted but content is logically identical
|
||||
# (e.g. user re-saved the file without edits). Just sync.
|
||||
self._mcp_signature = new_sig
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[ToolManager] mcp.json changed — "
|
||||
f"adding={added}, removing={removed}, restarting={changed}"
|
||||
)
|
||||
|
||||
# Tear down removed + changed servers (changed ones get restarted below)
|
||||
for name in removed + changed:
|
||||
self._teardown_mcp_server(name)
|
||||
|
||||
# Spin up newly added + changed servers in the background
|
||||
to_start = [new_by_name[n] for n in added + changed]
|
||||
if to_start:
|
||||
for cfg in to_start:
|
||||
self._mcp_status[cfg.get("name", "<unnamed>")] = "pending"
|
||||
threading.Thread(
|
||||
target=self._load_mcp_tools_async,
|
||||
args=(to_start,),
|
||||
daemon=True,
|
||||
name="mcp-loader-reload",
|
||||
).start()
|
||||
|
||||
self._mcp_active_configs = new_by_name
|
||||
self._mcp_signature = new_sig
|
||||
|
||||
def _teardown_mcp_server(self, server_name: str):
|
||||
"""Shut down one MCP server and drop its tools from the registry."""
|
||||
if self._mcp_registry is None:
|
||||
return
|
||||
client = None
|
||||
with self._mcp_registry._registry_lock:
|
||||
client = self._mcp_registry._clients.pop(server_name, None)
|
||||
if client is not None:
|
||||
try:
|
||||
client.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP] Error shutting down '{server_name}': {e}")
|
||||
# Drop tools that belonged to this server.
|
||||
for tool_name in list(self._mcp_tool_instances.keys()):
|
||||
tool = self._mcp_tool_instances.get(tool_name)
|
||||
if tool is not None and getattr(tool, "server_name", None) == server_name:
|
||||
self._mcp_tool_instances.pop(tool_name, None)
|
||||
self._mcp_status.pop(server_name, None)
|
||||
|
||||
def _load_mcp_tools_async(self, mcp_servers_config):
|
||||
"""
|
||||
Background worker: bring up each MCP server one-by-one and
|
||||
publish ready tools to _mcp_tool_instances as they come online.
|
||||
|
||||
Server failures are isolated — one bad server cannot block
|
||||
the others, and never raises out of the worker thread.
|
||||
"""
|
||||
try:
|
||||
from agent.tools.mcp.mcp_client import McpClient, McpClientRegistry
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
|
||||
registry = McpClientRegistry()
|
||||
self._mcp_registry = registry
|
||||
|
||||
for cfg in mcp_servers_config:
|
||||
server_name = cfg.get("name", "<unnamed>")
|
||||
try:
|
||||
client = McpClient(cfg)
|
||||
if not client.initialize():
|
||||
self._mcp_status[server_name] = "failed"
|
||||
logger.warning(
|
||||
f"[MCP] Server '{server_name}' failed to initialize — skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
tool_schemas = client.list_tools()
|
||||
added = []
|
||||
for schema in tool_schemas:
|
||||
tool_name = schema.get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
mcp_tool = McpTool(client, schema, server_name)
|
||||
# Atomic dict assignment is GIL-safe; readers iterate
|
||||
# over a list() snapshot to avoid concurrent mutation.
|
||||
self._mcp_tool_instances[tool_name] = mcp_tool
|
||||
added.append(tool_name)
|
||||
|
||||
# Register client into the shared registry only after its
|
||||
# tools are visible, so callers never see a half-loaded server.
|
||||
with registry._registry_lock:
|
||||
registry._clients[server_name] = client
|
||||
self._mcp_status[server_name] = "ready"
|
||||
logger.info(
|
||||
f"[MCP] Server '{server_name}' ready — "
|
||||
f"{len(added)} tool(s): {added}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._mcp_status[server_name] = "failed"
|
||||
logger.warning(f"[MCP] Server '{server_name}' load failed: {e}")
|
||||
|
||||
ready = sum(1 for s in self._mcp_status.values() if s == "ready")
|
||||
total = len(self._mcp_status)
|
||||
logger.info(
|
||||
f"[ToolManager] MCP loading complete: "
|
||||
f"{ready}/{total} server(s) ready, "
|
||||
f"{len(self._mcp_tool_instances)} tool(s) available"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] MCP background loader crashed: {e}")
|
||||
|
||||
def list_mcp_status(self) -> dict:
|
||||
"""Return {server_name: status} snapshot for UI / debugging."""
|
||||
return dict(self._mcp_status)
|
||||
|
||||
def sync_mcp_into_agent(self, agent) -> tuple:
|
||||
"""
|
||||
Reconcile a live agent's tool collection with the current MCP tool registry.
|
||||
|
||||
Adds tools that finished loading after the agent was created,
|
||||
and removes tools whose MCP server was torn down. Built-in tools
|
||||
on the agent are left untouched.
|
||||
|
||||
Handles both representations CowAgent uses:
|
||||
- Agent.tools: list[BaseTool] (default Agent class)
|
||||
- AgentStream.tools: dict[str, BaseTool] (streaming agent)
|
||||
|
||||
Returns (added_names, removed_names) for logging.
|
||||
"""
|
||||
if agent is None or not hasattr(agent, "tools"):
|
||||
return ([], [])
|
||||
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
current = self._mcp_tool_instances
|
||||
registry_names = set(current.keys())
|
||||
|
||||
agent_tools = agent.tools
|
||||
|
||||
if isinstance(agent_tools, dict):
|
||||
agent_mcp_names = {
|
||||
name for name, tool in agent_tools.items()
|
||||
if isinstance(tool, McpTool)
|
||||
}
|
||||
added = registry_names - agent_mcp_names
|
||||
removed = agent_mcp_names - registry_names
|
||||
if not (added or removed):
|
||||
return ([], [])
|
||||
for name in added:
|
||||
agent_tools[name] = current[name]
|
||||
for name in removed:
|
||||
agent_tools.pop(name, None)
|
||||
|
||||
elif isinstance(agent_tools, list):
|
||||
agent_mcp_names = {
|
||||
t.name for t in agent_tools if isinstance(t, McpTool)
|
||||
}
|
||||
added = registry_names - agent_mcp_names
|
||||
removed = agent_mcp_names - registry_names
|
||||
if not (added or removed):
|
||||
return ([], [])
|
||||
if removed:
|
||||
agent.tools = [
|
||||
t for t in agent_tools
|
||||
if not (isinstance(t, McpTool) and t.name in removed)
|
||||
]
|
||||
for name in added:
|
||||
agent.tools.append(current[name])
|
||||
|
||||
else:
|
||||
return ([], [])
|
||||
|
||||
return (sorted(added), sorted(removed))
|
||||
|
||||
def create_tool(self, name: str) -> BaseTool:
|
||||
"""
|
||||
Get a new instance of a tool by name.
|
||||
@@ -229,6 +581,12 @@ class ToolManager:
|
||||
tool_instance.config = self.tool_configs[name]
|
||||
|
||||
return tool_instance
|
||||
|
||||
# Fall back to MCP tool instances
|
||||
mcp_tool = self._mcp_tool_instances.get(name)
|
||||
if mcp_tool:
|
||||
return mcp_tool
|
||||
|
||||
return None
|
||||
|
||||
def list_tools(self) -> dict:
|
||||
@@ -245,4 +603,17 @@ class ToolManager:
|
||||
"description": temp_instance.description,
|
||||
"parameters": temp_instance.get_json_schema()
|
||||
}
|
||||
|
||||
# Include MCP tool instances
|
||||
for name, mcp_tool in self._mcp_tool_instances.items():
|
||||
result[name] = {
|
||||
"description": mcp_tool.description,
|
||||
"parameters": mcp_tool.params,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def shutdown_mcp(self):
|
||||
"""Shut down all MCP server clients."""
|
||||
if self._mcp_registry:
|
||||
self._mcp_registry.shutdown_all()
|
||||
|
||||
@@ -8,7 +8,10 @@ Truncation is based on two independent limits - whichever is hit first wins:
|
||||
Never returns partial lines (except bash tail truncation edge case).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Literal
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 2000
|
||||
@@ -278,7 +281,7 @@ def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
|
||||
return encoded[start:].decode('utf-8', errors='ignore')
|
||||
|
||||
|
||||
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> tuple[str, bool]:
|
||||
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> Tuple[str, bool]:
|
||||
"""
|
||||
Truncate single line to max characters, add [truncated] suffix.
|
||||
Used for grep match lines.
|
||||
|
||||
1
agent/tools/vision/__init__.py
Normal file
1
agent/tools/vision/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from agent.tools.vision.vision import Vision
|
||||
814
agent/tools/vision/vision.py
Normal file
814
agent/tools/vision/vision.py
Normal file
@@ -0,0 +1,814 @@
|
||||
"""
|
||||
Vision tool - Analyze images using Vision API.
|
||||
Supports local files (auto base64-encoded) and HTTP URLs.
|
||||
|
||||
Provider resolution:
|
||||
- tools.vision.model (if set) means "prefer this model first; fall back to
|
||||
other configured providers if it fails". The model name is mapped to its
|
||||
native provider (e.g. doubao-* → Doubao, kimi-* → Moonshot, gpt-* →
|
||||
OpenAI/LinkAI). That provider is tried first, then the standard auto
|
||||
chain runs as fallback (with the preferred provider de-duplicated).
|
||||
- Auto chain priority:
|
||||
1. Main model via bot.call_vision — only when the main bot is known
|
||||
to actually support vision (not just expose a call_vision method).
|
||||
2. Other models whose API key is configured.
|
||||
3. OpenAI / LinkAI raw HTTP.
|
||||
When use_linkai=true, LinkAI is promoted to #1.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
DEFAULT_MODEL = const.GPT_41_MINI
|
||||
DEFAULT_TIMEOUT = 60
|
||||
MAX_TOKENS = 1000
|
||||
COMPRESS_THRESHOLD = 1_048_576 # 1 MB
|
||||
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
}
|
||||
|
||||
_MAIN_MODEL_PROVIDER_NAME = "MainModel"
|
||||
|
||||
# (config_key_for_api_key, bot_type, default_vision_model, provider_display_name)
|
||||
# Auto-discovered as fallback vision providers when their API key is configured.
|
||||
# OpenAI and LinkAI are handled separately (raw HTTP providers), so not listed here.
|
||||
_DISCOVERABLE_MODELS = [
|
||||
("moonshot_api_key", const.MOONSHOT, const.KIMI_K2_6, "Moonshot"),
|
||||
("ark_api_key", const.DOUBAO, const.DOUBAO_SEED_2_PRO, "Doubao"),
|
||||
("dashscope_api_key", const.QWEN_DASHSCOPE, const.QWEN36_PLUS, "DashScope"),
|
||||
("claude_api_key", const.CLAUDEAPI, const.CLAUDE_4_6_SONNET, "Claude"),
|
||||
("gemini_api_key", const.GEMINI, const.GEMINI_35_FLASH, "Gemini"),
|
||||
("qianfan_api_key", const.QIANFAN, const.ERNIE_45_TURBO_VL, "Qianfan"),
|
||||
("zhipu_ai_api_key", const.ZHIPU_AI, const.GLM_4_7, "ZhipuAI"),
|
||||
("minimax_api_key", const.MiniMax, const.MINIMAX_M2_7, "MiniMax"),
|
||||
("mimo_api_key", const.MIMO, const.MIMO_V2_5_PRO, "MiMo"),
|
||||
]
|
||||
|
||||
# Model name prefix → discoverable provider display_name.
|
||||
# Used to auto-route tools.vision.model to its native provider.
|
||||
# Matched case-insensitively; longest prefix wins.
|
||||
_MODEL_PREFIX_TO_PROVIDER = [
|
||||
("doubao-", "Doubao"),
|
||||
("kimi-", "Moonshot"),
|
||||
("moonshot-", "Moonshot"),
|
||||
("qwen", "DashScope"), # qwen-*, qwen3-*, qwen3.6-*, etc.
|
||||
("claude-", "Claude"),
|
||||
("ernie-", "Qianfan"),
|
||||
("gemini-", "Gemini"),
|
||||
("glm-", "ZhipuAI"),
|
||||
("minimax-", "MiniMax"),
|
||||
("abab", "MiniMax"),
|
||||
("mimo-", "MiMo"),
|
||||
]
|
||||
|
||||
# Model prefixes that natively belong to OpenAI / LinkAI (raw HTTP providers).
|
||||
_OPENAI_MODEL_PREFIXES = ("gpt-", "o1-", "o3-", "o4-", "chatgpt-")
|
||||
|
||||
# Maps the UI provider id (persisted in tools.vision.provider) to the internal
|
||||
# display name used in VisionProvider.name. Keep in sync with _DISCOVERABLE_MODELS
|
||||
# and the openai/linkai branches in _route_by_model_name.
|
||||
_PROVIDER_ID_TO_DISPLAY = {
|
||||
"openai": "OpenAI",
|
||||
"linkai": "LinkAI",
|
||||
"moonshot": "Moonshot",
|
||||
"doubao": "Doubao",
|
||||
"dashscope": "DashScope",
|
||||
"claudeAPI": "Claude",
|
||||
"gemini": "Gemini",
|
||||
"qianfan": "Qianfan",
|
||||
"zhipu": "ZhipuAI",
|
||||
"minimax": "MiniMax",
|
||||
"mimo": "MiMo",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionProvider:
|
||||
"""A single Vision API provider configuration."""
|
||||
name: str
|
||||
api_key: str
|
||||
api_base: str
|
||||
extra_headers: dict = field(default_factory=dict)
|
||||
model_override: Optional[str] = None
|
||||
use_bot: bool = False # When True, call via bot.call_vision instead of raw HTTP
|
||||
fallback_bot: Any = None # Bot instance for non-main-model providers
|
||||
|
||||
|
||||
class VisionAPIError(Exception):
|
||||
"""Raised when a Vision API call fails and should trigger fallback."""
|
||||
pass
|
||||
|
||||
|
||||
class Vision(BaseTool):
|
||||
"""Analyze images using Vision API"""
|
||||
|
||||
name: str = "vision"
|
||||
description: str = (
|
||||
"Analyze a local image or image URL (jpg/jpeg/png) using Vision API. "
|
||||
"Can describe content, extract text, identify objects, colors, etc. "
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {
|
||||
"type": "string",
|
||||
"description": "Local file path or HTTP(S) URL of the image to analyze",
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Question to ask about the image",
|
||||
},
|
||||
},
|
||||
"required": ["image", "question"],
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return True
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
image = args.get("image", "").strip()
|
||||
question = args.get("question", "").strip()
|
||||
|
||||
if not image:
|
||||
return ToolResult.fail("Error: 'image' parameter is required")
|
||||
if not question:
|
||||
return ToolResult.fail("Error: 'question' parameter is required")
|
||||
|
||||
providers = self._resolve_providers()
|
||||
if not providers:
|
||||
return ToolResult.fail(
|
||||
"Error: No model available for Vision.\n"
|
||||
"The main model does not support vision and no other API keys are configured.\n"
|
||||
"Options:\n"
|
||||
" 1. Switch to a multimodal model (e.g. ernie-4.5-turbo-vl, qwen3.6-plus, claude-sonnet-4-6, gemini-2.0-flash)\n"
|
||||
" 2. Configure OPENAI_API_KEY: env_config(action=\"set\", key=\"OPENAI_API_KEY\", value=\"your-key\")\n"
|
||||
" 3. Configure LINKAI_API_KEY: env_config(action=\"set\", key=\"LINKAI_API_KEY\", value=\"your-key\")"
|
||||
)
|
||||
|
||||
try:
|
||||
image_content = self._build_image_content(image)
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: {e}")
|
||||
|
||||
# Default model is only used as a last-resort placeholder for providers
|
||||
# whose VisionProvider.model_override is None (e.g. raw OpenAI provider
|
||||
# when the user did not configure tools.vision.model).
|
||||
return self._call_with_fallback(providers, DEFAULT_MODEL, question, image_content)
|
||||
|
||||
def _call_with_fallback(self, providers: List[VisionProvider], model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""Try each provider in order; fall back to the next one on failure."""
|
||||
errors: List[str] = []
|
||||
for i, provider in enumerate(providers):
|
||||
use_model = provider.model_override or model
|
||||
try:
|
||||
logger.info(f"[Vision] Trying provider '{provider.name}' "
|
||||
f"with model '{use_model}' ({i + 1}/{len(providers)})")
|
||||
if provider.use_bot:
|
||||
result = self._call_via_bot(use_model, question, image_content, provider)
|
||||
else:
|
||||
result = self._call_api(provider, use_model, question, image_content)
|
||||
logger.info(f"[Vision] ✅ Success via {provider.name} (model={use_model})")
|
||||
return result
|
||||
except VisionAPIError as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' failed: {e}")
|
||||
except requests.Timeout:
|
||||
errors.append(f"[{provider.name}/{use_model}] Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' timed out")
|
||||
except requests.ConnectionError:
|
||||
errors.append(f"[{provider.name}/{use_model}] Connection failed")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' connection failed")
|
||||
except Exception as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.error(f"[Vision] Provider '{provider.name}' unexpected error: {e}", exc_info=True)
|
||||
|
||||
return ToolResult.fail(
|
||||
"Error: All Vision API providers failed.\n" + "\n".join(f" - {err}" for err in errors)
|
||||
)
|
||||
|
||||
def _resolve_providers(self) -> List[VisionProvider]:
|
||||
"""
|
||||
Build an ordered list of providers to try.
|
||||
|
||||
Semantics of `tools.vision.model`:
|
||||
"Prefer this model first; fall back to other configured providers
|
||||
if it fails."
|
||||
|
||||
Order:
|
||||
1. The provider that natively serves `tools.vision.model` (if any
|
||||
and its API key is configured) — using the user-specified model
|
||||
name verbatim.
|
||||
2. Auto-discovery chain as fallback:
|
||||
- use_linkai=true → [LinkAI, MainModel?, OtherModels…, OpenAI]
|
||||
- default → [MainModel?, OtherModels…, OpenAI, LinkAI]
|
||||
MainModel is only included when the main bot is known to support
|
||||
vision (see _main_bot_supports_vision).
|
||||
|
||||
Providers that share the same display name as the preferred provider
|
||||
are de-duplicated to avoid retrying the same endpoint twice.
|
||||
"""
|
||||
user_model = self._resolve_user_vision_model()
|
||||
user_provider = self._resolve_user_vision_provider()
|
||||
providers: List[VisionProvider] = []
|
||||
|
||||
# Step 1: preferred provider — explicit `tools.vision.provider`
|
||||
# wins so custom model names can still be routed correctly. Falls
|
||||
# through to model-name prefix inference when provider is unset.
|
||||
preferred = None
|
||||
if user_provider and user_model:
|
||||
preferred = self._route_by_provider_id(user_provider, user_model)
|
||||
if not preferred and user_model:
|
||||
preferred = self._route_by_model_name(user_model)
|
||||
if preferred:
|
||||
providers.extend(preferred)
|
||||
|
||||
# Step 2: auto-discovery chain as fallback
|
||||
existing = {p.name for p in providers}
|
||||
fallback: List[VisionProvider] = []
|
||||
use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
|
||||
if use_linkai:
|
||||
self._append_provider(fallback, lambda: self._build_linkai_provider(user_model))
|
||||
self._append_provider(fallback, self._build_main_model_provider)
|
||||
self._append_other_model_providers(fallback, preferred_model=user_model)
|
||||
self._append_provider(fallback, lambda: self._build_openai_provider(user_model))
|
||||
else:
|
||||
self._append_provider(fallback, self._build_main_model_provider)
|
||||
self._append_other_model_providers(fallback, preferred_model=user_model)
|
||||
self._append_provider(fallback, lambda: self._build_openai_provider(user_model))
|
||||
self._append_provider(fallback, lambda: self._build_linkai_provider(user_model))
|
||||
|
||||
for p in fallback:
|
||||
if p.name in existing:
|
||||
continue
|
||||
providers.append(p)
|
||||
existing.add(p.name)
|
||||
|
||||
return providers
|
||||
|
||||
@staticmethod
|
||||
def _append_provider(providers: List[VisionProvider], builder) -> None:
|
||||
p = builder()
|
||||
if p:
|
||||
providers.append(p)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_vision_model() -> Optional[str]:
|
||||
"""Read tools.vision.model (singular ``tool`` kept as runtime fallback)."""
|
||||
tools_conf = conf().get("tools") or conf().get("tool") or {}
|
||||
if not isinstance(tools_conf, dict):
|
||||
return None
|
||||
vision_conf = tools_conf.get("vision", {})
|
||||
if not isinstance(vision_conf, dict):
|
||||
return None
|
||||
m = vision_conf.get("model")
|
||||
if isinstance(m, str) and m.strip():
|
||||
return m.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_vision_provider() -> Optional[str]:
|
||||
"""Read tools.vision.provider — the UI-persisted vendor id.
|
||||
|
||||
Lets users pin a vendor for custom model names that prefix-inference
|
||||
can't recognize. Returns None when unset/blank.
|
||||
"""
|
||||
tools_conf = conf().get("tools") or conf().get("tool") or {}
|
||||
if not isinstance(tools_conf, dict):
|
||||
return None
|
||||
vision_conf = tools_conf.get("vision", {})
|
||||
if not isinstance(vision_conf, dict):
|
||||
return None
|
||||
p = vision_conf.get("provider")
|
||||
if isinstance(p, str) and p.strip():
|
||||
return p.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _infer_provider_from_model(model_name: str) -> Optional[str]:
|
||||
"""
|
||||
Infer the provider display name from a model name's prefix.
|
||||
Returns None when no rule matches (or for OpenAI-family names, which
|
||||
are handled separately by the caller).
|
||||
"""
|
||||
if not model_name:
|
||||
return None
|
||||
lower = model_name.lower()
|
||||
# Sort by prefix length desc so e.g. "moonshot-" wins over hypothetical "moo-"
|
||||
for prefix, display_name in sorted(_MODEL_PREFIX_TO_PROVIDER, key=lambda x: -len(x[0])):
|
||||
if lower.startswith(prefix.lower()):
|
||||
return display_name
|
||||
return None
|
||||
|
||||
def _route_by_provider_id(self, provider_id: str, user_model: str) -> Optional[List[VisionProvider]]:
|
||||
"""Route by the UI-persisted provider id.
|
||||
|
||||
Returns:
|
||||
- [provider] : provider id is known and its key is configured.
|
||||
- None : unknown provider id, or the bot can't be created.
|
||||
Caller falls through to model-name-based routing.
|
||||
"""
|
||||
display_name = _PROVIDER_ID_TO_DISPLAY.get(provider_id)
|
||||
if not display_name:
|
||||
return None
|
||||
|
||||
# OpenAI / LinkAI use raw HTTP providers, not the discoverable bot path.
|
||||
if provider_id == "openai":
|
||||
p = self._build_openai_provider(user_model)
|
||||
return [p] if p else None
|
||||
if provider_id == "linkai":
|
||||
p = self._build_linkai_provider(user_model)
|
||||
return [p] if p else None
|
||||
|
||||
# Discoverable bot-backed providers.
|
||||
for config_key, bot_type, _default_model, name in _DISCOVERABLE_MODELS:
|
||||
if name != display_name:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
logger.warning(f"[Vision] tools.vision.provider='{provider_id}' "
|
||||
f"but '{config_key}' is not configured. Falling back.")
|
||||
return None
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
logger.warning(f"[Vision] '{display_name}' bot does not implement call_vision.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Vision] Failed to create '{display_name}' bot: {e}")
|
||||
return None
|
||||
return [VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=user_model,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)]
|
||||
return None
|
||||
|
||||
def _route_by_model_name(self, user_model: str) -> Optional[List[VisionProvider]]:
|
||||
"""
|
||||
Try to build a provider list using the user-specified model name.
|
||||
Returns:
|
||||
- [provider] : matched and the provider's key is configured
|
||||
- [] : matched but key missing → tell caller to surface this
|
||||
as a hard error rather than silently falling back
|
||||
- None : no rule matches → caller should fall through to auto
|
||||
"""
|
||||
lower = user_model.lower()
|
||||
|
||||
# OpenAI / LinkAI family
|
||||
if lower.startswith(_OPENAI_MODEL_PREFIXES):
|
||||
providers: List[VisionProvider] = []
|
||||
# Prefer LinkAI when explicitly enabled, else OpenAI first
|
||||
use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
if use_linkai:
|
||||
self._append_provider(providers, lambda: self._build_linkai_provider(user_model))
|
||||
self._append_provider(providers, lambda: self._build_openai_provider(user_model))
|
||||
else:
|
||||
self._append_provider(providers, lambda: self._build_openai_provider(user_model))
|
||||
self._append_provider(providers, lambda: self._build_linkai_provider(user_model))
|
||||
if providers:
|
||||
return providers
|
||||
logger.warning(f"[Vision] tools.vision.model='{user_model}' looks like an OpenAI "
|
||||
f"model but neither OPENAI_API_KEY nor LINKAI_API_KEY is configured.")
|
||||
return None # fall through to auto
|
||||
|
||||
# Discoverable native providers (Doubao, Moonshot, etc.)
|
||||
target_display = self._infer_provider_from_model(user_model)
|
||||
if not target_display:
|
||||
return None # unknown prefix → auto
|
||||
|
||||
for config_key, bot_type, _default_model, display_name in _DISCOVERABLE_MODELS:
|
||||
if display_name != target_display:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
logger.warning(f"[Vision] tools.vision.model='{user_model}' routes to "
|
||||
f"'{display_name}' but '{config_key}' is not configured. "
|
||||
f"Falling back to auto-discovery.")
|
||||
return None # fall through to auto
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
logger.warning(f"[Vision] '{display_name}' bot does not implement call_vision.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Vision] Failed to create '{display_name}' bot: {e}")
|
||||
return None
|
||||
|
||||
return [VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=user_model,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)]
|
||||
|
||||
return None
|
||||
|
||||
def _append_other_model_providers(self, providers: List[VisionProvider],
|
||||
preferred_model: Optional[str] = None) -> None:
|
||||
"""
|
||||
Auto-discover other models whose API key is configured.
|
||||
Skip the main model's own bot_type (already covered by MainModel
|
||||
provider), unless the main model itself does not support vision —
|
||||
in that case we still want the vendor's dedicated vision model
|
||||
as a fallback. Also skip bot_types that already appear in the
|
||||
provider list.
|
||||
|
||||
If preferred_model matches a provider's family, use it instead
|
||||
of that provider's hard-coded default model.
|
||||
"""
|
||||
main_bot_type = None
|
||||
main_bot_supports_vision = False
|
||||
if self.model and hasattr(self.model, '_resolve_bot_type'):
|
||||
main_bot_type = self.model._resolve_bot_type(conf().get("model", ""))
|
||||
main_bot = getattr(self.model, "bot", None)
|
||||
main_bot_supports_vision = self._main_bot_supports_vision(main_bot)
|
||||
|
||||
existing_names = {p.name for p in providers}
|
||||
preferred_provider = self._infer_provider_from_model(preferred_model) if preferred_model else None
|
||||
|
||||
for config_key, bot_type, default_model, display_name in _DISCOVERABLE_MODELS:
|
||||
if display_name in existing_names:
|
||||
continue
|
||||
# Same bot_type as the main model is normally handled by the
|
||||
# MainModel provider; only skip it here if the main model
|
||||
# actually supports vision. Otherwise fall through and add
|
||||
# the vendor's dedicated vision model as a fallback.
|
||||
if bot_type == main_bot_type and main_bot_supports_vision:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
model_for_provider = (preferred_model
|
||||
if preferred_provider == display_name and preferred_model
|
||||
else default_model)
|
||||
|
||||
provider = VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=model_for_provider,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)
|
||||
|
||||
# Same vendor as the main bot is the most natural fallback when
|
||||
# the main model itself does not support vision — promote it to
|
||||
# the front of the list instead of relying on declaration order.
|
||||
if bot_type == main_bot_type:
|
||||
providers.insert(0, provider)
|
||||
else:
|
||||
providers.append(provider)
|
||||
|
||||
def _main_bot_supports_vision(self, bot) -> bool:
|
||||
"""
|
||||
Whether the main bot is known to natively support vision.
|
||||
|
||||
Having a `call_vision` method is necessary but not sufficient —
|
||||
some bots implement the method against an endpoint that does not
|
||||
actually serve vision models, which causes silent failures when a
|
||||
vendor-foreign model name is forwarded.
|
||||
|
||||
Resolution order:
|
||||
1. If the bot explicitly declares `supports_vision`, trust it.
|
||||
This lets bots opt in or out based on their own runtime
|
||||
configuration (e.g. the currently selected model).
|
||||
2. Otherwise, fall back to a model-name prefix heuristic: trust
|
||||
call_vision when the main model looks like an OpenAI family
|
||||
model or matches a known multimodal vendor prefix.
|
||||
"""
|
||||
if bot is None:
|
||||
return False
|
||||
if hasattr(bot, "supports_vision"):
|
||||
return bool(getattr(bot, "supports_vision"))
|
||||
main_model = (conf().get("model") or "").lower()
|
||||
if not main_model:
|
||||
return False
|
||||
if main_model.startswith(_OPENAI_MODEL_PREFIXES):
|
||||
return True
|
||||
return self._infer_provider_from_model(main_model) is not None
|
||||
|
||||
def _build_main_model_provider(self) -> Optional[VisionProvider]:
|
||||
"""
|
||||
Use the vendor's own model for vision via bot.call_vision.
|
||||
Gated by _main_bot_supports_vision so non-vision bots (DeepSeek, etc.)
|
||||
do not get routed vendor-foreign model names.
|
||||
"""
|
||||
if not (self.model and hasattr(self.model, 'bot')):
|
||||
return None
|
||||
try:
|
||||
bot = self.model.bot
|
||||
except Exception:
|
||||
return None
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
return None
|
||||
if not self._main_bot_supports_vision(bot):
|
||||
return None
|
||||
|
||||
# Use the configured main model name; do NOT inject tools.vision.model
|
||||
# here, because by the time we reach this branch the tools.vision.model
|
||||
# routing has already been attempted (and either matched the main bot
|
||||
# or failed to find a provider).
|
||||
main_model_name = conf().get("model") or None
|
||||
|
||||
return VisionProvider(
|
||||
name=_MAIN_MODEL_PROVIDER_NAME,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=main_model_name,
|
||||
use_bot=True,
|
||||
)
|
||||
|
||||
def _build_openai_provider(self, preferred_model: Optional[str] = None) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.openai.com/v1"
|
||||
# Only honor preferred_model when it looks like an OpenAI-family name;
|
||||
# otherwise the OpenAI endpoint would 400 on a vendor-specific name.
|
||||
model_override = preferred_model if (
|
||||
preferred_model and preferred_model.lower().startswith(_OPENAI_MODEL_PREFIXES)
|
||||
) else None
|
||||
return VisionProvider(
|
||||
name="OpenAI",
|
||||
api_key=api_key,
|
||||
api_base=self._ensure_v1(api_base),
|
||||
model_override=model_override,
|
||||
)
|
||||
|
||||
def _build_linkai_provider(self, preferred_model: Optional[str] = None) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.link-ai.tech"
|
||||
from common.utils import get_cloud_headers
|
||||
extra = get_cloud_headers(api_key)
|
||||
extra.pop("Authorization", None)
|
||||
extra.pop("Content-Type", None)
|
||||
# LinkAI is a multi-vendor proxy and accepts most model names, so we
|
||||
# honor any user-configured model name here.
|
||||
return VisionProvider(
|
||||
name="LinkAI",
|
||||
api_key=api_key,
|
||||
api_base=self._ensure_v1(api_base),
|
||||
extra_headers=extra,
|
||||
model_override=preferred_model,
|
||||
)
|
||||
|
||||
def _call_via_bot(self, model: str, question: str, image_content: dict,
|
||||
provider: Optional[VisionProvider] = None) -> ToolResult:
|
||||
"""
|
||||
Call a model's call_vision with vendor-native API format.
|
||||
Uses the provider's _fallback_bot if set, otherwise the main model bot.
|
||||
Raises VisionAPIError on failure so fallback can proceed.
|
||||
"""
|
||||
try:
|
||||
bot = (provider and provider.fallback_bot) or self.model.bot
|
||||
except Exception as e:
|
||||
raise VisionAPIError(f"Cannot access bot: {e}")
|
||||
|
||||
# Extract the raw image URL from the OpenAI-format image_content block
|
||||
image_url = image_content.get("image_url", {}).get("url", "")
|
||||
if not image_url:
|
||||
raise VisionAPIError("No image URL in content block")
|
||||
|
||||
try:
|
||||
response = bot.call_vision(
|
||||
image_url=image_url,
|
||||
question=question,
|
||||
model=model,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
except Exception as e:
|
||||
raise VisionAPIError(f"call_vision failed: {e}")
|
||||
|
||||
if response is NotImplemented:
|
||||
raise VisionAPIError("Bot does not support vision")
|
||||
|
||||
if isinstance(response, dict) and response.get("error"):
|
||||
raise VisionAPIError(f"API error - {response.get('message', 'Unknown')}")
|
||||
|
||||
content = response.get("content", "") if isinstance(response, dict) else ""
|
||||
if not content:
|
||||
raise VisionAPIError("Empty response from main model")
|
||||
|
||||
usage_info = response.get("usage", {}) if isinstance(response, dict) else {}
|
||||
|
||||
# Use the actual model name from the bot response if available
|
||||
actual_model = response.get("model", model) if isinstance(response, dict) else model
|
||||
provider_name = provider.name if provider else _MAIN_MODEL_PROVIDER_NAME
|
||||
return ToolResult.success({
|
||||
"model": actual_model,
|
||||
"provider": provider_name,
|
||||
"content": content,
|
||||
"usage": usage_info,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _ensure_v1(api_base: str) -> str:
|
||||
"""Append /v1 if the base URL doesn't already end with a versioned path."""
|
||||
if not api_base:
|
||||
return api_base
|
||||
# Already has /v1 or similar version suffix
|
||||
if api_base.rstrip("/").split("/")[-1].startswith("v"):
|
||||
return api_base
|
||||
return api_base.rstrip("/") + "/v1"
|
||||
|
||||
def _build_image_content(self, image: str) -> dict:
|
||||
"""
|
||||
Build the image_url content block.
|
||||
Both remote URLs and local files are converted to base64 data URLs
|
||||
so every bot backend can consume them without extra downloads.
|
||||
"""
|
||||
if image.startswith(("http://", "https://")):
|
||||
return self._download_to_data_url(image)
|
||||
|
||||
if not os.path.isfile(image):
|
||||
raise FileNotFoundError(f"Image file not found: {image}")
|
||||
|
||||
ext = image.rsplit(".", 1)[-1].lower() if "." in image else ""
|
||||
mime_type = SUPPORTED_EXTENSIONS.get(ext)
|
||||
if not mime_type:
|
||||
raise ValueError(
|
||||
f"Unsupported image format '.{ext}'. "
|
||||
f"Supported: {', '.join(SUPPORTED_EXTENSIONS.keys())}"
|
||||
)
|
||||
|
||||
file_path = self._maybe_compress(image)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
finally:
|
||||
if file_path != image and os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _download_to_data_url(url: str) -> dict:
|
||||
"""Download a remote image and return it as a base64 data URL."""
|
||||
resp = requests.get(url, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
raise VisionAPIError(f"Failed to download image: HTTP {resp.status_code}")
|
||||
content_type = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
|
||||
if not content_type.startswith("image/"):
|
||||
content_type = "image/jpeg"
|
||||
b64 = base64.b64encode(resp.content).decode("ascii")
|
||||
data_url = f"data:{content_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _maybe_compress(path: str) -> str:
|
||||
"""Compress image to under COMPRESS_THRESHOLD with max long-edge 1536px."""
|
||||
file_size = os.path.getsize(path)
|
||||
if file_size <= COMPRESS_THRESHOLD:
|
||||
return path
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||||
tmp.close()
|
||||
|
||||
def _try_sips(max_dim: str, quality: str) -> bool:
|
||||
try:
|
||||
subprocess.run(
|
||||
["sips", "-Z", max_dim, "-s", "formatOptions", quality,
|
||||
path, "--out", tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
return True
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return False
|
||||
|
||||
def _try_convert(max_dim: str, quality: str) -> bool:
|
||||
try:
|
||||
subprocess.run(
|
||||
["convert", path, "-resize", f"{max_dim}x{max_dim}>",
|
||||
"-quality", quality, tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
return True
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return False
|
||||
|
||||
attempts = [
|
||||
("1536", "85"),
|
||||
("1536", "70"),
|
||||
("1536", "50"),
|
||||
]
|
||||
|
||||
for max_dim, quality in attempts:
|
||||
ok = _try_sips(max_dim, quality) or _try_convert(max_dim, quality)
|
||||
if not ok:
|
||||
continue
|
||||
new_size = os.path.getsize(tmp.name)
|
||||
logger.debug(f"[Vision] Compressed image "
|
||||
f"({file_size // 1024}KB -> {new_size // 1024}KB, "
|
||||
f"max_dim={max_dim}, q={quality})")
|
||||
if new_size <= COMPRESS_THRESHOLD:
|
||||
return tmp.name
|
||||
|
||||
if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > 0:
|
||||
return tmp.name
|
||||
|
||||
os.remove(tmp.name)
|
||||
return path
|
||||
|
||||
def _call_api(self, provider: VisionProvider, model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""
|
||||
Call a single provider's Vision API.
|
||||
Raises VisionAPIError on recoverable failures so the caller can try
|
||||
the next provider.
|
||||
"""
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": question},
|
||||
image_content,
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**provider.extra_headers,
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{provider.api_base}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise VisionAPIError(f"HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
msg = data["error"].get("message", "Unknown API error")
|
||||
raise VisionAPIError(f"API error - {msg}")
|
||||
|
||||
content = ""
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "")
|
||||
|
||||
usage = data.get("usage", {})
|
||||
result = {
|
||||
"model": model,
|
||||
"provider": provider.name,
|
||||
"content": content,
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
}
|
||||
return ToolResult.success(result)
|
||||
0
agent/tools/web_fetch/__init__.py
Normal file
0
agent/tools/web_fetch/__init__.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
Web Fetch tool - Fetch and extract readable content from web pages and remote files.
|
||||
|
||||
Supports:
|
||||
- HTML web pages: extracts readable text content
|
||||
- Document files (PDF, Word, TXT, Markdown, etc.): downloads to workspace/tmp and parses content
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, Set
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size
|
||||
from common.log import logger
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
|
||||
"Accept": "*/*",
|
||||
}
|
||||
|
||||
# Supported document file extensions
|
||||
PDF_SUFFIXES: Set[str] = {".pdf"}
|
||||
WORD_SUFFIXES: Set[str] = {".docx"}
|
||||
TEXT_SUFFIXES: Set[str] = {".txt", ".md", ".markdown", ".rst", ".csv", ".tsv", ".log"}
|
||||
SPREADSHEET_SUFFIXES: Set[str] = {".xls", ".xlsx"}
|
||||
PPT_SUFFIXES: Set[str] = {".ppt", ".pptx"}
|
||||
|
||||
ALL_DOC_SUFFIXES = PDF_SUFFIXES | WORD_SUFFIXES | TEXT_SUFFIXES | SPREADSHEET_SUFFIXES | PPT_SUFFIXES
|
||||
|
||||
_CHARSET_RE = re.compile(r'charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_CHARSET_RE = re.compile(rb'<meta[^>]+charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_HTTP_EQUIV_RE = re.compile(
|
||||
rb'<meta[^>]+http-equiv\s*=\s*["\']?Content-Type["\']?[^>]+content\s*=\s*["\'][^"\']*charset=([\w\-]+)',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _extract_charset_from_content_type(content_type: str) -> Optional[str]:
|
||||
"""Extract charset from Content-Type header value."""
|
||||
m = _CHARSET_RE.search(content_type)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def _extract_charset_from_html_meta(raw_bytes: bytes) -> Optional[str]:
|
||||
"""Extract charset from HTML <meta> tags in the first few KB of raw bytes."""
|
||||
m = _META_CHARSET_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
m = _META_HTTP_EQUIV_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
return None
|
||||
|
||||
|
||||
def _get_url_suffix(url: str) -> str:
|
||||
"""Extract file extension from URL path, ignoring query params."""
|
||||
path = urlparse(url).path
|
||||
return os.path.splitext(path)[-1].lower()
|
||||
|
||||
|
||||
def _is_document_url(url: str) -> bool:
|
||||
"""Check if URL points to a downloadable document file."""
|
||||
suffix = _get_url_suffix(url)
|
||||
return suffix in ALL_DOC_SUFFIXES
|
||||
|
||||
|
||||
class WebFetch(BaseTool):
|
||||
"""Tool for fetching web pages and remote document files"""
|
||||
|
||||
name: str = "web_fetch"
|
||||
description: str = (
|
||||
"Fetch content from a http/https URL. For web pages, extracts readable text. "
|
||||
"For document files (PDF, Word, TXT, Markdown, Excel, PPT), downloads and parses the file content. "
|
||||
"Supported file types: .pdf, .docx, .txt, .md, .csv, .xls, .xlsx, .ppt, .pptx"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to fetch (web page or document file link)"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' parameter is required")
|
||||
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return ToolResult.fail("Error: Invalid URL (must start with http:// or https://)")
|
||||
|
||||
if _is_document_url(url):
|
||||
return self._fetch_document(url)
|
||||
|
||||
return self._fetch_webpage(url)
|
||||
|
||||
# ---- Web page fetching ----
|
||||
|
||||
def _fetch_webpage(self, url: str) -> ToolResult:
|
||||
"""Fetch and extract readable text from an HTML web page."""
|
||||
parsed = urlparse(url)
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: Failed to fetch URL: {e}")
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if self._is_binary_content_type(content_type) and not _is_document_url(url):
|
||||
return self._handle_download_by_content_type(url, response, content_type)
|
||||
|
||||
response.encoding = self._detect_encoding(response)
|
||||
html = response.text
|
||||
title = self._extract_title(html)
|
||||
text = self._extract_text(html)
|
||||
|
||||
return ToolResult.success(f"Title: {title}\n\nContent:\n{text}")
|
||||
|
||||
# ---- Document fetching ----
|
||||
|
||||
def _fetch_document(self, url: str) -> ToolResult:
|
||||
"""Download a document file and extract its text content."""
|
||||
suffix = _get_url_suffix(url)
|
||||
parsed = urlparse(url)
|
||||
filename = self._extract_filename(url)
|
||||
tmp_dir = self._ensure_tmp_dir()
|
||||
|
||||
local_path = os.path.join(tmp_dir, filename)
|
||||
logger.info(f"[WebFetch] Downloading document: {url} -> {local_path}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
content_length = int(response.headers.get("Content-Length", 0))
|
||||
if content_length > MAX_FILE_SIZE:
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large ({format_size(content_length)} > {format_size(MAX_FILE_SIZE)})"
|
||||
)
|
||||
|
||||
downloaded = 0
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
downloaded += len(chunk)
|
||||
if downloaded > MAX_FILE_SIZE:
|
||||
f.close()
|
||||
os.remove(local_path)
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large (>{format_size(MAX_FILE_SIZE)}), download aborted"
|
||||
)
|
||||
f.write(chunk)
|
||||
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Download timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to download file: {e}")
|
||||
|
||||
try:
|
||||
text = self._parse_document(local_path, suffix)
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to parse document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
file_size = os.path.getsize(local_path)
|
||||
return ToolResult.success(
|
||||
f"File downloaded to: {local_path} ({format_size(file_size)})\n"
|
||||
f"No text content could be extracted. The file may contain only images or be encrypted."
|
||||
)
|
||||
|
||||
truncation = truncate_head(text)
|
||||
result_text = truncation.content
|
||||
|
||||
file_size = os.path.getsize(local_path)
|
||||
header = f"[Document: {filename} | Size: {format_size(file_size)} | Saved to: {local_path}]\n\n"
|
||||
|
||||
if truncation.truncated:
|
||||
header += f"[Content truncated: showing {truncation.output_lines} of {truncation.total_lines} lines]\n\n"
|
||||
|
||||
return ToolResult.success(header + result_text)
|
||||
|
||||
def _parse_document(self, file_path: str, suffix: str) -> str:
|
||||
"""Parse document file and return extracted text."""
|
||||
if suffix in PDF_SUFFIXES:
|
||||
return self._parse_pdf(file_path)
|
||||
elif suffix in WORD_SUFFIXES:
|
||||
return self._parse_word(file_path)
|
||||
elif suffix in TEXT_SUFFIXES:
|
||||
return self._parse_text(file_path)
|
||||
elif suffix in SPREADSHEET_SUFFIXES:
|
||||
return self._parse_spreadsheet(file_path)
|
||||
elif suffix in PPT_SUFFIXES:
|
||||
return self._parse_ppt(file_path)
|
||||
else:
|
||||
return self._parse_text(file_path)
|
||||
|
||||
def _parse_pdf(self, file_path: str) -> str:
|
||||
"""Extract text from PDF using pypdf."""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
raise ImportError("pypdf library is required for PDF parsing. Install with: pip install pypdf")
|
||||
|
||||
reader = PdfReader(file_path)
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text and page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num}/{len(reader.pages)} ---\n{page_text}")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
def _parse_word(self, file_path: str) -> str:
|
||||
"""Extract text from Word documents (.docx)."""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-docx library is required for .docx parsing. Install with: pip install python-docx"
|
||||
)
|
||||
doc = Document(file_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return "\n\n".join(paragraphs)
|
||||
|
||||
def _parse_text(self, file_path: str) -> str:
|
||||
"""Read plain text files (txt, md, csv, etc.)."""
|
||||
encodings = ["utf-8", "utf-8-sig", "gbk", "gb2312", "latin-1"]
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(file_path, "r", encoding=enc) as f:
|
||||
return f.read()
|
||||
except (UnicodeDecodeError, UnicodeError):
|
||||
continue
|
||||
raise ValueError(f"Unable to decode file with any supported encoding: {encodings}")
|
||||
|
||||
def _parse_spreadsheet(self, file_path: str) -> str:
|
||||
"""Extract text from Excel files (.xls/.xlsx)."""
|
||||
try:
|
||||
import openpyxl
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"openpyxl library is required for .xlsx parsing. Install with: pip install openpyxl"
|
||||
)
|
||||
|
||||
wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True)
|
||||
result_parts = []
|
||||
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
rows = []
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
cells = [str(c) if c is not None else "" for c in row]
|
||||
if any(cells):
|
||||
rows.append(" | ".join(cells))
|
||||
if rows:
|
||||
result_parts.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
|
||||
|
||||
wb.close()
|
||||
return "\n\n".join(result_parts)
|
||||
|
||||
def _parse_ppt(self, file_path: str) -> str:
|
||||
"""Extract text from PowerPoint files (.ppt/.pptx)."""
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-pptx library is required for .pptx parsing. Install with: pip install python-pptx"
|
||||
)
|
||||
|
||||
prs = Presentation(file_path)
|
||||
text_parts = []
|
||||
|
||||
for slide_num, slide in enumerate(prs.slides, 1):
|
||||
slide_texts = []
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for paragraph in shape.text_frame.paragraphs:
|
||||
text = paragraph.text.strip()
|
||||
if text:
|
||||
slide_texts.append(text)
|
||||
if slide_texts:
|
||||
text_parts.append(f"--- Slide {slide_num}/{len(prs.slides)} ---\n" + "\n".join(slide_texts))
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
# ---- Encoding detection ----
|
||||
|
||||
@staticmethod
|
||||
def _detect_encoding(response: requests.Response) -> str:
|
||||
"""Detect response encoding with priority: Content-Type header > HTML meta > chardet > utf-8."""
|
||||
# 1. Check Content-Type header for explicit charset
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
charset = _extract_charset_from_content_type(content_type)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 2. Scan raw bytes for HTML meta charset declaration
|
||||
raw = response.content[:4096]
|
||||
charset = _extract_charset_from_html_meta(raw)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 3. Use apparent_encoding (chardet-based detection) if confident enough
|
||||
apparent = response.apparent_encoding
|
||||
if apparent:
|
||||
apparent_lower = apparent.lower()
|
||||
# Trust CJK / Windows encodings detected by chardet
|
||||
trusted_prefixes = ("utf", "gb", "big5", "euc", "shift_jis", "iso-2022", "windows", "ascii")
|
||||
if any(apparent_lower.startswith(p) for p in trusted_prefixes):
|
||||
return apparent
|
||||
|
||||
# 4. Fallback
|
||||
return "utf-8"
|
||||
|
||||
# ---- Helper methods ----
|
||||
|
||||
def _ensure_tmp_dir(self) -> str:
|
||||
"""Ensure workspace/tmp directory exists and return its path."""
|
||||
tmp_dir = os.path.join(self.cwd, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
def _extract_filename(self, url: str) -> str:
|
||||
"""Extract a safe filename from URL, with a short UUID prefix to avoid collisions."""
|
||||
path = urlparse(url).path
|
||||
basename = os.path.basename(unquote(path))
|
||||
if not basename or basename == "/":
|
||||
basename = "downloaded_file"
|
||||
# Sanitize: keep only safe chars
|
||||
basename = re.sub(r'[^\w.\-]', '_', basename)
|
||||
short_id = uuid.uuid4().hex[:8]
|
||||
return f"{short_id}_{basename}"
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_file(path: str):
|
||||
"""Remove a file if it exists, ignoring errors."""
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_binary_content_type(content_type: str) -> bool:
|
||||
"""Check if Content-Type indicates a binary/document response."""
|
||||
binary_types = [
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/octet-stream",
|
||||
]
|
||||
ct_lower = content_type.lower()
|
||||
return any(bt in ct_lower for bt in binary_types)
|
||||
|
||||
def _handle_download_by_content_type(self, url: str, response: requests.Response, content_type: str) -> ToolResult:
|
||||
"""Handle a URL that returned binary content instead of HTML."""
|
||||
ct_lower = content_type.lower()
|
||||
suffix_map = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml": ".docx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml": ".xlsx",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml": ".pptx",
|
||||
}
|
||||
detected_suffix = None
|
||||
for ct_prefix, ext in suffix_map.items():
|
||||
if ct_prefix in ct_lower:
|
||||
detected_suffix = ext
|
||||
break
|
||||
|
||||
if detected_suffix and detected_suffix in ALL_DOC_SUFFIXES:
|
||||
# Re-fetch as document
|
||||
return self._fetch_document(url if _get_url_suffix(url) in ALL_DOC_SUFFIXES
|
||||
else self._rewrite_url_with_suffix(url, detected_suffix))
|
||||
return ToolResult.fail(f"Error: URL returned binary content ({content_type}), not a supported document type")
|
||||
|
||||
@staticmethod
|
||||
def _rewrite_url_with_suffix(url: str, suffix: str) -> str:
|
||||
"""Append a suffix to the URL path so _get_url_suffix works correctly."""
|
||||
parsed = urlparse(url)
|
||||
new_path = parsed.path.rstrip("/") + suffix
|
||||
return parsed._replace(path=new_path).geturl()
|
||||
|
||||
# ---- HTML extraction (unchanged) ----
|
||||
|
||||
@staticmethod
|
||||
def _extract_title(html: str) -> str:
|
||||
match = re.search(r"<title[^>]*>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
|
||||
return match.group(1).strip() if match else "Untitled"
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(html: str) -> str:
|
||||
text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(""", '"').replace("'", "'").replace(" ", " ")
|
||||
text = re.sub(r"[^\S\n]+", " ", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
lines = [line.strip() for line in text.splitlines()]
|
||||
text = "\n".join(lines)
|
||||
return text.strip()
|
||||
3
agent/tools/web_search/__init__.py
Normal file
3
agent/tools/web_search/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
|
||||
__all__ = ["WebSearch"]
|
||||
487
agent/tools/web_search/web_search.py
Normal file
487
agent/tools/web_search/web_search.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""Web Search tool. Supports four backends with a unified response format:
|
||||
- bocha (https://open.bochaai.com)
|
||||
- zhipu (https://docs.bigmodel.cn/cn/guide/tools/web-search)
|
||||
- qianfan (https://cloud.baidu.com/doc/qianfan/s/2mh4su4uy)
|
||||
- linkai (https://link-ai.tech, fallback)
|
||||
|
||||
Provider selection
|
||||
- strategy 'auto' (default): pick the first configured provider in the
|
||||
canonical order [bocha, zhipu, qianfan, linkai]. When the caller passes
|
||||
an explicit `provider` it overrides the pick; an invalid/unconfigured
|
||||
one silently falls back to the auto order.
|
||||
- strategy 'fixed': use the configured provider; if its credential is
|
||||
missing at call time, silently fall back to auto order (no card hint).
|
||||
|
||||
Credentials
|
||||
- bocha : tools.web_search.bocha_api_key -> env BOCHA_API_KEY
|
||||
- zhipu : conf.zhipu_ai_api_key -> env ZHIPUAI_API_KEY
|
||||
- qianfan : conf.qianfan_api_key -> env QIANFAN_API_KEY
|
||||
- linkai : conf.linkai_api_key -> env LINKAI_API_KEY
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
|
||||
# Canonical fallback order. Empirically ordered by Chinese real-time
|
||||
# quality + relevance: bocha (best overall), qianfan (best for hot news),
|
||||
# zhipu (strong on long-form articles), linkai (cloud aggregator, last
|
||||
# resort).
|
||||
PROVIDER_ORDER = ("bocha", "qianfan", "zhipu", "linkai")
|
||||
|
||||
PROVIDER_LABELS = {
|
||||
"bocha": "Bocha",
|
||||
"zhipu": "Zhipu",
|
||||
"qianfan": "Baidu Qianfan",
|
||||
"linkai": "LinkAI",
|
||||
}
|
||||
|
||||
|
||||
def _tools_web_search_conf() -> dict:
|
||||
"""Return the tools.web_search config block (dict-like)."""
|
||||
tools_cfg = conf().get("tools") or {}
|
||||
if not isinstance(tools_cfg, dict):
|
||||
return {}
|
||||
block = tools_cfg.get("web_search") or {}
|
||||
return block if isinstance(block, dict) else {}
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> str:
|
||||
"""Resolve API key for a provider, with conf -> env fallback."""
|
||||
if provider == "bocha":
|
||||
key = (_tools_web_search_conf().get("bocha_api_key") or "").strip()
|
||||
return key or os.environ.get("BOCHA_API_KEY", "").strip()
|
||||
if provider == "zhipu":
|
||||
key = (conf().get("zhipu_ai_api_key") or "").strip()
|
||||
return key or os.environ.get("ZHIPUAI_API_KEY", "").strip()
|
||||
if provider == "qianfan":
|
||||
key = (conf().get("qianfan_api_key") or "").strip()
|
||||
return key or os.environ.get("QIANFAN_API_KEY", "").strip()
|
||||
if provider == "linkai":
|
||||
key = (conf().get("linkai_api_key") or "").strip()
|
||||
return key or os.environ.get("LINKAI_API_KEY", "").strip()
|
||||
return ""
|
||||
|
||||
|
||||
def configured_providers() -> List[str]:
|
||||
"""Return configured providers in canonical order."""
|
||||
return [p for p in PROVIDER_ORDER if _get_api_key(p)]
|
||||
|
||||
|
||||
def _configured_strategy() -> str:
|
||||
return (_tools_web_search_conf().get("strategy") or "auto").strip().lower()
|
||||
|
||||
|
||||
def _configured_provider() -> str:
|
||||
return (_tools_web_search_conf().get("provider") or "").strip().lower()
|
||||
|
||||
|
||||
class WebSearch(BaseTool):
|
||||
"""Tool for searching the web across multiple providers."""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = "Search the web for real-time information. Returns titles, URLs, and snippets."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query string"
|
||||
},
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-50, default: 10)"
|
||||
},
|
||||
"freshness": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time range filter. Options: "
|
||||
"'noLimit' (default), 'oneDay', 'oneWeek', 'oneMonth', 'oneYear', "
|
||||
"or date range like '2025-01-01..2025-02-01'"
|
||||
)
|
||||
},
|
||||
"summary": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to include text summary for each result (default: false)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
"""Tool is offered to the agent when at least one provider has a key."""
|
||||
return bool(configured_providers())
|
||||
|
||||
@classmethod
|
||||
def get_json_schema(cls) -> dict:
|
||||
"""Augment the static schema with a `provider` field — only when the
|
||||
user has ≥2 providers configured AND strategy is 'auto'. Otherwise
|
||||
the backend picks silently and exposing the field would only waste
|
||||
the agent's tokens."""
|
||||
schema = {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": json.loads(json.dumps(cls.params)), # deep copy
|
||||
}
|
||||
if _configured_strategy() != "auto":
|
||||
return schema
|
||||
available = configured_providers()
|
||||
if len(available) < 2:
|
||||
return schema
|
||||
|
||||
schema["parameters"]["properties"]["provider"] = {
|
||||
"type": "string",
|
||||
"enum": available,
|
||||
"description": "Optional. Specifies the search backend. You may switch between providers when the user wants results from a particular source or from multiple sources.",
|
||||
}
|
||||
return schema
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider resolution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_provider(self, requested: Optional[str]) -> Optional[str]:
|
||||
"""Pick a provider for this call.
|
||||
|
||||
Priority: caller-supplied (if configured) > fixed strategy (if
|
||||
configured) > first configured in PROVIDER_ORDER. Silent fallback
|
||||
when the desired one has no key.
|
||||
"""
|
||||
available = configured_providers()
|
||||
if not available:
|
||||
return None
|
||||
|
||||
if requested:
|
||||
req = requested.strip().lower()
|
||||
if req in available:
|
||||
return req
|
||||
logger.warning(f"[WebSearch] requested provider '{requested}' unavailable, falling back")
|
||||
|
||||
if _configured_strategy() == "fixed":
|
||||
pinned = _configured_provider()
|
||||
if pinned in available:
|
||||
return pinned
|
||||
if pinned:
|
||||
logger.warning(f"[WebSearch] pinned provider '{pinned}' unavailable, falling back to auto")
|
||||
|
||||
return available[0]
|
||||
|
||||
@staticmethod
|
||||
def _resolution_reason(requested: Optional[str], chosen: str) -> str:
|
||||
"""Human-readable explanation for why `chosen` won the resolver."""
|
||||
if requested and requested.strip().lower() == chosen:
|
||||
return "caller-requested"
|
||||
strategy = _configured_strategy()
|
||||
if strategy == "fixed" and _configured_provider() == chosen:
|
||||
return "fixed-strategy"
|
||||
return "auto-fallback"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
query = (args.get("query") or "").strip()
|
||||
if not query:
|
||||
return ToolResult.fail("Error: 'query' parameter is required")
|
||||
|
||||
count = args.get("count", 10)
|
||||
freshness = args.get("freshness", "noLimit")
|
||||
summary = args.get("summary", False)
|
||||
if not isinstance(count, int) or count < 1 or count > 50:
|
||||
count = 10
|
||||
|
||||
requested = args.get("provider")
|
||||
provider = self._resolve_provider(requested)
|
||||
if not provider:
|
||||
return ToolResult.fail(
|
||||
"Error: No search provider configured. "
|
||||
"Configure one of BOCHA_API_KEY / zhipu_ai_api_key / qianfan_api_key / linkai_api_key."
|
||||
)
|
||||
|
||||
# Always log the routing decision so multi-provider deployments can
|
||||
# tell at a glance which backend served any given query.
|
||||
available = configured_providers()
|
||||
reason = self._resolution_reason(requested, provider)
|
||||
q_preview = query if len(query) <= 60 else (query[:57] + "...")
|
||||
logger.info(
|
||||
f"[WebSearch] provider={provider} reason={reason} "
|
||||
f"available={list(available)} query={q_preview!r} count={count} freshness={freshness}"
|
||||
)
|
||||
|
||||
try:
|
||||
if provider == "bocha":
|
||||
return self._search_bocha(query, count, freshness, summary)
|
||||
if provider == "zhipu":
|
||||
return self._search_zhipu(query, count, freshness)
|
||||
if provider == "qianfan":
|
||||
return self._search_qianfan(query, count, freshness)
|
||||
if provider == "linkai":
|
||||
return self._search_linkai(query, count, freshness)
|
||||
return ToolResult.fail(f"Error: Unknown provider '{provider}'")
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Search request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail("Error: Failed to connect to search API")
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSearch] Unexpected error ({provider}): {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: Search failed - {str(e)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bocha
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_bocha(self, query: str, count: int, freshness: str, summary: bool) -> ToolResult:
|
||||
api_key = _get_api_key("bocha")
|
||||
url = "https://api.bochaai.com/v1/web-search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"query": query, "count": count, "freshness": freshness, "summary": summary}
|
||||
|
||||
logger.debug(f"[WebSearch] bocha: query='{query}', count={count}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid bocha API key.")
|
||||
if resp.status_code == 403:
|
||||
return ToolResult.fail("Error: bocha API — insufficient balance. Top up at https://open.bochaai.com")
|
||||
if resp.status_code == 429:
|
||||
return ToolResult.fail("Error: bocha API rate limit reached.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: bocha API returned HTTP {resp.status_code}")
|
||||
|
||||
data = resp.json()
|
||||
api_code = data.get("code")
|
||||
if api_code is not None and api_code != 200:
|
||||
msg = data.get("msg") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: bocha API error (code={api_code}): {msg}")
|
||||
|
||||
pages = (data.get("data") or {}).get("webPages", {}).get("value", []) or []
|
||||
results = []
|
||||
for p in pages:
|
||||
item = {
|
||||
"title": p.get("name", ""),
|
||||
"url": p.get("url", ""),
|
||||
"snippet": p.get("snippet", ""),
|
||||
"siteName": p.get("siteName", ""),
|
||||
"datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""),
|
||||
}
|
||||
if p.get("summary"):
|
||||
item["summary"] = p["summary"]
|
||||
results.append(item)
|
||||
total = (data.get("data") or {}).get("webPages", {}).get("totalEstimatedMatches", len(results))
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "bocha",
|
||||
"total": total, "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zhipu
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_zhipu(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("zhipu")
|
||||
api_base = (conf().get("zhipu_ai_api_base") or "https://open.bigmodel.cn/api/paas/v4").rstrip("/")
|
||||
url = f"{api_base}/web_search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Zhipu Web Search expects `search_query` <= 70 chars; truncate
|
||||
# gracefully so a long agent-supplied query doesn't get rejected.
|
||||
trimmed_query = (query or "")[:70]
|
||||
engine = (_tools_web_search_conf().get("zhipu_search_engine") or "search_pro").strip().lower()
|
||||
if engine not in ("search_std", "search_pro", "search_pro_sogou", "search_pro_quark"):
|
||||
engine = "search_pro"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"search_engine": engine,
|
||||
"search_query": trimmed_query,
|
||||
"search_intent": False,
|
||||
"count": max(1, min(int(count or 10), 50)),
|
||||
"search_recency_filter": freshness if freshness in (
|
||||
"oneDay", "oneWeek", "oneMonth", "oneYear", "noLimit"
|
||||
) else "noLimit",
|
||||
}
|
||||
content_size = (_tools_web_search_conf().get("zhipu_content_size") or "").strip().lower()
|
||||
if content_size in ("medium", "high"):
|
||||
payload["content_size"] = content_size
|
||||
|
||||
logger.debug(f"[WebSearch] zhipu: query='{trimmed_query}', count={payload['count']}, engine={engine}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid Zhipu API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Zhipu API returned HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
# Business-level errors (1701/1702/1703 etc.) come back as
|
||||
# {"error": {"code","message"}} even on HTTP 200.
|
||||
if isinstance(data, dict) and data.get("error"):
|
||||
err = data["error"] or {}
|
||||
return ToolResult.fail(f"Error: Zhipu returned {err.get('code')}: {err.get('message','')}")
|
||||
|
||||
items = data.get("search_result") or (data.get("data") or {}).get("search_result") or []
|
||||
results = []
|
||||
for it in items:
|
||||
results.append({
|
||||
"title": it.get("title", ""),
|
||||
"url": it.get("link") or it.get("url", ""),
|
||||
"snippet": it.get("content") or it.get("snippet", ""),
|
||||
"siteName": it.get("media") or it.get("siteName", ""),
|
||||
"datePublished": it.get("publish_date") or it.get("datePublished", ""),
|
||||
})
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "zhipu",
|
||||
"total": len(results), "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Qianfan (Baidu)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_qianfan(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("qianfan")
|
||||
api_base = (conf().get("qianfan_api_base") or "https://qianfan.baidubce.com/v2").rstrip("/")
|
||||
url = f"{api_base}/ai_search/web_search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-Appbuilder-From": "cow",
|
||||
}
|
||||
|
||||
count = max(1, min(int(count or 10), 50))
|
||||
payload: Dict[str, Any] = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "web", "top_k": count}],
|
||||
}
|
||||
|
||||
# Baidu AI Search expects freshness as a date-range filter, not a
|
||||
# named recency token. Translate our shared vocabulary into the
|
||||
# underlying page_time range expected by the API.
|
||||
search_filter = self._qianfan_build_freshness_filter(freshness)
|
||||
if search_filter:
|
||||
payload["search_filter"] = search_filter
|
||||
|
||||
logger.debug(f"[WebSearch] qianfan: query='{query}', count={count}, freshness={freshness!r}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid Qianfan API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Qianfan API returned HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
# Even on HTTP 200 Baidu surfaces business errors as {"code","message"}.
|
||||
if isinstance(data, dict) and data.get("code"):
|
||||
return ToolResult.fail(f"Error: Qianfan returned {data.get('code')}: {data.get('message','')}")
|
||||
|
||||
refs = data.get("references") or []
|
||||
results = []
|
||||
for d in refs:
|
||||
results.append({
|
||||
"title": d.get("title", ""),
|
||||
"url": d.get("url", ""),
|
||||
"snippet": (d.get("content") or "")[:200],
|
||||
"siteName": d.get("web_anchor") or d.get("website") or "",
|
||||
"datePublished": d.get("date", ""),
|
||||
})
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "qianfan",
|
||||
"total": len(results), "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _qianfan_build_freshness_filter(freshness: str) -> Optional[Dict[str, Any]]:
|
||||
if not freshness or freshness == "noLimit":
|
||||
return None
|
||||
delta_days = {"oneDay": 1, "oneWeek": 7, "oneMonth": 30, "oneYear": 365}.get(freshness)
|
||||
if not delta_days:
|
||||
return None
|
||||
from datetime import datetime, timedelta
|
||||
now = datetime.now()
|
||||
end_date = (now + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
start_date = (now - timedelta(days=delta_days)).strftime("%Y-%m-%d")
|
||||
return {"range": {"page_time": {"gte": start_date, "lt": end_date}}}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LinkAI (plugin)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("linkai")
|
||||
api_base = (conf().get("linkai_api_base") or "https://api.link-ai.tech").rstrip("/")
|
||||
url = f"{api_base}/v1/plugin/execute"
|
||||
|
||||
from common.utils import get_cloud_headers
|
||||
headers = get_cloud_headers(api_key)
|
||||
|
||||
payload = {"code": "web-search", "args": {"query": query, "count": count, "freshness": freshness}}
|
||||
logger.debug(f"[WebSearch] linkai: query='{query}', count={count}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid LinkAI API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: LinkAI API returned HTTP {resp.status_code}")
|
||||
|
||||
data = resp.json()
|
||||
if not data.get("success"):
|
||||
msg = data.get("message") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: LinkAI search failed: {msg}")
|
||||
|
||||
raw = data.get("data", "")
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
raw = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "linkai",
|
||||
"total": 1, "count": 1, "results": [{"content": raw}],
|
||||
})
|
||||
|
||||
if isinstance(raw, dict):
|
||||
pages = (raw.get("webPages") or {}).get("value", []) or []
|
||||
if pages:
|
||||
results = []
|
||||
for p in pages:
|
||||
item = {
|
||||
"title": p.get("name", ""),
|
||||
"url": p.get("url", ""),
|
||||
"snippet": p.get("snippet", ""),
|
||||
"siteName": p.get("siteName", ""),
|
||||
"datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""),
|
||||
}
|
||||
if p.get("summary"):
|
||||
item["summary"] = p["summary"]
|
||||
results.append(item)
|
||||
total = (raw.get("webPages") or {}).get("totalEstimatedMatches", len(results))
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "linkai",
|
||||
"total": total, "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "linkai",
|
||||
"total": 1, "count": 1, "results": [{"content": str(raw)}],
|
||||
})
|
||||
@@ -8,6 +8,7 @@ from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Write(BaseTool):
|
||||
@@ -90,7 +91,7 @@ class Write(BaseTool):
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
370
app.py
370
app.py
@@ -7,11 +7,261 @@ import time
|
||||
|
||||
from channel import channel_factory
|
||||
from common import const
|
||||
from config import load_config
|
||||
from common.log import logger
|
||||
from config import load_config, conf
|
||||
from plugins import *
|
||||
import threading
|
||||
|
||||
|
||||
_channel_mgr = None
|
||||
|
||||
|
||||
def get_channel_manager():
|
||||
return _channel_mgr
|
||||
|
||||
|
||||
def _parse_channel_type(raw) -> list:
|
||||
"""
|
||||
Parse channel_type config value into a list of channel names.
|
||||
Supports:
|
||||
- single string: "feishu"
|
||||
- comma-separated string: "feishu, dingtalk"
|
||||
- list: ["feishu", "dingtalk"]
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return [ch.strip() for ch in raw if ch.strip()]
|
||||
if isinstance(raw, str):
|
||||
return [ch.strip() for ch in raw.split(",") if ch.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manage the lifecycle of multiple channels running concurrently.
|
||||
Each channel.startup() runs in its own daemon thread.
|
||||
The web channel is started as default console unless explicitly disabled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._channels = {} # channel_name -> channel instance
|
||||
self._threads = {} # channel_name -> thread
|
||||
self._primary_channel = None
|
||||
self._lock = threading.Lock()
|
||||
self.cloud_mode = False # set to True when cloud client is active
|
||||
|
||||
@property
|
||||
def channel(self):
|
||||
"""Return the primary (first non-web) channel for backward compatibility."""
|
||||
return self._primary_channel
|
||||
|
||||
def get_channel(self, channel_name: str):
|
||||
return self._channels.get(channel_name)
|
||||
|
||||
def start(self, channel_names: list, first_start: bool = False):
|
||||
"""
|
||||
Create and start one or more channels in sub-threads.
|
||||
If first_start is True, plugins and linkai client will also be initialized.
|
||||
"""
|
||||
with self._lock:
|
||||
channels = []
|
||||
for name in channel_names:
|
||||
ch = channel_factory.create_channel(name)
|
||||
ch.cloud_mode = self.cloud_mode
|
||||
self._channels[name] = ch
|
||||
channels.append((name, ch))
|
||||
if self._primary_channel is None and name != "web":
|
||||
self._primary_channel = ch
|
||||
|
||||
if self._primary_channel is None and channels:
|
||||
self._primary_channel = channels[0][1]
|
||||
|
||||
if first_start:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
# Cloud client is optional. It is only started when
|
||||
# use_linkai=True AND cloud_deployment_id is set.
|
||||
# By default neither is configured, so the app runs
|
||||
# entirely locally without any remote connection.
|
||||
if conf().get("use_linkai") and (
|
||||
os.environ.get("CLOUD_DEPLOYMENT_ID") or conf().get("cloud_deployment_id")
|
||||
):
|
||||
try:
|
||||
from common import cloud_client
|
||||
threading.Thread(
|
||||
target=cloud_client.start,
|
||||
args=(self._primary_channel, self),
|
||||
daemon=True,
|
||||
).start()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Start web console first so its logs print cleanly,
|
||||
# then start remaining channels after a brief pause.
|
||||
web_entry = None
|
||||
other_entries = []
|
||||
for entry in channels:
|
||||
if entry[0] == "web":
|
||||
web_entry = entry
|
||||
else:
|
||||
other_entries.append(entry)
|
||||
|
||||
ordered = ([web_entry] if web_entry else []) + other_entries
|
||||
for i, (name, ch) in enumerate(ordered):
|
||||
if i > 0 and name != "web":
|
||||
time.sleep(0.1)
|
||||
t = threading.Thread(target=self._run_channel, args=(name, ch), daemon=True)
|
||||
self._threads[name] = t
|
||||
t.start()
|
||||
logger.debug(f"[ChannelManager] Channel '{name}' started in sub-thread")
|
||||
|
||||
def _run_channel(self, name: str, channel):
|
||||
try:
|
||||
channel.startup()
|
||||
except Exception as e:
|
||||
logger.error(f"[ChannelManager] Channel '{name}' startup error: {e}")
|
||||
logger.exception(e)
|
||||
|
||||
def stop(self, channel_name: str = None):
|
||||
"""
|
||||
Stop channel(s). If channel_name is given, stop only that channel;
|
||||
otherwise stop all channels.
|
||||
"""
|
||||
# Pop under lock, then stop outside lock to avoid deadlock
|
||||
with self._lock:
|
||||
names = [channel_name] if channel_name else list(self._channels.keys())
|
||||
to_stop = []
|
||||
for name in names:
|
||||
ch = self._channels.pop(name, None)
|
||||
th = self._threads.pop(name, None)
|
||||
to_stop.append((name, ch, th))
|
||||
if channel_name and self._primary_channel is self._channels.get(channel_name):
|
||||
self._primary_channel = None
|
||||
|
||||
for name, ch, th in to_stop:
|
||||
if ch is None:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' not found in managed channels")
|
||||
if th and th.is_alive():
|
||||
self._interrupt_thread(th, name)
|
||||
continue
|
||||
logger.info(f"[ChannelManager] Stopping channel '{name}'...")
|
||||
graceful = False
|
||||
if hasattr(ch, 'stop'):
|
||||
try:
|
||||
ch.stop()
|
||||
graceful = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Error during channel '{name}' stop: {e}")
|
||||
if th and th.is_alive():
|
||||
th.join(timeout=5)
|
||||
if th.is_alive():
|
||||
if graceful:
|
||||
logger.info(f"[ChannelManager] Channel '{name}' thread still alive after stop(), "
|
||||
"leaving daemon thread to finish on its own")
|
||||
else:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' thread did not exit in 5s, forcing interrupt")
|
||||
self._interrupt_thread(th, name)
|
||||
|
||||
@staticmethod
|
||||
def _interrupt_thread(th: threading.Thread, name: str):
|
||||
"""Raise SystemExit in target thread to break blocking loops like start_forever."""
|
||||
import ctypes
|
||||
try:
|
||||
tid = th.ident
|
||||
if tid is None:
|
||||
return
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
|
||||
)
|
||||
if res == 1:
|
||||
logger.info(f"[ChannelManager] Interrupted thread for channel '{name}'")
|
||||
elif res > 1:
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
|
||||
logger.warning(f"[ChannelManager] Failed to interrupt thread for channel '{name}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Thread interrupt error for '{name}': {e}")
|
||||
|
||||
def restart(self, new_channel_name: str):
|
||||
"""
|
||||
Restart a single channel with a new channel type.
|
||||
Can be called from any thread (e.g. linkai config callback).
|
||||
"""
|
||||
logger.info(f"[ChannelManager] Restarting channel to '{new_channel_name}'...")
|
||||
self.stop(new_channel_name)
|
||||
_clear_singleton_cache(new_channel_name)
|
||||
time.sleep(1)
|
||||
self.start([new_channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel restarted to '{new_channel_name}' successfully")
|
||||
|
||||
def add_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically add and start a new channel.
|
||||
If the channel is already running, restart it instead.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name in self._channels:
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' already exists, restarting")
|
||||
if self._channels.get(channel_name):
|
||||
self.restart(channel_name)
|
||||
return
|
||||
logger.info(f"[ChannelManager] Adding channel '{channel_name}'...")
|
||||
_clear_singleton_cache(channel_name)
|
||||
self.start([channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' added successfully")
|
||||
|
||||
def remove_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically stop and remove a running channel.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name not in self._channels:
|
||||
logger.warning(f"[ChannelManager] Channel '{channel_name}' not found, nothing to remove")
|
||||
return
|
||||
logger.info(f"[ChannelManager] Removing channel '{channel_name}'...")
|
||||
self.stop(channel_name)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' removed successfully")
|
||||
|
||||
|
||||
def _clear_singleton_cache(channel_name: str):
|
||||
"""
|
||||
Clear the singleton cache for the channel class so that
|
||||
a new instance can be created with updated config.
|
||||
"""
|
||||
cls_map = {
|
||||
"web": "channel.web.web_channel.WebChannel",
|
||||
"wechatmp": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatmp_service": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatcom_app": "channel.wechatcom.wechatcomapp_channel.WechatComAppChannel",
|
||||
const.WECHAT_KF: "channel.wechat_kf.wechat_kf_channel.WechatKfChannel",
|
||||
const.FEISHU: "channel.feishu.feishu_channel.FeiShuChanel",
|
||||
const.DINGTALK: "channel.dingtalk.dingtalk_channel.DingTalkChanel",
|
||||
const.WECOM_BOT: "channel.wecom_bot.wecom_bot_channel.WecomBotChannel",
|
||||
const.QQ: "channel.qq.qq_channel.QQChannel",
|
||||
const.WEIXIN: "channel.weixin.weixin_channel.WeixinChannel",
|
||||
"wx": "channel.weixin.weixin_channel.WeixinChannel",
|
||||
}
|
||||
module_path = cls_map.get(channel_name)
|
||||
if not module_path:
|
||||
return
|
||||
try:
|
||||
parts = module_path.rsplit(".", 1)
|
||||
module_name, class_name = parts[0], parts[1]
|
||||
import importlib
|
||||
module = importlib.import_module(module_name)
|
||||
wrapper = getattr(module, class_name, None)
|
||||
if wrapper and hasattr(wrapper, '__closure__') and wrapper.__closure__:
|
||||
for cell in wrapper.__closure__:
|
||||
try:
|
||||
cell_contents = cell.cell_contents
|
||||
if isinstance(cell_contents, dict):
|
||||
cell_contents.clear()
|
||||
logger.debug(f"[ChannelManager] Cleared singleton cache for {class_name}")
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Failed to clear singleton cache: {e}")
|
||||
|
||||
|
||||
def sigterm_handler_wrap(_signo):
|
||||
old_handler = signal.getsignal(_signo)
|
||||
|
||||
@@ -25,22 +275,65 @@ def sigterm_handler_wrap(_signo):
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def start_channel(channel_name: str):
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "web", "wechatmp_service", "wechatcom_app", "wework",
|
||||
const.FEISHU, const.DINGTALK]:
|
||||
PluginManager().load_plugins()
|
||||
def _warmup_mcp_tools():
|
||||
"""
|
||||
Kick off MCP server loading at process startup so subprocesses
|
||||
(npx / uvx etc.) finish initializing before the first user message
|
||||
arrives. Returns immediately — the actual work happens on a daemon
|
||||
thread inside ToolManager. Safe to call when MCP is not configured.
|
||||
"""
|
||||
try:
|
||||
from agent.tools import ToolManager
|
||||
ToolManager()._load_mcp_tools()
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] MCP warmup failed (non-fatal): {e}")
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
try:
|
||||
from common import linkai_client
|
||||
threading.Thread(target=linkai_client.start, args=(channel,)).start()
|
||||
except Exception as e:
|
||||
pass
|
||||
channel.startup()
|
||||
|
||||
def _warmup_scheduler():
|
||||
"""Eager-init AgentBridge so the scheduler thread starts at process
|
||||
boot rather than waiting for the first user message."""
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
Bridge().get_agent_bridge()
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Scheduler warmup failed: {e}")
|
||||
|
||||
|
||||
def _sync_builtin_skills():
|
||||
"""Sync builtin skills from project skills/ to workspace skills/ on startup."""
|
||||
import shutil
|
||||
try:
|
||||
workspace = conf().get("agent_workspace", "~/cow")
|
||||
workspace = os.path.expanduser(workspace)
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
builtin_dir = os.path.join(project_root, "skills")
|
||||
custom_dir = os.path.join(workspace, "skills")
|
||||
|
||||
if not os.path.isdir(builtin_dir):
|
||||
return
|
||||
|
||||
os.makedirs(custom_dir, exist_ok=True)
|
||||
synced = 0
|
||||
for name in os.listdir(builtin_dir):
|
||||
src = os.path.join(builtin_dir, name)
|
||||
if not os.path.isdir(src) or not os.path.isfile(os.path.join(src, "SKILL.md")):
|
||||
continue
|
||||
dst = os.path.join(custom_dir, name)
|
||||
try:
|
||||
if os.path.isdir(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst)
|
||||
synced += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Failed to sync builtin skill '{name}': {e}")
|
||||
if synced:
|
||||
logger.info(f"[App] Synced {synced} builtin skill(s) to workspace")
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Builtin skills sync failed: {e}")
|
||||
|
||||
|
||||
def run():
|
||||
global _channel_mgr
|
||||
try:
|
||||
# load config
|
||||
load_config()
|
||||
@@ -49,36 +342,39 @@ def run():
|
||||
# kill signal
|
||||
sigterm_handler_wrap(signal.SIGTERM)
|
||||
|
||||
# create channel
|
||||
channel_name = conf().get("channel_type", "wx")
|
||||
# Parse channel_type into a list
|
||||
raw_channel = conf().get("channel_type", "web")
|
||||
|
||||
if "--cmd" in sys.argv:
|
||||
channel_name = "terminal"
|
||||
|
||||
if channel_name == "wxy":
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
|
||||
start_channel(channel_name)
|
||||
|
||||
# 打印系统运行成功信息
|
||||
logger.info("")
|
||||
logger.info("=" * 50)
|
||||
if conf().get("agent", False):
|
||||
logger.info("✅ System started successfully!")
|
||||
logger.info("🐮 Cow Agent is running")
|
||||
logger.info(f" Channel: {channel_name}")
|
||||
logger.info(f" Model: {conf().get('model', 'unknown')}")
|
||||
logger.info(f" Workspace: {conf().get('agent_workspace', '~/cow')}")
|
||||
channel_names = ["terminal"]
|
||||
else:
|
||||
logger.info("✅ System started successfully!")
|
||||
logger.info("🤖 ChatBot is running")
|
||||
logger.info(f" Channel: {channel_name}")
|
||||
logger.info(f" Model: {conf().get('model', 'unknown')}")
|
||||
logger.info("=" * 50)
|
||||
logger.info("")
|
||||
channel_names = _parse_channel_type(raw_channel)
|
||||
if not channel_names:
|
||||
channel_names = ["web"]
|
||||
|
||||
# Auto-start web console unless explicitly disabled
|
||||
web_console_enabled = conf().get("web_console", True)
|
||||
if web_console_enabled and "web" not in channel_names:
|
||||
channel_names.append("web")
|
||||
|
||||
# Sync builtin skills to workspace before channels start
|
||||
_sync_builtin_skills()
|
||||
|
||||
# Kick off MCP server loading in the background so first-message
|
||||
# latency isn't dominated by npx package downloads.
|
||||
_warmup_mcp_tools()
|
||||
|
||||
_warmup_scheduler()
|
||||
|
||||
logger.info(f"[App] Starting channels: {channel_names}")
|
||||
|
||||
_channel_mgr = ChannelManager()
|
||||
_channel_mgr.start(channel_names, first_start=True)
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
125
bridge/agent_event_handler.py
Normal file
125
bridge/agent_event_handler.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Agent Event Handler - Handles agent events and thinking process output
|
||||
"""
|
||||
|
||||
from common import const
|
||||
from common.log import logger
|
||||
|
||||
# Cap intermediate thinking messages on weixin to stay within send quota.
|
||||
WEIXIN_THINKING_INSTANT_MAX = 7
|
||||
|
||||
|
||||
class AgentEventHandler:
|
||||
"""
|
||||
Handles agent events and optionally sends intermediate messages to channel
|
||||
"""
|
||||
|
||||
def __init__(self, context=None, original_callback=None):
|
||||
self.context = context
|
||||
self.original_callback = original_callback
|
||||
|
||||
self.channel = None
|
||||
if context:
|
||||
self.channel = context.kwargs.get("channel") if hasattr(context, "kwargs") else None
|
||||
|
||||
self.current_content = ""
|
||||
self.turn_number = 0
|
||||
|
||||
channel_type = ""
|
||||
if context and hasattr(context, "kwargs"):
|
||||
channel_type = context.kwargs.get("channel_type", "") or ""
|
||||
self._is_weixin = channel_type == const.WEIXIN
|
||||
self._thinking_sent_count = 0
|
||||
self._merged_buf: list[str] = []
|
||||
|
||||
def handle_event(self, event):
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "turn_start":
|
||||
self._handle_turn_start(data)
|
||||
elif event_type == "message_update":
|
||||
self._handle_message_update(data)
|
||||
elif event_type == "message_end":
|
||||
self._handle_message_end(data)
|
||||
elif event_type == "reasoning_update":
|
||||
pass
|
||||
elif event_type == "tool_execution_start":
|
||||
self._handle_tool_execution_start(data)
|
||||
elif event_type == "tool_execution_end":
|
||||
self._handle_tool_execution_end(data)
|
||||
elif event_type == "agent_end":
|
||||
self._handle_agent_end(data)
|
||||
|
||||
if self.original_callback:
|
||||
self.original_callback(event)
|
||||
|
||||
def _handle_turn_start(self, data):
|
||||
self.turn_number = data.get("turn", 0)
|
||||
self.current_content = ""
|
||||
|
||||
def _handle_message_update(self, data):
|
||||
delta = data.get("delta", "")
|
||||
self.current_content += delta
|
||||
|
||||
def _handle_message_end(self, data):
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
|
||||
if tool_calls:
|
||||
if self.current_content.strip():
|
||||
logger.info(f"💭 {self.current_content.strip()[:200]}{'...' if len(self.current_content) > 200 else ''}")
|
||||
self._send_to_channel(self.current_content.strip())
|
||||
else:
|
||||
if self.current_content.strip():
|
||||
logger.debug(f"💬 {self.current_content.strip()[:200]}{'...' if len(self.current_content) > 200 else ''}")
|
||||
# Drain weixin buffer before final reply leaves chat_channel
|
||||
self._flush_merged_now()
|
||||
|
||||
self.current_content = ""
|
||||
|
||||
def _handle_agent_end(self, data):
|
||||
self._flush_merged_now()
|
||||
|
||||
def _handle_tool_execution_start(self, data):
|
||||
pass
|
||||
|
||||
def _handle_tool_execution_end(self, data):
|
||||
pass
|
||||
|
||||
def _send_to_channel(self, message):
|
||||
if self.context and self.context.get("on_event"):
|
||||
return
|
||||
if not self.channel:
|
||||
return
|
||||
|
||||
if not self._is_weixin:
|
||||
self._do_send(message)
|
||||
return
|
||||
|
||||
if self._thinking_sent_count < WEIXIN_THINKING_INSTANT_MAX:
|
||||
self._do_send(message)
|
||||
self._thinking_sent_count += 1
|
||||
return
|
||||
|
||||
self._merged_buf.append(message)
|
||||
|
||||
def _flush_merged_now(self):
|
||||
if not self._merged_buf:
|
||||
return
|
||||
merged = "\n\n".join(self._merged_buf)
|
||||
count = len(self._merged_buf)
|
||||
self._merged_buf = []
|
||||
logger.debug(f"[AgentEventHandler] Flushing {count} merged thinking msgs, len={len(merged)}")
|
||||
self._do_send(merged)
|
||||
self._thinking_sent_count += 1
|
||||
|
||||
def _do_send(self, message):
|
||||
try:
|
||||
from bridge.reply import Reply, ReplyType
|
||||
reply = Reply(ReplyType.TEXT, message)
|
||||
self.channel._send(reply, self.context)
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentEventHandler] Failed to send to channel: {e}")
|
||||
|
||||
def log_summary(self):
|
||||
pass
|
||||
823
bridge/agent_initializer.py
Normal file
823
bridge/agent_initializer.py
Normal file
@@ -0,0 +1,823 @@
|
||||
"""
|
||||
Agent Initializer - Handles agent initialization logic
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, List
|
||||
|
||||
from agent.protocol import Agent
|
||||
from agent.tools import ToolManager
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
# Module-level lock to serialize scheduler init across concurrent sessions
|
||||
_scheduler_init_lock = threading.Lock()
|
||||
|
||||
# Track whether the embedding model log has been printed in this process,
|
||||
# so we avoid spamming it once per session.
|
||||
_embedding_logged: bool = False
|
||||
|
||||
|
||||
class AgentInitializer:
|
||||
"""
|
||||
Handles agent initialization including:
|
||||
- Workspace setup
|
||||
- Memory system initialization
|
||||
- Tool loading
|
||||
- System prompt building
|
||||
"""
|
||||
|
||||
def __init__(self, bridge, agent_bridge):
|
||||
"""
|
||||
Initialize agent initializer
|
||||
|
||||
Args:
|
||||
bridge: COW bridge instance
|
||||
agent_bridge: AgentBridge instance (for create_agent method)
|
||||
"""
|
||||
self.bridge = bridge
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def initialize_agent(self, session_id: Optional[str] = None) -> Agent:
|
||||
"""
|
||||
Initialize agent for a session
|
||||
|
||||
Args:
|
||||
session_id: Session ID (None for default agent)
|
||||
|
||||
Returns:
|
||||
Initialized agent instance
|
||||
"""
|
||||
from config import conf
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
|
||||
# Migrate API keys
|
||||
self._migrate_config_to_env(workspace_root)
|
||||
|
||||
# Load environment variables
|
||||
self._load_env_file()
|
||||
|
||||
# Initialize workspace
|
||||
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
|
||||
workspace_files = ensure_workspace(workspace_root, create_templates=True)
|
||||
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Workspace initialized at: {workspace_root}")
|
||||
|
||||
# Setup memory system
|
||||
memory_manager, memory_tools = self._setup_memory_system(workspace_root, session_id)
|
||||
|
||||
# Load tools
|
||||
tools = self._load_tools(workspace_root, memory_manager, memory_tools, session_id)
|
||||
|
||||
# Initialize scheduler if needed
|
||||
self._initialize_scheduler(tools, session_id)
|
||||
|
||||
# Load context files
|
||||
context_files = load_context_files(workspace_root)
|
||||
|
||||
# Initialize skill manager
|
||||
skill_manager = self._initialize_skill_manager(workspace_root, session_id)
|
||||
|
||||
# Build system prompt
|
||||
prompt_builder = PromptBuilder(workspace_dir=workspace_root, language="zh")
|
||||
runtime_info = self._get_runtime_info(workspace_root)
|
||||
|
||||
system_prompt = prompt_builder.build(
|
||||
tools=tools,
|
||||
context_files=context_files,
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
)
|
||||
|
||||
# Get cost control parameters
|
||||
from config import conf
|
||||
max_steps = conf().get("agent_max_steps", 20)
|
||||
max_context_tokens = conf().get("agent_max_context_tokens", 50000)
|
||||
|
||||
# Create agent
|
||||
agent = self.agent_bridge.create_agent(
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
max_steps=max_steps,
|
||||
output_mode="logger",
|
||||
workspace_dir=workspace_root,
|
||||
skill_manager=skill_manager,
|
||||
enable_skills=True,
|
||||
max_context_tokens=max_context_tokens,
|
||||
runtime_info=runtime_info # Pass runtime_info for dynamic time updates
|
||||
)
|
||||
|
||||
# Attach memory manager and share LLM model for summarization
|
||||
if memory_manager:
|
||||
agent.memory_manager = memory_manager
|
||||
if hasattr(agent, 'model') and agent.model:
|
||||
memory_manager.flush_manager.llm_model = agent.model
|
||||
|
||||
# Restore persisted conversation history for this session
|
||||
if session_id:
|
||||
self._restore_conversation_history(agent, session_id)
|
||||
|
||||
# Start daily memory flush timer (once, on first agent init regardless of session)
|
||||
self._start_daily_flush_timer()
|
||||
|
||||
return agent
|
||||
|
||||
def _restore_conversation_history(self, agent, session_id: str) -> None:
|
||||
"""
|
||||
Load persisted conversation messages from SQLite and inject them
|
||||
into the agent's in-memory message list.
|
||||
|
||||
Only user text and assistant text are restored. Tool call chains
|
||||
(tool_use / tool_result) are stripped out because:
|
||||
1. They are intermediate process, the value is already in the final
|
||||
assistant text reply.
|
||||
2. They consume massive context tokens (often 80%+ of history).
|
||||
3. Different models have incompatible tool message formats, so
|
||||
restoring tool chains across model switches causes 400 errors.
|
||||
4. Eliminates the entire class of tool_use/tool_result pairing bugs.
|
||||
"""
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
store = get_conversation_store()
|
||||
max_turns = conf().get("agent_max_context_turns", 20)
|
||||
# Scheduler tasks run on a stable isolated session per task and
|
||||
# can fire many times a day; a smaller restore window keeps prompt
|
||||
# cost bounded while still letting the agent see "last few" runs
|
||||
# for trend / dedup style logic. Regular chat sessions keep the
|
||||
# original heuristic so user dialogues feel continuous.
|
||||
if session_id.startswith("scheduler_"):
|
||||
restore_turns = max(1, max_turns // 5)
|
||||
else:
|
||||
restore_turns = max(3, max_turns // 6)
|
||||
saved = store.load_messages(session_id, max_turns=restore_turns)
|
||||
if saved:
|
||||
filtered = self._filter_text_only_messages(saved)
|
||||
if filtered:
|
||||
with agent.messages_lock:
|
||||
agent.messages = filtered
|
||||
logger.debug(
|
||||
f"[AgentInitializer] Restored {len(filtered)} text messages "
|
||||
f"(from {len(saved)} total, {restore_turns} turns cap) "
|
||||
f"for session={session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentInitializer] Failed to restore conversation history for "
|
||||
f"session={session_id}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_text_only_messages(messages: list) -> list:
|
||||
"""
|
||||
Extract clean user/assistant turn pairs from raw message history.
|
||||
|
||||
Groups messages into turns (each starting with a real user query),
|
||||
then keeps only:
|
||||
- The first user text in each turn (the actual user input)
|
||||
- The last assistant text in each turn (the final answer)
|
||||
|
||||
All tool_use, tool_result, intermediate assistant thoughts, and
|
||||
internal hint messages injected by the agent loop are discarded.
|
||||
"""
|
||||
|
||||
def _extract_text(content) -> str:
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
def _is_real_user_msg(msg: dict) -> bool:
|
||||
"""True for actual user input, False for tool_result or internal hints."""
|
||||
if msg.get("role") != "user":
|
||||
return False
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
has_tool_result = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
)
|
||||
if has_tool_result:
|
||||
return False
|
||||
text = _extract_text(content)
|
||||
return bool(text)
|
||||
|
||||
# Group into turns: each turn starts with a real user message
|
||||
turns = []
|
||||
current_turn = None
|
||||
for msg in messages:
|
||||
if _is_real_user_msg(msg):
|
||||
if current_turn is not None:
|
||||
turns.append(current_turn)
|
||||
current_turn = {"user": msg, "assistants": []}
|
||||
elif current_turn is not None and msg.get("role") == "assistant":
|
||||
text = _extract_text(msg.get("content"))
|
||||
if text:
|
||||
current_turn["assistants"].append(text)
|
||||
if current_turn is not None:
|
||||
turns.append(current_turn)
|
||||
|
||||
# Build result: one user msg + one assistant msg per turn
|
||||
filtered = []
|
||||
for turn in turns:
|
||||
user_text = _extract_text(turn["user"].get("content"))
|
||||
if not user_text:
|
||||
continue
|
||||
filtered.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if turn["assistants"]:
|
||||
final_reply = turn["assistants"][-1]
|
||||
filtered.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": final_reply}]
|
||||
})
|
||||
|
||||
return filtered
|
||||
|
||||
def _load_env_file(self):
|
||||
"""Load environment variables from .env file"""
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(env_file, override=True)
|
||||
except ImportError:
|
||||
logger.warning("[AgentInitializer] python-dotenv not installed")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to load .env file: {e}")
|
||||
|
||||
def _setup_memory_system(self, workspace_root: str, session_id: Optional[str] = None):
|
||||
"""
|
||||
Setup memory system
|
||||
|
||||
Returns:
|
||||
(memory_manager, memory_tools) tuple
|
||||
"""
|
||||
memory_manager = None
|
||||
memory_tools = []
|
||||
|
||||
try:
|
||||
from agent.memory import MemoryManager, MemoryConfig
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
from config import conf
|
||||
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
embedding_provider = self._init_embedding_provider(
|
||||
memory_config, session_id=session_id
|
||||
)
|
||||
|
||||
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
self._sync_memory(memory_manager, session_id)
|
||||
|
||||
memory_tools = [
|
||||
MemorySearchTool(memory_manager),
|
||||
MemoryGetTool(memory_manager)
|
||||
]
|
||||
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] Memory system initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Memory system not available: {e}")
|
||||
|
||||
return memory_manager, memory_tools
|
||||
|
||||
def _init_embedding_provider(self, memory_config, session_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the embedding provider for memory.
|
||||
|
||||
Two paths:
|
||||
A. Default (no `embedding_provider` in config.json):
|
||||
Auto-init OpenAI -> LinkAI fallback. Existing 1536-dim indices
|
||||
keep working.
|
||||
B. Explicit (`embedding_provider` is set):
|
||||
Initialize the requested vendor with unified dim (default 1024).
|
||||
If the index was built with a different dim, vector search will
|
||||
quietly return no results (cosine returns 0) and keyword search
|
||||
takes over until the user runs /memory rebuild-index.
|
||||
"""
|
||||
from agent.memory import create_embedding_provider
|
||||
from config import conf
|
||||
|
||||
explicit_provider = (conf().get("embedding_provider") or "").strip().lower()
|
||||
|
||||
if not explicit_provider:
|
||||
return self._init_embedding_provider_legacy(session_id=session_id)
|
||||
|
||||
return self._init_embedding_provider_explicit(
|
||||
memory_config, explicit_provider, session_id=session_id,
|
||||
)
|
||||
|
||||
def _init_embedding_provider_legacy(self, session_id: Optional[str] = None):
|
||||
"""Legacy auto-init path: OpenAI -> LinkAI. Preserved verbatim for compat."""
|
||||
from agent.memory import create_embedding_provider
|
||||
from config import conf
|
||||
|
||||
embedding_provider = None
|
||||
embedding_model = None
|
||||
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
model = "text-embedding-3-small"
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model=model,
|
||||
api_key=openai_api_key,
|
||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||
)
|
||||
embedding_model = f"openai/{model}"
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
|
||||
|
||||
if embedding_provider is None:
|
||||
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
|
||||
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
model = "text-embedding-3-small"
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="linkai",
|
||||
model=model,
|
||||
api_key=linkai_api_key,
|
||||
api_base=f"{linkai_api_base}/v1"
|
||||
)
|
||||
embedding_model = f"linkai/{model}"
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
|
||||
|
||||
if embedding_provider is not None and embedding_model:
|
||||
global _embedding_logged
|
||||
if not _embedding_logged:
|
||||
logger.info(
|
||||
f"[AgentInitializer] Embedding model in use: {embedding_model} "
|
||||
f"(dim={embedding_provider.dimensions})"
|
||||
)
|
||||
_embedding_logged = True
|
||||
|
||||
return embedding_provider
|
||||
|
||||
def _init_embedding_provider_explicit(
|
||||
self,
|
||||
memory_config,
|
||||
provider_key: str,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
"""Explicit-provider path: build the configured vendor.
|
||||
|
||||
If the index was built with a different dim, vector search will
|
||||
silently return no results (cosine returns 0 for mismatched dims)
|
||||
and keyword search takes over. Users switch vendors by running
|
||||
/memory rebuild-index — see docs.
|
||||
"""
|
||||
from agent.memory import create_embedding_provider
|
||||
from agent.memory.embedding import EMBEDDING_VENDORS
|
||||
from config import conf
|
||||
|
||||
meta = EMBEDDING_VENDORS.get(provider_key)
|
||||
if meta is None:
|
||||
logger.error(
|
||||
f"[AgentInitializer] Unknown embedding_provider '{provider_key}'. "
|
||||
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}. "
|
||||
f"Memory will run in keyword-only mode."
|
||||
)
|
||||
return None
|
||||
|
||||
api_key = self._resolve_embedding_api_key(provider_key)
|
||||
api_base = self._resolve_embedding_api_base(provider_key, meta["default_base_url"])
|
||||
|
||||
if not api_key:
|
||||
logger.error(
|
||||
f"[AgentInitializer] embedding_provider='{provider_key}' is set but its "
|
||||
f"API key is missing. Memory will run in keyword-only mode."
|
||||
)
|
||||
return None
|
||||
|
||||
model = (conf().get("embedding_model") or "").strip() or meta["default_model"]
|
||||
try:
|
||||
cfg_dim = int(conf().get("embedding_dimensions") or 0)
|
||||
except (TypeError, ValueError):
|
||||
cfg_dim = 0
|
||||
dim = cfg_dim if cfg_dim > 0 else meta["default_dimensions"]
|
||||
|
||||
try:
|
||||
provider = create_embedding_provider(
|
||||
provider=provider_key,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
dimensions=dim,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AgentInitializer] Failed to init embedding provider "
|
||||
f"'{provider_key}/{model}': {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
global _embedding_logged
|
||||
if not _embedding_logged:
|
||||
logger.info(
|
||||
f"[AgentInitializer] Embedding model in use: "
|
||||
f"{provider_key}/{model} (dim={provider.dimensions})"
|
||||
)
|
||||
_embedding_logged = True
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def _resolve_embedding_api_key(provider_key: str) -> str:
|
||||
"""Pick the API key for an explicit embedding provider from config."""
|
||||
from config import conf
|
||||
|
||||
key_map = {
|
||||
"openai": "open_ai_api_key",
|
||||
"linkai": "linkai_api_key",
|
||||
"dashscope": "dashscope_api_key",
|
||||
"doubao": "ark_api_key",
|
||||
"zhipu": "zhipu_ai_api_key",
|
||||
}
|
||||
field = key_map.get(provider_key)
|
||||
if not field:
|
||||
return ""
|
||||
value = conf().get(field, "") or ""
|
||||
if value in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
return ""
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _resolve_embedding_api_base(provider_key: str, default_base: str) -> str:
|
||||
"""Pick the API base for an explicit embedding provider from config."""
|
||||
from config import conf
|
||||
|
||||
base_map = {
|
||||
"openai": "open_ai_api_base",
|
||||
"linkai": "linkai_api_base",
|
||||
"doubao": "ark_base_url",
|
||||
"zhipu": "zhipu_ai_api_base",
|
||||
}
|
||||
field = base_map.get(provider_key)
|
||||
if not field:
|
||||
return default_base
|
||||
value = (conf().get(field) or "").strip()
|
||||
if not value:
|
||||
return default_base
|
||||
if provider_key == "linkai" and not value.rstrip("/").endswith("/v1"):
|
||||
return f"{value.rstrip('/')}/v1"
|
||||
return value
|
||||
|
||||
def _sync_memory(self, memory_manager, session_id: Optional[str] = None):
|
||||
"""Sync memory database"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed")
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
if loop.is_running():
|
||||
asyncio.create_task(memory_manager.sync())
|
||||
else:
|
||||
loop.run_until_complete(memory_manager.sync())
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Memory sync failed: {e}")
|
||||
|
||||
def _load_tools(self, workspace_root: str, memory_manager, memory_tools: List, session_id: Optional[str] = None):
|
||||
"""Load all tools"""
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.load_tools()
|
||||
|
||||
tools = []
|
||||
file_config = {
|
||||
"cwd": workspace_root,
|
||||
"memory_manager": memory_manager
|
||||
} if memory_manager else {"cwd": workspace_root}
|
||||
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
# Skip web_search if no API key is available
|
||||
if tool_name == "web_search":
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
if not WebSearch.is_available():
|
||||
logger.debug("[AgentInitializer] WebSearch skipped - no search provider configured")
|
||||
continue
|
||||
|
||||
# Special handling for EnvConfig tool
|
||||
if tool_name == "env_config":
|
||||
from agent.tools import EnvConfig
|
||||
tool = EnvConfig({"agent_bridge": self.agent_bridge})
|
||||
else:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
if tool:
|
||||
# Apply workspace config to file operation tools.
|
||||
# Merge into the existing tool.config (set by ToolManager from
|
||||
# config.json's `tools.<name>` section) instead of replacing
|
||||
# it, otherwise per-tool user configs (e.g. browser.cdp_endpoint)
|
||||
# would be silently dropped.
|
||||
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls', 'web_fetch', 'send', 'browser']:
|
||||
merged_config = dict(getattr(tool, 'config', None) or {})
|
||||
merged_config.update(file_config)
|
||||
tool.config = merged_config
|
||||
tool.cwd = merged_config.get("cwd", getattr(tool, 'cwd', None))
|
||||
if 'memory_manager' in merged_config:
|
||||
tool.memory_manager = merged_config['memory_manager']
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to load tool {tool_name}: {e}")
|
||||
|
||||
# Add MCP tools (snapshot to avoid races with the background loader)
|
||||
mcp_tools_snapshot = list(tool_manager._mcp_tool_instances.items())
|
||||
if mcp_tools_snapshot:
|
||||
for _, mcp_tool in mcp_tools_snapshot:
|
||||
tools.append(mcp_tool)
|
||||
if session_id is None:
|
||||
names = [name for name, _ in mcp_tools_snapshot]
|
||||
logger.info(
|
||||
f"[AgentInitializer] Added {len(names)} MCP tool(s): {names}"
|
||||
)
|
||||
|
||||
# Add memory tools
|
||||
if memory_tools:
|
||||
tools.extend(memory_tools)
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Added {len(memory_tools)} memory tools")
|
||||
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Loaded {len(tools)} tools: {[t.name for t in tools]}")
|
||||
|
||||
return tools
|
||||
|
||||
def _initialize_scheduler(self, tools: List, session_id: Optional[str] = None):
|
||||
"""Initialize scheduler service if needed.
|
||||
|
||||
Serialize the check-and-set under a module-level lock so concurrent
|
||||
first-time session inits cannot each create a new SchedulerService
|
||||
(which would leak background scanning threads).
|
||||
"""
|
||||
if not self.agent_bridge.scheduler_initialized:
|
||||
with _scheduler_init_lock:
|
||||
if not self.agent_bridge.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import init_scheduler
|
||||
if init_scheduler(self.agent_bridge):
|
||||
self.agent_bridge.scheduler_initialized = True
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] Scheduler service initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize scheduler: {e}")
|
||||
|
||||
# Inject scheduler dependencies
|
||||
if self.agent_bridge.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
|
||||
from agent.tools import SchedulerTool
|
||||
from config import conf
|
||||
|
||||
task_store = get_task_store()
|
||||
scheduler_service = get_scheduler_service()
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, SchedulerTool):
|
||||
tool.task_store = task_store
|
||||
tool.scheduler_service = scheduler_service
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
raw_ct = conf().get("channel_type", "unknown")
|
||||
if isinstance(raw_ct, list):
|
||||
ct = raw_ct[0] if raw_ct else "unknown"
|
||||
elif isinstance(raw_ct, str) and "," in raw_ct:
|
||||
ct = raw_ct.split(",")[0].strip()
|
||||
else:
|
||||
ct = raw_ct
|
||||
tool.config["channel_type"] = ct
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to inject scheduler dependencies: {e}")
|
||||
|
||||
def _initialize_skill_manager(self, workspace_root: str, session_id: Optional[str] = None):
|
||||
"""Initialize skill manager"""
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
skill_manager = SkillManager(custom_dir=os.path.join(workspace_root, "skills"))
|
||||
return skill_manager
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize SkillManager: {e}")
|
||||
return None
|
||||
|
||||
def _get_runtime_info(self, workspace_root: str):
|
||||
"""Get runtime information with dynamic time support"""
|
||||
from config import conf
|
||||
|
||||
def get_current_time():
|
||||
"""Get current time dynamically - called each time system prompt is accessed"""
|
||||
now = datetime.datetime.now()
|
||||
|
||||
# Get timezone info
|
||||
try:
|
||||
offset = -time.timezone if not time.daylight else -time.altzone
|
||||
hours = offset // 3600
|
||||
minutes = (offset % 3600) // 60
|
||||
timezone_name = f"UTC{hours:+03d}:{minutes:02d}" if minutes else f"UTC{hours:+03d}"
|
||||
except Exception:
|
||||
timezone_name = "UTC"
|
||||
|
||||
# Weekday: English name in en, Chinese mapping otherwise
|
||||
weekday_en = now.strftime("%A")
|
||||
try:
|
||||
from common import i18n
|
||||
is_en = i18n.get_language() == "en"
|
||||
except Exception:
|
||||
is_en = False
|
||||
if is_en:
|
||||
weekday = weekday_en
|
||||
else:
|
||||
weekday_map = {
|
||||
'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三',
|
||||
'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'
|
||||
}
|
||||
weekday = weekday_map.get(weekday_en, weekday_en)
|
||||
|
||||
return {
|
||||
'time': now.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
'weekday': weekday,
|
||||
'timezone': timezone_name
|
||||
}
|
||||
|
||||
def get_model():
|
||||
"""Get current model name dynamically from config"""
|
||||
return conf().get("model", "unknown")
|
||||
|
||||
return {
|
||||
"_get_model": get_model,
|
||||
"workspace": workspace_root,
|
||||
"channel": ", ".join(conf().get("channel_type")) if isinstance(conf().get("channel_type"), list) else conf().get("channel_type", "unknown"),
|
||||
"_get_current_time": get_current_time # Dynamic time function
|
||||
}
|
||||
|
||||
def _migrate_config_to_env(self, workspace_root: str):
|
||||
"""Migrate API keys from config.json to .env file"""
|
||||
from config import conf
|
||||
|
||||
key_mapping = {
|
||||
"open_ai_api_key": "OPENAI_API_KEY",
|
||||
"open_ai_api_base": "OPENAI_API_BASE",
|
||||
"gemini_api_key": "GEMINI_API_KEY",
|
||||
"claude_api_key": "CLAUDE_API_KEY",
|
||||
"linkai_api_key": "LINKAI_API_KEY",
|
||||
}
|
||||
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
|
||||
# Read existing env vars (key -> value)
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
with open(env_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, val = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = val.strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to read .env file: {e}")
|
||||
|
||||
# Sync config.json values into .env (add/update/remove)
|
||||
updated = False
|
||||
for config_key, env_key in key_mapping.items():
|
||||
raw = conf().get(config_key, "")
|
||||
value = raw.strip() if raw else ""
|
||||
old_value = existing_env_vars.get(env_key)
|
||||
|
||||
if value:
|
||||
if old_value == value:
|
||||
continue
|
||||
existing_env_vars[env_key] = value
|
||||
os.environ[env_key] = value
|
||||
updated = True
|
||||
else:
|
||||
if old_value is None:
|
||||
continue
|
||||
existing_env_vars.pop(env_key, None)
|
||||
os.environ.pop(env_key, None)
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
try:
|
||||
env_dir = os.path.dirname(env_file)
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
|
||||
# Rewrite the entire .env file to ensure consistency
|
||||
with open(env_file, 'w', encoding='utf-8') as f:
|
||||
f.write('# Environment variables for agent\n')
|
||||
f.write('# Auto-managed - synced from config.json on startup\n\n')
|
||||
for key, value in sorted(existing_env_vars.items()):
|
||||
f.write(f'{key}={value}\n')
|
||||
|
||||
logger.info(f"[AgentInitializer] Synced API keys from config.json to .env")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to sync API keys: {e}")
|
||||
|
||||
def _start_daily_flush_timer(self):
|
||||
"""Start a background thread that flushes all agents' memory daily at 23:55."""
|
||||
if getattr(self.agent_bridge, '_daily_flush_started', False):
|
||||
return
|
||||
self.agent_bridge._daily_flush_started = True
|
||||
|
||||
import threading
|
||||
|
||||
def _daily_flush_loop():
|
||||
import random
|
||||
last_run_date = None # Track last successful run date to prevent same-day re-trigger
|
||||
while True:
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
jitter_min = random.randint(50, 55)
|
||||
jitter_sec = random.randint(0, 59)
|
||||
target = now.replace(hour=23, minute=jitter_min, second=jitter_sec, microsecond=0)
|
||||
# Always schedule for tomorrow if we already ran today, or if target time has passed
|
||||
if target <= now or (last_run_date == now.date()):
|
||||
target += datetime.timedelta(days=1)
|
||||
wait_seconds = (target - now).total_seconds()
|
||||
logger.info(f"[DailyFlush] Next flush at {target.strftime('%Y-%m-%d %H:%M:%S')} (in {wait_seconds/3600:.1f}h)")
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
self._flush_all_agents()
|
||||
last_run_date = datetime.datetime.now().date()
|
||||
except Exception as e:
|
||||
logger.warning(f"[DailyFlush] Error in daily flush loop: {e}")
|
||||
time.sleep(3600)
|
||||
|
||||
t = threading.Thread(target=_daily_flush_loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
def _flush_all_agents(self):
|
||||
"""Flush memory for all active agent sessions, then run Deep Dream."""
|
||||
agents = []
|
||||
if self.agent_bridge.default_agent:
|
||||
agents.append(("default", self.agent_bridge.default_agent))
|
||||
for sid, agent in self.agent_bridge.agents.items():
|
||||
agents.append((sid, agent))
|
||||
|
||||
if not agents:
|
||||
return
|
||||
|
||||
# Phase 1: flush daily summaries
|
||||
flushed = 0
|
||||
flush_threads = []
|
||||
dream_candidate = None
|
||||
for label, agent in agents:
|
||||
try:
|
||||
if not agent.memory_manager:
|
||||
continue
|
||||
with agent.messages_lock:
|
||||
messages = list(agent.messages)
|
||||
if not messages:
|
||||
continue
|
||||
result = agent.memory_manager.flush_manager.create_daily_summary(messages)
|
||||
if result:
|
||||
flushed += 1
|
||||
t = agent.memory_manager.flush_manager._last_flush_thread
|
||||
if t:
|
||||
flush_threads.append(t)
|
||||
if dream_candidate is None:
|
||||
dream_candidate = agent.memory_manager.flush_manager
|
||||
except Exception as e:
|
||||
logger.warning(f"[DailyFlush] Failed for session {label}: {e}")
|
||||
|
||||
if flushed:
|
||||
logger.info(f"[DailyFlush] Flushed {flushed}/{len(agents)} agent session(s)")
|
||||
|
||||
# Wait for all flush threads to finish before dreaming
|
||||
for t in flush_threads:
|
||||
t.join(timeout=60)
|
||||
|
||||
# Phase 2: Deep Dream — distill daily memories → MEMORY.md + dream diary
|
||||
if dream_candidate:
|
||||
try:
|
||||
result = dream_candidate.deep_dream()
|
||||
if result:
|
||||
logger.info("[DeepDream] Memory distillation completed successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DeepDream] Failed: {e}")
|
||||
@@ -13,8 +13,10 @@ from voice.factory import create_voice
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
self.btype = {
|
||||
"chat": const.CHATGPT,
|
||||
"voice_to_text": conf().get("voice_to_text", "openai"),
|
||||
"chat": const.OPENAI,
|
||||
# Empty `voice_to_text` (the default in new configs) triggers
|
||||
# the auto-pick below — see _auto_pick_voice_to_text for order.
|
||||
"voice_to_text": conf().get("voice_to_text") or self._auto_pick_voice_to_text(),
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
}
|
||||
@@ -24,6 +26,13 @@ class Bridge(object):
|
||||
self.btype["chat"] = bot_type
|
||||
else:
|
||||
model_type = conf().get("model") or const.GPT_41_MINI
|
||||
|
||||
# Ensure model_type is string to prevent AttributeError when using startswith()
|
||||
# This handles cases where numeric model names (e.g., "1") are parsed as integers from YAML
|
||||
if not isinstance(model_type, str):
|
||||
logger.warning(f"[Bridge] model_type is not a string: {model_type} (type: {type(model_type).__name__}), converting to string")
|
||||
model_type = str(model_type)
|
||||
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
@@ -32,9 +41,9 @@ class Bridge(object):
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
if model_type in [const.QWEN, const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and (model_type.startswith("qwen") or model_type.startswith("qwq") or model_type.startswith("qvq")):
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and model_type.startswith("gemini"):
|
||||
self.btype["chat"] = const.GEMINI
|
||||
@@ -43,16 +52,31 @@ class Bridge(object):
|
||||
if model_type and model_type.startswith("claude"):
|
||||
self.btype["chat"] = const.CLAUDEAPI
|
||||
|
||||
if model_type in ["claude"]:
|
||||
self.btype["chat"] = const.CLAUDEAI
|
||||
|
||||
if model_type in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
if model_type and model_type.startswith("kimi"):
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
|
||||
if model_type and model_type.startswith("doubao"):
|
||||
self.btype["chat"] = const.DOUBAO
|
||||
|
||||
if model_type and model_type.startswith("deepseek"):
|
||||
self.btype["chat"] = const.DEEPSEEK
|
||||
|
||||
# 小米 MiMo 系列模型,全部以 mimo- 开头
|
||||
if model_type and model_type.startswith("mimo-"):
|
||||
self.btype["chat"] = const.MIMO
|
||||
|
||||
if model_type and isinstance(model_type, str):
|
||||
lowered_model_type = model_type.lower()
|
||||
if lowered_model_type == const.QIANFAN or lowered_model_type.startswith("ernie"):
|
||||
self.btype["chat"] = const.QIANFAN
|
||||
|
||||
if model_type in [const.MODELSCOPE]:
|
||||
self.btype["chat"] = const.MODELSCOPE
|
||||
|
||||
if model_type in ["abab6.5-chat"]:
|
||||
# MiniMax models
|
||||
if model_type and (model_type in ["abab6.5-chat", "abab6.5"] or model_type.lower().startswith("minimax")):
|
||||
self.btype["chat"] = const.MiniMax
|
||||
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
@@ -66,6 +90,46 @@ class Bridge(object):
|
||||
self.chat_bots = {}
|
||||
self._agent_bridge = None
|
||||
|
||||
def refresh_voice(self):
|
||||
"""Re-read voice_to_text / text_to_voice from config and drop the
|
||||
cached voice bots so the next call picks up the new provider.
|
||||
Used by the web console after the user edits voice settings.
|
||||
Does NOT touch the agent_bridge / agent state.
|
||||
"""
|
||||
new_v2t = conf().get("voice_to_text") or self._auto_pick_voice_to_text()
|
||||
new_t2v = conf().get("text_to_voice", "google")
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
new_v2t = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
new_t2v = const.LINKAI
|
||||
self.btype["voice_to_text"] = new_v2t
|
||||
self.btype["text_to_voice"] = new_t2v
|
||||
self.bots.pop("voice_to_text", None)
|
||||
self.bots.pop("text_to_voice", None)
|
||||
logger.info(f"[Bridge] voice refreshed: voice_to_text={new_v2t}, text_to_voice={new_t2v}")
|
||||
|
||||
@staticmethod
|
||||
def _auto_pick_voice_to_text() -> str:
|
||||
"""Pick an ASR provider by configured api keys when voice_to_text is
|
||||
unset. Order matches the web console: openai → dashscope → zhipu →
|
||||
linkai. Falls back to 'openai' when nothing is configured so the
|
||||
original "missing key" error is preserved.
|
||||
"""
|
||||
def has(k: str) -> bool:
|
||||
v = (conf().get(k) or "").strip()
|
||||
return v != "" and v not in ("YOUR API KEY", "YOUR_API_KEY")
|
||||
|
||||
for key, provider in (
|
||||
("open_ai_api_key", "openai"),
|
||||
("dashscope_api_key", "dashscope"),
|
||||
("zhipu_ai_api_key", "zhipu"),
|
||||
("linkai_api_key", "linkai"),
|
||||
):
|
||||
if has(key):
|
||||
return provider
|
||||
return "openai"
|
||||
|
||||
# 模型对应的接口
|
||||
def get_bot(self, typename):
|
||||
if self.bots.get(typename) is None:
|
||||
|
||||
@@ -13,12 +13,44 @@ class Channel(object):
|
||||
channel_type = ""
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
||||
|
||||
def __init__(self):
|
||||
import threading
|
||||
self._startup_event = threading.Event()
|
||||
self._startup_error = None
|
||||
self.cloud_mode = False # set to True by ChannelManager when running with cloud client
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
init channel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def report_startup_success(self):
|
||||
self._startup_error = None
|
||||
self._startup_event.set()
|
||||
|
||||
def report_startup_error(self, error: str):
|
||||
self._startup_error = error
|
||||
self._startup_event.set()
|
||||
|
||||
def wait_startup(self, timeout: float = 3) -> (bool, str):
|
||||
"""
|
||||
Wait for channel startup result.
|
||||
Returns (success: bool, error_msg: str).
|
||||
"""
|
||||
ready = self._startup_event.wait(timeout=timeout)
|
||||
if not ready:
|
||||
return True, ""
|
||||
if self._startup_error:
|
||||
return False, self._startup_error
|
||||
return True, ""
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
stop channel gracefully, called before restart
|
||||
"""
|
||||
pass
|
||||
|
||||
def handle_text(self, msg):
|
||||
"""
|
||||
process received msg
|
||||
@@ -41,7 +73,7 @@ class Channel(object):
|
||||
Build reply content, using agent if enabled in config
|
||||
"""
|
||||
# Check if agent mode is enabled
|
||||
use_agent = conf().get("agent", False)
|
||||
use_agent = conf().get("agent", True)
|
||||
|
||||
if use_agent:
|
||||
try:
|
||||
@@ -51,11 +83,14 @@ class Channel(object):
|
||||
if context and "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
|
||||
# Read on_event callback injected by the channel (e.g. web SSE)
|
||||
on_event = context.get("on_event") if context else None
|
||||
|
||||
# Use agent bridge to handle the query
|
||||
return Bridge().fetch_agent_reply(
|
||||
query=query,
|
||||
context=context,
|
||||
on_event=None,
|
||||
on_event=on_event,
|
||||
clear_history=False
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -12,16 +12,7 @@ def create_channel(channel_type) -> Channel:
|
||||
:return: channel instance
|
||||
"""
|
||||
ch = Channel()
|
||||
if channel_type == "wx":
|
||||
from channel.wechat.wechat_channel import WechatChannel
|
||||
ch = WechatChannel()
|
||||
elif channel_type == "wxy":
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
ch = WechatyChannel()
|
||||
elif channel_type == "wcf":
|
||||
from channel.wechat.wcf_channel import WechatfChannel
|
||||
ch = WechatfChannel()
|
||||
elif channel_type == "terminal":
|
||||
if channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == 'web':
|
||||
@@ -36,15 +27,34 @@ def create_channel(channel_type) -> Channel:
|
||||
elif channel_type == "wechatcom_app":
|
||||
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
|
||||
ch = WechatComAppChannel()
|
||||
elif channel_type == "wework":
|
||||
from channel.wework.wework_channel import WeworkChannel
|
||||
ch = WeworkChannel()
|
||||
elif channel_type == const.WECHAT_KF:
|
||||
from channel.wechat_kf.wechat_kf_channel import WechatKfChannel
|
||||
ch = WechatKfChannel()
|
||||
elif channel_type == const.FEISHU:
|
||||
from channel.feishu.feishu_channel import FeiShuChanel
|
||||
ch = FeiShuChanel()
|
||||
elif channel_type == const.DINGTALK:
|
||||
from channel.dingtalk.dingtalk_channel import DingTalkChanel
|
||||
ch = DingTalkChanel()
|
||||
elif channel_type == const.WECOM_BOT:
|
||||
from channel.wecom_bot.wecom_bot_channel import WecomBotChannel
|
||||
ch = WecomBotChannel()
|
||||
elif channel_type == const.QQ:
|
||||
from channel.qq.qq_channel import QQChannel
|
||||
ch = QQChannel()
|
||||
elif channel_type == const.TELEGRAM:
|
||||
from channel.telegram.telegram_channel import TelegramChannel
|
||||
ch = TelegramChannel()
|
||||
elif channel_type == const.SLACK:
|
||||
from channel.slack.slack_channel import SlackChannel
|
||||
ch = SlackChannel()
|
||||
elif channel_type == const.DISCORD:
|
||||
from channel.discord.discord_channel import DiscordChannel
|
||||
ch = DiscordChannel()
|
||||
elif channel_type in (const.WEIXIN, "wx"):
|
||||
from channel.weixin.weixin_channel import WeixinChannel
|
||||
ch = WeixinChannel()
|
||||
channel_type = const.WEIXIN
|
||||
else:
|
||||
raise RuntimeError
|
||||
ch.channel_type = channel_type
|
||||
|
||||
@@ -10,6 +10,7 @@ from bridge.reply import *
|
||||
from channel.channel import Channel
|
||||
from common.dequeue import Dequeue
|
||||
from common import memory
|
||||
from common.i18n import t as _t
|
||||
from plugins import *
|
||||
|
||||
try:
|
||||
@@ -24,11 +25,17 @@ handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Instance-level attributes so each channel subclass has its own
|
||||
# independent session queue and lock. Previously these were class-level,
|
||||
# which caused contexts from one channel (e.g. Feishu) to be consumed
|
||||
# by another channel's consume() thread (e.g. Web), leading to errors
|
||||
# like "No request_id found in context".
|
||||
self.futures = {}
|
||||
self.sessions = {}
|
||||
self.lock = threading.Lock()
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
@@ -37,9 +44,8 @@ class ChatChannel(Channel):
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
# context首次传入时,origin_ctype是None,
|
||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
# context首次传入时,receiver是None,根据类型设置receiver
|
||||
@@ -166,7 +172,13 @@ class ChatChannel(Channel):
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
# Voice input replies with voice when either voice_reply_voice
|
||||
# (mirror voice) or the global always_reply_voice toggle is on.
|
||||
if (
|
||||
"desire_rtype" not in context
|
||||
and (conf().get("voice_reply_voice") or conf().get("always_reply_voice"))
|
||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
|
||||
):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
return context
|
||||
|
||||
@@ -254,11 +266,13 @@ class ChatChannel(Channel):
|
||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||
logger.error("[chat_channel]reply type not support: " + str(reply.type))
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "不支持发送的消息类型: " + str(reply.type)
|
||||
reply.content = _t("不支持发送的消息类型: ", "Unsupported message type: ") + str(reply.type)
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
# Preserve original text for the "text-then-voice" pattern in _send_reply.
|
||||
context["voice_reply_text"] = reply.content
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context.get("isgroup", False):
|
||||
@@ -292,8 +306,12 @@ class ChatChannel(Channel):
|
||||
logger.debug("[chat_channel] sending reply: {}, context: {}".format(reply, context))
|
||||
|
||||
# 如果是文本回复,尝试提取并发送图片
|
||||
if reply.type == ReplyType.TEXT:
|
||||
# Web channel renders images/videos inline via renderMarkdown,
|
||||
# so skip the extract-and-send step to avoid duplicate media.
|
||||
if reply.type == ReplyType.TEXT and context.get("channel_type") != "web":
|
||||
self._extract_and_send_images(reply, context)
|
||||
elif reply.type == ReplyType.TEXT:
|
||||
self._send(reply, context)
|
||||
# 如果是图片回复但带有文本内容,先发文本再发图片
|
||||
elif reply.type == ReplyType.IMAGE_URL and hasattr(reply, 'text_content') and reply.text_content:
|
||||
# 先发送文本
|
||||
@@ -302,6 +320,15 @@ class ChatChannel(Channel):
|
||||
# 短暂延迟后发送图片
|
||||
time.sleep(0.3)
|
||||
self._send(reply, context)
|
||||
# Send text bubble before voice, unless channel already streamed
|
||||
# the text (feishu) or natively renders STT under the voice (wechatcom).
|
||||
elif reply.type == ReplyType.VOICE and context.get("voice_reply_text") \
|
||||
and not context.get("feishu_streamed") \
|
||||
and context.get("channel_type") not in ("wechatcom_app",):
|
||||
text_reply = Reply(ReplyType.TEXT, context.get("voice_reply_text"))
|
||||
self._send(text_reply, context)
|
||||
time.sleep(0.3)
|
||||
self._send(reply, context)
|
||||
else:
|
||||
self._send(reply, context)
|
||||
|
||||
@@ -342,38 +369,30 @@ class ChatChannel(Channel):
|
||||
if media_items:
|
||||
logger.info(f"[chat_channel] Extracted {len(media_items)} media item(s) from reply")
|
||||
|
||||
# 先发送文本(保持原文本不变)
|
||||
# Send text first (the frontend will embed video players via renderMarkdown).
|
||||
logger.info(f"[chat_channel] Sending text content before media: {reply.content[:100]}...")
|
||||
self._send(reply, context)
|
||||
logger.info(f"[chat_channel] Text sent, now sending {len(media_items)} media item(s)")
|
||||
|
||||
# 然后逐个发送媒体文件
|
||||
for i, (url, media_type) in enumerate(media_items):
|
||||
try:
|
||||
# 判断是本地文件还是URL
|
||||
# Determine whether it is a remote URL or a local file.
|
||||
if url.startswith(('http://', 'https://')):
|
||||
# 网络资源
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型发送
|
||||
media_reply = Reply(ReplyType.FILE, url)
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, url)
|
||||
elif os.path.exists(url):
|
||||
# 本地文件
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.FILE, f"file://{url}")
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, f"file://{url}")
|
||||
else:
|
||||
logger.warning(f"[chat_channel] Media file not found or invalid URL: {url}")
|
||||
continue
|
||||
|
||||
# 发送媒体文件(添加小延迟避免频率限制)
|
||||
if i > 0:
|
||||
time.sleep(0.5)
|
||||
self._send(media_reply, context)
|
||||
@@ -420,19 +439,55 @@ class ChatChannel(Channel):
|
||||
|
||||
return func
|
||||
|
||||
# Chat commands that must bypass the per-session serial queue,
|
||||
# otherwise /cancel would queue behind the task it tries to cancel.
|
||||
# Use /cancel (not /stop) to avoid colliding with `cow stop` CLI.
|
||||
_BYPASS_QUEUE_COMMANDS = ("/cancel",)
|
||||
|
||||
def produce(self, context: Context):
|
||||
session_id = context["session_id"]
|
||||
|
||||
# Fast path: /cancel must not enter the queue.
|
||||
if context.type == ContextType.TEXT and context.content:
|
||||
stripped = context.content.strip().lower()
|
||||
if stripped in self._BYPASS_QUEUE_COMMANDS:
|
||||
self._handle_cancel_command(context, session_id)
|
||||
return
|
||||
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = [
|
||||
Dequeue(),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)),
|
||||
]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
else:
|
||||
self.sessions[session_id][0].put(context)
|
||||
|
||||
def _handle_cancel_command(self, context: Context, session_id: str) -> None:
|
||||
"""Cancel any in-flight agent run for *session_id* and reply inline.
|
||||
|
||||
Runs synchronously on the caller's thread. Reply is sent through
|
||||
_send_reply so plugins (e.g. logging) still observe it.
|
||||
"""
|
||||
try:
|
||||
from agent.protocol import get_cancel_registry
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
cancelled = get_cancel_registry().cancel_session(session_id)
|
||||
text = (
|
||||
_t("🛑 已中止", "🛑 Cancelled")
|
||||
if cancelled > 0
|
||||
else _t("当前没有可中止的任务。", "Nothing to cancel.")
|
||||
)
|
||||
logger.info(
|
||||
f"[chat_channel] /cancel fast-path: session={session_id}, cancelled={cancelled}"
|
||||
)
|
||||
self._send_reply(context, Reply(ReplyType.TEXT, text))
|
||||
except Exception as e:
|
||||
logger.warning(f"[chat_channel] /cancel fast-path failed: {e}")
|
||||
|
||||
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
||||
def consume(self):
|
||||
while True:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
||||
Unified chat message class for different channel implementations.
|
||||
|
||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from dingtalk_stream.card_replier import CardReplier
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel
|
||||
from common.utils import expand_path
|
||||
from channel.dingtalk.dingtalk_message import DingTalkMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
@@ -85,17 +86,15 @@ def _check(func):
|
||||
|
||||
@singleton
|
||||
class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
|
||||
def setup_logger(self):
|
||||
logger = logging.getLogger()
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter('%(asctime)s %(name)-8s %(levelname)-8s %(message)s [%(filename)s:%(lineno)d]'))
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
return logger
|
||||
# Suppress verbose logs from dingtalk_stream SDK
|
||||
logging.getLogger("dingtalk_stream").setLevel(logging.WARNING)
|
||||
return logging.getLogger("DingTalk")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -103,6 +102,9 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
self._stream_client = None
|
||||
self._running = False
|
||||
self._event_loop = None
|
||||
logger.debug("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
@@ -115,12 +117,130 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
# Robot code cache (extracted from incoming messages)
|
||||
self._robot_code = None
|
||||
|
||||
def _open_connection(self, client):
|
||||
"""
|
||||
Open a DingTalk stream connection directly, bypassing SDK's internal error-swallowing.
|
||||
Returns (connection_dict, error_str). On success error_str is empty; on failure
|
||||
connection_dict is None and error_str contains a human-readable message.
|
||||
"""
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://api.dingtalk.com/v1.0/gateway/connections/open",
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
json={
|
||||
"clientId": client.credential.client_id,
|
||||
"clientSecret": client.credential.client_secret,
|
||||
"subscriptions": [{"type": "CALLBACK",
|
||||
"topic": dingtalk_stream.chatbot.ChatbotMessage.TOPIC}],
|
||||
"ua": "dingtalk-sdk-python/cow",
|
||||
"localIp": "",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
body = resp.json()
|
||||
if not resp.ok:
|
||||
code = body.get("code", resp.status_code)
|
||||
message = body.get("message", resp.reason)
|
||||
return None, f"open connection failed: [{code}] {message}"
|
||||
return body, ""
|
||||
except Exception as e:
|
||||
return None, f"open connection failed: {e}"
|
||||
|
||||
def startup(self):
|
||||
import asyncio
|
||||
self.dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
self.dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
self._running = True
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
self._stream_client = client
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
logger.info("[DingTalk] ✅ Stream connected, ready to receive messages")
|
||||
client.start_forever()
|
||||
logger.info("[DingTalk] ✅ Stream client initialized, ready to receive messages")
|
||||
|
||||
# Run the connection loop ourselves instead of delegating to client.start(),
|
||||
# so we can get detailed error messages and respond to stop() quickly.
|
||||
import urllib.parse as _urlparse
|
||||
import websockets as _ws
|
||||
import json as _json
|
||||
client.pre_start()
|
||||
_first_connect = True
|
||||
while self._running:
|
||||
# Open connection using our own request so we get detailed error info.
|
||||
connection, err_msg = self._open_connection(client)
|
||||
|
||||
if connection is None:
|
||||
if _first_connect:
|
||||
logger.warning(f"[DingTalk] {err_msg}")
|
||||
self.report_startup_error(err_msg)
|
||||
_first_connect = False
|
||||
else:
|
||||
logger.warning(f"[DingTalk] {err_msg}, retrying in 10s...")
|
||||
|
||||
# Interruptible sleep: checks _running every 100ms.
|
||||
for _ in range(100):
|
||||
if not self._running:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
if _first_connect:
|
||||
logger.info("[DingTalk] ✅ Connected to DingTalk stream")
|
||||
self.report_startup_success()
|
||||
_first_connect = False
|
||||
else:
|
||||
logger.info("[DingTalk] Reconnected to DingTalk stream")
|
||||
|
||||
# Run the WebSocket session in an asyncio loop.
|
||||
uri = '%s?ticket=%s' % (
|
||||
connection['endpoint'],
|
||||
_urlparse.quote_plus(connection['ticket'])
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
self._event_loop = loop
|
||||
try:
|
||||
async def _session():
|
||||
async with _ws.connect(uri) as websocket:
|
||||
client.websocket = websocket
|
||||
async for raw_message in websocket:
|
||||
json_message = _json.loads(raw_message)
|
||||
result = await client.route_message(json_message)
|
||||
if result == dingtalk_stream.DingTalkStreamClient.TAG_DISCONNECT:
|
||||
break
|
||||
|
||||
loop.run_until_complete(_session())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("[DingTalk] Session loop received stop signal, exiting")
|
||||
break
|
||||
except Exception as e:
|
||||
if not self._running:
|
||||
break
|
||||
logger.warning(f"[DingTalk] Stream session error: {e}, reconnecting in 3s...")
|
||||
for _ in range(30):
|
||||
if not self._running:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
finally:
|
||||
self._event_loop = None
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("[DingTalk] Startup loop exited")
|
||||
|
||||
def stop(self):
|
||||
logger.info("[DingTalk] stop() called, setting _running=False")
|
||||
self._running = False
|
||||
loop = self._event_loop
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
logger.info("[DingTalk] Sent stop signal to event loop")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DingTalk] Error stopping event loop: {e}")
|
||||
self._stream_client = None
|
||||
logger.info("[DingTalk] stop() completed")
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
@@ -276,7 +396,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
|
||||
# 保存到临时文件
|
||||
file_name = os.path.basename(file_path) or f"media_{uuid.uuid4()}"
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
temp_file = os.path.join(tmp_dir, file_name)
|
||||
@@ -457,23 +577,21 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
|
||||
|
||||
# 缓存 robot_code,用于后续图片下载
|
||||
if hasattr(incoming_message, 'robot_code'):
|
||||
self._robot_code_cache = incoming_message.robot_code
|
||||
|
||||
# Debug: 打印完整的 event 数据
|
||||
logger.debug(f"[DingTalk] ===== Incoming Message Debug =====")
|
||||
logger.debug(f"[DingTalk] callback.data keys: {callback.data.keys() if hasattr(callback.data, 'keys') else 'N/A'}")
|
||||
logger.debug(f"[DingTalk] incoming_message attributes: {dir(incoming_message)}")
|
||||
logger.debug(f"[DingTalk] robot_code: {getattr(incoming_message, 'robot_code', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] chatbot_corp_id: {getattr(incoming_message, 'chatbot_corp_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] chatbot_user_id: {getattr(incoming_message, 'chatbot_user_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] conversation_id: {getattr(incoming_message, 'conversation_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] Raw callback.data: {callback.data}")
|
||||
logger.debug(f"[DingTalk] =====================================")
|
||||
|
||||
image_download_handler = self # 传入方法所在的类实例
|
||||
|
||||
# Filter out stale messages from before channel startup (offline backlog)
|
||||
create_at = getattr(incoming_message, 'create_at', None)
|
||||
if create_at:
|
||||
msg_age_s = time.time() - int(create_at) / 1000
|
||||
if msg_age_s > 60:
|
||||
logger.warning(f"[DingTalk] stale msg filtered (age={msg_age_s:.0f}s), "
|
||||
f"msg_id={getattr(incoming_message, 'message_id', 'N/A')}")
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
|
||||
image_download_handler = self
|
||||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
|
||||
|
||||
if dingtalk_msg.is_group:
|
||||
@@ -482,8 +600,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.handle_single(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] process error: {e}")
|
||||
logger.exception(e) # 打印完整堆栈跟踪
|
||||
logger.error(f"[DingTalk] process error: {e}", exc_info=True)
|
||||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
|
||||
|
||||
@time_checker
|
||||
@@ -607,7 +724,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
logger.info(f"[DingTalk] send() called with reply.type={reply.type}, content_length={len(str(reply.content))}")
|
||||
logger.debug(f"[DingTalk] send() called with reply.type={reply.type}, content_length={len(str(reply.content))}")
|
||||
receiver = context["receiver"]
|
||||
|
||||
# Check if msg exists (for scheduled tasks, msg might be None)
|
||||
@@ -647,7 +764,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
robot_code = msg.robot_code
|
||||
if robot_code and robot_code != self._robot_code:
|
||||
self._robot_code = robot_code
|
||||
logger.info(f"[DingTalk] Cached robot_code: {robot_code}")
|
||||
logger.debug(f"[DingTalk] Cached robot_code: {robot_code}")
|
||||
|
||||
isgroup = msg.is_group
|
||||
incoming_message = msg.incoming_message
|
||||
@@ -755,6 +872,48 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.reply_text("抱歉,文件上传失败", incoming_message)
|
||||
return
|
||||
|
||||
# Native sampleAudio. Upload only accepts ogg/amr, so convert TTS mp3/wav to amr.
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
logger.info(f"[DingTalk] Sending voice: {reply.content}")
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot get access token for voice")
|
||||
self.reply_text("抱歉,语音发送失败(无法获取token)", incoming_message)
|
||||
return
|
||||
|
||||
voice_path = reply.content
|
||||
if voice_path.startswith("file://"):
|
||||
voice_path = voice_path[7:]
|
||||
|
||||
amr_path = voice_path
|
||||
duration_ms = 0
|
||||
if not voice_path.lower().endswith((".amr", ".ogg")):
|
||||
try:
|
||||
from voice.audio_convert import any_to_amr
|
||||
amr_path = os.path.splitext(voice_path)[0] + ".amr"
|
||||
duration_ms = int(any_to_amr(voice_path, amr_path) or 0)
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Failed to convert voice to amr: {e}")
|
||||
self.reply_text("抱歉,语音转码失败", incoming_message)
|
||||
return
|
||||
|
||||
media_id = self.upload_media(amr_path, media_type="voice")
|
||||
if not media_id:
|
||||
logger.error("[DingTalk] Failed to upload voice media")
|
||||
self.reply_text("抱歉,语音上传失败", incoming_message)
|
||||
return
|
||||
|
||||
msg_param = {
|
||||
"mediaId": media_id,
|
||||
"duration": str(duration_ms or 1000),
|
||||
}
|
||||
success = self._send_file_message(
|
||||
access_token, incoming_message, "sampleAudio", msg_param, isgroup
|
||||
)
|
||||
if not success:
|
||||
self.reply_text("抱歉,语音发送失败", incoming_message)
|
||||
return
|
||||
|
||||
# 处理文本消息
|
||||
elif reply.type == ReplyType.TEXT:
|
||||
logger.info(f"[DingTalk] Sending text message, length={len(reply.content)}")
|
||||
|
||||
@@ -9,6 +9,7 @@ from channel.chat_message import ChatMessage
|
||||
# -*- coding=utf-8 -*-
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
@@ -49,7 +50,7 @@ class DingTalkMessage(ChatMessage):
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
|
||||
# 下载到工作空间 tmp 目录
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
@@ -67,7 +68,7 @@ class DingTalkMessage(ChatMessage):
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
# 下载到工作空间 tmp 目录
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
|
||||
0
channel/discord/__init__.py
Normal file
0
channel/discord/__init__.py
Normal file
500
channel/discord/discord_channel.py
Normal file
500
channel/discord/discord_channel.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
Discord channel via the Gateway (WebSocket) using discord.py.
|
||||
|
||||
Features:
|
||||
- Direct message & guild channel chat (text / image / file)
|
||||
- Guild trigger: @mention or reply-to-bot (configurable)
|
||||
- /cancel fast-path matches Web channel behaviour
|
||||
- Gateway long connection: no public IP / callback URL required, works behind NAT
|
||||
|
||||
Implementation note:
|
||||
discord.py is async-first. We run the client inside a dedicated thread
|
||||
with its own asyncio loop so the rest of cow (which is sync) stays
|
||||
untouched. Inbound messages are dispatched onto cow's existing sync
|
||||
ChatChannel.produce() pipeline; outbound send() schedules coroutines
|
||||
back onto that loop via asyncio.run_coroutine_threadsafe.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.discord.discord_message import DiscordMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
# Discord caps a single message at 2000 chars; split conservatively below.
|
||||
DISCORD_MSG_LIMIT = 1900
|
||||
|
||||
|
||||
@singleton
|
||||
class DiscordChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = ""
|
||||
self.bot_user_id = "" # used to strip @mention and ignore self messages
|
||||
self.bot_username = ""
|
||||
self._client = None
|
||||
self._loop = None
|
||||
self._loop_thread = None
|
||||
self._stop_event = threading.Event()
|
||||
# Idempotent dedup; guard against rare duplicate dispatch
|
||||
self._received_msgs = ExpiredDict(60 * 60 * 1)
|
||||
|
||||
# Disable group whitelist / prefix checks (we handle triggering ourselves
|
||||
# in _should_reply_in_guild), aligned with telegram / slack channels.
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.bot_token = conf().get("discord_token", "")
|
||||
if not self.bot_token:
|
||||
err = "[Discord] discord_token is required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
try:
|
||||
import discord
|
||||
except ImportError:
|
||||
err = (
|
||||
"[Discord] discord.py is not installed. "
|
||||
"Run: pip install discord.py"
|
||||
)
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
# Run the asyncio event loop in a dedicated thread so the sync cow body
|
||||
# is untouched.
|
||||
self._loop = asyncio.new_event_loop()
|
||||
|
||||
def _run_loop():
|
||||
asyncio.set_event_loop(self._loop)
|
||||
try:
|
||||
self._loop.run_until_complete(self._async_main(discord))
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] event loop crashed: {e}", exc_info=True)
|
||||
self.report_startup_error(str(e))
|
||||
finally:
|
||||
try:
|
||||
self._loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[Discord] event loop exited")
|
||||
|
||||
self._loop_thread = threading.Thread(target=_run_loop, daemon=True, name="discord-loop")
|
||||
self._loop_thread.start()
|
||||
# Block startup() until the loop thread exits, matching other channels'
|
||||
# behaviour (startup is a blocking call).
|
||||
self._loop_thread.join()
|
||||
|
||||
async def _async_main(self, discord):
|
||||
"""Build the discord client, register handlers, and connect to the Gateway."""
|
||||
# message_content is a privileged intent; it must be enabled in the
|
||||
# Developer Portal (Bot -> Privileged Gateway Intents) to read text.
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
self._client = client
|
||||
|
||||
channel = self
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
channel.bot_user_id = str(client.user.id)
|
||||
channel.bot_username = client.user.name or ""
|
||||
channel.name = channel.bot_user_id # ChatChannel uses self.name to strip @-mention
|
||||
logger.info(f"[Discord] Bot logged in as {client.user} (id={client.user.id})")
|
||||
channel.report_startup_success()
|
||||
logger.info("[Discord] ✅ Discord bot ready, listening for messages")
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
await channel._on_message(message)
|
||||
|
||||
# Connect to the Gateway; discord.py auto-reconnects on transient errors.
|
||||
logger.info("[Discord] Connecting to Gateway...")
|
||||
|
||||
# client.start() handles login + Gateway connection and runs until
|
||||
# close(); it is the standard entrypoint across discord.py versions.
|
||||
runner_task = asyncio.create_task(client.start(self.bot_token))
|
||||
|
||||
# Block until stop()
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
if runner_task.done():
|
||||
# Surface a startup/connection failure (e.g. bad token)
|
||||
exc = runner_task.exception()
|
||||
if exc:
|
||||
logger.error(f"[Discord] client stopped: {exc}", exc_info=exc)
|
||||
self.report_startup_error(str(exc))
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
finally:
|
||||
try:
|
||||
if not client.is_closed():
|
||||
await client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] shutdown error: {e}")
|
||||
|
||||
def stop(self):
|
||||
logger.info("[Discord] stop() called")
|
||||
self._stop_event.set()
|
||||
if self._loop_thread and self._loop_thread.is_alive():
|
||||
try:
|
||||
self._loop_thread.join(timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[Discord] stop() completed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound: discord message -> ChatMessage -> ChatChannel.produce
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _on_message(self, message):
|
||||
"""Discord message entry: parse -> build ChatMessage -> produce()."""
|
||||
try:
|
||||
# Ignore our own messages and other bots. self._client.user may be
|
||||
# None until on_ready completes, so guard against that.
|
||||
if self._client and self._client.user and message.author.id == self._client.user.id:
|
||||
return
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
# Idempotent dedup
|
||||
msg_uid = f"{message.channel.id}:{message.id}"
|
||||
if self._received_msgs.get(msg_uid):
|
||||
return
|
||||
self._received_msgs[msg_uid] = True
|
||||
|
||||
# guild is None for DMs
|
||||
is_group = message.guild is not None
|
||||
|
||||
# Guild trigger gate (silently drop if not triggered)
|
||||
if is_group and not self._should_reply_in_guild(message):
|
||||
logger.debug(f"[Discord] guild message not triggered (need @mention or reply), skip")
|
||||
return
|
||||
|
||||
# Parse message type + download attachments if needed.
|
||||
ctype, content, caption = await self._parse_message(message)
|
||||
if ctype is None:
|
||||
logger.debug(f"[Discord] unsupported message type, skip. msg_id={message.id}")
|
||||
return
|
||||
|
||||
# Strip the bot mention from guild text/caption
|
||||
if is_group:
|
||||
if ctype == ContextType.TEXT and content:
|
||||
content = self._strip_at_mention(content)
|
||||
if caption:
|
||||
caption = self._strip_at_mention(caption)
|
||||
|
||||
dc_msg = DiscordMessage(
|
||||
message,
|
||||
is_group=is_group,
|
||||
bot_user_id=self.bot_user_id,
|
||||
ctype=ctype,
|
||||
content=content,
|
||||
)
|
||||
dc_msg.is_at = is_group # if we reached here in a guild, bot is mentioned/replied
|
||||
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
session_id = self._compute_session_id(message, is_group)
|
||||
|
||||
# Media + caption together: treat as a complete query and bypass the cache
|
||||
if ctype in (ContextType.IMAGE, ContextType.FILE) and caption:
|
||||
tag = "image" if ctype == ContextType.IMAGE else "file"
|
||||
merged_text = f"{caption}\n[{tag}: {content}]"
|
||||
dc_msg.ctype = ContextType.TEXT
|
||||
dc_msg.content = merged_text
|
||||
ctype = ContextType.TEXT
|
||||
logger.info(f"[Discord] Media+caption merged for session {session_id}")
|
||||
# fallthrough to the TEXT branch below
|
||||
|
||||
elif ctype == ContextType.IMAGE:
|
||||
file_cache.add(session_id, content, file_type="image")
|
||||
logger.info(f"[Discord] Image cached for session {session_id}, waiting for query...")
|
||||
return
|
||||
elif ctype == ContextType.FILE:
|
||||
file_cache.add(session_id, content, file_type="file")
|
||||
logger.info(f"[Discord] File cached for session {session_id}: {content}")
|
||||
return
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# Fast-path: /cancel mirrors Web channel behaviour
|
||||
if (content or "").strip().lower() in ("/cancel", "cancel"):
|
||||
await self._do_cancel(session_id, message)
|
||||
return
|
||||
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
tag = ftype if ftype in ("image", "video") else "file"
|
||||
refs.append(f"[{tag}: {fi['path']}]")
|
||||
dc_msg.content = (dc_msg.content or "") + "\n" + "\n".join(refs)
|
||||
file_cache.clear(session_id)
|
||||
logger.info(f"[Discord] Attached {len(cached_files)} cached file(s) to query")
|
||||
|
||||
context = self._compose_context(
|
||||
dc_msg.ctype,
|
||||
dc_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=dc_msg,
|
||||
# Replies use Discord's reply mechanism, no manual @mention needed
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = str(message.channel.id)
|
||||
context["discord_channel_id"] = message.channel.id
|
||||
context["discord_reply_to_msg_id"] = message.id if is_group else None
|
||||
self.produce(context)
|
||||
logger.debug(f"[Discord] received: type={ctype}, content={str(dc_msg.content)[:80]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] _on_message error: {e}", exc_info=True)
|
||||
|
||||
async def _do_cancel(self, session_id: str, message):
|
||||
"""Fast-path: /cancel calls cancel_session directly without going through agent."""
|
||||
try:
|
||||
from agent.protocol import get_cancel_registry
|
||||
cancelled = get_cancel_registry().cancel_session(session_id)
|
||||
text = "Current task cancelled." if cancelled else "No running task to cancel."
|
||||
await message.channel.send(text)
|
||||
logger.info(f"[Discord] /cancel session={session_id}, cancelled={cancelled}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] /cancel error: {e}", exc_info=True)
|
||||
|
||||
async def _parse_message(self, message):
|
||||
"""Parse a discord message and return (ctype, content, caption).
|
||||
|
||||
- content is text for ContextType.TEXT, otherwise the local file path
|
||||
- caption is the optional text accompanying an attachment; empty for plain text
|
||||
"""
|
||||
text = (message.content or "").strip()
|
||||
attachments = message.attachments or []
|
||||
|
||||
if attachments:
|
||||
# Handle the first attachment; caption is the accompanying message text
|
||||
att = attachments[0]
|
||||
content_type = (att.content_type or "").lower()
|
||||
name = att.filename or str(att.id)
|
||||
path = await self._download_attachment(att, name)
|
||||
if not path:
|
||||
return (None, None, "")
|
||||
is_image = content_type.startswith("image/") or name.lower().endswith(
|
||||
(".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
)
|
||||
if is_image:
|
||||
return (ContextType.IMAGE, path, text)
|
||||
return (ContextType.FILE, path, text)
|
||||
|
||||
if text:
|
||||
return (ContextType.TEXT, text, "")
|
||||
|
||||
return (None, None, "")
|
||||
|
||||
async def _download_attachment(self, attachment, name: str):
|
||||
"""Download a discord attachment into the local tmp dir; return path or None."""
|
||||
try:
|
||||
tmp_dir = DiscordMessage.get_tmp_dir()
|
||||
safe_name = re.sub(r"[^\w.\-]", "_", name)
|
||||
# Prefix with attachment id to avoid name collisions
|
||||
local_path = os.path.join(tmp_dir, f"{attachment.id}_{safe_name}")
|
||||
await attachment.save(local_path)
|
||||
logger.debug(f"[Discord] downloaded {name} -> {local_path}")
|
||||
return local_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] download_attachment failed ({name}): {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Guild trigger logic
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _should_reply_in_guild(self, message) -> bool:
|
||||
"""Decide whether to reply to a guild channel message based on configuration."""
|
||||
mode = conf().get("discord_group_trigger", "mention_or_reply")
|
||||
if mode == "all":
|
||||
return True
|
||||
|
||||
# self._client.user may be None until on_ready completes
|
||||
if not self._client or not self._client.user:
|
||||
return False
|
||||
|
||||
# 1) Mentioned (direct @bot, not @everyone / @role)
|
||||
if self._client.user in message.mentions:
|
||||
return True
|
||||
|
||||
# 2) Reply to a bot message
|
||||
if mode == "mention_or_reply":
|
||||
ref = message.reference
|
||||
resolved = getattr(ref, "resolved", None) if ref else None
|
||||
if resolved and getattr(resolved, "author", None):
|
||||
if resolved.author.id == self._client.user.id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _strip_at_mention(self, content: str) -> str:
|
||||
"""Strip <@BOT_ID> / <@!BOT_ID> from guild text."""
|
||||
if not content or not self.bot_user_id:
|
||||
return content
|
||||
pattern = re.compile(r"<@!?" + re.escape(self.bot_user_id) + r">")
|
||||
return pattern.sub("", content).strip()
|
||||
|
||||
@staticmethod
|
||||
def _compute_session_id(message, is_group: bool) -> str:
|
||||
channel_id = message.channel.id
|
||||
user_id = message.author.id
|
||||
if is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
return f"discord_channel_{channel_id}"
|
||||
return f"discord_channel_{channel_id}_{user_id}"
|
||||
return f"discord_user_{user_id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override _compose_context: skip the parent's group whitelist/at checks
|
||||
# (already handled via _should_reply_in_guild). Same idea as telegram / slack.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
if cmsg.is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = (content or "").strip()
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif ctype == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and (
|
||||
conf().get("voice_reply_voice") or conf().get("always_reply_voice")
|
||||
):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound: ChatChannel.send -> Discord Gateway/REST
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""Called from cow's sync main thread; marshal the coroutine onto the loop thread."""
|
||||
if self._loop is None or self._client is None:
|
||||
logger.warning("[Discord] client not ready, drop reply")
|
||||
return
|
||||
|
||||
channel_id = context.get("discord_channel_id")
|
||||
if channel_id is None:
|
||||
logger.warning("[Discord] no discord_channel_id in context, drop reply")
|
||||
return
|
||||
|
||||
coro = self._async_send(reply, channel_id)
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
future.result(timeout=180)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] send failed: {e}")
|
||||
|
||||
async def _async_send(self, reply: Reply, channel_id):
|
||||
try:
|
||||
import discord
|
||||
|
||||
channel = self._client.get_channel(channel_id)
|
||||
if channel is None:
|
||||
# Not in cache (e.g. DM channel); fetch it explicitly
|
||||
channel = await self._client.fetch_channel(channel_id)
|
||||
|
||||
rtype = reply.type
|
||||
content = reply.content
|
||||
|
||||
if rtype in (ReplyType.TEXT, ReplyType.INFO, ReplyType.ERROR):
|
||||
text = str(content) if content is not None else ""
|
||||
if not text:
|
||||
return
|
||||
for chunk in _split_text(text, DISCORD_MSG_LIMIT):
|
||||
await channel.send(chunk)
|
||||
|
||||
elif rtype == ReplyType.IMAGE:
|
||||
# Already a local BytesIO; send it directly
|
||||
content.seek(0)
|
||||
await channel.send(file=discord.File(content, filename="image.png"))
|
||||
|
||||
elif rtype == ReplyType.IMAGE_URL:
|
||||
url = str(content)
|
||||
if url.startswith("file://"):
|
||||
local = url[7:]
|
||||
await channel.send(file=discord.File(local))
|
||||
else:
|
||||
# Post the URL as text; Discord will unfurl it as an image preview
|
||||
await channel.send(url)
|
||||
|
||||
elif rtype in (ReplyType.VOICE, ReplyType.FILE):
|
||||
local = content[7:] if isinstance(content, str) and content.startswith("file://") else content
|
||||
caption = getattr(reply, "text_content", None) or None
|
||||
await channel.send(content=caption, file=discord.File(local))
|
||||
|
||||
else:
|
||||
# Fallback: send as plain text
|
||||
await channel.send(str(content))
|
||||
|
||||
logger.info(f"[Discord] sent reply (type={rtype}, channel={channel_id})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] _async_send error: {e}", exc_info=True)
|
||||
|
||||
|
||||
def _split_text(text: str, limit: int):
|
||||
"""Split long text preferring line breaks to keep markdown structure intact."""
|
||||
if len(text) <= limit:
|
||||
yield text
|
||||
return
|
||||
buf = []
|
||||
size = 0
|
||||
for line in text.splitlines(keepends=True):
|
||||
if size + len(line) > limit and buf:
|
||||
yield "".join(buf)
|
||||
buf, size = [], 0
|
||||
# Hard-split single lines that exceed the limit
|
||||
while len(line) > limit:
|
||||
yield line[:limit]
|
||||
line = line[limit:]
|
||||
buf.append(line)
|
||||
size += len(line)
|
||||
if buf:
|
||||
yield "".join(buf)
|
||||
60
channel/discord/discord_message.py
Normal file
60
channel/discord/discord_message.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Discord message adapter.
|
||||
|
||||
Convert a discord.py Message into cow's unified ChatMessage.
|
||||
File downloads are NOT performed here; the channel layer downloads
|
||||
attachments on demand inside the async event loop.
|
||||
"""
|
||||
import os
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
class DiscordMessage(ChatMessage):
|
||||
"""Wrap a discord.py Message into the unified ChatMessage."""
|
||||
|
||||
def __init__(self, message, is_group: bool = False, bot_user_id: str = "",
|
||||
ctype: ContextType = ContextType.TEXT, content: str = ""):
|
||||
super().__init__(message)
|
||||
# Basic fields
|
||||
self.msg_id = str(message.id)
|
||||
self.create_time = int(message.created_at.timestamp()) if message.created_at else 0
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
|
||||
author = message.author
|
||||
channel = message.channel
|
||||
|
||||
# Sender / chat info
|
||||
from_user_id = str(author.id)
|
||||
from_user_nick = getattr(author, "display_name", None) or getattr(author, "name", None) or from_user_id
|
||||
self.from_user_id = from_user_id
|
||||
self.from_user_nickname = from_user_nick
|
||||
self.to_user_id = bot_user_id or "discord_bot"
|
||||
self.to_user_nickname = bot_user_id or "discord_bot"
|
||||
|
||||
self.is_group = is_group
|
||||
if is_group:
|
||||
# Guild channel: other_user_id = channel_id, actual_user_id = sender id
|
||||
self.other_user_id = str(channel.id)
|
||||
self.other_user_nickname = getattr(channel, "name", None) or str(channel.id)
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = from_user_nick
|
||||
else:
|
||||
# DM: use channel_id so replies go back to the same DM channel
|
||||
self.other_user_id = str(channel.id)
|
||||
self.other_user_nickname = from_user_nick
|
||||
|
||||
# Whether the bot was triggered by @-mention (set by channel layer)
|
||||
self.is_at = False
|
||||
|
||||
@staticmethod
|
||||
def get_tmp_dir() -> str:
|
||||
"""Local download directory, aligned with other channels (agent_workspace/tmp)."""
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
@@ -140,6 +140,23 @@ python3 app.py
|
||||
|
||||
**解决**: 安装依赖 `pip install lark-oapi`
|
||||
|
||||
### SSL证书验证失败
|
||||
|
||||
```
|
||||
[Lark][ERROR] connect failed, err:[SSL:CERTIFICATE_VERIFY_FAILED] certificate verify failed: self signed certificate in certificate chain
|
||||
```
|
||||
|
||||
**原因**: 网络环境中存在自签名证书或SSL中间人代理(如企业代理、VPN等)
|
||||
|
||||
**解决**: 程序会自动检测SSL证书验证失败,并自动重试禁用证书验证的连接。无需手动配置。
|
||||
|
||||
当遇到证书错误时,日志会显示:
|
||||
```
|
||||
[FeiShu] SSL certificate verification disabled due to certificate error. This may happen when using corporate proxy or self-signed certificates.
|
||||
```
|
||||
|
||||
这是正常现象,程序会自动处理并继续运行。
|
||||
|
||||
### Webhook模式端口被占用
|
||||
|
||||
```
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
@@ -31,7 +32,7 @@ class FeishuMessage(ChatMessage):
|
||||
image_key = content.get("image_key")
|
||||
|
||||
# 下载图片到工作空间临时目录
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
image_path = os.path.join(tmp_dir, f"{image_key}.png")
|
||||
@@ -97,7 +98,7 @@ class FeishuMessage(ChatMessage):
|
||||
|
||||
if image_keys:
|
||||
# 如果包含图片,下载并在文本中引用本地路径
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
@@ -143,7 +144,14 @@ class FeishuMessage(ChatMessage):
|
||||
file_key = content.get("file_key")
|
||||
file_name = content.get("file_name")
|
||||
|
||||
self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name)
|
||||
# 落到 agent_workspace/tmp 下(绝对路径),与图片处理一致;
|
||||
# 否则相对路径 ./tmp 在 agent 工作区里 read 时会找不到。
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
self.content = os.path.join(
|
||||
tmp_dir, f"{file_key}.{utils.get_path_suffix(file_name)}"
|
||||
)
|
||||
|
||||
def _download_file():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
@@ -161,6 +169,42 @@ class FeishuMessage(ChatMessage):
|
||||
else:
|
||||
logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
|
||||
self._prepare_fn = _download_file
|
||||
elif msg_type == "audio":
|
||||
# 飞书用户发送的语音消息类型为 "audio",文件为 opus 编码格式。
|
||||
# 映射为 ContextType.VOICE,交由 chat_channel 的语音转文字(STT)流程处理。
|
||||
# 文件通过 _prepare_fn 延迟下载,在 chat_channel 调用 cmsg.prepare() 时才执行。
|
||||
self.ctype = ContextType.VOICE
|
||||
content = json.loads(msg.get("content"))
|
||||
file_key = content.get("file_key")
|
||||
|
||||
# 落到 agent_workspace/tmp 下(绝对路径),保证语音 STT 流程可读到
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
self.content = os.path.join(tmp_dir, f"{file_key}.opus")
|
||||
logger.info(f"[FeiShu] audio message: file_key={file_key}, save_path={self.content}")
|
||||
|
||||
def _download_audio():
|
||||
logger.info(f"[FeiShu] downloading audio: file_key={file_key}, msg_id={self.msg_id}")
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}"
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
}
|
||||
params = {
|
||||
"type": "file"
|
||||
}
|
||||
try:
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
logger.info(f"[FeiShu] download audio response: status={response.status_code}, size={len(response.content)} bytes")
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"[FeiShu] audio saved to: {self.content}")
|
||||
else:
|
||||
logger.error(f"[FeiShu] Failed to download audio, key={file_key}, status={response.status_code}, res={response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] Exception downloading audio, key={file_key}: {e}", exc_info=True)
|
||||
self._prepare_fn = _download_audio
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))
|
||||
|
||||
|
||||
0
channel/qq/__init__.py
Normal file
0
channel/qq/__init__.py
Normal file
736
channel/qq/qq_channel.py
Normal file
736
channel/qq/qq_channel.py
Normal file
@@ -0,0 +1,736 @@
|
||||
"""
|
||||
QQ Bot channel via WebSocket long connection.
|
||||
|
||||
Supports:
|
||||
- Group chat (@bot), single chat (C2C), guild channel, guild DM
|
||||
- Text / image / file message send & receive
|
||||
- Heartbeat keep-alive and auto-reconnect with session resume
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
import websocket
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.qq.qq_message import QQMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.ws_client_compat import websocket_app_run_forever
|
||||
from config import conf
|
||||
|
||||
# Rich media file_type constants
|
||||
QQ_FILE_TYPE_IMAGE = 1
|
||||
QQ_FILE_TYPE_VIDEO = 2
|
||||
QQ_FILE_TYPE_VOICE = 3
|
||||
QQ_FILE_TYPE_FILE = 4
|
||||
|
||||
QQ_API_BASE = "https://api.sgroup.qq.com"
|
||||
|
||||
# Intents: GROUP_AND_C2C_EVENT(1<<25) | PUBLIC_GUILD_MESSAGES(1<<30)
|
||||
DEFAULT_INTENTS = (1 << 25) | (1 << 30)
|
||||
|
||||
# OpCode constants
|
||||
OP_DISPATCH = 0
|
||||
OP_HEARTBEAT = 1
|
||||
OP_IDENTIFY = 2
|
||||
OP_RESUME = 6
|
||||
OP_RECONNECT = 7
|
||||
OP_INVALID_SESSION = 9
|
||||
OP_HELLO = 10
|
||||
OP_HEARTBEAT_ACK = 11
|
||||
|
||||
# Resumable error codes
|
||||
RESUMABLE_CLOSE_CODES = {4008, 4009}
|
||||
|
||||
|
||||
@singleton
|
||||
class QQChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app_id = ""
|
||||
self.app_secret = ""
|
||||
|
||||
self._access_token = ""
|
||||
self._token_expires_at = 0
|
||||
|
||||
self._ws = None
|
||||
self._ws_thread = None
|
||||
self._heartbeat_thread = None
|
||||
self._connected = False
|
||||
self._stop_event = threading.Event()
|
||||
self._token_lock = threading.Lock()
|
||||
|
||||
self._session_id = None
|
||||
self._last_seq = None
|
||||
self._heartbeat_interval = 45000
|
||||
self._can_resume = False
|
||||
|
||||
self.received_msgs = ExpiredDict(60 * 60 * 7.1)
|
||||
self._msg_seq_counter = {}
|
||||
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.app_id = conf().get("qq_app_id", "")
|
||||
self.app_secret = conf().get("qq_app_secret", "")
|
||||
|
||||
if not self.app_id or not self.app_secret:
|
||||
err = "[QQ] qq_app_id and qq_app_secret are required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._refresh_access_token()
|
||||
if not self._access_token:
|
||||
err = "[QQ] Failed to get initial access_token"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_ws()
|
||||
|
||||
def stop(self):
|
||||
logger.info("[QQ] stop() called")
|
||||
self._stop_event.set()
|
||||
if self._ws:
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Access Token
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _refresh_access_token(self):
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://bots.qq.com/app/getAppAccessToken",
|
||||
json={"appId": self.app_id, "clientSecret": self.app_secret},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._access_token = data.get("access_token", "")
|
||||
expires_in = int(data.get("expires_in", 7200))
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
logger.debug(f"[QQ] Access token refreshed, expires_in={expires_in}s")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to refresh access_token: {e}")
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
with self._token_lock:
|
||||
if time.time() >= self._token_expires_at:
|
||||
self._refresh_access_token()
|
||||
return self._access_token
|
||||
|
||||
def _get_auth_headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": f"QQBot {self._get_access_token()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket connection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_ws_url(self) -> str:
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{QQ_API_BASE}/gateway",
|
||||
headers=self._get_auth_headers(),
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
url = resp.json().get("url", "")
|
||||
logger.debug(f"[QQ] Gateway URL: {url}")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to get gateway URL: {e}")
|
||||
return ""
|
||||
|
||||
def _start_ws(self):
|
||||
ws_url = self._get_ws_url()
|
||||
if not ws_url:
|
||||
logger.error("[QQ] Cannot start WebSocket without gateway URL")
|
||||
self.report_startup_error("Failed to get gateway URL")
|
||||
return
|
||||
|
||||
def _on_open(ws):
|
||||
logger.debug("[QQ] WebSocket connected, waiting for Hello...")
|
||||
|
||||
def _on_message(ws, raw):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
self._handle_ws_message(data)
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to handle ws message: {e}", exc_info=True)
|
||||
|
||||
def _on_error(ws, error):
|
||||
logger.error(f"[QQ] WebSocket error: {error}")
|
||||
|
||||
def _on_close(ws, close_status_code, close_msg):
|
||||
logger.warning(f"[QQ] WebSocket closed: status={close_status_code}, msg={close_msg}")
|
||||
self._connected = False
|
||||
if not self._stop_event.is_set():
|
||||
if close_status_code in RESUMABLE_CLOSE_CODES and self._session_id:
|
||||
self._can_resume = True
|
||||
logger.info("[QQ] Will attempt resume in 3s...")
|
||||
time.sleep(3)
|
||||
else:
|
||||
self._can_resume = False
|
||||
logger.info("[QQ] Will reconnect in 5s...")
|
||||
time.sleep(5)
|
||||
if not self._stop_event.is_set():
|
||||
self._start_ws()
|
||||
|
||||
self._ws = websocket.WebSocketApp(
|
||||
ws_url,
|
||||
on_open=_on_open,
|
||||
on_message=_on_message,
|
||||
on_error=_on_error,
|
||||
on_close=_on_close,
|
||||
)
|
||||
|
||||
def run_forever():
|
||||
try:
|
||||
websocket_app_run_forever(self._ws, ping_interval=0, reconnect=0)
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
logger.info("[QQ] WebSocket thread interrupted")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] WebSocket run_forever error: {e}")
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_forever, daemon=True)
|
||||
self._ws_thread.start()
|
||||
self._ws_thread.join()
|
||||
|
||||
def _ws_send(self, data: dict):
|
||||
if self._ws:
|
||||
self._ws.send(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Identify & Resume & Heartbeat
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_identify(self):
|
||||
self._ws_send({
|
||||
"op": OP_IDENTIFY,
|
||||
"d": {
|
||||
"token": f"QQBot {self._get_access_token()}",
|
||||
"intents": DEFAULT_INTENTS,
|
||||
"shard": [0, 1],
|
||||
"properties": {
|
||||
"$os": "linux",
|
||||
"$browser": "chatgpt-on-wechat",
|
||||
"$device": "chatgpt-on-wechat",
|
||||
},
|
||||
},
|
||||
})
|
||||
logger.debug(f"[QQ] Identify sent with intents={DEFAULT_INTENTS}")
|
||||
|
||||
def _send_resume(self):
|
||||
self._ws_send({
|
||||
"op": OP_RESUME,
|
||||
"d": {
|
||||
"token": f"QQBot {self._get_access_token()}",
|
||||
"session_id": self._session_id,
|
||||
"seq": self._last_seq,
|
||||
},
|
||||
})
|
||||
logger.debug(f"[QQ] Resume sent: session_id={self._session_id}, seq={self._last_seq}")
|
||||
|
||||
def _start_heartbeat(self, interval_ms: int):
|
||||
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
|
||||
return
|
||||
self._heartbeat_interval = interval_ms
|
||||
interval_sec = interval_ms / 1000.0
|
||||
|
||||
def heartbeat_loop():
|
||||
while not self._stop_event.is_set() and self._connected:
|
||||
try:
|
||||
self._ws_send({
|
||||
"op": OP_HEARTBEAT,
|
||||
"d": self._last_seq,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"[QQ] Heartbeat send failed: {e}")
|
||||
break
|
||||
self._stop_event.wait(interval_sec)
|
||||
|
||||
self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
|
||||
self._heartbeat_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Incoming message dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_ws_message(self, data: dict):
|
||||
op = data.get("op")
|
||||
d = data.get("d")
|
||||
t = data.get("t")
|
||||
s = data.get("s")
|
||||
|
||||
if s is not None:
|
||||
self._last_seq = s
|
||||
|
||||
if op == OP_HELLO:
|
||||
heartbeat_interval = d.get("heartbeat_interval", 45000) if d else 45000
|
||||
logger.debug(f"[QQ] Received Hello, heartbeat_interval={heartbeat_interval}ms")
|
||||
self._heartbeat_interval = heartbeat_interval
|
||||
if self._can_resume and self._session_id:
|
||||
self._send_resume()
|
||||
else:
|
||||
self._send_identify()
|
||||
|
||||
elif op == OP_HEARTBEAT_ACK:
|
||||
pass
|
||||
|
||||
elif op == OP_HEARTBEAT:
|
||||
self._ws_send({"op": OP_HEARTBEAT, "d": self._last_seq})
|
||||
|
||||
elif op == OP_RECONNECT:
|
||||
logger.warning("[QQ] Server requested reconnect")
|
||||
self._can_resume = True
|
||||
if self._ws:
|
||||
self._ws.close()
|
||||
|
||||
elif op == OP_INVALID_SESSION:
|
||||
logger.warning("[QQ] Invalid session, re-identifying...")
|
||||
self._session_id = None
|
||||
self._can_resume = False
|
||||
time.sleep(2)
|
||||
self._send_identify()
|
||||
|
||||
elif op == OP_DISPATCH:
|
||||
if t == "READY":
|
||||
self._session_id = d.get("session_id", "")
|
||||
user = d.get("user", {})
|
||||
bot_name = user.get('username', '')
|
||||
logger.info(f"[QQ] ✅ Connected successfully (bot={bot_name})")
|
||||
self._connected = True
|
||||
self._can_resume = False
|
||||
self._start_heartbeat(self._heartbeat_interval)
|
||||
self.report_startup_success()
|
||||
|
||||
elif t == "RESUMED":
|
||||
logger.info("[QQ] Session resumed successfully")
|
||||
self._connected = True
|
||||
self._can_resume = False
|
||||
self._start_heartbeat(self._heartbeat_interval)
|
||||
|
||||
elif t in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE",
|
||||
"AT_MESSAGE_CREATE", "DIRECT_MESSAGE_CREATE"):
|
||||
self._handle_msg_event(d, t)
|
||||
|
||||
elif t in ("GROUP_ADD_ROBOT", "FRIEND_ADD"):
|
||||
logger.info(f"[QQ] Event: {t}")
|
||||
|
||||
else:
|
||||
logger.debug(f"[QQ] Dispatch event: {t}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message event handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_msg_event(self, event_data: dict, event_type: str):
|
||||
msg_id = event_data.get("id", "")
|
||||
if self.received_msgs.get(msg_id):
|
||||
logger.debug(f"[QQ] Duplicate msg filtered: {msg_id}")
|
||||
return
|
||||
self.received_msgs[msg_id] = True
|
||||
|
||||
try:
|
||||
qq_msg = QQMessage(event_data, event_type)
|
||||
except NotImplementedError as e:
|
||||
logger.warning(f"[QQ] {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to parse message: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
is_group = qq_msg.is_group
|
||||
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
if is_group:
|
||||
session_id = qq_msg.other_user_id
|
||||
else:
|
||||
session_id = qq_msg.from_user_id
|
||||
|
||||
if qq_msg.ctype == ContextType.IMAGE:
|
||||
if hasattr(qq_msg, "image_path") and qq_msg.image_path:
|
||||
file_cache.add(session_id, qq_msg.image_path, file_type="image")
|
||||
logger.info(f"[QQ] Image cached for session {session_id}")
|
||||
return
|
||||
|
||||
if qq_msg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
file_refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
fpath = fi["path"]
|
||||
if ftype == "image":
|
||||
file_refs.append(f"[图片: {fpath}]")
|
||||
elif ftype == "video":
|
||||
file_refs.append(f"[视频: {fpath}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {fpath}]")
|
||||
qq_msg.content = qq_msg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[QQ] Attached {len(cached_files)} cached file(s)")
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(
|
||||
qq_msg.ctype,
|
||||
qq_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=qq_msg,
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _compose_context
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
|
||||
if cmsg.is_group:
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send reply
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context.get("isgroup", False)
|
||||
receiver = context.get("receiver", "")
|
||||
|
||||
if not msg:
|
||||
# Active send (e.g. scheduled tasks), no original message to reply to
|
||||
self._active_send_text(reply.content if reply.type == ReplyType.TEXT else str(reply.content),
|
||||
receiver, is_group)
|
||||
return
|
||||
|
||||
event_type = getattr(msg, "event_type", "")
|
||||
msg_id = getattr(msg, "msg_id", "")
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
self._send_text(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
|
||||
self._send_image(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type == ReplyType.FILE:
|
||||
if hasattr(reply, "text_content") and reply.text_content:
|
||||
self._send_text(reply.text_content, msg, event_type, msg_id)
|
||||
time.sleep(0.3)
|
||||
self._send_file(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type in (ReplyType.VIDEO, ReplyType.VIDEO_URL):
|
||||
self._send_media(reply.content, msg, event_type, msg_id, QQ_FILE_TYPE_VIDEO)
|
||||
else:
|
||||
logger.warning(f"[QQ] Unsupported reply type: {reply.type}, falling back to text")
|
||||
self._send_text(str(reply.content), msg, event_type, msg_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_next_msg_seq(self, msg_id: str) -> int:
|
||||
seq = self._msg_seq_counter.get(msg_id, 1)
|
||||
self._msg_seq_counter[msg_id] = seq + 1
|
||||
return seq
|
||||
|
||||
def _build_msg_url_and_base_body(self, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Build the API URL and base body dict for sending a message."""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
url = f"{QQ_API_BASE}/v2/groups/{group_openid}/messages"
|
||||
body = {
|
||||
"msg_id": msg_id,
|
||||
"msg_seq": self._get_next_msg_seq(msg_id),
|
||||
}
|
||||
return url, body, "group", group_openid
|
||||
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = msg._rawmsg.get("author", {}).get("user_openid", "") or msg.from_user_id
|
||||
url = f"{QQ_API_BASE}/v2/users/{user_openid}/messages"
|
||||
body = {
|
||||
"msg_id": msg_id,
|
||||
"msg_seq": self._get_next_msg_seq(msg_id),
|
||||
}
|
||||
return url, body, "c2c", user_openid
|
||||
|
||||
elif event_type == "AT_MESSAGE_CREATE":
|
||||
channel_id = msg._rawmsg.get("channel_id", "")
|
||||
url = f"{QQ_API_BASE}/channels/{channel_id}/messages"
|
||||
body = {"msg_id": msg_id}
|
||||
return url, body, "channel", channel_id
|
||||
|
||||
elif event_type == "DIRECT_MESSAGE_CREATE":
|
||||
guild_id = msg._rawmsg.get("guild_id", "")
|
||||
url = f"{QQ_API_BASE}/dms/{guild_id}/messages"
|
||||
body = {"msg_id": msg_id}
|
||||
return url, body, "dm", guild_id
|
||||
|
||||
return None, None, None, None
|
||||
|
||||
def _post_message(self, url: str, body: dict, event_type: str):
|
||||
try:
|
||||
resp = requests.post(url, json=body, headers=self._get_auth_headers(), timeout=10)
|
||||
if resp.status_code in (200, 201, 202, 204):
|
||||
logger.info(f"[QQ] Message sent successfully: event_type={event_type}")
|
||||
else:
|
||||
logger.error(f"[QQ] Failed to send message: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Send message error: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Active send (no original message, e.g. scheduled tasks)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _active_send_text(self, content: str, receiver: str, is_group: bool):
|
||||
"""Send text without an original message (active push). QQ limits active messages to 4/month per user."""
|
||||
if not receiver:
|
||||
logger.warning("[QQ] No receiver for active send")
|
||||
return
|
||||
if is_group:
|
||||
url = f"{QQ_API_BASE}/v2/groups/{receiver}/messages"
|
||||
else:
|
||||
url = f"{QQ_API_BASE}/v2/users/{receiver}/messages"
|
||||
body = {
|
||||
"content": content,
|
||||
"msg_type": 0,
|
||||
}
|
||||
event_label = "GROUP_ACTIVE" if is_group else "C2C_ACTIVE"
|
||||
self._post_message(url, body, event_label)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send text
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_text(self, content: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
|
||||
if not url:
|
||||
logger.warning(f"[QQ] Cannot send reply for event_type: {event_type}")
|
||||
return
|
||||
body["content"] = content
|
||||
body["msg_type"] = 0
|
||||
self._post_message(url, body, event_type)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rich media upload & send (image / video / file)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _upload_rich_media(self, file_url: str, file_type: int, msg: QQMessage,
|
||||
event_type: str) -> str:
|
||||
"""
|
||||
Upload media via QQ rich media API and return file_info.
|
||||
For group: POST /v2/groups/{group_openid}/files
|
||||
For c2c: POST /v2/users/{openid}/files
|
||||
"""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
|
||||
or msg.from_user_id)
|
||||
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
|
||||
else:
|
||||
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
|
||||
return ""
|
||||
|
||||
upload_body = {
|
||||
"file_type": file_type,
|
||||
"url": file_url,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
upload_url, json=upload_body,
|
||||
headers=self._get_auth_headers(), timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
data = resp.json()
|
||||
file_info = data.get("file_info", "")
|
||||
logger.info(f"[QQ] Rich media uploaded: file_type={file_type}, "
|
||||
f"file_uuid={data.get('file_uuid', '')}")
|
||||
return file_info
|
||||
else:
|
||||
logger.error(f"[QQ] Rich media upload failed: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Rich media upload error: {e}")
|
||||
return ""
|
||||
|
||||
def _upload_rich_media_base64(self, file_path: str, file_type: int, msg: QQMessage,
|
||||
event_type: str) -> str:
|
||||
"""Upload local file via base64 file_data field."""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
|
||||
or msg.from_user_id)
|
||||
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
|
||||
else:
|
||||
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = base64.b64encode(f.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to read file for upload: {e}")
|
||||
return ""
|
||||
|
||||
upload_body = {
|
||||
"file_type": file_type,
|
||||
"file_data": file_data,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
upload_url, json=upload_body,
|
||||
headers=self._get_auth_headers(), timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
data = resp.json()
|
||||
file_info = data.get("file_info", "")
|
||||
logger.info(f"[QQ] Rich media uploaded (base64): file_type={file_type}, "
|
||||
f"file_uuid={data.get('file_uuid', '')}")
|
||||
return file_info
|
||||
else:
|
||||
logger.error(f"[QQ] Rich media upload (base64) failed: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Rich media upload (base64) error: {e}")
|
||||
return ""
|
||||
|
||||
def _send_media_msg(self, file_info: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send a message with msg_type=7 (rich media) using file_info."""
|
||||
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
|
||||
if not url:
|
||||
return
|
||||
body["msg_type"] = 7
|
||||
body["media"] = {"file_info": file_info}
|
||||
self._post_message(url, body, event_type)
|
||||
|
||||
def _send_image(self, img_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send image reply. Supports URL and local file path."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(img_path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if img_path_or_url.startswith("file://"):
|
||||
img_path_or_url = img_path_or_url[7:]
|
||||
|
||||
if img_path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(
|
||||
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
|
||||
elif os.path.exists(img_path_or_url):
|
||||
file_info = self._upload_rich_media_base64(
|
||||
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] Image not found: {img_path_or_url}")
|
||||
self._send_text("[Image send failed]", msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
self._send_text("[Image upload failed]", msg, event_type, msg_id)
|
||||
|
||||
def _send_file(self, file_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send file reply."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(file_path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_path_or_url.startswith("file://"):
|
||||
file_path_or_url = file_path_or_url[7:]
|
||||
|
||||
if file_path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(
|
||||
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
|
||||
elif os.path.exists(file_path_or_url):
|
||||
file_info = self._upload_rich_media_base64(
|
||||
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] File not found: {file_path_or_url}")
|
||||
self._send_text("[File send failed]", msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
self._send_text("[File upload failed]", msg, event_type, msg_id)
|
||||
|
||||
def _send_media(self, path_or_url: str, msg: QQMessage, event_type: str,
|
||||
msg_id: str, file_type: int):
|
||||
"""Generic media send for video/voice etc."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if path_or_url.startswith("file://"):
|
||||
path_or_url = path_or_url[7:]
|
||||
|
||||
if path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(path_or_url, file_type, msg, event_type)
|
||||
elif os.path.exists(path_or_url):
|
||||
file_info = self._upload_rich_media_base64(path_or_url, file_type, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] Media not found: {path_or_url}")
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
logger.error(f"[QQ] Media upload failed: {path_or_url}")
|
||||
123
channel/qq/qq_message.py
Normal file
123
channel/qq/qq_message.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
import requests
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
def _get_tmp_dir() -> str:
|
||||
"""Return the workspace tmp directory (absolute path), creating it if needed."""
|
||||
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(ws_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class QQMessage(ChatMessage):
|
||||
"""Message wrapper for QQ Bot (websocket long-connection mode)."""
|
||||
|
||||
def __init__(self, event_data: dict, event_type: str):
|
||||
super().__init__(event_data)
|
||||
self.msg_id = event_data.get("id", "")
|
||||
self.create_time = event_data.get("timestamp", "")
|
||||
self.is_group = event_type in ("GROUP_AT_MESSAGE_CREATE",)
|
||||
self.event_type = event_type
|
||||
|
||||
author = event_data.get("author", {})
|
||||
from_user_id = author.get("member_openid", "") or author.get("id", "")
|
||||
group_openid = event_data.get("group_openid", "")
|
||||
|
||||
content = event_data.get("content", "").strip()
|
||||
|
||||
attachments = event_data.get("attachments", [])
|
||||
has_image = any(
|
||||
a.get("content_type", "").startswith("image/") for a in attachments
|
||||
) if attachments else False
|
||||
|
||||
if has_image and not content:
|
||||
self.ctype = ContextType.IMAGE
|
||||
img_attachment = next(
|
||||
a for a in attachments if a.get("content_type", "").startswith("image/")
|
||||
)
|
||||
img_url = img_attachment.get("url", "")
|
||||
if img_url and not img_url.startswith("http"):
|
||||
img_url = "https://" + img_url
|
||||
tmp_dir = _get_tmp_dir()
|
||||
image_path = os.path.join(tmp_dir, f"qq_{self.msg_id}.png")
|
||||
try:
|
||||
resp = requests.get(img_url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
self.content = image_path
|
||||
self.image_path = image_path
|
||||
logger.info(f"[QQ] Image downloaded: {image_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to download image: {e}")
|
||||
self.content = "[Image download failed]"
|
||||
self.image_path = None
|
||||
elif has_image and content:
|
||||
self.ctype = ContextType.TEXT
|
||||
image_paths = []
|
||||
tmp_dir = _get_tmp_dir()
|
||||
for idx, att in enumerate(attachments):
|
||||
if not att.get("content_type", "").startswith("image/"):
|
||||
continue
|
||||
img_url = att.get("url", "")
|
||||
if img_url and not img_url.startswith("http"):
|
||||
img_url = "https://" + img_url
|
||||
img_path = os.path.join(tmp_dir, f"qq_{self.msg_id}_{idx}.png")
|
||||
try:
|
||||
resp = requests.get(img_url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
with open(img_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
image_paths.append(img_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to download mixed image: {e}")
|
||||
content_parts = [content]
|
||||
for p in image_paths:
|
||||
content_parts.append(f"[图片: {p}]")
|
||||
self.content = "\n".join(content_parts)
|
||||
else:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = content
|
||||
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
self.other_user_id = group_openid
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = from_user_id
|
||||
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = author.get("user_openid", "") or from_user_id
|
||||
self.from_user_id = user_openid
|
||||
self.to_user_id = ""
|
||||
self.other_user_id = user_openid
|
||||
self.actual_user_id = user_openid
|
||||
|
||||
elif event_type == "AT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
channel_id = event_data.get("channel_id", "")
|
||||
self.other_user_id = channel_id
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = author.get("username", from_user_id)
|
||||
|
||||
elif event_type == "DIRECT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
guild_id = event_data.get("guild_id", "")
|
||||
self.other_user_id = f"dm_{guild_id}_{from_user_id}"
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = author.get("username", from_user_id)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported QQ event type: {event_type}")
|
||||
|
||||
logger.debug(f"[QQ] Message parsed: type={event_type}, ctype={self.ctype}, "
|
||||
f"from={self.from_user_id}, content_len={len(self.content)}")
|
||||
1
channel/slack/__init__.py
Normal file
1
channel/slack/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
506
channel/slack/slack_channel.py
Normal file
506
channel/slack/slack_channel.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Slack channel via Bolt for Python (Socket Mode).
|
||||
|
||||
Features:
|
||||
- Direct message & channel chat (text / image / file)
|
||||
- Channel trigger: @mention or reply in a thread the bot is in (configurable)
|
||||
- /cancel fast-path matches Web channel behaviour
|
||||
- Socket Mode: no public IP / callback URL required, works behind NAT
|
||||
|
||||
Implementation note:
|
||||
slack_bolt's SocketModeHandler is blocking and runs its own background
|
||||
threads. We start it in a dedicated thread so the rest of cow (sync) stays
|
||||
untouched. Inbound events are dispatched onto cow's existing sync
|
||||
ChatChannel.produce() pipeline; outbound send() calls the Slack Web API
|
||||
client directly (it is sync-safe).
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
|
||||
import requests
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.slack.slack_message import SlackMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
|
||||
@singleton
|
||||
class SlackChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = ""
|
||||
self.app_token = ""
|
||||
self.bot_user_id = "" # used to strip @mention and ignore self messages
|
||||
self._app = None
|
||||
self._handler = None
|
||||
self._client = None
|
||||
self._loop_thread = None
|
||||
# Idempotent dedup; Slack retries event delivery on slow ack
|
||||
self._received_msgs = ExpiredDict(60 * 60 * 1)
|
||||
|
||||
# Disable group whitelist / prefix checks (we handle triggering ourselves
|
||||
# in _should_reply_in_channel), aligned with telegram / feishu channels.
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.bot_token = conf().get("slack_bot_token", "")
|
||||
self.app_token = conf().get("slack_app_token", "")
|
||||
if not self.bot_token or not self.app_token:
|
||||
err = "[Slack] slack_bot_token and slack_app_token are both required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
# Guard against the common mistake of swapping the two tokens:
|
||||
# bot token must start with xoxb-, app-level token with xapp-.
|
||||
if not self.bot_token.startswith("xoxb-") or not self.app_token.startswith("xapp-"):
|
||||
err = (
|
||||
"[Slack] token type mismatch: slack_bot_token must start with 'xoxb-' "
|
||||
"and slack_app_token must start with 'xapp-' (they look swapped)"
|
||||
)
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
try:
|
||||
from slack_bolt import App
|
||||
from slack_bolt.adapter.socket_mode import SocketModeHandler
|
||||
except ImportError:
|
||||
err = (
|
||||
"[Slack] slack_bolt is not installed. "
|
||||
"Run: pip install slack_bolt"
|
||||
)
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
try:
|
||||
self._app = App(token=self.bot_token)
|
||||
self._client = self._app.client
|
||||
|
||||
# Resolve our own bot user id (needed for @mention strip / self-ignore)
|
||||
auth = self._client.auth_test()
|
||||
self.bot_user_id = auth.get("user_id", "")
|
||||
self.name = self.bot_user_id # ChatChannel uses self.name to strip @-mention
|
||||
logger.info(f"[Slack] Bot logged in as user_id={self.bot_user_id}, team={auth.get('team')}")
|
||||
except Exception as e:
|
||||
err = f"[Slack] auth_test failed: {e}"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._register_handlers()
|
||||
|
||||
self._handler = SocketModeHandler(self._app, self.app_token)
|
||||
|
||||
def _run():
|
||||
try:
|
||||
logger.info("[Slack] Starting Socket Mode connection...")
|
||||
self.report_startup_success()
|
||||
logger.info("[Slack] ✅ Slack bot ready, listening for events")
|
||||
self._handler.start()
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] socket mode crashed: {e}", exc_info=True)
|
||||
self.report_startup_error(str(e))
|
||||
finally:
|
||||
logger.info("[Slack] socket mode exited")
|
||||
|
||||
self._loop_thread = threading.Thread(target=_run, daemon=True, name="slack-socket")
|
||||
self._loop_thread.start()
|
||||
# Block startup() until the handler thread exits, matching other channels'
|
||||
# behaviour (startup is a blocking call).
|
||||
self._loop_thread.join()
|
||||
|
||||
def _register_handlers(self):
|
||||
app = self._app
|
||||
|
||||
# app_mention: bot is @-mentioned in a channel
|
||||
@app.event("app_mention")
|
||||
def _on_app_mention(event, ack):
|
||||
ack()
|
||||
self._handle_event(event, is_group=True)
|
||||
|
||||
# message: DMs and channel messages (including thread replies)
|
||||
@app.event("message")
|
||||
def _on_message(event, ack):
|
||||
ack()
|
||||
self._handle_message_event(event)
|
||||
|
||||
def stop(self):
|
||||
logger.info("[Slack] stop() called")
|
||||
try:
|
||||
if self._handler is not None:
|
||||
self._handler.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Slack] handler close error: {e}")
|
||||
if self._loop_thread and self._loop_thread.is_alive():
|
||||
try:
|
||||
self._loop_thread.join(timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[Slack] stop() completed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound: slack event -> ChatMessage -> ChatChannel.produce
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_message_event(self, event: dict):
|
||||
"""Route a raw `message` event: skip bot/system noise, decide grouping."""
|
||||
try:
|
||||
logger.debug(
|
||||
f"[Slack] message event: channel_type={event.get('channel_type')}, "
|
||||
f"subtype={event.get('subtype')}, user={event.get('user')}, "
|
||||
f"ts={event.get('ts')}, thread_ts={event.get('thread_ts')}"
|
||||
)
|
||||
# Ignore bot messages (including our own) and message edits/deletes
|
||||
if event.get("bot_id") or event.get("subtype") in ("bot_message", "message_changed", "message_deleted"):
|
||||
return
|
||||
if event.get("user") == self.bot_user_id:
|
||||
return
|
||||
|
||||
channel_type = event.get("channel_type", "")
|
||||
# DM (im) is single chat; channel/group is group chat. app_mention
|
||||
# already covers channel @-mentions, so for plain channel messages we
|
||||
# only react when configured / thread-following.
|
||||
is_group = channel_type in ("channel", "group", "mpim")
|
||||
if is_group:
|
||||
# app_mention handler covers explicit @bot; here we only handle
|
||||
# follow-up replies in threads the bot participates in.
|
||||
if not self._should_reply_in_channel(event):
|
||||
return
|
||||
self._handle_event(event, is_group=is_group)
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] _handle_message_event error: {e}", exc_info=True)
|
||||
|
||||
def _handle_event(self, event: dict, is_group: bool):
|
||||
"""Parse event -> build SlackMessage -> produce()."""
|
||||
try:
|
||||
channel_id = event.get("channel", "")
|
||||
ts = event.get("ts", "")
|
||||
if not channel_id:
|
||||
return
|
||||
|
||||
# Idempotent dedup
|
||||
msg_uid = f"{channel_id}:{ts}"
|
||||
if self._received_msgs.get(msg_uid):
|
||||
return
|
||||
self._received_msgs[msg_uid] = True
|
||||
|
||||
# Parse type + download media if needed.
|
||||
ctype, content, caption = self._parse_event(event)
|
||||
if ctype is None:
|
||||
logger.debug(f"[Slack] unsupported message type, skip. event={event}")
|
||||
return
|
||||
|
||||
# Strip <@bot_user_id> mention from channel text
|
||||
if is_group and self.bot_user_id:
|
||||
if ctype == ContextType.TEXT and content:
|
||||
content = self._strip_at_mention(content)
|
||||
if caption:
|
||||
caption = self._strip_at_mention(caption)
|
||||
|
||||
slack_msg = SlackMessage(
|
||||
event,
|
||||
is_group=is_group,
|
||||
bot_user_id=self.bot_user_id,
|
||||
ctype=ctype,
|
||||
content=content,
|
||||
)
|
||||
slack_msg.is_at = is_group # if we reached here in a channel, bot is mentioned/threaded
|
||||
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
session_id = self._compute_session_id(event, is_group)
|
||||
|
||||
# Media + caption together: treat as a complete query and bypass the cache
|
||||
if ctype in (ContextType.IMAGE, ContextType.FILE) and caption:
|
||||
tag = "image" if ctype == ContextType.IMAGE else "file"
|
||||
merged_text = f"{caption}\n[{tag}: {content}]"
|
||||
slack_msg.ctype = ContextType.TEXT
|
||||
slack_msg.content = merged_text
|
||||
ctype = ContextType.TEXT
|
||||
logger.info(f"[Slack] Media+caption merged for session {session_id}")
|
||||
# fallthrough to the TEXT branch below
|
||||
|
||||
elif ctype == ContextType.IMAGE:
|
||||
file_cache.add(session_id, content, file_type="image")
|
||||
logger.info(f"[Slack] Image cached for session {session_id}, waiting for query...")
|
||||
return
|
||||
elif ctype == ContextType.FILE:
|
||||
file_cache.add(session_id, content, file_type="file")
|
||||
logger.info(f"[Slack] File cached for session {session_id}: {content}")
|
||||
return
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# Fast-path: /cancel mirrors Web channel behaviour
|
||||
if (content or "").strip().lower() in ("/cancel", "cancel"):
|
||||
self._do_cancel(session_id, channel_id, event)
|
||||
return
|
||||
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
tag = ftype if ftype in ("image", "video") else "file"
|
||||
refs.append(f"[{tag}: {fi['path']}]")
|
||||
slack_msg.content = (slack_msg.content or "") + "\n" + "\n".join(refs)
|
||||
file_cache.clear(session_id)
|
||||
logger.info(f"[Slack] Attached {len(cached_files)} cached file(s) to query")
|
||||
|
||||
# Reply in the originating thread when present, else start one on this msg
|
||||
thread_ts = event.get("thread_ts") or ts
|
||||
|
||||
context = self._compose_context(
|
||||
slack_msg.ctype,
|
||||
slack_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=slack_msg,
|
||||
# Replies go back into the thread, no manual @mention needed
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = channel_id
|
||||
context["slack_channel"] = channel_id
|
||||
context["slack_thread_ts"] = thread_ts if is_group else None
|
||||
self.produce(context)
|
||||
logger.debug(f"[Slack] received: type={ctype}, content={str(slack_msg.content)[:80]}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] _handle_event error: {e}", exc_info=True)
|
||||
|
||||
def _do_cancel(self, session_id: str, channel_id: str, event: dict):
|
||||
"""Fast-path: /cancel calls cancel_session directly without going through agent."""
|
||||
try:
|
||||
from agent.protocol import get_cancel_registry
|
||||
cancelled = get_cancel_registry().cancel_session(session_id)
|
||||
text = "Current task cancelled." if cancelled else "No running task to cancel."
|
||||
thread_ts = event.get("thread_ts") or event.get("ts")
|
||||
self._client.chat_postMessage(channel=channel_id, text=text, thread_ts=thread_ts)
|
||||
logger.info(f"[Slack] /cancel session={session_id}, cancelled={cancelled}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] /cancel error: {e}", exc_info=True)
|
||||
|
||||
def _parse_event(self, event: dict):
|
||||
"""Parse a slack event and return (ctype, content, caption).
|
||||
|
||||
- content is text for ContextType.TEXT, otherwise the local file path
|
||||
- caption is the optional text accompanying a file; empty for plain text
|
||||
"""
|
||||
text = (event.get("text") or "").strip()
|
||||
files = event.get("files") or []
|
||||
|
||||
if files:
|
||||
# Handle the first attachment; caption is the accompanying message text
|
||||
f = files[0]
|
||||
mimetype = (f.get("mimetype") or "").lower()
|
||||
url = f.get("url_private_download") or f.get("url_private")
|
||||
name = f.get("name") or f.get("id") or "file"
|
||||
if not url:
|
||||
return (None, None, "")
|
||||
path = self._download_file(url, name)
|
||||
if not path:
|
||||
return (None, None, "")
|
||||
if mimetype.startswith("image/"):
|
||||
return (ContextType.IMAGE, path, text)
|
||||
return (ContextType.FILE, path, text)
|
||||
|
||||
if text:
|
||||
return (ContextType.TEXT, text, "")
|
||||
|
||||
return (None, None, "")
|
||||
|
||||
def _download_file(self, url: str, name: str):
|
||||
"""Download a Slack private file (requires bot token auth) to local tmp dir."""
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {self.bot_token}"}
|
||||
resp = requests.get(url, headers=headers, timeout=60, stream=True)
|
||||
resp.raise_for_status()
|
||||
tmp_dir = SlackMessage.get_tmp_dir()
|
||||
# Sanitize the name and keep it unique-ish via the url tail
|
||||
safe_name = re.sub(r"[^\w.\-]", "_", name)
|
||||
local_path = os.path.join(tmp_dir, safe_name)
|
||||
with open(local_path, "wb") as fp:
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
fp.write(chunk)
|
||||
logger.debug(f"[Slack] downloaded {name} -> {local_path}")
|
||||
return local_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] download_file failed ({name}): {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Channel trigger logic
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _should_reply_in_channel(self, event: dict) -> bool:
|
||||
"""Decide whether to reply to a plain channel message (no @mention).
|
||||
|
||||
app_mention already handles explicit @bot, so here we only deal with
|
||||
follow-up messages. `all` replies to every message; `mention_or_reply`
|
||||
replies inside threads the bot already participates in.
|
||||
"""
|
||||
mode = conf().get("slack_group_trigger", "mention_or_reply")
|
||||
if mode == "all":
|
||||
return True
|
||||
if mode == "mention_only":
|
||||
return False
|
||||
# mention_or_reply: follow up only within an existing thread
|
||||
return bool(event.get("thread_ts"))
|
||||
|
||||
def _strip_at_mention(self, content: str) -> str:
|
||||
"""Strip <@BOT_USER_ID> from channel text."""
|
||||
if not content or not self.bot_user_id:
|
||||
return content
|
||||
pattern = re.compile(r"<@" + re.escape(self.bot_user_id) + r">", re.IGNORECASE)
|
||||
return pattern.sub("", content).strip()
|
||||
|
||||
@staticmethod
|
||||
def _compute_session_id(event: dict, is_group: bool) -> str:
|
||||
channel_id = event.get("channel", "")
|
||||
user_id = event.get("user", "")
|
||||
if is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
return f"slack_channel_{channel_id}"
|
||||
return f"slack_channel_{channel_id}_{user_id}"
|
||||
return f"slack_user_{user_id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override _compose_context: skip the parent's group whitelist/at checks
|
||||
# (already handled via _should_reply_in_channel). Same idea as telegram.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
if cmsg.is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = (content or "").strip()
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif ctype == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and (
|
||||
conf().get("voice_reply_voice") or conf().get("always_reply_voice")
|
||||
):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound: ChatChannel.send -> Slack Web API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""Called from cow's sync main thread; Slack Web client is sync-safe."""
|
||||
if self._client is None:
|
||||
logger.warning("[Slack] client not ready, drop reply")
|
||||
return
|
||||
|
||||
channel_id = context.get("slack_channel")
|
||||
thread_ts = context.get("slack_thread_ts")
|
||||
if not channel_id:
|
||||
logger.warning("[Slack] no slack_channel in context, drop reply")
|
||||
return
|
||||
|
||||
try:
|
||||
self._do_send(reply, channel_id, thread_ts)
|
||||
logger.info(f"[Slack] sent reply (type={reply.type}, channel={channel_id})")
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] send failed: {e}", exc_info=True)
|
||||
|
||||
def _do_send(self, reply: Reply, channel_id: str, thread_ts):
|
||||
rtype = reply.type
|
||||
content = reply.content
|
||||
|
||||
if rtype in (ReplyType.TEXT, ReplyType.INFO, ReplyType.ERROR):
|
||||
text = str(content) if content is not None else ""
|
||||
if not text:
|
||||
return
|
||||
# Slack caps a message around 40k chars; split conservatively
|
||||
for chunk in _split_text(text, 3500):
|
||||
self._client.chat_postMessage(channel=channel_id, text=chunk, thread_ts=thread_ts)
|
||||
|
||||
elif rtype == ReplyType.IMAGE:
|
||||
# Already a local BytesIO; upload it directly
|
||||
content.seek(0)
|
||||
self._client.files_upload_v2(
|
||||
channel=channel_id, file=content, filename="image.png", thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
elif rtype == ReplyType.IMAGE_URL:
|
||||
url = str(content)
|
||||
if url.startswith("file://"):
|
||||
local = url[7:]
|
||||
self._client.files_upload_v2(
|
||||
channel=channel_id, file=local, thread_ts=thread_ts,
|
||||
)
|
||||
else:
|
||||
# Post the URL as text; Slack will unfurl it as an image preview
|
||||
self._client.chat_postMessage(channel=channel_id, text=url, thread_ts=thread_ts)
|
||||
|
||||
elif rtype in (ReplyType.VOICE, ReplyType.FILE):
|
||||
local = content[7:] if isinstance(content, str) and content.startswith("file://") else content
|
||||
caption = getattr(reply, "text_content", None) or None
|
||||
self._client.files_upload_v2(
|
||||
channel=channel_id, file=local, initial_comment=caption, thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
else:
|
||||
# Fallback: send as plain text
|
||||
self._client.chat_postMessage(channel=channel_id, text=str(content), thread_ts=thread_ts)
|
||||
|
||||
|
||||
def _split_text(text: str, limit: int):
|
||||
"""Split long text preferring line breaks to keep markdown structure intact."""
|
||||
if len(text) <= limit:
|
||||
yield text
|
||||
return
|
||||
buf = []
|
||||
size = 0
|
||||
for line in text.splitlines(keepends=True):
|
||||
if size + len(line) > limit and buf:
|
||||
yield "".join(buf)
|
||||
buf, size = [], 0
|
||||
# Hard-split single lines that exceed the limit
|
||||
while len(line) > limit:
|
||||
yield line[:limit]
|
||||
line = line[limit:]
|
||||
buf.append(line)
|
||||
size += len(line)
|
||||
if buf:
|
||||
yield "".join(buf)
|
||||
60
channel/slack/slack_message.py
Normal file
60
channel/slack/slack_message.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Slack message adapter.
|
||||
|
||||
Convert a Slack event payload into cow's unified ChatMessage.
|
||||
File downloads are NOT performed here; the channel layer downloads files
|
||||
on demand because it needs the bot token for authenticated download URLs.
|
||||
"""
|
||||
import os
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
class SlackMessage(ChatMessage):
|
||||
"""Wrap a Slack event into the unified ChatMessage."""
|
||||
|
||||
def __init__(self, event: dict, is_group: bool = False, bot_user_id: str = "",
|
||||
ctype: ContextType = ContextType.TEXT, content: str = ""):
|
||||
super().__init__(event)
|
||||
# Basic fields
|
||||
self.msg_id = event.get("client_msg_id") or event.get("ts") or ""
|
||||
try:
|
||||
self.create_time = int(float(event.get("ts", 0)))
|
||||
except (TypeError, ValueError):
|
||||
self.create_time = 0
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
|
||||
# Sender / chat info
|
||||
from_user_id = event.get("user", "unknown")
|
||||
channel_id = event.get("channel", "")
|
||||
self.from_user_id = from_user_id
|
||||
self.from_user_nickname = from_user_id
|
||||
self.to_user_id = bot_user_id or "slack_bot"
|
||||
self.to_user_nickname = bot_user_id or "slack_bot"
|
||||
|
||||
self.is_group = is_group
|
||||
if is_group:
|
||||
# Channel chat: other_user_id = channel_id, actual_user_id = sender id
|
||||
self.other_user_id = channel_id
|
||||
self.other_user_nickname = channel_id
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = from_user_id
|
||||
else:
|
||||
# DM: use channel_id so replies go back to the same DM channel
|
||||
self.other_user_id = channel_id or from_user_id
|
||||
self.other_user_nickname = from_user_id
|
||||
|
||||
# Whether the bot was triggered by @-mention (set by channel layer)
|
||||
self.is_at = False
|
||||
|
||||
@staticmethod
|
||||
def get_tmp_dir() -> str:
|
||||
"""Local download directory, aligned with other channels (agent_workspace/tmp)."""
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
0
channel/telegram/__init__.py
Normal file
0
channel/telegram/__init__.py
Normal file
719
channel/telegram/telegram_channel.py
Normal file
719
channel/telegram/telegram_channel.py
Normal file
@@ -0,0 +1,719 @@
|
||||
"""
|
||||
Telegram channel via Bot API (long polling mode).
|
||||
|
||||
Features:
|
||||
- Single chat & group chat (text / photo / voice / video / document)
|
||||
- Group trigger: @mention or reply-to-bot (configurable)
|
||||
- /cancel fast-path matches Web channel behaviour
|
||||
- Auto-register bot commands menu on startup (mirrors Web slash menu)
|
||||
- Optional HTTP/SOCKS5 proxy support for restricted networks
|
||||
|
||||
Implementation note:
|
||||
python-telegram-bot is async-first. We run the bot inside a dedicated
|
||||
thread with its own asyncio loop so the rest of cow (which is sync)
|
||||
stays untouched. Inbound updates are dispatched onto cow's existing
|
||||
sync ChatChannel.produce() pipeline; outbound send() schedules
|
||||
coroutines back onto that loop via asyncio.run_coroutine_threadsafe.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.telegram.telegram_message import TelegramMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
# Bot command menu, aligned with Web slash commands.
|
||||
# Top-level commands only; sub-commands are entered with a space (e.g. "/skill list").
|
||||
TELEGRAM_BOT_COMMANDS = [
|
||||
("help", "Show command help"),
|
||||
("status", "Show running status"),
|
||||
("context", "View/clear conversation context (sub: clear)"),
|
||||
("skill", "Manage skills (list/search/install/...)"),
|
||||
("memory", "Manage memory (sub: dream)"),
|
||||
("knowledge", "Manage knowledge base (list/on/off)"),
|
||||
("config", "Show current config"),
|
||||
("cancel", "Cancel running agent task"),
|
||||
("logs", "Show recent logs"),
|
||||
("version", "Show version"),
|
||||
]
|
||||
|
||||
|
||||
@singleton
|
||||
class TelegramChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = ""
|
||||
self.bot_username = "" # used for @-mention matching
|
||||
self._bot = None
|
||||
self._application = None
|
||||
self._loop = None
|
||||
self._loop_thread = None
|
||||
self._stop_event = threading.Event()
|
||||
# Idempotent dedup; TG occasionally redelivers the same update on flaky networks
|
||||
self._received_msgs = ExpiredDict(60 * 60 * 1)
|
||||
|
||||
# Disable group whitelist / prefix checks (we handle triggering ourselves
|
||||
# in _should_reply_in_group), aligned with feishu / wecom_bot channels.
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.bot_token = conf().get("telegram_token", "")
|
||||
if not self.bot_token:
|
||||
err = "[Telegram] telegram_token is required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
try:
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
MessageHandler,
|
||||
CommandHandler,
|
||||
filters,
|
||||
)
|
||||
except ImportError:
|
||||
err = (
|
||||
"[Telegram] python-telegram-bot is not installed. "
|
||||
"Run: pip install python-telegram-bot"
|
||||
)
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
# Run the asyncio event loop in a dedicated thread so the sync cow body
|
||||
# is untouched.
|
||||
self._loop = asyncio.new_event_loop()
|
||||
|
||||
def _run_loop():
|
||||
asyncio.set_event_loop(self._loop)
|
||||
try:
|
||||
self._loop.run_until_complete(self._async_main(Application, MessageHandler, CommandHandler, filters))
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] event loop crashed: {e}", exc_info=True)
|
||||
self.report_startup_error(str(e))
|
||||
finally:
|
||||
try:
|
||||
self._loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[Telegram] event loop exited")
|
||||
|
||||
self._loop_thread = threading.Thread(target=_run_loop, daemon=True, name="telegram-loop")
|
||||
self._loop_thread.start()
|
||||
# Block startup() until the loop thread exits, matching other channels'
|
||||
# behaviour (startup is a blocking call).
|
||||
self._loop_thread.join()
|
||||
|
||||
async def _async_main(self, Application, MessageHandler, CommandHandler, filters):
|
||||
"""Build Application, register handlers, and run polling."""
|
||||
builder = Application.builder().token(self.bot_token)
|
||||
|
||||
# Proxy: prefer telegram_proxy config, fall back to HTTPS_PROXY env var
|
||||
proxy_url = conf().get("telegram_proxy", "") or os.environ.get("HTTPS_PROXY", "")
|
||||
if proxy_url:
|
||||
try:
|
||||
builder = builder.proxy(proxy_url).get_updates_proxy(proxy_url)
|
||||
logger.info(f"[Telegram] using proxy: {proxy_url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Telegram] proxy config failed, fallback to direct: {e}")
|
||||
|
||||
# Media uploads (photo/voice/video/document) over a proxy can be slow,
|
||||
# bump read/write/connect/pool timeouts.
|
||||
builder = (
|
||||
builder
|
||||
.read_timeout(60)
|
||||
.write_timeout(120)
|
||||
.connect_timeout(30)
|
||||
.pool_timeout(30)
|
||||
)
|
||||
|
||||
application = builder.build()
|
||||
self._application = application
|
||||
self._bot = application.bot
|
||||
|
||||
# Fetch our own username (needed for @-mention matching in groups)
|
||||
try:
|
||||
me = await self._bot.get_me()
|
||||
self.bot_username = me.username or ""
|
||||
self.name = self.bot_username # ChatChannel uses self.name to strip @-mention
|
||||
logger.info(f"[Telegram] Bot logged in as @{self.bot_username} (id={me.id})")
|
||||
except Exception as e:
|
||||
err = f"[Telegram] get_me failed: {e}"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
# Register the command menu (failure is non-fatal)
|
||||
if conf().get("telegram_register_commands", True):
|
||||
try:
|
||||
from telegram import BotCommand
|
||||
cmds = [BotCommand(name, desc) for name, desc in TELEGRAM_BOT_COMMANDS]
|
||||
await self._bot.set_my_commands(cmds)
|
||||
logger.info(f"[Telegram] Registered {len(cmds)} bot commands")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Telegram] set_my_commands failed: {e}")
|
||||
|
||||
# Handlers:
|
||||
# 1) /cancel uses the fast-path
|
||||
application.add_handler(CommandHandler("cancel", self._on_cancel))
|
||||
# 2) Normal messages (text + media)
|
||||
application.add_handler(MessageHandler(filters.ALL & ~filters.COMMAND, self._on_message))
|
||||
# 3) Other slash commands are forwarded as plain text for the agent to handle
|
||||
application.add_handler(MessageHandler(filters.COMMAND, self._on_command_passthrough))
|
||||
|
||||
# Start polling. drop_pending_updates avoids replaying backlog after restart.
|
||||
# Transient "Server disconnected" / RemoteProtocolError during get_updates
|
||||
# are common over proxies/flaky networks; PTB's network loop auto-retries,
|
||||
# so we only need to keep the noise down (see _quiet_polling_network_errors).
|
||||
self._quiet_polling_network_errors()
|
||||
logger.info("[Telegram] Starting long polling...")
|
||||
await application.initialize()
|
||||
await application.start()
|
||||
await application.updater.start_polling(
|
||||
drop_pending_updates=True,
|
||||
# Long-poll hold time on the server side; smaller value = reconnect more
|
||||
# often but each hung connection fails faster.
|
||||
timeout=30,
|
||||
# Retry forever on transient get_updates network errors instead of giving up.
|
||||
bootstrap_retries=-1,
|
||||
)
|
||||
self.report_startup_success()
|
||||
logger.info("[Telegram] ✅ Telegram bot ready, polling for updates")
|
||||
|
||||
# Block until stop()
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
await asyncio.sleep(0.5)
|
||||
finally:
|
||||
try:
|
||||
await application.updater.stop()
|
||||
await application.stop()
|
||||
await application.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Telegram] shutdown error: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _quiet_polling_network_errors():
|
||||
"""Downgrade PTB's noisy 'Exception happened while polling for updates' logs.
|
||||
|
||||
These transient get_updates errors (RemoteProtocolError / NetworkError /
|
||||
TimedOut, typically over a proxy) are auto-retried by PTB's network loop,
|
||||
so logging the full traceback at ERROR is just noise. We attach a filter
|
||||
that drops these specific records while leaving real errors untouched.
|
||||
"""
|
||||
import logging
|
||||
|
||||
class _PollingNoiseFilter(logging.Filter):
|
||||
_NEEDLES = (
|
||||
"Exception happened while polling for updates",
|
||||
"Server disconnected without sending a response",
|
||||
)
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
try:
|
||||
msg = record.getMessage()
|
||||
except Exception:
|
||||
return True
|
||||
if any(n in msg for n in self._NEEDLES):
|
||||
# Keep a single-line breadcrumb at DEBUG, drop the traceback.
|
||||
logger.debug(f"[Telegram] transient polling network error (auto-retrying): {msg.splitlines()[0]}")
|
||||
return False
|
||||
return True
|
||||
|
||||
noise_filter = _PollingNoiseFilter()
|
||||
for name in ("telegram.ext.Updater", "telegram.ext._updater", "telegram.ext"):
|
||||
logging.getLogger(name).addFilter(noise_filter)
|
||||
|
||||
def stop(self):
|
||||
logger.info("[Telegram] stop() called")
|
||||
self._stop_event.set()
|
||||
if self._loop_thread and self._loop_thread.is_alive():
|
||||
try:
|
||||
self._loop_thread.join(timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[Telegram] stop() completed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound: telegram update -> ChatMessage -> ChatChannel.produce
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _on_cancel(self, update, _context):
|
||||
"""Fast-path: /cancel calls cancel_session directly without going through agent."""
|
||||
try:
|
||||
from agent.protocol import get_cancel_registry
|
||||
session_id = self._compute_session_id(update)
|
||||
cancelled = get_cancel_registry().cancel_session(session_id)
|
||||
text = "Current task cancelled." if cancelled else "No running task to cancel."
|
||||
await update.effective_message.reply_text(text)
|
||||
logger.info(f"[Telegram] /cancel session={session_id}, cancelled={cancelled}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] /cancel error: {e}", exc_info=True)
|
||||
try:
|
||||
await update.effective_message.reply_text(f"⚠️ /cancel failed: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _on_command_passthrough(self, update, _context):
|
||||
"""All non-/cancel commands fall through to plain message handling."""
|
||||
await self._on_message(update, _context)
|
||||
|
||||
async def _on_message(self, update, _context):
|
||||
"""Telegram update entry: parse message -> build ChatMessage -> produce()."""
|
||||
try:
|
||||
message = update.effective_message
|
||||
chat = update.effective_chat
|
||||
if not message or not chat:
|
||||
return
|
||||
|
||||
# Idempotent dedup
|
||||
msg_uid = f"{chat.id}:{message.message_id}"
|
||||
if self._received_msgs.get(msg_uid):
|
||||
return
|
||||
self._received_msgs[msg_uid] = True
|
||||
|
||||
is_group = chat.type in ("group", "supergroup")
|
||||
|
||||
# Debug log: helpful when group messages are silently dropped
|
||||
if is_group:
|
||||
logger.debug(
|
||||
f"[Telegram] group update received: chat_id={chat.id}, "
|
||||
f"text={(message.text or message.caption or '')[:40]!r}, "
|
||||
f"reply_to_bot={bool(message.reply_to_message and message.reply_to_message.from_user and message.reply_to_message.from_user.username == self.bot_username)}"
|
||||
)
|
||||
|
||||
# Group trigger gate (silently drop if not triggered)
|
||||
if is_group and not self._should_reply_in_group(update):
|
||||
logger.debug(f"[Telegram] group message not triggered (need @{self.bot_username} or reply), skip")
|
||||
return
|
||||
|
||||
# Parse message type + download media if needed.
|
||||
# Media messages with caption return both the local path and the caption text.
|
||||
ctype, content, caption = await self._parse_message(message)
|
||||
if ctype is None:
|
||||
logger.debug(f"[Telegram] unsupported message type, skip. msg={message}")
|
||||
return
|
||||
|
||||
# Strip @bot mention for group text/caption
|
||||
if is_group and self.bot_username:
|
||||
if ctype == ContextType.TEXT and content:
|
||||
content = self._strip_at_mention(content)
|
||||
if caption:
|
||||
caption = self._strip_at_mention(caption)
|
||||
|
||||
tg_msg = TelegramMessage(
|
||||
update,
|
||||
is_group=is_group,
|
||||
bot_username=self.bot_username,
|
||||
ctype=ctype,
|
||||
content=content,
|
||||
)
|
||||
tg_msg.is_at = is_group # If we got here in a group, the bot is mentioned/replied
|
||||
|
||||
# File cache: standalone media goes into cache, the next text query attaches them
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
session_id = self._compute_session_id(update)
|
||||
|
||||
# Media + caption together: treat as a complete query and bypass the cache
|
||||
if ctype in (ContextType.IMAGE, ContextType.FILE) and caption:
|
||||
tag = "image" if ctype == ContextType.IMAGE else "file"
|
||||
merged_text = f"{caption}\n[{tag}: {content}]"
|
||||
tg_msg.ctype = ContextType.TEXT
|
||||
tg_msg.content = merged_text
|
||||
ctype = ContextType.TEXT
|
||||
logger.info(f"[Telegram] Media+caption merged for session {session_id}")
|
||||
# fallthrough to the TEXT branch below
|
||||
|
||||
elif ctype == ContextType.IMAGE:
|
||||
file_cache.add(session_id, content, file_type="image")
|
||||
logger.info(f"[Telegram] Image cached for session {session_id}, waiting for query...")
|
||||
return
|
||||
elif ctype == ContextType.FILE:
|
||||
file_cache.add(session_id, content, file_type="file")
|
||||
logger.info(f"[Telegram] File cached for session {session_id}: {content}")
|
||||
return
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
tag = ftype if ftype in ("image", "video") else "file"
|
||||
refs.append(f"[{tag}: {fi['path']}]")
|
||||
tg_msg.content = (tg_msg.content or "") + "\n" + "\n".join(refs)
|
||||
file_cache.clear(session_id)
|
||||
logger.info(f"[Telegram] Attached {len(cached_files)} cached file(s) to query")
|
||||
|
||||
# Dispatch to cow main pipeline (reuses ChatChannel._compose_context routing)
|
||||
context = self._compose_context(
|
||||
tg_msg.ctype,
|
||||
tg_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=tg_msg,
|
||||
)
|
||||
if context:
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = str(chat.id)
|
||||
context["telegram_chat_id"] = chat.id
|
||||
context["telegram_reply_to_msg_id"] = message.message_id if is_group else None
|
||||
self.produce(context)
|
||||
logger.debug(f"[Telegram] received: type={ctype}, content={str(tg_msg.content)[:80]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] _on_message error: {e}", exc_info=True)
|
||||
|
||||
async def _parse_message(self, message):
|
||||
"""Parse a telegram message and return (ctype, content, caption).
|
||||
|
||||
- content is text for ContextType.TEXT, otherwise the local file path
|
||||
- caption is the optional text accompanying a media message; empty for plain text
|
||||
"""
|
||||
caption = (message.caption or "").strip()
|
||||
|
||||
if message.photo:
|
||||
largest = message.photo[-1]
|
||||
path = await self._download_file(largest.file_id, suffix=".jpg")
|
||||
return (ContextType.IMAGE, path, caption) if path else (None, None, "")
|
||||
|
||||
if message.voice or message.audio:
|
||||
audio_obj = message.voice or message.audio
|
||||
suffix = ".ogg" if message.voice else (
|
||||
"." + (audio_obj.mime_type.split("/")[-1] if getattr(audio_obj, "mime_type", "") else "mp3")
|
||||
)
|
||||
path = await self._download_file(audio_obj.file_id, suffix=suffix)
|
||||
return (ContextType.VOICE, path, caption) if path else (None, None, "")
|
||||
|
||||
if message.video or message.video_note:
|
||||
video_obj = message.video or message.video_note
|
||||
path = await self._download_file(video_obj.file_id, suffix=".mp4")
|
||||
return (ContextType.FILE, path, caption) if path else (None, None, "")
|
||||
|
||||
if message.document:
|
||||
doc = message.document
|
||||
ext = ""
|
||||
if doc.file_name and "." in doc.file_name:
|
||||
ext = "." + doc.file_name.rsplit(".", 1)[-1]
|
||||
path = await self._download_file(doc.file_id, suffix=ext, original_name=doc.file_name)
|
||||
if not path:
|
||||
return (None, None, "")
|
||||
# Image-typed documents (user picked "send as file") are treated as images
|
||||
mime = (doc.mime_type or "").lower()
|
||||
if mime.startswith("image/"):
|
||||
return (ContextType.IMAGE, path, caption)
|
||||
return (ContextType.FILE, path, caption)
|
||||
|
||||
if message.text:
|
||||
return (ContextType.TEXT, message.text.strip(), "")
|
||||
|
||||
return (None, None, "")
|
||||
|
||||
async def _download_file(self, file_id: str, suffix: str = "", original_name: str = ""):
|
||||
"""Download via bot.get_file into the local tmp dir; return path or None on failure."""
|
||||
try:
|
||||
f = await self._bot.get_file(file_id)
|
||||
tmp_dir = TelegramMessage.get_tmp_dir()
|
||||
base = original_name or f"{file_id}{suffix or ''}"
|
||||
# Prefix with file_id to avoid name collisions / weird chars
|
||||
safe_name = f"{file_id}_{base}" if original_name else base
|
||||
local_path = os.path.join(tmp_dir, safe_name)
|
||||
await f.download_to_drive(custom_path=local_path)
|
||||
logger.debug(f"[Telegram] downloaded file_id={file_id} -> {local_path}")
|
||||
return local_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] download_file failed (file_id={file_id}): {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group trigger logic
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _should_reply_in_group(self, update) -> bool:
|
||||
"""Decide whether to reply to a group message based on configuration."""
|
||||
mode = conf().get("telegram_group_trigger", "mention_or_reply")
|
||||
if mode == "all":
|
||||
return True
|
||||
|
||||
message = update.effective_message
|
||||
if not message:
|
||||
return False
|
||||
|
||||
# 1) Mentioned
|
||||
if self.bot_username and self._is_mentioned(message, self.bot_username):
|
||||
return True
|
||||
|
||||
# 2) Reply to a bot message
|
||||
if mode == "mention_or_reply":
|
||||
reply = message.reply_to_message
|
||||
if reply and reply.from_user and reply.from_user.username == self.bot_username:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_mentioned(message, bot_username: str) -> bool:
|
||||
"""Check whether entities/caption_entities contain a @mention of the bot."""
|
||||
bot_at = "@" + bot_username.lower()
|
||||
text = (message.text or message.caption or "").lower()
|
||||
if bot_at in text:
|
||||
return True
|
||||
# Also check entities strictly to support text_mention (no-username @)
|
||||
for ent in (message.entities or []) + (message.caption_entities or []):
|
||||
if ent.type == "mention":
|
||||
src = message.text or message.caption or ""
|
||||
if src[ent.offset: ent.offset + ent.length].lower() == bot_at:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _strip_at_mention(self, content: str) -> str:
|
||||
"""Strip @bot_username from group text (case-insensitive)."""
|
||||
if not content or not self.bot_username:
|
||||
return content
|
||||
pattern = re.compile(r"@" + re.escape(self.bot_username), re.IGNORECASE)
|
||||
return pattern.sub("", content).strip()
|
||||
|
||||
@staticmethod
|
||||
def _compute_session_id(update) -> str:
|
||||
chat = update.effective_chat
|
||||
user = update.effective_user
|
||||
is_group = chat.type in ("group", "supergroup")
|
||||
if is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
return f"tg_group_{chat.id}"
|
||||
return f"tg_group_{chat.id}_{user.id}"
|
||||
return f"tg_user_{user.id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override _compose_context: skip the parent's group whitelist/at checks
|
||||
# (already handled in _on_message via _should_reply_in_group). Same idea
|
||||
# as the feishu channel.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
if cmsg.is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = (content or "").strip()
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif ctype == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and (
|
||||
conf().get("voice_reply_voice") or conf().get("always_reply_voice")
|
||||
):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound: ChatChannel.send -> Telegram API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""Called from cow's sync main thread; we marshal the coroutine onto the loop thread."""
|
||||
if self._loop is None or self._bot is None:
|
||||
logger.warning("[Telegram] bot not ready, drop reply")
|
||||
return
|
||||
|
||||
chat_id = context.get("telegram_chat_id")
|
||||
reply_to = context.get("telegram_reply_to_msg_id")
|
||||
if chat_id is None:
|
||||
logger.warning("[Telegram] no telegram_chat_id in context, drop reply")
|
||||
return
|
||||
|
||||
coro = self._async_send(reply, chat_id, reply_to)
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
# Media uploads through a proxy can be slow; let PTB's own timeouts win
|
||||
future.result(timeout=180)
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] send failed: {e}")
|
||||
|
||||
# Number of retries for transient network errors (proxy hiccups etc.)
|
||||
_SEND_RETRIES = 2
|
||||
_SEND_RETRY_BACKOFF = 2.0 # seconds
|
||||
|
||||
async def _send_with_retry(self, send_fn, *, label: str):
|
||||
"""Run a single Telegram API call with retries for transient network errors."""
|
||||
from telegram.error import NetworkError, TimedOut
|
||||
last_err = None
|
||||
for attempt in range(self._SEND_RETRIES + 1):
|
||||
try:
|
||||
return await send_fn()
|
||||
except (NetworkError, TimedOut) as e:
|
||||
last_err = e
|
||||
if attempt >= self._SEND_RETRIES:
|
||||
break
|
||||
wait = self._SEND_RETRY_BACKOFF * (attempt + 1)
|
||||
logger.warning(
|
||||
f"[Telegram] {label} transient error (attempt {attempt + 1}/"
|
||||
f"{self._SEND_RETRIES + 1}): {e}; retry in {wait}s"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
raise last_err
|
||||
|
||||
async def _async_send(self, reply: Reply, chat_id, reply_to_msg_id):
|
||||
try:
|
||||
rtype = reply.type
|
||||
content = reply.content
|
||||
|
||||
if rtype == ReplyType.TEXT or rtype == ReplyType.INFO or rtype == ReplyType.ERROR:
|
||||
# Telegram caps a single text message at 4096 chars; auto-split
|
||||
text = str(content) if content is not None else ""
|
||||
if not text:
|
||||
return
|
||||
for chunk in _split_text(text, 4000):
|
||||
await self._send_with_retry(
|
||||
lambda c=chunk: self._bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=c,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
# Avoid failing the whole send if reply_to was deleted
|
||||
allow_sending_without_reply=True,
|
||||
),
|
||||
label="send_message",
|
||||
)
|
||||
|
||||
elif rtype == ReplyType.IMAGE:
|
||||
# Already a local BytesIO; send it directly
|
||||
content.seek(0)
|
||||
await self._send_with_retry(
|
||||
lambda: self._bot.send_photo(
|
||||
chat_id=chat_id,
|
||||
photo=content,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
),
|
||||
label="send_photo",
|
||||
)
|
||||
|
||||
elif rtype == ReplyType.IMAGE_URL:
|
||||
url = str(content)
|
||||
if url.startswith("file://"):
|
||||
local = url[7:]
|
||||
# Open inside the lambda so each retry gets a fresh stream
|
||||
async def _send_local_photo():
|
||||
with open(local, "rb") as f:
|
||||
return await self._bot.send_photo(
|
||||
chat_id=chat_id, photo=f,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
)
|
||||
await self._send_with_retry(_send_local_photo, label="send_photo(file)")
|
||||
else:
|
||||
await self._send_with_retry(
|
||||
lambda: self._bot.send_photo(
|
||||
chat_id=chat_id, photo=url,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
),
|
||||
label="send_photo(url)",
|
||||
)
|
||||
|
||||
elif rtype == ReplyType.VOICE:
|
||||
local = content[7:] if isinstance(content, str) and content.startswith("file://") else content
|
||||
async def _send_voice():
|
||||
with open(local, "rb") as f:
|
||||
return await self._bot.send_voice(
|
||||
chat_id=chat_id, voice=f,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
)
|
||||
await self._send_with_retry(_send_voice, label="send_voice")
|
||||
|
||||
elif rtype == ReplyType.FILE:
|
||||
# Videos go through send_video, everything else through send_document
|
||||
local = content[7:] if isinstance(content, str) and content.startswith("file://") else content
|
||||
# File replies may carry an accompanying text caption
|
||||
caption = getattr(reply, "text_content", None) or None
|
||||
is_video = isinstance(local, str) and local.lower().endswith(
|
||||
(".mp4", ".mov", ".avi", ".mkv", ".webm")
|
||||
)
|
||||
|
||||
async def _send_file():
|
||||
with open(local, "rb") as f:
|
||||
if is_video:
|
||||
return await self._bot.send_video(
|
||||
chat_id=chat_id, video=f, caption=caption,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
)
|
||||
return await self._bot.send_document(
|
||||
chat_id=chat_id, document=f, caption=caption,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
)
|
||||
await self._send_with_retry(_send_file, label="send_video" if is_video else "send_document")
|
||||
|
||||
else:
|
||||
# Fallback: send as plain text
|
||||
await self._send_with_retry(
|
||||
lambda: self._bot.send_message(
|
||||
chat_id=chat_id, text=str(content),
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
),
|
||||
label="send_message(fallback)",
|
||||
)
|
||||
|
||||
logger.info(f"[Telegram] sent reply (type={rtype}, chat_id={chat_id})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] _async_send error: {e}", exc_info=True)
|
||||
|
||||
|
||||
def _split_text(text: str, limit: int):
|
||||
"""Split long text preferring line breaks to keep markdown structure intact."""
|
||||
if len(text) <= limit:
|
||||
yield text
|
||||
return
|
||||
buf = []
|
||||
size = 0
|
||||
for line in text.splitlines(keepends=True):
|
||||
if size + len(line) > limit and buf:
|
||||
yield "".join(buf)
|
||||
buf, size = [], 0
|
||||
# Hard-split single lines that exceed the limit
|
||||
while len(line) > limit:
|
||||
yield line[:limit]
|
||||
line = line[limit:]
|
||||
buf.append(line)
|
||||
size += len(line)
|
||||
if buf:
|
||||
yield "".join(buf)
|
||||
62
channel/telegram/telegram_message.py
Normal file
62
channel/telegram/telegram_message.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Telegram message adapter.
|
||||
|
||||
Convert a python-telegram-bot Update into cow's unified ChatMessage.
|
||||
File downloads are NOT performed here; the channel layer triggers
|
||||
bot.get_file() on demand because it requires the async event loop.
|
||||
"""
|
||||
import os
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
class TelegramMessage(ChatMessage):
|
||||
"""Wrap a Telegram Update into the unified ChatMessage."""
|
||||
|
||||
def __init__(self, update, is_group: bool = False, bot_username: str = "",
|
||||
ctype: ContextType = ContextType.TEXT, content: str = ""):
|
||||
super().__init__(update)
|
||||
message = update.effective_message
|
||||
chat = update.effective_chat
|
||||
user = update.effective_user
|
||||
|
||||
# Basic fields
|
||||
self.msg_id = str(message.message_id) if message else ""
|
||||
self.create_time = int(message.date.timestamp()) if message and message.date else 0
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
|
||||
# Sender / chat info
|
||||
from_user_id = str(user.id) if user else "unknown"
|
||||
from_user_nick = (
|
||||
user.full_name if user and user.full_name else (user.username if user else "unknown")
|
||||
)
|
||||
self.from_user_id = from_user_id
|
||||
self.from_user_nickname = from_user_nick or from_user_id
|
||||
self.to_user_id = bot_username or "telegram_bot"
|
||||
self.to_user_nickname = bot_username or "telegram_bot"
|
||||
|
||||
self.is_group = is_group
|
||||
if is_group:
|
||||
# Group: other_user_id = group_id, actual_user_id = sender id
|
||||
self.other_user_id = str(chat.id)
|
||||
self.other_user_nickname = chat.title or str(chat.id)
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = self.from_user_nickname
|
||||
else:
|
||||
self.other_user_id = from_user_id
|
||||
self.other_user_nickname = self.from_user_nickname
|
||||
|
||||
# Whether the bot was triggered by @-mention or reply (set by channel layer)
|
||||
self.is_at = False
|
||||
|
||||
@staticmethod
|
||||
def get_tmp_dir() -> str:
|
||||
"""Local download directory, aligned with other channels (agent_workspace/tmp)."""
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
@@ -1,4 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import Reply, ReplyType
|
||||
@@ -8,6 +11,164 @@ from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class _Style:
|
||||
"""ANSI escape codes for terminal styling. Disabled when not a tty."""
|
||||
|
||||
enabled = sys.stdout.isatty()
|
||||
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
ITALIC = "\033[3m"
|
||||
|
||||
GRAY = "\033[90m"
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
BLUE = "\033[34m"
|
||||
MAGENTA = "\033[35m"
|
||||
CYAN = "\033[36m"
|
||||
|
||||
@classmethod
|
||||
def wrap(cls, text, *codes):
|
||||
if not cls.enabled or not codes:
|
||||
return text
|
||||
return "".join(codes) + text + cls.RESET
|
||||
|
||||
|
||||
class TerminalAgentRenderer:
|
||||
"""Render agent stream events to the terminal in real time.
|
||||
|
||||
Reuses the same `on_event` mechanism as the web channel so the terminal
|
||||
can show reasoning, tool calls and streaming answer text just like the web UI.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._reasoning_active = False
|
||||
self._answer_active = False
|
||||
self._has_output = False
|
||||
# Track tool execution start time as a fallback when the event omits it
|
||||
self._tool_started_at = {}
|
||||
|
||||
def _print(self, text, end="", flush=True):
|
||||
sys.stdout.write(text)
|
||||
if end:
|
||||
sys.stdout.write(end)
|
||||
if flush:
|
||||
sys.stdout.flush()
|
||||
self._has_output = True
|
||||
|
||||
def _close_section(self):
|
||||
"""Finish the currently open streaming section (reasoning or answer)."""
|
||||
if self._reasoning_active:
|
||||
self._print("", end="\n")
|
||||
self._reasoning_active = False
|
||||
if self._answer_active:
|
||||
self._print("", end="\n")
|
||||
self._answer_active = False
|
||||
|
||||
def _format_arguments(self, arguments):
|
||||
try:
|
||||
if isinstance(arguments, (dict, list)):
|
||||
text = json.dumps(arguments, ensure_ascii=False)
|
||||
else:
|
||||
text = str(arguments)
|
||||
except Exception:
|
||||
text = str(arguments)
|
||||
# Keep tool input compact in the terminal
|
||||
if len(text) > 300:
|
||||
text = text[:300] + "…"
|
||||
return text
|
||||
|
||||
def handle_event(self, event: dict):
|
||||
try:
|
||||
self._handle_event(event)
|
||||
except Exception as e:
|
||||
logger.debug(f"[Terminal] render event error: {e}")
|
||||
|
||||
def _handle_event(self, event: dict):
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {}) or {}
|
||||
|
||||
if event_type == "agent_start":
|
||||
self._print("\n" + _Style.wrap("Agent: ", _Style.BOLD, _Style.GREEN), end="\n")
|
||||
|
||||
elif event_type == "reasoning_update":
|
||||
delta = data.get("delta", "")
|
||||
if not delta:
|
||||
return
|
||||
if self._answer_active:
|
||||
self._close_section()
|
||||
if not self._reasoning_active:
|
||||
self._print(_Style.wrap("💭 思考 ", _Style.DIM, _Style.MAGENTA), end="\n")
|
||||
self._reasoning_active = True
|
||||
self._print(_Style.wrap(delta, _Style.DIM, _Style.ITALIC))
|
||||
|
||||
elif event_type == "message_update":
|
||||
delta = data.get("delta", "")
|
||||
if not delta:
|
||||
return
|
||||
if self._reasoning_active:
|
||||
self._close_section()
|
||||
self._answer_active = True
|
||||
self._print(delta)
|
||||
|
||||
elif event_type == "tool_execution_start":
|
||||
self._close_section()
|
||||
tool_name = data.get("tool_name", "tool")
|
||||
tool_id = data.get("tool_call_id")
|
||||
arguments = data.get("arguments", {})
|
||||
self._tool_started_at[tool_id] = time.time()
|
||||
header = _Style.wrap(f"🔧 {tool_name}", _Style.BOLD, _Style.CYAN)
|
||||
args_str = self._format_arguments(arguments)
|
||||
self._print(f"{header} {_Style.wrap(args_str, _Style.GRAY)}", end="\n")
|
||||
|
||||
elif event_type == "tool_execution_end":
|
||||
tool_name = data.get("tool_name", "tool")
|
||||
tool_id = data.get("tool_call_id")
|
||||
status = data.get("status", "success")
|
||||
result = data.get("result", "")
|
||||
exec_time = data.get("execution_time")
|
||||
if exec_time is None and tool_id in self._tool_started_at:
|
||||
exec_time = time.time() - self._tool_started_at.pop(tool_id, time.time())
|
||||
success = status == "success"
|
||||
icon = "✓" if success else "✗"
|
||||
color = _Style.GREEN if success else _Style.RED
|
||||
result_str = str(result)
|
||||
if len(result_str) > 500:
|
||||
result_str = result_str[:500] + "…"
|
||||
# Indent multi-line tool output for readability
|
||||
result_str = result_str.replace("\n", "\n ")
|
||||
cost = f" ({exec_time:.2f}s)" if isinstance(exec_time, (int, float)) else ""
|
||||
self._print(
|
||||
_Style.wrap(f" {icon} {tool_name}{cost}", color) + " " + _Style.wrap(result_str, _Style.GRAY),
|
||||
end="\n",
|
||||
)
|
||||
|
||||
elif event_type == "file_to_send":
|
||||
self._close_section()
|
||||
file_path = data.get("path", "")
|
||||
file_name = data.get("file_name", "")
|
||||
label = file_name or file_path
|
||||
self._print(_Style.wrap(f"📎 文件: {label}", _Style.BLUE), end="\n")
|
||||
|
||||
elif event_type == "error":
|
||||
self._close_section()
|
||||
err_msg = data.get("error") or "unknown error"
|
||||
self._print(_Style.wrap(f"❌ {err_msg}", _Style.BOLD, _Style.RED), end="\n")
|
||||
|
||||
elif event_type == "agent_cancelled":
|
||||
self._close_section()
|
||||
self._print(_Style.wrap("⏹ 已中止", _Style.YELLOW), end="\n")
|
||||
|
||||
elif event_type == "agent_end":
|
||||
self._close_section()
|
||||
|
||||
def finish(self):
|
||||
"""Ensure any open section is closed at the end of a turn."""
|
||||
self._close_section()
|
||||
|
||||
|
||||
class TerminalMessage(ChatMessage):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -29,17 +190,33 @@ class TerminalMessage(ChatMessage):
|
||||
class TerminalChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Per-request renderers keyed by request_id; used to detect whether
|
||||
# agent text was already streamed so send() can avoid duplicate output.
|
||||
self._renderers = {}
|
||||
# Callback that restores TTY attributes on exit (set in startup).
|
||||
self._restore_terminal = None
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
print("\nBot:")
|
||||
request_id = context.get("request_id") if context else None
|
||||
renderer = self._renderers.pop(request_id, None) if request_id else None
|
||||
streamed = renderer is not None and renderer._has_output
|
||||
|
||||
if renderer is not None:
|
||||
renderer.finish()
|
||||
|
||||
if reply.type == ReplyType.IMAGE:
|
||||
from PIL import Image
|
||||
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
if not streamed:
|
||||
print("\nAgent: ")
|
||||
print("<IMAGE>")
|
||||
img.show()
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
elif reply.type == ReplyType.IMAGE_URL: # download image from url
|
||||
import io
|
||||
|
||||
import requests
|
||||
@@ -52,38 +229,122 @@ class TerminalChannel(ChatChannel):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
if not streamed:
|
||||
print("\nAgent: ")
|
||||
print(img_url)
|
||||
img.show()
|
||||
else:
|
||||
print(reply.content)
|
||||
print("\nUser:", end="")
|
||||
# When agent already streamed the answer, skip re-printing the
|
||||
# final text to avoid duplication; just emit a trailing newline.
|
||||
if streamed:
|
||||
print()
|
||||
else:
|
||||
print("\nAgent: ")
|
||||
print(reply.content)
|
||||
print("\nUser: ", end="")
|
||||
sys.stdout.flush()
|
||||
return
|
||||
|
||||
def _silence_console_logging(self):
|
||||
"""Mute console log output so background-thread logs (web/MCP/scheduler)
|
||||
don't flood the interactive terminal. Logs still go to run.log in full.
|
||||
|
||||
Configurable via `terminal_log_level` (default ERROR). The file handler
|
||||
is untouched, so run.log keeps the complete log.
|
||||
"""
|
||||
import logging
|
||||
|
||||
level_name = str(conf().get("terminal_log_level", "ERROR")).upper()
|
||||
level = getattr(logging, level_name, logging.ERROR)
|
||||
root_logger = logging.getLogger("log")
|
||||
for handler in root_logger.handlers:
|
||||
# Only raise the level of the stdout/stderr stream handler;
|
||||
# keep FileHandler at the logger's level so run.log stays complete.
|
||||
if isinstance(handler, logging.StreamHandler) and not isinstance(handler, logging.FileHandler):
|
||||
handler.setLevel(level)
|
||||
|
||||
def _install_terminal_guard(self):
|
||||
"""Save TTY attributes and register restore hooks so the terminal is
|
||||
never left in a broken state (no echo / raw mode / leftover ANSI) after
|
||||
the process exits, especially when Ctrl+C interrupts a blocking input().
|
||||
"""
|
||||
if not sys.stdin.isatty():
|
||||
return
|
||||
try:
|
||||
import atexit
|
||||
import termios
|
||||
|
||||
saved_attrs = termios.tcgetattr(sys.stdin.fileno())
|
||||
|
||||
def _restore():
|
||||
try:
|
||||
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, saved_attrs)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if _Style.enabled:
|
||||
sys.stdout.write(_Style.RESET)
|
||||
sys.stdout.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._restore_terminal = _restore
|
||||
atexit.register(_restore)
|
||||
except Exception as e:
|
||||
# termios is unavailable on Windows; skip the guard there.
|
||||
logger.debug(f"[Terminal] terminal guard not installed: {e}")
|
||||
self._restore_terminal = None
|
||||
|
||||
def startup(self):
|
||||
context = Context()
|
||||
logger.setLevel("WARN")
|
||||
print("\nPlease input your question:\nUser:", end="")
|
||||
self._silence_console_logging()
|
||||
self._install_terminal_guard()
|
||||
print("\nPlease input your question:\nUser: ", end="")
|
||||
sys.stdout.flush()
|
||||
msg_id = 0
|
||||
while True:
|
||||
try:
|
||||
prompt = self.get_input()
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
sys.exit()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
self._shutdown()
|
||||
msg_id += 1
|
||||
trigger_prefixs = conf().get("single_chat_prefix", [""])
|
||||
if check_prefix(prompt, trigger_prefixs) is None:
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
prompt = trigger_prefixs[0] + prompt # add trigger prefix to untriggered messages
|
||||
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
|
||||
context["isgroup"] = False
|
||||
if context:
|
||||
# Attach an agent event renderer so reasoning / tool calls /
|
||||
# streaming answer show up live in the terminal (web-like UX).
|
||||
request_id = str(msg_id)
|
||||
context["request_id"] = request_id
|
||||
renderer = TerminalAgentRenderer()
|
||||
self._renderers[request_id] = renderer
|
||||
context["on_event"] = renderer.handle_event
|
||||
self.produce(context)
|
||||
else:
|
||||
raise Exception("context is None")
|
||||
|
||||
def _shutdown(self):
|
||||
"""Restore terminal state and terminate the whole process.
|
||||
|
||||
startup() runs in a daemon sub-thread, so sys.exit() would only kill
|
||||
this thread and leave the main process (and web/MCP/scheduler threads)
|
||||
alive, holding the terminal in a half-occupied state -> laggy input.
|
||||
We reset any leftover ANSI styling and hard-exit the process instead.
|
||||
"""
|
||||
# Restore TTY attributes and reset any leftover ANSI styling
|
||||
# (e.g. interrupted mid-stream output) before terminating.
|
||||
if self._restore_terminal:
|
||||
self._restore_terminal()
|
||||
elif _Style.enabled:
|
||||
sys.stdout.write(_Style.RESET)
|
||||
sys.stdout.write("\nExiting...\n")
|
||||
sys.stdout.flush()
|
||||
# Hard-exit the entire process from a daemon thread.
|
||||
os._exit(0)
|
||||
|
||||
def get_input(self):
|
||||
"""
|
||||
Multi-line input function
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1401
channel/web/static/css/console.css
Normal file
1401
channel/web/static/css/console.css
Normal file
File diff suppressed because it is too large
Load Diff
7243
channel/web/static/js/console.js
Normal file
7243
channel/web/static/js/console.js
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user