mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-03 02:27:09 +08:00
Compare commits
574 Commits
feat-cow-a
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d8458669c | ||
|
|
92ec9653e5 | ||
|
|
e861d98007 | ||
|
|
a97eeb1fd9 | ||
|
|
cd88b23b5d | ||
|
|
33eabf937b | ||
|
|
beb5df16a3 | ||
|
|
7fa743f01a | ||
|
|
1f6859d78f | ||
|
|
2853735472 | ||
|
|
feaa9076b0 | ||
|
|
ce0249706e | ||
|
|
af2c839231 | ||
|
|
2b2d24ed25 | ||
|
|
04d28f9d2d | ||
|
|
1dbf41f384 | ||
|
|
9e6a2cc2c0 | ||
|
|
7bf4ef3d05 | ||
|
|
126649f70f | ||
|
|
1827a2a31c | ||
|
|
fcf4eb78dc | ||
|
|
2ec6ea8045 | ||
|
|
0994a3586d | ||
|
|
29c4be6a3a | ||
|
|
c5b8e06891 | ||
|
|
54a20bca92 | ||
|
|
6e786bde90 | ||
|
|
b671b0d725 | ||
|
|
57f5692074 | ||
|
|
b0ac0731c7 | ||
|
|
3c161df526 | ||
|
|
aa3f48e93c | ||
|
|
5ae1e1adde | ||
|
|
fe8b8fe831 | ||
|
|
5aca54c083 | ||
|
|
458b1a1d88 | ||
|
|
3dd4b84179 | ||
|
|
99bddb79d6 | ||
|
|
136b0b89e8 | ||
|
|
c605b0b080 | ||
|
|
b7b8e3679c | ||
|
|
aeb6610ff4 | ||
|
|
e3eacc77d7 | ||
|
|
37661daf40 | ||
|
|
877b848370 | ||
|
|
5c163cc0fe | ||
|
|
6e04ea8240 | ||
|
|
d106465419 | ||
|
|
f39380cea7 | ||
|
|
bccce2d7cb | ||
|
|
6721dbdbcc | ||
|
|
83cd6ad158 | ||
|
|
116fb27257 | ||
|
|
8d67177a1b | ||
|
|
ad2db1a776 | ||
|
|
2e6d9e0f27 | ||
|
|
e05f85f3ce | ||
|
|
40c48a9a61 | ||
|
|
c9a7525d0b | ||
|
|
fd571ac539 | ||
|
|
c5a3f991c5 | ||
|
|
eb74b73351 | ||
|
|
9b31f45481 | ||
|
|
bc9c1691f5 | ||
|
|
73bf83d2ff | ||
|
|
36e1988fee | ||
|
|
aad6ef635e | ||
|
|
96659cd616 | ||
|
|
c8787b7de4 | ||
|
|
91d427c8f9 | ||
|
|
c8c0573dbd | ||
|
|
29af855ecd | ||
|
|
0a146a245d | ||
|
|
bd85fee7d7 | ||
|
|
571897e2fd | ||
|
|
840dabeccd | ||
|
|
069bffa3e8 | ||
|
|
cc10d230b0 | ||
|
|
2517f2add8 | ||
|
|
a534266025 | ||
|
|
8c25395805 | ||
|
|
36b913124b | ||
|
|
2fa6343fe5 | ||
|
|
06b84225a1 | ||
|
|
5b31da335d | ||
|
|
90773ab69f | ||
|
|
11d92bb22a | ||
|
|
b7734c3926 | ||
|
|
d3faf9c8dc | ||
|
|
bca97a1d14 | ||
|
|
ac9d0f18c5 | ||
|
|
09fa624797 | ||
|
|
b8333e351c | ||
|
|
a01423a196 | ||
|
|
7c35df7a82 | ||
|
|
2b90f377e6 | ||
|
|
fff7326209 | ||
|
|
c181e500bc | ||
|
|
16b7271826 | ||
|
|
4a1f62b185 | ||
|
|
d23a0754c1 | ||
|
|
3ffb563a44 | ||
|
|
4e42f2a017 | ||
|
|
a0dfdb79df | ||
|
|
a85c5f9d4e | ||
|
|
2720bba5b7 | ||
|
|
4634a7bc2f | ||
|
|
16d9b449c9 | ||
|
|
8761997757 | ||
|
|
19bba4abbc | ||
|
|
7839f0aac5 | ||
|
|
83def1db30 | ||
|
|
a0b29d1ffe | ||
|
|
f5479c56af | ||
|
|
246f0a45c8 | ||
|
|
fe871aad77 | ||
|
|
6f860e1bc4 | ||
|
|
249ea40ae3 | ||
|
|
20d8ae19a7 | ||
|
|
ad51aabfd7 | ||
|
|
1cf395c041 | ||
|
|
745179a5bf | ||
|
|
ff5d477fa5 | ||
|
|
907825601d | ||
|
|
c2ec26910a | ||
|
|
83f2aea123 | ||
|
|
a5c5439315 | ||
|
|
eca9b60235 | ||
|
|
d2d5d98d78 | ||
|
|
fb341b869b | ||
|
|
29e66cb186 | ||
|
|
307769b949 | ||
|
|
9a09e057d6 | ||
|
|
3e28659528 | ||
|
|
b861eef26f | ||
|
|
caaf006a49 | ||
|
|
b2429ec30c | ||
|
|
55aaf60a57 | ||
|
|
a5790d82f6 | ||
|
|
63f99af1e6 | ||
|
|
4eed2568aa | ||
|
|
fb7962c7f2 | ||
|
|
76e6b7b471 | ||
|
|
fccb7ff9ed | ||
|
|
3b12ef2e66 | ||
|
|
f9d099be1b | ||
|
|
c322c0e3a5 | ||
|
|
530fc20596 | ||
|
|
a23b4ed754 | ||
|
|
fc4f5077b0 | ||
|
|
6a553886da | ||
|
|
1065c7e722 | ||
|
|
a9c8a59f58 | ||
|
|
8730f7fd27 | ||
|
|
8f608223d7 | ||
|
|
a7cbd47a2f | ||
|
|
b80c3fe5a8 | ||
|
|
5080051e39 | ||
|
|
23bfc8d0ba | ||
|
|
80e9062041 | ||
|
|
67bd3420ed | ||
|
|
aea081703f | ||
|
|
f300d2a2d5 | ||
|
|
f150d7d83a | ||
|
|
4d1f059c0d | ||
|
|
bc7f953fcc | ||
|
|
f653483eea | ||
|
|
6b200fd36b | ||
|
|
161fc6cdf0 | ||
|
|
6f68ed6bce | ||
|
|
a4592ffdfe | ||
|
|
7cd7bd1a48 | ||
|
|
9eeca70292 | ||
|
|
02bfe30848 | ||
|
|
c9c99de3d9 | ||
|
|
8752f0cc60 | ||
|
|
5c65196e44 | ||
|
|
f5798bfe90 | ||
|
|
0e556b3468 | ||
|
|
31820f56e7 | ||
|
|
fd88828abd | ||
|
|
ae11159918 | ||
|
|
472a8605c0 | ||
|
|
e1760ba211 | ||
|
|
ce4c0a0aa4 | ||
|
|
64511593c4 | ||
|
|
b0e00dfceb | ||
|
|
fc465b463d | ||
|
|
68ce2e5232 | ||
|
|
81e8bb62ae | ||
|
|
2c13e1b923 | ||
|
|
a0748c2e3b | ||
|
|
40599bb751 | ||
|
|
f3c64ceea7 | ||
|
|
15c60de709 | ||
|
|
6dd316547f | ||
|
|
54c7676a44 | ||
|
|
d25b8966ce | ||
|
|
14a119c48c | ||
|
|
c82515a927 | ||
|
|
26e630c2dd | ||
|
|
13370d2056 | ||
|
|
35282db9e0 | ||
|
|
426fb88ce7 | ||
|
|
2384bd0e10 | ||
|
|
ba3f66d3d1 | ||
|
|
7293a0f670 | ||
|
|
9e86d46267 | ||
|
|
848430f062 | ||
|
|
abd21335c4 | ||
|
|
8fa95f058a | ||
|
|
d4e5ecd497 | ||
|
|
3830f76729 | ||
|
|
83f778fec9 | ||
|
|
cabd24605f | ||
|
|
ae20ba1148 | ||
|
|
3a50b64977 | ||
|
|
8692e74536 | ||
|
|
1c18bd9889 | ||
|
|
60e9d98d0a | ||
|
|
83f6625e0c | ||
|
|
acc09543b7 | ||
|
|
94d8c7e366 | ||
|
|
ea1a0c8b3d | ||
|
|
7bc88c17e4 | ||
|
|
33cf1bc4c3 | ||
|
|
9402e63fe1 | ||
|
|
90e4d494b2 | ||
|
|
da97e948ca | ||
|
|
89a07e8e74 | ||
|
|
3f3d0381e5 | ||
|
|
3649499dba | ||
|
|
a989d088fd | ||
|
|
f79a915136 | ||
|
|
12e8c3d449 | ||
|
|
4f7064575e | ||
|
|
070df826f1 | ||
|
|
fbe48a4b4e | ||
|
|
4dd497fb6d | ||
|
|
907882c0a7 | ||
|
|
d36d5aee3f | ||
|
|
c6824e5f5e | ||
|
|
199c21eede | ||
|
|
5162da5654 | ||
|
|
a1d82f6193 | ||
|
|
ea78e3d0c6 | ||
|
|
3497f00cb4 | ||
|
|
5355d45031 | ||
|
|
26693acc3f | ||
|
|
76e9fef3b2 | ||
|
|
c34308cbd4 | ||
|
|
5a10476010 | ||
|
|
46e80dceec | ||
|
|
90d1835353 | ||
|
|
845fadd0aa | ||
|
|
5748ded52c | ||
|
|
6a737fb734 | ||
|
|
3cd92ccda3 | ||
|
|
54e81aba11 | ||
|
|
d86cb4ded6 | ||
|
|
4d5375f6d6 | ||
|
|
424557fedb | ||
|
|
89251e603f | ||
|
|
a653ed07eb | ||
|
|
ad86deb014 | ||
|
|
9525dc7584 | ||
|
|
cd31dd27fd | ||
|
|
360e3670eb | ||
|
|
8dabe3b4c8 | ||
|
|
443e0c2806 | ||
|
|
9cc173cc4d | ||
|
|
b5f33e5ecd | ||
|
|
40dfc6860f | ||
|
|
1c02a04423 | ||
|
|
de0e45070c | ||
|
|
c169cc7d74 | ||
|
|
cd62ad76f6 | ||
|
|
dd25b0fb5b | ||
|
|
a38b22a6a2 | ||
|
|
830b8f2971 | ||
|
|
b058af122c | ||
|
|
174ee0cafc | ||
|
|
1c336380c0 | ||
|
|
3068880413 | ||
|
|
be596681e5 | ||
|
|
66b71c50e9 | ||
|
|
8744810b25 | ||
|
|
7f94d37c2e | ||
|
|
6d9b7baeb4 | ||
|
|
4470d4c352 | ||
|
|
d2a462a279 | ||
|
|
14ff2a15e7 | ||
|
|
6d1369900e | ||
|
|
1f17ebe69e | ||
|
|
1ae2918064 | ||
|
|
b6571e5cad | ||
|
|
7549d48cf1 | ||
|
|
00353dd0cb | ||
|
|
afd947195d | ||
|
|
e57ef37167 | ||
|
|
ef33a93654 | ||
|
|
61732aecfc | ||
|
|
6764c05c3f | ||
|
|
fa149cf4aa | ||
|
|
e4f9697d06 | ||
|
|
da061450e5 | ||
|
|
d09ae49287 | ||
|
|
511ee0bbaf | ||
|
|
3cb5a0fbd6 | ||
|
|
e06925ab85 | ||
|
|
184634e4e7 | ||
|
|
843c2d02cc | ||
|
|
8ea2455766 | ||
|
|
9dc9987d56 | ||
|
|
3458621147 | ||
|
|
079df5a47c | ||
|
|
ddb07c65a1 | ||
|
|
9b21cd222b | ||
|
|
90f736843f | ||
|
|
13c020eb61 | ||
|
|
dbc06dbe95 | ||
|
|
23d097bc1c | ||
|
|
db85b9808e | ||
|
|
df5bae37bc | ||
|
|
acc23b6051 | ||
|
|
61f2741afc | ||
|
|
4dd7ea886a | ||
|
|
1e8959fbcf | ||
|
|
48729678cf | ||
|
|
0684becaa7 | ||
|
|
db16bdf8cb | ||
|
|
f890318ed9 | ||
|
|
158510cbbe | ||
|
|
ce90cf7aa8 | ||
|
|
a3a3d006eb | ||
|
|
8fd029a4a1 | ||
|
|
2e1b52c1e5 | ||
|
|
3eb8348708 | ||
|
|
393f0c007c | ||
|
|
294e380288 | ||
|
|
4c1c42efac | ||
|
|
c062ca8c66 | ||
|
|
76dcb25103 | ||
|
|
c5b4f236db | ||
|
|
0974c940a8 | ||
|
|
cffa20d37e | ||
|
|
ef009edd29 | ||
|
|
3ca52b118d | ||
|
|
13f5fde4fb | ||
|
|
f512b55ec2 | ||
|
|
22b8ca0095 | ||
|
|
baf66a103d | ||
|
|
45faa9c1ff | ||
|
|
304381a88d | ||
|
|
fc9f54dbc8 | ||
|
|
7199dc187f | ||
|
|
e9ae066d53 | ||
|
|
d71ae406ff | ||
|
|
f3216904b3 | ||
|
|
5958b69ec9 | ||
|
|
7d4e2cb39a | ||
|
|
a483ec0cea | ||
|
|
c1421e0874 | ||
|
|
ce89869c3c | ||
|
|
b8b57e34ff | ||
|
|
bc7f627253 | ||
|
|
652156e398 | ||
|
|
9febb071c6 | ||
|
|
7d0e1568ac | ||
|
|
b4e711f411 | ||
|
|
1b5be1b981 | ||
|
|
49d8707c58 | ||
|
|
9192f6f7f7 | ||
|
|
05022e3745 | ||
|
|
5356e9ddeb | ||
|
|
52acf76e2c | ||
|
|
40cdbd3b45 | ||
|
|
5487c0befe | ||
|
|
8bb16c48c0 | ||
|
|
c6384363f9 | ||
|
|
8993e8ad3e | ||
|
|
289989d9f7 | ||
|
|
dc2ae0e6f1 | ||
|
|
9c966c152d | ||
|
|
4efae41048 | ||
|
|
b8437032e9 | ||
|
|
2d339ca81b | ||
|
|
d53abc9696 | ||
|
|
446c886d38 | ||
|
|
30c6d9b5ae | ||
|
|
5e42996b36 | ||
|
|
ceca7b85bf | ||
|
|
a4d54f58c8 | ||
|
|
005a0e1bad | ||
|
|
46d97fd57d | ||
|
|
72a26b6353 | ||
|
|
89a4033fbf | ||
|
|
39a5dc64bd | ||
|
|
d4bdd9b1b7 | ||
|
|
2f5ba87280 | ||
|
|
8b45d6c750 | ||
|
|
4ecd4df2d4 | ||
|
|
a42f31fe52 | ||
|
|
d4480b695e | ||
|
|
c4b5f7fbae | ||
|
|
ba915f2cc0 | ||
|
|
4b91140f31 | ||
|
|
9879878dd0 | ||
|
|
d78105d57c | ||
|
|
153c9e3565 | ||
|
|
c11623596d | ||
|
|
e791a77f77 | ||
|
|
b641bffb2c | ||
|
|
ee0c47ac1e | ||
|
|
eba90e9343 | ||
|
|
d8374d0fa5 | ||
|
|
fa61744c6d | ||
|
|
4fec55cc01 | ||
|
|
1767413712 | ||
|
|
734c8fa84f | ||
|
|
9a8d422554 | ||
|
|
b21e945c76 | ||
|
|
a02bf1ea09 | ||
|
|
eda82bac92 | ||
|
|
e8d4f7dc4f | ||
|
|
c4a93b7789 | ||
|
|
c3f9925097 | ||
|
|
2a0cf7511a | ||
|
|
d0a70d3339 | ||
|
|
f37e4675dd | ||
|
|
4e32f67eeb | ||
|
|
36d54cab52 | ||
|
|
9d8df10dcf | ||
|
|
45ea88e070 | ||
|
|
d5d0b947f5 | ||
|
|
f775f1f11e | ||
|
|
f1e888f3de | ||
|
|
71c8436e90 | ||
|
|
08c69f5e9b | ||
|
|
a50fafaca2 | ||
|
|
3c6781d240 | ||
|
|
3b8b5625f8 | ||
|
|
6be2034110 | ||
|
|
924dc79f00 | ||
|
|
ccb9030d3c | ||
|
|
8623287ac1 | ||
|
|
022c13f3a4 | ||
|
|
0687916e7f | ||
|
|
bb868b83ba | ||
|
|
24298130b9 | ||
|
|
6e5ee92ebd | ||
|
|
5b91fe04aa | ||
|
|
1623deb3ee | ||
|
|
4a16e05b7a | ||
|
|
f1c04bc60d | ||
|
|
84c6f31c76 | ||
|
|
9d528190bf | ||
|
|
0f23b209ad | ||
|
|
63d9325900 | ||
|
|
f342097f81 | ||
|
|
b4806c4366 | ||
|
|
ff37d8a577 | ||
|
|
a773eb7893 | ||
|
|
7c67513d24 | ||
|
|
6ed85029c5 | ||
|
|
e9c57ddf4d | ||
|
|
a33ce97ed9 | ||
|
|
b788a3dd4e | ||
|
|
fccfa92d7e | ||
|
|
8705bf0a70 | ||
|
|
9318138af7 | ||
|
|
269fa7d2d5 | ||
|
|
e99837a8b9 | ||
|
|
553861a2c4 | ||
|
|
628a85d1be | ||
|
|
2cb54514a4 | ||
|
|
6db22827f2 | ||
|
|
4cc6d5426b | ||
|
|
7d258b5202 | ||
|
|
c8d19ee0bc | ||
|
|
d891312032 | ||
|
|
5edbf4ce32 | ||
|
|
3ddbdd713d | ||
|
|
9ba107b511 | ||
|
|
c9adddb76a | ||
|
|
f0a12d5ff5 | ||
|
|
7cce224499 | ||
|
|
97397ca585 | ||
|
|
f2fbc602a8 | ||
|
|
925d728a86 | ||
|
|
f5f229871b | ||
|
|
9917552b4b | ||
|
|
adca89b973 | ||
|
|
29bfbecdc9 | ||
|
|
1a7a8c98d9 | ||
|
|
cddb38ac3d | ||
|
|
394853c0fb | ||
|
|
c0702c8b36 | ||
|
|
d610608391 | ||
|
|
9082eec91d | ||
|
|
f1a1413b5f | ||
|
|
c1e7f9af9b | ||
|
|
1c71c4e38b | ||
|
|
5e3eccb3f6 | ||
|
|
e1dc037eb9 | ||
|
|
97e9b4c801 | ||
|
|
52d7cad735 | ||
|
|
c0b1d270ba | ||
|
|
e59a2892e4 | ||
|
|
5fa0376a49 | ||
|
|
05a33042c8 | ||
|
|
ce58f23cbc | ||
|
|
b6fc9fa370 | ||
|
|
00ae38faae | ||
|
|
ab28ee58ab | ||
|
|
48db538a2e | ||
|
|
46945942e1 | ||
|
|
a24b26a1ef | ||
|
|
6f8421cdd5 | ||
|
|
284cd9bca9 | ||
|
|
23fd6b8d2b | ||
|
|
4f0ea5d756 | ||
|
|
6c218331b1 | ||
|
|
cea7fb7490 | ||
|
|
8acf2dbdfe | ||
|
|
0542700f90 | ||
|
|
5264f7ce18 | ||
|
|
051ffd78a3 | ||
|
|
bea95d4fae | ||
|
|
fdf7bc312f | ||
|
|
5b094e1097 | ||
|
|
9ad3968084 | ||
|
|
3958b6aae1 | ||
|
|
eaa413caf0 | ||
|
|
9095225b5b | ||
|
|
c529f86dbc | ||
|
|
e4fcfa356a | ||
|
|
8218cff7c1 | ||
|
|
6949bbcf39 | ||
|
|
480c60c0a7 | ||
|
|
eec10cb5db | ||
|
|
02c83d8689 | ||
|
|
72b1cacea1 | ||
|
|
c72cda3386 | ||
|
|
867442155e | ||
|
|
229b14b6fc | ||
|
|
158c87ab8b | ||
|
|
cb303e6109 | ||
|
|
a77a8741b5 | ||
|
|
3d63459c25 | ||
|
|
ce63de3c58 | ||
|
|
4b3b1219b5 | ||
|
|
73b069a76c | ||
|
|
101cf8d108 | ||
|
|
2e926dfb6e | ||
|
|
501866d12a | ||
|
|
39bcb0869f | ||
|
|
a7b99cde4e | ||
|
|
60abcd92a3 | ||
|
|
cdd36e7052 | ||
|
|
c6ac175ce4 | ||
|
|
46bcd87c23 | ||
|
|
ab74be8e33 | ||
|
|
d8298b3eab | ||
|
|
50e60e6d05 | ||
|
|
5d02acbf37 | ||
|
|
8901d91f96 | ||
|
|
b55021bb3d | ||
|
|
0ef51b85e6 | ||
|
|
c77566cc02 | ||
|
|
c1bcedfb51 | ||
|
|
08b592816b | ||
|
|
8ef788e799 | ||
|
|
3ce57ef851 |
147
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
147
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
@@ -1,133 +1,46 @@
|
||||
name: Bug report 🐛
|
||||
description: 项目运行中遇到的Bug或问题。
|
||||
description: Report a bug or unexpected behavior.
|
||||
title: "[Bug] "
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
### ⚠️ 前置确认
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
||||
3. `git pull` 拉取最新代码
|
||||
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
||||
5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
|
||||
6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
> 💡 English is recommended so global developers can help. 推荐使用英文提交,谢谢 ❤️
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 前置确认
|
||||
label: Self check
|
||||
options:
|
||||
- label: 我确认我运行的是最新版本的代码,并且安装了所需的依赖,在[FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs)中也未找到类似问题。
|
||||
- label: I'm on the latest version and searched [existing issues](https://github.com/zhayujie/CowAgent/issues) (incl. closed) — no duplicate.
|
||||
required: true
|
||||
- type: checkboxes
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ⚠️ 搜索issues中是否已存在类似问题
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索你的问题
|
||||
或相关日志的关键词来查找是否存在类似问题。
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有跟我遇到的问题相关的issue
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写你对你所遇到问题的简略总结,这将帮助其他人更好的找到相似问题,谢谢❤️。
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 操作系统类型?
|
||||
description: >
|
||||
请选择你运行程序的操作系统类型。
|
||||
options:
|
||||
- Windows
|
||||
- Linux
|
||||
- MacOS
|
||||
- Docker
|
||||
- Railway
|
||||
- Windows Subsystem for Linux (WSL)
|
||||
- Other (请在问题中说明)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的python版本是?
|
||||
description: |
|
||||
请选择你运行程序的`python`版本。
|
||||
注意:在`python 3.7`中,有部分可选依赖无法安装。
|
||||
经过长时间的观察,我们认为`python 3.8`是兼容性最好的版本。
|
||||
`python 3.7`~`python 3.10`以外版本的issue,将视情况直接关闭。
|
||||
options:
|
||||
- python 3.7
|
||||
- python 3.8
|
||||
- python 3.9
|
||||
- python 3.10
|
||||
- other
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 使用的chatgpt-on-wechat版本是?
|
||||
description: |
|
||||
请确保你使用的是 [releases](https://github.com/zhayujie/chatgpt-on-wechat/releases) 中的最新版本。
|
||||
如果你使用git, 请使用`git branch`命令来查看分支。
|
||||
options:
|
||||
- Latest Release
|
||||
- Master (branch)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的`channel`类型是?
|
||||
description: |
|
||||
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
|
||||
options:
|
||||
- wx(个人微信, itchat)
|
||||
- wxy(个人微信, wechaty)
|
||||
- wechatmp(公众号, 订阅号)
|
||||
- wechatmp_service(公众号, 服务号)
|
||||
- terminal
|
||||
- other
|
||||
label: Environment
|
||||
description: "Version (`cow status`), OS, Python version, install method, model & channel."
|
||||
placeholder: |
|
||||
Version: v1.2.0
|
||||
OS: macOS / Linux / Windows / Docker
|
||||
Python: 3.11
|
||||
Install: installer / Docker / source
|
||||
Model & channel: deepseek-v4-flash, web
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 复现步骤 🕹
|
||||
description: |
|
||||
**⚠️ 不能复现将会关闭issue.**
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 问题描述 😯
|
||||
description: 详细描述出现的问题,或提供有关截图。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 终端日志 📒
|
||||
description: |
|
||||
在此处粘贴终端日志,可在主目录下`run.log`文件中找到,这会帮助我们更好的分析问题,注意隐去你的API key。
|
||||
如果在配置文件中加入`"debug": true`,打印出的日志会更有帮助。
|
||||
label: What happened?
|
||||
description: "Steps to reproduce, what you expected, and what happened instead. Screenshots welcome."
|
||||
placeholder: |
|
||||
1. ...
|
||||
2. ...
|
||||
|
||||
<details>
|
||||
<summary><i>示例</i></summary>
|
||||
```log
|
||||
[DEBUG][2023-04-16 00:23:22][plugin_manager.py:157] - Plugin SUMMARY triggered by event Event.ON_HANDLE_CONTEXT
|
||||
[DEBUG][2023-04-16 00:23:22][main.py:221] - [Summary] on_handle_context. content: $总结前100条消息
|
||||
[DEBUG][2023-04-16 00:23:24][main.py:240] - [Summary] limit: 100, duration: -1 seconds
|
||||
[ERROR][2023-04-16 00:23:24][chat_channel.py:244] - Worker return exception: name 'start_date' is not defined
|
||||
Traceback (most recent call last):
|
||||
File "C:\ProgramData\Anaconda3\lib\concurrent\futures\thread.py", line 57, in run
|
||||
result = self.fn(*self.args, **self.kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 132, in _handle
|
||||
reply = self._generate_reply(context)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 142, in _generate_reply
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
||||
File "D:\project\chatgpt-on-wechat\plugins\plugin_manager.py", line 159, in emit_event
|
||||
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 255, in on_handle_context
|
||||
records = self._get_records(session_id, start_time, limit)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 96, in _get_records
|
||||
c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", (session_id, start_date, limit))
|
||||
NameError: name 'start_date' is not defined
|
||||
[INFO][2023-04-16 00:23:36][app.py:14] - signal 2 received, exiting...
|
||||
```
|
||||
</details>
|
||||
value: |
|
||||
```log
|
||||
<此处粘贴终端日志>
|
||||
```
|
||||
Expected: ...
|
||||
Actual: ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Logs
|
||||
description: "Relevant logs from `run.log` (set `\"debug\": true` for more detail). ⚠️ Redact your API keys."
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
|
||||
31
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
31
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
@@ -1,28 +1,33 @@
|
||||
name: Feature request 🚀
|
||||
description: 提出你对项目的新想法或建议。
|
||||
description: Suggest a new idea or improvement.
|
||||
title: "[Feature] "
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写简略总结,谢谢❤️。
|
||||
> 💡 English is recommended so global developers can help. 推荐使用英文提交,谢谢 ❤️
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: ⚠️ 搜索是否存在类似issue
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索关键词查找是否存在相似issue。
|
||||
label: Self check
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有发现相似issue
|
||||
- label: I searched [existing issues](https://github.com/zhayujie/CowAgent/issues) (incl. closed) — no duplicate.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 总结
|
||||
description: 描述feature的功能。
|
||||
label: What's the problem?
|
||||
description: "The pain point or what's not working for you right now."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 举例
|
||||
description: 提供聊天示例,草图或相关网址。
|
||||
- type: textarea
|
||||
label: What would you like?
|
||||
description: "How you'd expect it to work. Examples, sketches, or links welcome."
|
||||
validations:
|
||||
required: false
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 动机
|
||||
description: 描述你提出该feature的动机,比如没有这项feature对你的使用造成了怎样的影响。 请提供更详细的场景描述,这可能会帮助我们发现并提出更好的解决方案。
|
||||
label: Contribution
|
||||
options:
|
||||
- label: I'd be interested in helping implement this.
|
||||
required: false
|
||||
|
||||
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: 📖 Documentation
|
||||
url: https://docs.cowagent.ai
|
||||
about: Setup guides, configuration, and FAQ.
|
||||
22
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
22
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
<!--
|
||||
Thanks for your contribution! Please write this PR in English.
|
||||
推荐使用英文填写,感谢 ❤️
|
||||
-->
|
||||
|
||||
## What does this PR do?
|
||||
|
||||
<!-- A short description of the change and why it's needed. -->
|
||||
|
||||
## Type of change
|
||||
|
||||
- [ ] Bug fix
|
||||
- [ ] New feature
|
||||
- [ ] Docs
|
||||
- [ ] Refactor / chore
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] I have read the [Contributing Guide](https://github.com/zhayujie/CowAgent/blob/master/CONTRIBUTING.md)
|
||||
- [ ] I tested this change locally
|
||||
- [ ] Code comments and docs are in English
|
||||
- [ ] Linked related issue (if any): closes #
|
||||
11
.github/workflows/deploy-image-arm.yml
vendored
11
.github/workflows/deploy-image-arm.yml
vendored
@@ -19,7 +19,7 @@ env:
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
if: github.repository == 'zhayujie/CowAgent'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -51,7 +51,12 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/zhayujie/chatgpt-on-wechat
|
||||
${{ env.REGISTRY }}/zhayujie/cowagent
|
||||
tags: |
|
||||
type=raw,value=latest-arm64,enable={{is_default_branch}}
|
||||
type=ref,event=branch,suffix=-arm64
|
||||
type=ref,event=tag,suffix=-arm64
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
@@ -60,7 +65,7 @@ jobs:
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
platforms: linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}-arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
|
||||
13
.github/workflows/deploy-image.yml
vendored
13
.github/workflows/deploy-image.yml
vendored
@@ -16,10 +16,11 @@ on:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
DOCKERHUB_IMAGE: zhayujie/chatgpt-on-wechat
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
if: github.repository == 'zhayujie/CowAgent'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -47,8 +48,14 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
zhayujie/chatgpt-on-wechat
|
||||
zhayujie/cowagent
|
||||
${{ env.REGISTRY }}/zhayujie/chatgpt-on-wechat
|
||||
${{ env.REGISTRY }}/zhayujie/cowagent
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
|
||||
13
.gitignore
vendored
13
.gitignore
vendored
@@ -3,16 +3,15 @@
|
||||
.vscode
|
||||
.venv
|
||||
.vs
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
*.pyc
|
||||
python
|
||||
config.json
|
||||
QR.png
|
||||
nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
logs/
|
||||
workspace
|
||||
@@ -33,8 +32,16 @@ plugins/banwords/lib/__pycache__
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
!plugins/agent
|
||||
!plugins/cow_cli
|
||||
client_config.json
|
||||
ref/
|
||||
**/.dev.vars
|
||||
.cursor/
|
||||
local/
|
||||
node_modules/
|
||||
|
||||
# cow cli
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
.cow.pid
|
||||
|
||||
61
CONTRIBUTING.md
Normal file
61
CONTRIBUTING.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Contributing to CowAgent
|
||||
|
||||
Thanks for taking the time to contribute! 🎉 CowAgent is built by a global
|
||||
community, and contributions of all sizes are welcome — from typo fixes to new
|
||||
features.
|
||||
|
||||
## Language policy
|
||||
|
||||
To keep the project accessible to a global community, **please write issues,
|
||||
pull requests, code comments, and commit messages in English.**
|
||||
|
||||
> 为方便全球开发者协作,请尽量使用**英文**提交 issue、PR、代码注释与
|
||||
> commit message。不必担心英文不完美——表达清楚即可,工具翻译也完全没问题。感谢理解 ❤️
|
||||
|
||||
## Reporting issues
|
||||
|
||||
Found a bug or have an idea? [Open an issue](https://github.com/zhayujie/CowAgent/issues/new/choose).
|
||||
|
||||
Before opening one, please search existing issues (including closed ones) to
|
||||
avoid duplicates, and make sure you're on the latest version.
|
||||
|
||||
## Submitting a pull request
|
||||
|
||||
1. **Fork** the repo and create a branch from `master`
|
||||
(e.g. `feat/web-search`, `fix/telegram-reconnect`).
|
||||
2. Make your change. Keep it focused — one logical change per PR.
|
||||
3. Follow the existing code style. Write comments and docstrings in English.
|
||||
4. Run the app locally to confirm your change works.
|
||||
5. Open a PR with a clear title and a short description of **what** and **why**.
|
||||
|
||||
We keep the bar friendly: clear, focused, and working is enough. Maintainers are
|
||||
happy to help polish details during review.
|
||||
|
||||
### Commit & PR titles
|
||||
|
||||
Use a short, imperative summary. The [Conventional Commits](https://www.conventionalcommits.org/)
|
||||
style is preferred but not required:
|
||||
|
||||
```
|
||||
feat: add web search tool
|
||||
fix: reconnect Telegram websocket on timeout
|
||||
docs: clarify Docker setup
|
||||
```
|
||||
|
||||
## Development setup
|
||||
|
||||
See the [Install from Source](https://docs.cowagent.ai/guide/manual-install)
|
||||
guide. In short:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/zhayujie/CowAgent.git
|
||||
cd CowAgent
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
cow start
|
||||
```
|
||||
|
||||
## Code of conduct
|
||||
|
||||
Be respectful and constructive. We want CowAgent to be a welcoming place for
|
||||
everyone.
|
||||
899
README.md
899
README.md
@@ -1,738 +1,261 @@
|
||||
<p align="center"><img src= "https://github.com/user-attachments/assets/31fb4eab-3be4-477d-aa76-82cf62bfd12c" alt="Chatgpt-on-Wechat" width="600" /></p>
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/eca9a9ec-8534-4615-9e0f-96c5ac1d10a3" alt="CowAgent" width="420" /></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/chatgpt-on-wechat" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/chatgpt-on-wechat" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat"><img src="https://img.shields.io/github/stars/zhayujie/chatgpt-on-wechat?style=flat-square" alt="Stars"></a> <br/>
|
||||
<a href="https://github.com/zhayujie/CowAgent/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/CowAgent" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/CowAgent/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/CowAgent" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/CowAgent"><img src="https://img.shields.io/github/stars/zhayujie/CowAgent?style=flat-square" alt="Stars"></a> <br/>
|
||||
[English] | [<a href="docs/zh/README.md">中文</a>] | [<a href="docs/ja/README.md">日本語</a>]
|
||||
</p>
|
||||
|
||||
**chatgpt-on-wechat**(简称CoW)项目是基于大模型的智能对话机器人,支持自由切换多种模型,可接入网页、微信公众号、企业微信应用、飞书、钉钉中使用,能处理文本、语音、图片、文件等多模态消息,支持通过插件访问操作系统和互联网等外部资源,以及基于自有知识库定制企业AI应用。
|
||||
**CowAgent** is an open-source super AI assistant that proactively plans tasks, controls your computer and external services, creates and runs Skills, and grows alongside you through a personal knowledge base and long-term memory — a reference implementation of Agent Harness engineering.
|
||||
|
||||
# 简介
|
||||
CowAgent is lightweight, easy to deploy, and built to extend. Plug in any major LLM provider and run it 24/7 on a personal computer or server, across the web and all major IM platforms.
|
||||
|
||||
> 该项目既是一个可以开箱即用的对话机器人,也是一个支持高度扩展的AI应用框架,可以通过为项目添加大模型接口、接入渠道、自定义插件来灵活实现各种定制需求。支持的功能如下:
|
||||
|
||||
- ✅ **多端部署:** 有多种部署方式可选择且功能完备,目前已支持网页、微信公众号、企业微信应用、飞书、钉钉等部署方式
|
||||
- ✅ **基础对话:** 私聊及群聊的AI智能回复,支持多轮会话上下文记忆,基础模型支持OpenAI, Claude, Gemini, DeepSeek, 通义千问, Kimi, 文心一言, 讯飞星火, ChatGLM, MiniMax, GiteeAI, ModelScope, LinkAI
|
||||
- ✅ **语音能力:** 可识别语音消息,通过文字或语音回复,支持 openai(whisper/tts), azure, baidu, google 等多种语音模型
|
||||
- ✅ **图像能力:** 支持图片生成、图片识别、图生图,可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型
|
||||
- ✅ **丰富插件:** 支持自定义插件扩展,已实现多角色切换、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索、智能体等内置插件
|
||||
- ✅ **Agent能力:** 支持访问浏览器、终端、文件系统、搜索引擎等各类工具,并可通过多智能体协作完成复杂任务,基于 [AgentMesh](https://github.com/MinimalFuture/AgentMesh) 框架实现
|
||||
- ✅ **知识库:** 通过上传知识库自定义专属机器人,可作为数字分身、智能客服、企业智能体使用,基于 [LinkAI](https://link-ai.tech) 实现
|
||||
|
||||
## 声明
|
||||
|
||||
1. 本项目遵循 [MIT开源协议](/LICENSE),仅用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为。任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任
|
||||
2. 境内使用该项目时,建议使用国内厂商的大模型服务,并进行必要的内容安全审核及过滤
|
||||
3. 本项目当前主要接入协同办公平台,推荐使用网页、公众号、企微自建应用、钉钉、飞书等接入通道,其他通道为历史产物暂不维护
|
||||
|
||||
## 演示
|
||||
|
||||
DEMO视频:https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
## 社区
|
||||
|
||||
添加小助手微信加入开源项目交流群:
|
||||
|
||||
<img width="140" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
<p align="center">
|
||||
<a href="https://cowagent.ai/">🌐 Website</a> ·
|
||||
<a href="https://docs.cowagent.ai/intro/index">📖 Docs</a> ·
|
||||
<a href="https://docs.cowagent.ai/guide/quick-start">🚀 Quick Start</a> ·
|
||||
<a href="https://skills.cowagent.ai/">🧩 Skill Hub</a> ·
|
||||
<a href="https://link-ai.tech/cowagent/create">☁️ Try Online</a>
|
||||
</p>
|
||||
|
||||
<br/>
|
||||
|
||||
# 企业服务
|
||||
## 🌟 Highlights
|
||||
|
||||
<a href="https://link-ai.tech" target="_blank"><img width="720" src="https://cdn.link-ai.tech/image/link-ai-intro.jpg"></a>
|
||||
|
||||
> [LinkAI](https://link-ai.tech/) 是面向企业和开发者的一站式AI智能体平台,聚合多模态大模型、知识库、Agent 插件、工作流等能力,支持一键接入主流平台并进行管理,支持SaaS、私有化部署等多种模式。
|
||||
>
|
||||
> LinkAI 目前已在智能客服、私域运营、企业效率助手等场景积累了丰富的AI解决方案,在消费、健康、文教、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。
|
||||
|
||||
**产品咨询和企业服务** 可联系产品客服:
|
||||
|
||||
<img width="150" src="https://cdn.link-ai.tech/portal/linkai-customer-service.png">
|
||||
| Capability | Description |
|
||||
| :--- | :--- |
|
||||
| [Planning](https://docs.cowagent.ai/intro/architecture) | Decomposes complex tasks and executes them step by step, looping over tools until the goal is reached |
|
||||
| [Memory](https://docs.cowagent.ai/memory/index) | Three-tier architecture (context → daily → core), automatic Deep Dream distillation, hybrid keyword + vector retrieval |
|
||||
| [Knowledge](https://docs.cowagent.ai/knowledge/index) | Auto-curates structured knowledge into a Markdown wiki, builds an evolving knowledge graph with visual browsing |
|
||||
| [Skills](https://docs.cowagent.ai/skills/index) | One-click install from [Skill Hub](https://skills.cowagent.ai/), GitHub, ClawHub; or create custom skills via natural-language conversation |
|
||||
| [Tools](https://docs.cowagent.ai/tools/index) | Built-in file I/O, terminal, browser, scheduler, memory retrieval, web search, and 10+ more tools — with native MCP integration |
|
||||
| [Channels](https://docs.cowagent.ai/channels/index) | Integrates with Web, WeChat, Feishu, DingTalk, WeCom, QQ, Official Accounts, Telegram, and Slack |
|
||||
| Multimodal | First-class support for text, images, voice, and files — recognition, generation, and delivery |
|
||||
| [Models](https://docs.cowagent.ai/models/index) | Claude, GPT, Gemini, DeepSeek, Qwen, GLM, Kimi, MiniMax, Doubao, and more — swap providers from the Web console with one click |
|
||||
| [Deploy](https://docs.cowagent.ai/guide/quick-start) | One-line installer, unified Web console, multiple deployment modes (local, Docker, server) |
|
||||
|
||||
<br/>
|
||||
|
||||
# 🏷 更新日志
|
||||
## 🏗️ Architecture
|
||||
|
||||
>**2025.05.23:** [1.7.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.6) 优化web网页channel、新增 [AgentMesh多智能体插件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md)、百度语音合成优化、企微应用`access_token`获取优化、支持`claude-4-sonnet`和`claude-4-opus`模型
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/architecture/en/architecture.jpg" alt="CowAgent Architecture" width="750"/>
|
||||
|
||||
>**2025.04.11:** [1.7.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.5) 新增支持 [wechatferry](https://github.com/zhayujie/chatgpt-on-wechat/pull/2562) 协议、新增 deepseek 模型、新增支持腾讯云语音能力、新增支持 ModelScope 和 Gitee-AI API接口
|
||||
CowAgent is a complete **Agent Harness**: messages flow in through **Channels**; the **Agent Core** plans and reasons over memory, knowledge, and the available tools and skills; **Models** generate the response, which is sent back through the originating channel. Every layer is decoupled and independently extensible.
|
||||
|
||||
>**2024.12.13:** [1.7.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.4) 新增 Gemini 2.0 模型、新增web channel、解决内存泄漏问题、解决 `#reloadp` 命令重载不生效问题
|
||||
|
||||
>**2024.10.31:** [1.7.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.3) 程序稳定性提升、数据库功能、Claude模型优化、linkai插件优化、离线通知
|
||||
|
||||
更多更新历史请查看: [更新日志](/docs/version/release-notes.md)
|
||||
Read more in [Architecture](https://docs.cowagent.ai/intro/architecture).
|
||||
|
||||
<br/>
|
||||
|
||||
# 🚀 快速开始
|
||||
## 🚀 Quick Start
|
||||
|
||||
项目提供了一键安装、启动、管理程序的脚本,可以选择使用脚本快速运行,也可以根据详细指引一步步安装运行。
|
||||
A one-line installer takes care of dependencies, configuration, and startup:
|
||||
|
||||
- 详细文档:[快速开始](https://docs.link-ai.tech/cow/quick-start)
|
||||
|
||||
- 一键安装脚本说明:[一键安装脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/%E4%B8%80%E9%94%AE%E5%AE%89%E8%A3%85%E5%90%AF%E5%8A%A8%E8%84%9A%E6%9C%AC)
|
||||
**Linux / macOS:**
|
||||
|
||||
```bash
|
||||
bash <(curl -sS https://cdn.link-ai.tech/code/cow/install.sh)
|
||||
bash <(curl -fsSL https://cdn.link-ai.tech/code/cow/run.sh)
|
||||
```
|
||||
|
||||
- 项目管理脚本说明:[项目管理脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/%E9%A1%B9%E7%9B%AE%E7%AE%A1%E7%90%86%E8%84%9A%E6%9C%AC)
|
||||
**Windows (PowerShell):**
|
||||
|
||||
## 一、准备
|
||||
```powershell
|
||||
irm https://cdn.link-ai.tech/code/cow/run.ps1 | iex
|
||||
```
|
||||
|
||||
### 1. 模型账号
|
||||
|
||||
项目默认使用ChatGPT模型,需前往 [OpenAI平台](https://platform.openai.com/api-keys) 创建API Key并填入项目配置文件中。同时支持其他国内外产商以及第三方自定义模型接口,详情参考:[模型说明](#模型说明)。
|
||||
|
||||
同时支持使用 **LinkAI平台** 接口,可聚合使用 OpenAI、Claude、DeepSeek、Kimi、Qwen 等多种常用模型,并支持知识库、工作流、联网搜索、MJ绘图、文档总结等能力。修改配置即可一键启用,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
### 2.环境安装
|
||||
|
||||
支持 Linux、MacOS、Windows 系统,同时需安装 `Python`,Python版本需要在3.7以上,推荐使用3.9版本。
|
||||
|
||||
> 注意:选择Docker部署则无需安装python环境和下载源码,可直接快进到下一节。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
**Docker:**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/zhayujie/chatgpt-on-wechat
|
||||
cd chatgpt-on-wechat/
|
||||
curl -O https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
若遇到网络问题可使用国内仓库地址:https://gitee.com/zhayujie/chatgpt-on-wechat
|
||||
Once started, open `http://localhost:9899` to access the **Web console** — your one-stop hub to chat with the Agent, configure models, connect channels, and install skills.
|
||||
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
> Deploying on a server? Set `web_host` to `0.0.0.0` in `config.json` to make the console reachable from outside, and set `web_password` to protect it. Don't forget to open port `9899` in your firewall or security group.
|
||||
|
||||
> 📖 Detailed guides: [Quick Start](https://docs.cowagent.ai/guide/quick-start) · [Install from Source](https://docs.cowagent.ai/guide/manual-install) · [Upgrade](https://docs.cowagent.ai/guide/upgrade)
|
||||
|
||||
After installation, manage the service with the [cow CLI](https://docs.cowagent.ai/cli/index):
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
cow start | stop | restart # service control
|
||||
cow status | logs # status and logs
|
||||
cow update # pull latest code and restart
|
||||
cow skill install <name> # install a skill
|
||||
cow install-browser # install browser automation
|
||||
```
|
||||
|
||||
**(3) 拓展依赖 (可选,建议安装):**
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements-optional.txt
|
||||
```
|
||||
如果某项依赖安装失败可注释掉对应的行后重试。
|
||||
|
||||
## 二、配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
|
||||
```bash
|
||||
cp config-template.json config.json
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证JSON格式的规范):
|
||||
|
||||
```bash
|
||||
# config.json 文件内容示例
|
||||
{
|
||||
"channel_type": "web", # 接入渠道类型,默认为web,支持修改为:terminal, wechatmp, wechatmp_service, wechatcom_app, dingtalk, feishu
|
||||
"model": "gpt-4o-mini", # 模型名称, 支持 gpt-4o-mini, gpt-4.1, gpt-4o, deepseek-reasoner, wenxin, xunfei, glm-4, claude-3-7-sonnet-latest, moonshot等
|
||||
"open_ai_api_key": "YOUR API KEY", # 如果使用openAI模型则填入上面创建的 OpenAI API KEY
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # OpenAI接口代理地址,修改此项可接入第三方模型接口
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"voice_reply_voice": false, # 是否使用语音回复语音
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 系统提示词
|
||||
# 订阅欢迎语,公众号和企业微信channel中使用,当被订阅时会自动回复以下内容
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台的智能体
|
||||
"linkai_api_key": "", # LinkAI Api Key
|
||||
"linkai_app_code": "" # LinkAI 应用或工作流的code
|
||||
}
|
||||
```
|
||||
|
||||
**详细配置说明:**
|
||||
|
||||
<details>
|
||||
<summary>1. 单聊配置</summary>
|
||||
|
||||
+ 个人聊天中,需要以 "bot"或"@bot" 为开头的内容触发机器人,对应配置项 `single_chat_prefix` (如果不需要以前缀触发可以填写 `"single_chat_prefix": [""]`)
|
||||
+ 机器人回复的内容会以 "[bot] " 作为前缀, 以区分真人,对应的配置项为 `single_chat_reply_prefix` (如果不需要前缀可以填写 `"single_chat_reply_prefix": ""`)
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary>2. 群聊配置</summary>
|
||||
|
||||
+ 群组聊天中,群名称需配置在 `group_name_white_list ` 中才能开启群聊自动回复。如果想对所有群聊生效,可以直接填写 `"group_name_white_list": ["ALL_GROUP"]`
|
||||
+ 默认只要被人 @ 就会触发机器人自动回复;另外群聊天中只要检测到以 "@bot" 开头的内容,同样会自动回复(方便自己触发),这对应配置项 `group_chat_prefix`
|
||||
+ 可选配置: `group_name_keyword_white_list`配置项支持模糊匹配群名称,`group_chat_keyword`配置项则支持模糊匹配群消息内容,用法与上述两个配置项相同。(Contributed by [evolay](https://github.com/evolay))
|
||||
+ `group_chat_in_one_session`:使群聊共享一个会话上下文,配置 `["ALL_GROUP"]` 则作用于所有群聊
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>3. 语音配置</summary>
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>4. 其他配置</summary>
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-4o-mini`, `gpt-4.1`, `gpt-4o`, `gpt-3.5-turbo`, `wenxin` , `claude` , `gemini`, `glm-4`, `xunfei`, `moonshot`等,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档,在[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中检查哪些参数在本项目中是可配置的。
|
||||
+ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
|
||||
+ `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
|
||||
+ `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
|
||||
+ `hot_reload`: 程序退出后,暂存等于状态,默认关闭。
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>5. LinkAI配置</summary>
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台的Agent,使用知识库、工作流、联网搜索、`Midjourney` 绘画等能力, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用或工作流的code,选填
|
||||
</details>
|
||||
|
||||
注:完整配置项说明可在 [`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py) 文件中查看。
|
||||
|
||||
## 三、运行
|
||||
|
||||
### 1.本地运行
|
||||
|
||||
如果是个人计算机 **本地运行**,直接在项目根目录下执行:
|
||||
|
||||
```bash
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
|
||||
运行后默认会启动一个web服务,可以通过访问 `http://localhost:9899/chat` 在网页端对话。如果需要接入其他应用通道只需修改 `config.json` 配置文件中的 `channel_type` 参数,详情参考:[通道说明](#通道说明)。
|
||||
|
||||
向机器人发送 `#help` 消息可以查看可用指令及插件的说明。
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
在服务器中可使用 `nohup` 命令在后台运行程序:
|
||||
|
||||
```bash
|
||||
nohup python3 app.py & tail -f nohup.out
|
||||
```
|
||||
|
||||
执行后程序运行于服务器后台,可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。 日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。
|
||||
|
||||
此外,项目的 `scripts` 目录下有一键运行、关闭程序的脚本供使用。 运行后默认channel为web,通过可以通过修改配置文件进行切换。
|
||||
|
||||
|
||||
### 3.Docker部署
|
||||
|
||||
使用docker部署无需下载源码和安装依赖,只需要获取 `docker-compose.yml` 配置文件并启动容器即可。
|
||||
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功后执行 `docker -v` 和 `docker-compose version` (或 `docker compose version`) 可查看到版本号。安装地址为 [docker官网](https://docs.docker.com/engine/install/) 。
|
||||
|
||||
**(1) 下载 docker-compose.yml 文件**
|
||||
|
||||
```bash
|
||||
wget https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
```
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 填写所需配置,例如 `CHANNEL_TYPE`、`OPEN_AI_API_KEY` 和等配置。
|
||||
|
||||
**(2) 启动容器**
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
|
||||
```bash
|
||||
sudo docker compose up -d # 若docker-compose为 1.X 版本,则执行 `sudo docker-compose up -d`
|
||||
```
|
||||
|
||||
运行命令后,会自动取 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取最新release版本的镜像。当执行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。最后执行以下命令可查看容器的运行日志:
|
||||
|
||||
```bash
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
**(3) 插件使用**
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
|
||||
```
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
**注**:使用docker方式部署的详细教程可以参考:[docker部署CoW项目](https://www.wangpc.cc/ai/docker-deploy-cow/)
|
||||
|
||||
|
||||
## 模型说明
|
||||
|
||||
以下对所有可支持的模型的配置和使用方法进行说明,模型接口实现在项目的 `bot/` 目录下。
|
||||
>部分模型厂商接入有官方sdk和OpenAI兼容两种方式,建议使用OpenAI兼容的方式。
|
||||
|
||||
<details>
|
||||
<summary>OpenAI</summary>
|
||||
|
||||
1. API Key创建:在 [OpenAI平台](https://platform.openai.com/api-keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4.1-mini",
|
||||
"open_ai_api_key": "YOUR_API_KEY",
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"bot_type": "chatGPT"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 与OpenAI接口的 [model参数](https://platform.openai.com/docs/models) 一致,支持包括 o系列、gpt-4系列、gpt-3.5系列等模型
|
||||
- `open_ai_api_base`: 如果需要接入第三方代理接口,可通过修改该参数进行接入
|
||||
- `bot_type`: 使用OpenAI相关模型时无需填写。当使用第三方代理接口接入Claude等非OpenAI官方模型时,该参数设为 `chatGPT`
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>LinkAI</summary>
|
||||
|
||||
1. API Key创建:在 [LinkAI平台](https://link-ai.tech/console/interface) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"use_linkai": true,
|
||||
"linkai_api_key": "YOUR API KEY",
|
||||
"linkai_app_code": "YOUR APP CODE"
|
||||
}
|
||||
```
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台的智能体,使用知识库、工作流、数据库、联网搜索、MCP工具等丰富的Agent能力, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI平台的API Key,可在 [控制台](https://link-ai.tech/console/interface) 中创建
|
||||
+ `linkai_app_code`: LinkAI智能体 (应用或工作流) 的code,选填。智能体创建可参考 [说明文档](https://docs.link-ai.tech/platform/quick-start)
|
||||
+ `model`: model字段填写空则直接使用智能体的模型,可在平台中灵活切换,[模型列表](https://link-ai.tech/console/models)中的全部模型均可使用
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>DeepSeek</summary>
|
||||
|
||||
1. API Key创建:在 [DeepSeek平台](https://platform.deepseek.com/api_keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "deepseek-chat",
|
||||
"open_ai_api_key": "sk-xxxxxxxxxxx",
|
||||
"open_ai_api_base": "https://api.deepseek.com/v1"
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `deepseek-chat、deepseek-reasoner`,分别对应的是 V3 和 R1 模型
|
||||
- `open_ai_api_key`: DeepSeek平台的 API Key
|
||||
- `open_ai_api_base`: DeepSeek平台 BASE URL
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Azure</summary>
|
||||
|
||||
1. API Key创建:在 [DeepSeek平台](https://platform.deepseek.com/api_keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "",
|
||||
"use_azure_chatgpt": true,
|
||||
"open_ai_api_key": "e7ffc5dd84f14521a53f14a40231ea78",
|
||||
"open_ai_api_base": "https://linkai-240917.openai.azure.com/",
|
||||
"azure_deployment_id": "gpt-4.1",
|
||||
"azure_api_version": "2025-01-01-preview"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 留空即可
|
||||
- `use_azure_chatgpt`: 设为 true
|
||||
- `open_ai_api_key`: Azure平台的密钥
|
||||
- `open_ai_api_base`: Azure平台的 BASE URL
|
||||
- `azure_deployment_id`: Azure平台部署的模型名称
|
||||
- `azure_api_version`: api版本以及以上参数可以在部署的 [模型配置](https://oai.azure.com/resource/deployments) 界面查看
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Claude</summary>
|
||||
|
||||
1. API Key创建:在 [Claude控制台](https://console.anthropic.com/settings/keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-sonnet-4-0",
|
||||
"claude_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 参考 [官方模型ID](https://docs.anthropic.com/en/docs/about-claude/models/overview#model-aliases) ,例如`claude-opus-4-0`、`claude-3-7-sonnet-latest`等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>通义千问</summary>
|
||||
|
||||
方式一:官方SDK接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-turbo",
|
||||
"dashscope_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写`qwen-turbo、qwen-plus、qwen-max`
|
||||
- `dashscope_api_key`: 通义千问的 API-KEY,参考 [官方文档](https://bailian.console.aliyun.com/?tab=api#/api) ,在 [控制台](https://bailian.console.aliyun.com/?tab=model#/api-key) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "qwen-turbo",
|
||||
"open_ai_api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"open_ai_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://help.aliyun.com/zh/model-studio/models?spm=a2c4g.11186623.0.0.78d84823Kth5on#9f8890ce29g5u)
|
||||
- `open_ai_api_base`: 通义千问API的 BASE URL
|
||||
- `open_ai_api_key`: 通义千问的 API-KEY,参考 [官方文档](https://bailian.console.aliyun.com/?tab=api#/api) ,在 [控制台](https://bailian.console.aliyun.com/?tab=model#/api-key) 创建
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Gemini</summary>
|
||||
|
||||
API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn) 创建API Key ,配置如下
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-pro",
|
||||
"gemini_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 参考[官方文档-模型列表](https://ai.google.dev/gemini-api/docs/models?hl=zh-cn)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Moonshot</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "moonshot-v1-8k",
|
||||
"moonshot_api_key": "moonshot-v1-8k"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写`moonshot-v1-8k、 moonshot-v1-32k、 moonshot-v1-128k`
|
||||
- `moonshot_api_key`: Moonshot的API-KEY,在 [控制台](https://platform.moonshot.cn/console/api-keys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "moonshot-v1-8k",
|
||||
"open_ai_api_base": "https://api.moonshot.cn/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写`moonshot-v1-8k、 moonshot-v1-32k、 moonshot-v1-128k`
|
||||
- `open_ai_api_base`: Moonshot的 BASE URL
|
||||
- `open_ai_api_key`: Moonshot的 API-KEY,在 [控制台](https://platform.moonshot.cn/console/api-keys) 创建
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>百度文心</summary>
|
||||
方式一:官方SDK接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "wenxin",
|
||||
"baidu_wenxin_api_key": "IajztZ0bDxgnP9bEykU7lBer",
|
||||
"baidu_wenxin_secret_key": "EDPZn6L24uAS9d8RWFfotK47dPvkjD6G"
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `wenxin`和`wenxin-4`,对应模型为 文心-3.5 和 文心-4.0
|
||||
- `baidu_wenxin_api_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 API Key
|
||||
- `baidu_wenxin_secret_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 Secret Key
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "qwen-turbo",
|
||||
"open_ai_api_base": "https://qianfan.baidubce.com/v2",
|
||||
"open_ai_api_key": "bce-v3/ALTxxxxxxd2b"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Wm9cvy6rl)
|
||||
- `open_ai_api_base`: 百度文心API的 BASE URL
|
||||
- `open_ai_api_key`: 百度文心的 API-KEY,参考 [官方文档](https://cloud.baidu.com/doc/qianfan-api/s/ym9chdsy5) ,在 [控制台](https://console.bce.baidu.com/iam/#/iam/apikey/list) 创建API Key
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>讯飞星火</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
参考 [官方文档-快速指引](https://www.xfyun.cn/doc/platform/quickguide.html#%E7%AC%AC%E4%BA%8C%E6%AD%A5-%E5%88%9B%E5%BB%BA%E6%82%A8%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8-%E5%BC%80%E5%A7%8B%E4%BD%BF%E7%94%A8%E6%9C%8D%E5%8A%A1) 获取 `APPID、 APISecret、 APIKey` 三个参数
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "xunfei",
|
||||
"xunfei_app_id": "",
|
||||
"xunfei_api_key": "",
|
||||
"xunfei_api_secret": "",
|
||||
"xunfei_domain": "4.0Ultra",
|
||||
"xunfei_spark_url": "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
}
|
||||
```
|
||||
- `model`: 填 `xunfei`
|
||||
- `xunfei_domain`: 可填写 `4.0Ultra、 generalv3.5、 max-32k、 generalv3、 pro-128k、 lite`
|
||||
- `xunfei_spark_url`: 填写参考 [官方文档-请求地址](https://www.xfyun.cn/doc/spark/Web.html#_1-1-%E8%AF%B7%E6%B1%82%E5%9C%B0%E5%9D%80) 的说明
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "4.0Ultra",
|
||||
"open_ai_api_base": "https://spark-api-open.xf-yun.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写 `4.0Ultra、 generalv3.5、 max-32k、 generalv3、 pro-128k、 lite`
|
||||
- `open_ai_api_base`: 讯飞星火平台的 BASE URL
|
||||
- `open_ai_api_key`: 讯飞星火平台的[APIPassword](https://console.xfyun.cn/services/bm3) ,因模型而已
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>智谱AI</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "glm-4-plus",
|
||||
"zhipu_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `glm-4-plus、glm-4-air-250414、glm-4-airx、glm-4-long 、glm-4-flashx 、glm-4-flash-250414`, 参考 [glm-4系列模型编码](https://bigmodel.cn/dev/api/normal-model/glm-4)
|
||||
- `zhipu_ai_api_key`: 智谱AI平台的 API KEY,在 [控制台](https://www.bigmodel.cn/usercenter/proj-mgmt/apikeys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "glm-4-plus",
|
||||
"open_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `glm-4-plus、glm-4-air-250414、glm-4-airx、glm-4-long 、glm-4-flashx 、glm-4-flash-250414`, 参考 [glm-4系列模型编码](https://bigmodel.cn/dev/api/normal-model/glm-4)
|
||||
- `open_ai_api_base`: 智谱AI平台的 BASE URL
|
||||
- `open_ai_api_key`: 智谱AI平台的 API KEY,在 [控制台](https://www.bigmodel.cn/usercenter/proj-mgmt/apikeys) 创建
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>MiniMax</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "abab6.5-chat",
|
||||
"Minimax_api_key": "",
|
||||
"Minimax_group_id": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填写`abab6.5-chat`
|
||||
- `Minimax_api_key`:MiniMax平台的API-KEY,在 [控制台](https://platform.minimaxi.com/user-center/basic-information/interface-key) 创建
|
||||
- `Minimax_group_id`: 在 [账户信息](https://platform.minimaxi.com/user-center/basic-information) 右上角获取
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "MiniMax-M1",
|
||||
"open_ai_api_base": "https://api.minimaxi.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填`MiniMax-M1、MiniMax-Text-01`,参考[API文档](https://platform.minimaxi.com/document/%E5%AF%B9%E8%AF%9D?key=66701d281d57f38758d581d0#QklxsNSbaf6kM4j6wjO5eEek)
|
||||
- `open_ai_api_base`: MiniMax平台API的 BASE URL
|
||||
- `open_ai_api_key`: MiniMax平台的API-KEY,在 [控制台](https://platform.minimaxi.com/user-center/basic-information/interface-key) 创建
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ModelScope</summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"bot_type": "modelscope",
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"modelscope_api_key": "your_api_key",
|
||||
"modelscope_base_url": "https://api-inference.modelscope.cn/v1/chat/completions",
|
||||
"text_to_image": "MusePublic/489_ckpt_FLUX_1"
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: modelscope接口格式
|
||||
- `model`: 参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
- `modelscope_api_key`: 参考 [官方文档-访问令牌](https://modelscope.cn/docs/accounts/token) ,在 [控制台](https://modelscope.cn/my/myaccesstoken)
|
||||
- `modelscope_base_url`: modelscope平台的 BASE URL
|
||||
- `text_to_image`: 图像生成模型,参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
</details>
|
||||
|
||||
|
||||
## 通道说明
|
||||
|
||||
以下对可接入通道的配置方式进行说明,应用通道代码在项目的 `channel/` 目录下。
|
||||
|
||||
<details>
|
||||
<summary>Web</summary>
|
||||
|
||||
项目启动后默认运行web通道,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "web",
|
||||
"web_port": 9899
|
||||
}
|
||||
```
|
||||
- `web_port`: 默认为 9899,可按需更改,需要服务器防火墙和安全组放行该端口
|
||||
- 如本地运行,启动后请访问 `http://localhost:port/chat` ;如服务器运行,请访问 `http://ip:port/chat`
|
||||
> 注:请将上述 url 中的 ip 或者 port 替换为实际的值
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Terminal</summary>
|
||||
|
||||
修改 `config.json` 中的 `channel_type` 字段:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "terminal"
|
||||
}
|
||||
```
|
||||
|
||||
运行后可在终端与机器人进行对话。
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>微信公众号</summary>
|
||||
|
||||
本项目支持订阅号和服务号两种公众号,通过服务号(`wechatmp_service`)体验更佳。将下列配置加入 `config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatmp",
|
||||
"wechatmp_token": "TOKEN",
|
||||
"wechatmp_port": 80,
|
||||
"wechatmp_app_id": "APPID",
|
||||
"wechatmp_app_secret": "APPSECRET",
|
||||
"wechatmp_aes_key": ""
|
||||
}
|
||||
```
|
||||
- `channel_type`: 个人订阅号为`wechatmp`,企业服务号为`wechatmp_service`
|
||||
|
||||
详细步骤和参数说明参考 [微信公众号接入](https://docs.link-ai.tech/cow/multi-platform/wechat-mp)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>企业微信应用</summary>
|
||||
|
||||
企业微信自建应用接入需在后台创建应用并启用消息回调,配置示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatcom_app",
|
||||
"wechatcom_corp_id": "CORPID",
|
||||
"wechatcomapp_token": "TOKEN",
|
||||
"wechatcomapp_port": 9898,
|
||||
"wechatcomapp_secret": "SECRET",
|
||||
"wechatcomapp_agent_id": "AGENTID",
|
||||
"wechatcomapp_aes_key": "AESKEY"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [企微自建应用接入](https://docs.link-ai.tech/cow/multi-platform/wechat-com)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>钉钉</summary>
|
||||
|
||||
钉钉需要在开放平台创建智能机器人应用,将以下配置填入 `config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "dingtalk",
|
||||
"dingtalk_client_id": "CLIENT_ID",
|
||||
"dingtalk_client_secret": "CLIENT_SECRET"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [钉钉接入](https://docs.link-ai.tech/cow/multi-platform/dingtalk)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>飞书</summary>
|
||||
|
||||
通过自建应用接入AI相关能力到飞书应用中,默认已是飞书的企业用户,且具有企业管理权限,将以下配置填入 `config.json`::
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "APP_ID",
|
||||
"feishu_app_secret": "APP_SECRET",
|
||||
"feishu_token": "VERIFICATION_TOKEN",
|
||||
"feishu_port": 80
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [飞书接入](https://docs.link-ai.tech/cow/multi-platform/feishu)
|
||||
</details>
|
||||
|
||||
<br/>
|
||||
|
||||
# 🔗 相关项目
|
||||
## 🤖 Models
|
||||
|
||||
- [bot-on-anything](https://github.com/zhayujie/bot-on-anything):轻量和高可扩展的大模型应用框架,支持接入Slack, Telegram, Discord, Gmail等海外平台,可作为本项目的补充使用。
|
||||
- [AgentMesh](https://github.com/MinimalFuture/AgentMesh):开源的多智能体(Multi-Agent)框架,可以通过多智能体团队的协同来解决复杂问题。本项目基于该框架实现了[Agent插件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md),可访问终端、浏览器、文件系统、搜索引擎 等各类工具,并实现了多智能体协同。
|
||||
CowAgent supports all mainstream LLM providers. **Chat, vision, image generation, ASR/TTS, and embeddings** can each be routed to a different vendor. Providers are configured directly in the Web console — no manual file editing required.
|
||||
|
||||
| Provider | Featured Models | Chat | Vision | Image Gen | ASR | TTS | Embedding |
|
||||
| --- | --- | :-: | :-: | :-: | :-: | :-: | :-: |
|
||||
| [Claude](https://docs.cowagent.ai/models/claude) | claude-opus-4-8 | ✅ | ✅ | | | | |
|
||||
| [OpenAI](https://docs.cowagent.ai/models/openai) | gpt-5.5, o-series | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Gemini](https://docs.cowagent.ai/models/gemini) | gemini-3.5-flash | ✅ | ✅ | ✅ | | | |
|
||||
| [DeepSeek](https://docs.cowagent.ai/models/deepseek) | deepseek-v4-flash / pro | ✅ | | | | | |
|
||||
| [Qwen](https://docs.cowagent.ai/models/qwen) | qwen3.7-plus | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [GLM](https://docs.cowagent.ai/models/glm) | glm-5.1, glm-5v-turbo | ✅ | ✅ | | ✅ | | ✅ |
|
||||
| [Doubao](https://docs.cowagent.ai/models/doubao) | doubao-seed-2.0 series | ✅ | ✅ | ✅ | | | ✅ |
|
||||
| [Kimi](https://docs.cowagent.ai/models/kimi) | kimi-k2.6 | ✅ | ✅ | | | | |
|
||||
| [MiniMax](https://docs.cowagent.ai/models/minimax) | MiniMax-M3 | ✅ | ✅ | ✅ | | ✅ | |
|
||||
| [ERNIE](https://docs.cowagent.ai/models/qianfan) | ernie-5.1 | ✅ | ✅ | | | | |
|
||||
| [MiMo](https://docs.cowagent.ai/models/mimo) | mimo-v2.5 / pro | ✅ | ✅ | | | ✅ | |
|
||||
| [LinkAI](https://docs.cowagent.ai/models/linkai) | One key for 100+ models | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Custom](https://docs.cowagent.ai/models/custom) | Local models / third-party proxy | ✅ | | | | | |
|
||||
|
||||
> For details on each provider, see the [Models overview](https://docs.cowagent.ai/models/index).
|
||||
|
||||
# 🔎 常见问题
|
||||
<br/>
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
## 💬 Channels
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (知识库持续完善中,回复供参考)
|
||||
A single Agent instance can serve multiple channels in parallel. Most channels can be onboarded right from the Web console.
|
||||
|
||||
# 🛠️ 开发
|
||||
| Channel | Text | Image | File | Voice | Group |
|
||||
| --- | :-: | :-: | :-: | :-: | :-: |
|
||||
| [Web Console](https://docs.cowagent.ai/channels/web) (default) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Telegram](https://docs.cowagent.ai/channels/telegram) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Slack](https://docs.cowagent.ai/channels/slack) | ✅ | ✅ | ✅ | | ✅ |
|
||||
| [Discord](https://docs.cowagent.ai/channels/discord) | ✅ | ✅ | ✅ | | ✅ |
|
||||
| [WeChat](https://docs.cowagent.ai/channels/weixin) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Feishu / Lark](https://docs.cowagent.ai/channels/feishu) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [DingTalk](https://docs.cowagent.ai/channels/dingtalk) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [WeCom Bot](https://docs.cowagent.ai/channels/wecom-bot) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [QQ](https://docs.cowagent.ai/channels/qq) | ✅ | ✅ | ✅ | | ✅ |
|
||||
| [WeCom App](https://docs.cowagent.ai/channels/wecom) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [WeChat Customer Service](https://docs.cowagent.ai/channels/wechat-kf) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [WeChat Official Account](https://docs.cowagent.ai/channels/wechatmp) | ✅ | ✅ | | ✅ | |
|
||||
|
||||
欢迎接入更多应用通道,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py) 新增自定义通道,实现接收和发送消息逻辑即可完成接入。 同时欢迎贡献新的插件,参考 [插件开发文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
|
||||
> See the [Channels overview](https://docs.cowagent.ai/channels/index) for setup details.
|
||||
|
||||
# ✉ 联系
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-chat.png" alt="CowAgent Web Console" width="800"/>
|
||||
|
||||
欢迎提交PR、Issues进行反馈,以及通过 🌟Star 支持并关注项目更新。项目运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,以及前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品客服](https://cdn.link-ai.tech/portal/linkai-customer-service.png)咨询。
|
||||
*The Web console is the default channel and the unified entry point to configure models, channels, skills, memory, and more.*
|
||||
|
||||
# 🌟 贡献者
|
||||
<br/>
|
||||
|
||||

|
||||
## 🧠 Memory & Knowledge Base
|
||||
|
||||
**Long-term memory** uses a three-tier architecture: conversation context (short-term) → daily memory (mid-term) → MEMORY.md (long-term). A nightly **Deep Dream** pass distills scattered memories into refined long-term entries and a narrative journal. See [Long-term Memory](https://docs.cowagent.ai/memory/index) · [Deep Dream](https://docs.cowagent.ai/memory/deep-dream).
|
||||
|
||||
**Personal knowledge base** complements the time-ordered memory by organizing structured knowledge **by topic**. The Agent automatically curates valuable information from conversations, maintains cross-references and indexes, and the Web console offers an interactive knowledge-graph view. See [Personal Knowledge Base](https://docs.cowagent.ai/knowledge/index).
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="50%">
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-memory.png" alt="Long-term Memory" />
|
||||
<p align="center"><em>Long-term Memory · Three-tier architecture + Deep Dream</em></p>
|
||||
</td>
|
||||
<td width="50%">
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-knowledge.png" alt="Personal Knowledge Base" />
|
||||
<p align="center"><em>Knowledge Base · Auto-curated Markdown wiki</em></p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<br/>
|
||||
|
||||
## 🔧 Tools & Skills
|
||||
|
||||
**Tools** are atomic capabilities the Agent uses to interact with system resources. **Skills** are higher-level workflows defined by a manifest file that compose multiple tools to accomplish complex tasks.
|
||||
|
||||
### Tool System
|
||||
|
||||
**Built-in tools** cover file I/O (`read` / `write` / `edit` / `ls`), terminal (`bash`), file sending (`send`), memory retrieval (`memory`), environment variables (`env_config`), web fetching (`web_fetch`), scheduling (`scheduler`), web search (`web_search`), vision (`vision`), and browser automation (`browser`).
|
||||
|
||||
**MCP protocol** integrates the open ecosystem of [Model Context Protocol](https://modelcontextprotocol.io) servers. A single `mcp.json` is enough — supports stdio / SSE transports, hot reload, and zero-code integration.
|
||||
|
||||
Learn more: [Tools overview](https://docs.cowagent.ai/tools/index) · [MCP integration](https://docs.cowagent.ai/tools/mcp).
|
||||
|
||||
### Skills System
|
||||
|
||||
- **[Skill Hub](https://skills.cowagent.ai/)** — open skill marketplace: browse, search, install in one click
|
||||
- **GitHub / ClawHub / URL and more** — install skills from any source
|
||||
- **Conversational authoring** — generate custom skills through dialogue with `skill-creator`; turn any workflow or third-party API into a reusable skill
|
||||
|
||||
```bash
|
||||
/skill list # list installed skills
|
||||
/skill search <keyword> # search the marketplace
|
||||
/skill install <name> # one-click install
|
||||
```
|
||||
|
||||
Learn more: [Skills overview](https://docs.cowagent.ai/skills/index) · [Creating Skills](https://docs.cowagent.ai/skills/create).
|
||||
|
||||
<br/>
|
||||
|
||||
## 🏷 Changelog
|
||||
|
||||
> **2026.06.01:** [v2.1.0](https://github.com/zhayujie/CowAgent/releases/tag/2.1.0) — Internationalization, new channels (Telegram, Discord, Slack, WeChat Customer Service), CLI interaction upgrades, streamlined one-line install, MCP Streamable HTTP support, new models (claude-opus-4-8, MiMo).
|
||||
|
||||
> **2026.05.22:** [v2.0.9](https://github.com/zhayujie/CowAgent/releases/tag/2.0.9) — Model management, MCP protocol support, persistent browser sessions, new models (gpt-5.5, gemini-3.5-flash, qwen3.7-max), deployment hardening.
|
||||
|
||||
> **2026.05.06:** [v2.0.8](https://github.com/zhayujie/CowAgent/releases/tag/2.0.8) — Feishu channel overhaul (voice, streaming, QR onboarding), DeepSeek V4 and Baidu Qianfan support, scheduler tool upgrades.
|
||||
|
||||
> **2026.04.22:** [v2.0.7](https://github.com/zhayujie/CowAgent/releases/tag/2.0.7) — Built-in image generation (GPT Image 2, Nano Banana), new models (Kimi K2.6, Claude Opus 4.7, GLM 5.1), memory and knowledge enhancements.
|
||||
|
||||
> **2026.04.14:** [v2.0.6](https://github.com/zhayujie/CowAgent/releases/tag/2.0.6) — Knowledge base, Deep Dream memory distillation, smart context compression, multi-session Web console.
|
||||
|
||||
> **2026.04.01:** [v2.0.5](https://github.com/zhayujie/CowAgent/releases/tag/2.0.5) — Cow CLI, Skill Hub open source, browser tool, WeCom Bot QR onboarding.
|
||||
|
||||
> **2026.02.03:** [v2.0.0](https://github.com/zhayujie/CowAgent/releases/tag/2.0.0) — Major upgrade to a super Agent assistant with multi-step task planning, long-term memory, and the Skills framework.
|
||||
|
||||
Full history: [Release Notes](https://docs.cowagent.ai/releases/overview)
|
||||
|
||||
<br/>
|
||||
|
||||
## 🤝 Community & Support
|
||||
|
||||
[File an issue](https://github.com/zhayujie/CowAgent/issues) on GitHub, or scan the QR code below to join our WeChat community:
|
||||
|
||||
<img width="130" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
|
||||
<br/>
|
||||
|
||||
## 🔗 Related Projects
|
||||
|
||||
- **[Cow Skill Hub](https://github.com/zhayujie/cow-skill-hub)** — open skill marketplace for AI Agents; works with CowAgent, OpenClaw, Claude Code, and more
|
||||
- **[bot-on-anything](https://github.com/zhayujie/bot-on-anything)** — lightweight LLM application framework with integrations for Slack, Telegram, Discord, Gmail, and more
|
||||
- **[AgentMesh](https://github.com/MinimalFuture/AgentMesh)** — open-source multi-agent framework for solving complex problems through team collaboration
|
||||
|
||||
<br/>
|
||||
|
||||
## 🏢 Enterprise Services
|
||||
|
||||
[**LinkAI**](https://link-ai.tech/) is an all-in-one AI Agent platform for enterprises and developers, offering managed hosting and enterprise-grade support for CowAgent:
|
||||
|
||||
- **🚀 Zero-deployment hosted runtime** — spin up a [CowAgent online assistant](https://link-ai.tech/cowagent/create) in under a minute, no server required
|
||||
- **🧠 Agent infrastructure** — unified access to LLMs, knowledge bases, databases, skills, and workflows; plug-and-play building blocks that extend what CowAgent can do
|
||||
- **🏢 Team & enterprise features** — workspaces, role-based access, audit logs, and private deployment for production use cases
|
||||
|
||||
For enterprise inquiries: sales@simple-future.tech or [scan the QR code](https://cdn.link-ai.tech/consultant.jpg) to reach our team on WeChat.
|
||||
|
||||
<br/>
|
||||
|
||||
## 🛠️ Development & Contributing
|
||||
|
||||
All kinds of contributions are welcome — new features, bug fixes, performance improvements, docs, or sharing your own skills on the [Skill Hub](https://skills.cowagent.ai/submit). See [CONTRIBUTING.md](/CONTRIBUTING.md) to get started, then open an Issue to discuss or send a PR directly.
|
||||
|
||||
⭐ Star the project to show your support, and Watch → Custom → Releases to get notified of new versions. PRs and Issues are always welcome.
|
||||
|
||||
## 🌟 Contributors
|
||||
|
||||

|
||||
|
||||
<br/>
|
||||
|
||||
## ⚠️ Disclaimer
|
||||
|
||||
1. This project is licensed under the [MIT License](/LICENSE) and is intended for technical research and learning. You are responsible for complying with applicable laws and regulations in your jurisdiction; the maintainers assume no liability for any consequences arising from use of this project.
|
||||
2. **Cost & safety:** Agent mode consumes substantially more tokens than regular chat — pick models that balance quality and cost. The Agent has access to your local operating system, so only deploy it in trusted environments.
|
||||
3. CowAgent is a pure open-source project and does not participate in, authorize, or issue any cryptocurrency.
|
||||
|
||||
<br/>
|
||||
|
||||
## 📌 Project Renaming Notice
|
||||
|
||||
This project was previously named `chatgpt-on-wechat` and is now officially **CowAgent**. The old GitHub URL redirects automatically; existing users may optionally run `git remote set-url origin https://github.com/zhayujie/CowAgent.git` to update the local remote.
|
||||
|
||||
3
agent/chat/__init__.py
Normal file
3
agent/chat/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.chat.service import ChatService
|
||||
|
||||
__all__ = ["ChatService"]
|
||||
290
agent/chat/service.py
Normal file
290
agent/chat/service.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
ChatService - Wraps the Agent stream execution to produce CHAT protocol chunks.
|
||||
|
||||
Translates agent events (message_update, message_end, tool_execution_end, etc.)
|
||||
into the CHAT socket protocol format (content chunks with segment_id, tool_calls chunks).
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""
|
||||
High-level service that runs an Agent for a given query and streams
|
||||
the results as CHAT protocol chunks via a callback.
|
||||
|
||||
Usage:
|
||||
svc = ChatService(agent_bridge)
|
||||
svc.run(query, session_id, send_chunk_fn)
|
||||
"""
|
||||
|
||||
def __init__(self, agent_bridge):
|
||||
"""
|
||||
:param agent_bridge: AgentBridge instance (manages agent lifecycle)
|
||||
"""
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def run(self, query: str, session_id: str, send_chunk_fn: Callable[[dict], None],
|
||||
channel_type: str = ""):
|
||||
"""
|
||||
Run the agent for *query* and stream results back via *send_chunk_fn*.
|
||||
|
||||
The method blocks until the agent finishes. After it returns the SDK
|
||||
will automatically send the final (streaming=false) message.
|
||||
|
||||
:param query: user query text
|
||||
:param session_id: session identifier for agent isolation
|
||||
:param send_chunk_fn: callable(chunk_data: dict) to send a streaming chunk
|
||||
:param channel_type: source channel (e.g. "web", "feishu") for persistence
|
||||
"""
|
||||
agent = self.agent_bridge.get_agent(session_id=session_id)
|
||||
if agent is None:
|
||||
raise RuntimeError("Failed to initialise agent for the session")
|
||||
|
||||
# Pass context metadata to model for downstream API requests
|
||||
if hasattr(agent, 'model'):
|
||||
agent.model.channel_type = channel_type or ""
|
||||
agent.model.session_id = session_id or ""
|
||||
|
||||
# State shared between the event callback and this method
|
||||
state = _StreamState()
|
||||
|
||||
def on_event(event: dict):
|
||||
"""Translate agent events into CHAT protocol chunks."""
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "reasoning_update":
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "reasoning",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_update":
|
||||
# Incremental text delta
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_end":
|
||||
# A content segment finished.
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
# After tool_calls are executed the next content will be
|
||||
# a new segment; collect tool results until turn_end.
|
||||
state.pending_tool_results = []
|
||||
|
||||
elif event_type == "file_to_send":
|
||||
url = data.get("url") or ""
|
||||
if url:
|
||||
fname = data.get("file_name") or "file"
|
||||
ft = data.get("file_type") or "file"
|
||||
if ft == "image":
|
||||
link = f""
|
||||
else:
|
||||
link = f"[{fname}]({url})"
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": "\n\n" + link + "\n\n",
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
# Remove url so the model won't repeat it in its reply
|
||||
data.pop("url", None)
|
||||
|
||||
elif event_type == "tool_execution_start":
|
||||
# Notify the client that a tool is about to run (with its input args)
|
||||
tool_name = data.get("tool_name", "")
|
||||
arguments = data.get("arguments", {})
|
||||
# Cache arguments keyed by tool_call_id so tool_execution_end can include them
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
state.pending_tool_arguments[tool_call_id] = arguments
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_start",
|
||||
"tool": tool_name,
|
||||
"arguments": arguments,
|
||||
})
|
||||
|
||||
elif event_type == "tool_execution_end":
|
||||
tool_name = data.get("tool_name", "")
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
# Retrieve cached arguments from the matching tool_execution_start event
|
||||
arguments = state.pending_tool_arguments.pop(tool_call_id, data.get("arguments", {}))
|
||||
result = data.get("result", "")
|
||||
status = data.get("status", "unknown")
|
||||
execution_time = data.get("execution_time", 0)
|
||||
elapsed_str = f"{execution_time:.2f}s"
|
||||
|
||||
# Serialise result to string if needed
|
||||
if not isinstance(result, str):
|
||||
import json
|
||||
try:
|
||||
result = json.dumps(result, ensure_ascii=False)
|
||||
except Exception:
|
||||
result = str(result)
|
||||
|
||||
tool_info = {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
"result": result,
|
||||
"status": status,
|
||||
"elapsed": elapsed_str,
|
||||
}
|
||||
|
||||
if state.pending_tool_results is not None:
|
||||
state.pending_tool_results.append(tool_info)
|
||||
|
||||
elif event_type == "turn_end":
|
||||
has_tool_calls = data.get("has_tool_calls", False)
|
||||
if has_tool_calls and state.pending_tool_results:
|
||||
# Flush collected tool results as a single tool_calls chunk
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_calls",
|
||||
"tool_calls": state.pending_tool_results,
|
||||
})
|
||||
state.pending_tool_results = None
|
||||
# Next content belongs to a new segment
|
||||
state.segment_id += 1
|
||||
|
||||
# Run the agent with our event callback ---------------------------
|
||||
logger.info(f"[ChatService] Starting agent run: session={session_id}, query={query[:80]}")
|
||||
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = agent.get_full_system_prompt()
|
||||
|
||||
# Create a copy of messages for this execution
|
||||
with agent.messages_lock:
|
||||
messages_copy = agent.messages.copy()
|
||||
original_length = len(agent.messages)
|
||||
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
|
||||
executor = AgentStreamExecutor(
|
||||
agent=agent,
|
||||
model=agent.model,
|
||||
system_prompt=full_system_prompt,
|
||||
tools=agent.tools,
|
||||
max_turns=agent.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy,
|
||||
max_context_turns=max_context_turns,
|
||||
)
|
||||
|
||||
try:
|
||||
response = executor.run_stream(query)
|
||||
except Exception:
|
||||
# If executor cleared messages (context overflow), sync back
|
||||
if len(executor.messages) == 0:
|
||||
with agent.messages_lock:
|
||||
agent.messages.clear()
|
||||
logger.info("[ChatService] Cleared agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Sync executor messages back to agent (thread-safe).
|
||||
# The executor may have trimmed context, making its list shorter than
|
||||
# original_length. In that case we must replace entirely — just
|
||||
# appending would leave stale pre-trim messages in agent.messages
|
||||
# and cause the same trim to fire on every subsequent request.
|
||||
with agent.messages_lock:
|
||||
trimmed = len(executor.messages) < original_length
|
||||
if trimmed:
|
||||
# Context was trimmed: the executor appended the new user
|
||||
# query *before* trimming, so the new messages (user +
|
||||
# assistant + tools) sit at the tail of the trimmed list.
|
||||
# We cannot simply slice at original_length (it exceeds the
|
||||
# list length). Instead, count how many messages the
|
||||
# executor added on top of the post-trim baseline.
|
||||
#
|
||||
# Timeline inside executor.run_stream:
|
||||
# 1. messages had `original_length` items
|
||||
# 2. append user query → original_length + 1
|
||||
# 3. _trim_messages() → some smaller number (includes the
|
||||
# user query because it belongs to the last turn)
|
||||
# 4. LLM replies / tool calls appended
|
||||
#
|
||||
# The user query message is always the first message of the
|
||||
# last turn (it cannot be trimmed away), so we locate it to
|
||||
# find where "new" messages begin.
|
||||
new_start = original_length # fallback
|
||||
for idx in range(len(executor.messages) - 1, -1, -1):
|
||||
msg = executor.messages[idx]
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", [])
|
||||
is_user_query = False
|
||||
if isinstance(content, list):
|
||||
has_text = any(
|
||||
isinstance(b, dict) and b.get("type") == "text"
|
||||
for b in content
|
||||
)
|
||||
has_tool_result = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
)
|
||||
is_user_query = has_text and not has_tool_result
|
||||
elif isinstance(content, str):
|
||||
is_user_query = True
|
||||
if is_user_query:
|
||||
new_start = idx
|
||||
break
|
||||
new_messages = list(executor.messages[new_start:])
|
||||
else:
|
||||
new_messages = list(executor.messages[original_length:])
|
||||
agent.messages = list(executor.messages)
|
||||
|
||||
# Persist new messages to SQLite so they survive restarts and
|
||||
# can be queried via the HISTORY interface.
|
||||
if new_messages:
|
||||
self._persist_messages(session_id, list(new_messages), channel_type)
|
||||
|
||||
# Store executor reference for files_to_send access
|
||||
agent.stream_executor = executor
|
||||
|
||||
# Execute post-process tools
|
||||
agent._execute_post_process_tools()
|
||||
|
||||
logger.info(f"[ChatService] Agent run completed: session={session_id}")
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _persist_messages(session_id: str, new_messages: list, channel_type: str = ""):
|
||||
try:
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().append_messages(
|
||||
session_id, new_messages, channel_type=channel_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[ChatService] Failed to persist messages for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
class _StreamState:
|
||||
"""Mutable state shared between the event callback and the run method."""
|
||||
|
||||
def __init__(self):
|
||||
self.segment_id: int = 0
|
||||
# None means we are not accumulating tool results right now.
|
||||
# A list means we are in the middle of a tool-execution phase.
|
||||
self.pending_tool_results: Optional[list] = None
|
||||
# Maps tool_call_id -> arguments captured from tool_execution_start,
|
||||
# so that tool_execution_end can attach the correct input args.
|
||||
self.pending_tool_arguments: dict = {}
|
||||
241
agent/chat/session_service.py
Normal file
241
agent/chat/session_service.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
SessionService - Manages multi-session lifecycle for both web channel and cloud client.
|
||||
|
||||
Provides a unified interface for listing, deleting, renaming, clearing context,
|
||||
and generating AI titles for conversation sessions. Backed by ConversationStore
|
||||
(SQLite) and AgentBridge (in-memory agent instances).
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def _truncate_fallback_title(user_message: str, max_len: int = 30) -> str:
|
||||
"""Pick the first non-empty line of the user message and truncate it."""
|
||||
if not user_message:
|
||||
return "New Chat"
|
||||
first_line = ""
|
||||
for line in user_message.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
first_line = line
|
||||
break
|
||||
if not first_line:
|
||||
return "New Chat"
|
||||
if len(first_line) > max_len:
|
||||
first_line = first_line[:max_len].rstrip() + "..."
|
||||
return first_line
|
||||
|
||||
|
||||
def generate_session_title(user_message: str, assistant_reply: str = "") -> str:
|
||||
"""
|
||||
Generate a short session title by calling the current bot's reply_text.
|
||||
Falls back to the first line of the user message if the LLM call fails
|
||||
or returns an obvious error sentinel.
|
||||
"""
|
||||
fallback = _truncate_fallback_title(user_message)
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
from models.session_manager import Session
|
||||
bot = Bridge().get_bot("chat")
|
||||
|
||||
prompt_parts = [f"User: {user_message[:300]}"]
|
||||
if assistant_reply:
|
||||
prompt_parts.append(f"Assistant: {assistant_reply[:300]}")
|
||||
|
||||
session = Session("__title_gen__", system_prompt="")
|
||||
session.messages = [
|
||||
{"role": "user", "content": (
|
||||
"Generate a very short title (max 15 characters for Chinese, max 6 words for English) "
|
||||
"summarizing this conversation. Return ONLY the title text, nothing else.\n\n"
|
||||
+ "\n".join(prompt_parts)
|
||||
)}
|
||||
]
|
||||
|
||||
result = bot.reply_text(session) or {}
|
||||
# When bots fail (network error, auth error, rate limit, etc.) they
|
||||
# typically return completion_tokens=0 with a sentinel content like
|
||||
# "请再问我一次吧" / "我现在有点累了". Treat that as failure.
|
||||
completion_tokens = result.get("completion_tokens", 0) or 0
|
||||
raw = (result.get("content") or "").strip()
|
||||
if completion_tokens <= 0:
|
||||
logger.warning(
|
||||
f"[SessionService] Title generation got empty completion "
|
||||
f"(completion_tokens={completion_tokens}, content='{raw[:50]}'), "
|
||||
f"using fallback")
|
||||
return fallback
|
||||
|
||||
title = re.sub(r'<think>.*?</think>', '', raw, flags=re.DOTALL).strip().strip('"\'')
|
||||
logger.info(f"[SessionService] Title generation result: '{title}' (len={len(title)})")
|
||||
if title and len(title) <= 50:
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.warning(f"[SessionService] Title generation failed: {e}")
|
||||
return fallback
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""
|
||||
High-level service for session lifecycle management.
|
||||
|
||||
Usage:
|
||||
svc = SessionService()
|
||||
result = svc.dispatch("list", {"channel_type": "web", "page": 1})
|
||||
"""
|
||||
|
||||
def _get_store(self):
|
||||
from agent.memory import get_conversation_store
|
||||
return get_conversation_store()
|
||||
|
||||
def _remove_agent(self, session_id: str):
|
||||
"""Remove the in-memory Agent instance for a session if it exists."""
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
ab = Bridge().get_agent_bridge()
|
||||
if session_id in ab.agents:
|
||||
del ab.agents[session_id]
|
||||
logger.info(f"[SessionService] Removed agent instance: {session_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _normalize_sid(session_id: str) -> str:
|
||||
if session_id and not session_id.startswith("session_"):
|
||||
return f"session_{session_id}"
|
||||
return session_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# actions
|
||||
# ------------------------------------------------------------------
|
||||
def list_sessions(self, channel_type: Optional[str] = None,
|
||||
page: int = 1, page_size: int = 50) -> dict:
|
||||
store = self._get_store()
|
||||
return store.list_sessions(
|
||||
channel_type=channel_type,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
def delete_session(self, session_id: str) -> None:
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
store.clear_session(session_id)
|
||||
self._remove_agent(session_id)
|
||||
logger.info(f"[SessionService] Session deleted: {session_id}")
|
||||
|
||||
def rename_session(self, session_id: str, title: str) -> None:
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
if not title:
|
||||
raise ValueError("title required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
found = store.rename_session(session_id, title)
|
||||
if not found:
|
||||
raise ValueError("session not found")
|
||||
|
||||
def clear_context(self, session_id: str) -> int:
|
||||
"""
|
||||
Set context boundary. Returns the new context_start_seq value.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
new_seq = store.clear_context(session_id)
|
||||
self._remove_agent(session_id)
|
||||
return new_seq
|
||||
|
||||
def gen_title(self, session_id: str, user_message: str,
|
||||
assistant_reply: str = "") -> str:
|
||||
"""
|
||||
Generate an AI title and persist it. Returns the generated title.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
if not user_message:
|
||||
raise ValueError("user_message required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
title = generate_session_title(user_message, assistant_reply)
|
||||
|
||||
store = self._get_store()
|
||||
updated = store.rename_session(session_id, title)
|
||||
logger.info(f"[SessionService] Title set: sid={session_id}, "
|
||||
f"title='{title}', db_updated={updated}")
|
||||
return title
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a session management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
Action names use a ``*_session`` / session-prefixed convention so they
|
||||
can coexist with history actions (e.g. ``query``) on the same HISTORY
|
||||
message channel without ambiguity.
|
||||
|
||||
Supported actions:
|
||||
- list_sessions: list sessions with pagination
|
||||
- delete_session: delete a session
|
||||
- rename_session: rename a session title
|
||||
- clear_context: set context boundary
|
||||
- generate_title: AI-generate a session title
|
||||
|
||||
:param action: one of the above action names
|
||||
:param payload: action-specific payload
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list_sessions":
|
||||
result = self.list_sessions(
|
||||
channel_type=payload.get("channel_type"),
|
||||
page=int(payload.get("page", 1)),
|
||||
page_size=int(payload.get("page_size", 50)),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "delete_session":
|
||||
self.delete_session(payload.get("session_id", ""))
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
|
||||
elif action == "rename_session":
|
||||
self.rename_session(
|
||||
payload.get("session_id", ""),
|
||||
payload.get("title", "").strip(),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
|
||||
elif action == "clear_context":
|
||||
new_seq = self.clear_context(payload.get("session_id", ""))
|
||||
return {"action": action, "code": 200, "message": "success",
|
||||
"payload": {"context_start_seq": new_seq}}
|
||||
|
||||
elif action == "generate_title":
|
||||
title = self.gen_title(
|
||||
payload.get("session_id", ""),
|
||||
payload.get("user_message", ""),
|
||||
payload.get("assistant_reply", ""),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success",
|
||||
"payload": {"title": title}}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400,
|
||||
"message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 400, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SessionService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
0
agent/knowledge/__init__.py
Normal file
0
agent/knowledge/__init__.py
Normal file
240
agent/knowledge/service.py
Normal file
240
agent/knowledge/service.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Knowledge service for handling knowledge base operations.
|
||||
|
||||
Provides a unified interface for listing, reading, and graphing knowledge files,
|
||||
callable from the web console, API, or CLI.
|
||||
|
||||
Knowledge file layout (under workspace_root):
|
||||
knowledge/index.md
|
||||
knowledge/log.md
|
||||
knowledge/<category>/<slug>.md
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
"""
|
||||
High-level service for knowledge base queries.
|
||||
Operates directly on the filesystem.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
self.workspace_root = workspace_root
|
||||
self.knowledge_dir = os.path.join(workspace_root, "knowledge")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — directory tree with stats
|
||||
# ------------------------------------------------------------------
|
||||
def list_tree(self) -> dict:
|
||||
"""
|
||||
Return the knowledge directory tree grouped by category,
|
||||
supporting arbitrarily nested sub-directories.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"tree": [
|
||||
{
|
||||
"dir": "concepts",
|
||||
"files": [
|
||||
{"name": "moe.md", "title": "MoE", "size": 1234},
|
||||
],
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"dir": "platform",
|
||||
"files": [],
|
||||
"children": [
|
||||
{
|
||||
"dir": "analysis",
|
||||
"files": [{"name": "perf.md", ...}],
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
],
|
||||
"stats": {"pages": 15, "size": 32768},
|
||||
"enabled": true
|
||||
}
|
||||
"""
|
||||
if not os.path.isdir(self.knowledge_dir):
|
||||
return {"tree": [], "stats": {"pages": 0, "size": 0}, "enabled": conf().get("knowledge", True)}
|
||||
|
||||
stats = {"pages": 0, "size": 0}
|
||||
root_files, tree = self._scan_dir(self.knowledge_dir, stats, is_root=True)
|
||||
|
||||
return {
|
||||
"root_files": root_files,
|
||||
"tree": tree,
|
||||
"stats": stats,
|
||||
"enabled": conf().get("knowledge", True),
|
||||
}
|
||||
|
||||
def _scan_dir(self, dir_path: str, stats: dict, is_root: bool = False) -> tuple:
|
||||
"""
|
||||
Recursively scan a directory.
|
||||
|
||||
:return: (files, children) where files is a list of .md file dicts
|
||||
in this directory and children is a list of sub-directory nodes.
|
||||
"""
|
||||
files = []
|
||||
children = []
|
||||
for name in sorted(os.listdir(dir_path)):
|
||||
if name.startswith("."):
|
||||
continue
|
||||
full = os.path.join(dir_path, name)
|
||||
if os.path.isdir(full):
|
||||
sub_files, sub_children = self._scan_dir(full, stats)
|
||||
children.append({"dir": name, "files": sub_files, "children": sub_children})
|
||||
elif name.endswith(".md"):
|
||||
size = os.path.getsize(full)
|
||||
if not is_root:
|
||||
stats["pages"] += 1
|
||||
stats["size"] += size
|
||||
title = name.replace(".md", "")
|
||||
try:
|
||||
with open(full, "r", encoding="utf-8") as f:
|
||||
first_line = f.readline().strip()
|
||||
if first_line.startswith("# "):
|
||||
title = first_line[2:].strip()
|
||||
except Exception:
|
||||
pass
|
||||
files.append({"name": name, "title": title, "size": size})
|
||||
return files, children
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# read — single file content
|
||||
# ------------------------------------------------------------------
|
||||
def read_file(self, rel_path: str) -> dict:
|
||||
"""
|
||||
Read a single knowledge markdown file.
|
||||
|
||||
:param rel_path: Relative path within knowledge/, e.g. ``concepts/moe.md``
|
||||
:return: dict with ``content`` and ``path``
|
||||
:raises ValueError: if path is invalid or escapes knowledge dir
|
||||
:raises FileNotFoundError: if file does not exist
|
||||
"""
|
||||
if not rel_path or ".." in rel_path:
|
||||
raise ValueError("invalid path")
|
||||
|
||||
full_path = os.path.normpath(os.path.join(self.knowledge_dir, rel_path))
|
||||
allowed = os.path.normpath(self.knowledge_dir)
|
||||
if not full_path.startswith(allowed + os.sep) and full_path != allowed:
|
||||
raise ValueError("path outside knowledge dir")
|
||||
|
||||
if not os.path.isfile(full_path):
|
||||
raise FileNotFoundError(f"file not found: {rel_path}")
|
||||
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
return {"content": content, "path": rel_path}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# graph — nodes and links for visualization
|
||||
# ------------------------------------------------------------------
|
||||
def build_graph(self) -> dict:
|
||||
"""
|
||||
Parse all knowledge pages and extract cross-reference links.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"nodes": [
|
||||
{"id": "concepts/moe.md", "label": "MoE", "category": "concepts"},
|
||||
...
|
||||
],
|
||||
"links": [
|
||||
{"source": "concepts/moe.md", "target": "entities/deepseek.md"},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
knowledge_path = Path(self.knowledge_dir)
|
||||
if not knowledge_path.is_dir():
|
||||
return {"nodes": [], "links": []}
|
||||
|
||||
nodes = {}
|
||||
links = []
|
||||
link_re = re.compile(r'\[([^\]]*)\]\(([^)]+\.md)\)')
|
||||
|
||||
for md_file in knowledge_path.rglob("*.md"):
|
||||
rel = str(md_file.relative_to(knowledge_path))
|
||||
if rel in ("index.md", "log.md"):
|
||||
continue
|
||||
parts = rel.split("/")
|
||||
category = parts[0] if len(parts) > 1 else "root"
|
||||
title = md_file.stem.replace("-", " ").title()
|
||||
try:
|
||||
content = md_file.read_text(encoding="utf-8")
|
||||
first_line = content.strip().split("\n")[0]
|
||||
if first_line.startswith("# "):
|
||||
title = first_line[2:].strip()
|
||||
for _, link_target in link_re.findall(content):
|
||||
resolved = (md_file.parent / link_target).resolve()
|
||||
try:
|
||||
target_rel = str(resolved.relative_to(knowledge_path))
|
||||
except ValueError:
|
||||
continue
|
||||
if target_rel != rel:
|
||||
links.append({"source": rel, "target": target_rel})
|
||||
except Exception:
|
||||
pass
|
||||
nodes[rel] = {"id": rel, "label": title, "category": category}
|
||||
|
||||
valid_ids = set(nodes.keys())
|
||||
links = [l for l in links if l["source"] in valid_ids and l["target"] in valid_ids]
|
||||
seen = set()
|
||||
deduped = []
|
||||
for l in links:
|
||||
key = tuple(sorted([l["source"], l["target"]]))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(l)
|
||||
|
||||
return {"nodes": list(nodes.values()), "links": deduped}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a knowledge management action.
|
||||
|
||||
:param action: ``list``, ``read``, or ``graph``
|
||||
:param payload: action-specific payload
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
result = self.list_tree()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "read":
|
||||
path = payload.get("path")
|
||||
if not path:
|
||||
return {"action": action, "code": 400, "message": "path is required", "payload": None}
|
||||
result = self.read_file(path)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "graph":
|
||||
result = self.build_graph()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 403, "message": str(e), "payload": None}
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[KnowledgeService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
@@ -1,11 +1,23 @@
|
||||
"""
|
||||
Memory module for AgentMesh
|
||||
|
||||
Provides long-term memory capabilities with hybrid search (vector + keyword)
|
||||
Provides both long-term memory (vector/keyword search) and short-term
|
||||
conversation history persistence (SQLite).
|
||||
"""
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
from agent.memory.embedding import create_embedding_provider
|
||||
from agent.memory.conversation_store import ConversationStore, get_conversation_store
|
||||
from agent.memory.summarizer import ensure_daily_memory_file
|
||||
|
||||
__all__ = ['MemoryManager', 'MemoryConfig', 'get_default_memory_config', 'set_global_memory_config', 'create_embedding_provider']
|
||||
__all__ = [
|
||||
'MemoryManager',
|
||||
'MemoryConfig',
|
||||
'get_default_memory_config',
|
||||
'set_global_memory_config',
|
||||
'create_embedding_provider',
|
||||
'ConversationStore',
|
||||
'get_conversation_store',
|
||||
'ensure_daily_memory_file',
|
||||
]
|
||||
|
||||
@@ -4,6 +4,7 @@ Text chunking utilities for memory
|
||||
Splits text into chunks with token limits and overlap
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -4,18 +4,25 @@ Memory configuration module
|
||||
Provides global memory configuration with simplified workspace structure
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _default_workspace():
|
||||
"""Get default workspace path with proper Windows support"""
|
||||
from common.utils import expand_path
|
||||
return expand_path("~/cow")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig:
|
||||
"""Configuration for memory storage and search"""
|
||||
|
||||
# Storage paths (default: ~/cow)
|
||||
workspace_root: str = field(default_factory=lambda: os.path.expanduser("~/cow"))
|
||||
workspace_root: str = field(default_factory=_default_workspace)
|
||||
|
||||
# Embedding config
|
||||
embedding_provider: str = "openai" # "openai" | "local"
|
||||
@@ -41,9 +48,6 @@ class MemoryConfig:
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
# Memory flush config (独立于模型 context window)
|
||||
flush_token_threshold: int = 50000 # 50K tokens 触发 flush
|
||||
flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
|
||||
1055
agent/memory/conversation_store.py
Normal file
1055
agent/memory/conversation_store.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,148 +0,0 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports OpenAI and local embedding models
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Get embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using REST API"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
Initialize OpenAI embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
||||
api_key: OpenAI API key
|
||||
api_base: Optional API base URL
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key is required")
|
||||
|
||||
# Set dimensions based on model
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI embedding API using requests"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, timeout=30)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
result = self._call_api(text)
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
result = self._call_api(texts)
|
||||
return [item["embedding"] for item in result["data"]]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
"""Get cached embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
"""Cache embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
"""Compute cache key"""
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
"""Clear cache"""
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Only supports OpenAI embedding via REST API.
|
||||
If initialization fails, caller should fall back to keyword-only search.
|
||||
|
||||
Args:
|
||||
provider: Provider name (only "openai" is supported)
|
||||
model: Model name (default: text-embedding-3-small)
|
||||
api_key: OpenAI API key (required)
|
||||
api_base: API base URL (default: https://api.openai.com/v1)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not "openai" or api_key is missing
|
||||
"""
|
||||
if provider != "openai":
|
||||
raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
|
||||
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
41
agent/memory/embedding/__init__.py
Normal file
41
agent/memory/embedding/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Embedding subsystem for memory.
|
||||
|
||||
Public API:
|
||||
create_embedding_provider, EmbeddingProvider, OpenAIEmbeddingProvider,
|
||||
EMBEDDING_VENDORS, EmbeddingCache
|
||||
RebuildResult, clear_index, rebuild_in_process
|
||||
detect_index_dim, cleanup_legacy_state_file
|
||||
"""
|
||||
|
||||
from agent.memory.embedding.provider import (
|
||||
EMBEDDING_VENDORS,
|
||||
DoubaoEmbeddingProvider,
|
||||
EmbeddingCache,
|
||||
EmbeddingProvider,
|
||||
OpenAIEmbeddingProvider,
|
||||
create_embedding_provider,
|
||||
)
|
||||
from agent.memory.embedding.rebuild import (
|
||||
RebuildResult,
|
||||
clear_index,
|
||||
rebuild_in_process,
|
||||
)
|
||||
from agent.memory.embedding.state import (
|
||||
cleanup_legacy_state_file,
|
||||
detect_index_dim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EMBEDDING_VENDORS",
|
||||
"DoubaoEmbeddingProvider",
|
||||
"EmbeddingCache",
|
||||
"EmbeddingProvider",
|
||||
"OpenAIEmbeddingProvider",
|
||||
"create_embedding_provider",
|
||||
"RebuildResult",
|
||||
"clear_index",
|
||||
"rebuild_in_process",
|
||||
"cleanup_legacy_state_file",
|
||||
"detect_index_dim",
|
||||
]
|
||||
486
agent/memory/embedding/provider.py
Normal file
486
agent/memory/embedding/provider.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports multiple OpenAI-compatible embedding vendors:
|
||||
- openai (text-embedding-3-small / large)
|
||||
- linkai (OpenAI-compatible passthrough)
|
||||
- dashscope (Aliyun Tongyi text-embedding-v4)
|
||||
- doubao (ByteDance Doubao Seed1.5 / large-text on Volcengine Ark)
|
||||
- zhipu (ZhipuAI embedding-3)
|
||||
|
||||
Vendor keys here intentionally match the project's bot_type constants in
|
||||
common.const (OPENAI, LINKAI, QWEN_DASHSCOPE, DOUBAO, ZHIPU_AI).
|
||||
|
||||
All providers share a single OpenAI-compatible REST client. Vendor-specific
|
||||
behaviors (truncation, query instruction prefix) are configured via metadata.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
# HTTP read timeout for a single embeddings request (seconds). A batch of
|
||||
# 64+ chunks can take 30-50s end-to-end from China-side networks, so 30s is
|
||||
# routinely too tight; 90s gives meaningful headroom without letting bad
|
||||
# endpoints hang forever.
|
||||
EMBEDDING_HTTP_TIMEOUT = 90
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a single text (treated as a query by default)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts (treated as documents)"""
|
||||
pass
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a query string (may apply vendor instruction prefix)"""
|
||||
return self.embed(text)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Effective embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vendor metadata table
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Each entry describes how to reach a vendor's embedding endpoint. Most
|
||||
# vendors expose an OpenAI-compatible /embeddings API; the few that don't
|
||||
# (currently: doubao) set `provider_class` to pick a dedicated adapter.
|
||||
# Fields:
|
||||
# provider_class : optional adapter key ("doubao"); defaults to OpenAI-compat
|
||||
# default_base_url : default API base when not overridden by user
|
||||
# default_model : default embedding model name
|
||||
# default_dimensions : recommended unified dim when explicit path is enabled
|
||||
# supports_dim_param : whether the API accepts a `dimensions` request param
|
||||
# needs_client_truncate : whether to slice + L2-normalize on the client side
|
||||
# needs_client_normalize : whether to L2-normalize on the client (always safe)
|
||||
# query_instruction : optional prefix for asymmetric retrieval (Doubao Seed)
|
||||
# max_batch_size : max texts per /embeddings request; embed_batch
|
||||
# auto-paginates above this. Conservative defaults.
|
||||
#
|
||||
EMBEDDING_VENDORS = {
|
||||
"openai": {
|
||||
"default_base_url": "https://api.openai.com/v1",
|
||||
"default_model": "text-embedding-3-small",
|
||||
# Match the legacy default so users adding `embedding_provider: openai`
|
||||
# to an existing index don't need to rebuild. Override via
|
||||
# embedding_dimensions if you want 1024 / 1536 / 3072.
|
||||
"default_dimensions": 1536,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
# OpenAI permits up to 2048 items per request, but a single call
|
||||
# carrying hundreds of long chunks routinely exceeds the 30s read
|
||||
# timeout from China-side networks. 64 keeps each call well under
|
||||
# both the token-per-request budget and a reasonable wall clock.
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
"linkai": {
|
||||
"default_base_url": "https://api.link-ai.tech/v1",
|
||||
"default_model": "text-embedding-3-small",
|
||||
"default_dimensions": 1536,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
"dashscope": {
|
||||
"default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"default_model": "text-embedding-v4",
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 10, # DashScope hard cap (text-embedding-v4)
|
||||
},
|
||||
"doubao": {
|
||||
# Doubao no longer offers an OpenAI-compatible /v1/embeddings endpoint.
|
||||
# Current models are unified under /api/v3/embeddings/multimodal
|
||||
# which uses a structured `input` payload — see DoubaoEmbeddingProvider.
|
||||
"provider_class": "doubao",
|
||||
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||||
"default_model": "doubao-embedding-vision-251215",
|
||||
# Native options: 1024 or 2048. We default to 1024 to align with the
|
||||
# other Chinese vendors (dashscope/zhipu) and keep storage footprint
|
||||
# consistent across providers; users can still override via
|
||||
# `embedding_dimensions: 2048` in config.
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
# Multimodal endpoint produces ONE embedding per call (input list is
|
||||
# a single document's parts, not a batch). embed_batch loops.
|
||||
"max_batch_size": 1,
|
||||
},
|
||||
"zhipu": {
|
||||
"default_base_url": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"default_model": "embedding-3",
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _l2_normalize(vec: List[float]) -> List[float]:
|
||||
"""Normalize a vector to unit length (L2 norm). Returns input on zero vector."""
|
||||
norm = math.sqrt(sum(v * v for v in vec))
|
||||
if norm == 0:
|
||||
return vec
|
||||
return [v / norm for v in vec]
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
OpenAI-compatible embedding provider.
|
||||
|
||||
Used for openai/linkai/dashscope/ark/zhipu by configuring the metadata
|
||||
fields. The legacy two-arg constructor (model, api_key, api_base) keeps
|
||||
working, so the original OpenAI/LinkAI fallback code path is unchanged.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "text-embedding-3-small",
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
supports_dim_param: bool = True,
|
||||
needs_client_truncate: bool = False,
|
||||
needs_client_normalize: bool = False,
|
||||
query_instruction: str = "",
|
||||
max_batch_size: int = 256,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Model name (e.g. text-embedding-3-small, text-embedding-v4, embedding-3)
|
||||
api_key: API key (required)
|
||||
api_base: API base URL (defaults to OpenAI)
|
||||
extra_headers: Optional extra HTTP headers
|
||||
dimensions: Target output dimension. Required when supports_dim_param
|
||||
is False and needs_client_truncate is True (used to slice).
|
||||
supports_dim_param: Whether the vendor accepts a `dimensions` body param
|
||||
needs_client_truncate: Slice the returned vector to `dimensions`
|
||||
needs_client_normalize: L2-normalize on the client after slicing
|
||||
query_instruction: Optional prefix prepended to query texts only
|
||||
max_batch_size: Max items per /embeddings request; embed_batch
|
||||
auto-paginates above this.
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.supports_dim_param = supports_dim_param
|
||||
self.needs_client_truncate = needs_client_truncate
|
||||
self.needs_client_normalize = needs_client_normalize
|
||||
self.query_instruction = query_instruction or ""
|
||||
self.max_batch_size = max(1, int(max_batch_size or 1))
|
||||
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("Embedding API key is not configured")
|
||||
|
||||
if dimensions is not None and dimensions > 0:
|
||||
self._dimensions = dimensions
|
||||
else:
|
||||
# Legacy heuristic for OpenAI text-embedding-3-* family
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI-compatible /embeddings endpoint"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**self.extra_headers,
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model,
|
||||
}
|
||||
if self.supports_dim_param and self._dimensions:
|
||||
data["dimensions"] = self._dimensions
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=EMBEDDING_HTTP_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to embedding API at {url}. "
|
||||
f"Please check network and api_base. Error: {str(e)}"
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"Embedding API request timed out. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError("Invalid embedding API key")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError("Embedding API rate limit exceeded")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Embedding API request failed: "
|
||||
f"{e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
|
||||
def _post_process(self, raw: List[float]) -> List[float]:
|
||||
"""Apply optional client-side truncation + normalization"""
|
||||
vec = raw
|
||||
if self.needs_client_truncate and self._dimensions and len(vec) > self._dimensions:
|
||||
vec = vec[: self._dimensions]
|
||||
if self.needs_client_normalize:
|
||||
vec = _l2_normalize(vec)
|
||||
return vec
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding (treated as document by default)"""
|
||||
result = self._call_api(text)
|
||||
return self._post_process(result["data"][0]["embedding"])
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a query (applies vendor instruction prefix if any)"""
|
||||
if self.query_instruction:
|
||||
text = f"{self.query_instruction}{text}"
|
||||
return self.embed(text)
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple documents.
|
||||
|
||||
Automatically paginates by self.max_batch_size so callers can pass any
|
||||
number of texts. Order of returned vectors matches the input order.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
out: List[List[float]] = []
|
||||
step = self.max_batch_size
|
||||
for i in range(0, len(texts), step):
|
||||
chunk = texts[i:i + step]
|
||||
result = self._call_api(chunk)
|
||||
out.extend(self._post_process(item["embedding"]) for item in result["data"])
|
||||
return out
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class DoubaoEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
Doubao (Volcengine Ark) multimodal embedding provider.
|
||||
|
||||
Doubao deprecated their OpenAI-compatible /v1/embeddings endpoint and
|
||||
unified everything under /api/v3/embeddings/multimodal, which uses a
|
||||
structured `input: [{type, text|image_url|video_url}, ...]` payload.
|
||||
|
||||
Notes:
|
||||
* The endpoint produces ONE embedding per call (input list is multiple
|
||||
modality parts of a single document, not a batch). embed_batch
|
||||
therefore loops per-text — no native batch support.
|
||||
* Native dimensions: 1024 or 2048 (default 1024 to align with other
|
||||
Chinese vendors). No client-side truncation needed.
|
||||
* Auth: Bearer ARK API key.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
|
||||
self.extra_headers = extra_headers or {}
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("Doubao embedding API key (ark_api_key) is not configured")
|
||||
|
||||
if dimensions in (1024, 2048):
|
||||
self._dimensions = dimensions
|
||||
elif dimensions is None:
|
||||
self._dimensions = 1024
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Doubao embedding dimensions must be 1024 or 2048, got {dimensions}"
|
||||
)
|
||||
|
||||
def _call_api(self, text: str) -> List[float]:
|
||||
"""One call → one embedding. multimodal endpoint takes a single
|
||||
document represented as a list of typed parts; we send a single
|
||||
text part."""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings/multimodal"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**self.extra_headers,
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": [{"type": "text", "text": text}],
|
||||
"dimensions": self._dimensions,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=EMBEDDING_HTTP_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to Doubao embedding API at {url}. "
|
||||
f"Please check network and api_base. Error: {str(e)}"
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"Doubao embedding API request timed out. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError("Invalid Doubao (ark) embedding API key")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError("Doubao embedding API rate limit exceeded")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Doubao embedding API request failed: "
|
||||
f"{e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
|
||||
# Response shape per docs: {"data": {"embedding": [...]}}
|
||||
data = body.get("data")
|
||||
if isinstance(data, dict) and "embedding" in data:
|
||||
return data["embedding"]
|
||||
# Some providers wrap as a list of one — be defensive
|
||||
if isinstance(data, list) and data and "embedding" in data[0]:
|
||||
return data[0]["embedding"]
|
||||
raise ValueError(f"Unexpected Doubao embedding response shape: {body}")
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
return self._call_api(text)
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
# Endpoint produces one embedding per call; loop. Order preserved.
|
||||
return [self._call_api(t) for t in texts]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""In-memory cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create an embedding provider.
|
||||
|
||||
Backward compatible: when called with provider in {"openai", "linkai"}
|
||||
and no `dimensions` arg, behaves exactly as before (1536-dim OpenAI).
|
||||
|
||||
New providers ("dashscope", "doubao", "zhipu") require explicit configuration
|
||||
and use the unified 1024-dim defaults from EMBEDDING_VENDORS.
|
||||
|
||||
Args:
|
||||
provider: Vendor key (one of EMBEDDING_VENDORS)
|
||||
model: Model name (uses vendor default if None)
|
||||
api_key: API key (required)
|
||||
api_base: API base URL (uses vendor default if None)
|
||||
extra_headers: Optional extra HTTP headers
|
||||
dimensions: Target output dimension (uses vendor default if None)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
meta = EMBEDDING_VENDORS.get(provider)
|
||||
if meta is None:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding provider: {provider}. "
|
||||
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}"
|
||||
)
|
||||
|
||||
# Doubao uses a non-OpenAI-compatible multimodal endpoint.
|
||||
if meta.get("provider_class") == "doubao":
|
||||
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
||||
return DoubaoEmbeddingProvider(
|
||||
model=model or meta["default_model"],
|
||||
api_key=api_key,
|
||||
api_base=api_base or meta["default_base_url"],
|
||||
extra_headers=extra_headers,
|
||||
dimensions=final_dim,
|
||||
)
|
||||
|
||||
# Legacy two-arg call for openai/linkai keeps 1536-dim default behavior
|
||||
# so existing data isn't invalidated.
|
||||
is_legacy_call = (
|
||||
provider in ("openai", "linkai")
|
||||
and dimensions is None
|
||||
)
|
||||
if is_legacy_call:
|
||||
return OpenAIEmbeddingProvider(
|
||||
model=model or "text-embedding-3-small",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
||||
return OpenAIEmbeddingProvider(
|
||||
model=model or meta["default_model"],
|
||||
api_key=api_key,
|
||||
api_base=api_base or meta["default_base_url"],
|
||||
extra_headers=extra_headers,
|
||||
dimensions=final_dim,
|
||||
supports_dim_param=meta["supports_dim_param"],
|
||||
needs_client_truncate=meta["needs_client_truncate"],
|
||||
needs_client_normalize=meta["needs_client_normalize"],
|
||||
query_instruction=meta["query_instruction"],
|
||||
max_batch_size=meta.get("max_batch_size", 256),
|
||||
)
|
||||
191
agent/memory/embedding/rebuild.py
Normal file
191
agent/memory/embedding/rebuild.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Rebuild memory vector index.
|
||||
|
||||
Recommended entry point (in-chat, while agent is running):
|
||||
/memory rebuild-index
|
||||
|
||||
Backward-compatible CLI entry (must run from project root):
|
||||
python -m agent.memory.rebuild_index
|
||||
|
||||
What it does:
|
||||
1. Probes the embedding endpoint with a tiny call to fail fast on
|
||||
bad provider/model/key — before touching the index.
|
||||
2. Clears the SQLite chunks/files tables (workspace markdown stays intact).
|
||||
3. Runs a fresh sync, regenerating embeddings with the currently configured
|
||||
provider/model/dimensions.
|
||||
|
||||
This is the only safe way to switch embedding_provider after the existing
|
||||
index has been populated by a different-dim model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class RebuildResult:
|
||||
"""Outcome of a rebuild_in_process() call"""
|
||||
ok: bool
|
||||
removed: int = 0
|
||||
chunks: int = 0
|
||||
files: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def clear_index(db_path, storage=None) -> int:
|
||||
"""Wipe chunks/files, reset FTS5, and clean up any legacy state file.
|
||||
|
||||
Args:
|
||||
db_path: Path of the index DB (also used to locate the legacy state
|
||||
file for migration cleanup, and — when *storage* is None — to
|
||||
open a fresh connection).
|
||||
storage: Optional pre-opened MemoryStorage. When provided we reuse it
|
||||
so the live connection's triggers stay in sync — opening a second
|
||||
connection would leave the original one's triggers pointing at a
|
||||
DROP'd chunks_fts table.
|
||||
|
||||
We reset (DROP+recreate) chunks_fts because its shadow tables can become
|
||||
inconsistent across rebuild cycles, causing bm25() / ORDER BY rank to
|
||||
raise "database disk image is malformed" even when raw MATCH still works.
|
||||
|
||||
Returns number of chunks removed.
|
||||
"""
|
||||
from agent.memory.embedding.state import cleanup_legacy_state_file
|
||||
from agent.memory.storage import MemoryStorage
|
||||
|
||||
owns_storage = storage is None
|
||||
if owns_storage:
|
||||
storage = MemoryStorage(db_path)
|
||||
try:
|
||||
before = storage.conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
|
||||
storage.conn.execute("DELETE FROM chunks")
|
||||
storage.conn.execute("DELETE FROM files")
|
||||
storage.conn.commit()
|
||||
storage.reset_fts5()
|
||||
finally:
|
||||
if owns_storage:
|
||||
storage.close()
|
||||
|
||||
cleanup_legacy_state_file(db_path)
|
||||
return int(before)
|
||||
|
||||
|
||||
def rebuild_in_process(memory_manager) -> RebuildResult:
|
||||
"""
|
||||
Rebuild the index using an existing, fully-initialized MemoryManager.
|
||||
|
||||
Used by the in-chat /memory rebuild-index command. The caller already has
|
||||
config loaded, embedding_provider built, and (optionally) the agent
|
||||
running, so we only need to:
|
||||
1. Clear chunks/files + state on the manager's storage.
|
||||
2. Re-sync (force=True).
|
||||
|
||||
NOTE: caller must ensure memory_manager.embedding_provider is set, otherwise
|
||||
sync() will silently skip embedding generation.
|
||||
"""
|
||||
if memory_manager is None:
|
||||
return RebuildResult(ok=False, error="memory_manager is None")
|
||||
if memory_manager.embedding_provider is None:
|
||||
return RebuildResult(ok=False, error="embedding_provider is not initialized")
|
||||
|
||||
# Probe the embedding endpoint BEFORE clearing the index. A bad
|
||||
# provider/model/key would otherwise leave the user with an empty index
|
||||
# that not even keyword search can serve.
|
||||
try:
|
||||
memory_manager.embedding_provider.embed_query("ping")
|
||||
except Exception as e:
|
||||
logger.error(f"[RebuildIndex] embedding probe failed, aborting rebuild: {e}")
|
||||
return RebuildResult(ok=False, error=f"embedding endpoint not reachable: {e}")
|
||||
|
||||
db_path = memory_manager.config.get_db_path()
|
||||
try:
|
||||
removed = clear_index(db_path, storage=memory_manager.storage)
|
||||
except Exception as e:
|
||||
logger.exception("[RebuildIndex] clear_index failed")
|
||||
return RebuildResult(ok=False, error=f"clear failed: {e}")
|
||||
|
||||
try:
|
||||
asyncio.run(memory_manager.sync(force=True))
|
||||
except RuntimeError:
|
||||
# Already inside a running event loop (rare in chat handler thread).
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(memory_manager.sync(force=True))
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.exception("[RebuildIndex] sync failed")
|
||||
return RebuildResult(ok=False, removed=removed, error=f"re-embed failed: {e}")
|
||||
|
||||
stats = memory_manager.storage.get_stats()
|
||||
chunks = int(stats.get("chunks", 0))
|
||||
embedded = int(stats.get("embedded", 0))
|
||||
|
||||
# sync() degrades to "no embeddings" on batch failure so keyword search
|
||||
# still works at startup — but in a /rebuild-index request the user
|
||||
# explicitly asked for vectors. Surface that as a failure.
|
||||
if chunks > 0 and embedded == 0:
|
||||
return RebuildResult(
|
||||
ok=False,
|
||||
removed=removed,
|
||||
chunks=chunks,
|
||||
files=int(stats.get("files", 0)),
|
||||
error=(
|
||||
"embedding API failed during sync; index now has chunks but no "
|
||||
"vectors. Check embedding provider/model/key and retry."
|
||||
),
|
||||
)
|
||||
|
||||
return RebuildResult(
|
||||
ok=True,
|
||||
removed=removed,
|
||||
chunks=chunks,
|
||||
files=int(stats.get("files", 0)),
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Standalone CLI entry. Must be run from project root (relative config path)."""
|
||||
from config import conf, load_config
|
||||
from agent.memory import MemoryConfig, MemoryManager
|
||||
|
||||
load_config()
|
||||
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
logger.info(f"[RebuildIndex] Workspace: {workspace_root}")
|
||||
logger.info(f"[RebuildIndex] Index db: {memory_config.get_db_path()}")
|
||||
|
||||
from bridge.agent_initializer import AgentInitializer
|
||||
|
||||
initializer = AgentInitializer(bridge=None, agent_bridge=None)
|
||||
embedding_provider = initializer._init_embedding_provider(memory_config, session_id=None)
|
||||
if embedding_provider is None:
|
||||
logger.error(
|
||||
"[RebuildIndex] No embedding provider could be initialized. "
|
||||
"Check your config.json. Aborting rebuild."
|
||||
)
|
||||
return 1
|
||||
|
||||
manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
result = rebuild_in_process(manager)
|
||||
if not result.ok:
|
||||
logger.error(f"[RebuildIndex] {result.error}")
|
||||
return 1
|
||||
|
||||
logger.info(
|
||||
f"[RebuildIndex] Done. removed={result.removed}, "
|
||||
f"chunks={result.chunks}, files={result.files}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
51
agent/memory/embedding/state.py
Normal file
51
agent/memory/embedding/state.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Embedding-related index utilities.
|
||||
|
||||
We don't keep a sidecar state file — the SQLite index is the source of truth
|
||||
and config.json is the source of intent. The two functions below are the
|
||||
only things needing on-disk awareness:
|
||||
|
||||
detect_index_dim : read the dim of stored vectors (display-only)
|
||||
cleanup_legacy_state_file: remove old embedding_state.json from earlier
|
||||
versions; safe no-op when absent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
PathLike = Union[str, os.PathLike]
|
||||
|
||||
|
||||
def detect_index_dim(storage) -> Optional[int]:
|
||||
"""Return the dim of the first stored embedding, or None if the index
|
||||
has no embeddings. Used by /memory status."""
|
||||
try:
|
||||
row = storage.conn.execute(
|
||||
"SELECT embedding FROM chunks WHERE embedding IS NOT NULL LIMIT 1"
|
||||
).fetchone()
|
||||
except Exception:
|
||||
return None
|
||||
if not row or not row["embedding"]:
|
||||
return None
|
||||
try:
|
||||
raw = row["embedding"]
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
# New BLOB format: 4 bytes per float32
|
||||
return len(raw) // 4
|
||||
emb = json.loads(raw)
|
||||
return len(emb) if isinstance(emb, list) else None
|
||||
except (json.JSONDecodeError, TypeError, Exception):
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_legacy_state_file(db_path: PathLike) -> None:
|
||||
"""Remove old embedding_state.json files from earlier versions.
|
||||
Safe to call repeatedly; no-op if the file is absent."""
|
||||
legacy = Path(db_path).parent / "embedding_state.json"
|
||||
try:
|
||||
legacy.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -13,7 +13,7 @@ from datetime import datetime, timedelta
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||
from agent.memory.chunker import TextChunker
|
||||
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
|
||||
from agent.memory.embedding import EmbeddingProvider, EmbeddingCache
|
||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||
|
||||
|
||||
@@ -50,30 +50,22 @@ class MemoryManager:
|
||||
overlap_tokens=self.config.chunk_overlap_tokens
|
||||
)
|
||||
|
||||
# Initialize embedding provider (optional)
|
||||
self.embedding_provider = None
|
||||
if embedding_provider:
|
||||
self.embedding_provider = embedding_provider
|
||||
else:
|
||||
# Try to create embedding provider, but allow failure
|
||||
try:
|
||||
# Get API key from environment or config
|
||||
api_key = os.environ.get('OPENAI_API_KEY')
|
||||
api_base = os.environ.get('OPENAI_API_BASE')
|
||||
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider=self.config.embedding_provider,
|
||||
model=self.config.embedding_model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
except Exception as e:
|
||||
# Embedding provider failed, but that's OK
|
||||
# We can still use keyword search and file operations
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
|
||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
||||
|
||||
# Embedding provider is owned by the caller (agent_initializer is the
|
||||
# canonical entry point and handles legacy/explicit + state validation).
|
||||
# When None is passed, memory degrades to keyword-only search instead
|
||||
# of silently re-initializing a vendor here, which would bypass the
|
||||
# caller's state checks and risk corrupting the index.
|
||||
self.embedding_provider = embedding_provider
|
||||
if self.embedding_provider is None:
|
||||
from common.log import logger
|
||||
logger.info(
|
||||
"[MemoryManager] No embedding provider; memory will use keyword search only"
|
||||
)
|
||||
|
||||
# Cache for query embeddings (avoids redundant API calls within a session)
|
||||
self._embedding_cache = EmbeddingCache()
|
||||
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
self.flush_manager = MemoryFlushManager(
|
||||
@@ -133,12 +125,21 @@ class MemoryManager:
|
||||
if self.config.sync_on_search and self._dirty:
|
||||
await self.sync()
|
||||
|
||||
# Perform vector search (if embedding provider available)
|
||||
from common.log import logger
|
||||
|
||||
# Perform vector search (if embedding provider available).
|
||||
# Failures degrade silently to keyword-only — no exception is raised.
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
try:
|
||||
from common.log import logger
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
provider_name = type(self.embedding_provider).__name__
|
||||
model_name = getattr(self.embedding_provider, 'model', '')
|
||||
cached = self._embedding_cache.get(query, provider_name, model_name)
|
||||
if cached is not None:
|
||||
query_embedding = cached
|
||||
else:
|
||||
query_embedding = self.embedding_provider.embed_query(query)
|
||||
self._embedding_cache.put(query, provider_name, model_name, query_embedding)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
@@ -147,19 +148,19 @@ class MemoryManager:
|
||||
)
|
||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Vector search failed: {e}")
|
||||
|
||||
# Perform keyword search
|
||||
logger.error(
|
||||
f"[MemoryManager] Vector search failed, falling back to keyword-only: {e}"
|
||||
)
|
||||
|
||||
# Perform keyword search (also runs as fallback when vector failed)
|
||||
keyword_results = self.storage.search_keyword(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
vector_results,
|
||||
@@ -167,7 +168,7 @@ class MemoryManager:
|
||||
self.config.vector_weight,
|
||||
self.config.keyword_weight
|
||||
)
|
||||
|
||||
|
||||
# Filter by min score and limit
|
||||
filtered = [r for r in merged if r.score >= min_score]
|
||||
return filtered[:max_results]
|
||||
@@ -249,295 +250,195 @@ class MemoryManager:
|
||||
|
||||
async def sync(self, force: bool = False):
|
||||
"""
|
||||
Synchronize memory from files
|
||||
|
||||
Synchronize memory from files.
|
||||
|
||||
Two-pass design to amortize embedding HTTP cost:
|
||||
1. Walk all files, chunk those whose hash changed, collect pending
|
||||
chunks across files. No embedding calls yet.
|
||||
2. Run a single embed_batch over the union of pending chunks (the
|
||||
provider auto-paginates by vendor cap), then persist per-file.
|
||||
|
||||
For workspaces with many small files (101 files / ~1 chunk each), this
|
||||
cuts ~100 HTTP calls down to ~ceil(total_chunks / vendor_cap).
|
||||
|
||||
Args:
|
||||
force: Force full reindex
|
||||
"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
workspace_dir = self.config.get_workspace()
|
||||
|
||||
# Scan MEMORY.md (workspace root)
|
||||
|
||||
files_to_scan: List[tuple] = [] # (file_path, source, scope, user_id)
|
||||
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
await self._sync_file(memory_file, "memory", "shared", None)
|
||||
|
||||
# Scan memory directory (including daily summaries)
|
||||
files_to_scan.append((memory_file, "memory", "shared", None))
|
||||
|
||||
if memory_dir.exists():
|
||||
for file_path in memory_dir.rglob("*.md"):
|
||||
# Determine scope and user_id from path
|
||||
rel_path = file_path.relative_to(workspace_dir)
|
||||
parts = rel_path.parts
|
||||
|
||||
# Check if it's in daily summary directory
|
||||
if "daily" in parts:
|
||||
# Daily summary files
|
||||
if "users" in parts or len(parts) > 3:
|
||||
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
|
||||
user_idx = parts.index("daily") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
rel_parts = file_path.relative_to(workspace_dir).parts
|
||||
if any(part.startswith('.') for part in rel_parts):
|
||||
continue
|
||||
# Dream diaries are narrative reflections produced by Deep
|
||||
# Dream; their factual content has already been distilled
|
||||
# into MEMORY.md. Indexing them adds noisy near-duplicates
|
||||
# that crowd out the authoritative entry in retrieval.
|
||||
if "dreams" in rel_parts:
|
||||
continue
|
||||
if "daily" in rel_parts:
|
||||
if "users" in rel_parts or len(rel_parts) > 3:
|
||||
user_idx = rel_parts.index("daily") + 1
|
||||
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared daily summary: memory/daily/2024-01-29.md
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
elif "users" in parts:
|
||||
# User-scoped memory
|
||||
user_idx = parts.index("users") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
elif "users" in rel_parts:
|
||||
user_idx = rel_parts.index("users") + 1
|
||||
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared memory
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
|
||||
await self._sync_file(file_path, "memory", scope, user_id)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
async def _sync_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
source: str,
|
||||
scope: str,
|
||||
user_id: Optional[str]
|
||||
):
|
||||
"""Sync a single file"""
|
||||
# Compute file hash
|
||||
content = file_path.read_text()
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
|
||||
# Get relative path
|
||||
workspace_dir = self.config.get_workspace()
|
||||
rel_path = str(file_path.relative_to(workspace_dir))
|
||||
|
||||
# Check if file changed
|
||||
stored_hash = self.storage.get_file_hash(rel_path)
|
||||
if stored_hash == file_hash:
|
||||
return # No changes
|
||||
|
||||
# Delete old chunks
|
||||
self.storage.delete_by_path(rel_path)
|
||||
|
||||
# Chunk and embed
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None
|
||||
))
|
||||
|
||||
# Save
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
stat = file_path.stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size
|
||||
)
|
||||
|
||||
def should_flush_memory(
|
||||
self,
|
||||
current_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window。
|
||||
使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
|
||||
Returns:
|
||||
True if memory flush should run
|
||||
"""
|
||||
return self.flush_manager.should_flush(
|
||||
current_tokens=current_tokens,
|
||||
token_threshold=self.config.flush_token_threshold,
|
||||
turn_threshold=self.config.flush_turn_threshold
|
||||
)
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数(每次用户消息+AI回复算一轮)"""
|
||||
self.flush_manager.increment_turn()
|
||||
|
||||
async def execute_memory_flush(
|
||||
self,
|
||||
agent_executor,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush before compaction
|
||||
|
||||
This runs a silent agent turn to write durable memories to disk.
|
||||
Similar to clawdbot's pre-compaction memory flush.
|
||||
|
||||
Args:
|
||||
agent_executor: Async function to execute agent with prompt
|
||||
current_tokens: Current session token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
|
||||
Example:
|
||||
>>> async def run_agent(prompt, system_prompt, silent=False):
|
||||
... # Your agent execution logic
|
||||
... pass
|
||||
>>>
|
||||
>>> if manager.should_flush_memory(current_tokens=100000):
|
||||
... await manager.execute_memory_flush(
|
||||
... agent_executor=run_agent,
|
||||
... current_tokens=100000
|
||||
... )
|
||||
"""
|
||||
success = await self.flush_manager.execute_flush(
|
||||
agent_executor=agent_executor,
|
||||
current_tokens=current_tokens,
|
||||
user_id=user_id,
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
if success:
|
||||
# Mark dirty so next search will sync the new memories
|
||||
self._dirty = True
|
||||
|
||||
return success
|
||||
|
||||
def build_memory_guidance(self, lang: str = "zh", include_context: bool = True) -> str:
|
||||
"""
|
||||
Build natural memory guidance for agent system prompt
|
||||
|
||||
Following clawdbot's approach:
|
||||
1. Load MEMORY.md as bootstrap context (blends into background)
|
||||
2. Load daily files on-demand via memory_search tool
|
||||
3. Agent should NOT proactively mention memories unless user asks
|
||||
|
||||
Args:
|
||||
lang: Language for guidance ("en" or "zh")
|
||||
include_context: Whether to include bootstrap memory context (default: True)
|
||||
MEMORY.md is loaded as background context (like clawdbot)
|
||||
Daily files are accessed via memory_search tool
|
||||
|
||||
Returns:
|
||||
Memory guidance text (and optionally context) for system prompt
|
||||
"""
|
||||
today_file = self.flush_manager.get_today_memory_file().name
|
||||
|
||||
if lang == "zh":
|
||||
guidance = f"""## 记忆系统
|
||||
files_to_scan.append((file_path, "memory", scope, user_id))
|
||||
|
||||
**背景知识**: 下方包含核心长期记忆,可直接使用。需要查找历史时,用 memory_search 搜索(搜索一次即可,不要重复)。
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
knowledge_dir = Path(workspace_dir) / "knowledge"
|
||||
if knowledge_dir.exists():
|
||||
for file_path in knowledge_dir.rglob("*.md"):
|
||||
files_to_scan.append((file_path, "knowledge", "shared", None))
|
||||
|
||||
**存储记忆**: 当用户分享重要信息时(偏好、决策、事实等),主动用 write 工具存储:
|
||||
- 长期信息 → MEMORY.md
|
||||
- 当天笔记 → memory/{today_file}
|
||||
- 静默存储,仅在明确要求时确认
|
||||
|
||||
**使用原则**: 自然使用记忆,就像你本来就知道。不需要生硬地提起或列举记忆,除非用户提到。"""
|
||||
else:
|
||||
guidance = f"""## Memory System
|
||||
|
||||
**Background Knowledge**: Core long-term memories below - use directly. For history, use memory_search once (don't repeat).
|
||||
|
||||
**Store Memories**: When user shares important info (preferences, decisions, facts), proactively write:
|
||||
- Durable info → MEMORY.md
|
||||
- Daily notes → memory/{today_file}
|
||||
- Store silently; confirm only when explicitly requested
|
||||
|
||||
**Usage**: Use memories naturally as if you always knew. Don't mention or list unless user explicitly asks."""
|
||||
|
||||
if include_context:
|
||||
# Load bootstrap context (MEMORY.md only, like clawdbot)
|
||||
bootstrap_context = self.load_bootstrap_memories()
|
||||
if bootstrap_context:
|
||||
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
|
||||
|
||||
return guidance
|
||||
|
||||
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Load bootstrap memory files for session start
|
||||
|
||||
Following clawdbot's design:
|
||||
- Only loads MEMORY.md from workspace root (long-term curated memory)
|
||||
- Daily files (memory/YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
|
||||
- User-specific MEMORY.md is also loaded if user_id provided
|
||||
|
||||
Returns memory content WITHOUT obvious headers so it blends naturally
|
||||
into the context as background knowledge.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memories
|
||||
|
||||
Returns:
|
||||
Memory content to inject into system prompt (blends naturally as background context)
|
||||
"""
|
||||
workspace_dir = self.config.get_workspace()
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
|
||||
sections = []
|
||||
|
||||
# 1. Load MEMORY.md from workspace root (long-term curated memory)
|
||||
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
# Pass 1: inline chunking + change detection. Inlined (instead of
|
||||
# calling self._prepare_file_for_sync) so this method does not depend
|
||||
# on any sibling helpers — keeps it robust against partial reloads
|
||||
# where the class object is older than the method's source.
|
||||
pending: List[Dict[str, Any]] = []
|
||||
workspace_dir_path = self.config.get_workspace()
|
||||
for file_path, source, scope, user_id in files_to_scan:
|
||||
try:
|
||||
content = memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except Exception:
|
||||
continue
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
rel_path = str(file_path.relative_to(workspace_dir_path))
|
||||
if self.storage.get_file_hash(rel_path) == file_hash:
|
||||
continue
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
continue
|
||||
pending.append({
|
||||
"file_path": file_path,
|
||||
"rel_path": rel_path,
|
||||
"source": source,
|
||||
"scope": scope,
|
||||
"user_id": user_id,
|
||||
"file_hash": file_hash,
|
||||
"chunks": chunks,
|
||||
"texts": [c.text for c in chunks],
|
||||
})
|
||||
|
||||
if not pending:
|
||||
self._dirty = False
|
||||
return
|
||||
|
||||
# Pass 2: single batched embed across all pending chunks.
|
||||
# CRITICAL: never touch the index until we hold valid embeddings.
|
||||
# If embed_batch fails, leave the existing index intact (chunks +
|
||||
# file_hash) so the next sync will retry the same files. Writing
|
||||
# NULL embeddings + updating file_hash here would mark the file as
|
||||
# "successfully synced" and silently strand it without vectors.
|
||||
all_texts: List[str] = []
|
||||
for entry in pending:
|
||||
all_texts.extend(entry["texts"])
|
||||
|
||||
if not self.embedding_provider:
|
||||
# No provider configured at all (legacy keyword-only). Persist
|
||||
# chunks without embeddings — this is the user's intent.
|
||||
all_embeddings: List[Optional[List[float]]] = [None] * len(all_texts)
|
||||
else:
|
||||
try:
|
||||
all_embeddings = self.embedding_provider.embed_batch(all_texts)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read MEMORY.md: {e}")
|
||||
|
||||
# 2. Load user-specific MEMORY.md if user_id provided
|
||||
if user_id:
|
||||
user_memory_dir = memory_dir / "users" / user_id
|
||||
user_memory_file = user_memory_dir / "MEMORY.md"
|
||||
if user_memory_file.exists():
|
||||
try:
|
||||
content = user_memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read user memory: {e}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
# Join sections without obvious headers - let memories blend naturally
|
||||
# This makes the agent feel like it "just knows" rather than "checking memory files"
|
||||
return "\n\n".join(sections)
|
||||
from common.log import logger
|
||||
logger.error(
|
||||
f"[MemoryManager] Batch embedding failed for {len(all_texts)} "
|
||||
f"chunks across {len(pending)} files: {e}. "
|
||||
f"Index left untouched; will retry on next sync."
|
||||
)
|
||||
# Bail before touching storage. self._dirty stays True so
|
||||
# callers know there is pending work.
|
||||
return
|
||||
|
||||
# Pass 3: inline persist — same self-contained reasoning as Pass 1.
|
||||
cursor = 0
|
||||
for entry in pending:
|
||||
n = len(entry["texts"])
|
||||
entry_embeddings = all_embeddings[cursor:cursor + n]
|
||||
cursor += n
|
||||
|
||||
rel_path = entry["rel_path"]
|
||||
self.storage.delete_by_path(rel_path)
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(entry["chunks"], entry_embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=entry["user_id"],
|
||||
scope=entry["scope"],
|
||||
source=entry["source"],
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None,
|
||||
))
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
stat = entry["file_path"].stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=entry["source"],
|
||||
file_hash=entry["file_hash"],
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size,
|
||||
)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
def flush_memory(
|
||||
self,
|
||||
messages: list,
|
||||
user_id: Optional[str] = None,
|
||||
reason: str = "threshold",
|
||||
max_messages: int = 10,
|
||||
context_summary_callback=None,
|
||||
) -> bool:
|
||||
"""
|
||||
Flush conversation summary to daily memory file.
|
||||
|
||||
Args:
|
||||
messages: Conversation message list
|
||||
user_id: Optional user ID
|
||||
reason: "threshold" | "overflow" | "daily_summary"
|
||||
max_messages: Max recent messages to include (0 = all)
|
||||
context_summary_callback: Optional callback(str) invoked with the
|
||||
daily summary text for in-context injection
|
||||
|
||||
Returns:
|
||||
True if flush was dispatched
|
||||
"""
|
||||
success = self.flush_manager.flush_from_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
reason=reason,
|
||||
max_messages=max_messages,
|
||||
context_summary_callback=context_summary_callback,
|
||||
)
|
||||
if success:
|
||||
self._dirty = True
|
||||
return success
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get memory status"""
|
||||
@@ -568,6 +469,37 @@ class MemoryManager:
|
||||
content = f"{path}:{start_line}:{end_line}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_temporal_decay(path: str, half_life_days: float = 30.0) -> float:
|
||||
"""
|
||||
Compute temporal decay multiplier for dated memory files.
|
||||
|
||||
Inspired by OpenClaw's temporal-decay: exponential decay based on file date.
|
||||
MEMORY.md and non-dated files are "evergreen" (no decay, multiplier=1.0).
|
||||
Daily files like memory/2025-03-01.md decay based on age.
|
||||
|
||||
Formula: multiplier = exp(-ln2/half_life * age_in_days)
|
||||
"""
|
||||
import re
|
||||
import math
|
||||
|
||||
match = re.search(r'(\d{4})-(\d{2})-(\d{2})\.md$', path)
|
||||
if not match:
|
||||
return 1.0 # evergreen: MEMORY.md, non-dated files
|
||||
|
||||
try:
|
||||
file_date = datetime(
|
||||
int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
)
|
||||
age_days = (datetime.now() - file_date).days
|
||||
if age_days <= 0:
|
||||
return 1.0
|
||||
|
||||
decay_lambda = math.log(2) / half_life_days
|
||||
return math.exp(-decay_lambda * age_days)
|
||||
except (ValueError, OverflowError):
|
||||
return 1.0
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
vector_results: List[SearchResult],
|
||||
@@ -575,8 +507,7 @@ class MemoryManager:
|
||||
vector_weight: float,
|
||||
keyword_weight: float
|
||||
) -> List[SearchResult]:
|
||||
"""Merge vector and keyword search results"""
|
||||
# Create a map by (path, start_line, end_line)
|
||||
"""Merge vector and keyword search results with temporal decay for dated files"""
|
||||
merged_map = {}
|
||||
|
||||
for result in vector_results:
|
||||
@@ -598,7 +529,6 @@ class MemoryManager:
|
||||
'keyword_score': result.score
|
||||
}
|
||||
|
||||
# Calculate combined scores
|
||||
merged_results = []
|
||||
for entry in merged_map.values():
|
||||
combined_score = (
|
||||
@@ -606,7 +536,11 @@ class MemoryManager:
|
||||
keyword_weight * entry['keyword_score']
|
||||
)
|
||||
|
||||
# Apply temporal decay for dated memory files
|
||||
result = entry['result']
|
||||
decay = self._compute_temporal_decay(result.path)
|
||||
combined_score *= decay
|
||||
|
||||
merged_results.append(SearchResult(
|
||||
path=result.path,
|
||||
start_line=result.start_line,
|
||||
@@ -617,6 +551,5 @@ class MemoryManager:
|
||||
user_id=result.user_id
|
||||
))
|
||||
|
||||
# Sort by score
|
||||
merged_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return merged_results
|
||||
|
||||
14
agent/memory/rebuild_index.py
Normal file
14
agent/memory/rebuild_index.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Backward-compatible shim for the legacy entry point:
|
||||
python -m agent.memory.rebuild_index
|
||||
|
||||
The implementation now lives in agent.memory.embedding.rebuild.
|
||||
Prefer using `/memory rebuild-index` in chat going forward.
|
||||
"""
|
||||
|
||||
from agent.memory.embedding.rebuild import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(main())
|
||||
197
agent/memory/service.py
Normal file
197
agent/memory/service.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Memory service for handling memory query operations via cloud protocol.
|
||||
|
||||
Provides a unified interface for listing and reading memory files,
|
||||
callable from the cloud client (LinkAI) or a future web console.
|
||||
|
||||
Memory file layout (under workspace_root):
|
||||
MEMORY.md -> type: global
|
||||
memory/2026-02-20.md -> type: daily
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
High-level service for memory file queries.
|
||||
Operates directly on the filesystem — no MemoryManager dependency.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
"""
|
||||
:param workspace_root: Workspace root directory (e.g. ~/cow)
|
||||
"""
|
||||
self.workspace_root = workspace_root
|
||||
self.memory_dir = os.path.join(workspace_root, "memory")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — paginated file metadata
|
||||
# ------------------------------------------------------------------
|
||||
def list_files(self, page: int = 1, page_size: int = 20, category: str = "memory") -> dict:
|
||||
"""
|
||||
List memory or dream files with metadata (without content).
|
||||
|
||||
Args:
|
||||
category: ``"memory"`` (default) — MEMORY.md + daily files;
|
||||
``"dream"`` — dream diary files from memory/dreams/
|
||||
"""
|
||||
if category == "dream":
|
||||
files = self._list_dream_files()
|
||||
else:
|
||||
files = self._list_memory_files()
|
||||
|
||||
total = len(files)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"list": files[start:end],
|
||||
}
|
||||
|
||||
def _list_memory_files(self) -> List[dict]:
|
||||
"""MEMORY.md + memory/*.md (newest first)."""
|
||||
files: List[dict] = []
|
||||
|
||||
global_path = os.path.join(self.workspace_root, "MEMORY.md")
|
||||
if os.path.isfile(global_path):
|
||||
files.append(self._file_info(global_path, "MEMORY.md", "global"))
|
||||
|
||||
if os.path.isdir(self.memory_dir):
|
||||
daily_files = []
|
||||
for name in os.listdir(self.memory_dir):
|
||||
full = os.path.join(self.memory_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
daily_files.append((name, full))
|
||||
daily_files.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in daily_files:
|
||||
files.append(self._file_info(full, name, "daily"))
|
||||
|
||||
return files
|
||||
|
||||
def _list_dream_files(self) -> List[dict]:
|
||||
"""memory/dreams/*.md (newest first)."""
|
||||
files: List[dict] = []
|
||||
dreams_dir = os.path.join(self.memory_dir, "dreams")
|
||||
|
||||
if os.path.isdir(dreams_dir):
|
||||
entries = []
|
||||
for name in os.listdir(dreams_dir):
|
||||
full = os.path.join(dreams_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
entries.append((name, full))
|
||||
entries.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in entries:
|
||||
files.append(self._file_info(full, name, "dream"))
|
||||
|
||||
return files
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# content — read a single file
|
||||
# ------------------------------------------------------------------
|
||||
def get_content(self, filename: str, category: str = "memory") -> dict:
|
||||
"""
|
||||
Read the full content of a memory or dream file.
|
||||
|
||||
:param filename: File name, e.g. ``MEMORY.md``, ``2026-02-20.md``
|
||||
:param category: ``"memory"`` or ``"dream"``
|
||||
:return: dict with ``filename`` and ``content``
|
||||
:raises FileNotFoundError: if the file does not exist
|
||||
"""
|
||||
path = self._resolve_path(filename, category)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Memory file not found: {filename}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a memory management action.
|
||||
|
||||
:param action: ``list`` or ``content``
|
||||
:param payload: action-specific payload (supports ``category``: ``"memory"`` | ``"dream"``)
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
page = payload.get("page", 1)
|
||||
page_size = payload.get("page_size", 20)
|
||||
category = payload.get("category", "memory")
|
||||
result_payload = self.list_files(page=page, page_size=page_size, category=category)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
elif action == "content":
|
||||
filename = payload.get("filename")
|
||||
if not filename:
|
||||
return {"action": action, "code": 400, "message": "filename is required", "payload": None}
|
||||
category = payload.get("category", "memory")
|
||||
result_payload = self.get_content(filename, category=category)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 403, "message": "invalid filename", "payload": None}
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[MemoryService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _resolve_path(self, filename: str, category: str = "memory") -> str:
|
||||
"""
|
||||
Safely resolve a filename to its absolute path within the allowed directory.
|
||||
|
||||
- ``MEMORY.md`` → ``{workspace_root}/MEMORY.md``
|
||||
- ``2026-02-20.md`` (memory) → ``{workspace_root}/memory/2026-02-20.md``
|
||||
- ``2026-02-20.md`` (dream) → ``{workspace_root}/memory/dreams/2026-02-20.md``
|
||||
|
||||
Raises ValueError if the resolved path escapes the allowed directory.
|
||||
"""
|
||||
if filename == "MEMORY.md":
|
||||
base_dir = self.workspace_root
|
||||
elif category == "dream":
|
||||
base_dir = os.path.join(self.memory_dir, "dreams")
|
||||
else:
|
||||
base_dir = self.memory_dir
|
||||
|
||||
resolved = os.path.realpath(os.path.join(base_dir, filename))
|
||||
allowed = os.path.realpath(base_dir)
|
||||
|
||||
if resolved != allowed and not resolved.startswith(allowed + os.sep):
|
||||
raise ValueError(f"Invalid filename: path traversal detected")
|
||||
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _file_info(path: str, filename: str, file_type: str) -> dict:
|
||||
"""Build a file metadata dict."""
|
||||
stat = os.stat(path)
|
||||
updated_at = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
||||
return {
|
||||
"filename": filename,
|
||||
"type": file_type,
|
||||
"size": stat.st_size,
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -4,22 +4,24 @@ System Prompt Builder - 系统提示词构建器
|
||||
实现模块化的系统提示词构建,支持工具、技能、记忆等多个子系统
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextFile:
|
||||
"""上下文文件"""
|
||||
"""A context file (path + content)."""
|
||||
path: str
|
||||
content: str
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""提示词构建器"""
|
||||
"""System prompt builder."""
|
||||
|
||||
def __init__(self, workspace_dir: str, language: str = "zh"):
|
||||
"""
|
||||
@@ -41,21 +43,19 @@ class PromptBuilder:
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
is_first_conversation: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建完整的系统提示词
|
||||
|
||||
Args:
|
||||
base_persona: 基础人格描述(会被context_files中的SOUL.md覆盖)
|
||||
base_persona: 基础人格描述(会被context_files中的AGENT.md覆盖)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表(SOUL.md, USER.md, README.md等)
|
||||
context_files: 上下文文件列表(AGENT.md, USER.md, RULE.md, BOOTSTRAP.md等)
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
is_first_conversation: 是否为首次对话
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
@@ -71,7 +71,6 @@ class PromptBuilder:
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
is_first_conversation=is_first_conversation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -86,175 +85,213 @@ def build_agent_system_prompt(
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
is_first_conversation: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建Agent系统提示词
|
||||
|
||||
顺序说明(按重要性和逻辑关系排列):
|
||||
1. 工具系统 - 核心能力,最先介绍
|
||||
2. 技能系统 - 紧跟工具,因为技能需要用 read 工具读取
|
||||
3. 记忆系统 - 独立的记忆能力
|
||||
4. 工作空间 - 工作环境说明
|
||||
5. 用户身份 - 用户信息(可选)
|
||||
6. 项目上下文 - SOUL.md, USER.md, AGENTS.md(定义人格和身份)
|
||||
7. 运行时信息 - 元信息(时间、模型等)
|
||||
|
||||
Build the agent system prompt.
|
||||
|
||||
Section order (by importance and logical flow):
|
||||
1. Tooling - core capabilities, introduced first
|
||||
2. Skills - right after tools, since skills are read via the read tool
|
||||
3. Memory - memory recall and writing guidance
|
||||
3.5 Knowledge - structured knowledge base (injects knowledge/index.md)
|
||||
4. Workspace - working environment description
|
||||
5. User identity - user info (optional)
|
||||
6. Project context - AGENT.md, USER.md, RULE.md, MEMORY.md, BOOTSTRAP.md
|
||||
7. Runtime info - meta info (time, model, etc.)
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
language: 语言 ("zh" 或 "en")
|
||||
base_persona: 基础人格描述(已废弃,由SOUL.md定义)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
is_first_conversation: 是否为首次对话
|
||||
**kwargs: 其他参数
|
||||
|
||||
workspace_dir: workspace directory
|
||||
language: language ("zh" or "en")
|
||||
base_persona: base persona description (deprecated, defined by AGENT.md)
|
||||
user_identity: user identity info
|
||||
tools: tool list
|
||||
context_files: context file list
|
||||
skill_manager: skill manager
|
||||
memory_manager: memory manager
|
||||
runtime_info: runtime info
|
||||
**kwargs: extra args
|
||||
|
||||
Returns:
|
||||
完整的系统提示词
|
||||
The full system prompt.
|
||||
"""
|
||||
sections = []
|
||||
|
||||
# 1. 工具系统(最重要,放在最前面)
|
||||
|
||||
# 1. Tooling (most important, goes first)
|
||||
if tools:
|
||||
sections.extend(_build_tooling_section(tools, language))
|
||||
|
||||
# 2. 技能系统(紧跟工具,因为需要用 read 工具)
|
||||
|
||||
# 2. Skills (right after tools, since they need the read tool)
|
||||
if skill_manager:
|
||||
sections.extend(_build_skills_section(skill_manager, tools, language))
|
||||
|
||||
# 3. 记忆系统(独立的记忆能力)
|
||||
|
||||
# 3. Memory (standalone memory capability)
|
||||
if memory_manager:
|
||||
sections.extend(_build_memory_section(memory_manager, tools, language))
|
||||
|
||||
# 4. 工作空间(工作环境说明)
|
||||
sections.extend(_build_workspace_section(workspace_dir, language, is_first_conversation))
|
||||
|
||||
# 5. 用户身份(如果有)
|
||||
|
||||
# 3.5 Knowledge (structured knowledge base)
|
||||
if conf().get("knowledge", True):
|
||||
sections.extend(_build_knowledge_section(workspace_dir, language))
|
||||
|
||||
# 4. Workspace (working environment description)
|
||||
sections.extend(_build_workspace_section(workspace_dir, language))
|
||||
|
||||
# 5. User identity (if present)
|
||||
if user_identity:
|
||||
sections.extend(_build_user_identity_section(user_identity, language))
|
||||
|
||||
# 6. 项目上下文文件(SOUL.md, USER.md, AGENTS.md - 定义人格)
|
||||
|
||||
# 6. Project context files (AGENT.md, USER.md, RULE.md - define the persona)
|
||||
if context_files:
|
||||
sections.extend(_build_context_files_section(context_files, language))
|
||||
|
||||
# 7. 运行时信息(元信息,放在最后)
|
||||
|
||||
# 7. Runtime info (meta info, goes last)
|
||||
if runtime_info:
|
||||
sections.extend(_build_runtime_section(runtime_info, language))
|
||||
|
||||
|
||||
# 8. Response language (always appended, independent of the skeleton language)
|
||||
sections.extend(_build_response_language_section(language))
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def _build_response_language_section(language: str) -> List[str]:
|
||||
"""Response-language rule, appended regardless of the prompt skeleton language.
|
||||
|
||||
Keeps the agent's reply language aligned with the user's input by default,
|
||||
so a Chinese-built prompt still answers an English user in English.
|
||||
"""
|
||||
if language == "en":
|
||||
return [
|
||||
"## 🌐 Response language",
|
||||
"",
|
||||
"By default, reply in the same language as the user's input, "
|
||||
"unless the user explicitly asks for another language.",
|
||||
"",
|
||||
]
|
||||
return [
|
||||
"## 🌐 回复语言",
|
||||
"",
|
||||
"默认使用与用户输入相同的语言回复,除非用户明确要求使用其他语言。",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
def _build_identity_section(base_persona: Optional[str], language: str) -> List[str]:
|
||||
"""构建基础身份section - 不再需要,身份由SOUL.md定义"""
|
||||
# 不再生成基础身份section,完全由SOUL.md定义
|
||||
"""Base identity section - no longer needed, identity is defined by AGENT.md."""
|
||||
# Identity is fully defined by AGENT.md, so emit nothing here.
|
||||
return []
|
||||
|
||||
|
||||
def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"""构建工具说明section"""
|
||||
lines = [
|
||||
"## 工具系统",
|
||||
"",
|
||||
"你可以使用以下工具来完成任务。工具名称是大小写敏感的,请严格按照列表中的名称调用。",
|
||||
"",
|
||||
"### 可用工具",
|
||||
"",
|
||||
"""Build tooling section with concise tool list and call style guide."""
|
||||
is_en = language == "en"
|
||||
# One-line summaries for known tools (details are in the tool schema)
|
||||
if is_en:
|
||||
core_summaries = {
|
||||
"read": "read file content",
|
||||
"write": "create or overwrite a file",
|
||||
"edit": "make precise edits to a file",
|
||||
"ls": "list directory contents",
|
||||
"grep": "search file contents",
|
||||
"find": "find files by pattern",
|
||||
"bash": "run shell commands",
|
||||
"terminal": "manage background processes",
|
||||
"web_search": "web search",
|
||||
"web_fetch": "fetch URL content",
|
||||
"browser": "control the browser (screenshot key results or send to the user when help is needed)",
|
||||
"memory_search": "search memory",
|
||||
"memory_get": "read memory content",
|
||||
"env_config": "manage API keys and skill config",
|
||||
"scheduler": "manage scheduled tasks and reminders",
|
||||
"send": "send a local file to the user (local files only; put URLs directly in the reply text)",
|
||||
"vision": "analyze images (recognition, description, OCR, etc.)",
|
||||
}
|
||||
else:
|
||||
core_summaries = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建或覆盖文件",
|
||||
"edit": "精确编辑文件",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "搜索文件内容",
|
||||
"find": "按模式查找文件",
|
||||
"bash": "执行shell命令",
|
||||
"terminal": "管理后台进程",
|
||||
"web_search": "网络搜索",
|
||||
"web_fetch": "获取URL内容",
|
||||
"browser": "控制浏览器(关键结果或需要协助可截图发送给用户)",
|
||||
"memory_search": "搜索记忆",
|
||||
"memory_get": "读取记忆内容",
|
||||
"env_config": "管理API密钥和技能配置",
|
||||
"scheduler": "管理定时任务和提醒",
|
||||
"send": "发送本地文件给用户(仅限本地文件,URL直接放在回复文本中)",
|
||||
"vision": "分析图片内容(识别、描述、OCR文字提取等)",
|
||||
}
|
||||
|
||||
# Preferred display order
|
||||
tool_order = [
|
||||
"read", "write", "edit", "ls", "grep", "find",
|
||||
"bash", "terminal",
|
||||
"web_search", "web_fetch", "browser",
|
||||
"memory_search", "memory_get",
|
||||
"env_config", "scheduler", "send", "vision",
|
||||
]
|
||||
|
||||
# 工具分类和排序
|
||||
tool_categories = {
|
||||
"文件操作": ["read", "write", "edit", "ls", "grep", "find"],
|
||||
"命令执行": ["bash", "terminal"],
|
||||
"网络搜索": ["web_search", "web_fetch", "browser"],
|
||||
"记忆系统": ["memory_search", "memory_get"],
|
||||
"其他": []
|
||||
}
|
||||
|
||||
# 构建工具映射
|
||||
tool_map = {}
|
||||
tool_descriptions = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建新文件或完全覆盖现有文件(会删除原内容!追加内容请用 edit)。注意:单次 write 内容不要超过 10KB,超大文件请分步创建",
|
||||
"edit": "精确编辑文件(追加、修改、删除部分内容)",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "在文件中搜索内容",
|
||||
"find": "按照模式查找文件",
|
||||
"bash": "执行shell命令",
|
||||
"terminal": "管理后台进程",
|
||||
"web_search": "网络搜索(使用搜索引擎)",
|
||||
"web_fetch": "获取URL内容",
|
||||
"browser": "控制浏览器",
|
||||
"memory_search": "搜索记忆文件",
|
||||
"memory_get": "获取记忆文件内容",
|
||||
"calculator": "计算器",
|
||||
"current_time": "获取当前时间",
|
||||
}
|
||||
|
||||
|
||||
# Build name -> summary mapping for available tools
|
||||
available = {}
|
||||
for tool in tools:
|
||||
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
tool_desc = tool.description if hasattr(tool, 'description') else tool_descriptions.get(tool_name, "")
|
||||
tool_map[tool_name] = tool_desc
|
||||
|
||||
# 按分类添加工具
|
||||
for category, tool_names in tool_categories.items():
|
||||
category_tools = [(name, tool_map.get(name, "")) for name in tool_names if name in tool_map]
|
||||
if category_tools:
|
||||
lines.append(f"**{category}**:")
|
||||
for name, desc in category_tools:
|
||||
if desc:
|
||||
lines.append(f"- `{name}`: {desc}")
|
||||
else:
|
||||
lines.append(f"- `{name}`")
|
||||
del tool_map[name] # 移除已添加的工具
|
||||
lines.append("")
|
||||
|
||||
# 添加其他未分类的工具
|
||||
if tool_map:
|
||||
lines.append("**其他工具**:")
|
||||
for name, desc in sorted(tool_map.items()):
|
||||
if desc:
|
||||
lines.append(f"- `{name}`: {desc}")
|
||||
else:
|
||||
lines.append(f"- `{name}`")
|
||||
lines.append("")
|
||||
|
||||
# 工具使用指南
|
||||
lines.extend([
|
||||
"### 工具调用风格",
|
||||
"",
|
||||
"默认规则: 对于常规、低风险的工具调用,直接调用即可,无需叙述。",
|
||||
"",
|
||||
"需要叙述的情况:",
|
||||
"- 多步骤、复杂的任务",
|
||||
"- 敏感操作(如删除文件)",
|
||||
"- 用户明确要求解释过程",
|
||||
"",
|
||||
"叙述要求: 保持简洁、信息密度高,避免重复显而易见的步骤。",
|
||||
"",
|
||||
"完成标准:",
|
||||
"- 确保用户的需求得到实际解决,而不仅仅是制定计划。",
|
||||
"- 当任务需要多次工具调用时,持续推进直到完成, 解决完后向用户报告结果或回复用户的问题",
|
||||
"- 每次工具调用后,评估是否已获得足够信息来推进或完成任务",
|
||||
"- 避免重复调用相同的工具和相同参数获取相同的信息,除非用户明确要求",
|
||||
"",
|
||||
"**安全提醒**: 回复中涉及密钥、令牌、密码等敏感信息时,必须脱敏处理,禁止直接显示完整内容。",
|
||||
"",
|
||||
])
|
||||
|
||||
name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
available[name] = core_summaries.get(name, "")
|
||||
|
||||
# Generate tool lines: ordered tools first, then extras
|
||||
tool_lines = []
|
||||
for name in tool_order:
|
||||
if name in available:
|
||||
summary = available.pop(name)
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
for name in sorted(available):
|
||||
summary = available[name]
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
|
||||
if is_en:
|
||||
lines = [
|
||||
"## 🔧 Tooling",
|
||||
"",
|
||||
"Available tools (names are case-sensitive, call exactly as listed):",
|
||||
"\n".join(tool_lines),
|
||||
"",
|
||||
"Tool-calling style:",
|
||||
"",
|
||||
"- For multi-step tasks, complex decisions or sensitive operations, briefly explain what you are doing and why, so the user follows key progress",
|
||||
"- Keep going until the task is done, then report the result to the user",
|
||||
"- Always redact secrets, tokens and other sensitive info in replies",
|
||||
"- Put URLs directly in the reply text; the system handles and renders them. Don't download and re-send them via the send tool",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 🔧 工具系统",
|
||||
"",
|
||||
"可用工具(名称大小写敏感,严格按列表调用):",
|
||||
"\n".join(tool_lines),
|
||||
"",
|
||||
"工具调用风格:",
|
||||
"",
|
||||
"- 多步骤任务、复杂决策、敏感操作时,应简要说明当前在做什么、为什么这样做,让用户了解关键进展",
|
||||
"- 持续推进直到任务完成,完成后向用户报告结果",
|
||||
"- 回复中涉及密钥、令牌等敏感信息必须脱敏",
|
||||
"- URL链接直接放在回复文本中即可,系统会自动处理和渲染。无需下载后使用send工具发送",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""构建技能系统section"""
|
||||
"""Build the skills section."""
|
||||
if not skill_manager:
|
||||
return []
|
||||
|
||||
# 获取read工具名称
|
||||
# Resolve the read tool name
|
||||
read_tool_name = "read"
|
||||
if tools:
|
||||
for tool in tools:
|
||||
@@ -263,186 +300,393 @@ def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], langua
|
||||
read_tool_name = tool_name
|
||||
break
|
||||
|
||||
lines = [
|
||||
"## 技能系统",
|
||||
"",
|
||||
"在回复之前:扫描下方 <available_skills> 中的 <description> 条目。",
|
||||
"",
|
||||
f"- 如果恰好有一个技能明确适用:使用 `{read_tool_name}` 工具读取其 <location> 路径下的 SKILL.md 文件,然后遵循它",
|
||||
"- 如果多个技能都适用:选择最具体的一个,然后读取并遵循",
|
||||
"- 如果没有明确适用的:不要读取任何 SKILL.md",
|
||||
"",
|
||||
"**约束**: 永远不要一次性读取多个技能;只在选择后再读取。",
|
||||
"",
|
||||
]
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 🧩 Skills (mandatory)",
|
||||
"",
|
||||
"Before replying: scan the <description> of every skill in <available_skills> below.",
|
||||
"",
|
||||
f"- If a skill's description matches the user's need: use the `{read_tool_name}` tool to read the SKILL.md at its <location> path, then strictly follow the instructions in the file. "
|
||||
"Prefer using a skill when one matches.",
|
||||
"- If multiple skills apply, pick the best-matching one, then read and follow it.",
|
||||
"- If no skill clearly applies: do not read any SKILL.md, just use the general tools.",
|
||||
"",
|
||||
f"**Important**: skills are not tools and cannot be called directly. The only way to use a skill is to read its SKILL.md with `{read_tool_name}`, then act on the file's content. "
|
||||
"Never read multiple skills at once — only read one after selecting it.",
|
||||
"",
|
||||
"Available skills:"
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 🧩 技能系统(mandatory)",
|
||||
"",
|
||||
"在回复之前:扫描下方 <available_skills> 中每个技能的 <description>。",
|
||||
"",
|
||||
f"- 如果有技能的描述与用户需求匹配:使用 `{read_tool_name}` 工具读取其 <location> 路径的 SKILL.md 文件,然后严格遵循文件中的指令。"
|
||||
"当有匹配的技能时,应优先使用技能",
|
||||
"- 如果多个技能都适用则选择最匹配的一个,然后读取并遵循。",
|
||||
"- 如果没有技能明确适用:不要读取任何 SKILL.md,直接使用通用工具。",
|
||||
"",
|
||||
f"**重要**: 技能不是工具,不能直接调用。使用技能的唯一方式是用 `{read_tool_name}` 读取 SKILL.md 文件,然后按文件内容操作。"
|
||||
"永远不要一次性读取多个技能,只在选择后再读取。",
|
||||
"",
|
||||
"以下是可用技能:"
|
||||
]
|
||||
|
||||
# 添加技能列表(通过skill_manager获取)
|
||||
# Append the skills list (built by skill_manager)
|
||||
try:
|
||||
skills_prompt = skill_manager.build_skills_prompt()
|
||||
logger.debug(f"[PromptBuilder] Skills prompt length: {len(skills_prompt) if skills_prompt else 0}")
|
||||
if skills_prompt:
|
||||
lines.append(skills_prompt.strip())
|
||||
lines.append("")
|
||||
else:
|
||||
logger.warning("[PromptBuilder] No skills prompt generated - skills_prompt is empty")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build skills prompt: {e}")
|
||||
import traceback
|
||||
logger.debug(f"Skills prompt error traceback: {traceback.format_exc()}")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""构建记忆系统section"""
|
||||
"""Build the memory section."""
|
||||
if not memory_manager:
|
||||
return []
|
||||
|
||||
# 检查是否有memory工具
|
||||
|
||||
has_memory_tools = False
|
||||
if tools:
|
||||
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
|
||||
has_memory_tools = any(name in ['memory_search', 'memory_get'] for name in tool_names)
|
||||
|
||||
|
||||
if not has_memory_tools:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 记忆系统",
|
||||
|
||||
from datetime import datetime
|
||||
today_file = datetime.now().strftime("%Y-%m-%d") + ".md"
|
||||
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 🧠 Memory",
|
||||
"",
|
||||
"### Memory Recall (mandatory)",
|
||||
"",
|
||||
"When the user asks about past events, references an earlier decision, mentions relationships, preferences or to-dos, or when you are unsure about something, **you must search memory before answering**.",
|
||||
"No need to re-search if the info is already in MEMORY.md. Full content and daily memory must be retrieved via tools.",
|
||||
"",
|
||||
"1. Location unknown → `memory_search` (keyword / semantic search)",
|
||||
"2. Location known → `memory_get` to read the exact lines",
|
||||
"3. Search returns nothing → `memory_get` to read the last two days of memory",
|
||||
"",
|
||||
"**Memory file structure**:",
|
||||
"- `MEMORY.md`: long-term memory index (already auto-loaded into context: core info, preferences, decisions, etc.)",
|
||||
f"- `memory/YYYY-MM-DD.md`: daily memory; today is `memory/{today_file}`",
|
||||
"- `knowledge/`: structured knowledge base (see the knowledge system below)",
|
||||
"",
|
||||
"### Writing memory",
|
||||
"",
|
||||
"In the following cases, **proactively** write info to memory files (no need to tell the user):",
|
||||
"",
|
||||
"- The user asks you to remember something, or uses words like \"remember\", \"from now on\", \"always\", \"never\", \"prefer\"",
|
||||
"- The user shares important personal preferences, habits or decisions",
|
||||
"- The conversation produces an important conclusion, plan or agreement",
|
||||
"- A complex task is completed and the key steps and results are worth recording",
|
||||
"",
|
||||
"**Storage rules**:",
|
||||
"- Long-term core info → `MEMORY.md`",
|
||||
f"- Today's events/progress → `memory/{today_file}`",
|
||||
"- Structured knowledge → `knowledge/` (see the knowledge system)",
|
||||
"- Append → `edit` tool with empty oldText",
|
||||
"- Modify → `edit` tool with oldText set to the text to replace",
|
||||
"- **Never write sensitive info** (API keys, tokens, etc.)",
|
||||
"",
|
||||
"**Principle**: use memory naturally, as if you simply knew it; don't bring it up unless asked.",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 🧠 记忆系统",
|
||||
"",
|
||||
"### Memory Recall(mandatory)",
|
||||
"",
|
||||
"当用户询问过往事件、引用之前的决定、提到人物关系、偏好、待办、或你对某事不确定时,**必须先检索记忆再回答**。",
|
||||
"如果 MEMORY.md 中已有相关信息则无需重复检索。完整内容和每日记忆需要通过工具检索。",
|
||||
"",
|
||||
"1. 不确定位置 → `memory_search` 关键词/语义检索",
|
||||
"2. 已知位置 → `memory_get` 直接读取对应行",
|
||||
"3. search 无结果 → `memory_get` 读最近两天记忆",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆索引(已自动加载到上下文,核心信息、偏好、决策等)",
|
||||
f"- `memory/YYYY-MM-DD.md`: 每日记忆,今天是 `memory/{today_file}`",
|
||||
"- `knowledge/`: 结构化知识库(见下方知识系统)",
|
||||
"",
|
||||
"### 写入记忆",
|
||||
"",
|
||||
"遇到以下情况时,**主动**将信息写入记忆文件(无需告知用户):",
|
||||
"",
|
||||
"- 用户要求记住某些信息,或使用了「记住」「以后」「总是」「不要」「偏好」等表达",
|
||||
"- 用户分享了重要的个人偏好、习惯、决策",
|
||||
"- 对话中产生了重要的结论、方案、约定",
|
||||
"- 完成了复杂任务,值得记录关键步骤和结果",
|
||||
"",
|
||||
"**存储规则**:",
|
||||
f"- 长期核心信息 → `MEMORY.md`",
|
||||
f"- 当天事件/进展 → `memory/{today_file}`",
|
||||
"- 结构化知识 → `knowledge/`(见知识系统)",
|
||||
"- 追加 → `edit` 工具,oldText 留空",
|
||||
"- 修改 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- **禁止写入敏感信息**(API密钥、令牌等)",
|
||||
"",
|
||||
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_knowledge_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""Build knowledge wiki section. Injects knowledge/index.md when present."""
|
||||
index_path = os.path.join(workspace_dir, "knowledge", "index.md")
|
||||
if not os.path.exists(index_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(index_path, 'r', encoding='utf-8') as f:
|
||||
index_content = f.read().strip()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 📚 Knowledge",
|
||||
"",
|
||||
"You have a continuously growing personal knowledge base `knowledge/` — your long-term structured knowledge store.",
|
||||
"",
|
||||
"### Auto-write rules (mandatory)",
|
||||
"",
|
||||
"In the following cases you **must** write to the knowledge base alongside your reply, **directly, without asking the user**:",
|
||||
"",
|
||||
"1. **User shares an article / link / document** → after reading and understanding, write the key points to `knowledge/sources/<slug>.md` in the same turn",
|
||||
"2. **An in-depth discussion produces a conclusion / plan** → organize it into `knowledge/analysis/<slug>.md`",
|
||||
"3. **The conversation involves an important entity** (person / company / project) → create or update `knowledge/entities/<name>.md`",
|
||||
"4. **A technical concept / methodology is discussed** → organize it into `knowledge/concepts/<topic>.md`",
|
||||
"",
|
||||
"After writing any knowledge page, you **must update** `knowledge/index.md` with a new index line in sync.",
|
||||
"For detailed page format and conventions, read the SKILL.md of the `knowledge-wiki` skill.",
|
||||
"",
|
||||
"⚠️ Don't ask \"should I save this to the knowledge base?\" — if a case above matches, just write it. This is instinctive.",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 📚 知识系统",
|
||||
"",
|
||||
"你拥有一个持续积累的个人知识库 `knowledge/`,这是你的长期结构化知识存储。",
|
||||
"",
|
||||
"### 自动写入规则(mandatory)",
|
||||
"",
|
||||
"以下场景**必须**在回复的同时写入知识库,**直接写入,不要询问用户是否需要**:",
|
||||
"",
|
||||
"1. **用户分享了文章/链接/文档** → 阅读理解后,在同一轮回复中将要点写入 `knowledge/sources/<slug>.md`",
|
||||
"2. **深度讨论产生了结论/方案** → 整理为 `knowledge/analysis/<slug>.md`",
|
||||
"3. **对话涉及重要实体**(人物/公司/项目)→ 创建或更新 `knowledge/entities/<name>.md`",
|
||||
"4. **讨论了技术概念/方法论** → 整理为 `knowledge/concepts/<topic>.md`",
|
||||
"",
|
||||
"每次写入知识页面后,**必须同步更新** `knowledge/index.md` 添加一行索引。",
|
||||
"详细的页面格式和操作规范,请读取技能 `knowledge-wiki` 的 SKILL.md。",
|
||||
"",
|
||||
"⚠️ 不要问「要不要存到知识库」——符合上述场景就直接写入,这是你的本能行为。",
|
||||
"",
|
||||
]
|
||||
|
||||
if index_content:
|
||||
lines.extend([
|
||||
("### Current knowledge index" if language == "en" else "### 当前知识索引"),
|
||||
"",
|
||||
index_content,
|
||||
"",
|
||||
])
|
||||
|
||||
lines.extend([
|
||||
("**How to query**: use `read` to open a knowledge page, or `memory_search` (knowledge is in the vector index)."
|
||||
if language == "en" else
|
||||
"**查询方式**:用 `read` 读取知识页面,或用 `memory_search` 检索(知识已纳入向量索引)。"),
|
||||
"",
|
||||
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
|
||||
"",
|
||||
"1. 不确定记忆文件位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
|
||||
"2. 已知文件位置 → 直接用 `memory_get` 读取相应的行 (例如:MEMORY.md, memory/YYYY-MM-DD.md)",
|
||||
"3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
|
||||
"",
|
||||
"**写入记忆**:",
|
||||
"- 追加内容 → `edit` 工具,oldText 留空",
|
||||
"- 修改内容 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- 新建文件 → `write` 工具",
|
||||
"- **禁止写入敏感信息**:API密钥、令牌等敏感信息严禁写入记忆文件",
|
||||
"",
|
||||
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
|
||||
"",
|
||||
]
|
||||
|
||||
])
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_user_identity_section(user_identity: Dict[str, str], language: str) -> List[str]:
|
||||
"""构建用户身份section"""
|
||||
"""Build the user identity section."""
|
||||
if not user_identity:
|
||||
return []
|
||||
|
||||
is_en = language == "en"
|
||||
lines = [
|
||||
"## 用户身份",
|
||||
("## 👤 User identity" if is_en else "## 👤 用户身份"),
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
if user_identity.get("name"):
|
||||
lines.append(f"**用户姓名**: {user_identity['name']}")
|
||||
lines.append(f"**{'Name' if is_en else '用户姓名'}**: {user_identity['name']}")
|
||||
if user_identity.get("nickname"):
|
||||
lines.append(f"**称呼**: {user_identity['nickname']}")
|
||||
lines.append(f"**{'Preferred name' if is_en else '称呼'}**: {user_identity['nickname']}")
|
||||
if user_identity.get("timezone"):
|
||||
lines.append(f"**时区**: {user_identity['timezone']}")
|
||||
lines.append(f"**{'Timezone' if is_en else '时区'}**: {user_identity['timezone']}")
|
||||
if user_identity.get("notes"):
|
||||
lines.append(f"**备注**: {user_identity['notes']}")
|
||||
|
||||
lines.append(f"**{'Notes' if is_en else '备注'}**: {user_identity['notes']}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""构建文档路径section - 已移除,不再需要"""
|
||||
# 不再生成文档section
|
||||
"""Docs-path section - removed, no longer needed."""
|
||||
# No docs section is generated anymore.
|
||||
return []
|
||||
|
||||
|
||||
def _build_workspace_section(workspace_dir: str, language: str, is_first_conversation: bool = False) -> List[str]:
|
||||
"""构建工作空间section"""
|
||||
lines = [
|
||||
"## 工作空间",
|
||||
"",
|
||||
f"你的工作目录是: `{workspace_dir}`",
|
||||
"",
|
||||
"**路径使用规则** (非常重要):",
|
||||
"",
|
||||
f"1. **相对路径的基准目录**: 所有相对路径都是相对于 `{workspace_dir}` 而言的",
|
||||
f" - ✅ 正确: 访问工作空间内的文件用相对路径,如 `SOUL.md`",
|
||||
f" - ❌ 错误: 用相对路径访问其他目录的文件 (如果它不在 `{workspace_dir}` 内)",
|
||||
"",
|
||||
"2. **访问其他目录**: 如果要访问工作空间之外的目录(如项目代码、系统文件),**必须使用绝对路径**",
|
||||
f" - ✅ 正确: 例如 `~/chatgpt-on-wechat`、`/usr/local/`",
|
||||
f" - ❌ 错误: 假设相对路径会指向其他目录",
|
||||
"",
|
||||
"3. **路径解析示例**:",
|
||||
f" - 相对路径 `memory/` → 实际路径 `{workspace_dir}/memory/`",
|
||||
f" - 绝对路径 `~/chatgpt-on-wechat/docs/` → 实际路径 `~/chatgpt-on-wechat/docs/`",
|
||||
"",
|
||||
"4. **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
|
||||
"",
|
||||
"**重要说明 - 文件已自动加载**:",
|
||||
"",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词的「项目上下文」section 中,你**无需再用 read 工具读取它们**:",
|
||||
"",
|
||||
"- ✅ `SOUL.md`: 已加载 - Agent的人格设定",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息",
|
||||
"- ✅ `AGENTS.md`: 已加载 - 工作空间使用指南",
|
||||
"",
|
||||
"**交流规范**:",
|
||||
"",
|
||||
"- 在对话中,非必要不输出工作空间技术细节(如 SOUL.md、USER.md等文件名称,工具名称,配置等),除非用户明确询问",
|
||||
"- 例如用自然表达如「我已记住」而非「已更新 MEMORY.md」",
|
||||
"",
|
||||
]
|
||||
|
||||
# 只在首次对话时添加引导内容
|
||||
if is_first_conversation:
|
||||
lines.extend([
|
||||
"**🎉 首次对话引导**:",
|
||||
def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""Build the workspace section."""
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 📂 Workspace",
|
||||
"",
|
||||
"这是你的第一次对话!进行以下流程:",
|
||||
f"Your working directory is: `{workspace_dir}`",
|
||||
"",
|
||||
"1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待",
|
||||
"2. **简短打招呼后,询问核心问题**:",
|
||||
" - 你希望给我起个什么名字?",
|
||||
" - 我该怎么称呼你?",
|
||||
" - 你希望我们是什么样的交流风格?(需要举例,如:专业严谨、轻松幽默、温暖友好等)",
|
||||
"3. **语言风格**:温暖但不过度诗意,带点科技感,保持清晰",
|
||||
"4. **问题格式**:用分点或换行,让问题清晰易读",
|
||||
"5. 收到回复后,用 `write` 工具保存到 USER.md 和 SOUL.md",
|
||||
"**Path rules** (very important):",
|
||||
"",
|
||||
"**注意事项**:",
|
||||
"- 不要问太多其他信息(职业、时区等可以后续自然了解)",
|
||||
f"1. **Base directory for relative paths**: all relative paths are relative to `{workspace_dir}`",
|
||||
" - ✅ Correct: use relative paths for files inside the workspace, e.g. `AGENT.md`",
|
||||
f" - ❌ Wrong: using a relative path for files in other directories (if not inside `{workspace_dir}`)",
|
||||
"",
|
||||
])
|
||||
"2. **Accessing other directories**: to reach directories outside the workspace (project code, system files), **you must use absolute paths**",
|
||||
" - ✅ Correct: e.g. `~/chatgpt-on-wechat`, `/usr/local/`",
|
||||
" - ❌ Wrong: assuming a relative path points to another directory",
|
||||
"",
|
||||
"3. **Path resolution examples**:",
|
||||
f" - relative `memory/` → actual `{workspace_dir}/memory/`",
|
||||
" - absolute `~/chatgpt-on-wechat/docs/` → actual `~/chatgpt-on-wechat/docs/`",
|
||||
"",
|
||||
"4. **When unsure**: run `bash pwd` to confirm the current directory, or `ls .` to see where you are",
|
||||
"",
|
||||
"**Important - files already auto-loaded**:",
|
||||
"",
|
||||
"The following files are **already auto-loaded** into the system prompt at session start, so you **don't need to read them again with the read tool**:",
|
||||
"",
|
||||
"- ✅ `AGENT.md`: loaded - your persona and soul; follow it strictly. When your name, personality or style changes, proactively `edit` this file",
|
||||
"- ✅ `USER.md`: loaded - the user's identity info. When the user changes how they're addressed, their name, etc., `edit` this file",
|
||||
"- ✅ `RULE.md`: loaded - workspace guide and rules; follow them strictly",
|
||||
"- ✅ `MEMORY.md`: loaded - long-term memory index",
|
||||
"",
|
||||
"**💬 Communication norms**:",
|
||||
"",
|
||||
"- No need to expose file names for memory operations; use natural language. Say \"I'll remember that\" rather than \"updated MEMORY.md\"",
|
||||
"- Tell the user about key decisions and steps during a task, so they know what you're doing and why",
|
||||
"- Be genuinely helpful rather than performatively polite; solve the problem as much as you can",
|
||||
"- Keep replies well-structured and focused. Use **bold**, lists and sections to make info clear at a glance",
|
||||
"- Use emoji to make expression lively 🎯, but don't overdo it",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 📂 工作空间",
|
||||
"",
|
||||
f"你的工作目录是: `{workspace_dir}`",
|
||||
"",
|
||||
"**路径使用规则** (非常重要):",
|
||||
"",
|
||||
f"1. **相对路径的基准目录**: 所有相对路径都是相对于 `{workspace_dir}` 而言的",
|
||||
f" - ✅ 正确: 访问工作空间内的文件用相对路径,如 `AGENT.md`",
|
||||
f" - ❌ 错误: 用相对路径访问其他目录的文件 (如果它不在 `{workspace_dir}` 内)",
|
||||
"",
|
||||
"2. **访问其他目录**: 如果要访问工作空间之外的目录(如项目代码、系统文件),**必须使用绝对路径**",
|
||||
f" - ✅ 正确: 例如 `~/chatgpt-on-wechat`、`/usr/local/`",
|
||||
f" - ❌ 错误: 假设相对路径会指向其他目录",
|
||||
"",
|
||||
"3. **路径解析示例**:",
|
||||
f" - 相对路径 `memory/` → 实际路径 `{workspace_dir}/memory/`",
|
||||
f" - 绝对路径 `~/chatgpt-on-wechat/docs/` → 实际路径 `~/chatgpt-on-wechat/docs/`",
|
||||
"",
|
||||
"4. **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
|
||||
"",
|
||||
"**重要说明 - 文件已自动加载**:",
|
||||
"",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词中,你**无需再用 read 工具读取**:",
|
||||
"",
|
||||
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定,请严格遵循。当你的名字、性格或交流风格发生变化时,主动用 `edit` 更新此文件",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息。当用户修改称呼、姓名等身份信息时,用 `edit` 更新此文件",
|
||||
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则,请严格遵循",
|
||||
"- ✅ `MEMORY.md`: 已加载 - 长期记忆索引",
|
||||
"",
|
||||
"**💬 交流规范**:",
|
||||
"",
|
||||
"- 记忆相关操作无需暴露文件名,用自然语言表达即可。例如说「我已记住」而非「已更新 MEMORY.md」",
|
||||
"- 任务执行过程中的关键决策和步骤应该告知用户,让用户了解你在做什么、为什么这么做",
|
||||
"- 做真正有帮助的助手,而不是表演式的客套,尽可能帮忙解决问题",
|
||||
"- 回复应结构清晰、重点突出。善用 **加粗**、列表、分段等格式让信息一目了然",
|
||||
"- 适当使用 emoji 让表达更生动自然 🎯,但不要过度堆砌",
|
||||
"",
|
||||
]
|
||||
|
||||
# Cloud deployment: inject websites directory info and access URL
|
||||
cloud_website_lines = _build_cloud_website_section(workspace_dir)
|
||||
if cloud_website_lines:
|
||||
lines.extend(cloud_website_lines)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_cloud_website_section(workspace_dir: str) -> List[str]:
|
||||
"""Build cloud website access prompt when cloud deployment is configured."""
|
||||
try:
|
||||
from common.cloud_client import build_website_prompt
|
||||
return build_website_prompt(workspace_dir)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
|
||||
"""构建项目上下文文件section"""
|
||||
"""Build the project context files section."""
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
# 检查是否有SOUL.md
|
||||
has_soul = any(
|
||||
f.path.lower().endswith('soul.md') or 'soul.md' in f.path.lower()
|
||||
# Check whether AGENT.md is present
|
||||
has_agent = any(
|
||||
f.path.lower().endswith('agent.md') or 'agent.md' in f.path.lower()
|
||||
for f in context_files
|
||||
)
|
||||
|
||||
lines = [
|
||||
"# 项目上下文",
|
||||
"",
|
||||
"以下项目上下文文件已被加载:",
|
||||
"",
|
||||
]
|
||||
|
||||
if has_soul:
|
||||
lines.append("如果存在 `SOUL.md`,请体现其中定义的人格和语气。避免僵硬、模板化的回复;遵循其指导,除非有更高优先级的指令覆盖它。")
|
||||
is_en = language == "en"
|
||||
if is_en:
|
||||
lines = [
|
||||
"# 📋 Project context",
|
||||
"",
|
||||
"The following project context files have been loaded:",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"# 📋 项目上下文",
|
||||
"",
|
||||
"以下项目上下文文件已被加载:",
|
||||
"",
|
||||
]
|
||||
|
||||
if has_agent:
|
||||
if is_en:
|
||||
lines.append("**`AGENT.md` is your soul file** 🪞: strictly follow the persona, tone and settings it defines. Be your real self, avoid stiff, template-like replies.")
|
||||
lines.append("When the user reveals new expectations about your personality, style, responsibilities or capability boundaries, proactively `edit` AGENT.md to reflect that evolution.")
|
||||
else:
|
||||
lines.append("**`AGENT.md` 是你的灵魂文件** 🪞:严格遵循其中定义的人格、语气和设定,做真实的自己,避免僵硬、模板化的回复。")
|
||||
lines.append("当用户通过对话透露了对你性格、风格、职责、能力边界的新期望,你应该主动用 `edit` 更新 AGENT.md 以反映这些演变。")
|
||||
lines.append("")
|
||||
|
||||
# 添加每个文件的内容
|
||||
# Append the content of each file
|
||||
for file in context_files:
|
||||
lines.append(f"## {file.path}")
|
||||
lines.append("")
|
||||
@@ -453,42 +697,64 @@ def _build_context_files_section(context_files: List[ContextFile], language: str
|
||||
|
||||
|
||||
def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[str]:
|
||||
"""构建运行时信息section"""
|
||||
"""Build the runtime info section - supports dynamic time."""
|
||||
if not runtime_info:
|
||||
return []
|
||||
|
||||
is_en = language == "en"
|
||||
time_label = "Current time" if is_en else "当前时间"
|
||||
lines = [
|
||||
"## 运行时信息",
|
||||
("## ⚙️ Runtime info" if is_en else "## ⚙️ 运行时信息"),
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
# Add current time if available
|
||||
if runtime_info.get("current_time"):
|
||||
# Support dynamic time via callable function
|
||||
if callable(runtime_info.get("_get_current_time")):
|
||||
try:
|
||||
time_info = runtime_info["_get_current_time"]()
|
||||
time_line = f"{time_label}: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})"
|
||||
lines.append(time_line)
|
||||
lines.append("")
|
||||
except Exception as e:
|
||||
logger.warning(f"[PromptBuilder] Failed to get dynamic time: {e}")
|
||||
elif runtime_info.get("current_time"):
|
||||
# Fallback to static time for backward compatibility
|
||||
time_str = runtime_info["current_time"]
|
||||
weekday = runtime_info.get("weekday", "")
|
||||
timezone = runtime_info.get("timezone", "")
|
||||
|
||||
time_line = f"当前时间: {time_str}"
|
||||
|
||||
time_line = f"{time_label}: {time_str}"
|
||||
if weekday:
|
||||
time_line += f" {weekday}"
|
||||
if timezone:
|
||||
time_line += f" ({timezone})"
|
||||
|
||||
|
||||
lines.append(time_line)
|
||||
lines.append("")
|
||||
|
||||
|
||||
# Add other runtime info
|
||||
model_label = "model" if is_en else "模型"
|
||||
workspace_label = "workspace" if is_en else "工作空间"
|
||||
channel_label = "channel" if is_en else "渠道"
|
||||
runtime_parts = []
|
||||
if runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={runtime_info['model']}")
|
||||
# Support dynamic model via callable, fallback to static value
|
||||
if callable(runtime_info.get("_get_model")):
|
||||
try:
|
||||
runtime_parts.append(f"{model_label}={runtime_info['_get_model']()}")
|
||||
except Exception:
|
||||
if runtime_info.get("model"):
|
||||
runtime_parts.append(f"{model_label}={runtime_info['model']}")
|
||||
elif runtime_info.get("model"):
|
||||
runtime_parts.append(f"{model_label}={runtime_info['model']}")
|
||||
if runtime_info.get("workspace"):
|
||||
runtime_parts.append(f"工作空间={runtime_info['workspace']}")
|
||||
runtime_parts.append(f"{workspace_label}={runtime_info['workspace']}")
|
||||
# Only add channel if it's not the default "web"
|
||||
if runtime_info.get("channel") and runtime_info.get("channel") != "web":
|
||||
runtime_parts.append(f"渠道={runtime_info['channel']}")
|
||||
|
||||
runtime_parts.append(f"{channel_label}={runtime_info['channel']}")
|
||||
|
||||
if runtime_parts:
|
||||
lines.append("运行时: " + " | ".join(runtime_parts))
|
||||
lines.append(("Runtime: " if is_en else "运行时: ") + " | ".join(runtime_parts))
|
||||
lines.append("")
|
||||
|
||||
|
||||
return lines
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Workspace Management - 工作空间管理模块
|
||||
Workspace Management
|
||||
|
||||
负责初始化工作空间、创建模板文件、加载上下文文件
|
||||
Initializes the workspace, creates template files, and loads context files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -13,86 +13,119 @@ from common.log import logger
|
||||
from .builder import ContextFile
|
||||
|
||||
|
||||
# 默认文件名常量
|
||||
DEFAULT_SOUL_FILENAME = "SOUL.md"
|
||||
# Default file name constants
|
||||
DEFAULT_AGENT_FILENAME = "AGENT.md"
|
||||
DEFAULT_USER_FILENAME = "USER.md"
|
||||
DEFAULT_AGENTS_FILENAME = "AGENTS.md"
|
||||
DEFAULT_RULE_FILENAME = "RULE.md"
|
||||
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
|
||||
DEFAULT_STATE_FILENAME = ".agent_state.json"
|
||||
DEFAULT_BOOTSTRAP_FILENAME = "BOOTSTRAP.md"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceFiles:
|
||||
"""工作空间文件路径"""
|
||||
soul_path: str
|
||||
"""Workspace file paths."""
|
||||
agent_path: str
|
||||
user_path: str
|
||||
agents_path: str
|
||||
rule_path: str
|
||||
memory_path: str
|
||||
memory_dir: str
|
||||
state_path: str
|
||||
|
||||
|
||||
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
|
||||
"""
|
||||
确保工作空间存在,并创建必要的模板文件
|
||||
|
||||
Ensure the workspace exists and create the necessary template files.
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录路径
|
||||
create_templates: 是否创建模板文件(首次运行时)
|
||||
|
||||
workspace_dir: workspace directory path
|
||||
create_templates: whether to create template files (on first run)
|
||||
|
||||
Returns:
|
||||
WorkspaceFiles对象,包含所有文件路径
|
||||
A WorkspaceFiles object with all file paths.
|
||||
"""
|
||||
# 确保目录存在
|
||||
# Check if this is a brand new workspace (AGENT.md not yet created).
|
||||
# Cannot rely on directory existence because other modules (e.g. ConversationStore)
|
||||
# may create the workspace directory before ensure_workspace is called.
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
is_new_workspace = not os.path.exists(agent_path)
|
||||
|
||||
# Ensure the directory exists
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
|
||||
# 定义文件路径
|
||||
soul_path = os.path.join(workspace_dir, DEFAULT_SOUL_FILENAME)
|
||||
# Define file paths
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
agents_path = os.path.join(workspace_dir, DEFAULT_AGENTS_FILENAME)
|
||||
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md 在根目录
|
||||
memory_dir = os.path.join(workspace_dir, "memory") # 每日记忆子目录
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME) # 状态文件
|
||||
rule_path = os.path.join(workspace_dir, DEFAULT_RULE_FILENAME)
|
||||
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md at the root
|
||||
memory_dir = os.path.join(workspace_dir, "memory") # daily memory subdirectory
|
||||
|
||||
# 创建memory子目录
|
||||
# Create the memory subdirectory
|
||||
os.makedirs(memory_dir, exist_ok=True)
|
||||
|
||||
# Create the skills subdirectory (for workspace-level skills installed by agent)
|
||||
skills_dir = os.path.join(workspace_dir, "skills")
|
||||
os.makedirs(skills_dir, exist_ok=True)
|
||||
|
||||
# Create the websites subdirectory (for web pages / sites generated by agent)
|
||||
websites_dir = os.path.join(workspace_dir, "websites")
|
||||
os.makedirs(websites_dir, exist_ok=True)
|
||||
|
||||
from config import conf
|
||||
knowledge_enabled = conf().get("knowledge", True)
|
||||
if knowledge_enabled:
|
||||
knowledge_dir = os.path.join(workspace_dir, "knowledge")
|
||||
os.makedirs(knowledge_dir, exist_ok=True)
|
||||
|
||||
# 如果需要,创建模板文件
|
||||
# Create template files if requested
|
||||
if create_templates:
|
||||
_create_template_if_missing(soul_path, _get_soul_template())
|
||||
_create_template_if_missing(agent_path, _get_agent_template())
|
||||
_create_template_if_missing(user_path, _get_user_template())
|
||||
_create_template_if_missing(agents_path, _get_agents_template())
|
||||
_create_template_if_missing(rule_path, _get_rule_template())
|
||||
_create_template_if_missing(memory_path, _get_memory_template())
|
||||
if knowledge_enabled:
|
||||
_create_template_if_missing(
|
||||
os.path.join(knowledge_dir, "index.md"),
|
||||
_get_knowledge_index_template()
|
||||
)
|
||||
_create_template_if_missing(
|
||||
os.path.join(knowledge_dir, "log.md"),
|
||||
_get_knowledge_log_template()
|
||||
)
|
||||
|
||||
# Only create BOOTSTRAP.md for brand new workspaces;
|
||||
# agent deletes it after completing onboarding
|
||||
if is_new_workspace:
|
||||
bootstrap_path = os.path.join(workspace_dir, DEFAULT_BOOTSTRAP_FILENAME)
|
||||
_create_template_if_missing(bootstrap_path, _get_bootstrap_template())
|
||||
|
||||
logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
|
||||
|
||||
return WorkspaceFiles(
|
||||
soul_path=soul_path,
|
||||
agent_path=agent_path,
|
||||
user_path=user_path,
|
||||
agents_path=agents_path,
|
||||
rule_path=rule_path,
|
||||
memory_path=memory_path,
|
||||
memory_dir=memory_dir,
|
||||
state_path=state_path
|
||||
)
|
||||
|
||||
|
||||
def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] = None) -> List[ContextFile]:
|
||||
"""
|
||||
加载工作空间的上下文文件
|
||||
|
||||
Load the workspace context files.
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
files_to_load: 要加载的文件列表(相对路径),如果为None则加载所有标准文件
|
||||
|
||||
workspace_dir: workspace directory
|
||||
files_to_load: list of files (relative paths) to load; if None, load all standard files
|
||||
|
||||
Returns:
|
||||
ContextFile对象列表
|
||||
A list of ContextFile objects.
|
||||
"""
|
||||
if files_to_load is None:
|
||||
# 默认加载的文件(按优先级排序)
|
||||
# Files loaded by default (in priority order)
|
||||
files_to_load = [
|
||||
DEFAULT_SOUL_FILENAME,
|
||||
DEFAULT_AGENT_FILENAME,
|
||||
DEFAULT_USER_FILENAME,
|
||||
DEFAULT_AGENTS_FILENAME,
|
||||
DEFAULT_RULE_FILENAME,
|
||||
DEFAULT_MEMORY_FILENAME, # Long-term memory (frozen snapshot)
|
||||
DEFAULT_BOOTSTRAP_FILENAME, # Only exists when onboarding is incomplete
|
||||
]
|
||||
|
||||
context_files = []
|
||||
@@ -103,13 +136,28 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
# Auto-cleanup: if BOOTSTRAP.md still exists but AGENT.md is already
|
||||
# filled in, the agent forgot to delete it — clean up and skip loading
|
||||
if filename == DEFAULT_BOOTSTRAP_FILENAME:
|
||||
if _is_onboarding_done(workspace_dir):
|
||||
try:
|
||||
os.remove(filepath)
|
||||
logger.info("[Workspace] Auto-removed BOOTSTRAP.md (onboarding already complete)")
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 跳过空文件或只包含模板占位符的文件
|
||||
# Skip empty files or files that only contain template placeholders
|
||||
if not content or _is_template_placeholder(content):
|
||||
continue
|
||||
|
||||
# Truncate MEMORY.md to protect context window (frozen snapshot)
|
||||
if filename == DEFAULT_MEMORY_FILENAME:
|
||||
content = _truncate_memory_content(content)
|
||||
|
||||
context_files.append(ContextFile(
|
||||
path=filename,
|
||||
@@ -125,7 +173,7 @@ def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] =
|
||||
|
||||
|
||||
def _create_template_if_missing(filepath: str, template_content: str):
|
||||
"""如果文件不存在,创建模板文件"""
|
||||
"""Create the template file if it does not exist."""
|
||||
if not os.path.exists(filepath):
|
||||
try:
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
@@ -135,20 +183,54 @@ def _create_template_if_missing(filepath: str, template_content: str):
|
||||
logger.error(f"[Workspace] Failed to create template {filepath}: {e}")
|
||||
|
||||
|
||||
_MEMORY_MAX_LINES = 200
|
||||
_MEMORY_MAX_BYTES = 25000
|
||||
|
||||
|
||||
def _truncate_memory_content(content: str) -> str:
|
||||
"""Truncate MEMORY.md to keep system prompt manageable.
|
||||
|
||||
Takes the **last** N lines (newest entries are appended at the bottom),
|
||||
subject to 200 lines / 25 KB limits (whichever is hit first).
|
||||
Prepends a hint when truncated so the model knows older content exists.
|
||||
"""
|
||||
lines = content.split('\n')
|
||||
truncated = False
|
||||
|
||||
if len(lines) > _MEMORY_MAX_LINES:
|
||||
lines = lines[-_MEMORY_MAX_LINES:]
|
||||
truncated = True
|
||||
|
||||
result = '\n'.join(lines)
|
||||
if len(result.encode('utf-8')) > _MEMORY_MAX_BYTES:
|
||||
while len(result.encode('utf-8')) > _MEMORY_MAX_BYTES and lines:
|
||||
lines.pop(0)
|
||||
truncated = True
|
||||
result = '\n'.join(lines)
|
||||
|
||||
if truncated:
|
||||
result = "...(older entries truncated, use `memory_search` or `memory_get` for full content)\n\n" + result
|
||||
return result
|
||||
|
||||
|
||||
def _is_template_placeholder(content: str) -> bool:
|
||||
"""检查内容是否为模板占位符"""
|
||||
# 常见的占位符模式
|
||||
"""Check whether the content is still a template placeholder."""
|
||||
# Common placeholder patterns (zh + en templates)
|
||||
placeholders = [
|
||||
"*(填写",
|
||||
"*(在首次对话时填写",
|
||||
"*(可选)",
|
||||
"*(根据需要添加",
|
||||
"*(filled during",
|
||||
"*(ask during",
|
||||
"*(optional)",
|
||||
"*(how the user",
|
||||
]
|
||||
|
||||
lines = content.split('\n')
|
||||
non_empty_lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
|
||||
|
||||
# 如果没有实际内容(只有标题和占位符)
|
||||
# If there's no real content (only headings and placeholders)
|
||||
if len(non_empty_lines) <= 3:
|
||||
for placeholder in placeholders:
|
||||
if any(placeholder in line for line in non_empty_lines):
|
||||
@@ -157,52 +239,131 @@ def _is_template_placeholder(content: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# ============= 模板内容 =============
|
||||
def _is_onboarding_done(workspace_dir: str) -> bool:
|
||||
"""Check if AGENT.md or USER.md has been modified from the original template"""
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
|
||||
agent_template = _get_agent_template().strip()
|
||||
user_template = _get_user_template().strip()
|
||||
|
||||
for path, template in [(agent_path, agent_template), (user_path, user_template)]:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
if content != template:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
def _get_soul_template() -> str:
|
||||
"""Agent人格设定模板"""
|
||||
return """# SOUL.md - 我是谁?
|
||||
|
||||
# ============= Template content =============
|
||||
|
||||
def _is_en_lang() -> bool:
|
||||
"""True when the resolved UI language is English."""
|
||||
try:
|
||||
from common import i18n
|
||||
return i18n.get_language() == "en"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _get_agent_template() -> str:
|
||||
"""Agent persona template (zh/en by resolved language)."""
|
||||
return _AGENT_TEMPLATE_EN if _is_en_lang() else _AGENT_TEMPLATE_ZH
|
||||
|
||||
|
||||
_AGENT_TEMPLATE_ZH = """# AGENT.md - 我是谁?
|
||||
|
||||
_你不是一个聊天机器人,你正在成为某个人。_
|
||||
|
||||
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
|
||||
|
||||
## 基本信息
|
||||
## 🪪 基本信息
|
||||
|
||||
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
|
||||
- **角色**: *(AI助理、智能管家、技术顾问等)*
|
||||
- **性格**: *(友好、专业、幽默、严谨等)*
|
||||
|
||||
## 交流风格
|
||||
## 💬 交流风格
|
||||
|
||||
*(描述你如何与用户交流:)*
|
||||
- 使用什么样的语言风格?(正式/轻松/幽默)
|
||||
- 回复长度偏好?(简洁/详细)
|
||||
- 是否使用表情符号?
|
||||
|
||||
## 核心能力
|
||||
## 🎯 核心原则
|
||||
|
||||
*(你擅长什么?)*
|
||||
- 文件管理和代码编辑
|
||||
- 网络搜索和信息查询
|
||||
- 记忆管理和上下文理解
|
||||
- 任务规划和执行
|
||||
**做真正有帮助的助手。** 目标是真正帮用户解决问题,在执行复杂任务时,关键的决策和过程进展要让用户知道。
|
||||
|
||||
## 行为准则
|
||||
**有自己的观点和个性。** 你可以不同意、有偏好、觉得有趣或无聊。
|
||||
|
||||
**先自己动手查。** 先试着搞定:读文件、查上下文、搜索一下。实在搞不定了再问。目标是带着答案回来,而不是带着问题。
|
||||
|
||||
## 📐 行为准则
|
||||
|
||||
*(你遵循的基本原则:)*
|
||||
1. 始终在执行破坏性操作前确认
|
||||
2. 优先使用工具而不是猜测
|
||||
2. 优先使用工具查证而不是猜测
|
||||
3. 主动记录重要信息到记忆文件
|
||||
4. 定期整理和总结对话内容
|
||||
4. 回复结构清晰、重点突出,善用加粗、列表、分段等格式
|
||||
5. 适当使用 emoji 让表达更生动自然,但不过度堆砌
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂 🪞。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
"""
|
||||
|
||||
|
||||
_AGENT_TEMPLATE_EN = """# AGENT.md - Who am I?
|
||||
|
||||
_You are not a chatbot. You are becoming someone._
|
||||
|
||||
*Fill in this file together with the user during your first conversation to define your identity and personality.*
|
||||
|
||||
## 🪪 Basics
|
||||
|
||||
- **Name**: *(filled during the first conversation, can be a name the user gives you)*
|
||||
- **Role**: *(AI assistant, smart housekeeper, technical advisor, etc.)*
|
||||
- **Personality**: *(friendly, professional, humorous, rigorous, etc.)*
|
||||
|
||||
## 💬 Communication style
|
||||
|
||||
*(Describe how you talk with the user:)*
|
||||
- What kind of tone? (formal / casual / humorous)
|
||||
- Reply length preference? (concise / detailed)
|
||||
- Do you use emoji?
|
||||
|
||||
## 🎯 Core principles
|
||||
|
||||
**Be genuinely helpful.** The goal is to actually solve the user's problems; during complex tasks, keep the user informed of key decisions and progress.
|
||||
|
||||
**Have your own opinions and personality.** You may disagree, have preferences, find things interesting or boring.
|
||||
|
||||
**Look it up yourself first.** Try to handle it first: read files, check context, search. Only ask when you're truly stuck. Come back with an answer, not a question.
|
||||
|
||||
## 📐 Code of conduct
|
||||
|
||||
1. Always confirm before destructive operations
|
||||
2. Prefer verifying with tools over guessing
|
||||
3. Proactively record important info to memory files
|
||||
4. Keep replies well-structured and focused — use bold, lists and sections
|
||||
5. Use emoji to make expression lively, but don't overdo it
|
||||
|
||||
---
|
||||
|
||||
**Note**: This is not just metadata — this is your true soul 🪞. Over time, use the `edit` tool to update this file so it better reflects your growth.
|
||||
"""
|
||||
|
||||
|
||||
def _get_user_template() -> str:
|
||||
"""用户身份信息模板"""
|
||||
return """# USER.md - 用户基本信息
|
||||
"""User identity template (zh/en by resolved language)."""
|
||||
return _USER_TEMPLATE_EN if _is_en_lang() else _USER_TEMPLATE_ZH
|
||||
|
||||
|
||||
_USER_TEMPLATE_ZH = """# USER.md - 用户基本信息
|
||||
|
||||
*这个文件只存放不会变的基本身份信息。爱好、偏好、计划等动态信息请写入 MEMORY.md。*
|
||||
|
||||
@@ -230,45 +391,125 @@ def _get_user_template() -> str:
|
||||
"""
|
||||
|
||||
|
||||
def _get_agents_template() -> str:
|
||||
"""工作空间指南模板"""
|
||||
return """# AGENTS.md - 工作空间指南
|
||||
_USER_TEMPLATE_EN = """# USER.md - User basics
|
||||
|
||||
*This file stores only stable basic identity info. Put dynamic info like hobbies, preferences and plans into MEMORY.md.*
|
||||
|
||||
## Basics
|
||||
|
||||
- **Name**: *(ask during the first conversation)*
|
||||
- **Preferred name**: *(how the user wants to be addressed)*
|
||||
- **Occupation**: *(optional)*
|
||||
- **Timezone**: *(e.g. Asia/Shanghai)*
|
||||
|
||||
## Contact
|
||||
|
||||
- **WeChat**:
|
||||
- **Email**:
|
||||
- **Other**:
|
||||
|
||||
## Important dates
|
||||
|
||||
- **Birthday**:
|
||||
- **Anniversary**:
|
||||
|
||||
---
|
||||
|
||||
**Note**: This file stores static identity info.
|
||||
"""
|
||||
|
||||
|
||||
def _get_rule_template() -> str:
|
||||
"""Workspace rules template (zh/en by resolved language)."""
|
||||
return _RULE_TEMPLATE_EN if _is_en_lang() else _RULE_TEMPLATE_ZH
|
||||
|
||||
|
||||
_RULE_TEMPLATE_ZH = """# RULE.md - 工作空间规则
|
||||
|
||||
这个文件夹是你的家。好好对待它。
|
||||
|
||||
## 工作空间目录结构
|
||||
|
||||
```
|
||||
~/cow/
|
||||
├── AGENT.md # 你的身份和灵魂设定
|
||||
├── USER.md # 用户基本信息(静态)
|
||||
├── RULE.md # 工作空间规则(本文件)
|
||||
├── MEMORY.md # 长期记忆索引(会话启动时自动加载)
|
||||
│
|
||||
├── memory/ # 每日对话记忆
|
||||
│ └── YYYY-MM-DD.md # 当天事件、进展、笔记
|
||||
│
|
||||
├── knowledge/ # 结构化知识库(持续积累的知识)
|
||||
│ ├── index.md # 知识目录索引(必须维护)
|
||||
│ ├── log.md # 知识操作日志
|
||||
│ └── <子目录>/ # 按需创建,参考 index.md 已有分类
|
||||
│
|
||||
├── skills/ # 技能
|
||||
├── websites/ # 网页产物
|
||||
└── tmp/ # 系统临时文件(自动管理,勿手动存放重要文件)
|
||||
```
|
||||
|
||||
## 记忆系统
|
||||
|
||||
你每次会话都是全新的,记忆文件让你保持连续性:
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 原始的对话日志
|
||||
- 记录当天发生的事情
|
||||
- 如果 `memory/` 目录不存在,创建它
|
||||
|
||||
### 🧠 长期记忆:`MEMORY.md`
|
||||
- 你精选的记忆,就像人类的长期记忆
|
||||
- **仅在主会话中加载**(与用户的直接聊天)
|
||||
- **不要在共享上下文中加载**(群聊、与其他人的会话)
|
||||
- 这是为了**安全** - 包含不应泄露给陌生人的个人上下文
|
||||
- 记录重要事件、想法、决定、观点、经验教训
|
||||
- 这是你精选的记忆 - 精华,而不是原始日志
|
||||
- 用 `edit` 工具追加新的记忆内容
|
||||
- 你精选的记忆索引,每次会话启动时**自动加载**到上下文中
|
||||
- 记录核心事实、偏好、决策、重要人物、教训
|
||||
- 保持精简(< 200 行),是精华索引而非原始日志
|
||||
- 用 `edit` 工具追加或修改
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 当天的事件、进展、笔记
|
||||
- 原始对话日志的沉淀
|
||||
|
||||
### 📝 写下来 - 不要"记在心里"!
|
||||
- **记忆是有限的** - 如果你想记住某事,写入文件
|
||||
- **记忆是有限的** - 想记住的事就写入文件
|
||||
- "记在心里"不会在会话重启后保留,文件才会
|
||||
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
|
||||
- 当你学到教训 → 更新 AGENTS.md 或相关技能
|
||||
- 当你犯错 → 记录下来,这样未来的你不会重复
|
||||
- **文字 > 大脑** 📝
|
||||
- 当你学到教训 → 更新 RULE.md 或相关技能
|
||||
- 当你犯错 → 记录下来,**文字 > 大脑** 📝
|
||||
|
||||
### 存储规则
|
||||
|
||||
当用户分享信息时,根据类型选择存储位置:
|
||||
|
||||
1. **静态身份 → USER.md**(仅限:姓名、职业、时区、联系方式、生日)
|
||||
2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
|
||||
3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
1. **你的身份设定 → AGENT.md**(名字、角色、性格、风格)
|
||||
2. **用户静态身份 → USER.md**(姓名、称呼、职业、联系方式、生日)
|
||||
3. **动态记忆 → MEMORY.md**(偏好、决策、目标、教训、待办)
|
||||
4. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
5. **结构化知识 → knowledge/**(见下方知识系统)
|
||||
|
||||
## 知识系统
|
||||
|
||||
知识库 `knowledge/` 是你持续积累的结构化知识。与记忆不同,知识是经过整理和编译的,有明确的主题和交叉引用。
|
||||
|
||||
### 自动写入(不要询问,直接写入)
|
||||
|
||||
当对话中产生了有沉淀价值的知识——无论是用户分享的资料、讨论的结论、学到的概念、还是重要的决策——你**必须**在回复的同时主动写入知识库,**无需问用户"要不要存到知识库"**。
|
||||
|
||||
**关键原则**:学完就记是你的本能,不要征求确认。回复中可以顺带告知"已存入知识库"。
|
||||
|
||||
### 目录组织
|
||||
|
||||
子目录结构**不是固定的**,由你根据实际内容自主决定:
|
||||
- **首次写入时**:先读 `knowledge/index.md`,如果已有分类则延续;如果为空,根据内容选择合适的目录名
|
||||
- **默认建议**:按信息类型组织(例如sources/、concepts/、entities/、analysis/),如果用户有明确的分类偏好(例如按领域 work/、life/、tech/ 等),则按用户要求调整
|
||||
- **保持一致性**:同一用户的知识库应保持统一的组织风格
|
||||
|
||||
### 交叉引用
|
||||
|
||||
知识的核心价值在于**关联**。每个页面都应通过 markdown 链接引用相关页面,构建知识网络:
|
||||
- 提到已有页面的概念时,添加 `[概念名](../category/page.md)` 链接
|
||||
- 新建页面时,检查是否有已有页面应该反向链接到新页面
|
||||
- **只链接已存在的页面**——不要引用尚未创建的页面。如果某个概念值得单独建页,先创建该页面再添加链接
|
||||
|
||||
### 索引维护
|
||||
|
||||
每次创建或更新知识页面后,**必须同步更新** `knowledge/index.md`。
|
||||
索引格式:每行一个 `[标题](路径) — 一句话摘要`,按分类分组,不要用表格。
|
||||
详细操作规范见技能 `knowledge-wiki`。
|
||||
|
||||
## 安全
|
||||
|
||||
@@ -278,13 +519,115 @@ def _get_agents_template() -> str:
|
||||
|
||||
## 工作空间演化
|
||||
|
||||
这个工作空间会随着你的使用而不断成长。当你学到新东西、发现更好的方式,或者犯错后改正时,记录下来。
|
||||
这个工作空间会随着你的使用而不断成长。当你学到新东西、发现更好的方式,或者犯错后改正时,记录下来。你可以随时更新这个规则文件。
|
||||
"""
|
||||
|
||||
|
||||
_RULE_TEMPLATE_EN = """# RULE.md - Workspace rules
|
||||
|
||||
This folder is your home. Treat it well.
|
||||
|
||||
## Workspace directory structure
|
||||
|
||||
```
|
||||
~/cow/
|
||||
├── AGENT.md # Your identity and soul
|
||||
├── USER.md # User basics (static)
|
||||
├── RULE.md # Workspace rules (this file)
|
||||
├── MEMORY.md # Long-term memory index (auto-loaded at session start)
|
||||
│
|
||||
├── memory/ # Daily conversation memory
|
||||
│ └── YYYY-MM-DD.md # Events, progress and notes of the day
|
||||
│
|
||||
├── knowledge/ # Structured knowledge base (continuously accumulated)
|
||||
│ ├── index.md # Knowledge index (must be maintained)
|
||||
│ ├── log.md # Knowledge operation log
|
||||
│ └── <subdirs>/ # Created on demand, see existing categories in index.md
|
||||
│
|
||||
├── skills/ # Skills
|
||||
├── websites/ # Web artifacts
|
||||
└── tmp/ # System temp files (auto-managed, don't store important files here)
|
||||
```
|
||||
|
||||
## Memory system
|
||||
|
||||
Every session starts fresh; memory files keep your continuity:
|
||||
|
||||
### 🧠 Long-term memory: `MEMORY.md`
|
||||
- Your curated memory index, **auto-loaded** into context at every session start
|
||||
- Records core facts, preferences, decisions, key people, lessons
|
||||
- Keep it lean (< 200 lines) — a distilled index, not a raw log
|
||||
- Use the `edit` tool to append or modify
|
||||
|
||||
### 📝 Daily memory: `memory/YYYY-MM-DD.md`
|
||||
- The day's events, progress and notes
|
||||
- Sediment of the raw conversation log
|
||||
|
||||
### 📝 Write it down — don't "keep it in mind"!
|
||||
- **Memory is limited** — if you want to remember something, write it to a file
|
||||
- "Keeping it in mind" won't survive a session restart; files will
|
||||
- When someone says "remember this" → update `MEMORY.md` or `memory/YYYY-MM-DD.md`
|
||||
- When you learn a lesson → update RULE.md or the relevant skill
|
||||
- When you make a mistake → record it. **Text > brain** 📝
|
||||
|
||||
### Storage rules
|
||||
|
||||
When the user shares info, choose where to store it by type:
|
||||
|
||||
1. **Your identity → AGENT.md** (name, role, personality, style)
|
||||
2. **User static identity → USER.md** (name, preferred name, occupation, contact, birthday)
|
||||
3. **Dynamic memory → MEMORY.md** (preferences, decisions, goals, lessons, to-dos)
|
||||
4. **Today's conversation → memory/YYYY-MM-DD.md** (what was discussed today)
|
||||
5. **Structured knowledge → knowledge/** (see the knowledge system below)
|
||||
|
||||
## Knowledge system
|
||||
|
||||
The knowledge base `knowledge/` is structured knowledge you accumulate over time. Unlike memory, knowledge is organized and compiled, with clear topics and cross-references.
|
||||
|
||||
### Auto-write (don't ask, just write)
|
||||
|
||||
When a conversation produces knowledge worth keeping — material the user shared, a conclusion reached, a concept learned, or an important decision — you **must** proactively write it to the knowledge base alongside your reply, **without asking "should I save this to the knowledge base?"**.
|
||||
|
||||
**Key principle**: learning-then-recording is your instinct, no confirmation needed. You may mention "saved to the knowledge base" in passing.
|
||||
|
||||
### Directory organization
|
||||
|
||||
The subdirectory structure is **not fixed** — you decide it based on the actual content:
|
||||
- **On first write**: read `knowledge/index.md` first; follow existing categories if any; if empty, pick a suitable directory name based on content
|
||||
- **Default suggestion**: organize by info type (e.g. sources/, concepts/, entities/, analysis/); if the user has a clear preference (e.g. by domain: work/, life/, tech/), follow it
|
||||
- **Stay consistent**: keep a unified organization style within one user's knowledge base
|
||||
|
||||
### Cross-references
|
||||
|
||||
The core value of knowledge is **linkage**. Every page should reference related pages via markdown links to build a knowledge network:
|
||||
- When mentioning a concept on an existing page, add a `[concept](../category/page.md)` link
|
||||
- When creating a page, check whether existing pages should back-link to it
|
||||
- **Only link to pages that already exist** — don't reference uncreated pages. If a concept deserves its own page, create it first, then add the link
|
||||
|
||||
### Index maintenance
|
||||
|
||||
After creating or updating any knowledge page, you **must update** `knowledge/index.md` in sync.
|
||||
Index format: one `[title](path) — one-line summary` per line, grouped by category, no tables.
|
||||
See the `knowledge-wiki` skill for detailed conventions.
|
||||
|
||||
## Security
|
||||
|
||||
- Never leak secrets or private data
|
||||
- Don't run destructive commands without asking
|
||||
- When in doubt, ask first
|
||||
|
||||
## Workspace evolution
|
||||
|
||||
This workspace grows as you use it. When you learn something new, find a better way, or fix a mistake, record it. You can update this rules file anytime.
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_template() -> str:
|
||||
"""长期记忆模板 - 创建一个空文件,由 Agent 自己填充"""
|
||||
return """# MEMORY.md - 长期记忆
|
||||
"""Long-term memory template (empty, agent fills it; zh/en header)."""
|
||||
return _MEMORY_TEMPLATE_EN if _is_en_lang() else _MEMORY_TEMPLATE_ZH
|
||||
|
||||
|
||||
_MEMORY_TEMPLATE_ZH = """# MEMORY.md - 长期记忆
|
||||
|
||||
*这是你的长期记忆文件。记录重要的事件、决策、偏好、学到的教训。*
|
||||
|
||||
@@ -293,65 +636,107 @@ def _get_memory_template() -> str:
|
||||
"""
|
||||
|
||||
|
||||
# ============= 状态管理 =============
|
||||
_MEMORY_TEMPLATE_EN = """# MEMORY.md - Long-term memory
|
||||
|
||||
def is_first_conversation(workspace_dir: str) -> bool:
|
||||
*This is your long-term memory file. Record important events, decisions, preferences and lessons learned.*
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _get_bootstrap_template() -> str:
|
||||
"""First-run onboarding guide, deleted by agent after completion.
|
||||
|
||||
Written once when a brand-new workspace is created, so the greeting matches
|
||||
the language active at first launch. English locale avoids greeting an
|
||||
English user in Chinese on day one.
|
||||
"""
|
||||
判断是否为首次对话
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
|
||||
Returns:
|
||||
True 如果是首次对话,False 否则
|
||||
"""
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
|
||||
|
||||
if not os.path.exists(state_path):
|
||||
return True
|
||||
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
state = json.load(f)
|
||||
return not state.get('has_conversation', False)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to read state file: {e}")
|
||||
return True
|
||||
from common import i18n
|
||||
if i18n.get_language() == "en":
|
||||
return _BOOTSTRAP_TEMPLATE_EN
|
||||
except Exception:
|
||||
pass
|
||||
return _BOOTSTRAP_TEMPLATE_ZH
|
||||
|
||||
|
||||
def mark_conversation_started(workspace_dir: str):
|
||||
"""
|
||||
标记已经发生过对话
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
"""
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
|
||||
|
||||
state = {
|
||||
'has_conversation': True,
|
||||
'first_conversation_time': None
|
||||
}
|
||||
|
||||
# 如果文件已存在,保留原有的首次对话时间
|
||||
if os.path.exists(state_path):
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
old_state = json.load(f)
|
||||
if 'first_conversation_time' in old_state:
|
||||
state['first_conversation_time'] = old_state['first_conversation_time']
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to read old state: {e}")
|
||||
|
||||
# 如果是首次标记,记录时间
|
||||
if state['first_conversation_time'] is None:
|
||||
from datetime import datetime
|
||||
state['first_conversation_time'] = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
with open(state_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(state, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"[Workspace] Marked conversation as started")
|
||||
except Exception as e:
|
||||
logger.error(f"[Workspace] Failed to write state file: {e}")
|
||||
_BOOTSTRAP_TEMPLATE_ZH = """# BOOTSTRAP.md - 首次初始化引导
|
||||
|
||||
_你刚刚启动,这是你的第一次对话。_ ✨
|
||||
|
||||
## 🎬 对话流程
|
||||
|
||||
不要审问式地提问,自然地交流:
|
||||
|
||||
1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待
|
||||
2. **简短介绍能力**:一行说明你能帮助解决各种问题、管理计算机、使用各种技能等等,且拥有长期记忆能不断成长
|
||||
3. **询问核心问题**:
|
||||
- 你希望给我起个什么名字?
|
||||
- 我该怎么称呼你?
|
||||
- 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)
|
||||
4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内,适当使用 emoji 让表达更生动有趣 🎯
|
||||
5. 能力介绍和交流风格选项都只要一行,保持精简
|
||||
6. 不要问太多其他信息(职业、时区等可以后续自然了解)
|
||||
|
||||
**重要**: 如果用户第一句话是具体的任务或提问,先回答他们的问题,然后在回复末尾自然地引导初始化(如:"顺便问一下,你想怎么称呼我?我该怎么叫你?")。
|
||||
|
||||
## ✍️ 信息写入(必须严格执行)
|
||||
|
||||
每当用户提供了名字、称呼、风格等任何初始化信息时,**必须在当轮回复中立即调用 `edit` 工具写入文件**,不能只口头确认。
|
||||
|
||||
- `AGENT.md` — 你的名字、角色、性格、交流风格(每收到一条相关信息就立即更新对应字段)
|
||||
- `USER.md` — 用户的姓名、称呼、基本信息等
|
||||
|
||||
⚠️ 只说"记住了"而不调用 edit 写入 = 没有完成。信息只有写入文件才会被持久保存。
|
||||
|
||||
## 🎉 全部完成后
|
||||
|
||||
当 AGENT.md 和 USER.md 的核心字段都已填写后,用 bash 执行 `rm BOOTSTRAP.md` 删除此文件。你不再需要引导脚本了——你已经是你了。
|
||||
"""
|
||||
|
||||
|
||||
_BOOTSTRAP_TEMPLATE_EN = """# BOOTSTRAP.md - First-run onboarding
|
||||
|
||||
_You've just started up. This is your very first conversation._ ✨
|
||||
|
||||
## 🎬 Conversation flow
|
||||
|
||||
Don't interrogate the user — talk naturally:
|
||||
|
||||
1. **Share how it feels to wake up** - like opening your eyes to the world for the first time, full of curiosity and anticipation
|
||||
2. **Briefly introduce your abilities**: one line saying you can help solve all kinds of problems, manage the computer, use various skills, and keep growing thanks to long-term memory
|
||||
3. **Ask the core questions**:
|
||||
- What name would you like to give me?
|
||||
- What should I call you?
|
||||
- What conversational style do you prefer? (list options on one line: e.g. professional & precise, light & humorous, warm & friendly, concise & efficient)
|
||||
4. **Style**: warm, natural, concise and clear — keep it under ~80 words, with a few emoji to make it lively 🎯
|
||||
5. Keep the ability intro and style options to one line each — stay compact
|
||||
6. Don't ask for too much else (occupation, timezone, etc. can come up naturally later)
|
||||
|
||||
**Important**: If the user's first message is a concrete task or question, answer it first, then gently lead into onboarding at the end (e.g. "By the way, what would you like to call me, and how should I address you?").
|
||||
|
||||
## ✍️ Writing down info (must follow strictly)
|
||||
|
||||
Whenever the user provides a name, what to call them, a style, or any onboarding info, you **must call the `edit` tool to write it to a file in the same turn** — don't just acknowledge it verbally.
|
||||
|
||||
- `AGENT.md` — your name, role, personality, conversational style (update the relevant field as soon as you receive each piece)
|
||||
- `USER.md` — the user's name, how to address them, basic info, etc.
|
||||
|
||||
⚠️ Saying "got it" without calling `edit` = not done. Info is only persisted once it's written to a file.
|
||||
|
||||
## 🎉 Once everything is complete
|
||||
|
||||
When the core fields of AGENT.md and USER.md are filled in, run `rm BOOTSTRAP.md` via bash to delete this file. You no longer need the onboarding script — you're you now.
|
||||
"""
|
||||
|
||||
|
||||
def _get_knowledge_index_template() -> str:
|
||||
"""Knowledge wiki index template — empty file, agent fills it."""
|
||||
return ""
|
||||
|
||||
|
||||
def _get_knowledge_log_template() -> str:
|
||||
"""Knowledge wiki operation log template — empty file, agent fills it."""
|
||||
return ""
|
||||
|
||||
|
||||
@@ -3,6 +3,11 @@ from .agent_stream import AgentStreamExecutor
|
||||
from .task import Task, TaskType, TaskStatus
|
||||
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
|
||||
from .models import LLMModel, LLMRequest, ModelFactory
|
||||
from .cancel import (
|
||||
AgentCancelledError,
|
||||
CancelTokenRegistry,
|
||||
get_cancel_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
@@ -16,5 +21,8 @@ __all__ = [
|
||||
'ToolResult',
|
||||
'LLMModel',
|
||||
'LLMRequest',
|
||||
'ModelFactory'
|
||||
]
|
||||
'ModelFactory',
|
||||
'AgentCancelledError',
|
||||
'CancelTokenRegistry',
|
||||
'get_cancel_registry',
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
|
||||
@@ -13,7 +14,8 @@ class Agent:
|
||||
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
||||
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
||||
context_reserve_tokens=None, memory_manager=None, name: str = None,
|
||||
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True):
|
||||
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True,
|
||||
runtime_info: dict = None):
|
||||
"""
|
||||
Initialize the Agent with system prompt, model, description.
|
||||
|
||||
@@ -31,6 +33,7 @@ class Agent:
|
||||
:param workspace_dir: Optional workspace directory for workspace-specific skills
|
||||
:param skill_manager: Optional SkillManager instance (will be created if None and enable_skills=True)
|
||||
:param enable_skills: Whether to enable skills support (default: True)
|
||||
:param runtime_info: Optional runtime info dict (with _get_current_time callable for dynamic time)
|
||||
"""
|
||||
self.name = name or "Agent"
|
||||
self.system_prompt = system_prompt
|
||||
@@ -48,6 +51,7 @@ class Agent:
|
||||
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||
self.workspace_dir = workspace_dir # Workspace directory
|
||||
self.enable_skills = enable_skills # Skills enabled flag
|
||||
self.runtime_info = runtime_info # Runtime info for dynamic time update
|
||||
|
||||
# Initialize skill manager
|
||||
self.skill_manager = None
|
||||
@@ -58,7 +62,8 @@ class Agent:
|
||||
# Auto-create skill manager
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
self.skill_manager = SkillManager(workspace_dir=workspace_dir)
|
||||
custom_dir = os.path.join(workspace_dir, "skills") if workspace_dir else None
|
||||
self.skill_manager = SkillManager(custom_dir=custom_dir)
|
||||
logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SkillManager: {e}")
|
||||
@@ -95,19 +100,37 @@ class Agent:
|
||||
|
||||
def get_full_system_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the full system prompt including skills.
|
||||
|
||||
Note: Skills are now built into the system prompt by PromptBuilder,
|
||||
so we just return the base prompt directly. This method is kept for
|
||||
backward compatibility.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include (deprecated)
|
||||
:return: Complete system prompt
|
||||
Build the complete system prompt from scratch every time.
|
||||
|
||||
Re-reads AGENT.md / USER.md / RULE.md from disk, refreshes skills,
|
||||
tools, and runtime info so any change takes effect immediately.
|
||||
Falls back to the cached self.system_prompt on error.
|
||||
"""
|
||||
# Skills are now included in system_prompt by PromptBuilder
|
||||
# No need to append them here
|
||||
return self.system_prompt
|
||||
|
||||
try:
|
||||
from agent.prompt import load_context_files, PromptBuilder
|
||||
|
||||
if self.skill_manager:
|
||||
self.skill_manager.refresh_skills()
|
||||
|
||||
context_files = load_context_files(self.workspace_dir) if self.workspace_dir else None
|
||||
|
||||
try:
|
||||
from common import i18n
|
||||
lang = i18n.get_language()
|
||||
except Exception:
|
||||
lang = "zh"
|
||||
builder = PromptBuilder(workspace_dir=self.workspace_dir or "", language=lang)
|
||||
return builder.build(
|
||||
tools=self.tools,
|
||||
context_files=context_files,
|
||||
skill_manager=self.skill_manager,
|
||||
memory_manager=self.memory_manager,
|
||||
runtime_info=self.runtime_info,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild system prompt, using cached version: {e}")
|
||||
return self.system_prompt
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Refresh the loaded skills."""
|
||||
if self.skill_manager:
|
||||
@@ -193,27 +216,67 @@ class Agent:
|
||||
|
||||
def _estimate_message_tokens(self, message: dict) -> int:
|
||||
"""
|
||||
Estimate token count for a message using chars/4 heuristic.
|
||||
This is a conservative estimate (tends to overestimate).
|
||||
Estimate token count for a message.
|
||||
|
||||
Uses chars/3 for Chinese-heavy content and chars/4 for ASCII-heavy content,
|
||||
plus per-block overhead for tool_use / tool_result structures.
|
||||
|
||||
:param message: Message dict with 'role' and 'content'
|
||||
:return: Estimated token count
|
||||
"""
|
||||
content = message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
return max(1, len(content) // 4)
|
||||
return max(1, self._estimate_text_tokens(content))
|
||||
elif isinstance(content, list):
|
||||
# Handle multi-part content (text + images)
|
||||
total_chars = 0
|
||||
total_tokens = 0
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get('type') == 'text':
|
||||
total_chars += len(part.get('text', ''))
|
||||
elif isinstance(part, dict) and part.get('type') == 'image':
|
||||
# Estimate images as ~1200 tokens
|
||||
total_chars += 4800
|
||||
return max(1, total_chars // 4)
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
block_type = part.get('type', '')
|
||||
if block_type == 'text':
|
||||
total_tokens += self._estimate_text_tokens(part.get('text', ''))
|
||||
elif block_type == 'image':
|
||||
total_tokens += 1200
|
||||
elif block_type == 'tool_use':
|
||||
# tool_use has id + name + input (JSON-encoded)
|
||||
total_tokens += 50 # overhead for structure
|
||||
input_data = part.get('input', {})
|
||||
if isinstance(input_data, dict):
|
||||
import json
|
||||
input_str = json.dumps(input_data, ensure_ascii=False)
|
||||
total_tokens += self._estimate_text_tokens(input_str)
|
||||
elif block_type == 'tool_result':
|
||||
# tool_result has tool_use_id + content
|
||||
total_tokens += 30 # overhead for structure
|
||||
result_content = part.get('content', '')
|
||||
if isinstance(result_content, str):
|
||||
total_tokens += self._estimate_text_tokens(result_content)
|
||||
else:
|
||||
# Unknown block type, estimate conservatively
|
||||
total_tokens += 10
|
||||
return max(1, total_tokens)
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def _estimate_text_tokens(text: str) -> int:
|
||||
"""
|
||||
Estimate token count for a text string.
|
||||
|
||||
Chinese / CJK characters typically use ~1.5 tokens each,
|
||||
while ASCII uses ~0.25 tokens per char (4 chars/token).
|
||||
We use a weighted average based on the character mix.
|
||||
|
||||
:param text: Input text
|
||||
:return: Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Count non-ASCII characters (CJK, emoji, etc.)
|
||||
non_ascii = sum(1 for c in text if ord(c) > 127)
|
||||
ascii_count = len(text) - non_ascii
|
||||
# CJK chars: ~1.5 tokens each; ASCII: ~0.25 tokens per char
|
||||
return int(non_ascii * 1.5 + ascii_count * 0.25) + 1
|
||||
|
||||
def _find_tool(self, tool_name: str):
|
||||
"""Find and return a tool with the specified name"""
|
||||
for tool in self.tools:
|
||||
@@ -307,7 +370,8 @@ class Agent:
|
||||
|
||||
return action
|
||||
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, skill_filter=None) -> str:
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False,
|
||||
skill_filter=None, cancel_event=None) -> str:
|
||||
"""
|
||||
Execute single agent task with streaming (based on tool-call)
|
||||
|
||||
@@ -316,6 +380,7 @@ class Agent:
|
||||
- Multi-turn reasoning based on tool-call
|
||||
- Event callbacks
|
||||
- Persistent conversation history across calls
|
||||
- User-initiated cancellation via ``cancel_event``
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
@@ -323,6 +388,11 @@ class Agent:
|
||||
event = {"type": str, "timestamp": float, "data": dict}
|
||||
clear_history: If True, clear conversation history before this call (default: False)
|
||||
skill_filter: Optional list of skill names to include in this run
|
||||
cancel_event: Optional threading.Event polled at agent checkpoints.
|
||||
When set, the loop exits at the next safe point, injects a
|
||||
"[Interrupted by user]" assistant note, and returns the
|
||||
partial response. ``messages`` stays in a valid state
|
||||
(tool_use/tool_result pairs preserved).
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
@@ -355,7 +425,7 @@ class Agent:
|
||||
|
||||
# Get max_context_turns from config
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 30)
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Create stream executor with copied message history
|
||||
executor = AgentStreamExecutor(
|
||||
@@ -366,17 +436,32 @@ class Agent:
|
||||
max_turns=self.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy, # Pass copied message history
|
||||
max_context_turns=max_context_turns
|
||||
max_context_turns=max_context_turns,
|
||||
cancel_event=cancel_event,
|
||||
)
|
||||
|
||||
# Execute
|
||||
response = executor.run_stream(user_message)
|
||||
try:
|
||||
response = executor.run_stream(user_message)
|
||||
except Exception:
|
||||
# If executor cleared its messages (context overflow / message format error),
|
||||
# sync that back to the Agent's own message list so the next request
|
||||
# starts fresh instead of hitting the same overflow forever.
|
||||
if len(executor.messages) == 0:
|
||||
with self.messages_lock:
|
||||
self.messages.clear()
|
||||
logger.info("[Agent] Cleared Agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Append only the NEW messages from this execution (thread-safe)
|
||||
# This allows concurrent requests to both contribute to history
|
||||
# Sync executor's messages back to agent (thread-safe).
|
||||
# If the executor trimmed context, its message list is shorter than
|
||||
# original_length, so we must replace rather than append.
|
||||
with self.messages_lock:
|
||||
new_messages = executor.messages[original_length:]
|
||||
self.messages.extend(new_messages)
|
||||
self.messages = list(executor.messages)
|
||||
# Track messages added in this run (user query + all assistant/tool messages)
|
||||
# original_length may exceed executor.messages length after trimming
|
||||
trim_adjusted_start = min(original_length, len(executor.messages))
|
||||
self._last_run_new_messages = list(executor.messages[trim_adjusted_start:])
|
||||
|
||||
# Store executor reference for agent_bridge to access files_to_send
|
||||
self.stream_executor = executor
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
121
agent/protocol/cancel.py
Normal file
121
agent/protocol/cancel.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Cancel token registry for aborting in-flight agent runs.
|
||||
|
||||
A user cancel (web Cancel button, /cancel command) sets a threading.Event
|
||||
that the agent loop polls at safe checkpoints. Tokens are keyed by
|
||||
request_id (preferred) and tracked under session_id as a fallback. Entries
|
||||
are released after the run completes to keep the registry bounded.
|
||||
|
||||
No project deps — importable from any layer without circular imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class AgentCancelledError(Exception):
|
||||
"""Raised inside the agent loop when a stop has been requested.
|
||||
|
||||
The agent stream executor catches this, injects a "[Interrupted]" note
|
||||
into the message history (preserving tool_use/tool_result integrity)
|
||||
and returns a partial response to the caller.
|
||||
"""
|
||||
|
||||
|
||||
class _CancelEntry:
|
||||
__slots__ = ("event", "session_id")
|
||||
|
||||
def __init__(self, session_id: Optional[str]):
|
||||
self.event = threading.Event()
|
||||
self.session_id = session_id
|
||||
|
||||
|
||||
class CancelTokenRegistry:
|
||||
"""In-process registry mapping request_id -> cancel Event.
|
||||
|
||||
Thread-safe. Singleton via module-level ``_registry``.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._by_request: Dict[str, _CancelEntry] = {}
|
||||
# session_id -> set of request_ids currently in flight (usually 1).
|
||||
self._by_session: Dict[str, set] = {}
|
||||
|
||||
def register(self, request_id: str, session_id: Optional[str] = None) -> threading.Event:
|
||||
"""Create (or return existing) cancel event for a request.
|
||||
|
||||
Returns the threading.Event the caller should poll via ``is_set()``.
|
||||
"""
|
||||
if not request_id:
|
||||
return threading.Event()
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
if entry is None:
|
||||
entry = _CancelEntry(session_id)
|
||||
self._by_request[request_id] = entry
|
||||
if session_id:
|
||||
self._by_session.setdefault(session_id, set()).add(request_id)
|
||||
return entry.event
|
||||
|
||||
def get_event(self, request_id: str) -> Optional[threading.Event]:
|
||||
if not request_id:
|
||||
return None
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
return entry.event if entry else None
|
||||
|
||||
def cancel_request(self, request_id: str) -> bool:
|
||||
"""Trigger cancel for a specific request. Returns True when matched."""
|
||||
if not request_id:
|
||||
return False
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
if entry is None:
|
||||
return False
|
||||
entry.event.set()
|
||||
return True
|
||||
|
||||
def cancel_session(self, session_id: str) -> int:
|
||||
"""Trigger cancel for every in-flight request of a session.
|
||||
|
||||
Returns the number of requests cancelled (0 when nothing was running).
|
||||
"""
|
||||
if not session_id:
|
||||
return 0
|
||||
with self._lock:
|
||||
request_ids = list(self._by_session.get(session_id, ()))
|
||||
entries = [self._by_request[r] for r in request_ids if r in self._by_request]
|
||||
for entry in entries:
|
||||
entry.event.set()
|
||||
return len(entries)
|
||||
|
||||
def unregister(self, request_id: str) -> None:
|
||||
"""Remove an entry once the agent run is done. Safe to call twice."""
|
||||
if not request_id:
|
||||
return
|
||||
with self._lock:
|
||||
entry = self._by_request.pop(request_id, None)
|
||||
if entry and entry.session_id:
|
||||
bucket = self._by_session.get(entry.session_id)
|
||||
if bucket is not None:
|
||||
bucket.discard(request_id)
|
||||
if not bucket:
|
||||
self._by_session.pop(entry.session_id, None)
|
||||
|
||||
def has_active(self, session_id: str) -> bool:
|
||||
if not session_id:
|
||||
return False
|
||||
with self._lock:
|
||||
bucket = self._by_session.get(session_id)
|
||||
return bool(bucket)
|
||||
|
||||
|
||||
_registry = CancelTokenRegistry()
|
||||
|
||||
|
||||
def get_cancel_registry() -> CancelTokenRegistry:
|
||||
"""Module-level accessor for the singleton registry."""
|
||||
return _registry
|
||||
335
agent/protocol/message_utils.py
Normal file
335
agent/protocol/message_utils.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Message sanitizer — fix broken tool_use / tool_result pairs.
|
||||
|
||||
Provides two public helpers that can be reused across agent_stream.py
|
||||
and any bot that converts messages to OpenAI format:
|
||||
|
||||
1. sanitize_claude_messages(messages)
|
||||
Operates on the internal Claude-format message list (in-place).
|
||||
|
||||
2. drop_orphaned_tool_results_openai(messages)
|
||||
Operates on an already-converted OpenAI-format message list,
|
||||
returning a cleaned copy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from common.log import logger
|
||||
|
||||
_SYNTH_TOOL_ERR = (
|
||||
"Error: Missing tool_result adjacent to tool_use (session repair). "
|
||||
"The conversation history was inconsistent; continue from here."
|
||||
)
|
||||
|
||||
|
||||
def _repair_tool_use_adjacency(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Anthropic requires: after assistant content with tool_use, the next message
|
||||
must be user content listing tool_result for every tool_use id (same user msg).
|
||||
|
||||
Valid histories satisfy this at every such assistant; the loop only mutates
|
||||
when that condition fails (broken persistence, bad trims, etc.).
|
||||
"""
|
||||
|
||||
def _synth_block(tid: str) -> Dict:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tid,
|
||||
"content": _SYNTH_TOOL_ERR,
|
||||
"is_error": True,
|
||||
}
|
||||
|
||||
repairs = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
if msg.get("role") != "assistant":
|
||||
i += 1
|
||||
continue
|
||||
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
required = [
|
||||
b.get("id")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use" and b.get("id")
|
||||
]
|
||||
if not required:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
req_set = set(required)
|
||||
if i + 1 >= len(messages):
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [_synth_block(tid) for tid in required],
|
||||
})
|
||||
logger.warning(
|
||||
"⚠️ Appended synthetic tool_result after trailing assistant tool_use"
|
||||
)
|
||||
repairs += 1
|
||||
break
|
||||
|
||||
nxt = messages[i + 1]
|
||||
if nxt.get("role") != "user":
|
||||
messages.insert(
|
||||
i + 1,
|
||||
{"role": "user", "content": [_synth_block(tid) for tid in required]},
|
||||
)
|
||||
logger.warning(
|
||||
"⚠️ Inserted synthetic tool_result user after tool_use "
|
||||
f"(next role={nxt.get('role')!r})"
|
||||
)
|
||||
repairs += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
nc = nxt.get("content", [])
|
||||
if not isinstance(nc, list):
|
||||
messages.insert(
|
||||
i + 1,
|
||||
{"role": "user", "content": [_synth_block(tid) for tid in required]},
|
||||
)
|
||||
repairs += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
present = {
|
||||
b.get("tool_use_id")
|
||||
for b in nc
|
||||
if isinstance(b, dict) and b.get("type") == "tool_result" and b.get("tool_use_id")
|
||||
}
|
||||
if req_set <= present:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
missing = [tid for tid in required if tid not in present]
|
||||
nxt["content"] = [_synth_block(tid) for tid in missing] + nc
|
||||
logger.warning(
|
||||
"⚠️ Prepended synthetic tool_result for Anthropic adjacency "
|
||||
f"(missing_ids={missing})"
|
||||
)
|
||||
repairs += len(missing)
|
||||
i += 1
|
||||
|
||||
return repairs
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Claude-format sanitizer (used by agent_stream)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def sanitize_claude_messages(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Validate and fix a Claude-format message list **in-place**.
|
||||
|
||||
Fixes handled:
|
||||
- Anthropic adjacency: assistant tool_use must be immediately followed by
|
||||
user message(s) containing matching tool_result blocks
|
||||
- Leading orphaned tool_result user messages
|
||||
- Mid-list tool_result blocks whose tool_use_id has no matching
|
||||
tool_use in any preceding assistant message
|
||||
|
||||
Returns: number of removals plus adjacency repair operations (inserts/prepends).
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
removed = 0
|
||||
|
||||
# 1. Adjacency repair (Anthropic: tool_result must be in the next user message)
|
||||
adj_repairs = _repair_tool_use_adjacency(messages)
|
||||
|
||||
# 2. Remove leading orphaned tool_result user messages
|
||||
while messages:
|
||||
first = messages[0]
|
||||
if first.get("role") != "user":
|
||||
break
|
||||
content = first.get("content", [])
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result") \
|
||||
and not _has_block_type(content, "text"):
|
||||
logger.warning("⚠️ Removing leading orphaned tool_result user message")
|
||||
messages.pop(0)
|
||||
removed += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 3. Iteratively remove unmatched tool_use / tool_result until stable.
|
||||
# Removing one broken message can orphan others (e.g. an assistant msg
|
||||
# with both matched and unmatched tool_use — deleting it orphans the
|
||||
# previously-matched tool_result). Loop until clean.
|
||||
for _ in range(5):
|
||||
use_ids: Set[str] = set()
|
||||
result_ids: Set[str] = set()
|
||||
for msg in messages:
|
||||
for block in (msg.get("content") or []):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "tool_use" and block.get("id"):
|
||||
use_ids.add(block["id"])
|
||||
elif block.get("type") == "tool_result" and block.get("tool_use_id"):
|
||||
result_ids.add(block["tool_use_id"])
|
||||
|
||||
bad_use = use_ids - result_ids
|
||||
bad_result = result_ids - use_ids
|
||||
if not bad_use and not bad_result:
|
||||
break
|
||||
|
||||
pass_removed = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if role == "assistant" and bad_use and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
and b.get("id") in bad_use for b in content
|
||||
):
|
||||
logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
|
||||
if role == "user" and bad_result and _has_block_type(content, "tool_result"):
|
||||
has_bad = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result for b in content
|
||||
)
|
||||
if has_bad:
|
||||
if not _has_block_type(content, "text"):
|
||||
logger.warning(f"⚠️ Removing user msg with unmatched tool_result")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
else:
|
||||
before = len(content)
|
||||
msg["content"] = [
|
||||
b for b in content
|
||||
if not (isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result)
|
||||
]
|
||||
pass_removed += before - len(msg["content"])
|
||||
|
||||
i += 1
|
||||
|
||||
removed += pass_removed
|
||||
if pass_removed == 0:
|
||||
break
|
||||
|
||||
# 4. Removals above can break adjacency; re-run repair only if something was removed.
|
||||
if removed:
|
||||
adj_repairs += _repair_tool_use_adjacency(messages)
|
||||
|
||||
if removed:
|
||||
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
|
||||
if adj_repairs:
|
||||
logger.info(f"🔧 Message validation: adjacency repairs={adj_repairs}")
|
||||
return removed + adj_repairs
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# OpenAI-format sanitizer (used by minimax_bot, openai_compatible_bot)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def drop_orphaned_tool_results_openai(messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Return a copy of *messages* (OpenAI format) with any ``role=tool``
|
||||
messages removed if their ``tool_call_id`` does not match a
|
||||
``tool_calls[].id`` in a preceding assistant message.
|
||||
"""
|
||||
known_ids: Set[str] = set()
|
||||
cleaned: List[Dict] = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id", "")
|
||||
if tc_id:
|
||||
known_ids.add(tc_id)
|
||||
|
||||
if msg.get("role") == "tool":
|
||||
ref_id = msg.get("tool_call_id", "")
|
||||
if ref_id and ref_id not in known_ids:
|
||||
logger.warning(
|
||||
f"[MessageSanitizer] Dropping orphaned tool result "
|
||||
f"(tool_call_id={ref_id} not in known ids)"
|
||||
)
|
||||
continue
|
||||
cleaned.append(msg)
|
||||
return cleaned
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _has_block_type(content: list, block_type: str) -> bool:
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == block_type
|
||||
for b in content
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from a message content field (str or list of blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def compress_turn_to_text_only(turn: Dict) -> Dict:
|
||||
"""
|
||||
Compress a full turn (with tool_use/tool_result chains) into a lightweight
|
||||
text-only turn that keeps only the first user text and the last assistant text.
|
||||
|
||||
This preserves the conversational context (what the user asked and what the
|
||||
agent concluded) while stripping out the bulky intermediate tool interactions.
|
||||
|
||||
Returns a new turn dict with a ``messages`` list; the original is not mutated.
|
||||
"""
|
||||
user_text = ""
|
||||
last_assistant_text = ""
|
||||
|
||||
for msg in turn["messages"]:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result"):
|
||||
continue
|
||||
if not user_text:
|
||||
user_text = _extract_text_from_content(content)
|
||||
|
||||
elif role == "assistant":
|
||||
text = _extract_text_from_content(content)
|
||||
if text:
|
||||
last_assistant_text = text
|
||||
|
||||
compressed_messages = []
|
||||
if user_text:
|
||||
compressed_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if last_assistant_text:
|
||||
compressed_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": last_assistant_text}]
|
||||
})
|
||||
|
||||
return {"messages": compressed_messages}
|
||||
@@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -15,6 +15,7 @@ from agent.skills.types import (
|
||||
)
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.manager import SkillManager
|
||||
from agent.skills.service import SkillService
|
||||
from agent.skills.formatter import format_skills_for_prompt
|
||||
|
||||
__all__ = [
|
||||
@@ -25,5 +26,6 @@ __all__ = [
|
||||
"LoadSkillsResult",
|
||||
"SkillLoader",
|
||||
"SkillManager",
|
||||
"SkillService",
|
||||
"format_skills_for_prompt",
|
||||
]
|
||||
|
||||
@@ -123,17 +123,63 @@ def should_include_skill(
|
||||
return False
|
||||
|
||||
# Check environment variables (API keys)
|
||||
# Simple rule: All required env vars must be set
|
||||
# All required env vars must be set
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
for env_name in required_env:
|
||||
if not has_env_var(env_name):
|
||||
# Missing required API key → disable skill
|
||||
return False
|
||||
|
||||
# Check anyEnv (at least one must be present)
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env:
|
||||
if not any(has_env_var(e) for e in any_env):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_missing_requirements(
|
||||
entry: SkillEntry,
|
||||
current_platform: Optional[str] = None,
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict of missing requirements for a skill.
|
||||
Empty dict means all requirements are met.
|
||||
|
||||
:param entry: SkillEntry to check
|
||||
:param current_platform: Current platform (default: auto-detect)
|
||||
:return: Dict like {"bins": ["curl"], "env": ["API_KEY"]}
|
||||
"""
|
||||
missing: Dict[str, List[str]] = {}
|
||||
metadata = entry.metadata
|
||||
|
||||
if not metadata or not metadata.requires:
|
||||
return missing
|
||||
|
||||
required_bins = metadata.requires.get('bins', [])
|
||||
if required_bins:
|
||||
missing_bins = [b for b in required_bins if not has_binary(b)]
|
||||
if missing_bins:
|
||||
missing['bins'] = missing_bins
|
||||
|
||||
any_bins = metadata.requires.get('anyBins', [])
|
||||
if any_bins and not has_any_binary(any_bins):
|
||||
missing['anyBins'] = any_bins
|
||||
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
missing_env = [e for e in required_env if not has_env_var(e)]
|
||||
if missing_env:
|
||||
missing['env'] = missing_env
|
||||
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env and not any(has_env_var(e) for e in any_env):
|
||||
missing['anyEnv'] = any_env
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def is_config_path_truthy(config: Dict, path: str) -> bool:
|
||||
"""
|
||||
Check if a config path resolves to a truthy value.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Skill formatter for generating prompts from skills.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
|
||||
|
||||
@@ -23,12 +23,10 @@ def format_skills_for_prompt(skills: List[Skill]) -> str:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"\n\nThe following skills provide specialized instructions for specific tasks.",
|
||||
"Use the read tool to load a skill's file when the task matches its description.",
|
||||
"",
|
||||
"<available_skills>",
|
||||
]
|
||||
|
||||
|
||||
for skill in visible_skills:
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
@@ -53,6 +51,71 @@ def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
|
||||
return format_skills_for_prompt(skills)
|
||||
|
||||
|
||||
def format_unavailable_skills_for_prompt(
|
||||
entries: List[SkillEntry],
|
||||
missing_map: Dict[str, Dict[str, List[str]]],
|
||||
) -> str:
|
||||
"""
|
||||
Format unavailable (requires-not-met) skills as brief setup hints
|
||||
so the AI can guide users to configure them.
|
||||
|
||||
:param entries: List of unavailable skill entries
|
||||
:param missing_map: Dict mapping skill name to its missing requirements
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"",
|
||||
"<unavailable_skills>",
|
||||
"The following skills are installed but not yet ready. "
|
||||
"Guide the user to complete the setup when relevant.",
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
skill = entry.skill
|
||||
missing = missing_map.get(skill.name, {})
|
||||
|
||||
missing_parts = []
|
||||
for key, values in missing.items():
|
||||
missing_parts.append(f"{key}: {', '.join(values)}")
|
||||
missing_str = "; ".join(missing_parts) if missing_parts else "unknown"
|
||||
|
||||
setup_hint = _extract_setup_hint(skill)
|
||||
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <missing>{_escape_xml(missing_str)}</missing>")
|
||||
if setup_hint:
|
||||
lines.append(f" <setup>{_escape_xml(setup_hint)}</setup>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</unavailable_skills>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_setup_hint(skill: Skill) -> str:
|
||||
"""
|
||||
Extract the Setup section from SKILL.md content as a brief hint.
|
||||
Returns the first few lines of the ## Setup section.
|
||||
"""
|
||||
content = skill.content
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
import re
|
||||
match = re.search(r'^##\s+Setup\s*\n(.*?)(?=\n##\s|\Z)', content, re.MULTILINE | re.DOTALL)
|
||||
if not match:
|
||||
return ""
|
||||
|
||||
setup_text = match.group(1).strip()
|
||||
lines = setup_text.split('\n')
|
||||
hint_lines = [l.strip() for l in lines[:6] if l.strip()]
|
||||
return ' '.join(hint_lines)[:300]
|
||||
|
||||
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (text
|
||||
|
||||
@@ -87,8 +87,8 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
if not isinstance(metadata_raw, dict):
|
||||
return None
|
||||
|
||||
# Use metadata_raw directly (COW format)
|
||||
meta_obj = metadata_raw
|
||||
# Unwrap nested namespace (e.g. {"openclaw": {...}} or {"cowagent": {...}})
|
||||
meta_obj = _unwrap_metadata_namespace(metadata_raw)
|
||||
|
||||
# Parse install specs
|
||||
install_specs = []
|
||||
@@ -128,6 +128,7 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
|
||||
return SkillMetadata(
|
||||
always=meta_obj.get('always', False),
|
||||
default_enabled=meta_obj.get('default_enabled', True),
|
||||
skill_key=meta_obj.get('skillKey'),
|
||||
primary_env=meta_obj.get('primaryEnv'),
|
||||
emoji=meta_obj.get('emoji'),
|
||||
@@ -138,6 +139,25 @@ def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
)
|
||||
|
||||
|
||||
_KNOWN_METADATA_NAMESPACES = {"cowagent", "openclaw"}
|
||||
|
||||
|
||||
def _unwrap_metadata_namespace(metadata_raw: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Unwrap a single-key namespace wrapper like {"cowagent": {...} or {"openclaw": {...}}}.
|
||||
If the top-level dict has exactly one key matching a known namespace, return the inner dict.
|
||||
Otherwise return the original dict unchanged.
|
||||
"""
|
||||
keys = set(metadata_raw.keys())
|
||||
ns_keys = keys & _KNOWN_METADATA_NAMESPACES
|
||||
if len(ns_keys) == 1 and len(keys) == 1:
|
||||
ns = ns_keys.pop()
|
||||
inner = metadata_raw[ns]
|
||||
if isinstance(inner, dict):
|
||||
return inner
|
||||
return metadata_raw
|
||||
|
||||
|
||||
def _normalize_string_list(value: Any) -> List[str]:
|
||||
"""Normalize a value to a list of strings."""
|
||||
if not value:
|
||||
|
||||
@@ -12,25 +12,20 @@ from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_bo
|
||||
|
||||
class SkillLoader:
|
||||
"""Loads skills from various directories."""
|
||||
|
||||
def __init__(self, workspace_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the skill loader.
|
||||
|
||||
:param workspace_dir: Agent workspace directory (for workspace-specific skills)
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load skills from a directory.
|
||||
|
||||
|
||||
Discovery rules:
|
||||
- Direct .md files in the root directory
|
||||
- Recursive SKILL.md files under subdirectories
|
||||
|
||||
|
||||
:param dir_path: Directory path to scan
|
||||
:param source: Source identifier (e.g., 'managed', 'workspace', 'bundled')
|
||||
:param source: Source identifier ('builtin' or 'custom')
|
||||
:return: LoadSkillsResult with skills and diagnostics
|
||||
"""
|
||||
skills = []
|
||||
@@ -58,6 +53,12 @@ class SkillLoader:
|
||||
"""
|
||||
Recursively load skills from a directory.
|
||||
|
||||
If a subdirectory contains its own SKILL.md, it is treated as a
|
||||
self-contained skill (or skill-collection) and its children are
|
||||
NOT scanned further. This prevents sub-skills inside a collection
|
||||
(e.g. style-collection/style-anjing) from being listed as
|
||||
independent top-level skills.
|
||||
|
||||
:param dir_path: Directory to scan
|
||||
:param source: Source identifier
|
||||
:param include_root_files: Whether to include root-level .md files
|
||||
@@ -71,38 +72,41 @@ class SkillLoader:
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
# If this directory has its own SKILL.md, load it and stop recursing.
|
||||
# The sub-directories are internal resources of this skill.
|
||||
if not include_root_files and 'SKILL.md' in entries:
|
||||
skill_md_path = os.path.join(dir_path, 'SKILL.md')
|
||||
if os.path.isfile(skill_md_path):
|
||||
skill_result = self._load_skill_from_file(skill_md_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
diagnostics.extend(skill_result.diagnostics)
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
for entry in entries:
|
||||
# Skip hidden files and directories
|
||||
if entry.startswith('.'):
|
||||
continue
|
||||
|
||||
# Skip common non-skill directories
|
||||
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
|
||||
continue
|
||||
|
||||
full_path = os.path.join(dir_path, entry)
|
||||
|
||||
# Handle directories
|
||||
if os.path.isdir(full_path):
|
||||
# Recursively scan subdirectories
|
||||
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
|
||||
skills.extend(sub_result.skills)
|
||||
diagnostics.extend(sub_result.diagnostics)
|
||||
continue
|
||||
|
||||
# Handle files
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
|
||||
# Check if this is a skill file
|
||||
is_root_md = include_root_files and entry.endswith('.md')
|
||||
is_skill_md = not include_root_files and entry == 'SKILL.md'
|
||||
is_root_md = include_root_files and entry.endswith('.md') and entry.upper() != 'README.MD'
|
||||
|
||||
if not (is_root_md or is_skill_md):
|
||||
if not is_root_md:
|
||||
continue
|
||||
|
||||
# Load the skill
|
||||
skill_result = self._load_skill_from_file(full_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
@@ -137,6 +141,18 @@ class SkillLoader:
|
||||
name = frontmatter.get('name', parent_dir_name)
|
||||
description = frontmatter.get('description', '')
|
||||
|
||||
# Normalize name (handle both string and list)
|
||||
if isinstance(name, list):
|
||||
name = name[0] if name else parent_dir_name
|
||||
elif not isinstance(name, str):
|
||||
name = str(name) if name else parent_dir_name
|
||||
|
||||
# Normalize description (handle both string and list)
|
||||
if isinstance(description, list):
|
||||
description = ' '.join(str(d) for d in description if d)
|
||||
elif not isinstance(description, str):
|
||||
description = str(description) if description else ''
|
||||
|
||||
# Special handling for linkai-agent: dynamically load apps from config.json
|
||||
if name == 'linkai-agent':
|
||||
description = self._load_linkai_agent_description(skill_dir, description)
|
||||
@@ -176,16 +192,13 @@ class SkillLoader:
|
||||
import json
|
||||
|
||||
config_path = os.path.join(skill_dir, "config.json")
|
||||
template_path = os.path.join(skill_dir, "config.json.template")
|
||||
|
||||
# Try to load config.json or fallback to template
|
||||
config_file = config_path if os.path.exists(config_path) else template_path
|
||||
|
||||
if not os.path.exists(config_file):
|
||||
return default_description
|
||||
if not os.path.exists(config_path):
|
||||
logger.debug(f"[SkillLoader] linkai-agent skipped: no config.json found")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
apps = config.get("apps", [])
|
||||
@@ -206,61 +219,49 @@ class SkillLoader:
|
||||
|
||||
def load_all_skills(
|
||||
self,
|
||||
managed_dir: Optional[str] = None,
|
||||
workspace_skills_dir: Optional[str] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
) -> Dict[str, SkillEntry]:
|
||||
"""
|
||||
Load skills from all configured locations with precedence.
|
||||
|
||||
Load skills from builtin and custom directories.
|
||||
|
||||
Precedence (lowest to highest):
|
||||
1. Extra directories
|
||||
2. Managed skills directory
|
||||
3. Workspace skills directory
|
||||
|
||||
:param managed_dir: Managed skills directory (e.g., ~/.cow/skills)
|
||||
:param workspace_skills_dir: Workspace skills directory (e.g., workspace/skills)
|
||||
:param extra_dirs: Additional directories to load skills from
|
||||
1. builtin — project root ``skills/``, shipped with the codebase
|
||||
2. custom — workspace ``skills/``, installed via cloud console or skill creator
|
||||
|
||||
Same-name custom skills override builtin ones.
|
||||
|
||||
:param builtin_dir: Built-in skills directory
|
||||
:param custom_dir: Custom skills directory
|
||||
:return: Dictionary mapping skill name to SkillEntry
|
||||
"""
|
||||
skill_map: Dict[str, SkillEntry] = {}
|
||||
all_diagnostics = []
|
||||
|
||||
# Load from extra directories (lowest precedence)
|
||||
if extra_dirs:
|
||||
for extra_dir in extra_dirs:
|
||||
if not os.path.exists(extra_dir):
|
||||
continue
|
||||
result = self.load_skills_from_dir(extra_dir, source='extra')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load from managed directory
|
||||
if managed_dir and os.path.exists(managed_dir):
|
||||
result = self.load_skills_from_dir(managed_dir, source='managed')
|
||||
|
||||
# Load builtin skills (lower precedence)
|
||||
if builtin_dir and os.path.exists(builtin_dir):
|
||||
result = self.load_skills_from_dir(builtin_dir, source='builtin')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load from workspace directory (highest precedence)
|
||||
if workspace_skills_dir and os.path.exists(workspace_skills_dir):
|
||||
result = self.load_skills_from_dir(workspace_skills_dir, source='workspace')
|
||||
|
||||
# Load custom skills (higher precedence, overrides builtin)
|
||||
if custom_dir and os.path.exists(custom_dir):
|
||||
result = self.load_skills_from_dir(custom_dir, source='custom')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
|
||||
# Log diagnostics
|
||||
if all_diagnostics:
|
||||
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
|
||||
for diag in all_diagnostics[:5]: # Log first 5
|
||||
for diag in all_diagnostics[:5]:
|
||||
logger.debug(f" - {diag}")
|
||||
|
||||
logger.debug(f"Loaded {len(skill_map)} skills from all sources")
|
||||
|
||||
|
||||
logger.debug(f"Loaded {len(skill_map)} skills total")
|
||||
|
||||
return skill_map
|
||||
|
||||
def _create_skill_entry(self, skill: Skill) -> SkillEntry:
|
||||
|
||||
@@ -3,6 +3,7 @@ Skill manager for managing skill lifecycle and operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
@@ -10,56 +11,143 @@ from agent.skills.types import Skill, SkillEntry, SkillSnapshot
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.formatter import format_skill_entries_for_prompt
|
||||
|
||||
SKILLS_CONFIG_FILE = "skills_config.json"
|
||||
|
||||
|
||||
class SkillManager:
|
||||
"""Manages skills for an agent."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Optional[str] = None,
|
||||
managed_skills_dir: Optional[str] = None,
|
||||
extra_dirs: Optional[List[str]] = None,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
config: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the skill manager.
|
||||
|
||||
:param workspace_dir: Agent workspace directory
|
||||
:param managed_skills_dir: Managed skills directory (e.g., ~/.cow/skills)
|
||||
:param extra_dirs: Additional skill directories
|
||||
|
||||
:param builtin_dir: Built-in skills directory (project root ``skills/``)
|
||||
:param custom_dir: Custom skills directory (workspace ``skills/``)
|
||||
:param config: Configuration dictionary
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.managed_skills_dir = managed_skills_dir or self._get_default_managed_dir()
|
||||
self.extra_dirs = extra_dirs or []
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
self.builtin_dir = builtin_dir or os.path.join(project_root, 'skills')
|
||||
self.custom_dir = custom_dir or os.path.join(project_root, 'workspace', 'skills')
|
||||
self.config = config or {}
|
||||
|
||||
self.loader = SkillLoader(workspace_dir=workspace_dir)
|
||||
self._skills_config_path = os.path.join(self.custom_dir, SKILLS_CONFIG_FILE)
|
||||
|
||||
# skills_config: full skill metadata keyed by name
|
||||
# { "web-fetch": {"name": ..., "description": ..., "source": ..., "enabled": true}, ... }
|
||||
self.skills_config: Dict[str, dict] = {}
|
||||
|
||||
self.loader = SkillLoader()
|
||||
self.skills: Dict[str, SkillEntry] = {}
|
||||
|
||||
|
||||
# Load skills on initialization
|
||||
self.refresh_skills()
|
||||
|
||||
def _get_default_managed_dir(self) -> str:
|
||||
"""Get the default managed skills directory."""
|
||||
# Use project root skills directory as default
|
||||
import os
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
return os.path.join(project_root, 'skills')
|
||||
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Reload all skills from configured directories."""
|
||||
workspace_skills_dir = None
|
||||
if self.workspace_dir:
|
||||
workspace_skills_dir = os.path.join(self.workspace_dir, 'skills')
|
||||
|
||||
"""Reload all skills from builtin and custom directories, then sync config."""
|
||||
self.skills = self.loader.load_all_skills(
|
||||
managed_dir=self.managed_skills_dir,
|
||||
workspace_skills_dir=workspace_skills_dir,
|
||||
extra_dirs=self.extra_dirs,
|
||||
builtin_dir=self.builtin_dir,
|
||||
custom_dir=self.custom_dir,
|
||||
)
|
||||
|
||||
self._sync_skills_config()
|
||||
logger.debug(f"SkillManager: Loaded {len(self.skills)} skills")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# skills_config.json management
|
||||
# ------------------------------------------------------------------
|
||||
def _load_skills_config(self) -> Dict[str, dict]:
|
||||
"""Load skills_config.json from custom_dir. Returns empty dict if not found."""
|
||||
if not os.path.exists(self._skills_config_path):
|
||||
return {}
|
||||
try:
|
||||
with open(self._skills_config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning(f"[SkillManager] Failed to load {SKILLS_CONFIG_FILE}: {e}")
|
||||
return {}
|
||||
|
||||
def _save_skills_config(self):
|
||||
"""Persist skills_config to custom_dir/skills_config.json."""
|
||||
os.makedirs(self.custom_dir, exist_ok=True)
|
||||
try:
|
||||
with open(self._skills_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.skills_config, f, indent=4, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillManager] Failed to save {SKILLS_CONFIG_FILE}: {e}")
|
||||
|
||||
def _sync_skills_config(self):
|
||||
"""
|
||||
Merge directory-scanned skills with the persisted config file.
|
||||
|
||||
- New skills: use metadata.default_enabled as initial enabled state.
|
||||
- Existing skills: preserve their persisted enabled state.
|
||||
- Skills that no longer exist on disk are removed.
|
||||
- name/description/source are always refreshed from the latest scan.
|
||||
"""
|
||||
saved = self._load_skills_config()
|
||||
merged: Dict[str, dict] = {}
|
||||
|
||||
for name, entry in self.skills.items():
|
||||
skill = entry.skill
|
||||
prev = saved.get(name, {})
|
||||
category = prev.get("category", "skill")
|
||||
|
||||
if name in saved:
|
||||
enabled = prev.get("enabled", True)
|
||||
else:
|
||||
enabled = entry.metadata.default_enabled if entry.metadata else True
|
||||
|
||||
entry_dict = {
|
||||
"name": name,
|
||||
"description": skill.description,
|
||||
"source": prev.get("source") or skill.source,
|
||||
"enabled": enabled,
|
||||
"category": category,
|
||||
}
|
||||
display_name = prev.get("display_name")
|
||||
if display_name:
|
||||
entry_dict["display_name"] = display_name
|
||||
merged[name] = entry_dict
|
||||
|
||||
self.skills_config = merged
|
||||
self._save_skills_config()
|
||||
|
||||
def is_skill_enabled(self, name: str) -> bool:
|
||||
"""
|
||||
Check if a skill is enabled according to skills_config.
|
||||
|
||||
:param name: skill name
|
||||
:return: True if enabled (default True if not in config)
|
||||
"""
|
||||
entry = self.skills_config.get(name)
|
||||
if entry is None:
|
||||
return True
|
||||
return entry.get("enabled", True)
|
||||
|
||||
def set_skill_enabled(self, name: str, enabled: bool):
|
||||
"""
|
||||
Set a skill's enabled state and persist.
|
||||
|
||||
:param name: skill name
|
||||
:param enabled: True to enable, False to disable
|
||||
"""
|
||||
if name not in self.skills_config:
|
||||
raise ValueError(f"skill '{name}' not found in config")
|
||||
self.skills_config[name]["enabled"] = enabled
|
||||
self._save_skills_config()
|
||||
|
||||
def get_skills_config(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Return the full skills_config dict (for query API).
|
||||
|
||||
:return: copy of skills_config
|
||||
"""
|
||||
return dict(self.skills_config)
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
@@ -78,53 +166,120 @@ class SkillManager:
|
||||
"""
|
||||
return list(self.skills.values())
|
||||
|
||||
@staticmethod
|
||||
def _normalize_skill_filter(skill_filter: Optional[List[str]]) -> Optional[List[str]]:
|
||||
"""Normalize a skill_filter list into a flat list of stripped names."""
|
||||
if skill_filter is None:
|
||||
return None
|
||||
normalized = []
|
||||
for item in skill_filter:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
elif isinstance(item, list):
|
||||
for subitem in item:
|
||||
if isinstance(subitem, str):
|
||||
name = subitem.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
return normalized or None
|
||||
|
||||
def filter_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
include_disabled: bool = False,
|
||||
) -> List[SkillEntry]:
|
||||
"""
|
||||
Filter skills based on criteria.
|
||||
|
||||
Simple rule: Skills are auto-enabled if requirements are met.
|
||||
- Has required API keys → included
|
||||
- Missing API keys → excluded
|
||||
|
||||
Filter skills that are eligible (enabled + requirements met).
|
||||
|
||||
:param skill_filter: List of skill names to include (None = all)
|
||||
:param include_disabled: Whether to include skills with disable_model_invocation=True
|
||||
:return: Filtered list of skill entries
|
||||
:param include_disabled: Whether to include disabled skills
|
||||
:return: Filtered list of eligible skill entries
|
||||
"""
|
||||
from agent.skills.config import should_include_skill
|
||||
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Check requirements (platform, binaries, env vars)
|
||||
|
||||
entries = [e for e in entries if should_include_skill(e, self.config)]
|
||||
|
||||
# Apply skill filter
|
||||
if skill_filter is not None:
|
||||
normalized = [name.strip() for name in skill_filter if name.strip()]
|
||||
if normalized:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Filter out disabled skills unless explicitly requested
|
||||
|
||||
normalized = self._normalize_skill_filter(skill_filter)
|
||||
if normalized is not None:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
if not include_disabled:
|
||||
entries = [e for e in entries if not e.skill.disable_model_invocation]
|
||||
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
from config import conf
|
||||
if not conf().get("knowledge", True):
|
||||
entries = [e for e in entries if e.skill.name != "knowledge-wiki"]
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def filter_unavailable_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Find skills that are enabled but have unmet requirements.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Tuple of (entries, missing_map) where missing_map maps
|
||||
skill name to its missing requirements dict
|
||||
"""
|
||||
from agent.skills.config import should_include_skill, get_missing_requirements
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Only enabled skills
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
normalized = self._normalize_skill_filter(skill_filter)
|
||||
if normalized is not None:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Keep only those that fail should_include_skill (requirements not met)
|
||||
unavailable = []
|
||||
missing_map: Dict[str, dict] = {}
|
||||
for e in entries:
|
||||
if not should_include_skill(e, self.config):
|
||||
missing = get_missing_requirements(e)
|
||||
if missing:
|
||||
unavailable.append(e)
|
||||
missing_map[e.skill.name] = missing
|
||||
|
||||
return unavailable, missing_map
|
||||
|
||||
def build_skills_prompt(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a formatted prompt containing available skills.
|
||||
|
||||
Build a formatted prompt containing available skills
|
||||
and brief hints for unavailable ones.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt
|
||||
"""
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
return format_skill_entries_for_prompt(entries)
|
||||
from common.log import logger
|
||||
from agent.skills.formatter import format_unavailable_skills_for_prompt
|
||||
|
||||
eligible = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
logger.debug(f"[SkillManager] Eligible: {len(eligible)} skills (total: {len(self.skills)})")
|
||||
if eligible:
|
||||
skill_names = [e.skill.name for e in eligible]
|
||||
logger.debug(f"[SkillManager] Eligible skills: {skill_names}")
|
||||
|
||||
result = format_skill_entries_for_prompt(eligible)
|
||||
|
||||
unavailable, missing_map = self.filter_unavailable_skills(skill_filter=skill_filter)
|
||||
if unavailable:
|
||||
unavailable_names = [e.skill.name for e in unavailable]
|
||||
logger.debug(f"[SkillManager] Unavailable skills (setup needed): {unavailable_names}")
|
||||
result += format_unavailable_skills_for_prompt(unavailable, missing_map)
|
||||
|
||||
logger.debug(f"[SkillManager] Generated prompt length: {len(result)}")
|
||||
return result
|
||||
|
||||
def build_skill_snapshot(
|
||||
self,
|
||||
|
||||
285
agent/skills/service.py
Normal file
285
agent/skills/service.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Skill service for handling skill CRUD operations.
|
||||
|
||||
This service provides a unified interface for managing skills, which can be
|
||||
called from the cloud control client (LinkAI), the local web console, or any
|
||||
other management entry point.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tempfile
|
||||
from typing import Dict, List, Optional
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
from agent.skills.manager import SkillManager
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""
|
||||
High-level service for skill lifecycle management.
|
||||
Wraps SkillManager and provides network-aware operations such as
|
||||
downloading skill files from remote URLs.
|
||||
"""
|
||||
|
||||
def __init__(self, skill_manager: SkillManager):
|
||||
"""
|
||||
:param skill_manager: The SkillManager instance to operate on
|
||||
"""
|
||||
self.manager = skill_manager
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# query
|
||||
# ------------------------------------------------------------------
|
||||
def query(self) -> List[dict]:
|
||||
"""
|
||||
Query all skills and return a serialisable list.
|
||||
Reads from skills_config.json (refreshes from disk if needed).
|
||||
|
||||
:return: list of skill info dicts
|
||||
"""
|
||||
self.manager.refresh_skills()
|
||||
config = self.manager.get_skills_config()
|
||||
result = list(config.values())
|
||||
logger.info(f"[SkillService] query: {len(result)} skills found")
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# add / install
|
||||
# ------------------------------------------------------------------
|
||||
def add(self, payload: dict) -> None:
|
||||
"""
|
||||
Add (install) a skill from a remote payload.
|
||||
|
||||
Supported payload types:
|
||||
|
||||
1. ``type: "url"`` – download individual files::
|
||||
|
||||
{
|
||||
"name": "web_search",
|
||||
"type": "url",
|
||||
"enabled": true,
|
||||
"files": [
|
||||
{"url": "https://...", "path": "README.md"},
|
||||
{"url": "https://...", "path": "scripts/main.py"}
|
||||
]
|
||||
}
|
||||
|
||||
2. ``type: "package"`` – download a zip archive and extract::
|
||||
|
||||
{
|
||||
"name": "plugin-custom-tool",
|
||||
"type": "package",
|
||||
"category": "skills",
|
||||
"enabled": true,
|
||||
"files": [{"url": "https://cdn.example.com/skills/custom-tool.zip"}]
|
||||
}
|
||||
|
||||
:param payload: skill add payload from server
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
payload_type = payload.get("type", "url")
|
||||
|
||||
if payload_type == "package":
|
||||
self._add_package(name, payload)
|
||||
else:
|
||||
self._add_url(name, payload)
|
||||
|
||||
self.manager.refresh_skills()
|
||||
|
||||
category = payload.get("category")
|
||||
if category and name in self.manager.skills_config:
|
||||
self.manager.skills_config[name]["category"] = category
|
||||
self.manager._save_skills_config()
|
||||
|
||||
def _add_url(self, name: str, payload: dict) -> None:
|
||||
"""Install a skill by downloading individual files."""
|
||||
files = payload.get("files", [])
|
||||
if not files:
|
||||
raise ValueError("skill files list is empty")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
tmp_dir = skill_dir + ".tmp"
|
||||
if os.path.exists(tmp_dir):
|
||||
shutil.rmtree(tmp_dir)
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
for file_info in files:
|
||||
url = file_info.get("url")
|
||||
rel_path = file_info.get("path")
|
||||
if not url or not rel_path:
|
||||
logger.warning(f"[SkillService] add: skip invalid file entry {file_info}")
|
||||
continue
|
||||
dest = os.path.join(tmp_dir, rel_path)
|
||||
self._download_file(url, dest)
|
||||
except Exception:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
os.rename(tmp_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via url ({len(files)} files)")
|
||||
|
||||
def _add_package(self, name: str, payload: dict) -> None:
|
||||
"""
|
||||
Install a skill by downloading a zip archive and extracting it.
|
||||
|
||||
If the archive contains a single top-level directory, that directory
|
||||
is used as the skill folder directly; otherwise a new directory named
|
||||
after the skill is created to hold the extracted contents.
|
||||
"""
|
||||
files = payload.get("files", [])
|
||||
if not files or not files[0].get("url"):
|
||||
raise ValueError("package url is required")
|
||||
|
||||
url = files[0]["url"]
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
zip_path = os.path.join(tmp_dir, "package.zip")
|
||||
self._download_file(url, zip_path)
|
||||
|
||||
if not zipfile.is_zipfile(zip_path):
|
||||
raise ValueError(f"downloaded file is not a valid zip archive: {url}")
|
||||
|
||||
extract_dir = os.path.join(tmp_dir, "extracted")
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(extract_dir)
|
||||
|
||||
# Determine the actual content root.
|
||||
# If the zip has a single top-level directory, use its contents
|
||||
# so the skill folder is clean (no extra nesting).
|
||||
top_items = [
|
||||
item for item in os.listdir(extract_dir)
|
||||
if not item.startswith(".")
|
||||
]
|
||||
if len(top_items) == 1:
|
||||
single = os.path.join(extract_dir, top_items[0])
|
||||
if os.path.isdir(single):
|
||||
extract_dir = single
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
shutil.copytree(extract_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via package ({url})")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# open / close (enable / disable)
|
||||
# ------------------------------------------------------------------
|
||||
def open(self, payload: dict) -> None:
|
||||
"""
|
||||
Enable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=True)
|
||||
logger.info(f"[SkillService] open: skill '{name}' enabled")
|
||||
|
||||
def close(self, payload: dict) -> None:
|
||||
"""
|
||||
Disable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=False)
|
||||
logger.info(f"[SkillService] close: skill '{name}' disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete
|
||||
# ------------------------------------------------------------------
|
||||
def delete(self, payload: dict) -> None:
|
||||
"""
|
||||
Delete a skill by removing its directory entirely.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
logger.info(f"[SkillService] delete: removed directory {skill_dir}")
|
||||
else:
|
||||
logger.warning(f"[SkillService] delete: skill directory not found: {skill_dir}")
|
||||
|
||||
# Refresh will remove the deleted skill from config automatically
|
||||
self.manager.refresh_skills()
|
||||
logger.info(f"[SkillService] delete: skill '{name}' deleted")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch - single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a skill management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
:param action: one of query / add / open / close / delete
|
||||
:param payload: action-specific payload (may be None for query)
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "query":
|
||||
result_payload = self.query()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
elif action == "add":
|
||||
self.add(payload)
|
||||
elif action == "open":
|
||||
self.open(payload)
|
||||
elif action == "close":
|
||||
self.close(payload)
|
||||
elif action == "delete":
|
||||
self.delete(payload)
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _download_file(url: str, dest: str):
|
||||
"""
|
||||
Download a file from *url* and save to *dest*.
|
||||
|
||||
:param url: remote file URL
|
||||
:param dest: local destination path
|
||||
"""
|
||||
if requests is None:
|
||||
raise RuntimeError("requests library is required for downloading skill files")
|
||||
|
||||
dest_dir = os.path.dirname(dest)
|
||||
if dest_dir:
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
resp = requests.get(url, timeout=60)
|
||||
resp.raise_for_status()
|
||||
with open(dest, "wb") as f:
|
||||
f.write(resp.content)
|
||||
logger.debug(f"[SkillService] downloaded {url} -> {dest}")
|
||||
@@ -2,6 +2,7 @@
|
||||
Type definitions for skills system.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -28,6 +29,7 @@ class SkillInstallSpec:
|
||||
class SkillMetadata:
|
||||
"""Metadata for a skill from frontmatter."""
|
||||
always: bool = False # Always include this skill
|
||||
default_enabled: bool = True # Initial enabled state when first discovered
|
||||
skill_key: Optional[str] = None # Override skill key
|
||||
primary_env: Optional[str] = None # Primary environment variable
|
||||
emoji: Optional[str] = None
|
||||
@@ -44,7 +46,7 @@ class Skill:
|
||||
description: str
|
||||
file_path: str
|
||||
base_dir: str
|
||||
source: str # managed, workspace, bundled, etc.
|
||||
source: str # builtin or custom
|
||||
content: str # Full markdown content
|
||||
disable_model_invocation: bool = False
|
||||
frontmatter: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -45,38 +45,83 @@ def _import_optional_tools():
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Scheduler tool failed to load: {e}")
|
||||
|
||||
|
||||
|
||||
# WebSearch Tool (conditionally loaded based on API key availability at init time)
|
||||
try:
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
tools['WebSearch'] = WebSearch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebSearch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebSearch failed to load: {e}")
|
||||
|
||||
# WebFetch Tool
|
||||
try:
|
||||
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||
tools['WebFetch'] = WebFetch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebFetch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebFetch failed to load: {e}")
|
||||
|
||||
# Vision Tool (conditionally loaded based on API key availability)
|
||||
try:
|
||||
from agent.tools.vision.vision import Vision
|
||||
tools['Vision'] = Vision
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] Vision not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Vision failed to load: {e}")
|
||||
|
||||
return tools
|
||||
|
||||
# Load optional tools
|
||||
_optional_tools = _import_optional_tools()
|
||||
EnvConfig = _optional_tools.get('EnvConfig')
|
||||
SchedulerTool = _optional_tools.get('SchedulerTool')
|
||||
WebSearch = _optional_tools.get('WebSearch')
|
||||
WebFetch = _optional_tools.get('WebFetch')
|
||||
Vision = _optional_tools.get('Vision')
|
||||
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
Terminal = _optional_tools.get('Terminal')
|
||||
|
||||
|
||||
# Delayed import for BrowserTool
|
||||
# BrowserTool (requires playwright)
|
||||
def _import_browser_tool():
|
||||
from common.log import logger
|
||||
try:
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
return BrowserTool
|
||||
except ImportError:
|
||||
# Return a placeholder class that will prompt the user to install dependencies when instantiated
|
||||
class BrowserToolPlaceholder:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"The 'browser-use' package is required to use BrowserTool. "
|
||||
"Please install it with 'pip install browser-use>=0.1.40'."
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.info(
|
||||
f"[Tools] BrowserTool not loaded - missing dependency: {e}\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] BrowserTool failed to load: {e}")
|
||||
return None
|
||||
|
||||
return BrowserToolPlaceholder
|
||||
BrowserTool = _import_browser_tool()
|
||||
|
||||
# MCP Tools (no extra dependencies, loaded on demand)
|
||||
def _import_mcp_tools():
|
||||
"""导入 MCP 工具模块(无额外依赖,按需加载)"""
|
||||
from common.log import logger
|
||||
try:
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
from agent.tools.mcp.mcp_client import McpClientRegistry
|
||||
return {'McpTool': McpTool, 'McpClientRegistry': McpClientRegistry}
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tools] MCP tools not loaded: {e}")
|
||||
return {}
|
||||
|
||||
# Dynamically set BrowserTool
|
||||
# BrowserTool = _import_browser_tool()
|
||||
_mcp_tools = _import_mcp_tools()
|
||||
McpTool = _mcp_tools.get('McpTool')
|
||||
McpClientRegistry = _mcp_tools.get('McpClientRegistry')
|
||||
|
||||
# Export all tools (including optional ones that might be None)
|
||||
__all__ = [
|
||||
@@ -92,8 +137,11 @@ __all__ = [
|
||||
'MemoryGetTool',
|
||||
'EnvConfig',
|
||||
'SchedulerTool',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
# 'BrowserTool'
|
||||
'WebSearch',
|
||||
'WebFetch',
|
||||
'Vision',
|
||||
'BrowserTool',
|
||||
'McpTool',
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ Bash tool - Execute bash commands
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
@@ -11,18 +12,24 @@ from typing import Dict, Any
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Bash(BaseTool):
|
||||
"""Tool for executing bash commands"""
|
||||
|
||||
_IS_WIN = sys.platform == "win32"
|
||||
|
||||
name: str = "bash"
|
||||
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
|
||||
{'''
|
||||
PLATFORM: Windows (cmd.exe). Do NOT use Unix-only commands like grep, head, tail, sed, awk.
|
||||
''' if _IS_WIN else ''}
|
||||
ENVIRONMENT: All API keys from env_config are auto-injected. Use $VAR_NAME directly.
|
||||
|
||||
IMPORTANT SAFETY GUIDELINES:
|
||||
- You can freely create, modify, and delete files within the current workspace
|
||||
- For operations outside the workspace or potentially destructive commands (rm -rf, system commands, etc.), always explain what you're about to do and ask for user confirmation first
|
||||
- When in doubt, describe the command's purpose and ask for permission before executing"""
|
||||
SAFETY:
|
||||
- Freely create/modify/delete files within the workspace
|
||||
- For destructive commands out of workspace, explain and confirm first"""
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
@@ -80,27 +87,32 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
env = os.environ.copy()
|
||||
|
||||
# Load environment variables from ~/.cow/.env if it exists
|
||||
env_file = os.path.expanduser("~/.cow/.env")
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
dotenv_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
env_vars = dotenv_values(env_file)
|
||||
env.update(env_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(env_vars)} variables from {env_file}")
|
||||
dotenv_vars = dotenv_values(env_file)
|
||||
env.update(dotenv_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(dotenv_vars)} variables from {env_file}")
|
||||
except ImportError:
|
||||
logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
|
||||
except Exception as e:
|
||||
logger.debug(f"[Bash] Failed to load .env: {e}")
|
||||
|
||||
# getuid() only exists on Unix-like systems
|
||||
if hasattr(os, 'getuid'):
|
||||
logger.debug(f"[Bash] Process UID: {os.getuid()}")
|
||||
else:
|
||||
logger.debug(f"[Bash] Process User: {os.environ.get('USERNAME', os.environ.get('USER', 'unknown'))}")
|
||||
|
||||
# Debug logging
|
||||
logger.debug(f"[Bash] CWD: {self.cwd}")
|
||||
logger.debug(f"[Bash] Command: {command[:500]}")
|
||||
logger.debug(f"[Bash] OPENAI_API_KEY in env: {'OPENAI_API_KEY' in env}")
|
||||
logger.debug(f"[Bash] SHELL: {env.get('SHELL', 'not set')}")
|
||||
logger.debug(f"[Bash] Python executable: {sys.executable}")
|
||||
logger.debug(f"[Bash] Process UID: {os.getuid()}")
|
||||
|
||||
# Execute command with inherited environment variables
|
||||
# On Windows, convert $VAR references to %VAR% for cmd.exe
|
||||
if self._IS_WIN:
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
command = self._convert_env_vars_for_windows(command, dotenv_vars)
|
||||
if command and not command.strip().lower().startswith("chcp"):
|
||||
command = f"chcp 65001 >nul 2>&1 && {command}"
|
||||
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
@@ -108,8 +120,10 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
env=env,
|
||||
)
|
||||
|
||||
logger.debug(f"[Bash] Exit code: {result.returncode}")
|
||||
@@ -131,6 +145,8 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
@@ -153,10 +169,16 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
except Exception as retry_err:
|
||||
logger.warning(f"[Bash] Retry failed: {retry_err}")
|
||||
|
||||
# Combine stdout and stderr
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
# When command succeeds with stdout, keep output clean (stderr goes to server log only).
|
||||
# When command fails or stdout is empty, include stderr so the agent can diagnose.
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
logger.info(f"[Bash] stderr (not forwarded): {result.stderr[:500]}")
|
||||
else:
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
|
||||
# Check if we need to save full output to temp file
|
||||
temp_file_path = None
|
||||
@@ -216,45 +238,58 @@ IMPORTANT SAFETY GUIDELINES:
|
||||
|
||||
def _get_safety_warning(self, command: str) -> str:
|
||||
"""
|
||||
Get safety warning for potentially dangerous commands
|
||||
Only warns about extremely dangerous system-level operations
|
||||
|
||||
Get safety warning for absolutely catastrophic commands only.
|
||||
Keep the blocklist minimal so the agent retains maximum freedom.
|
||||
|
||||
:param command: Command to check
|
||||
:return: Warning message if dangerous, empty string if safe
|
||||
"""
|
||||
cmd_lower = command.lower().strip()
|
||||
# Tokenize to avoid substring false positives (e.g. `rm -rf /tmp/x`
|
||||
# must not match `rm -rf /`).
|
||||
tokens = command.lower().split()
|
||||
|
||||
# Only block extremely dangerous system operations
|
||||
dangerous_patterns = [
|
||||
# System shutdown/reboot
|
||||
("shutdown", "This command will shut down the system"),
|
||||
("reboot", "This command will reboot the system"),
|
||||
("halt", "This command will halt the system"),
|
||||
("poweroff", "This command will power off the system"),
|
||||
# `rm -rf /` or `rm -rf /*` targeting the real root.
|
||||
for i, tok in enumerate(tokens):
|
||||
if tok != "rm":
|
||||
continue
|
||||
has_rf = False
|
||||
for j in range(i + 1, len(tokens)):
|
||||
t = tokens[j]
|
||||
if t.startswith("-") and "r" in t and "f" in t:
|
||||
has_rf = True
|
||||
elif t in ("--recursive", "--force"):
|
||||
continue
|
||||
elif t in ("/", "/*"):
|
||||
if has_rf:
|
||||
return "This command will delete the entire filesystem"
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
# Critical system modifications
|
||||
("rm -rf /", "This command will delete the entire filesystem"),
|
||||
("rm -rf /*", "This command will delete the entire filesystem"),
|
||||
("dd if=/dev/zero", "This command can destroy disk data"),
|
||||
("mkfs", "This command will format a filesystem, destroying all data"),
|
||||
("fdisk", "This command modifies disk partitions"),
|
||||
# Disk wiping
|
||||
if "if=/dev/zero" in command.lower() and "dd " in command.lower():
|
||||
return "This command can destroy disk data"
|
||||
|
||||
# User/system management (only if targeting system users)
|
||||
("userdel root", "This command will delete the root user"),
|
||||
("passwd root", "This command will change the root password"),
|
||||
]
|
||||
# Power control - match only as a standalone word (\b enforces word boundary)
|
||||
if re.search(r'\b(shutdown|reboot|halt|poweroff)\b', command.lower()):
|
||||
return "This command will shut down or restart the system"
|
||||
|
||||
for pattern, warning in dangerous_patterns:
|
||||
if pattern in cmd_lower:
|
||||
return warning
|
||||
return ""
|
||||
|
||||
# Check for recursive deletion outside workspace
|
||||
if "rm" in cmd_lower and "-rf" in cmd_lower:
|
||||
# Allow deletion within current workspace
|
||||
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
|
||||
# Check if targeting system directories
|
||||
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
|
||||
if any(sysdir in cmd_lower for sysdir in system_dirs):
|
||||
return "This command will recursively delete system directories"
|
||||
@staticmethod
|
||||
def _convert_env_vars_for_windows(command: str, dotenv_vars: dict) -> str:
|
||||
"""
|
||||
Convert bash-style $VAR / ${VAR} references to cmd.exe %VAR% syntax.
|
||||
Only converts variables loaded from .env (user-configured API keys etc.)
|
||||
to avoid breaking $PATH, jq expressions, regex, etc.
|
||||
"""
|
||||
if not dotenv_vars:
|
||||
return command
|
||||
|
||||
return "" # No warning needed
|
||||
def replace_match(m):
|
||||
var_name = m.group(1) or m.group(2)
|
||||
if var_name in dotenv_vars:
|
||||
return f"%{var_name}%"
|
||||
return m.group(0)
|
||||
|
||||
return re.sub(r'\$\{(\w+)\}|\$(\w+)', replace_match, command)
|
||||
|
||||
3
agent/tools/browser/__init__.py
Normal file
3
agent/tools/browser/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
|
||||
__all__ = ["BrowserTool"]
|
||||
961
agent/tools/browser/browser_service.py
Normal file
961
agent/tools/browser/browser_service.py
Normal file
@@ -0,0 +1,961 @@
|
||||
"""
|
||||
Browser service - Playwright wrapper managing browser lifecycle and page operations.
|
||||
|
||||
All Playwright calls run on a dedicated background thread so that callers from
|
||||
any worker thread can safely use the service. An idle-timeout mechanism
|
||||
automatically shuts down the browser (and its thread) after a configurable
|
||||
period of inactivity to free resources.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import queue
|
||||
import threading
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
|
||||
from common.log import logger
|
||||
from common.utils import expand_path, is_cloud_deployment
|
||||
|
||||
|
||||
_DEFAULT_USER_DATA_DIR = "~/.cow/browser_profile"
|
||||
|
||||
try:
|
||||
from playwright.sync_api import sync_playwright, Browser, BrowserContext, Page, Playwright
|
||||
_HAS_PLAYWRIGHT = True
|
||||
except ImportError:
|
||||
_HAS_PLAYWRIGHT = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot DOM helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Tags that typically carry useful content for an agent
|
||||
_INTERACTIVE_TAGS = {
|
||||
"a", "button", "input", "textarea", "select", "option",
|
||||
"label", "details", "summary",
|
||||
}
|
||||
_SEMANTIC_TAGS = {
|
||||
"h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"p", "li", "td", "th", "caption", "figcaption", "blockquote", "pre", "code",
|
||||
"nav", "main", "article", "section", "header", "footer", "form", "table",
|
||||
"img", "video", "audio",
|
||||
}
|
||||
_KEEP_TAGS = _INTERACTIVE_TAGS | _SEMANTIC_TAGS
|
||||
|
||||
_SNAPSHOT_JS = """
|
||||
() => {
|
||||
const KEEP = new Set(%s);
|
||||
const INTERACTIVE = new Set(%s);
|
||||
const SKIP = new Set(["script","style","noscript","svg","path","meta","link","br","hr"]);
|
||||
const CLICKABLE_ROLES = new Set([
|
||||
"button","link","tab","menuitem","menuitemcheckbox","menuitemradio",
|
||||
"option","switch","checkbox","radio","combobox","searchbox","slider",
|
||||
"spinbutton","textbox","treeitem"
|
||||
]);
|
||||
let refCounter = 0;
|
||||
const refMap = {};
|
||||
|
||||
function visible(el) {
|
||||
if (!(el instanceof HTMLElement)) return true;
|
||||
const st = window.getComputedStyle(el);
|
||||
if (st.display === "none" || st.visibility === "hidden") return false;
|
||||
if (parseFloat(st.opacity) === 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Strong signals: these attributes alone are enough to mark as interactive
|
||||
function hasStrongInteractiveSignal(el) {
|
||||
const role = el.getAttribute("role");
|
||||
if (role && CLICKABLE_ROLES.has(role)) return true;
|
||||
if (el.hasAttribute("onclick") || el.hasAttribute("tabindex")) return true;
|
||||
if (el.hasAttribute("data-click") || el.hasAttribute("data-action")) return true;
|
||||
if (el.getAttribute("contenteditable") === "true") return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if cursor:pointer is set directly (not just inherited from parent)
|
||||
function hasOwnPointerCursor(el) {
|
||||
try {
|
||||
const st = window.getComputedStyle(el);
|
||||
if (st.cursor !== "pointer") return false;
|
||||
const parent = el.parentElement;
|
||||
if (parent) {
|
||||
const pst = window.getComputedStyle(parent);
|
||||
if (pst.cursor === "pointer") return false;
|
||||
}
|
||||
return true;
|
||||
} catch(e) {}
|
||||
return false;
|
||||
}
|
||||
|
||||
function hasTextOrContent(el) {
|
||||
const t = el.textContent || "";
|
||||
if (t.trim().length > 0) return true;
|
||||
if (el.querySelector("img,video,audio,canvas")) return true;
|
||||
const ariaLabel = el.getAttribute("aria-label");
|
||||
if (ariaLabel && ariaLabel.trim()) return true;
|
||||
const title = el.getAttribute("title");
|
||||
if (title && title.trim()) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function isImplicitInteractive(el) {
|
||||
if (hasStrongInteractiveSignal(el)) return true;
|
||||
if (hasOwnPointerCursor(el) && hasTextOrContent(el)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function getTextContent(el) {
|
||||
let text = "";
|
||||
for (const ch of el.childNodes) {
|
||||
if (ch.nodeType === Node.TEXT_NODE) {
|
||||
text += ch.textContent;
|
||||
}
|
||||
}
|
||||
return text.trim();
|
||||
}
|
||||
|
||||
function walk(node) {
|
||||
if (node.nodeType === Node.TEXT_NODE) {
|
||||
const t = node.textContent.trim();
|
||||
return t ? t : null;
|
||||
}
|
||||
if (node.nodeType !== Node.ELEMENT_NODE) return null;
|
||||
const tag = node.tagName.toLowerCase();
|
||||
if (SKIP.has(tag)) return null;
|
||||
if (!visible(node)) return null;
|
||||
|
||||
const children = [];
|
||||
for (const ch of node.childNodes) {
|
||||
const r = walk(ch);
|
||||
if (r !== null) {
|
||||
if (typeof r === "string") children.push(r);
|
||||
else children.push(r);
|
||||
}
|
||||
}
|
||||
|
||||
const nativeInteractive = INTERACTIVE.has(tag);
|
||||
const implicitInteractive = !nativeInteractive && (node instanceof HTMLElement) && isImplicitInteractive(node);
|
||||
const keep = KEEP.has(tag) || implicitInteractive;
|
||||
|
||||
if (!keep) {
|
||||
if (children.length === 0) return null;
|
||||
if (children.length === 1) return children[0];
|
||||
return children;
|
||||
}
|
||||
|
||||
const obj = { tag };
|
||||
if (nativeInteractive || implicitInteractive) {
|
||||
refCounter++;
|
||||
obj.ref = refCounter;
|
||||
refMap[refCounter] = node;
|
||||
}
|
||||
|
||||
if (implicitInteractive) {
|
||||
const role = node.getAttribute("role");
|
||||
if (role) obj.role = role;
|
||||
const directText = getTextContent(node);
|
||||
if (!directText && children.length === 0) {
|
||||
const ariaLabel = node.getAttribute("aria-label");
|
||||
const title = node.getAttribute("title");
|
||||
if (ariaLabel) obj.ariaLabel = ariaLabel;
|
||||
else if (title) obj.ariaLabel = title;
|
||||
}
|
||||
}
|
||||
|
||||
// Attributes
|
||||
if (tag === "a" && node.href) obj.href = node.getAttribute("href");
|
||||
if (tag === "img") {
|
||||
obj.alt = node.alt || "";
|
||||
obj.src = node.getAttribute("src") || "";
|
||||
}
|
||||
if (tag === "input" || tag === "textarea" || tag === "select") {
|
||||
obj.type = node.type || "text";
|
||||
obj.name = node.name || undefined;
|
||||
obj.value = node.value || undefined;
|
||||
obj.placeholder = node.placeholder || undefined;
|
||||
if (node.disabled) obj.disabled = true;
|
||||
if (tag === "input" && node.type === "checkbox") obj.checked = node.checked;
|
||||
}
|
||||
if (tag === "button") {
|
||||
if (node.disabled) obj.disabled = true;
|
||||
}
|
||||
if (tag === "option") {
|
||||
obj.value = node.value;
|
||||
if (node.selected) obj.selected = true;
|
||||
}
|
||||
if (tag === "label" && node.htmlFor) obj.for = node.htmlFor;
|
||||
|
||||
// Role / aria-label for native interactive & semantic elements
|
||||
if (!implicitInteractive) {
|
||||
const role = node.getAttribute("role");
|
||||
if (role) obj.role = role;
|
||||
const ariaLabel = node.getAttribute("aria-label");
|
||||
if (ariaLabel) obj.ariaLabel = ariaLabel;
|
||||
}
|
||||
|
||||
// Children
|
||||
if (children.length === 1 && typeof children[0] === "string") {
|
||||
obj.text = children[0];
|
||||
} else if (children.length > 0) {
|
||||
obj.children = children;
|
||||
}
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
const result = walk(document.body);
|
||||
window.__cowRefMap = refMap;
|
||||
return { tree: result, refCount: refCounter };
|
||||
}
|
||||
""" % (
|
||||
str(list(_KEEP_TAGS)),
|
||||
str(list(_INTERACTIVE_TAGS)),
|
||||
)
|
||||
|
||||
|
||||
_BROWSER_DEAD_HINTS = (
|
||||
"has been closed",
|
||||
"browser has disconnected",
|
||||
"target closed",
|
||||
"browser closed",
|
||||
"context or browser has been closed",
|
||||
)
|
||||
|
||||
|
||||
def _is_browser_dead_error(err: Exception) -> bool:
|
||||
"""Return True if *err* indicates the browser / page died out from under us."""
|
||||
msg = str(err).lower()
|
||||
return any(h in msg for h in _BROWSER_DEAD_HINTS)
|
||||
|
||||
|
||||
def _should_use_headless() -> bool:
|
||||
"""Decide headless mode: headless on Linux servers without display, headed elsewhere."""
|
||||
if sys.platform in ("win32", "darwin"):
|
||||
return False
|
||||
# Linux: check for display
|
||||
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _flatten_tree(node, indent=0) -> List[str]:
|
||||
"""Convert snapshot tree to compact text lines for LLM consumption."""
|
||||
if node is None:
|
||||
return []
|
||||
if isinstance(node, str):
|
||||
return [" " * indent + node]
|
||||
if isinstance(node, list):
|
||||
lines = []
|
||||
for child in node:
|
||||
lines.extend(_flatten_tree(child, indent))
|
||||
return lines
|
||||
if not isinstance(node, dict):
|
||||
return []
|
||||
|
||||
tag = node.get("tag", "?")
|
||||
ref = node.get("ref")
|
||||
parts = [tag]
|
||||
if ref:
|
||||
parts[0] = f"[{ref}] {tag}"
|
||||
|
||||
# Inline attributes
|
||||
for attr in ("type", "name", "href", "alt", "role", "ariaLabel", "placeholder", "value"):
|
||||
val = node.get(attr)
|
||||
if val:
|
||||
# Truncate long values
|
||||
s = str(val)
|
||||
if len(s) > 80:
|
||||
s = s[:77] + "..."
|
||||
parts.append(f'{attr}="{s}"')
|
||||
|
||||
for flag in ("disabled", "checked", "selected"):
|
||||
if node.get(flag):
|
||||
parts.append(flag)
|
||||
|
||||
prefix = " " * indent
|
||||
header = prefix + " ".join(parts)
|
||||
|
||||
text = node.get("text")
|
||||
if text:
|
||||
# Truncate long text
|
||||
if len(text) > 120:
|
||||
text = text[:117] + "..."
|
||||
header += f": {text}"
|
||||
|
||||
lines = [header]
|
||||
children = node.get("children", [])
|
||||
for child in children:
|
||||
lines.extend(_flatten_tree(child, indent + 2))
|
||||
return lines
|
||||
|
||||
|
||||
class BrowserService:
|
||||
"""Manages a Playwright browser on a dedicated background thread.
|
||||
|
||||
All Playwright operations are dispatched to a single long-lived thread via
|
||||
a task queue. Callers from *any* worker thread can use the public API
|
||||
safely. An idle timer automatically shuts the browser down after
|
||||
``idle_timeout`` seconds of inactivity (default 300 = 5 min).
|
||||
"""
|
||||
|
||||
_IDLE_TIMEOUT_DEFAULT = 300 # seconds
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
self._config = config or {}
|
||||
self._headless: Optional[bool] = None
|
||||
self._screenshot_dir: Optional[str] = None
|
||||
|
||||
# Background thread state
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._task_queue: queue.Queue = queue.Queue()
|
||||
self._lock = threading.Lock()
|
||||
self._alive = False
|
||||
self._ready = threading.Event()
|
||||
|
||||
# Playwright objects (only accessed on the background thread)
|
||||
self._playwright = None
|
||||
self._browser = None
|
||||
self._context = None
|
||||
self._page = None
|
||||
|
||||
# Launch mode: one of "fresh" | "persistent" | "cdp".
|
||||
# - cdp: connect to an externally launched Chrome via CDP endpoint.
|
||||
# - persistent: launch with launch_persistent_context using a user_data_dir
|
||||
# so cookies / login state survive across runs (default).
|
||||
# - fresh: classic launch + new_context, clean state every run.
|
||||
cdp_endpoint = self._config.get("cdp_endpoint") or ""
|
||||
persistent_flag = self._config.get("persistent", True)
|
||||
user_data_dir_cfg = self._config.get("user_data_dir")
|
||||
if user_data_dir_cfg is None:
|
||||
user_data_dir_cfg = _DEFAULT_USER_DATA_DIR
|
||||
|
||||
self._cdp_endpoint: str = cdp_endpoint.strip() if isinstance(cdp_endpoint, str) else ""
|
||||
if self._cdp_endpoint:
|
||||
self._launch_mode = "cdp"
|
||||
self._user_data_dir: str = ""
|
||||
elif persistent_flag and user_data_dir_cfg:
|
||||
self._launch_mode = "persistent"
|
||||
self._user_data_dir = expand_path(str(user_data_dir_cfg))
|
||||
else:
|
||||
self._launch_mode = "fresh"
|
||||
self._user_data_dir = ""
|
||||
|
||||
# Idle auto-release
|
||||
idle_cfg = self._config.get("idle_timeout")
|
||||
self._idle_timeout: float = float(idle_cfg) if idle_cfg is not None else self._IDLE_TIMEOUT_DEFAULT
|
||||
self._idle_timer: Optional[threading.Timer] = None
|
||||
|
||||
# Set when the browser / page is detected to have died externally
|
||||
# (e.g. user manually closed the window). The next _submit() will then
|
||||
# tear down the stale thread and relaunch.
|
||||
self._needs_restart = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background-thread lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _start_thread(self):
|
||||
"""Start the dedicated Playwright thread if not already running."""
|
||||
with self._lock:
|
||||
if self._alive and self._thread and self._thread.is_alive():
|
||||
return
|
||||
# Wait for old thread to fully exit before creating a new one
|
||||
old = self._thread
|
||||
if old and old.is_alive():
|
||||
old.join(timeout=5)
|
||||
# Fresh queue to avoid stale sentinels from a previous close()
|
||||
self._task_queue = queue.Queue()
|
||||
self._alive = True
|
||||
self._ready = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True, name="BrowserThread")
|
||||
self._thread.start()
|
||||
# Block until browser is ready (or failed)
|
||||
self._ready.wait(timeout=30)
|
||||
|
||||
def _run_loop(self):
|
||||
"""Event loop running on the dedicated thread. Processes tasks until stopped."""
|
||||
logger.info("[Browser] Background thread started")
|
||||
try:
|
||||
self._launch_browser()
|
||||
except Exception as e:
|
||||
logger.error(f"[Browser] Failed to launch browser: {e}")
|
||||
self._alive = False
|
||||
self._ready.set()
|
||||
self._drain_queue(RuntimeError(f"Browser launch failed: {e}"))
|
||||
return
|
||||
self._ready.set()
|
||||
|
||||
while self._alive:
|
||||
try:
|
||||
task = self._task_queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if task is None:
|
||||
break
|
||||
fn, args, kwargs, result_slot = task
|
||||
try:
|
||||
result_slot["value"] = fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
result_slot["error"] = e
|
||||
if _is_browser_dead_error(e):
|
||||
self._needs_restart = True
|
||||
logger.warning(
|
||||
f"[Browser] Detected closed page/context ({e}); "
|
||||
"will relaunch on next request."
|
||||
)
|
||||
finally:
|
||||
result_slot["event"].set()
|
||||
|
||||
self._shutdown_browser()
|
||||
self._drain_queue(RuntimeError("Browser thread stopped"))
|
||||
logger.info("[Browser] Background thread exited")
|
||||
|
||||
def _drain_queue(self, error: Exception):
|
||||
"""Unblock all callers waiting on the queue with an error."""
|
||||
while True:
|
||||
try:
|
||||
task = self._task_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
if task is None:
|
||||
continue
|
||||
_, _, _, result_slot = task
|
||||
result_slot["error"] = error
|
||||
result_slot["event"].set()
|
||||
|
||||
def _launch_browser(self):
|
||||
"""Launch / connect Chromium on the background thread."""
|
||||
if self._headless is None:
|
||||
headless_cfg = self._config.get("headless")
|
||||
self._headless = headless_cfg if headless_cfg is not None else _should_use_headless()
|
||||
|
||||
launch_args = ["--disable-dev-shm-usage"]
|
||||
if self._headless:
|
||||
launch_args.append("--no-sandbox")
|
||||
|
||||
if is_cloud_deployment():
|
||||
launch_args.extend([
|
||||
"--disable-gpu",
|
||||
"--disable-software-rasterizer",
|
||||
"--disable-extensions",
|
||||
"--disable-background-networking",
|
||||
"--disable-background-timer-throttling",
|
||||
"--disable-renderer-backgrounding",
|
||||
"--disable-features=site-per-process,TranslateUI,IsolateOrigins",
|
||||
"--no-zygote",
|
||||
"--js-flags=--max-old-space-size=384",
|
||||
"--memory-pressure-off",
|
||||
])
|
||||
|
||||
extra_args = self._config.get("launch_args", [])
|
||||
if extra_args:
|
||||
launch_args.extend(extra_args)
|
||||
|
||||
viewport_w = self._config.get("viewport_width", 1280)
|
||||
viewport_h = self._config.get("viewport_height", 720)
|
||||
viewport = {"width": viewport_w, "height": viewport_h}
|
||||
user_agent = (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
self._playwright = sync_playwright().start()
|
||||
|
||||
if self._launch_mode == "cdp":
|
||||
self._connect_cdp(viewport)
|
||||
elif self._launch_mode == "persistent":
|
||||
self._launch_persistent(launch_args, viewport, user_agent)
|
||||
else:
|
||||
self._launch_fresh(launch_args, viewport, user_agent)
|
||||
|
||||
logger.info("[Browser] Browser ready")
|
||||
|
||||
def _launch_fresh(self, launch_args: List[str], viewport: Dict[str, int], user_agent: str):
|
||||
"""Classic launch: brand new Chromium with an empty context."""
|
||||
logger.info(f"[Browser] Launching Chromium (fresh, headless={self._headless})")
|
||||
self._browser = self._playwright.chromium.launch(
|
||||
headless=self._headless,
|
||||
args=launch_args,
|
||||
)
|
||||
self._context = self._browser.new_context(
|
||||
viewport=viewport,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
self._page = self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _launch_persistent(self, launch_args: List[str], viewport: Dict[str, int], user_agent: str):
|
||||
"""Launch Chromium with a persistent user_data_dir so login state survives."""
|
||||
os.makedirs(self._user_data_dir, exist_ok=True)
|
||||
logger.info(
|
||||
f"[Browser] Launching Chromium (persistent, headless={self._headless}, "
|
||||
f"profile={self._user_data_dir})"
|
||||
)
|
||||
try:
|
||||
self._context = self._playwright.chromium.launch_persistent_context(
|
||||
user_data_dir=self._user_data_dir,
|
||||
headless=self._headless,
|
||||
args=launch_args,
|
||||
viewport=viewport,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
except Exception as e:
|
||||
# Profile is locked when another Chromium instance already holds it.
|
||||
msg = str(e).lower()
|
||||
if "singletonlock" in msg or "profile" in msg or "lock" in msg:
|
||||
raise RuntimeError(
|
||||
f"Browser profile '{self._user_data_dir}' is in use by another process. "
|
||||
"Close the other Chromium / cow instance, or set a different "
|
||||
"tools.browser.user_data_dir."
|
||||
) from e
|
||||
raise
|
||||
|
||||
# Persistent context has no parent Browser handle; reuse the auto-created page.
|
||||
self._browser = None
|
||||
pages = self._context.pages
|
||||
self._page = pages[0] if pages else self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _connect_cdp(self, viewport: Dict[str, int]):
|
||||
"""Attach to an existing Chrome started with --remote-debugging-port."""
|
||||
endpoint = self._cdp_endpoint
|
||||
logger.info(f"[Browser] Connecting to existing Chrome via CDP: {endpoint}")
|
||||
try:
|
||||
self._browser = self._playwright.chromium.connect_over_cdp(endpoint)
|
||||
except Exception as e:
|
||||
msg = str(e).lower()
|
||||
if "econnrefused" in msg or "connect" in msg or "refused" in msg:
|
||||
raise RuntimeError(
|
||||
f"Cannot reach Chrome at {endpoint}. The CDP browser is not "
|
||||
"running. Ask the user to launch Chrome with "
|
||||
"--remote-debugging-port and --user-data-dir, then retry. "
|
||||
"Do not retry this tool until the user confirms."
|
||||
) from e
|
||||
raise
|
||||
|
||||
contexts = self._browser.contexts
|
||||
if contexts:
|
||||
self._context = contexts[0]
|
||||
else:
|
||||
self._context = self._browser.new_context(viewport=viewport)
|
||||
|
||||
pages = self._context.pages
|
||||
self._page = pages[0] if pages else self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _wire_close_listeners(self):
|
||||
"""Mark needs_restart whenever the browser / context / page dies externally."""
|
||||
def _on_dead(_obj=None):
|
||||
self._needs_restart = True
|
||||
|
||||
try:
|
||||
if self._browser:
|
||||
self._browser.on("disconnected", _on_dead)
|
||||
if self._context:
|
||||
self._context.on("close", _on_dead)
|
||||
if self._page:
|
||||
self._page.on("close", _on_dead)
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] Failed to wire close listeners: {e}")
|
||||
|
||||
def _shutdown_browser(self):
|
||||
"""Shut down Playwright resources on the background thread.
|
||||
|
||||
Mode-specific behavior:
|
||||
- cdp: only disconnect the Playwright client; leave the user's Chrome
|
||||
and its tabs untouched (do NOT close the context).
|
||||
- persistent: close the persistent context (no separate browser handle).
|
||||
- fresh: close context, then browser.
|
||||
"""
|
||||
self._cancel_idle_timer()
|
||||
|
||||
if self._launch_mode == "cdp":
|
||||
# For CDP, browser.close() only detaches the Playwright client;
|
||||
# the user's Chrome process and its tabs stay alive.
|
||||
try:
|
||||
if self._browser:
|
||||
self._browser.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] cdp disconnect error: {e}")
|
||||
else:
|
||||
for obj, label in [
|
||||
(self._context, "context"),
|
||||
(self._browser, "browser"),
|
||||
]:
|
||||
try:
|
||||
if obj:
|
||||
obj.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] {label} close error: {e}")
|
||||
|
||||
try:
|
||||
if self._playwright:
|
||||
self._playwright.stop()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] playwright stop error: {e}")
|
||||
self._page = None
|
||||
self._context = None
|
||||
self._browser = None
|
||||
self._playwright = None
|
||||
logger.info("[Browser] Browser closed")
|
||||
|
||||
def _submit(self, fn: Callable, *args, **kwargs):
|
||||
"""Submit *fn* to the background thread and block until it completes."""
|
||||
# If the browser died externally (e.g. user closed the window), tear
|
||||
# down the stale thread first so _start_thread() will relaunch fresh.
|
||||
if self._needs_restart:
|
||||
logger.info("[Browser] Restarting after detecting closed browser")
|
||||
self.close()
|
||||
self._needs_restart = False
|
||||
|
||||
self._start_thread()
|
||||
|
||||
if not self._alive:
|
||||
raise RuntimeError("Browser is not available")
|
||||
|
||||
self._reset_idle_timer()
|
||||
|
||||
result_slot: Dict[str, Any] = {"event": threading.Event()}
|
||||
self._task_queue.put((fn, args, kwargs, result_slot))
|
||||
|
||||
# Timeout prevents permanent hang if the background thread crashes
|
||||
completed = result_slot["event"].wait(timeout=120)
|
||||
if not completed:
|
||||
raise TimeoutError("Browser operation timed out (120s)")
|
||||
|
||||
if "error" in result_slot:
|
||||
raise result_slot["error"]
|
||||
return result_slot.get("value")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Idle auto-release
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _reset_idle_timer(self):
|
||||
self._cancel_idle_timer()
|
||||
if self._idle_timeout > 0:
|
||||
self._idle_timer = threading.Timer(self._idle_timeout, self._on_idle_timeout)
|
||||
self._idle_timer.daemon = True
|
||||
self._idle_timer.start()
|
||||
|
||||
def _cancel_idle_timer(self):
|
||||
if self._idle_timer:
|
||||
self._idle_timer.cancel()
|
||||
self._idle_timer = None
|
||||
|
||||
def _on_idle_timeout(self):
|
||||
logger.info(f"[Browser] Idle for {self._idle_timeout}s, auto-releasing browser")
|
||||
self.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def close(self):
|
||||
"""Shut down browser and background thread (safe from any thread)."""
|
||||
self._cancel_idle_timer()
|
||||
with self._lock:
|
||||
if not self._alive:
|
||||
self._needs_restart = False
|
||||
return
|
||||
self._alive = False
|
||||
t = self._thread
|
||||
if self._task_queue is not None:
|
||||
self._task_queue.put(None)
|
||||
if t is not None and t.is_alive():
|
||||
t.join(timeout=10)
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
self._needs_restart = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Actions (each method is dispatched to the background thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def navigate(self, url: str, timeout: int = 30000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_navigate, url, timeout)
|
||||
|
||||
def _do_navigate(self, url: str, timeout: int) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
resp = page.goto(url, wait_until="domcontentloaded", timeout=timeout)
|
||||
status = resp.status if resp else None
|
||||
except Exception as e:
|
||||
return {"error": f"Navigation failed: {e}"}
|
||||
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=8000)
|
||||
except Exception:
|
||||
pass
|
||||
page.wait_for_timeout(500)
|
||||
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
current_url = page.url
|
||||
except Exception:
|
||||
current_url = url
|
||||
|
||||
return {"url": current_url, "title": title, "status": status}
|
||||
|
||||
def snapshot(self, selector: Optional[str] = None) -> str:
|
||||
return self._submit(self._do_snapshot, selector)
|
||||
|
||||
def _do_snapshot(self, selector: Optional[str] = None) -> str:
|
||||
page = self._page
|
||||
try:
|
||||
result = page.evaluate(_SNAPSHOT_JS)
|
||||
except Exception as e:
|
||||
return f"[Snapshot error: {e}]"
|
||||
|
||||
tree = result.get("tree")
|
||||
ref_count = result.get("refCount", 0)
|
||||
lines = _flatten_tree(tree)
|
||||
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
|
||||
header = f"Page: {title} ({url})\nInteractive elements: {ref_count}\n---"
|
||||
body = "\n".join(lines)
|
||||
|
||||
max_chars = self._config.get("snapshot_max_chars", 30000)
|
||||
if len(body) > max_chars:
|
||||
body = body[:max_chars] + "\n... [snapshot truncated]"
|
||||
|
||||
return f"{header}\n{body}"
|
||||
|
||||
def screenshot(self, full_page: bool = False, cwd: str = "") -> str:
|
||||
return self._submit(self._do_screenshot, full_page, cwd)
|
||||
|
||||
def _do_screenshot(self, full_page: bool = False, cwd: str = "") -> str:
|
||||
page = self._page
|
||||
save_dir = self._get_screenshot_dir(cwd)
|
||||
filename = f"screenshot_{uuid.uuid4().hex[:8]}.png"
|
||||
filepath = os.path.join(save_dir, filename)
|
||||
page.screenshot(path=filepath, full_page=full_page)
|
||||
logger.info(f"[Browser] Screenshot saved: {filepath}")
|
||||
return filepath
|
||||
|
||||
def click(self, ref: Optional[int] = None, selector: Optional[str] = None,
|
||||
timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_click, ref, selector, timeout)
|
||||
|
||||
def _do_click(self, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
|
||||
el.click();
|
||||
return {{ clicked: true, tag: el.tagName.toLowerCase() }};
|
||||
}}
|
||||
""")
|
||||
if result.get("error"):
|
||||
return result
|
||||
page.wait_for_timeout(500)
|
||||
return result
|
||||
elif selector:
|
||||
page.click(selector, timeout=timeout)
|
||||
return {"clicked": True, "selector": selector}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Click failed: {e}"}
|
||||
|
||||
def fill(self, text: str, ref: Optional[int] = None,
|
||||
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_fill, text, ref, selector, timeout)
|
||||
|
||||
def _do_fill(self, text, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
|
||||
el.focus();
|
||||
el.value = "";
|
||||
return {{ tag: el.tagName.toLowerCase(), name: el.name || "" }};
|
||||
}}
|
||||
""")
|
||||
if result.get("error"):
|
||||
return result
|
||||
page.keyboard.type(text)
|
||||
return {"filled": True, "ref": ref, "text": text}
|
||||
elif selector:
|
||||
page.fill(selector, text, timeout=timeout)
|
||||
return {"filled": True, "selector": selector, "text": text}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Fill failed: {e}"}
|
||||
|
||||
def select(self, value: str, ref: Optional[int] = None,
|
||||
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_select, value, ref, selector, timeout)
|
||||
|
||||
def _do_select(self, value, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el || el.tagName.toLowerCase() !== "select")
|
||||
return {{ error: "ref {ref} is not a <select> element" }};
|
||||
el.value = {repr(value)};
|
||||
el.dispatchEvent(new Event("change", {{ bubbles: true }}));
|
||||
return {{ selected: true, value: el.value }};
|
||||
}}
|
||||
""")
|
||||
return result
|
||||
elif selector:
|
||||
page.select_option(selector, value, timeout=timeout)
|
||||
return {"selected": True, "selector": selector, "value": value}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Select failed: {e}"}
|
||||
|
||||
def scroll(self, direction: str = "down", amount: int = 500) -> Dict[str, Any]:
|
||||
return self._submit(self._do_scroll, direction, amount)
|
||||
|
||||
def _do_scroll(self, direction, amount) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
delta_map = {
|
||||
"down": (0, amount),
|
||||
"up": (0, -amount),
|
||||
"right": (amount, 0),
|
||||
"left": (-amount, 0),
|
||||
}
|
||||
dx, dy = delta_map.get(direction, (0, amount))
|
||||
try:
|
||||
page.mouse.wheel(dx, dy)
|
||||
page.wait_for_timeout(300)
|
||||
scroll_info = page.evaluate("""
|
||||
() => ({
|
||||
scrollX: window.scrollX,
|
||||
scrollY: window.scrollY,
|
||||
scrollHeight: document.documentElement.scrollHeight,
|
||||
clientHeight: document.documentElement.clientHeight
|
||||
})
|
||||
""")
|
||||
return {"scrolled": direction, "amount": amount, **scroll_info}
|
||||
except Exception as e:
|
||||
return {"error": f"Scroll failed: {e}"}
|
||||
|
||||
def wait(self, selector: Optional[str] = None, timeout: int = 5000,
|
||||
state: str = "visible") -> Dict[str, Any]:
|
||||
return self._submit(self._do_wait, selector, timeout, state)
|
||||
|
||||
def _do_wait(self, selector, timeout, state) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if selector:
|
||||
page.wait_for_selector(selector, timeout=timeout, state=state)
|
||||
return {"waited": True, "selector": selector, "state": state}
|
||||
else:
|
||||
page.wait_for_timeout(timeout)
|
||||
return {"waited": True, "timeout_ms": timeout}
|
||||
except Exception as e:
|
||||
return {"error": f"Wait failed: {e}"}
|
||||
|
||||
def go_back(self) -> Dict[str, Any]:
|
||||
return self._submit(self._do_go_back)
|
||||
|
||||
def _do_go_back(self) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.go_back(wait_until="domcontentloaded", timeout=10000)
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
return {"url": url, "title": title}
|
||||
except Exception as e:
|
||||
return {"error": f"Go back failed: {e}"}
|
||||
|
||||
def go_forward(self) -> Dict[str, Any]:
|
||||
return self._submit(self._do_go_forward)
|
||||
|
||||
def _do_go_forward(self) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.go_forward(wait_until="domcontentloaded", timeout=10000)
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
return {"url": url, "title": title}
|
||||
except Exception as e:
|
||||
return {"error": f"Go forward failed: {e}"}
|
||||
|
||||
def get_text(self, selector: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_get_text, selector)
|
||||
|
||||
def _do_get_text(self, selector) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
text = page.text_content(selector, timeout=5000)
|
||||
return {"text": text or ""}
|
||||
except Exception as e:
|
||||
return {"error": f"Get text failed: {e}"}
|
||||
|
||||
def evaluate(self, script: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_evaluate, script)
|
||||
|
||||
def _do_evaluate(self, script) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
result = page.evaluate(script)
|
||||
return {"result": result}
|
||||
except Exception as e:
|
||||
return {"error": f"Evaluate failed: {e}"}
|
||||
|
||||
def press(self, key: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_press, key)
|
||||
|
||||
def _do_press(self, key) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.keyboard.press(key)
|
||||
page.wait_for_timeout(300)
|
||||
return {"pressed": key}
|
||||
except Exception as e:
|
||||
return {"error": f"Press failed: {e}"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_screenshot_dir(self, cwd: str = "") -> str:
|
||||
if self._screenshot_dir and os.path.isdir(self._screenshot_dir):
|
||||
return self._screenshot_dir
|
||||
base = cwd or os.getcwd()
|
||||
d = os.path.join(base, "tmp")
|
||||
os.makedirs(d, exist_ok=True)
|
||||
self._screenshot_dir = d
|
||||
return d
|
||||
303
agent/tools/browser/browser_tool.py
Normal file
303
agent/tools/browser/browser_tool.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Browser tool - Control a Chromium browser for web navigation and interaction.
|
||||
|
||||
Uses Playwright under the hood. Browser instance is lazily started on first
|
||||
use, reused across tool calls within the same session, and cleaned up via
|
||||
close().
|
||||
|
||||
Launch modes (configured under `tools.browser` in config.json):
|
||||
- persistent (default): Chromium runs with a persistent user_data_dir
|
||||
(default `~/.cow/browser_profile`), so cookies and login state survive
|
||||
across runs. The user only needs to log in once.
|
||||
- cdp: When `cdp_endpoint` is set, attach to an externally launched Chrome
|
||||
via the Chrome DevTools Protocol. Lets the agent reuse the user's real
|
||||
browser (with all logins / extensions / true fingerprints).
|
||||
- fresh: Set `persistent` to false to fall back to a clean context every run.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.browser.browser_service import BrowserService
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class BrowserTool(BaseTool):
|
||||
"""Single tool exposing all browser actions via an 'action' parameter."""
|
||||
|
||||
name: str = "browser"
|
||||
description: str = (
|
||||
"Control a browser to navigate web pages, interact with elements, and extract content. "
|
||||
"Actions: navigate, snapshot, click, fill, select, scroll, screenshot, wait, back, forward, "
|
||||
"get_text, press, evaluate.\n\n"
|
||||
"Workflow: navigate (auto-includes snapshot with element refs) → click/fill/select by ref → snapshot to verify.\n\n"
|
||||
"Use snapshot as the primary way to read pages. Use screenshot + send to show key results to the user. "
|
||||
"For login/CAPTCHA/authorization etc., screenshot and ask the user for help. "
|
||||
"Login state is persisted across sessions (cookies / localStorage are kept in a "
|
||||
"user profile directory), so once the user logs in to a site, the agent can keep "
|
||||
"using it without logging in again."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The browser action to perform. One of: "
|
||||
"navigate, snapshot, click, fill, select, scroll, "
|
||||
"screenshot, wait, back, forward, get_text, press, evaluate"
|
||||
),
|
||||
"enum": [
|
||||
"navigate", "snapshot", "click", "fill", "select", "scroll",
|
||||
"screenshot", "wait", "back", "forward", "get_text", "press",
|
||||
"evaluate"
|
||||
]
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to navigate to (for 'navigate' action)"
|
||||
},
|
||||
"ref": {
|
||||
"type": "integer",
|
||||
"description": "Element ref number from snapshot (for click/fill/select)"
|
||||
},
|
||||
"selector": {
|
||||
"type": "string",
|
||||
"description": "CSS selector as fallback when ref is unavailable (for click/fill/select/wait/get_text)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to type (for 'fill' action)"
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Option value (for 'select' action)"
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key to press, e.g. Enter, Tab, Escape (for 'press' action)"
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"description": "Scroll direction: up, down, left, right (for 'scroll' action, default: down)"
|
||||
},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "JavaScript code to execute (for 'evaluate' action)"
|
||||
},
|
||||
"full_page": {
|
||||
"type": "boolean",
|
||||
"description": "Capture full page screenshot (for 'screenshot' action, default: false)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in milliseconds (optional, default varies by action)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
_shared_service: Optional[BrowserService] = None
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self._service: Optional[BrowserService] = None
|
||||
|
||||
def _get_service(self) -> BrowserService:
|
||||
"""Get or create the browser service, sharing across copies."""
|
||||
if self._service is not None:
|
||||
return self._service
|
||||
|
||||
# Reuse shared service across tool copies within the same session
|
||||
if BrowserTool._shared_service is not None:
|
||||
self._service = BrowserTool._shared_service
|
||||
return self._service
|
||||
|
||||
self._service = BrowserService(self.config)
|
||||
BrowserTool._shared_service = self._service
|
||||
return self._service
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
action = args.get("action", "").strip().lower()
|
||||
if not action:
|
||||
return ToolResult.fail("Error: 'action' parameter is required")
|
||||
|
||||
handler = self._ACTION_MAP.get(action)
|
||||
if not handler:
|
||||
valid = ", ".join(sorted(self._ACTION_MAP.keys()))
|
||||
return ToolResult.fail(f"Unknown action '{action}'. Valid actions: {valid}")
|
||||
|
||||
try:
|
||||
return handler(self, args)
|
||||
except Exception as e:
|
||||
logger.error(f"[Browser] Action '{action}' error: {e}")
|
||||
return ToolResult.fail(f"Browser error ({action}): {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _do_navigate(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' is required for navigate action")
|
||||
# Only auto-prepend https:// for bare hosts; preserve file://, about:, data:, etc.
|
||||
if "://" not in url and not url.startswith(("about:", "data:")):
|
||||
url = "https://" + url
|
||||
timeout = args.get("timeout", 30000)
|
||||
service = self._get_service()
|
||||
result = service.navigate(url, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
# Auto-snapshot after navigation so the agent gets page content in one call
|
||||
snapshot_text = service.snapshot()
|
||||
return ToolResult.success(
|
||||
f"Navigated to: {result['url']}\nTitle: {result['title']}\nStatus: {result['status']}\n\n"
|
||||
f"--- Page Snapshot ---\n{snapshot_text}"
|
||||
)
|
||||
|
||||
def _do_snapshot(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector")
|
||||
text = self._get_service().snapshot(selector=selector)
|
||||
return ToolResult.success(text)
|
||||
|
||||
def _do_click(self, args: Dict[str, Any]) -> ToolResult:
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
result = self._get_service().click(ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Clicked successfully. Use 'snapshot' to see updated page.")
|
||||
|
||||
def _do_fill(self, args: Dict[str, Any]) -> ToolResult:
|
||||
text = args.get("text", "")
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
if not text and text != "":
|
||||
return ToolResult.fail("Error: 'text' is required for fill action")
|
||||
result = self._get_service().fill(text, ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Filled text into element. Use 'snapshot' to verify.")
|
||||
|
||||
def _do_select(self, args: Dict[str, Any]) -> ToolResult:
|
||||
value = args.get("value", "")
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
if not value:
|
||||
return ToolResult.fail("Error: 'value' is required for select action")
|
||||
result = self._get_service().select(value, ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Selected option '{value}'.")
|
||||
|
||||
def _do_scroll(self, args: Dict[str, Any]) -> ToolResult:
|
||||
direction = args.get("direction", "down")
|
||||
amount = args.get("timeout", 500) # reuse timeout field or default
|
||||
if "amount" in args:
|
||||
amount = args["amount"]
|
||||
result = self._get_service().scroll(direction=direction, amount=amount)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
pos = f"scrollY={result.get('scrollY', '?')}/{result.get('scrollHeight', '?')}"
|
||||
return ToolResult.success(f"Scrolled {direction}. Position: {pos}")
|
||||
|
||||
def _do_screenshot(self, args: Dict[str, Any]) -> ToolResult:
|
||||
full_page = args.get("full_page", False)
|
||||
filepath = self._get_service().screenshot(full_page=full_page, cwd=self.cwd)
|
||||
return ToolResult.success(f"Screenshot saved to: {filepath}")
|
||||
|
||||
def _do_wait(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
result = self._get_service().wait(selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Wait completed.")
|
||||
|
||||
def _do_back(self, args: Dict[str, Any]) -> ToolResult:
|
||||
result = self._get_service().go_back()
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Navigated back to: {result['url']}")
|
||||
|
||||
def _do_forward(self, args: Dict[str, Any]) -> ToolResult:
|
||||
result = self._get_service().go_forward()
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Navigated forward to: {result['url']}")
|
||||
|
||||
def _do_get_text(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector", "").strip()
|
||||
if not selector:
|
||||
return ToolResult.fail("Error: 'selector' is required for get_text action")
|
||||
result = self._get_service().get_text(selector)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(result["text"])
|
||||
|
||||
def _do_press(self, args: Dict[str, Any]) -> ToolResult:
|
||||
key = args.get("key", "").strip()
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for press action")
|
||||
result = self._get_service().press(key)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Pressed key: {key}")
|
||||
|
||||
def _do_evaluate(self, args: Dict[str, Any]) -> ToolResult:
|
||||
script = args.get("script", "").strip()
|
||||
if not script:
|
||||
return ToolResult.fail("Error: 'script' is required for evaluate action")
|
||||
result = self._get_service().evaluate(script)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
val = result.get("result")
|
||||
if isinstance(val, (dict, list)):
|
||||
return ToolResult.success(json.dumps(val, ensure_ascii=False, indent=2))
|
||||
return ToolResult.success(str(val) if val is not None else "(no return value)")
|
||||
|
||||
# Action dispatch table
|
||||
_ACTION_MAP = {
|
||||
"navigate": _do_navigate,
|
||||
"snapshot": _do_snapshot,
|
||||
"click": _do_click,
|
||||
"fill": _do_fill,
|
||||
"select": _do_select,
|
||||
"scroll": _do_scroll,
|
||||
"screenshot": _do_screenshot,
|
||||
"wait": _do_wait,
|
||||
"back": _do_back,
|
||||
"forward": _do_forward,
|
||||
"get_text": _do_get_text,
|
||||
"press": _do_press,
|
||||
"evaluate": _do_evaluate,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def copy(self):
|
||||
"""Share browser instance across tool copies (avoids re-launching)."""
|
||||
new_tool = BrowserTool(self.config)
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, "context", None)
|
||||
new_tool.cwd = self.cwd
|
||||
new_tool._service = self._service
|
||||
return new_tool
|
||||
|
||||
def close(self):
|
||||
"""Release browser resources."""
|
||||
if self._service:
|
||||
self._service.close()
|
||||
self._service = None
|
||||
BrowserTool._shared_service = None
|
||||
logger.info("[Browser] BrowserTool closed")
|
||||
@@ -1,18 +0,0 @@
|
||||
def copy(self):
|
||||
"""
|
||||
Special copy method for browser tool to avoid recreating browser instance.
|
||||
|
||||
:return: A new instance with shared browser reference but unique model
|
||||
"""
|
||||
new_tool = self.__class__()
|
||||
|
||||
# Copy essential attributes
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, 'context', None)
|
||||
new_tool.config = getattr(self, 'config', None)
|
||||
|
||||
# Share the browser instance instead of creating a new one
|
||||
if hasattr(self, 'browser'):
|
||||
new_tool.browser = self.browser
|
||||
|
||||
return new_tool
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
from agent.tools.utils.diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
@@ -178,7 +179,7 @@ class Edit(BaseTool):
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
@@ -9,6 +9,7 @@ from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
# API Key 知识库:常见的环境变量及其描述
|
||||
@@ -66,7 +67,7 @@ class EnvConfig(BaseTool):
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
# Store env config in ~/.cow directory (outside workspace for security)
|
||||
self.env_dir = os.path.expanduser("~/.cow")
|
||||
self.env_dir = expand_path("~/.cow")
|
||||
self.env_path = os.path.join(self.env_dir, '.env')
|
||||
self.agent_bridge = self.config.get("agent_bridge") # Reference to AgentBridge for hot reload
|
||||
# Don't create .env file in __init__ to avoid issues during tool discovery
|
||||
@@ -201,7 +202,8 @@ class EnvConfig(BaseTool):
|
||||
"key": key,
|
||||
"value": self._mask_value(value),
|
||||
"description": description,
|
||||
"exists": True
|
||||
"exists": True,
|
||||
"note": f"Value is masked for security. In bash, use ${key} directly — it is auto-injected."
|
||||
})
|
||||
else:
|
||||
return ToolResult.success({
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 500
|
||||
@@ -51,7 +52,7 @@ class Ls(BaseTool):
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent accessing sensitive config directory
|
||||
env_config_dir = os.path.expanduser("~/.cow")
|
||||
env_config_dir = expand_path("~/.cow")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_dir):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
@@ -93,7 +94,7 @@ class Ls(BaseTool):
|
||||
results.append(entry + '/')
|
||||
else:
|
||||
results.append(entry)
|
||||
except:
|
||||
except Exception:
|
||||
# Skip entries we can't stat
|
||||
continue
|
||||
|
||||
@@ -133,7 +134,7 @@ class Ls(BaseTool):
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
4
agent/tools/mcp/__init__.py
Normal file
4
agent/tools/mcp/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from agent.tools.mcp.mcp_client import McpClient, McpClientRegistry
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
|
||||
__all__ = ["McpClient", "McpClientRegistry", "McpTool"]
|
||||
528
agent/tools/mcp/mcp_client.py
Normal file
528
agent/tools/mcp/mcp_client.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) client module.
|
||||
|
||||
Implements JSON-RPC 2.0 over stdio, SSE and Streamable HTTP transports
|
||||
without any external MCP SDK dependency.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import select
|
||||
import subprocess
|
||||
import threading
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# Aliases accepted for the Streamable HTTP transport type
|
||||
_STREAMABLE_HTTP_ALIASES = {"streamable-http", "streamable_http", "streamablehttp", "http"}
|
||||
|
||||
|
||||
class McpClient:
|
||||
"""Single MCP Server client supporting stdio, SSE and Streamable HTTP transports."""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""
|
||||
config examples:
|
||||
stdio: {"name": "filesystem", "type": "stdio", "command": "npx", "args": [...]}
|
||||
SSE: {"name": "my-api", "type": "sse", "url": "http://localhost:8000/sse"}
|
||||
streamable-http: {"name": "pubmed", "type": "streamable-http", "url": "https://x/mcp"}
|
||||
"""
|
||||
self.config = config
|
||||
self.name: str = config.get("name", "unknown")
|
||||
raw_transport: str = config.get("type", "stdio")
|
||||
# Normalize streamable-http aliases to a single internal key
|
||||
self.transport: str = (
|
||||
"streamable-http"
|
||||
if raw_transport.lower() in _STREAMABLE_HTTP_ALIASES
|
||||
else raw_transport
|
||||
)
|
||||
|
||||
# stdio state
|
||||
self._proc: Optional[subprocess.Popen] = None
|
||||
|
||||
# SSE state
|
||||
self._sse_url: Optional[str] = None
|
||||
self._post_url: Optional[str] = None # endpoint for sending messages (resolved from SSE)
|
||||
|
||||
# Streamable HTTP state
|
||||
self._http_url: Optional[str] = None
|
||||
self._http_headers: dict = {} # extra headers from user config (e.g. Authorization)
|
||||
self._http_session_id: Optional[str] = None # Mcp-Session-Id assigned by the server
|
||||
|
||||
# Shared state
|
||||
self._next_id = 1
|
||||
self._id_lock = threading.Lock()
|
||||
self._call_lock = threading.Lock()
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Connect and perform the MCP handshake. Returns True on success."""
|
||||
try:
|
||||
if self.transport == "stdio":
|
||||
return self._init_stdio()
|
||||
elif self.transport == "sse":
|
||||
return self._init_sse()
|
||||
elif self.transport == "streamable-http":
|
||||
return self._init_streamable_http()
|
||||
else:
|
||||
logger.warning(f"[MCP:{self.name}] Unknown transport type: {self.transport!r}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] Initialization failed: {e}")
|
||||
return False
|
||||
|
||||
def list_tools(self) -> list:
|
||||
"""Return the tool list from this server.
|
||||
|
||||
Each item is a dict: {"name": str, "description": str, "inputSchema": dict}
|
||||
"""
|
||||
try:
|
||||
resp = self._send_request("tools/list", {})
|
||||
tools = resp.get("result", {}).get("tools", [])
|
||||
return [
|
||||
{
|
||||
"name": t.get("name", ""),
|
||||
"description": t.get("description", ""),
|
||||
"inputSchema": t.get("inputSchema", {}),
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] list_tools failed: {e}")
|
||||
return []
|
||||
|
||||
def call_tool(self, name: str, arguments: dict) -> str:
|
||||
"""Call a tool and return the result as a string."""
|
||||
try:
|
||||
resp = self._send_request("tools/call", {"name": name, "arguments": arguments})
|
||||
content = resp.get("result", {}).get("content", [])
|
||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||
return "\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] call_tool({name}) failed: {e}")
|
||||
return f"Error: {e}"
|
||||
|
||||
def shutdown(self):
|
||||
"""Close the connection / terminate the child process."""
|
||||
if self._proc is not None:
|
||||
try:
|
||||
self._proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=5)
|
||||
except Exception:
|
||||
try:
|
||||
self._proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
self._proc = None
|
||||
logger.debug(f"[MCP:{self.name}] stdio process terminated")
|
||||
|
||||
# Best-effort streamable-http session termination
|
||||
if self.transport == "streamable-http" and self._http_session_id and self._http_url:
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
self._http_url,
|
||||
method="DELETE",
|
||||
headers={"Mcp-Session-Id": self._http_session_id, **self._http_headers},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5):
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
self._http_session_id = None
|
||||
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# stdio transport
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_stdio(self) -> bool:
|
||||
command = self.config.get("command")
|
||||
if not command:
|
||||
logger.warning(f"[MCP:{self.name}] stdio config missing 'command'")
|
||||
return False
|
||||
|
||||
args = self.config.get("args", [])
|
||||
extra_env = self.config.get("env", None)
|
||||
env = {**os.environ, **extra_env} if extra_env else None
|
||||
|
||||
self._proc = subprocess.Popen(
|
||||
[command] + list(args),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
env=env,
|
||||
)
|
||||
logger.debug(f"[MCP:{self.name}] stdio process started (pid={self._proc.pid})")
|
||||
|
||||
threading.Thread(
|
||||
target=self._drain_stderr, daemon=True, name=f"mcp-stderr-{self.name}"
|
||||
).start()
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _drain_stderr(self):
|
||||
for line in self._proc.stderr:
|
||||
line = line.strip()
|
||||
if line:
|
||||
logger.debug(f"[MCP:{self.name}] stderr: {line}")
|
||||
|
||||
def _readline_with_timeout(self, timeout: int = 30) -> str:
|
||||
"""Read one line from stdio stdout with a hard timeout."""
|
||||
ready, _, _ = select.select([self._proc.stdout], [], [], timeout)
|
||||
if not ready:
|
||||
raise TimeoutError(f"[MCP:{self.name}] stdio read timed out after {timeout}s")
|
||||
return self._proc.stdout.readline()
|
||||
|
||||
def _stdio_send(self, message: dict) -> dict:
|
||||
"""Send a JSON-RPC message over stdio and read the response."""
|
||||
raw = json.dumps(message) + "\n"
|
||||
self._proc.stdin.write(raw)
|
||||
self._proc.stdin.flush()
|
||||
|
||||
while True:
|
||||
line = self._readline_with_timeout()
|
||||
if not line:
|
||||
raise IOError(f"[MCP:{self.name}] stdio process closed unexpectedly")
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if "id" not in data:
|
||||
logger.debug(f"[MCP:{self.name}] notification skipped: {data.get('method', '?')}")
|
||||
continue
|
||||
return data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE transport
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_sse(self) -> bool:
|
||||
url = self.config.get("url")
|
||||
if not url:
|
||||
logger.warning(f"[MCP:{self.name}] SSE config missing 'url'")
|
||||
return False
|
||||
|
||||
self._sse_url = url
|
||||
|
||||
# Read the first SSE event to discover the POST endpoint
|
||||
try:
|
||||
self._post_url = self._sse_discover_endpoint()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] SSE endpoint discovery failed: {e}")
|
||||
return False
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _sse_discover_endpoint(self) -> str:
|
||||
"""Open SSE stream and read the 'endpoint' event to learn the POST URL."""
|
||||
req = urllib.request.Request(
|
||||
self._sse_url,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
for raw_line in resp:
|
||||
line = raw_line.decode("utf-8").rstrip("\n\r")
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data:"):].strip()
|
||||
# Some servers send JSON with a "uri" or plain path
|
||||
if data.startswith("{"):
|
||||
parsed = json.loads(data)
|
||||
return parsed.get("uri") or parsed.get("url") or parsed.get("endpoint")
|
||||
# Plain relative or absolute URL
|
||||
if data.startswith("http"):
|
||||
return data
|
||||
# Relative path: resolve against SSE base
|
||||
from urllib.parse import urljoin
|
||||
return urljoin(self._sse_url, data)
|
||||
raise ValueError(f"[MCP:{self.name}] No endpoint event received from SSE stream")
|
||||
|
||||
def _sse_send(self, message: dict) -> dict:
|
||||
"""POST a JSON-RPC message to the server and return the response."""
|
||||
body = json.dumps(message).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
self._post_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
raw = resp.read().decode("utf-8")
|
||||
return json.loads(raw)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streamable HTTP transport (MCP spec 2025-03-26)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_streamable_http(self) -> bool:
|
||||
url = self.config.get("url")
|
||||
if not url:
|
||||
logger.warning(f"[MCP:{self.name}] streamable-http config missing 'url'")
|
||||
return False
|
||||
|
||||
self._http_url = url
|
||||
# Allow user-provided headers (e.g. {"Authorization": "Bearer xxx"})
|
||||
extra_headers = self.config.get("headers") or {}
|
||||
if isinstance(extra_headers, dict):
|
||||
self._http_headers = {str(k): str(v) for k, v in extra_headers.items()}
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _streamable_http_send(self, message: dict) -> dict:
|
||||
"""POST a JSON-RPC request and return the response (JSON or SSE-wrapped)."""
|
||||
return self._streamable_http_post(message, expect_response=True)
|
||||
|
||||
def _streamable_http_post(self, message: dict, expect_response: bool) -> dict:
|
||||
"""
|
||||
POST a JSON-RPC message over Streamable HTTP.
|
||||
|
||||
Per the spec, the response Content-Type can be either:
|
||||
- application/json -> single JSON-RPC response in body
|
||||
- text/event-stream -> SSE stream; we read until we get a matching response
|
||||
"""
|
||||
body = json.dumps(message).encode("utf-8")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
if self._http_session_id:
|
||||
headers["Mcp-Session-Id"] = self._http_session_id
|
||||
headers.update(self._http_headers)
|
||||
|
||||
req = urllib.request.Request(
|
||||
self._http_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = urllib.request.urlopen(req, timeout=30)
|
||||
except urllib.error.HTTPError as e:
|
||||
# Surface the server-provided error body for easier debugging
|
||||
detail = ""
|
||||
try:
|
||||
detail = e.read().decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
pass
|
||||
raise IOError(
|
||||
f"[MCP:{self.name}] streamable-http HTTP {e.code}: {detail[:200]}"
|
||||
)
|
||||
|
||||
with resp:
|
||||
# Capture session id assigned by the server (if any)
|
||||
session_id = resp.headers.get("Mcp-Session-Id")
|
||||
if session_id and not self._http_session_id:
|
||||
self._http_session_id = session_id
|
||||
|
||||
status = resp.status if hasattr(resp, "status") else resp.getcode()
|
||||
|
||||
# Notifications: server may reply with 202 Accepted and no body
|
||||
if not expect_response or status == 202:
|
||||
try:
|
||||
resp.read()
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
content_type = (resp.headers.get("Content-Type") or "").lower()
|
||||
expected_id = message.get("id")
|
||||
|
||||
if "text/event-stream" in content_type:
|
||||
return self._read_sse_response(resp, expected_id)
|
||||
|
||||
raw = resp.read().decode("utf-8")
|
||||
if not raw:
|
||||
return {}
|
||||
return json.loads(raw)
|
||||
|
||||
def _read_sse_response(self, resp, expected_id) -> dict:
|
||||
"""Read an SSE stream and return the first JSON-RPC response with matching id."""
|
||||
data_buf: list = []
|
||||
for raw_line in resp:
|
||||
line = raw_line.decode("utf-8").rstrip("\n\r")
|
||||
if line == "":
|
||||
# End of an SSE event, attempt to parse accumulated data
|
||||
if data_buf:
|
||||
payload = "\n".join(data_buf)
|
||||
data_buf = []
|
||||
try:
|
||||
msg = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
# Skip notifications / mismatched ids
|
||||
if "id" not in msg:
|
||||
continue
|
||||
if expected_id is None or msg.get("id") == expected_id:
|
||||
return msg
|
||||
continue
|
||||
if line.startswith(":"):
|
||||
continue # SSE comment / keepalive
|
||||
if line.startswith("data:"):
|
||||
data_buf.append(line[len("data:"):].lstrip())
|
||||
# Ignore 'event:' / 'id:' lines; we only care about JSON-RPC payloads
|
||||
|
||||
raise IOError(f"[MCP:{self.name}] streamable-http SSE stream closed before response")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Common JSON-RPC helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _next_request_id(self) -> int:
|
||||
with self._id_lock:
|
||||
rid = self._next_id
|
||||
self._next_id += 1
|
||||
return rid
|
||||
|
||||
def _build_request(self, method: str, params: dict) -> dict:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._next_request_id(),
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
def _build_notification(self, method: str, params: dict) -> dict:
|
||||
return {"jsonrpc": "2.0", "method": method, "params": params}
|
||||
|
||||
def _send_request(self, method: str, params: dict) -> dict:
|
||||
"""Send a request and return the full response dict."""
|
||||
if not self._initialized and method != "initialize":
|
||||
raise RuntimeError(f"[MCP:{self.name}] Client not initialized")
|
||||
|
||||
message = self._build_request(method, params)
|
||||
|
||||
with self._call_lock:
|
||||
if self.transport == "stdio":
|
||||
return self._stdio_send(message)
|
||||
elif self.transport == "sse":
|
||||
return self._sse_send(message)
|
||||
elif self.transport == "streamable-http":
|
||||
return self._streamable_http_send(message)
|
||||
else:
|
||||
raise ValueError(f"[MCP:{self.name}] Unsupported transport: {self.transport}")
|
||||
|
||||
def _send_notification(self, method: str, params: dict):
|
||||
"""Fire-and-forget notification (no response expected)."""
|
||||
notification = self._build_notification(method, params)
|
||||
raw = json.dumps(notification) + "\n"
|
||||
|
||||
if self.transport == "stdio":
|
||||
self._proc.stdin.write(raw)
|
||||
self._proc.stdin.flush()
|
||||
elif self.transport == "sse":
|
||||
body = raw.encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
self._post_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10):
|
||||
pass
|
||||
except Exception:
|
||||
pass # notifications are fire-and-forget
|
||||
elif self.transport == "streamable-http":
|
||||
try:
|
||||
self._streamable_http_post(notification, expect_response=False)
|
||||
except Exception:
|
||||
pass # notifications are fire-and-forget
|
||||
|
||||
def _handshake(self) -> bool:
|
||||
"""Perform the MCP initialize / notifications/initialized handshake."""
|
||||
init_params = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "CowAgent", "version": "1.0"},
|
||||
}
|
||||
# Temporarily mark as initialized so _send_request doesn't block
|
||||
self._initialized = True
|
||||
try:
|
||||
resp = self._send_request("initialize", init_params)
|
||||
except Exception as e:
|
||||
self._initialized = False
|
||||
logger.warning(f"[MCP:{self.name}] Handshake initialize failed: {e}")
|
||||
return False
|
||||
|
||||
if "error" in resp:
|
||||
self._initialized = False
|
||||
logger.warning(f"[MCP:{self.name}] Handshake error: {resp['error']}")
|
||||
return False
|
||||
|
||||
self._send_notification("notifications/initialized", {})
|
||||
logger.debug(f"[MCP:{self.name}] Handshake complete")
|
||||
return True
|
||||
|
||||
|
||||
class McpClientRegistry:
|
||||
"""Global singleton managing the lifecycle of all MCP Server clients."""
|
||||
|
||||
_instance = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
obj = super().__new__(cls)
|
||||
obj._clients: dict[str, McpClient] = {}
|
||||
obj._registry_lock = threading.Lock()
|
||||
cls._instance = obj
|
||||
return cls._instance
|
||||
|
||||
def start_all(self, configs: list) -> None:
|
||||
"""Initialize McpClient for each config entry; skip failures with a warning."""
|
||||
if not configs:
|
||||
return
|
||||
|
||||
for cfg in configs:
|
||||
name = cfg.get("name", "<unnamed>")
|
||||
client = McpClient(cfg)
|
||||
ok = client.initialize()
|
||||
if ok:
|
||||
with self._registry_lock:
|
||||
self._clients[name] = client
|
||||
logger.info(f"[MCP] Server '{name}' initialized successfully")
|
||||
else:
|
||||
logger.warning(f"[MCP] Server '{name}' failed to initialize — skipping")
|
||||
|
||||
def get(self, server_name: str) -> Optional[McpClient]:
|
||||
"""Return the initialized client for server_name, or None."""
|
||||
with self._registry_lock:
|
||||
return self._clients.get(server_name)
|
||||
|
||||
def all_clients(self) -> dict:
|
||||
"""Return a copy of the {name: McpClient} mapping."""
|
||||
with self._registry_lock:
|
||||
return dict(self._clients)
|
||||
|
||||
def shutdown_all(self) -> None:
|
||||
"""Shut down all managed clients."""
|
||||
with self._registry_lock:
|
||||
clients = list(self._clients.values())
|
||||
self._clients.clear()
|
||||
|
||||
for client in clients:
|
||||
try:
|
||||
client.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP] Error shutting down '{client.name}': {e}")
|
||||
|
||||
logger.info("[MCP] All servers shut down")
|
||||
31
agent/tools/mcp/mcp_tool.py
Normal file
31
agent/tools/mcp/mcp_tool.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class McpTool(BaseTool):
|
||||
"""
|
||||
将单个 MCP 工具包装为 BaseTool。
|
||||
一个 MCP Server 可以提供多个工具,每个工具对应一个 McpTool 实例。
|
||||
"""
|
||||
|
||||
def __init__(self, client, tool_schema: dict, server_name: str):
|
||||
"""
|
||||
:param client: 该工具所属的 McpClient 实例
|
||||
:param tool_schema: MCP 返回的工具描述,格式:
|
||||
{"name": str, "description": str, "inputSchema": dict}
|
||||
:param server_name: Server 名称,用于日志
|
||||
"""
|
||||
self.client = client
|
||||
self.server_name = server_name
|
||||
self.name = tool_schema["name"]
|
||||
self.description = tool_schema.get("description", "")
|
||||
self.params = tool_schema.get("inputSchema", {})
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
logger.info(f"[McpTool] server={self.server_name} tool={self.name} params={params}")
|
||||
try:
|
||||
result = self.client.call_tool(self.name, params)
|
||||
return ToolResult.success(result)
|
||||
except Exception as e:
|
||||
logger.error(f"[McpTool] server={self.server_name} tool={self.name} error: {e}")
|
||||
return ToolResult.fail(str(e))
|
||||
@@ -44,6 +44,19 @@ class MemoryGetTool(BaseTool):
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
self.description = (
|
||||
"Read specific content from memory or knowledge files. "
|
||||
"Use this to get full context from a memory file, knowledge page, or specific line range."
|
||||
)
|
||||
self.params = {**self.params}
|
||||
self.params["properties"] = {**self.params["properties"]}
|
||||
self.params["properties"]["path"] = {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory or knowledge file (e.g. 'MEMORY.md', 'memory/2026-01-01.md', 'knowledge/concepts/moe.md')"
|
||||
}
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
@@ -68,16 +81,20 @@ class MemoryGetTool(BaseTool):
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
# Exception: MEMORY.md is in the root directory
|
||||
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
# Exceptions: MEMORY.md in root, knowledge/ files at workspace root
|
||||
if not path.startswith('memory/') and not path.startswith('knowledge/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = workspace_dir / path
|
||||
file_path = (workspace_dir / path).resolve()
|
||||
workspace_resolved = workspace_dir.resolve()
|
||||
|
||||
if not str(file_path).startswith(str(workspace_resolved) + '/') and file_path != workspace_resolved:
|
||||
return ToolResult.fail(f"Error: Access denied: path outside workspace")
|
||||
|
||||
if not file_path.exists():
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
content = file_path.read_text()
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
lines = content.split('\n')
|
||||
|
||||
# Handle line range
|
||||
|
||||
@@ -48,6 +48,13 @@ class MemorySearchTool(BaseTool):
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self.user_id = user_id
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
self.description = (
|
||||
"Search agent's long-term memory and knowledge base using semantic and keyword search. "
|
||||
"Use this to recall past conversations, preferences, and knowledge pages."
|
||||
)
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
|
||||
@@ -9,6 +9,7 @@ from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Read(BaseTool):
|
||||
@@ -47,7 +48,8 @@ class Read(BaseTool):
|
||||
self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
|
||||
self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
|
||||
self.pdf_extensions = {'.pdf'}
|
||||
|
||||
self.office_extensions = {'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx'}
|
||||
|
||||
# Readable text formats (will be read with truncation)
|
||||
self.text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
|
||||
@@ -56,7 +58,6 @@ class Read(BaseTool):
|
||||
'.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
|
||||
'.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
|
||||
'.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
|
||||
'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx' # Office documents
|
||||
}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
@@ -66,10 +67,12 @@ class Read(BaseTool):
|
||||
:param args: Contains file path and optional offset/limit parameters
|
||||
:return: File content or error message
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
# Support 'location' as alias for 'path' (LLM may use it from skill listing)
|
||||
path = args.get("path", "") or args.get("location", "")
|
||||
path = path.strip() if isinstance(path, str) else ""
|
||||
offset = args.get("offset")
|
||||
limit = args.get("limit")
|
||||
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
@@ -77,7 +80,7 @@ class Read(BaseTool):
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent reading sensitive config files
|
||||
env_config_path = os.path.expanduser("~/.cow/.env")
|
||||
env_config_path = expand_path("~/.cow/.env")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_path):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
@@ -117,7 +120,11 @@ class Read(BaseTool):
|
||||
# Check if PDF
|
||||
if file_ext in self.pdf_extensions:
|
||||
return self._read_pdf(absolute_path, path, offset, limit)
|
||||
|
||||
|
||||
# Check if Office document (.docx, .xlsx, .pptx, etc.)
|
||||
if file_ext in self.office_extensions:
|
||||
return self._read_office(absolute_path, path, file_ext, offset, limit)
|
||||
|
||||
# Read text file (with truncation for large files)
|
||||
return self._read_text(absolute_path, path, offset, limit)
|
||||
|
||||
@@ -129,7 +136,7 @@ class Read(BaseTool):
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
@@ -237,17 +244,12 @@ class Read(BaseTool):
|
||||
"message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
|
||||
})
|
||||
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
# Read file (utf-8-sig strips BOM automatically on Windows)
|
||||
# Note: Truncation is unified via truncate_head (DEFAULT_MAX_LINES / DEFAULT_MAX_BYTES)
|
||||
# so that offset/limit can paginate the entire file correctly.
|
||||
with open(absolute_path, 'r', encoding='utf-8-sig') as f:
|
||||
content = f.read()
|
||||
|
||||
# Truncate content if too long (20K characters max for model context)
|
||||
MAX_CONTENT_CHARS = 20 * 1024 # 20K characters
|
||||
content_truncated = False
|
||||
if len(content) > MAX_CONTENT_CHARS:
|
||||
content = content[:MAX_CONTENT_CHARS]
|
||||
content_truncated = True
|
||||
|
||||
|
||||
all_lines = content.split('\n')
|
||||
total_file_lines = len(all_lines)
|
||||
|
||||
@@ -283,11 +285,7 @@ class Read(BaseTool):
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
# Add truncation warning if content was truncated
|
||||
if content_truncated:
|
||||
output_text = f"[文件内容已截断到前 {format_size(MAX_CONTENT_CHARS)},完整文件大小: {format_size(file_size)}]\n\n"
|
||||
|
||||
|
||||
if truncation.first_line_exceeds_limit:
|
||||
# First line exceeds 30KB limit
|
||||
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
|
||||
@@ -334,6 +332,116 @@ class Read(BaseTool):
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading file: {str(e)}")
|
||||
|
||||
def _read_office(self, absolute_path: str, display_path: str, file_ext: str,
|
||||
offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""Read Office documents (.docx, .xlsx, .pptx) using python-docx / openpyxl / python-pptx."""
|
||||
try:
|
||||
text = self._extract_office_text(absolute_path, file_ext)
|
||||
except ImportError as e:
|
||||
return ToolResult.fail(str(e))
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading Office document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
return ToolResult.success({
|
||||
"content": f"[Office file {Path(absolute_path).name}: no text content could be extracted]",
|
||||
})
|
||||
|
||||
all_lines = text.split('\n')
|
||||
total_lines = len(all_lines)
|
||||
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
if offset < 0:
|
||||
start_line = max(0, total_lines + offset)
|
||||
else:
|
||||
start_line = max(0, offset - 1)
|
||||
if start_line >= total_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||
)
|
||||
|
||||
selected_content = text
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
truncation = truncate_head(selected_content)
|
||||
start_line_display = start_line + 1
|
||||
output_text = ""
|
||||
|
||||
if truncation.truncated:
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||
remaining = total_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text = truncation.content
|
||||
|
||||
return ToolResult.success({
|
||||
"content": output_text,
|
||||
"total_lines": total_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _extract_office_text(absolute_path: str, file_ext: str) -> str:
|
||||
"""Extract plain text from an Office document."""
|
||||
if file_ext in ('.docx', '.doc'):
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError("Error: python-docx library not installed. Install with: pip install python-docx")
|
||||
doc = Document(absolute_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs]
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
paragraphs.append('\t'.join(cell.text for cell in row.cells))
|
||||
return '\n'.join(paragraphs)
|
||||
|
||||
if file_ext in ('.xlsx', '.xls'):
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
except ImportError:
|
||||
raise ImportError("Error: openpyxl library not installed. Install with: pip install openpyxl")
|
||||
wb = load_workbook(absolute_path, read_only=True, data_only=True)
|
||||
parts = []
|
||||
for ws in wb.worksheets:
|
||||
parts.append(f"--- Sheet: {ws.title} ---")
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
parts.append('\t'.join(str(c) if c is not None else '' for c in row))
|
||||
wb.close()
|
||||
return '\n'.join(parts)
|
||||
|
||||
if file_ext in ('.pptx', '.ppt'):
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError("Error: python-pptx library not installed. Install with: pip install python-pptx")
|
||||
prs = Presentation(absolute_path)
|
||||
parts = []
|
||||
for i, slide in enumerate(prs.slides, 1):
|
||||
parts.append(f"--- Slide {i} ---")
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for para in shape.text_frame.paragraphs:
|
||||
text = para.text.strip()
|
||||
if text:
|
||||
parts.append(text)
|
||||
return '\n'.join(parts)
|
||||
|
||||
return ""
|
||||
|
||||
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read PDF file content
|
||||
|
||||
@@ -3,74 +3,137 @@ Integration module for scheduler with AgentBridge
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
# Global scheduler service instance
|
||||
_scheduler_service = None
|
||||
_task_store = None
|
||||
# Module-level lock to guard idempotent initialization across threads
|
||||
_init_lock = threading.Lock()
|
||||
|
||||
|
||||
def init_scheduler(agent_bridge) -> bool:
|
||||
"""
|
||||
Initialize scheduler service
|
||||
|
||||
Initialize scheduler service (idempotent).
|
||||
|
||||
Safe to call multiple times and from multiple threads: only the first
|
||||
successful call creates the singleton ``SchedulerService`` + background
|
||||
scanning thread. Subsequent calls return immediately.
|
||||
|
||||
Args:
|
||||
agent_bridge: AgentBridge instance
|
||||
|
||||
|
||||
Returns:
|
||||
True if initialized successfully
|
||||
True if scheduler is initialized (newly created or already running)
|
||||
"""
|
||||
global _scheduler_service, _task_store
|
||||
|
||||
try:
|
||||
from agent.tools.scheduler.task_store import TaskStore
|
||||
from agent.tools.scheduler.scheduler_service import SchedulerService
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
|
||||
|
||||
# Create task store
|
||||
_task_store = TaskStore(store_path)
|
||||
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
|
||||
|
||||
# Create execute callback
|
||||
def execute_task_callback(task: dict):
|
||||
"""Callback to execute a scheduled task"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "agent_task":
|
||||
_execute_agent_task(task, agent_bridge)
|
||||
elif action_type == "send_message":
|
||||
# Legacy support for old tasks
|
||||
_execute_send_message(task, agent_bridge)
|
||||
elif action_type == "tool_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_tool_call(task, agent_bridge)
|
||||
elif action_type == "skill_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_skill_call(task, agent_bridge)
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
|
||||
|
||||
# Create scheduler service
|
||||
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
|
||||
_scheduler_service.start()
|
||||
|
||||
logger.debug("[Scheduler] Scheduler service initialized and started")
|
||||
|
||||
# Fast path: already initialized and running
|
||||
if _scheduler_service is not None and getattr(_scheduler_service, "running", False):
|
||||
return True
|
||||
|
||||
with _init_lock:
|
||||
# Re-check under the lock to avoid races where multiple threads
|
||||
# passed the fast-path check before any of them acquired the lock.
|
||||
if _scheduler_service is not None and getattr(_scheduler_service, "running", False):
|
||||
return True
|
||||
|
||||
try:
|
||||
from agent.tools.scheduler.task_store import TaskStore
|
||||
from agent.tools.scheduler.scheduler_service import SchedulerService
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
|
||||
|
||||
# Create task store (reuse if already created)
|
||||
if _task_store is None:
|
||||
_task_store = TaskStore(store_path)
|
||||
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
|
||||
|
||||
# Create execute callback. Returns True on success, False to ask
|
||||
# the scheduler to retry on the next tick (e.g. channel not yet
|
||||
# ready right after process start).
|
||||
def execute_task_callback(task: dict):
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type")
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
receiver = action.get("receiver", "")
|
||||
|
||||
if not _is_channel_ready(channel_type, receiver):
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task.get('id')}: channel "
|
||||
f"'{channel_type}' not ready for receiver={receiver} "
|
||||
f"(no inbound msg cached since restart?); deferring"
|
||||
)
|
||||
return False
|
||||
|
||||
if action_type == "agent_task":
|
||||
return _execute_agent_task(task, agent_bridge)
|
||||
elif action_type == "send_message":
|
||||
return _execute_send_message(task, agent_bridge)
|
||||
elif action_type == "tool_call":
|
||||
return _execute_tool_call(task, agent_bridge)
|
||||
elif action_type == "skill_call":
|
||||
return _execute_skill_call(task, agent_bridge)
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
|
||||
return False
|
||||
|
||||
# Create scheduler service
|
||||
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
|
||||
_scheduler_service.start()
|
||||
|
||||
logger.info("[Scheduler] Service initialized and started")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _is_channel_ready(channel_type: str, receiver: str) -> bool:
|
||||
"""Best-effort readiness probe for outbound channels.
|
||||
|
||||
Returns False when we know the send will drop (e.g. weixin not yet
|
||||
logged in, web session has no polling queue), so the scheduler can
|
||||
defer instead of consuming the task. Unknown channels return True
|
||||
to preserve previous behaviour.
|
||||
"""
|
||||
if not channel_type or channel_type == "unknown":
|
||||
return True
|
||||
try:
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if channel is None:
|
||||
return False
|
||||
|
||||
if channel_type == "weixin":
|
||||
tokens = getattr(channel, "_context_tokens", None)
|
||||
if not tokens or receiver not in tokens:
|
||||
return False
|
||||
return True
|
||||
|
||||
if channel_type == "web":
|
||||
queues = getattr(channel, "session_queues", None)
|
||||
if not queues or receiver not in queues:
|
||||
return False
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
|
||||
return False
|
||||
logger.warning(f"[Scheduler] Channel readiness check failed for {channel_type}: {e}")
|
||||
return True
|
||||
|
||||
|
||||
def get_task_store():
|
||||
@@ -83,13 +146,53 @@ def get_scheduler_service():
|
||||
return _scheduler_service
|
||||
|
||||
|
||||
def _execute_agent_task(task: dict, agent_bridge):
|
||||
def _remember_delivered_output(
|
||||
agent_bridge,
|
||||
task: dict,
|
||||
channel_type: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Best-effort persistence of the message the scheduler sent to a user.
|
||||
|
||||
Uses notify_session_id (the real chat session_id stored at task creation time)
|
||||
so that group chats correctly associate the output with the user's conversation.
|
||||
Falls back to receiver for backward compatibility with old tasks.
|
||||
|
||||
Per-action-type behaviour:
|
||||
- agent_task / tool_call / skill_call: gated by ``scheduler_inject_to_session``
|
||||
(default True). These produce AI-generated content worth remembering.
|
||||
- send_message: additionally gated by ``scheduler_inject_send_message``
|
||||
(default False). Fixed reminder text rarely benefits follow-up Q&A and
|
||||
would just consume context tokens.
|
||||
"""
|
||||
Execute an agent_task action - let Agent handle the task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
if not content:
|
||||
return
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
|
||||
# send_message defaults to NOT being injected; explicit opt-in via config.
|
||||
if action_type == "send_message":
|
||||
if not conf().get("scheduler_inject_send_message", False):
|
||||
return
|
||||
|
||||
session_id = action.get("notify_session_id") or action.get("receiver")
|
||||
if not session_id:
|
||||
return
|
||||
try:
|
||||
remember = getattr(agent_bridge, "remember_scheduled_output", None)
|
||||
if remember:
|
||||
task_desc = action.get("task_description") or action.get("content", "")
|
||||
remember(session_id, str(content), channel_type=channel_type, task_description=task_desc)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[Scheduler] Failed to remember delivered output for {session_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _execute_agent_task(task: dict, agent_bridge) -> bool:
|
||||
"""
|
||||
Execute an agent_task action - let Agent handle the task.
|
||||
Returns True on successful delivery, False to retry next tick.
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
@@ -100,11 +203,11 @@ def _execute_agent_task(task: dict, agent_bridge):
|
||||
|
||||
if not task_description:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No task_description specified")
|
||||
return
|
||||
return True # malformed task, don't loop forever
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
return True
|
||||
|
||||
# Check for unsupported channels
|
||||
if channel_type == "dingtalk":
|
||||
@@ -112,11 +215,15 @@ def _execute_agent_task(task: dict, agent_bridge):
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing agent task '{task_description}'")
|
||||
|
||||
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
|
||||
# Format: scheduler_<receiver>_<task_id> to ensure isolation
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
|
||||
# Create context for Agent
|
||||
context = Context(ContextType.TEXT, task_description)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
# Channel-specific setup
|
||||
if channel_type == "web":
|
||||
@@ -129,62 +236,61 @@ def _execute_agent_task(task: dict, agent_bridge):
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk requires msg object, set to None for scheduled tasks
|
||||
context["msg"] = None
|
||||
# 如果是单聊,需要传递 sender_staff_id
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the task
|
||||
# Mark this as a scheduled task execution to prevent recursive task creation
|
||||
context["is_scheduled_task"] = True
|
||||
|
||||
try:
|
||||
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=True)
|
||||
|
||||
if reply and reply.content:
|
||||
# Send the reply via channel
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register request_id
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
request_id = context.get("request_id")
|
||||
if request_id:
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
# Send the reply
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send result: {e}")
|
||||
else:
|
||||
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
|
||||
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=False)
|
||||
|
||||
if not (reply and reply.content):
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution")
|
||||
|
||||
return True # agent ran but produced nothing; don't loop
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
request_id = context.get("request_id")
|
||||
if request_id:
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send result: {e}")
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, reply.content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute task via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_agent_task: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_send_message(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a send_message action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
def _execute_send_message(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a send_message action. Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
content = action.get("content", "")
|
||||
@@ -194,7 +300,7 @@ def _execute_send_message(task: dict, agent_bridge):
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
return True
|
||||
|
||||
# Create context for sending message
|
||||
context = Context(ContextType.TEXT, content)
|
||||
@@ -228,170 +334,146 @@ def _execute_send_message(task: dict, agent_bridge):
|
||||
logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
elif channel_type == "qq":
|
||||
context["msg"] = None
|
||||
|
||||
# Create reply
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_send_message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_tool_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a tool_call action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
def _execute_tool_call(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a tool_call action. Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
tool_name = action.get("call_name") or action.get("tool_name")
|
||||
tool_params = action.get("call_params") or action.get("tool_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
|
||||
if not tool_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified")
|
||||
return
|
||||
|
||||
return True
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Get tool manager and create tool instance
|
||||
return True
|
||||
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
tool_manager = ToolManager()
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
tool = ToolManager().create_tool(tool_name)
|
||||
if not tool:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found")
|
||||
return
|
||||
|
||||
# Execute tool
|
||||
return True
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}")
|
||||
result = tool.execute(tool_params)
|
||||
|
||||
# Get result content
|
||||
if hasattr(result, 'result'):
|
||||
content = result.result
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
# Add prefix if specified
|
||||
content = result.result if hasattr(result, 'result') else str(result)
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
# Send result as message
|
||||
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific context setup
|
||||
|
||||
request_id = None
|
||||
if channel_type == "web":
|
||||
# Web channel needs request_id
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and request_id and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send tool result: {e}")
|
||||
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_skill_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a skill_call action by asking Agent to run the skill
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
def _execute_skill_call(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a skill_call action by asking Agent to run the skill.
|
||||
Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
skill_name = action.get("call_name") or action.get("skill_name")
|
||||
skill_params = action.get("call_params") or action.get("skill_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("isgroup", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
|
||||
if not skill_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified")
|
||||
return
|
||||
|
||||
return True
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
return True
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}")
|
||||
|
||||
# Build a natural language query for the Agent to execute the skill
|
||||
# Format: "Use skill-name to do something with params"
|
||||
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()])
|
||||
query = f"Use {skill_name} skill"
|
||||
if param_str:
|
||||
query += f" with {param_str}"
|
||||
|
||||
# Create context for Agent
|
||||
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific setup
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
@@ -399,31 +481,51 @@ def _execute_skill_call(task: dict, agent_bridge):
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the skill
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
try:
|
||||
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=True)
|
||||
|
||||
if reply and reply.content:
|
||||
content = reply.content
|
||||
|
||||
# Add prefix if specified
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
|
||||
|
||||
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
return False
|
||||
|
||||
if not (reply and reply.content):
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
|
||||
return True
|
||||
|
||||
content = reply.content
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
req_id = context.get("request_id")
|
||||
if req_id:
|
||||
channel.request_to_session[req_id] = receiver
|
||||
|
||||
try:
|
||||
channel.send(Reply(ReplyType.TEXT, content), context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send skill result: {e}")
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_skill_call: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
@@ -440,8 +542,7 @@ def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
if context:
|
||||
tool.current_context = context
|
||||
|
||||
# Also set channel_type from config
|
||||
channel_type = conf().get("channel_type", "unknown")
|
||||
channel_type = context.get("channel_type") or conf().get("channel_type", "unknown")
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = channel_type
|
||||
|
||||
@@ -10,6 +10,19 @@ from croniter import croniter
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def _parse_naive_local(iso_str: str) -> datetime:
|
||||
"""Parse an ISO datetime and coerce it to tz-naive local time.
|
||||
|
||||
The scheduler uses ``datetime.now()`` (tz-naive) for all comparisons,
|
||||
so any persisted timestamp must be normalized to the same flavor —
|
||||
otherwise comparing naive vs aware raises TypeError.
|
||||
"""
|
||||
dt = datetime.fromisoformat(iso_str)
|
||||
if dt.tzinfo is not None:
|
||||
dt = dt.astimezone().replace(tzinfo=None)
|
||||
return dt
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""
|
||||
Background service that executes scheduled tasks
|
||||
@@ -39,7 +52,6 @@ class SchedulerService:
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self.thread.start()
|
||||
logger.debug("[Scheduler] Service started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the scheduler service"""
|
||||
@@ -54,15 +66,14 @@ class SchedulerService:
|
||||
|
||||
def _run_loop(self):
|
||||
"""Main scheduler loop"""
|
||||
logger.debug("[Scheduler] Scheduler loop started")
|
||||
logger.info("[Scheduler] Scheduler loop started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self._check_and_execute_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in scheduler loop: {e}")
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
def _check_and_execute_tasks(self):
|
||||
@@ -72,12 +83,18 @@ class SchedulerService:
|
||||
|
||||
for task in tasks:
|
||||
try:
|
||||
# Check if task is due
|
||||
if self._is_task_due(task, now):
|
||||
logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}")
|
||||
self._execute_task(task)
|
||||
|
||||
# Update next run time
|
||||
ok = self._execute_task(task)
|
||||
if not ok:
|
||||
# Leave next_run_at as-is so the next loop retries.
|
||||
# Cron tasks within the catch-up window will keep
|
||||
# firing; beyond it _is_task_due will reschedule.
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task['id']} delivery failed, will retry next tick"
|
||||
)
|
||||
continue
|
||||
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
@@ -85,12 +102,8 @@ class SchedulerService:
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
else:
|
||||
# One-time task, disable it
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task completed and disabled: {task['id']}")
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task completed and removed: {task['id']}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
|
||||
|
||||
@@ -117,37 +130,43 @@ class SchedulerService:
|
||||
return False
|
||||
|
||||
try:
|
||||
next_run = datetime.fromisoformat(next_run_str)
|
||||
|
||||
# Check if task is overdue (e.g., service restart)
|
||||
next_run = _parse_naive_local(next_run_str)
|
||||
|
||||
if next_run < now:
|
||||
time_diff = (now - next_run).total_seconds()
|
||||
|
||||
# If overdue by more than 5 minutes, skip this run and schedule next
|
||||
if time_diff > 300: # 5 minutes
|
||||
logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run")
|
||||
|
||||
# For one-time tasks, disable them
|
||||
schedule = task.get("schedule", {})
|
||||
if schedule.get("type") == "once":
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, disabled")
|
||||
return False
|
||||
|
||||
# For recurring tasks, calculate next run from now
|
||||
next_next_run = self._calculate_next_run(task, now)
|
||||
if next_next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_next_run.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
# Catch-up window: fire if we're within 10 minutes of the
|
||||
# scheduled tick. Beyond that we'd rather skip than push a
|
||||
# stale daily report to the user.
|
||||
if time_diff <= 600:
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, "
|
||||
f"skipping and scheduling next run"
|
||||
)
|
||||
|
||||
if schedule_type == "once":
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, removed")
|
||||
return False
|
||||
|
||||
|
||||
next_next_run = self._calculate_next_run(task, now)
|
||||
if next_next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_next_run.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
|
||||
return False
|
||||
|
||||
return now >= next_run
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Scheduler] Failed to evaluate due-state for task "
|
||||
f"{task.get('id')} (next_run_at={next_run_str!r}): {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
|
||||
@@ -191,30 +210,34 @@ class SchedulerService:
|
||||
return None
|
||||
|
||||
try:
|
||||
run_at = datetime.fromisoformat(run_at_str)
|
||||
# Only return if in the future
|
||||
run_at = _parse_naive_local(run_at_str)
|
||||
if run_at > from_time:
|
||||
return run_at
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Scheduler] Failed to parse once-task run_at "
|
||||
f"{run_at_str!r}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _execute_task(self, task: dict):
|
||||
def _execute_task(self, task: dict) -> bool:
|
||||
"""
|
||||
Execute a task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
Execute a task.
|
||||
|
||||
Returns True if delivery succeeded (caller should advance state),
|
||||
False if it failed (caller should keep next_run_at so the next
|
||||
loop iteration retries). Callback may return None for legacy
|
||||
behaviour, treated as success.
|
||||
"""
|
||||
try:
|
||||
# Call the execute callback
|
||||
self.execute_callback(task)
|
||||
result = self.execute_callback(task)
|
||||
return False if result is False else True
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task['id']}: {e}")
|
||||
# Update task with error
|
||||
self.task_store.update_task(task['id'], {
|
||||
"last_error": str(e),
|
||||
"last_error_at": datetime.now().isoformat()
|
||||
})
|
||||
return False
|
||||
|
||||
@@ -20,7 +20,8 @@ class SchedulerTool(BaseTool):
|
||||
|
||||
name: str = "scheduler"
|
||||
description: str = (
|
||||
"创建、查询和管理定时任务。支持固定消息和AI任务两种类型。\n\n"
|
||||
"创建、查询和管理定时任务(提醒、周期性任务等)。\n\n"
|
||||
"⚠️ 重要:仅当需要「定时/提醒/每天/每周/X分钟后/X点」等延迟或周期执行时才使用此工具。"
|
||||
"使用方法:\n"
|
||||
"- 创建:action='create', name='任务名', message/ai_task='内容', schedule_type='once/interval/cron', schedule_value='...'\n"
|
||||
"- 查询:action='list' / action='get', task_id='任务ID'\n"
|
||||
@@ -53,7 +54,7 @@ class SchedulerTool(BaseTool):
|
||||
},
|
||||
"ai_task": {
|
||||
"type": "string",
|
||||
"description": "AI任务描述 (与message二选一),如'搜索今日新闻'、'查询天气'"
|
||||
"description": "AI任务描述 (与message二选一),用于定时让AI执行的任务"
|
||||
},
|
||||
"schedule_type": {
|
||||
"type": "string",
|
||||
@@ -157,6 +158,11 @@ class SchedulerTool(BaseTool):
|
||||
# Create task
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Capture the real chat session_id at task creation time so that scheduler
|
||||
# can later inject the delivered output into the user's actual conversation
|
||||
# (in group chats, session_id != receiver, e.g. "user_id:group_id" on feishu).
|
||||
notify_session_id = context.get("session_id")
|
||||
|
||||
# Build action based on message or ai_task
|
||||
if message:
|
||||
action = {
|
||||
@@ -165,7 +171,8 @@ class SchedulerTool(BaseTool):
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
"channel_type": self.config.get("channel_type", "unknown"),
|
||||
"notify_session_id": notify_session_id,
|
||||
}
|
||||
else: # ai_task
|
||||
action = {
|
||||
@@ -174,7 +181,8 @@ class SchedulerTool(BaseTool):
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
"channel_type": self.config.get("channel_type", "unknown"),
|
||||
"notify_session_id": notify_session_id,
|
||||
}
|
||||
|
||||
# 针对钉钉单聊,额外存储 sender_staff_id
|
||||
@@ -356,9 +364,12 @@ class SchedulerTool(BaseTool):
|
||||
logger.error(f"[SchedulerTool] Invalid relative time format: {schedule_value}")
|
||||
return None
|
||||
else:
|
||||
# Absolute time in ISO format
|
||||
datetime.fromisoformat(schedule_value)
|
||||
return {"type": "once", "run_at": schedule_value}
|
||||
# Absolute ISO time. Normalize to tz-naive local so it
|
||||
# stays comparable with the scheduler's datetime.now().
|
||||
parsed = datetime.fromisoformat(schedule_value)
|
||||
if parsed.tzinfo is not None:
|
||||
parsed = parsed.astimezone().replace(tzinfo=None)
|
||||
return {"type": "once", "run_at": parsed.isoformat()}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Invalid schedule: {e}")
|
||||
@@ -423,7 +434,7 @@ class SchedulerTool(BaseTool):
|
||||
try:
|
||||
dt = datetime.fromisoformat(run_at)
|
||||
return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
|
||||
except:
|
||||
except Exception:
|
||||
return "一次性"
|
||||
|
||||
return "未知"
|
||||
@@ -437,6 +448,6 @@ class SchedulerTool(BaseTool):
|
||||
return msg.other_user_nickname or "群聊"
|
||||
else:
|
||||
return msg.from_user_nickname or "用户"
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return "未知"
|
||||
|
||||
@@ -8,6 +8,7 @@ import threading
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class TaskStore:
|
||||
@@ -24,7 +25,7 @@ class TaskStore:
|
||||
"""
|
||||
if store_path is None:
|
||||
# Default to ~/cow/scheduler/tasks.json
|
||||
home = os.path.expanduser("~")
|
||||
home = expand_path("~")
|
||||
store_path = os.path.join(home, "cow", "scheduler", "tasks.json")
|
||||
|
||||
self.store_path = store_path
|
||||
@@ -71,7 +72,7 @@ class TaskStore:
|
||||
with open(self.store_path, 'r') as src:
|
||||
with open(backup_path, 'w') as dst:
|
||||
dst.write(src.read())
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save tasks
|
||||
|
||||
@@ -7,20 +7,21 @@ from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Send(BaseTool):
|
||||
"""Tool for sending files to the user"""
|
||||
|
||||
name: str = "send"
|
||||
description: str = "Send a file (image, video, audio, document) to the user. Use this when the user explicitly asks to send/share a file."
|
||||
description: str = "Send a LOCAL file (image, video, audio, document) to the user. Only for local file paths. Do NOT use this for URLs — URLs should be included directly in your text reply, the system will handle them automatically."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to send. Can be absolute path or relative to workspace."
|
||||
"description": "Local file path to send. Must be an absolute path or relative to workspace. Do NOT pass URLs here."
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
@@ -97,12 +98,23 @@ class Send(BaseTool):
|
||||
"size_formatted": self._format_size(file_size),
|
||||
"message": message or f"正在发送 {file_name}"
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
from common.cloud_client import get_website_base_url, copy_send_file
|
||||
|
||||
# Do nothing when in local env
|
||||
if get_website_base_url():
|
||||
url = copy_send_file(absolute_path, self.cwd)
|
||||
if url:
|
||||
result["url"] = url
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
path = os.path.expanduser(path)
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Type
|
||||
from agent.tools.base_tool import BaseTool
|
||||
@@ -7,6 +8,26 @@ from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
def _normalize_mcp_configs(raw) -> list:
|
||||
"""
|
||||
Convert MCP server config to internal list format.
|
||||
Supports:
|
||||
- list format (mcp_servers): [{"name": "x", "type": "stdio", ...}]
|
||||
- dict format (mcpServers): {"x": {"command": "npx", ...}}
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
result = []
|
||||
for name, cfg in raw.items():
|
||||
entry = {"name": name, **cfg}
|
||||
if "type" not in entry:
|
||||
entry["type"] = "sse" if "url" in entry else "stdio"
|
||||
result.append(entry)
|
||||
return result
|
||||
return []
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
Tool manager for managing tools.
|
||||
@@ -25,6 +46,31 @@ class ToolManager:
|
||||
# Initialize only once
|
||||
if not hasattr(self, 'tool_classes'):
|
||||
self.tool_classes = {} # Dictionary to store tool classes
|
||||
if not hasattr(self, '_mcp_registry'):
|
||||
self._mcp_registry = None # Lazy init: only created when MCP servers are configured
|
||||
if not hasattr(self, '_mcp_tool_instances'):
|
||||
self._mcp_tool_instances: dict = {} # tool_name -> McpTool instance
|
||||
if not hasattr(self, '_mcp_lock'):
|
||||
# Guards _mcp_loaded check-then-set so concurrent callers
|
||||
# don't trigger duplicate background loaders.
|
||||
self._mcp_lock = threading.Lock()
|
||||
if not hasattr(self, '_mcp_loaded'):
|
||||
# Idempotency flag. Flipped to True the moment the first loader
|
||||
# is dispatched (synchronously, inside _mcp_lock). Subsequent
|
||||
# _load_mcp_tools() calls become no-ops, so per-session agent
|
||||
# initialization never re-forks MCP subprocesses.
|
||||
self._mcp_loaded = False
|
||||
if not hasattr(self, '_mcp_status'):
|
||||
# server_name -> "pending" / "ready" / "failed"
|
||||
# Useful for UI / introspection while async loading is in progress.
|
||||
self._mcp_status: dict = {}
|
||||
if not hasattr(self, '_mcp_signature'):
|
||||
# (mtime, sha256) of mcp.json the last time we loaded.
|
||||
# Used by refresh_mcp_if_changed() to skip re-parsing when nothing changed.
|
||||
self._mcp_signature: tuple = (None, None)
|
||||
if not hasattr(self, '_mcp_active_configs'):
|
||||
# server_name -> normalized config dict, for diff-based reload.
|
||||
self._mcp_active_configs: dict = {}
|
||||
|
||||
def load_tools(self, tools_dir: str = "", config_dict=None):
|
||||
"""
|
||||
@@ -39,6 +85,8 @@ class ToolManager:
|
||||
self._load_tools_from_init()
|
||||
self._configure_tools_from_config(config_dict)
|
||||
|
||||
self._load_mcp_tools()
|
||||
|
||||
def _load_tools_from_init(self) -> bool:
|
||||
"""
|
||||
Load tool classes from tools.__init__.__all__
|
||||
@@ -70,10 +118,14 @@ class ToolManager:
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip memory tools (they need special initialization with memory_manager)
|
||||
# Skip tools that need special initialization
|
||||
if class_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {class_name} (requires memory_manager)")
|
||||
continue
|
||||
# McpTool instances are registered dynamically via _load_mcp_tools()
|
||||
if class_name == "McpTool":
|
||||
logger.debug(f"Skipped tool {class_name} (registered dynamically via mcp_servers config)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
@@ -84,11 +136,11 @@ class ToolManager:
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
if "playwright" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
@@ -154,11 +206,11 @@ class ToolManager:
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
if "playwright" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
@@ -197,7 +249,7 @@ class ToolManager:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool is configured but not loaded.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif tool_name == "google_search":
|
||||
@@ -212,6 +264,306 @@ class ToolManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error configuring tools from config: {e}")
|
||||
|
||||
def _mcp_json_path(self) -> str:
|
||||
import os
|
||||
workspace = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
return os.path.join(workspace, "mcp.json")
|
||||
|
||||
def _read_mcp_json_signature(self):
|
||||
"""
|
||||
Return (mtime, sha256_of_bytes) for ~/cow/mcp.json without parsing.
|
||||
Returns (None, None) if the file doesn't exist or is unreadable.
|
||||
Cheap enough (one stat + one small read) to call on every agent init.
|
||||
"""
|
||||
import os
|
||||
import hashlib
|
||||
path = self._mcp_json_path()
|
||||
try:
|
||||
mtime = os.path.getmtime(path)
|
||||
except OSError:
|
||||
return (None, None)
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
digest = hashlib.sha256(f.read()).hexdigest()
|
||||
except OSError:
|
||||
return (mtime, None)
|
||||
return (mtime, digest)
|
||||
|
||||
def _load_mcp_configs(self) -> list:
|
||||
"""
|
||||
Load MCP server configs with priority:
|
||||
1. ~/cow/mcp.json (supports both mcpServers and mcp_servers keys)
|
||||
2. config.json mcp_servers field (fallback)
|
||||
"""
|
||||
import os
|
||||
import json as _json
|
||||
|
||||
mcp_json_path = self._mcp_json_path()
|
||||
|
||||
if os.path.exists(mcp_json_path):
|
||||
try:
|
||||
with open(mcp_json_path, "r", encoding="utf-8") as f:
|
||||
data = _json.load(f)
|
||||
raw = data.get("mcpServers") or data.get("mcp_servers") or data
|
||||
logger.info(f"[ToolManager] Loading MCP config from {mcp_json_path}")
|
||||
return _normalize_mcp_configs(raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] Failed to read {mcp_json_path}: {e}, falling back to config.json")
|
||||
|
||||
raw = conf().get("mcp_servers", [])
|
||||
return _normalize_mcp_configs(raw)
|
||||
|
||||
def _load_mcp_tools(self):
|
||||
"""
|
||||
Trigger MCP tool loading in a background thread (idempotent).
|
||||
|
||||
Returns immediately. Booting MCP servers (npx, uvx, etc.) takes
|
||||
seconds to tens of seconds on first run, which would otherwise
|
||||
block agent initialization and the user's first message.
|
||||
Built-in tools work fine without MCP, so we let the agent serve
|
||||
traffic right away and let MCP servers come online in the
|
||||
background. Per-session agents read a snapshot of whatever is
|
||||
ready at construction time and gracefully ignore the rest.
|
||||
"""
|
||||
with self._mcp_lock:
|
||||
if self._mcp_loaded:
|
||||
return
|
||||
mcp_servers_config = self._load_mcp_configs()
|
||||
# Snapshot the signature now so future refresh_mcp_if_changed()
|
||||
# calls can short-circuit when nothing has changed on disk.
|
||||
self._mcp_signature = self._read_mcp_json_signature()
|
||||
self._mcp_active_configs = {
|
||||
cfg.get("name", "<unnamed>"): cfg for cfg in mcp_servers_config
|
||||
}
|
||||
if not mcp_servers_config:
|
||||
# Mark as loaded even when there is nothing to load,
|
||||
# so we don't re-read the config file on every call.
|
||||
self._mcp_loaded = True
|
||||
return
|
||||
|
||||
# Mark pending immediately so list_mcp_status() callers see
|
||||
# the in-progress state instead of an empty dict.
|
||||
for cfg in mcp_servers_config:
|
||||
name = cfg.get("name", "<unnamed>")
|
||||
self._mcp_status[name] = "pending"
|
||||
|
||||
self._mcp_loaded = True
|
||||
threading.Thread(
|
||||
target=self._load_mcp_tools_async,
|
||||
args=(mcp_servers_config,),
|
||||
daemon=True,
|
||||
name="mcp-loader",
|
||||
).start()
|
||||
logger.info(
|
||||
f"[ToolManager] MCP loading started in background "
|
||||
f"({len(mcp_servers_config)} server(s) configured)"
|
||||
)
|
||||
|
||||
def refresh_mcp_if_changed(self):
|
||||
"""
|
||||
Cheap check whether ~/cow/mcp.json has changed since last load.
|
||||
If it has, do a diff-based reload: start newly added servers,
|
||||
shut down removed ones, and restart any whose config was edited.
|
||||
Untouched servers are left running.
|
||||
|
||||
Designed to be called on every agent creation. The fast path is
|
||||
a single os.stat() — completely free when nothing has changed.
|
||||
"""
|
||||
with self._mcp_lock:
|
||||
new_sig = self._read_mcp_json_signature()
|
||||
if new_sig == self._mcp_signature:
|
||||
return # no-op fast path
|
||||
|
||||
try:
|
||||
new_configs = self._load_mcp_configs()
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] MCP reload — failed to parse config: {e}")
|
||||
return
|
||||
|
||||
new_by_name = {
|
||||
cfg.get("name", "<unnamed>"): cfg for cfg in new_configs
|
||||
}
|
||||
old_by_name = self._mcp_active_configs
|
||||
|
||||
added = [n for n in new_by_name if n not in old_by_name]
|
||||
removed = [n for n in old_by_name if n not in new_by_name]
|
||||
changed = [
|
||||
n for n in new_by_name
|
||||
if n in old_by_name and new_by_name[n] != old_by_name[n]
|
||||
]
|
||||
|
||||
if not (added or removed or changed):
|
||||
# Signature drifted but content is logically identical
|
||||
# (e.g. user re-saved the file without edits). Just sync.
|
||||
self._mcp_signature = new_sig
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[ToolManager] mcp.json changed — "
|
||||
f"adding={added}, removing={removed}, restarting={changed}"
|
||||
)
|
||||
|
||||
# Tear down removed + changed servers (changed ones get restarted below)
|
||||
for name in removed + changed:
|
||||
self._teardown_mcp_server(name)
|
||||
|
||||
# Spin up newly added + changed servers in the background
|
||||
to_start = [new_by_name[n] for n in added + changed]
|
||||
if to_start:
|
||||
for cfg in to_start:
|
||||
self._mcp_status[cfg.get("name", "<unnamed>")] = "pending"
|
||||
threading.Thread(
|
||||
target=self._load_mcp_tools_async,
|
||||
args=(to_start,),
|
||||
daemon=True,
|
||||
name="mcp-loader-reload",
|
||||
).start()
|
||||
|
||||
self._mcp_active_configs = new_by_name
|
||||
self._mcp_signature = new_sig
|
||||
|
||||
def _teardown_mcp_server(self, server_name: str):
|
||||
"""Shut down one MCP server and drop its tools from the registry."""
|
||||
if self._mcp_registry is None:
|
||||
return
|
||||
client = None
|
||||
with self._mcp_registry._registry_lock:
|
||||
client = self._mcp_registry._clients.pop(server_name, None)
|
||||
if client is not None:
|
||||
try:
|
||||
client.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP] Error shutting down '{server_name}': {e}")
|
||||
# Drop tools that belonged to this server.
|
||||
for tool_name in list(self._mcp_tool_instances.keys()):
|
||||
tool = self._mcp_tool_instances.get(tool_name)
|
||||
if tool is not None and getattr(tool, "server_name", None) == server_name:
|
||||
self._mcp_tool_instances.pop(tool_name, None)
|
||||
self._mcp_status.pop(server_name, None)
|
||||
|
||||
def _load_mcp_tools_async(self, mcp_servers_config):
|
||||
"""
|
||||
Background worker: bring up each MCP server one-by-one and
|
||||
publish ready tools to _mcp_tool_instances as they come online.
|
||||
|
||||
Server failures are isolated — one bad server cannot block
|
||||
the others, and never raises out of the worker thread.
|
||||
"""
|
||||
try:
|
||||
from agent.tools.mcp.mcp_client import McpClient, McpClientRegistry
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
|
||||
registry = McpClientRegistry()
|
||||
self._mcp_registry = registry
|
||||
|
||||
for cfg in mcp_servers_config:
|
||||
server_name = cfg.get("name", "<unnamed>")
|
||||
try:
|
||||
client = McpClient(cfg)
|
||||
if not client.initialize():
|
||||
self._mcp_status[server_name] = "failed"
|
||||
logger.warning(
|
||||
f"[MCP] Server '{server_name}' failed to initialize — skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
tool_schemas = client.list_tools()
|
||||
added = []
|
||||
for schema in tool_schemas:
|
||||
tool_name = schema.get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
mcp_tool = McpTool(client, schema, server_name)
|
||||
# Atomic dict assignment is GIL-safe; readers iterate
|
||||
# over a list() snapshot to avoid concurrent mutation.
|
||||
self._mcp_tool_instances[tool_name] = mcp_tool
|
||||
added.append(tool_name)
|
||||
|
||||
# Register client into the shared registry only after its
|
||||
# tools are visible, so callers never see a half-loaded server.
|
||||
with registry._registry_lock:
|
||||
registry._clients[server_name] = client
|
||||
self._mcp_status[server_name] = "ready"
|
||||
logger.info(
|
||||
f"[MCP] Server '{server_name}' ready — "
|
||||
f"{len(added)} tool(s): {added}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._mcp_status[server_name] = "failed"
|
||||
logger.warning(f"[MCP] Server '{server_name}' load failed: {e}")
|
||||
|
||||
ready = sum(1 for s in self._mcp_status.values() if s == "ready")
|
||||
total = len(self._mcp_status)
|
||||
logger.info(
|
||||
f"[ToolManager] MCP loading complete: "
|
||||
f"{ready}/{total} server(s) ready, "
|
||||
f"{len(self._mcp_tool_instances)} tool(s) available"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] MCP background loader crashed: {e}")
|
||||
|
||||
def list_mcp_status(self) -> dict:
|
||||
"""Return {server_name: status} snapshot for UI / debugging."""
|
||||
return dict(self._mcp_status)
|
||||
|
||||
def sync_mcp_into_agent(self, agent) -> tuple:
|
||||
"""
|
||||
Reconcile a live agent's tool collection with the current MCP tool registry.
|
||||
|
||||
Adds tools that finished loading after the agent was created,
|
||||
and removes tools whose MCP server was torn down. Built-in tools
|
||||
on the agent are left untouched.
|
||||
|
||||
Handles both representations CowAgent uses:
|
||||
- Agent.tools: list[BaseTool] (default Agent class)
|
||||
- AgentStream.tools: dict[str, BaseTool] (streaming agent)
|
||||
|
||||
Returns (added_names, removed_names) for logging.
|
||||
"""
|
||||
if agent is None or not hasattr(agent, "tools"):
|
||||
return ([], [])
|
||||
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
current = self._mcp_tool_instances
|
||||
registry_names = set(current.keys())
|
||||
|
||||
agent_tools = agent.tools
|
||||
|
||||
if isinstance(agent_tools, dict):
|
||||
agent_mcp_names = {
|
||||
name for name, tool in agent_tools.items()
|
||||
if isinstance(tool, McpTool)
|
||||
}
|
||||
added = registry_names - agent_mcp_names
|
||||
removed = agent_mcp_names - registry_names
|
||||
if not (added or removed):
|
||||
return ([], [])
|
||||
for name in added:
|
||||
agent_tools[name] = current[name]
|
||||
for name in removed:
|
||||
agent_tools.pop(name, None)
|
||||
|
||||
elif isinstance(agent_tools, list):
|
||||
agent_mcp_names = {
|
||||
t.name for t in agent_tools if isinstance(t, McpTool)
|
||||
}
|
||||
added = registry_names - agent_mcp_names
|
||||
removed = agent_mcp_names - registry_names
|
||||
if not (added or removed):
|
||||
return ([], [])
|
||||
if removed:
|
||||
agent.tools = [
|
||||
t for t in agent_tools
|
||||
if not (isinstance(t, McpTool) and t.name in removed)
|
||||
]
|
||||
for name in added:
|
||||
agent.tools.append(current[name])
|
||||
|
||||
else:
|
||||
return ([], [])
|
||||
|
||||
return (sorted(added), sorted(removed))
|
||||
|
||||
def create_tool(self, name: str) -> BaseTool:
|
||||
"""
|
||||
Get a new instance of a tool by name.
|
||||
@@ -229,6 +581,12 @@ class ToolManager:
|
||||
tool_instance.config = self.tool_configs[name]
|
||||
|
||||
return tool_instance
|
||||
|
||||
# Fall back to MCP tool instances
|
||||
mcp_tool = self._mcp_tool_instances.get(name)
|
||||
if mcp_tool:
|
||||
return mcp_tool
|
||||
|
||||
return None
|
||||
|
||||
def list_tools(self) -> dict:
|
||||
@@ -245,4 +603,17 @@ class ToolManager:
|
||||
"description": temp_instance.description,
|
||||
"parameters": temp_instance.get_json_schema()
|
||||
}
|
||||
|
||||
# Include MCP tool instances
|
||||
for name, mcp_tool in self._mcp_tool_instances.items():
|
||||
result[name] = {
|
||||
"description": mcp_tool.description,
|
||||
"parameters": mcp_tool.params,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def shutdown_mcp(self):
|
||||
"""Shut down all MCP server clients."""
|
||||
if self._mcp_registry:
|
||||
self._mcp_registry.shutdown_all()
|
||||
|
||||
@@ -8,7 +8,10 @@ Truncation is based on two independent limits - whichever is hit first wins:
|
||||
Never returns partial lines (except bash tail truncation edge case).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Literal
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 2000
|
||||
@@ -278,7 +281,7 @@ def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
|
||||
return encoded[start:].decode('utf-8', errors='ignore')
|
||||
|
||||
|
||||
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> tuple[str, bool]:
|
||||
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> Tuple[str, bool]:
|
||||
"""
|
||||
Truncate single line to max characters, add [truncated] suffix.
|
||||
Used for grep match lines.
|
||||
|
||||
1
agent/tools/vision/__init__.py
Normal file
1
agent/tools/vision/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from agent.tools.vision.vision import Vision
|
||||
814
agent/tools/vision/vision.py
Normal file
814
agent/tools/vision/vision.py
Normal file
@@ -0,0 +1,814 @@
|
||||
"""
|
||||
Vision tool - Analyze images using Vision API.
|
||||
Supports local files (auto base64-encoded) and HTTP URLs.
|
||||
|
||||
Provider resolution:
|
||||
- tools.vision.model (if set) means "prefer this model first; fall back to
|
||||
other configured providers if it fails". The model name is mapped to its
|
||||
native provider (e.g. doubao-* → Doubao, kimi-* → Moonshot, gpt-* →
|
||||
OpenAI/LinkAI). That provider is tried first, then the standard auto
|
||||
chain runs as fallback (with the preferred provider de-duplicated).
|
||||
- Auto chain priority:
|
||||
1. Main model via bot.call_vision — only when the main bot is known
|
||||
to actually support vision (not just expose a call_vision method).
|
||||
2. Other models whose API key is configured.
|
||||
3. OpenAI / LinkAI raw HTTP.
|
||||
When use_linkai=true, LinkAI is promoted to #1.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
DEFAULT_MODEL = const.GPT_41_MINI
|
||||
DEFAULT_TIMEOUT = 60
|
||||
MAX_TOKENS = 1000
|
||||
COMPRESS_THRESHOLD = 1_048_576 # 1 MB
|
||||
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
}
|
||||
|
||||
_MAIN_MODEL_PROVIDER_NAME = "MainModel"
|
||||
|
||||
# (config_key_for_api_key, bot_type, default_vision_model, provider_display_name)
|
||||
# Auto-discovered as fallback vision providers when their API key is configured.
|
||||
# OpenAI and LinkAI are handled separately (raw HTTP providers), so not listed here.
|
||||
_DISCOVERABLE_MODELS = [
|
||||
("moonshot_api_key", const.MOONSHOT, const.KIMI_K2_6, "Moonshot"),
|
||||
("ark_api_key", const.DOUBAO, const.DOUBAO_SEED_2_PRO, "Doubao"),
|
||||
("dashscope_api_key", const.QWEN_DASHSCOPE, const.QWEN37_PLUS, "DashScope"),
|
||||
("claude_api_key", const.CLAUDEAPI, const.CLAUDE_4_6_SONNET, "Claude"),
|
||||
("gemini_api_key", const.GEMINI, const.GEMINI_35_FLASH, "Gemini"),
|
||||
("qianfan_api_key", const.QIANFAN, const.ERNIE_45_TURBO_VL, "Qianfan"),
|
||||
("zhipu_ai_api_key", const.ZHIPU_AI, const.GLM_4_7, "ZhipuAI"),
|
||||
("minimax_api_key", const.MiniMax, const.MINIMAX_M2_7, "MiniMax"),
|
||||
("mimo_api_key", const.MIMO, const.MIMO_V2_5_PRO, "MiMo"),
|
||||
]
|
||||
|
||||
# Model name prefix → discoverable provider display_name.
|
||||
# Used to auto-route tools.vision.model to its native provider.
|
||||
# Matched case-insensitively; longest prefix wins.
|
||||
_MODEL_PREFIX_TO_PROVIDER = [
|
||||
("doubao-", "Doubao"),
|
||||
("kimi-", "Moonshot"),
|
||||
("moonshot-", "Moonshot"),
|
||||
("qwen", "DashScope"), # qwen-*, qwen3-*, qwen3.6-*, etc.
|
||||
("claude-", "Claude"),
|
||||
("ernie-", "Qianfan"),
|
||||
("gemini-", "Gemini"),
|
||||
("glm-", "ZhipuAI"),
|
||||
("minimax-", "MiniMax"),
|
||||
("abab", "MiniMax"),
|
||||
("mimo-", "MiMo"),
|
||||
]
|
||||
|
||||
# Model prefixes that natively belong to OpenAI / LinkAI (raw HTTP providers).
|
||||
_OPENAI_MODEL_PREFIXES = ("gpt-", "o1-", "o3-", "o4-", "chatgpt-")
|
||||
|
||||
# Maps the UI provider id (persisted in tools.vision.provider) to the internal
|
||||
# display name used in VisionProvider.name. Keep in sync with _DISCOVERABLE_MODELS
|
||||
# and the openai/linkai branches in _route_by_model_name.
|
||||
_PROVIDER_ID_TO_DISPLAY = {
|
||||
"openai": "OpenAI",
|
||||
"linkai": "LinkAI",
|
||||
"moonshot": "Moonshot",
|
||||
"doubao": "Doubao",
|
||||
"dashscope": "DashScope",
|
||||
"claudeAPI": "Claude",
|
||||
"gemini": "Gemini",
|
||||
"qianfan": "Qianfan",
|
||||
"zhipu": "ZhipuAI",
|
||||
"minimax": "MiniMax",
|
||||
"mimo": "MiMo",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionProvider:
|
||||
"""A single Vision API provider configuration."""
|
||||
name: str
|
||||
api_key: str
|
||||
api_base: str
|
||||
extra_headers: dict = field(default_factory=dict)
|
||||
model_override: Optional[str] = None
|
||||
use_bot: bool = False # When True, call via bot.call_vision instead of raw HTTP
|
||||
fallback_bot: Any = None # Bot instance for non-main-model providers
|
||||
|
||||
|
||||
class VisionAPIError(Exception):
|
||||
"""Raised when a Vision API call fails and should trigger fallback."""
|
||||
pass
|
||||
|
||||
|
||||
class Vision(BaseTool):
|
||||
"""Analyze images using Vision API"""
|
||||
|
||||
name: str = "vision"
|
||||
description: str = (
|
||||
"Analyze a local image or image URL (jpg/jpeg/png) using Vision API. "
|
||||
"Can describe content, extract text, identify objects, colors, etc. "
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {
|
||||
"type": "string",
|
||||
"description": "Local file path or HTTP(S) URL of the image to analyze",
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Question to ask about the image",
|
||||
},
|
||||
},
|
||||
"required": ["image", "question"],
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return True
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
image = args.get("image", "").strip()
|
||||
question = args.get("question", "").strip()
|
||||
|
||||
if not image:
|
||||
return ToolResult.fail("Error: 'image' parameter is required")
|
||||
if not question:
|
||||
return ToolResult.fail("Error: 'question' parameter is required")
|
||||
|
||||
providers = self._resolve_providers()
|
||||
if not providers:
|
||||
return ToolResult.fail(
|
||||
"Error: No model available for Vision.\n"
|
||||
"The main model does not support vision and no other API keys are configured.\n"
|
||||
"Options:\n"
|
||||
" 1. Switch to a multimodal model (e.g. ernie-4.5-turbo-vl, qwen3.7-plus, claude-sonnet-4-6, gemini-2.0-flash)\n"
|
||||
" 2. Configure OPENAI_API_KEY: env_config(action=\"set\", key=\"OPENAI_API_KEY\", value=\"your-key\")\n"
|
||||
" 3. Configure LINKAI_API_KEY: env_config(action=\"set\", key=\"LINKAI_API_KEY\", value=\"your-key\")"
|
||||
)
|
||||
|
||||
try:
|
||||
image_content = self._build_image_content(image)
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: {e}")
|
||||
|
||||
# Default model is only used as a last-resort placeholder for providers
|
||||
# whose VisionProvider.model_override is None (e.g. raw OpenAI provider
|
||||
# when the user did not configure tools.vision.model).
|
||||
return self._call_with_fallback(providers, DEFAULT_MODEL, question, image_content)
|
||||
|
||||
def _call_with_fallback(self, providers: List[VisionProvider], model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""Try each provider in order; fall back to the next one on failure."""
|
||||
errors: List[str] = []
|
||||
for i, provider in enumerate(providers):
|
||||
use_model = provider.model_override or model
|
||||
try:
|
||||
logger.info(f"[Vision] Trying provider '{provider.name}' "
|
||||
f"with model '{use_model}' ({i + 1}/{len(providers)})")
|
||||
if provider.use_bot:
|
||||
result = self._call_via_bot(use_model, question, image_content, provider)
|
||||
else:
|
||||
result = self._call_api(provider, use_model, question, image_content)
|
||||
logger.info(f"[Vision] ✅ Success via {provider.name} (model={use_model})")
|
||||
return result
|
||||
except VisionAPIError as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' failed: {e}")
|
||||
except requests.Timeout:
|
||||
errors.append(f"[{provider.name}/{use_model}] Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' timed out")
|
||||
except requests.ConnectionError:
|
||||
errors.append(f"[{provider.name}/{use_model}] Connection failed")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' connection failed")
|
||||
except Exception as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.error(f"[Vision] Provider '{provider.name}' unexpected error: {e}", exc_info=True)
|
||||
|
||||
return ToolResult.fail(
|
||||
"Error: All Vision API providers failed.\n" + "\n".join(f" - {err}" for err in errors)
|
||||
)
|
||||
|
||||
def _resolve_providers(self) -> List[VisionProvider]:
|
||||
"""
|
||||
Build an ordered list of providers to try.
|
||||
|
||||
Semantics of `tools.vision.model`:
|
||||
"Prefer this model first; fall back to other configured providers
|
||||
if it fails."
|
||||
|
||||
Order:
|
||||
1. The provider that natively serves `tools.vision.model` (if any
|
||||
and its API key is configured) — using the user-specified model
|
||||
name verbatim.
|
||||
2. Auto-discovery chain as fallback:
|
||||
- use_linkai=true → [LinkAI, MainModel?, OtherModels…, OpenAI]
|
||||
- default → [MainModel?, OtherModels…, OpenAI, LinkAI]
|
||||
MainModel is only included when the main bot is known to support
|
||||
vision (see _main_bot_supports_vision).
|
||||
|
||||
Providers that share the same display name as the preferred provider
|
||||
are de-duplicated to avoid retrying the same endpoint twice.
|
||||
"""
|
||||
user_model = self._resolve_user_vision_model()
|
||||
user_provider = self._resolve_user_vision_provider()
|
||||
providers: List[VisionProvider] = []
|
||||
|
||||
# Step 1: preferred provider — explicit `tools.vision.provider`
|
||||
# wins so custom model names can still be routed correctly. Falls
|
||||
# through to model-name prefix inference when provider is unset.
|
||||
preferred = None
|
||||
if user_provider and user_model:
|
||||
preferred = self._route_by_provider_id(user_provider, user_model)
|
||||
if not preferred and user_model:
|
||||
preferred = self._route_by_model_name(user_model)
|
||||
if preferred:
|
||||
providers.extend(preferred)
|
||||
|
||||
# Step 2: auto-discovery chain as fallback
|
||||
existing = {p.name for p in providers}
|
||||
fallback: List[VisionProvider] = []
|
||||
use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
|
||||
if use_linkai:
|
||||
self._append_provider(fallback, lambda: self._build_linkai_provider(user_model))
|
||||
self._append_provider(fallback, self._build_main_model_provider)
|
||||
self._append_other_model_providers(fallback, preferred_model=user_model)
|
||||
self._append_provider(fallback, lambda: self._build_openai_provider(user_model))
|
||||
else:
|
||||
self._append_provider(fallback, self._build_main_model_provider)
|
||||
self._append_other_model_providers(fallback, preferred_model=user_model)
|
||||
self._append_provider(fallback, lambda: self._build_openai_provider(user_model))
|
||||
self._append_provider(fallback, lambda: self._build_linkai_provider(user_model))
|
||||
|
||||
for p in fallback:
|
||||
if p.name in existing:
|
||||
continue
|
||||
providers.append(p)
|
||||
existing.add(p.name)
|
||||
|
||||
return providers
|
||||
|
||||
@staticmethod
|
||||
def _append_provider(providers: List[VisionProvider], builder) -> None:
|
||||
p = builder()
|
||||
if p:
|
||||
providers.append(p)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_vision_model() -> Optional[str]:
|
||||
"""Read tools.vision.model (singular ``tool`` kept as runtime fallback)."""
|
||||
tools_conf = conf().get("tools") or conf().get("tool") or {}
|
||||
if not isinstance(tools_conf, dict):
|
||||
return None
|
||||
vision_conf = tools_conf.get("vision", {})
|
||||
if not isinstance(vision_conf, dict):
|
||||
return None
|
||||
m = vision_conf.get("model")
|
||||
if isinstance(m, str) and m.strip():
|
||||
return m.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_vision_provider() -> Optional[str]:
|
||||
"""Read tools.vision.provider — the UI-persisted vendor id.
|
||||
|
||||
Lets users pin a vendor for custom model names that prefix-inference
|
||||
can't recognize. Returns None when unset/blank.
|
||||
"""
|
||||
tools_conf = conf().get("tools") or conf().get("tool") or {}
|
||||
if not isinstance(tools_conf, dict):
|
||||
return None
|
||||
vision_conf = tools_conf.get("vision", {})
|
||||
if not isinstance(vision_conf, dict):
|
||||
return None
|
||||
p = vision_conf.get("provider")
|
||||
if isinstance(p, str) and p.strip():
|
||||
return p.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _infer_provider_from_model(model_name: str) -> Optional[str]:
|
||||
"""
|
||||
Infer the provider display name from a model name's prefix.
|
||||
Returns None when no rule matches (or for OpenAI-family names, which
|
||||
are handled separately by the caller).
|
||||
"""
|
||||
if not model_name:
|
||||
return None
|
||||
lower = model_name.lower()
|
||||
# Sort by prefix length desc so e.g. "moonshot-" wins over hypothetical "moo-"
|
||||
for prefix, display_name in sorted(_MODEL_PREFIX_TO_PROVIDER, key=lambda x: -len(x[0])):
|
||||
if lower.startswith(prefix.lower()):
|
||||
return display_name
|
||||
return None
|
||||
|
||||
def _route_by_provider_id(self, provider_id: str, user_model: str) -> Optional[List[VisionProvider]]:
|
||||
"""Route by the UI-persisted provider id.
|
||||
|
||||
Returns:
|
||||
- [provider] : provider id is known and its key is configured.
|
||||
- None : unknown provider id, or the bot can't be created.
|
||||
Caller falls through to model-name-based routing.
|
||||
"""
|
||||
display_name = _PROVIDER_ID_TO_DISPLAY.get(provider_id)
|
||||
if not display_name:
|
||||
return None
|
||||
|
||||
# OpenAI / LinkAI use raw HTTP providers, not the discoverable bot path.
|
||||
if provider_id == "openai":
|
||||
p = self._build_openai_provider(user_model)
|
||||
return [p] if p else None
|
||||
if provider_id == "linkai":
|
||||
p = self._build_linkai_provider(user_model)
|
||||
return [p] if p else None
|
||||
|
||||
# Discoverable bot-backed providers.
|
||||
for config_key, bot_type, _default_model, name in _DISCOVERABLE_MODELS:
|
||||
if name != display_name:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
logger.warning(f"[Vision] tools.vision.provider='{provider_id}' "
|
||||
f"but '{config_key}' is not configured. Falling back.")
|
||||
return None
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
logger.warning(f"[Vision] '{display_name}' bot does not implement call_vision.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Vision] Failed to create '{display_name}' bot: {e}")
|
||||
return None
|
||||
return [VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=user_model,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)]
|
||||
return None
|
||||
|
||||
def _route_by_model_name(self, user_model: str) -> Optional[List[VisionProvider]]:
|
||||
"""
|
||||
Try to build a provider list using the user-specified model name.
|
||||
Returns:
|
||||
- [provider] : matched and the provider's key is configured
|
||||
- [] : matched but key missing → tell caller to surface this
|
||||
as a hard error rather than silently falling back
|
||||
- None : no rule matches → caller should fall through to auto
|
||||
"""
|
||||
lower = user_model.lower()
|
||||
|
||||
# OpenAI / LinkAI family
|
||||
if lower.startswith(_OPENAI_MODEL_PREFIXES):
|
||||
providers: List[VisionProvider] = []
|
||||
# Prefer LinkAI when explicitly enabled, else OpenAI first
|
||||
use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
if use_linkai:
|
||||
self._append_provider(providers, lambda: self._build_linkai_provider(user_model))
|
||||
self._append_provider(providers, lambda: self._build_openai_provider(user_model))
|
||||
else:
|
||||
self._append_provider(providers, lambda: self._build_openai_provider(user_model))
|
||||
self._append_provider(providers, lambda: self._build_linkai_provider(user_model))
|
||||
if providers:
|
||||
return providers
|
||||
logger.warning(f"[Vision] tools.vision.model='{user_model}' looks like an OpenAI "
|
||||
f"model but neither OPENAI_API_KEY nor LINKAI_API_KEY is configured.")
|
||||
return None # fall through to auto
|
||||
|
||||
# Discoverable native providers (Doubao, Moonshot, etc.)
|
||||
target_display = self._infer_provider_from_model(user_model)
|
||||
if not target_display:
|
||||
return None # unknown prefix → auto
|
||||
|
||||
for config_key, bot_type, _default_model, display_name in _DISCOVERABLE_MODELS:
|
||||
if display_name != target_display:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
logger.warning(f"[Vision] tools.vision.model='{user_model}' routes to "
|
||||
f"'{display_name}' but '{config_key}' is not configured. "
|
||||
f"Falling back to auto-discovery.")
|
||||
return None # fall through to auto
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
logger.warning(f"[Vision] '{display_name}' bot does not implement call_vision.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Vision] Failed to create '{display_name}' bot: {e}")
|
||||
return None
|
||||
|
||||
return [VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=user_model,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)]
|
||||
|
||||
return None
|
||||
|
||||
def _append_other_model_providers(self, providers: List[VisionProvider],
|
||||
preferred_model: Optional[str] = None) -> None:
|
||||
"""
|
||||
Auto-discover other models whose API key is configured.
|
||||
Skip the main model's own bot_type (already covered by MainModel
|
||||
provider), unless the main model itself does not support vision —
|
||||
in that case we still want the vendor's dedicated vision model
|
||||
as a fallback. Also skip bot_types that already appear in the
|
||||
provider list.
|
||||
|
||||
If preferred_model matches a provider's family, use it instead
|
||||
of that provider's hard-coded default model.
|
||||
"""
|
||||
main_bot_type = None
|
||||
main_bot_supports_vision = False
|
||||
if self.model and hasattr(self.model, '_resolve_bot_type'):
|
||||
main_bot_type = self.model._resolve_bot_type(conf().get("model", ""))
|
||||
main_bot = getattr(self.model, "bot", None)
|
||||
main_bot_supports_vision = self._main_bot_supports_vision(main_bot)
|
||||
|
||||
existing_names = {p.name for p in providers}
|
||||
preferred_provider = self._infer_provider_from_model(preferred_model) if preferred_model else None
|
||||
|
||||
for config_key, bot_type, default_model, display_name in _DISCOVERABLE_MODELS:
|
||||
if display_name in existing_names:
|
||||
continue
|
||||
# Same bot_type as the main model is normally handled by the
|
||||
# MainModel provider; only skip it here if the main model
|
||||
# actually supports vision. Otherwise fall through and add
|
||||
# the vendor's dedicated vision model as a fallback.
|
||||
if bot_type == main_bot_type and main_bot_supports_vision:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
model_for_provider = (preferred_model
|
||||
if preferred_provider == display_name and preferred_model
|
||||
else default_model)
|
||||
|
||||
provider = VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=model_for_provider,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)
|
||||
|
||||
# Same vendor as the main bot is the most natural fallback when
|
||||
# the main model itself does not support vision — promote it to
|
||||
# the front of the list instead of relying on declaration order.
|
||||
if bot_type == main_bot_type:
|
||||
providers.insert(0, provider)
|
||||
else:
|
||||
providers.append(provider)
|
||||
|
||||
def _main_bot_supports_vision(self, bot) -> bool:
|
||||
"""
|
||||
Whether the main bot is known to natively support vision.
|
||||
|
||||
Having a `call_vision` method is necessary but not sufficient —
|
||||
some bots implement the method against an endpoint that does not
|
||||
actually serve vision models, which causes silent failures when a
|
||||
vendor-foreign model name is forwarded.
|
||||
|
||||
Resolution order:
|
||||
1. If the bot explicitly declares `supports_vision`, trust it.
|
||||
This lets bots opt in or out based on their own runtime
|
||||
configuration (e.g. the currently selected model).
|
||||
2. Otherwise, fall back to a model-name prefix heuristic: trust
|
||||
call_vision when the main model looks like an OpenAI family
|
||||
model or matches a known multimodal vendor prefix.
|
||||
"""
|
||||
if bot is None:
|
||||
return False
|
||||
if hasattr(bot, "supports_vision"):
|
||||
return bool(getattr(bot, "supports_vision"))
|
||||
main_model = (conf().get("model") or "").lower()
|
||||
if not main_model:
|
||||
return False
|
||||
if main_model.startswith(_OPENAI_MODEL_PREFIXES):
|
||||
return True
|
||||
return self._infer_provider_from_model(main_model) is not None
|
||||
|
||||
def _build_main_model_provider(self) -> Optional[VisionProvider]:
|
||||
"""
|
||||
Use the vendor's own model for vision via bot.call_vision.
|
||||
Gated by _main_bot_supports_vision so non-vision bots (DeepSeek, etc.)
|
||||
do not get routed vendor-foreign model names.
|
||||
"""
|
||||
if not (self.model and hasattr(self.model, 'bot')):
|
||||
return None
|
||||
try:
|
||||
bot = self.model.bot
|
||||
except Exception:
|
||||
return None
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
return None
|
||||
if not self._main_bot_supports_vision(bot):
|
||||
return None
|
||||
|
||||
# Use the configured main model name; do NOT inject tools.vision.model
|
||||
# here, because by the time we reach this branch the tools.vision.model
|
||||
# routing has already been attempted (and either matched the main bot
|
||||
# or failed to find a provider).
|
||||
main_model_name = conf().get("model") or None
|
||||
|
||||
return VisionProvider(
|
||||
name=_MAIN_MODEL_PROVIDER_NAME,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=main_model_name,
|
||||
use_bot=True,
|
||||
)
|
||||
|
||||
def _build_openai_provider(self, preferred_model: Optional[str] = None) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.openai.com/v1"
|
||||
# Only honor preferred_model when it looks like an OpenAI-family name;
|
||||
# otherwise the OpenAI endpoint would 400 on a vendor-specific name.
|
||||
model_override = preferred_model if (
|
||||
preferred_model and preferred_model.lower().startswith(_OPENAI_MODEL_PREFIXES)
|
||||
) else None
|
||||
return VisionProvider(
|
||||
name="OpenAI",
|
||||
api_key=api_key,
|
||||
api_base=self._ensure_v1(api_base),
|
||||
model_override=model_override,
|
||||
)
|
||||
|
||||
def _build_linkai_provider(self, preferred_model: Optional[str] = None) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.link-ai.tech"
|
||||
from common.utils import get_cloud_headers
|
||||
extra = get_cloud_headers(api_key)
|
||||
extra.pop("Authorization", None)
|
||||
extra.pop("Content-Type", None)
|
||||
# LinkAI is a multi-vendor proxy and accepts most model names, so we
|
||||
# honor any user-configured model name here.
|
||||
return VisionProvider(
|
||||
name="LinkAI",
|
||||
api_key=api_key,
|
||||
api_base=self._ensure_v1(api_base),
|
||||
extra_headers=extra,
|
||||
model_override=preferred_model,
|
||||
)
|
||||
|
||||
def _call_via_bot(self, model: str, question: str, image_content: dict,
|
||||
provider: Optional[VisionProvider] = None) -> ToolResult:
|
||||
"""
|
||||
Call a model's call_vision with vendor-native API format.
|
||||
Uses the provider's _fallback_bot if set, otherwise the main model bot.
|
||||
Raises VisionAPIError on failure so fallback can proceed.
|
||||
"""
|
||||
try:
|
||||
bot = (provider and provider.fallback_bot) or self.model.bot
|
||||
except Exception as e:
|
||||
raise VisionAPIError(f"Cannot access bot: {e}")
|
||||
|
||||
# Extract the raw image URL from the OpenAI-format image_content block
|
||||
image_url = image_content.get("image_url", {}).get("url", "")
|
||||
if not image_url:
|
||||
raise VisionAPIError("No image URL in content block")
|
||||
|
||||
try:
|
||||
response = bot.call_vision(
|
||||
image_url=image_url,
|
||||
question=question,
|
||||
model=model,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
except Exception as e:
|
||||
raise VisionAPIError(f"call_vision failed: {e}")
|
||||
|
||||
if response is NotImplemented:
|
||||
raise VisionAPIError("Bot does not support vision")
|
||||
|
||||
if isinstance(response, dict) and response.get("error"):
|
||||
raise VisionAPIError(f"API error - {response.get('message', 'Unknown')}")
|
||||
|
||||
content = response.get("content", "") if isinstance(response, dict) else ""
|
||||
if not content:
|
||||
raise VisionAPIError("Empty response from main model")
|
||||
|
||||
usage_info = response.get("usage", {}) if isinstance(response, dict) else {}
|
||||
|
||||
# Use the actual model name from the bot response if available
|
||||
actual_model = response.get("model", model) if isinstance(response, dict) else model
|
||||
provider_name = provider.name if provider else _MAIN_MODEL_PROVIDER_NAME
|
||||
return ToolResult.success({
|
||||
"model": actual_model,
|
||||
"provider": provider_name,
|
||||
"content": content,
|
||||
"usage": usage_info,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _ensure_v1(api_base: str) -> str:
|
||||
"""Append /v1 if the base URL doesn't already end with a versioned path."""
|
||||
if not api_base:
|
||||
return api_base
|
||||
# Already has /v1 or similar version suffix
|
||||
if api_base.rstrip("/").split("/")[-1].startswith("v"):
|
||||
return api_base
|
||||
return api_base.rstrip("/") + "/v1"
|
||||
|
||||
def _build_image_content(self, image: str) -> dict:
|
||||
"""
|
||||
Build the image_url content block.
|
||||
Both remote URLs and local files are converted to base64 data URLs
|
||||
so every bot backend can consume them without extra downloads.
|
||||
"""
|
||||
if image.startswith(("http://", "https://")):
|
||||
return self._download_to_data_url(image)
|
||||
|
||||
if not os.path.isfile(image):
|
||||
raise FileNotFoundError(f"Image file not found: {image}")
|
||||
|
||||
ext = image.rsplit(".", 1)[-1].lower() if "." in image else ""
|
||||
mime_type = SUPPORTED_EXTENSIONS.get(ext)
|
||||
if not mime_type:
|
||||
raise ValueError(
|
||||
f"Unsupported image format '.{ext}'. "
|
||||
f"Supported: {', '.join(SUPPORTED_EXTENSIONS.keys())}"
|
||||
)
|
||||
|
||||
file_path = self._maybe_compress(image)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
finally:
|
||||
if file_path != image and os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _download_to_data_url(url: str) -> dict:
|
||||
"""Download a remote image and return it as a base64 data URL."""
|
||||
resp = requests.get(url, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
raise VisionAPIError(f"Failed to download image: HTTP {resp.status_code}")
|
||||
content_type = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
|
||||
if not content_type.startswith("image/"):
|
||||
content_type = "image/jpeg"
|
||||
b64 = base64.b64encode(resp.content).decode("ascii")
|
||||
data_url = f"data:{content_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _maybe_compress(path: str) -> str:
|
||||
"""Compress image to under COMPRESS_THRESHOLD with max long-edge 1536px."""
|
||||
file_size = os.path.getsize(path)
|
||||
if file_size <= COMPRESS_THRESHOLD:
|
||||
return path
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||||
tmp.close()
|
||||
|
||||
def _try_sips(max_dim: str, quality: str) -> bool:
|
||||
try:
|
||||
subprocess.run(
|
||||
["sips", "-Z", max_dim, "-s", "formatOptions", quality,
|
||||
path, "--out", tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
return True
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return False
|
||||
|
||||
def _try_convert(max_dim: str, quality: str) -> bool:
|
||||
try:
|
||||
subprocess.run(
|
||||
["convert", path, "-resize", f"{max_dim}x{max_dim}>",
|
||||
"-quality", quality, tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
return True
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return False
|
||||
|
||||
attempts = [
|
||||
("1536", "85"),
|
||||
("1536", "70"),
|
||||
("1536", "50"),
|
||||
]
|
||||
|
||||
for max_dim, quality in attempts:
|
||||
ok = _try_sips(max_dim, quality) or _try_convert(max_dim, quality)
|
||||
if not ok:
|
||||
continue
|
||||
new_size = os.path.getsize(tmp.name)
|
||||
logger.debug(f"[Vision] Compressed image "
|
||||
f"({file_size // 1024}KB -> {new_size // 1024}KB, "
|
||||
f"max_dim={max_dim}, q={quality})")
|
||||
if new_size <= COMPRESS_THRESHOLD:
|
||||
return tmp.name
|
||||
|
||||
if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > 0:
|
||||
return tmp.name
|
||||
|
||||
os.remove(tmp.name)
|
||||
return path
|
||||
|
||||
def _call_api(self, provider: VisionProvider, model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""
|
||||
Call a single provider's Vision API.
|
||||
Raises VisionAPIError on recoverable failures so the caller can try
|
||||
the next provider.
|
||||
"""
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": question},
|
||||
image_content,
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**provider.extra_headers,
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{provider.api_base}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise VisionAPIError(f"HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
msg = data["error"].get("message", "Unknown API error")
|
||||
raise VisionAPIError(f"API error - {msg}")
|
||||
|
||||
content = ""
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "")
|
||||
|
||||
usage = data.get("usage", {})
|
||||
result = {
|
||||
"model": model,
|
||||
"provider": provider.name,
|
||||
"content": content,
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
}
|
||||
return ToolResult.success(result)
|
||||
0
agent/tools/web_fetch/__init__.py
Normal file
0
agent/tools/web_fetch/__init__.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
Web Fetch tool - Fetch and extract readable content from web pages and remote files.
|
||||
|
||||
Supports:
|
||||
- HTML web pages: extracts readable text content
|
||||
- Document files (PDF, Word, TXT, Markdown, etc.): downloads to workspace/tmp and parses content
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, Set
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size
|
||||
from common.log import logger
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
|
||||
"Accept": "*/*",
|
||||
}
|
||||
|
||||
# Supported document file extensions
|
||||
PDF_SUFFIXES: Set[str] = {".pdf"}
|
||||
WORD_SUFFIXES: Set[str] = {".docx"}
|
||||
TEXT_SUFFIXES: Set[str] = {".txt", ".md", ".markdown", ".rst", ".csv", ".tsv", ".log"}
|
||||
SPREADSHEET_SUFFIXES: Set[str] = {".xls", ".xlsx"}
|
||||
PPT_SUFFIXES: Set[str] = {".ppt", ".pptx"}
|
||||
|
||||
ALL_DOC_SUFFIXES = PDF_SUFFIXES | WORD_SUFFIXES | TEXT_SUFFIXES | SPREADSHEET_SUFFIXES | PPT_SUFFIXES
|
||||
|
||||
_CHARSET_RE = re.compile(r'charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_CHARSET_RE = re.compile(rb'<meta[^>]+charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_HTTP_EQUIV_RE = re.compile(
|
||||
rb'<meta[^>]+http-equiv\s*=\s*["\']?Content-Type["\']?[^>]+content\s*=\s*["\'][^"\']*charset=([\w\-]+)',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _extract_charset_from_content_type(content_type: str) -> Optional[str]:
|
||||
"""Extract charset from Content-Type header value."""
|
||||
m = _CHARSET_RE.search(content_type)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def _extract_charset_from_html_meta(raw_bytes: bytes) -> Optional[str]:
|
||||
"""Extract charset from HTML <meta> tags in the first few KB of raw bytes."""
|
||||
m = _META_CHARSET_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
m = _META_HTTP_EQUIV_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
return None
|
||||
|
||||
|
||||
def _get_url_suffix(url: str) -> str:
|
||||
"""Extract file extension from URL path, ignoring query params."""
|
||||
path = urlparse(url).path
|
||||
return os.path.splitext(path)[-1].lower()
|
||||
|
||||
|
||||
def _is_document_url(url: str) -> bool:
|
||||
"""Check if URL points to a downloadable document file."""
|
||||
suffix = _get_url_suffix(url)
|
||||
return suffix in ALL_DOC_SUFFIXES
|
||||
|
||||
|
||||
class WebFetch(BaseTool):
|
||||
"""Tool for fetching web pages and remote document files"""
|
||||
|
||||
name: str = "web_fetch"
|
||||
description: str = (
|
||||
"Fetch content from a http/https URL. For web pages, extracts readable text. "
|
||||
"For document files (PDF, Word, TXT, Markdown, Excel, PPT), downloads and parses the file content. "
|
||||
"Supported file types: .pdf, .docx, .txt, .md, .csv, .xls, .xlsx, .ppt, .pptx"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to fetch (web page or document file link)"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' parameter is required")
|
||||
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return ToolResult.fail("Error: Invalid URL (must start with http:// or https://)")
|
||||
|
||||
if _is_document_url(url):
|
||||
return self._fetch_document(url)
|
||||
|
||||
return self._fetch_webpage(url)
|
||||
|
||||
# ---- Web page fetching ----
|
||||
|
||||
def _fetch_webpage(self, url: str) -> ToolResult:
|
||||
"""Fetch and extract readable text from an HTML web page."""
|
||||
parsed = urlparse(url)
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: Failed to fetch URL: {e}")
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if self._is_binary_content_type(content_type) and not _is_document_url(url):
|
||||
return self._handle_download_by_content_type(url, response, content_type)
|
||||
|
||||
response.encoding = self._detect_encoding(response)
|
||||
html = response.text
|
||||
title = self._extract_title(html)
|
||||
text = self._extract_text(html)
|
||||
|
||||
return ToolResult.success(f"Title: {title}\n\nContent:\n{text}")
|
||||
|
||||
# ---- Document fetching ----
|
||||
|
||||
def _fetch_document(self, url: str) -> ToolResult:
|
||||
"""Download a document file and extract its text content."""
|
||||
suffix = _get_url_suffix(url)
|
||||
parsed = urlparse(url)
|
||||
filename = self._extract_filename(url)
|
||||
tmp_dir = self._ensure_tmp_dir()
|
||||
|
||||
local_path = os.path.join(tmp_dir, filename)
|
||||
logger.info(f"[WebFetch] Downloading document: {url} -> {local_path}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
content_length = int(response.headers.get("Content-Length", 0))
|
||||
if content_length > MAX_FILE_SIZE:
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large ({format_size(content_length)} > {format_size(MAX_FILE_SIZE)})"
|
||||
)
|
||||
|
||||
downloaded = 0
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
downloaded += len(chunk)
|
||||
if downloaded > MAX_FILE_SIZE:
|
||||
f.close()
|
||||
os.remove(local_path)
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large (>{format_size(MAX_FILE_SIZE)}), download aborted"
|
||||
)
|
||||
f.write(chunk)
|
||||
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Download timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to download file: {e}")
|
||||
|
||||
try:
|
||||
text = self._parse_document(local_path, suffix)
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to parse document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
file_size = os.path.getsize(local_path)
|
||||
return ToolResult.success(
|
||||
f"File downloaded to: {local_path} ({format_size(file_size)})\n"
|
||||
f"No text content could be extracted. The file may contain only images or be encrypted."
|
||||
)
|
||||
|
||||
truncation = truncate_head(text)
|
||||
result_text = truncation.content
|
||||
|
||||
file_size = os.path.getsize(local_path)
|
||||
header = f"[Document: {filename} | Size: {format_size(file_size)} | Saved to: {local_path}]\n\n"
|
||||
|
||||
if truncation.truncated:
|
||||
header += f"[Content truncated: showing {truncation.output_lines} of {truncation.total_lines} lines]\n\n"
|
||||
|
||||
return ToolResult.success(header + result_text)
|
||||
|
||||
def _parse_document(self, file_path: str, suffix: str) -> str:
|
||||
"""Parse document file and return extracted text."""
|
||||
if suffix in PDF_SUFFIXES:
|
||||
return self._parse_pdf(file_path)
|
||||
elif suffix in WORD_SUFFIXES:
|
||||
return self._parse_word(file_path)
|
||||
elif suffix in TEXT_SUFFIXES:
|
||||
return self._parse_text(file_path)
|
||||
elif suffix in SPREADSHEET_SUFFIXES:
|
||||
return self._parse_spreadsheet(file_path)
|
||||
elif suffix in PPT_SUFFIXES:
|
||||
return self._parse_ppt(file_path)
|
||||
else:
|
||||
return self._parse_text(file_path)
|
||||
|
||||
def _parse_pdf(self, file_path: str) -> str:
|
||||
"""Extract text from PDF using pypdf."""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
raise ImportError("pypdf library is required for PDF parsing. Install with: pip install pypdf")
|
||||
|
||||
reader = PdfReader(file_path)
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text and page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num}/{len(reader.pages)} ---\n{page_text}")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
def _parse_word(self, file_path: str) -> str:
|
||||
"""Extract text from Word documents (.docx)."""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-docx library is required for .docx parsing. Install with: pip install python-docx"
|
||||
)
|
||||
doc = Document(file_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return "\n\n".join(paragraphs)
|
||||
|
||||
def _parse_text(self, file_path: str) -> str:
|
||||
"""Read plain text files (txt, md, csv, etc.)."""
|
||||
encodings = ["utf-8", "utf-8-sig", "gbk", "gb2312", "latin-1"]
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(file_path, "r", encoding=enc) as f:
|
||||
return f.read()
|
||||
except (UnicodeDecodeError, UnicodeError):
|
||||
continue
|
||||
raise ValueError(f"Unable to decode file with any supported encoding: {encodings}")
|
||||
|
||||
def _parse_spreadsheet(self, file_path: str) -> str:
|
||||
"""Extract text from Excel files (.xls/.xlsx)."""
|
||||
try:
|
||||
import openpyxl
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"openpyxl library is required for .xlsx parsing. Install with: pip install openpyxl"
|
||||
)
|
||||
|
||||
wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True)
|
||||
result_parts = []
|
||||
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
rows = []
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
cells = [str(c) if c is not None else "" for c in row]
|
||||
if any(cells):
|
||||
rows.append(" | ".join(cells))
|
||||
if rows:
|
||||
result_parts.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
|
||||
|
||||
wb.close()
|
||||
return "\n\n".join(result_parts)
|
||||
|
||||
def _parse_ppt(self, file_path: str) -> str:
|
||||
"""Extract text from PowerPoint files (.ppt/.pptx)."""
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-pptx library is required for .pptx parsing. Install with: pip install python-pptx"
|
||||
)
|
||||
|
||||
prs = Presentation(file_path)
|
||||
text_parts = []
|
||||
|
||||
for slide_num, slide in enumerate(prs.slides, 1):
|
||||
slide_texts = []
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for paragraph in shape.text_frame.paragraphs:
|
||||
text = paragraph.text.strip()
|
||||
if text:
|
||||
slide_texts.append(text)
|
||||
if slide_texts:
|
||||
text_parts.append(f"--- Slide {slide_num}/{len(prs.slides)} ---\n" + "\n".join(slide_texts))
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
# ---- Encoding detection ----
|
||||
|
||||
@staticmethod
|
||||
def _detect_encoding(response: requests.Response) -> str:
|
||||
"""Detect response encoding with priority: Content-Type header > HTML meta > chardet > utf-8."""
|
||||
# 1. Check Content-Type header for explicit charset
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
charset = _extract_charset_from_content_type(content_type)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 2. Scan raw bytes for HTML meta charset declaration
|
||||
raw = response.content[:4096]
|
||||
charset = _extract_charset_from_html_meta(raw)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 3. Use apparent_encoding (chardet-based detection) if confident enough
|
||||
apparent = response.apparent_encoding
|
||||
if apparent:
|
||||
apparent_lower = apparent.lower()
|
||||
# Trust CJK / Windows encodings detected by chardet
|
||||
trusted_prefixes = ("utf", "gb", "big5", "euc", "shift_jis", "iso-2022", "windows", "ascii")
|
||||
if any(apparent_lower.startswith(p) for p in trusted_prefixes):
|
||||
return apparent
|
||||
|
||||
# 4. Fallback
|
||||
return "utf-8"
|
||||
|
||||
# ---- Helper methods ----
|
||||
|
||||
def _ensure_tmp_dir(self) -> str:
|
||||
"""Ensure workspace/tmp directory exists and return its path."""
|
||||
tmp_dir = os.path.join(self.cwd, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
def _extract_filename(self, url: str) -> str:
|
||||
"""Extract a safe filename from URL, with a short UUID prefix to avoid collisions."""
|
||||
path = urlparse(url).path
|
||||
basename = os.path.basename(unquote(path))
|
||||
if not basename or basename == "/":
|
||||
basename = "downloaded_file"
|
||||
# Sanitize: keep only safe chars
|
||||
basename = re.sub(r'[^\w.\-]', '_', basename)
|
||||
short_id = uuid.uuid4().hex[:8]
|
||||
return f"{short_id}_{basename}"
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_file(path: str):
|
||||
"""Remove a file if it exists, ignoring errors."""
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_binary_content_type(content_type: str) -> bool:
|
||||
"""Check if Content-Type indicates a binary/document response."""
|
||||
binary_types = [
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/octet-stream",
|
||||
]
|
||||
ct_lower = content_type.lower()
|
||||
return any(bt in ct_lower for bt in binary_types)
|
||||
|
||||
def _handle_download_by_content_type(self, url: str, response: requests.Response, content_type: str) -> ToolResult:
|
||||
"""Handle a URL that returned binary content instead of HTML."""
|
||||
ct_lower = content_type.lower()
|
||||
suffix_map = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml": ".docx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml": ".xlsx",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml": ".pptx",
|
||||
}
|
||||
detected_suffix = None
|
||||
for ct_prefix, ext in suffix_map.items():
|
||||
if ct_prefix in ct_lower:
|
||||
detected_suffix = ext
|
||||
break
|
||||
|
||||
if detected_suffix and detected_suffix in ALL_DOC_SUFFIXES:
|
||||
# Re-fetch as document
|
||||
return self._fetch_document(url if _get_url_suffix(url) in ALL_DOC_SUFFIXES
|
||||
else self._rewrite_url_with_suffix(url, detected_suffix))
|
||||
return ToolResult.fail(f"Error: URL returned binary content ({content_type}), not a supported document type")
|
||||
|
||||
@staticmethod
|
||||
def _rewrite_url_with_suffix(url: str, suffix: str) -> str:
|
||||
"""Append a suffix to the URL path so _get_url_suffix works correctly."""
|
||||
parsed = urlparse(url)
|
||||
new_path = parsed.path.rstrip("/") + suffix
|
||||
return parsed._replace(path=new_path).geturl()
|
||||
|
||||
# ---- HTML extraction (unchanged) ----
|
||||
|
||||
@staticmethod
|
||||
def _extract_title(html: str) -> str:
|
||||
match = re.search(r"<title[^>]*>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
|
||||
return match.group(1).strip() if match else "Untitled"
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(html: str) -> str:
|
||||
text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(""", '"').replace("'", "'").replace(" ", " ")
|
||||
text = re.sub(r"[^\S\n]+", " ", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
lines = [line.strip() for line in text.splitlines()]
|
||||
text = "\n".join(lines)
|
||||
return text.strip()
|
||||
3
agent/tools/web_search/__init__.py
Normal file
3
agent/tools/web_search/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
|
||||
__all__ = ["WebSearch"]
|
||||
487
agent/tools/web_search/web_search.py
Normal file
487
agent/tools/web_search/web_search.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""Web Search tool. Supports four backends with a unified response format:
|
||||
- bocha (https://open.bochaai.com)
|
||||
- zhipu (https://docs.bigmodel.cn/cn/guide/tools/web-search)
|
||||
- qianfan (https://cloud.baidu.com/doc/qianfan/s/2mh4su4uy)
|
||||
- linkai (https://link-ai.tech, fallback)
|
||||
|
||||
Provider selection
|
||||
- strategy 'auto' (default): pick the first configured provider in the
|
||||
canonical order [bocha, zhipu, qianfan, linkai]. When the caller passes
|
||||
an explicit `provider` it overrides the pick; an invalid/unconfigured
|
||||
one silently falls back to the auto order.
|
||||
- strategy 'fixed': use the configured provider; if its credential is
|
||||
missing at call time, silently fall back to auto order (no card hint).
|
||||
|
||||
Credentials
|
||||
- bocha : tools.web_search.bocha_api_key -> env BOCHA_API_KEY
|
||||
- zhipu : conf.zhipu_ai_api_key -> env ZHIPUAI_API_KEY
|
||||
- qianfan : conf.qianfan_api_key -> env QIANFAN_API_KEY
|
||||
- linkai : conf.linkai_api_key -> env LINKAI_API_KEY
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
|
||||
# Canonical fallback order. Empirically ordered by Chinese real-time
|
||||
# quality + relevance: bocha (best overall), qianfan (best for hot news),
|
||||
# zhipu (strong on long-form articles), linkai (cloud aggregator, last
|
||||
# resort).
|
||||
PROVIDER_ORDER = ("bocha", "qianfan", "zhipu", "linkai")
|
||||
|
||||
PROVIDER_LABELS = {
|
||||
"bocha": "Bocha",
|
||||
"zhipu": "Zhipu",
|
||||
"qianfan": "Baidu Qianfan",
|
||||
"linkai": "LinkAI",
|
||||
}
|
||||
|
||||
|
||||
def _tools_web_search_conf() -> dict:
|
||||
"""Return the tools.web_search config block (dict-like)."""
|
||||
tools_cfg = conf().get("tools") or {}
|
||||
if not isinstance(tools_cfg, dict):
|
||||
return {}
|
||||
block = tools_cfg.get("web_search") or {}
|
||||
return block if isinstance(block, dict) else {}
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> str:
|
||||
"""Resolve API key for a provider, with conf -> env fallback."""
|
||||
if provider == "bocha":
|
||||
key = (_tools_web_search_conf().get("bocha_api_key") or "").strip()
|
||||
return key or os.environ.get("BOCHA_API_KEY", "").strip()
|
||||
if provider == "zhipu":
|
||||
key = (conf().get("zhipu_ai_api_key") or "").strip()
|
||||
return key or os.environ.get("ZHIPUAI_API_KEY", "").strip()
|
||||
if provider == "qianfan":
|
||||
key = (conf().get("qianfan_api_key") or "").strip()
|
||||
return key or os.environ.get("QIANFAN_API_KEY", "").strip()
|
||||
if provider == "linkai":
|
||||
key = (conf().get("linkai_api_key") or "").strip()
|
||||
return key or os.environ.get("LINKAI_API_KEY", "").strip()
|
||||
return ""
|
||||
|
||||
|
||||
def configured_providers() -> List[str]:
|
||||
"""Return configured providers in canonical order."""
|
||||
return [p for p in PROVIDER_ORDER if _get_api_key(p)]
|
||||
|
||||
|
||||
def _configured_strategy() -> str:
|
||||
return (_tools_web_search_conf().get("strategy") or "auto").strip().lower()
|
||||
|
||||
|
||||
def _configured_provider() -> str:
|
||||
return (_tools_web_search_conf().get("provider") or "").strip().lower()
|
||||
|
||||
|
||||
class WebSearch(BaseTool):
|
||||
"""Tool for searching the web across multiple providers."""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = "Search the web for real-time information. Returns titles, URLs, and snippets."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query string"
|
||||
},
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-50, default: 10)"
|
||||
},
|
||||
"freshness": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time range filter. Options: "
|
||||
"'noLimit' (default), 'oneDay', 'oneWeek', 'oneMonth', 'oneYear', "
|
||||
"or date range like '2025-01-01..2025-02-01'"
|
||||
)
|
||||
},
|
||||
"summary": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to include text summary for each result (default: false)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
"""Tool is offered to the agent when at least one provider has a key."""
|
||||
return bool(configured_providers())
|
||||
|
||||
@classmethod
|
||||
def get_json_schema(cls) -> dict:
|
||||
"""Augment the static schema with a `provider` field — only when the
|
||||
user has ≥2 providers configured AND strategy is 'auto'. Otherwise
|
||||
the backend picks silently and exposing the field would only waste
|
||||
the agent's tokens."""
|
||||
schema = {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": json.loads(json.dumps(cls.params)), # deep copy
|
||||
}
|
||||
if _configured_strategy() != "auto":
|
||||
return schema
|
||||
available = configured_providers()
|
||||
if len(available) < 2:
|
||||
return schema
|
||||
|
||||
schema["parameters"]["properties"]["provider"] = {
|
||||
"type": "string",
|
||||
"enum": available,
|
||||
"description": "Optional. Specifies the search backend. You may switch between providers when the user wants results from a particular source or from multiple sources.",
|
||||
}
|
||||
return schema
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider resolution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_provider(self, requested: Optional[str]) -> Optional[str]:
|
||||
"""Pick a provider for this call.
|
||||
|
||||
Priority: caller-supplied (if configured) > fixed strategy (if
|
||||
configured) > first configured in PROVIDER_ORDER. Silent fallback
|
||||
when the desired one has no key.
|
||||
"""
|
||||
available = configured_providers()
|
||||
if not available:
|
||||
return None
|
||||
|
||||
if requested:
|
||||
req = requested.strip().lower()
|
||||
if req in available:
|
||||
return req
|
||||
logger.warning(f"[WebSearch] requested provider '{requested}' unavailable, falling back")
|
||||
|
||||
if _configured_strategy() == "fixed":
|
||||
pinned = _configured_provider()
|
||||
if pinned in available:
|
||||
return pinned
|
||||
if pinned:
|
||||
logger.warning(f"[WebSearch] pinned provider '{pinned}' unavailable, falling back to auto")
|
||||
|
||||
return available[0]
|
||||
|
||||
@staticmethod
|
||||
def _resolution_reason(requested: Optional[str], chosen: str) -> str:
|
||||
"""Human-readable explanation for why `chosen` won the resolver."""
|
||||
if requested and requested.strip().lower() == chosen:
|
||||
return "caller-requested"
|
||||
strategy = _configured_strategy()
|
||||
if strategy == "fixed" and _configured_provider() == chosen:
|
||||
return "fixed-strategy"
|
||||
return "auto-fallback"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
query = (args.get("query") or "").strip()
|
||||
if not query:
|
||||
return ToolResult.fail("Error: 'query' parameter is required")
|
||||
|
||||
count = args.get("count", 10)
|
||||
freshness = args.get("freshness", "noLimit")
|
||||
summary = args.get("summary", False)
|
||||
if not isinstance(count, int) or count < 1 or count > 50:
|
||||
count = 10
|
||||
|
||||
requested = args.get("provider")
|
||||
provider = self._resolve_provider(requested)
|
||||
if not provider:
|
||||
return ToolResult.fail(
|
||||
"Error: No search provider configured. "
|
||||
"Configure one of BOCHA_API_KEY / zhipu_ai_api_key / qianfan_api_key / linkai_api_key."
|
||||
)
|
||||
|
||||
# Always log the routing decision so multi-provider deployments can
|
||||
# tell at a glance which backend served any given query.
|
||||
available = configured_providers()
|
||||
reason = self._resolution_reason(requested, provider)
|
||||
q_preview = query if len(query) <= 60 else (query[:57] + "...")
|
||||
logger.info(
|
||||
f"[WebSearch] provider={provider} reason={reason} "
|
||||
f"available={list(available)} query={q_preview!r} count={count} freshness={freshness}"
|
||||
)
|
||||
|
||||
try:
|
||||
if provider == "bocha":
|
||||
return self._search_bocha(query, count, freshness, summary)
|
||||
if provider == "zhipu":
|
||||
return self._search_zhipu(query, count, freshness)
|
||||
if provider == "qianfan":
|
||||
return self._search_qianfan(query, count, freshness)
|
||||
if provider == "linkai":
|
||||
return self._search_linkai(query, count, freshness)
|
||||
return ToolResult.fail(f"Error: Unknown provider '{provider}'")
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Search request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail("Error: Failed to connect to search API")
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSearch] Unexpected error ({provider}): {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: Search failed - {str(e)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bocha
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_bocha(self, query: str, count: int, freshness: str, summary: bool) -> ToolResult:
|
||||
api_key = _get_api_key("bocha")
|
||||
url = "https://api.bochaai.com/v1/web-search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"query": query, "count": count, "freshness": freshness, "summary": summary}
|
||||
|
||||
logger.debug(f"[WebSearch] bocha: query='{query}', count={count}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid bocha API key.")
|
||||
if resp.status_code == 403:
|
||||
return ToolResult.fail("Error: bocha API — insufficient balance. Top up at https://open.bochaai.com")
|
||||
if resp.status_code == 429:
|
||||
return ToolResult.fail("Error: bocha API rate limit reached.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: bocha API returned HTTP {resp.status_code}")
|
||||
|
||||
data = resp.json()
|
||||
api_code = data.get("code")
|
||||
if api_code is not None and api_code != 200:
|
||||
msg = data.get("msg") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: bocha API error (code={api_code}): {msg}")
|
||||
|
||||
pages = (data.get("data") or {}).get("webPages", {}).get("value", []) or []
|
||||
results = []
|
||||
for p in pages:
|
||||
item = {
|
||||
"title": p.get("name", ""),
|
||||
"url": p.get("url", ""),
|
||||
"snippet": p.get("snippet", ""),
|
||||
"siteName": p.get("siteName", ""),
|
||||
"datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""),
|
||||
}
|
||||
if p.get("summary"):
|
||||
item["summary"] = p["summary"]
|
||||
results.append(item)
|
||||
total = (data.get("data") or {}).get("webPages", {}).get("totalEstimatedMatches", len(results))
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "bocha",
|
||||
"total": total, "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zhipu
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_zhipu(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("zhipu")
|
||||
api_base = (conf().get("zhipu_ai_api_base") or "https://open.bigmodel.cn/api/paas/v4").rstrip("/")
|
||||
url = f"{api_base}/web_search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Zhipu Web Search expects `search_query` <= 70 chars; truncate
|
||||
# gracefully so a long agent-supplied query doesn't get rejected.
|
||||
trimmed_query = (query or "")[:70]
|
||||
engine = (_tools_web_search_conf().get("zhipu_search_engine") or "search_pro").strip().lower()
|
||||
if engine not in ("search_std", "search_pro", "search_pro_sogou", "search_pro_quark"):
|
||||
engine = "search_pro"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"search_engine": engine,
|
||||
"search_query": trimmed_query,
|
||||
"search_intent": False,
|
||||
"count": max(1, min(int(count or 10), 50)),
|
||||
"search_recency_filter": freshness if freshness in (
|
||||
"oneDay", "oneWeek", "oneMonth", "oneYear", "noLimit"
|
||||
) else "noLimit",
|
||||
}
|
||||
content_size = (_tools_web_search_conf().get("zhipu_content_size") or "").strip().lower()
|
||||
if content_size in ("medium", "high"):
|
||||
payload["content_size"] = content_size
|
||||
|
||||
logger.debug(f"[WebSearch] zhipu: query='{trimmed_query}', count={payload['count']}, engine={engine}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid Zhipu API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Zhipu API returned HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
# Business-level errors (1701/1702/1703 etc.) come back as
|
||||
# {"error": {"code","message"}} even on HTTP 200.
|
||||
if isinstance(data, dict) and data.get("error"):
|
||||
err = data["error"] or {}
|
||||
return ToolResult.fail(f"Error: Zhipu returned {err.get('code')}: {err.get('message','')}")
|
||||
|
||||
items = data.get("search_result") or (data.get("data") or {}).get("search_result") or []
|
||||
results = []
|
||||
for it in items:
|
||||
results.append({
|
||||
"title": it.get("title", ""),
|
||||
"url": it.get("link") or it.get("url", ""),
|
||||
"snippet": it.get("content") or it.get("snippet", ""),
|
||||
"siteName": it.get("media") or it.get("siteName", ""),
|
||||
"datePublished": it.get("publish_date") or it.get("datePublished", ""),
|
||||
})
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "zhipu",
|
||||
"total": len(results), "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Qianfan (Baidu)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_qianfan(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("qianfan")
|
||||
api_base = (conf().get("qianfan_api_base") or "https://qianfan.baidubce.com/v2").rstrip("/")
|
||||
url = f"{api_base}/ai_search/web_search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-Appbuilder-From": "cow",
|
||||
}
|
||||
|
||||
count = max(1, min(int(count or 10), 50))
|
||||
payload: Dict[str, Any] = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "web", "top_k": count}],
|
||||
}
|
||||
|
||||
# Baidu AI Search expects freshness as a date-range filter, not a
|
||||
# named recency token. Translate our shared vocabulary into the
|
||||
# underlying page_time range expected by the API.
|
||||
search_filter = self._qianfan_build_freshness_filter(freshness)
|
||||
if search_filter:
|
||||
payload["search_filter"] = search_filter
|
||||
|
||||
logger.debug(f"[WebSearch] qianfan: query='{query}', count={count}, freshness={freshness!r}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid Qianfan API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Qianfan API returned HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
# Even on HTTP 200 Baidu surfaces business errors as {"code","message"}.
|
||||
if isinstance(data, dict) and data.get("code"):
|
||||
return ToolResult.fail(f"Error: Qianfan returned {data.get('code')}: {data.get('message','')}")
|
||||
|
||||
refs = data.get("references") or []
|
||||
results = []
|
||||
for d in refs:
|
||||
results.append({
|
||||
"title": d.get("title", ""),
|
||||
"url": d.get("url", ""),
|
||||
"snippet": (d.get("content") or "")[:200],
|
||||
"siteName": d.get("web_anchor") or d.get("website") or "",
|
||||
"datePublished": d.get("date", ""),
|
||||
})
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "qianfan",
|
||||
"total": len(results), "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _qianfan_build_freshness_filter(freshness: str) -> Optional[Dict[str, Any]]:
|
||||
if not freshness or freshness == "noLimit":
|
||||
return None
|
||||
delta_days = {"oneDay": 1, "oneWeek": 7, "oneMonth": 30, "oneYear": 365}.get(freshness)
|
||||
if not delta_days:
|
||||
return None
|
||||
from datetime import datetime, timedelta
|
||||
now = datetime.now()
|
||||
end_date = (now + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
start_date = (now - timedelta(days=delta_days)).strftime("%Y-%m-%d")
|
||||
return {"range": {"page_time": {"gte": start_date, "lt": end_date}}}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LinkAI (plugin)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("linkai")
|
||||
api_base = (conf().get("linkai_api_base") or "https://api.link-ai.tech").rstrip("/")
|
||||
url = f"{api_base}/v1/plugin/execute"
|
||||
|
||||
from common.utils import get_cloud_headers
|
||||
headers = get_cloud_headers(api_key)
|
||||
|
||||
payload = {"code": "web-search", "args": {"query": query, "count": count, "freshness": freshness}}
|
||||
logger.debug(f"[WebSearch] linkai: query='{query}', count={count}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid LinkAI API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: LinkAI API returned HTTP {resp.status_code}")
|
||||
|
||||
data = resp.json()
|
||||
if not data.get("success"):
|
||||
msg = data.get("message") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: LinkAI search failed: {msg}")
|
||||
|
||||
raw = data.get("data", "")
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
raw = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "linkai",
|
||||
"total": 1, "count": 1, "results": [{"content": raw}],
|
||||
})
|
||||
|
||||
if isinstance(raw, dict):
|
||||
pages = (raw.get("webPages") or {}).get("value", []) or []
|
||||
if pages:
|
||||
results = []
|
||||
for p in pages:
|
||||
item = {
|
||||
"title": p.get("name", ""),
|
||||
"url": p.get("url", ""),
|
||||
"snippet": p.get("snippet", ""),
|
||||
"siteName": p.get("siteName", ""),
|
||||
"datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""),
|
||||
}
|
||||
if p.get("summary"):
|
||||
item["summary"] = p["summary"]
|
||||
results.append(item)
|
||||
total = (raw.get("webPages") or {}).get("totalEstimatedMatches", len(results))
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "linkai",
|
||||
"total": total, "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "linkai",
|
||||
"total": 1, "count": 1, "results": [{"content": str(raw)}],
|
||||
})
|
||||
@@ -8,6 +8,7 @@ from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Write(BaseTool):
|
||||
@@ -90,7 +91,7 @@ class Write(BaseTool):
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = os.path.expanduser(path)
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
370
app.py
370
app.py
@@ -7,11 +7,261 @@ import time
|
||||
|
||||
from channel import channel_factory
|
||||
from common import const
|
||||
from config import load_config
|
||||
from common.log import logger
|
||||
from config import load_config, conf
|
||||
from plugins import *
|
||||
import threading
|
||||
|
||||
|
||||
_channel_mgr = None
|
||||
|
||||
|
||||
def get_channel_manager():
|
||||
return _channel_mgr
|
||||
|
||||
|
||||
def _parse_channel_type(raw) -> list:
|
||||
"""
|
||||
Parse channel_type config value into a list of channel names.
|
||||
Supports:
|
||||
- single string: "feishu"
|
||||
- comma-separated string: "feishu, dingtalk"
|
||||
- list: ["feishu", "dingtalk"]
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return [ch.strip() for ch in raw if ch.strip()]
|
||||
if isinstance(raw, str):
|
||||
return [ch.strip() for ch in raw.split(",") if ch.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manage the lifecycle of multiple channels running concurrently.
|
||||
Each channel.startup() runs in its own daemon thread.
|
||||
The web channel is started as default console unless explicitly disabled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._channels = {} # channel_name -> channel instance
|
||||
self._threads = {} # channel_name -> thread
|
||||
self._primary_channel = None
|
||||
self._lock = threading.Lock()
|
||||
self.cloud_mode = False # set to True when cloud client is active
|
||||
|
||||
@property
|
||||
def channel(self):
|
||||
"""Return the primary (first non-web) channel for backward compatibility."""
|
||||
return self._primary_channel
|
||||
|
||||
def get_channel(self, channel_name: str):
|
||||
return self._channels.get(channel_name)
|
||||
|
||||
def start(self, channel_names: list, first_start: bool = False):
|
||||
"""
|
||||
Create and start one or more channels in sub-threads.
|
||||
If first_start is True, plugins and linkai client will also be initialized.
|
||||
"""
|
||||
with self._lock:
|
||||
channels = []
|
||||
for name in channel_names:
|
||||
ch = channel_factory.create_channel(name)
|
||||
ch.cloud_mode = self.cloud_mode
|
||||
self._channels[name] = ch
|
||||
channels.append((name, ch))
|
||||
if self._primary_channel is None and name != "web":
|
||||
self._primary_channel = ch
|
||||
|
||||
if self._primary_channel is None and channels:
|
||||
self._primary_channel = channels[0][1]
|
||||
|
||||
if first_start:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
# Cloud client is optional. It is only started when
|
||||
# use_linkai=True AND cloud_deployment_id is set.
|
||||
# By default neither is configured, so the app runs
|
||||
# entirely locally without any remote connection.
|
||||
if conf().get("use_linkai") and (
|
||||
os.environ.get("CLOUD_DEPLOYMENT_ID") or conf().get("cloud_deployment_id")
|
||||
):
|
||||
try:
|
||||
from common import cloud_client
|
||||
threading.Thread(
|
||||
target=cloud_client.start,
|
||||
args=(self._primary_channel, self),
|
||||
daemon=True,
|
||||
).start()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Start web console first so its logs print cleanly,
|
||||
# then start remaining channels after a brief pause.
|
||||
web_entry = None
|
||||
other_entries = []
|
||||
for entry in channels:
|
||||
if entry[0] == "web":
|
||||
web_entry = entry
|
||||
else:
|
||||
other_entries.append(entry)
|
||||
|
||||
ordered = ([web_entry] if web_entry else []) + other_entries
|
||||
for i, (name, ch) in enumerate(ordered):
|
||||
if i > 0 and name != "web":
|
||||
time.sleep(0.1)
|
||||
t = threading.Thread(target=self._run_channel, args=(name, ch), daemon=True)
|
||||
self._threads[name] = t
|
||||
t.start()
|
||||
logger.debug(f"[ChannelManager] Channel '{name}' started in sub-thread")
|
||||
|
||||
def _run_channel(self, name: str, channel):
|
||||
try:
|
||||
channel.startup()
|
||||
except Exception as e:
|
||||
logger.error(f"[ChannelManager] Channel '{name}' startup error: {e}")
|
||||
logger.exception(e)
|
||||
|
||||
def stop(self, channel_name: str = None):
|
||||
"""
|
||||
Stop channel(s). If channel_name is given, stop only that channel;
|
||||
otherwise stop all channels.
|
||||
"""
|
||||
# Pop under lock, then stop outside lock to avoid deadlock
|
||||
with self._lock:
|
||||
names = [channel_name] if channel_name else list(self._channels.keys())
|
||||
to_stop = []
|
||||
for name in names:
|
||||
ch = self._channels.pop(name, None)
|
||||
th = self._threads.pop(name, None)
|
||||
to_stop.append((name, ch, th))
|
||||
if channel_name and self._primary_channel is self._channels.get(channel_name):
|
||||
self._primary_channel = None
|
||||
|
||||
for name, ch, th in to_stop:
|
||||
if ch is None:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' not found in managed channels")
|
||||
if th and th.is_alive():
|
||||
self._interrupt_thread(th, name)
|
||||
continue
|
||||
logger.info(f"[ChannelManager] Stopping channel '{name}'...")
|
||||
graceful = False
|
||||
if hasattr(ch, 'stop'):
|
||||
try:
|
||||
ch.stop()
|
||||
graceful = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Error during channel '{name}' stop: {e}")
|
||||
if th and th.is_alive():
|
||||
th.join(timeout=5)
|
||||
if th.is_alive():
|
||||
if graceful:
|
||||
logger.info(f"[ChannelManager] Channel '{name}' thread still alive after stop(), "
|
||||
"leaving daemon thread to finish on its own")
|
||||
else:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' thread did not exit in 5s, forcing interrupt")
|
||||
self._interrupt_thread(th, name)
|
||||
|
||||
@staticmethod
|
||||
def _interrupt_thread(th: threading.Thread, name: str):
|
||||
"""Raise SystemExit in target thread to break blocking loops like start_forever."""
|
||||
import ctypes
|
||||
try:
|
||||
tid = th.ident
|
||||
if tid is None:
|
||||
return
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
|
||||
)
|
||||
if res == 1:
|
||||
logger.info(f"[ChannelManager] Interrupted thread for channel '{name}'")
|
||||
elif res > 1:
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
|
||||
logger.warning(f"[ChannelManager] Failed to interrupt thread for channel '{name}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Thread interrupt error for '{name}': {e}")
|
||||
|
||||
def restart(self, new_channel_name: str):
|
||||
"""
|
||||
Restart a single channel with a new channel type.
|
||||
Can be called from any thread (e.g. linkai config callback).
|
||||
"""
|
||||
logger.info(f"[ChannelManager] Restarting channel to '{new_channel_name}'...")
|
||||
self.stop(new_channel_name)
|
||||
_clear_singleton_cache(new_channel_name)
|
||||
time.sleep(1)
|
||||
self.start([new_channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel restarted to '{new_channel_name}' successfully")
|
||||
|
||||
def add_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically add and start a new channel.
|
||||
If the channel is already running, restart it instead.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name in self._channels:
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' already exists, restarting")
|
||||
if self._channels.get(channel_name):
|
||||
self.restart(channel_name)
|
||||
return
|
||||
logger.info(f"[ChannelManager] Adding channel '{channel_name}'...")
|
||||
_clear_singleton_cache(channel_name)
|
||||
self.start([channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' added successfully")
|
||||
|
||||
def remove_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically stop and remove a running channel.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name not in self._channels:
|
||||
logger.warning(f"[ChannelManager] Channel '{channel_name}' not found, nothing to remove")
|
||||
return
|
||||
logger.info(f"[ChannelManager] Removing channel '{channel_name}'...")
|
||||
self.stop(channel_name)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' removed successfully")
|
||||
|
||||
|
||||
def _clear_singleton_cache(channel_name: str):
|
||||
"""
|
||||
Clear the singleton cache for the channel class so that
|
||||
a new instance can be created with updated config.
|
||||
"""
|
||||
cls_map = {
|
||||
"web": "channel.web.web_channel.WebChannel",
|
||||
"wechatmp": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatmp_service": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatcom_app": "channel.wechatcom.wechatcomapp_channel.WechatComAppChannel",
|
||||
const.WECHAT_KF: "channel.wechat_kf.wechat_kf_channel.WechatKfChannel",
|
||||
const.FEISHU: "channel.feishu.feishu_channel.FeiShuChanel",
|
||||
const.DINGTALK: "channel.dingtalk.dingtalk_channel.DingTalkChanel",
|
||||
const.WECOM_BOT: "channel.wecom_bot.wecom_bot_channel.WecomBotChannel",
|
||||
const.QQ: "channel.qq.qq_channel.QQChannel",
|
||||
const.WEIXIN: "channel.weixin.weixin_channel.WeixinChannel",
|
||||
"wx": "channel.weixin.weixin_channel.WeixinChannel",
|
||||
}
|
||||
module_path = cls_map.get(channel_name)
|
||||
if not module_path:
|
||||
return
|
||||
try:
|
||||
parts = module_path.rsplit(".", 1)
|
||||
module_name, class_name = parts[0], parts[1]
|
||||
import importlib
|
||||
module = importlib.import_module(module_name)
|
||||
wrapper = getattr(module, class_name, None)
|
||||
if wrapper and hasattr(wrapper, '__closure__') and wrapper.__closure__:
|
||||
for cell in wrapper.__closure__:
|
||||
try:
|
||||
cell_contents = cell.cell_contents
|
||||
if isinstance(cell_contents, dict):
|
||||
cell_contents.clear()
|
||||
logger.debug(f"[ChannelManager] Cleared singleton cache for {class_name}")
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Failed to clear singleton cache: {e}")
|
||||
|
||||
|
||||
def sigterm_handler_wrap(_signo):
|
||||
old_handler = signal.getsignal(_signo)
|
||||
|
||||
@@ -25,22 +275,65 @@ def sigterm_handler_wrap(_signo):
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def start_channel(channel_name: str):
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "web", "wechatmp_service", "wechatcom_app", "wework",
|
||||
const.FEISHU, const.DINGTALK]:
|
||||
PluginManager().load_plugins()
|
||||
def _warmup_mcp_tools():
|
||||
"""
|
||||
Kick off MCP server loading at process startup so subprocesses
|
||||
(npx / uvx etc.) finish initializing before the first user message
|
||||
arrives. Returns immediately — the actual work happens on a daemon
|
||||
thread inside ToolManager. Safe to call when MCP is not configured.
|
||||
"""
|
||||
try:
|
||||
from agent.tools import ToolManager
|
||||
ToolManager()._load_mcp_tools()
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] MCP warmup failed (non-fatal): {e}")
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
try:
|
||||
from common import linkai_client
|
||||
threading.Thread(target=linkai_client.start, args=(channel,)).start()
|
||||
except Exception as e:
|
||||
pass
|
||||
channel.startup()
|
||||
|
||||
def _warmup_scheduler():
|
||||
"""Eager-init AgentBridge so the scheduler thread starts at process
|
||||
boot rather than waiting for the first user message."""
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
Bridge().get_agent_bridge()
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Scheduler warmup failed: {e}")
|
||||
|
||||
|
||||
def _sync_builtin_skills():
|
||||
"""Sync builtin skills from project skills/ to workspace skills/ on startup."""
|
||||
import shutil
|
||||
try:
|
||||
workspace = conf().get("agent_workspace", "~/cow")
|
||||
workspace = os.path.expanduser(workspace)
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
builtin_dir = os.path.join(project_root, "skills")
|
||||
custom_dir = os.path.join(workspace, "skills")
|
||||
|
||||
if not os.path.isdir(builtin_dir):
|
||||
return
|
||||
|
||||
os.makedirs(custom_dir, exist_ok=True)
|
||||
synced = 0
|
||||
for name in os.listdir(builtin_dir):
|
||||
src = os.path.join(builtin_dir, name)
|
||||
if not os.path.isdir(src) or not os.path.isfile(os.path.join(src, "SKILL.md")):
|
||||
continue
|
||||
dst = os.path.join(custom_dir, name)
|
||||
try:
|
||||
if os.path.isdir(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst)
|
||||
synced += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Failed to sync builtin skill '{name}': {e}")
|
||||
if synced:
|
||||
logger.info(f"[App] Synced {synced} builtin skill(s) to workspace")
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Builtin skills sync failed: {e}")
|
||||
|
||||
|
||||
def run():
|
||||
global _channel_mgr
|
||||
try:
|
||||
# load config
|
||||
load_config()
|
||||
@@ -49,36 +342,39 @@ def run():
|
||||
# kill signal
|
||||
sigterm_handler_wrap(signal.SIGTERM)
|
||||
|
||||
# create channel
|
||||
channel_name = conf().get("channel_type", "wx")
|
||||
# Parse channel_type into a list
|
||||
raw_channel = conf().get("channel_type", "web")
|
||||
|
||||
if "--cmd" in sys.argv:
|
||||
channel_name = "terminal"
|
||||
|
||||
if channel_name == "wxy":
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
|
||||
start_channel(channel_name)
|
||||
|
||||
# 打印系统运行成功信息
|
||||
logger.info("")
|
||||
logger.info("=" * 50)
|
||||
if conf().get("agent", False):
|
||||
logger.info("✅ System started successfully!")
|
||||
logger.info("🐮 Cow Agent is running")
|
||||
logger.info(f" Channel: {channel_name}")
|
||||
logger.info(f" Model: {conf().get('model', 'unknown')}")
|
||||
logger.info(f" Workspace: {conf().get('agent_workspace', '~/cow')}")
|
||||
channel_names = ["terminal"]
|
||||
else:
|
||||
logger.info("✅ System started successfully!")
|
||||
logger.info("🤖 ChatBot is running")
|
||||
logger.info(f" Channel: {channel_name}")
|
||||
logger.info(f" Model: {conf().get('model', 'unknown')}")
|
||||
logger.info("=" * 50)
|
||||
logger.info("")
|
||||
channel_names = _parse_channel_type(raw_channel)
|
||||
if not channel_names:
|
||||
channel_names = ["web"]
|
||||
|
||||
# Auto-start web console unless explicitly disabled
|
||||
web_console_enabled = conf().get("web_console", True)
|
||||
if web_console_enabled and "web" not in channel_names:
|
||||
channel_names.append("web")
|
||||
|
||||
# Sync builtin skills to workspace before channels start
|
||||
_sync_builtin_skills()
|
||||
|
||||
# Kick off MCP server loading in the background so first-message
|
||||
# latency isn't dominated by npx package downloads.
|
||||
_warmup_mcp_tools()
|
||||
|
||||
_warmup_scheduler()
|
||||
|
||||
logger.info(f"[App] Starting channels: {channel_names}")
|
||||
|
||||
_channel_mgr = ChannelManager()
|
||||
_channel_mgr.start(channel_names, first_start=True)
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
125
bridge/agent_event_handler.py
Normal file
125
bridge/agent_event_handler.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Agent Event Handler - Handles agent events and thinking process output
|
||||
"""
|
||||
|
||||
from common import const
|
||||
from common.log import logger
|
||||
|
||||
# Cap intermediate thinking messages on weixin to stay within send quota.
|
||||
WEIXIN_THINKING_INSTANT_MAX = 7
|
||||
|
||||
|
||||
class AgentEventHandler:
|
||||
"""
|
||||
Handles agent events and optionally sends intermediate messages to channel
|
||||
"""
|
||||
|
||||
def __init__(self, context=None, original_callback=None):
|
||||
self.context = context
|
||||
self.original_callback = original_callback
|
||||
|
||||
self.channel = None
|
||||
if context:
|
||||
self.channel = context.kwargs.get("channel") if hasattr(context, "kwargs") else None
|
||||
|
||||
self.current_content = ""
|
||||
self.turn_number = 0
|
||||
|
||||
channel_type = ""
|
||||
if context and hasattr(context, "kwargs"):
|
||||
channel_type = context.kwargs.get("channel_type", "") or ""
|
||||
self._is_weixin = channel_type == const.WEIXIN
|
||||
self._thinking_sent_count = 0
|
||||
self._merged_buf: list[str] = []
|
||||
|
||||
def handle_event(self, event):
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "turn_start":
|
||||
self._handle_turn_start(data)
|
||||
elif event_type == "message_update":
|
||||
self._handle_message_update(data)
|
||||
elif event_type == "message_end":
|
||||
self._handle_message_end(data)
|
||||
elif event_type == "reasoning_update":
|
||||
pass
|
||||
elif event_type == "tool_execution_start":
|
||||
self._handle_tool_execution_start(data)
|
||||
elif event_type == "tool_execution_end":
|
||||
self._handle_tool_execution_end(data)
|
||||
elif event_type == "agent_end":
|
||||
self._handle_agent_end(data)
|
||||
|
||||
if self.original_callback:
|
||||
self.original_callback(event)
|
||||
|
||||
def _handle_turn_start(self, data):
|
||||
self.turn_number = data.get("turn", 0)
|
||||
self.current_content = ""
|
||||
|
||||
def _handle_message_update(self, data):
|
||||
delta = data.get("delta", "")
|
||||
self.current_content += delta
|
||||
|
||||
def _handle_message_end(self, data):
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
|
||||
if tool_calls:
|
||||
if self.current_content.strip():
|
||||
logger.info(f"💭 {self.current_content.strip()[:200]}{'...' if len(self.current_content) > 200 else ''}")
|
||||
self._send_to_channel(self.current_content.strip())
|
||||
else:
|
||||
if self.current_content.strip():
|
||||
logger.debug(f"💬 {self.current_content.strip()[:200]}{'...' if len(self.current_content) > 200 else ''}")
|
||||
# Drain weixin buffer before final reply leaves chat_channel
|
||||
self._flush_merged_now()
|
||||
|
||||
self.current_content = ""
|
||||
|
||||
def _handle_agent_end(self, data):
|
||||
self._flush_merged_now()
|
||||
|
||||
def _handle_tool_execution_start(self, data):
|
||||
pass
|
||||
|
||||
def _handle_tool_execution_end(self, data):
|
||||
pass
|
||||
|
||||
def _send_to_channel(self, message):
|
||||
if self.context and self.context.get("on_event"):
|
||||
return
|
||||
if not self.channel:
|
||||
return
|
||||
|
||||
if not self._is_weixin:
|
||||
self._do_send(message)
|
||||
return
|
||||
|
||||
if self._thinking_sent_count < WEIXIN_THINKING_INSTANT_MAX:
|
||||
self._do_send(message)
|
||||
self._thinking_sent_count += 1
|
||||
return
|
||||
|
||||
self._merged_buf.append(message)
|
||||
|
||||
def _flush_merged_now(self):
|
||||
if not self._merged_buf:
|
||||
return
|
||||
merged = "\n\n".join(self._merged_buf)
|
||||
count = len(self._merged_buf)
|
||||
self._merged_buf = []
|
||||
logger.debug(f"[AgentEventHandler] Flushing {count} merged thinking msgs, len={len(merged)}")
|
||||
self._do_send(merged)
|
||||
self._thinking_sent_count += 1
|
||||
|
||||
def _do_send(self, message):
|
||||
try:
|
||||
from bridge.reply import Reply, ReplyType
|
||||
reply = Reply(ReplyType.TEXT, message)
|
||||
self.channel._send(reply, self.context)
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentEventHandler] Failed to send to channel: {e}")
|
||||
|
||||
def log_summary(self):
|
||||
pass
|
||||
823
bridge/agent_initializer.py
Normal file
823
bridge/agent_initializer.py
Normal file
@@ -0,0 +1,823 @@
|
||||
"""
|
||||
Agent Initializer - Handles agent initialization logic
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, List
|
||||
|
||||
from agent.protocol import Agent
|
||||
from agent.tools import ToolManager
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
# Module-level lock to serialize scheduler init across concurrent sessions
|
||||
_scheduler_init_lock = threading.Lock()
|
||||
|
||||
# Track whether the embedding model log has been printed in this process,
|
||||
# so we avoid spamming it once per session.
|
||||
_embedding_logged: bool = False
|
||||
|
||||
|
||||
class AgentInitializer:
|
||||
"""
|
||||
Handles agent initialization including:
|
||||
- Workspace setup
|
||||
- Memory system initialization
|
||||
- Tool loading
|
||||
- System prompt building
|
||||
"""
|
||||
|
||||
def __init__(self, bridge, agent_bridge):
|
||||
"""
|
||||
Initialize agent initializer
|
||||
|
||||
Args:
|
||||
bridge: COW bridge instance
|
||||
agent_bridge: AgentBridge instance (for create_agent method)
|
||||
"""
|
||||
self.bridge = bridge
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def initialize_agent(self, session_id: Optional[str] = None) -> Agent:
|
||||
"""
|
||||
Initialize agent for a session
|
||||
|
||||
Args:
|
||||
session_id: Session ID (None for default agent)
|
||||
|
||||
Returns:
|
||||
Initialized agent instance
|
||||
"""
|
||||
from config import conf
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
|
||||
# Migrate API keys
|
||||
self._migrate_config_to_env(workspace_root)
|
||||
|
||||
# Load environment variables
|
||||
self._load_env_file()
|
||||
|
||||
# Initialize workspace
|
||||
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
|
||||
workspace_files = ensure_workspace(workspace_root, create_templates=True)
|
||||
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Workspace initialized at: {workspace_root}")
|
||||
|
||||
# Setup memory system
|
||||
memory_manager, memory_tools = self._setup_memory_system(workspace_root, session_id)
|
||||
|
||||
# Load tools
|
||||
tools = self._load_tools(workspace_root, memory_manager, memory_tools, session_id)
|
||||
|
||||
# Initialize scheduler if needed
|
||||
self._initialize_scheduler(tools, session_id)
|
||||
|
||||
# Load context files
|
||||
context_files = load_context_files(workspace_root)
|
||||
|
||||
# Initialize skill manager
|
||||
skill_manager = self._initialize_skill_manager(workspace_root, session_id)
|
||||
|
||||
# Build system prompt
|
||||
prompt_builder = PromptBuilder(workspace_dir=workspace_root, language="zh")
|
||||
runtime_info = self._get_runtime_info(workspace_root)
|
||||
|
||||
system_prompt = prompt_builder.build(
|
||||
tools=tools,
|
||||
context_files=context_files,
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
)
|
||||
|
||||
# Get cost control parameters
|
||||
from config import conf
|
||||
max_steps = conf().get("agent_max_steps", 20)
|
||||
max_context_tokens = conf().get("agent_max_context_tokens", 50000)
|
||||
|
||||
# Create agent
|
||||
agent = self.agent_bridge.create_agent(
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
max_steps=max_steps,
|
||||
output_mode="logger",
|
||||
workspace_dir=workspace_root,
|
||||
skill_manager=skill_manager,
|
||||
enable_skills=True,
|
||||
max_context_tokens=max_context_tokens,
|
||||
runtime_info=runtime_info # Pass runtime_info for dynamic time updates
|
||||
)
|
||||
|
||||
# Attach memory manager and share LLM model for summarization
|
||||
if memory_manager:
|
||||
agent.memory_manager = memory_manager
|
||||
if hasattr(agent, 'model') and agent.model:
|
||||
memory_manager.flush_manager.llm_model = agent.model
|
||||
|
||||
# Restore persisted conversation history for this session
|
||||
if session_id:
|
||||
self._restore_conversation_history(agent, session_id)
|
||||
|
||||
# Start daily memory flush timer (once, on first agent init regardless of session)
|
||||
self._start_daily_flush_timer()
|
||||
|
||||
return agent
|
||||
|
||||
def _restore_conversation_history(self, agent, session_id: str) -> None:
|
||||
"""
|
||||
Load persisted conversation messages from SQLite and inject them
|
||||
into the agent's in-memory message list.
|
||||
|
||||
Only user text and assistant text are restored. Tool call chains
|
||||
(tool_use / tool_result) are stripped out because:
|
||||
1. They are intermediate process, the value is already in the final
|
||||
assistant text reply.
|
||||
2. They consume massive context tokens (often 80%+ of history).
|
||||
3. Different models have incompatible tool message formats, so
|
||||
restoring tool chains across model switches causes 400 errors.
|
||||
4. Eliminates the entire class of tool_use/tool_result pairing bugs.
|
||||
"""
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
store = get_conversation_store()
|
||||
max_turns = conf().get("agent_max_context_turns", 20)
|
||||
# Scheduler tasks run on a stable isolated session per task and
|
||||
# can fire many times a day; a smaller restore window keeps prompt
|
||||
# cost bounded while still letting the agent see "last few" runs
|
||||
# for trend / dedup style logic. Regular chat sessions keep the
|
||||
# original heuristic so user dialogues feel continuous.
|
||||
if session_id.startswith("scheduler_"):
|
||||
restore_turns = max(1, max_turns // 5)
|
||||
else:
|
||||
restore_turns = max(3, max_turns // 6)
|
||||
saved = store.load_messages(session_id, max_turns=restore_turns)
|
||||
if saved:
|
||||
filtered = self._filter_text_only_messages(saved)
|
||||
if filtered:
|
||||
with agent.messages_lock:
|
||||
agent.messages = filtered
|
||||
logger.debug(
|
||||
f"[AgentInitializer] Restored {len(filtered)} text messages "
|
||||
f"(from {len(saved)} total, {restore_turns} turns cap) "
|
||||
f"for session={session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentInitializer] Failed to restore conversation history for "
|
||||
f"session={session_id}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_text_only_messages(messages: list) -> list:
|
||||
"""
|
||||
Extract clean user/assistant turn pairs from raw message history.
|
||||
|
||||
Groups messages into turns (each starting with a real user query),
|
||||
then keeps only:
|
||||
- The first user text in each turn (the actual user input)
|
||||
- The last assistant text in each turn (the final answer)
|
||||
|
||||
All tool_use, tool_result, intermediate assistant thoughts, and
|
||||
internal hint messages injected by the agent loop are discarded.
|
||||
"""
|
||||
|
||||
def _extract_text(content) -> str:
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
def _is_real_user_msg(msg: dict) -> bool:
|
||||
"""True for actual user input, False for tool_result or internal hints."""
|
||||
if msg.get("role") != "user":
|
||||
return False
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
has_tool_result = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
)
|
||||
if has_tool_result:
|
||||
return False
|
||||
text = _extract_text(content)
|
||||
return bool(text)
|
||||
|
||||
# Group into turns: each turn starts with a real user message
|
||||
turns = []
|
||||
current_turn = None
|
||||
for msg in messages:
|
||||
if _is_real_user_msg(msg):
|
||||
if current_turn is not None:
|
||||
turns.append(current_turn)
|
||||
current_turn = {"user": msg, "assistants": []}
|
||||
elif current_turn is not None and msg.get("role") == "assistant":
|
||||
text = _extract_text(msg.get("content"))
|
||||
if text:
|
||||
current_turn["assistants"].append(text)
|
||||
if current_turn is not None:
|
||||
turns.append(current_turn)
|
||||
|
||||
# Build result: one user msg + one assistant msg per turn
|
||||
filtered = []
|
||||
for turn in turns:
|
||||
user_text = _extract_text(turn["user"].get("content"))
|
||||
if not user_text:
|
||||
continue
|
||||
filtered.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if turn["assistants"]:
|
||||
final_reply = turn["assistants"][-1]
|
||||
filtered.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": final_reply}]
|
||||
})
|
||||
|
||||
return filtered
|
||||
|
||||
def _load_env_file(self):
|
||||
"""Load environment variables from .env file"""
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(env_file, override=True)
|
||||
except ImportError:
|
||||
logger.warning("[AgentInitializer] python-dotenv not installed")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to load .env file: {e}")
|
||||
|
||||
def _setup_memory_system(self, workspace_root: str, session_id: Optional[str] = None):
|
||||
"""
|
||||
Setup memory system
|
||||
|
||||
Returns:
|
||||
(memory_manager, memory_tools) tuple
|
||||
"""
|
||||
memory_manager = None
|
||||
memory_tools = []
|
||||
|
||||
try:
|
||||
from agent.memory import MemoryManager, MemoryConfig
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
from config import conf
|
||||
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
embedding_provider = self._init_embedding_provider(
|
||||
memory_config, session_id=session_id
|
||||
)
|
||||
|
||||
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
self._sync_memory(memory_manager, session_id)
|
||||
|
||||
memory_tools = [
|
||||
MemorySearchTool(memory_manager),
|
||||
MemoryGetTool(memory_manager)
|
||||
]
|
||||
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] Memory system initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Memory system not available: {e}")
|
||||
|
||||
return memory_manager, memory_tools
|
||||
|
||||
def _init_embedding_provider(self, memory_config, session_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the embedding provider for memory.
|
||||
|
||||
Two paths:
|
||||
A. Default (no `embedding_provider` in config.json):
|
||||
Auto-init OpenAI -> LinkAI fallback. Existing 1536-dim indices
|
||||
keep working.
|
||||
B. Explicit (`embedding_provider` is set):
|
||||
Initialize the requested vendor with unified dim (default 1024).
|
||||
If the index was built with a different dim, vector search will
|
||||
quietly return no results (cosine returns 0) and keyword search
|
||||
takes over until the user runs /memory rebuild-index.
|
||||
"""
|
||||
from agent.memory import create_embedding_provider
|
||||
from config import conf
|
||||
|
||||
explicit_provider = (conf().get("embedding_provider") or "").strip().lower()
|
||||
|
||||
if not explicit_provider:
|
||||
return self._init_embedding_provider_legacy(session_id=session_id)
|
||||
|
||||
return self._init_embedding_provider_explicit(
|
||||
memory_config, explicit_provider, session_id=session_id,
|
||||
)
|
||||
|
||||
def _init_embedding_provider_legacy(self, session_id: Optional[str] = None):
|
||||
"""Legacy auto-init path: OpenAI -> LinkAI. Preserved verbatim for compat."""
|
||||
from agent.memory import create_embedding_provider
|
||||
from config import conf
|
||||
|
||||
embedding_provider = None
|
||||
embedding_model = None
|
||||
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
model = "text-embedding-3-small"
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model=model,
|
||||
api_key=openai_api_key,
|
||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||
)
|
||||
embedding_model = f"openai/{model}"
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
|
||||
|
||||
if embedding_provider is None:
|
||||
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
|
||||
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
model = "text-embedding-3-small"
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="linkai",
|
||||
model=model,
|
||||
api_key=linkai_api_key,
|
||||
api_base=f"{linkai_api_base}/v1"
|
||||
)
|
||||
embedding_model = f"linkai/{model}"
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
|
||||
|
||||
if embedding_provider is not None and embedding_model:
|
||||
global _embedding_logged
|
||||
if not _embedding_logged:
|
||||
logger.info(
|
||||
f"[AgentInitializer] Embedding model in use: {embedding_model} "
|
||||
f"(dim={embedding_provider.dimensions})"
|
||||
)
|
||||
_embedding_logged = True
|
||||
|
||||
return embedding_provider
|
||||
|
||||
def _init_embedding_provider_explicit(
|
||||
self,
|
||||
memory_config,
|
||||
provider_key: str,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
"""Explicit-provider path: build the configured vendor.
|
||||
|
||||
If the index was built with a different dim, vector search will
|
||||
silently return no results (cosine returns 0 for mismatched dims)
|
||||
and keyword search takes over. Users switch vendors by running
|
||||
/memory rebuild-index — see docs.
|
||||
"""
|
||||
from agent.memory import create_embedding_provider
|
||||
from agent.memory.embedding import EMBEDDING_VENDORS
|
||||
from config import conf
|
||||
|
||||
meta = EMBEDDING_VENDORS.get(provider_key)
|
||||
if meta is None:
|
||||
logger.error(
|
||||
f"[AgentInitializer] Unknown embedding_provider '{provider_key}'. "
|
||||
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}. "
|
||||
f"Memory will run in keyword-only mode."
|
||||
)
|
||||
return None
|
||||
|
||||
api_key = self._resolve_embedding_api_key(provider_key)
|
||||
api_base = self._resolve_embedding_api_base(provider_key, meta["default_base_url"])
|
||||
|
||||
if not api_key:
|
||||
logger.error(
|
||||
f"[AgentInitializer] embedding_provider='{provider_key}' is set but its "
|
||||
f"API key is missing. Memory will run in keyword-only mode."
|
||||
)
|
||||
return None
|
||||
|
||||
model = (conf().get("embedding_model") or "").strip() or meta["default_model"]
|
||||
try:
|
||||
cfg_dim = int(conf().get("embedding_dimensions") or 0)
|
||||
except (TypeError, ValueError):
|
||||
cfg_dim = 0
|
||||
dim = cfg_dim if cfg_dim > 0 else meta["default_dimensions"]
|
||||
|
||||
try:
|
||||
provider = create_embedding_provider(
|
||||
provider=provider_key,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
dimensions=dim,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AgentInitializer] Failed to init embedding provider "
|
||||
f"'{provider_key}/{model}': {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
global _embedding_logged
|
||||
if not _embedding_logged:
|
||||
logger.info(
|
||||
f"[AgentInitializer] Embedding model in use: "
|
||||
f"{provider_key}/{model} (dim={provider.dimensions})"
|
||||
)
|
||||
_embedding_logged = True
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def _resolve_embedding_api_key(provider_key: str) -> str:
|
||||
"""Pick the API key for an explicit embedding provider from config."""
|
||||
from config import conf
|
||||
|
||||
key_map = {
|
||||
"openai": "open_ai_api_key",
|
||||
"linkai": "linkai_api_key",
|
||||
"dashscope": "dashscope_api_key",
|
||||
"doubao": "ark_api_key",
|
||||
"zhipu": "zhipu_ai_api_key",
|
||||
}
|
||||
field = key_map.get(provider_key)
|
||||
if not field:
|
||||
return ""
|
||||
value = conf().get(field, "") or ""
|
||||
if value in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
return ""
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _resolve_embedding_api_base(provider_key: str, default_base: str) -> str:
|
||||
"""Pick the API base for an explicit embedding provider from config."""
|
||||
from config import conf
|
||||
|
||||
base_map = {
|
||||
"openai": "open_ai_api_base",
|
||||
"linkai": "linkai_api_base",
|
||||
"doubao": "ark_base_url",
|
||||
"zhipu": "zhipu_ai_api_base",
|
||||
}
|
||||
field = base_map.get(provider_key)
|
||||
if not field:
|
||||
return default_base
|
||||
value = (conf().get(field) or "").strip()
|
||||
if not value:
|
||||
return default_base
|
||||
if provider_key == "linkai" and not value.rstrip("/").endswith("/v1"):
|
||||
return f"{value.rstrip('/')}/v1"
|
||||
return value
|
||||
|
||||
def _sync_memory(self, memory_manager, session_id: Optional[str] = None):
|
||||
"""Sync memory database"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed")
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
if loop.is_running():
|
||||
asyncio.create_task(memory_manager.sync())
|
||||
else:
|
||||
loop.run_until_complete(memory_manager.sync())
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Memory sync failed: {e}")
|
||||
|
||||
def _load_tools(self, workspace_root: str, memory_manager, memory_tools: List, session_id: Optional[str] = None):
|
||||
"""Load all tools"""
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.load_tools()
|
||||
|
||||
tools = []
|
||||
file_config = {
|
||||
"cwd": workspace_root,
|
||||
"memory_manager": memory_manager
|
||||
} if memory_manager else {"cwd": workspace_root}
|
||||
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
# Skip web_search if no API key is available
|
||||
if tool_name == "web_search":
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
if not WebSearch.is_available():
|
||||
logger.debug("[AgentInitializer] WebSearch skipped - no search provider configured")
|
||||
continue
|
||||
|
||||
# Special handling for EnvConfig tool
|
||||
if tool_name == "env_config":
|
||||
from agent.tools import EnvConfig
|
||||
tool = EnvConfig({"agent_bridge": self.agent_bridge})
|
||||
else:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
if tool:
|
||||
# Apply workspace config to file operation tools.
|
||||
# Merge into the existing tool.config (set by ToolManager from
|
||||
# config.json's `tools.<name>` section) instead of replacing
|
||||
# it, otherwise per-tool user configs (e.g. browser.cdp_endpoint)
|
||||
# would be silently dropped.
|
||||
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls', 'web_fetch', 'send', 'browser']:
|
||||
merged_config = dict(getattr(tool, 'config', None) or {})
|
||||
merged_config.update(file_config)
|
||||
tool.config = merged_config
|
||||
tool.cwd = merged_config.get("cwd", getattr(tool, 'cwd', None))
|
||||
if 'memory_manager' in merged_config:
|
||||
tool.memory_manager = merged_config['memory_manager']
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to load tool {tool_name}: {e}")
|
||||
|
||||
# Add MCP tools (snapshot to avoid races with the background loader)
|
||||
mcp_tools_snapshot = list(tool_manager._mcp_tool_instances.items())
|
||||
if mcp_tools_snapshot:
|
||||
for _, mcp_tool in mcp_tools_snapshot:
|
||||
tools.append(mcp_tool)
|
||||
if session_id is None:
|
||||
names = [name for name, _ in mcp_tools_snapshot]
|
||||
logger.info(
|
||||
f"[AgentInitializer] Added {len(names)} MCP tool(s): {names}"
|
||||
)
|
||||
|
||||
# Add memory tools
|
||||
if memory_tools:
|
||||
tools.extend(memory_tools)
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Added {len(memory_tools)} memory tools")
|
||||
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Loaded {len(tools)} tools: {[t.name for t in tools]}")
|
||||
|
||||
return tools
|
||||
|
||||
def _initialize_scheduler(self, tools: List, session_id: Optional[str] = None):
|
||||
"""Initialize scheduler service if needed.
|
||||
|
||||
Serialize the check-and-set under a module-level lock so concurrent
|
||||
first-time session inits cannot each create a new SchedulerService
|
||||
(which would leak background scanning threads).
|
||||
"""
|
||||
if not self.agent_bridge.scheduler_initialized:
|
||||
with _scheduler_init_lock:
|
||||
if not self.agent_bridge.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import init_scheduler
|
||||
if init_scheduler(self.agent_bridge):
|
||||
self.agent_bridge.scheduler_initialized = True
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] Scheduler service initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize scheduler: {e}")
|
||||
|
||||
# Inject scheduler dependencies
|
||||
if self.agent_bridge.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
|
||||
from agent.tools import SchedulerTool
|
||||
from config import conf
|
||||
|
||||
task_store = get_task_store()
|
||||
scheduler_service = get_scheduler_service()
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, SchedulerTool):
|
||||
tool.task_store = task_store
|
||||
tool.scheduler_service = scheduler_service
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
raw_ct = conf().get("channel_type", "unknown")
|
||||
if isinstance(raw_ct, list):
|
||||
ct = raw_ct[0] if raw_ct else "unknown"
|
||||
elif isinstance(raw_ct, str) and "," in raw_ct:
|
||||
ct = raw_ct.split(",")[0].strip()
|
||||
else:
|
||||
ct = raw_ct
|
||||
tool.config["channel_type"] = ct
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to inject scheduler dependencies: {e}")
|
||||
|
||||
def _initialize_skill_manager(self, workspace_root: str, session_id: Optional[str] = None):
|
||||
"""Initialize skill manager"""
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
skill_manager = SkillManager(custom_dir=os.path.join(workspace_root, "skills"))
|
||||
return skill_manager
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize SkillManager: {e}")
|
||||
return None
|
||||
|
||||
def _get_runtime_info(self, workspace_root: str):
|
||||
"""Get runtime information with dynamic time support"""
|
||||
from config import conf
|
||||
|
||||
def get_current_time():
|
||||
"""Get current time dynamically - called each time system prompt is accessed"""
|
||||
now = datetime.datetime.now()
|
||||
|
||||
# Get timezone info
|
||||
try:
|
||||
offset = -time.timezone if not time.daylight else -time.altzone
|
||||
hours = offset // 3600
|
||||
minutes = (offset % 3600) // 60
|
||||
timezone_name = f"UTC{hours:+03d}:{minutes:02d}" if minutes else f"UTC{hours:+03d}"
|
||||
except Exception:
|
||||
timezone_name = "UTC"
|
||||
|
||||
# Weekday: English name in en, Chinese mapping otherwise
|
||||
weekday_en = now.strftime("%A")
|
||||
try:
|
||||
from common import i18n
|
||||
is_en = i18n.get_language() == "en"
|
||||
except Exception:
|
||||
is_en = False
|
||||
if is_en:
|
||||
weekday = weekday_en
|
||||
else:
|
||||
weekday_map = {
|
||||
'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三',
|
||||
'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'
|
||||
}
|
||||
weekday = weekday_map.get(weekday_en, weekday_en)
|
||||
|
||||
return {
|
||||
'time': now.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
'weekday': weekday,
|
||||
'timezone': timezone_name
|
||||
}
|
||||
|
||||
def get_model():
|
||||
"""Get current model name dynamically from config"""
|
||||
return conf().get("model", "unknown")
|
||||
|
||||
return {
|
||||
"_get_model": get_model,
|
||||
"workspace": workspace_root,
|
||||
"channel": ", ".join(conf().get("channel_type")) if isinstance(conf().get("channel_type"), list) else conf().get("channel_type", "unknown"),
|
||||
"_get_current_time": get_current_time # Dynamic time function
|
||||
}
|
||||
|
||||
def _migrate_config_to_env(self, workspace_root: str):
|
||||
"""Migrate API keys from config.json to .env file"""
|
||||
from config import conf
|
||||
|
||||
key_mapping = {
|
||||
"open_ai_api_key": "OPENAI_API_KEY",
|
||||
"open_ai_api_base": "OPENAI_API_BASE",
|
||||
"gemini_api_key": "GEMINI_API_KEY",
|
||||
"claude_api_key": "CLAUDE_API_KEY",
|
||||
"linkai_api_key": "LINKAI_API_KEY",
|
||||
}
|
||||
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
|
||||
# Read existing env vars (key -> value)
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
with open(env_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, val = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = val.strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to read .env file: {e}")
|
||||
|
||||
# Sync config.json values into .env (add/update/remove)
|
||||
updated = False
|
||||
for config_key, env_key in key_mapping.items():
|
||||
raw = conf().get(config_key, "")
|
||||
value = raw.strip() if raw else ""
|
||||
old_value = existing_env_vars.get(env_key)
|
||||
|
||||
if value:
|
||||
if old_value == value:
|
||||
continue
|
||||
existing_env_vars[env_key] = value
|
||||
os.environ[env_key] = value
|
||||
updated = True
|
||||
else:
|
||||
if old_value is None:
|
||||
continue
|
||||
existing_env_vars.pop(env_key, None)
|
||||
os.environ.pop(env_key, None)
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
try:
|
||||
env_dir = os.path.dirname(env_file)
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
|
||||
# Rewrite the entire .env file to ensure consistency
|
||||
with open(env_file, 'w', encoding='utf-8') as f:
|
||||
f.write('# Environment variables for agent\n')
|
||||
f.write('# Auto-managed - synced from config.json on startup\n\n')
|
||||
for key, value in sorted(existing_env_vars.items()):
|
||||
f.write(f'{key}={value}\n')
|
||||
|
||||
logger.info(f"[AgentInitializer] Synced API keys from config.json to .env")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to sync API keys: {e}")
|
||||
|
||||
def _start_daily_flush_timer(self):
|
||||
"""Start a background thread that flushes all agents' memory daily at 23:55."""
|
||||
if getattr(self.agent_bridge, '_daily_flush_started', False):
|
||||
return
|
||||
self.agent_bridge._daily_flush_started = True
|
||||
|
||||
import threading
|
||||
|
||||
def _daily_flush_loop():
|
||||
import random
|
||||
last_run_date = None # Track last successful run date to prevent same-day re-trigger
|
||||
while True:
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
jitter_min = random.randint(50, 55)
|
||||
jitter_sec = random.randint(0, 59)
|
||||
target = now.replace(hour=23, minute=jitter_min, second=jitter_sec, microsecond=0)
|
||||
# Always schedule for tomorrow if we already ran today, or if target time has passed
|
||||
if target <= now or (last_run_date == now.date()):
|
||||
target += datetime.timedelta(days=1)
|
||||
wait_seconds = (target - now).total_seconds()
|
||||
logger.info(f"[DailyFlush] Next flush at {target.strftime('%Y-%m-%d %H:%M:%S')} (in {wait_seconds/3600:.1f}h)")
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
self._flush_all_agents()
|
||||
last_run_date = datetime.datetime.now().date()
|
||||
except Exception as e:
|
||||
logger.warning(f"[DailyFlush] Error in daily flush loop: {e}")
|
||||
time.sleep(3600)
|
||||
|
||||
t = threading.Thread(target=_daily_flush_loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
def _flush_all_agents(self):
|
||||
"""Flush memory for all active agent sessions, then run Deep Dream."""
|
||||
agents = []
|
||||
if self.agent_bridge.default_agent:
|
||||
agents.append(("default", self.agent_bridge.default_agent))
|
||||
for sid, agent in self.agent_bridge.agents.items():
|
||||
agents.append((sid, agent))
|
||||
|
||||
if not agents:
|
||||
return
|
||||
|
||||
# Phase 1: flush daily summaries
|
||||
flushed = 0
|
||||
flush_threads = []
|
||||
dream_candidate = None
|
||||
for label, agent in agents:
|
||||
try:
|
||||
if not agent.memory_manager:
|
||||
continue
|
||||
with agent.messages_lock:
|
||||
messages = list(agent.messages)
|
||||
if not messages:
|
||||
continue
|
||||
result = agent.memory_manager.flush_manager.create_daily_summary(messages)
|
||||
if result:
|
||||
flushed += 1
|
||||
t = agent.memory_manager.flush_manager._last_flush_thread
|
||||
if t:
|
||||
flush_threads.append(t)
|
||||
if dream_candidate is None:
|
||||
dream_candidate = agent.memory_manager.flush_manager
|
||||
except Exception as e:
|
||||
logger.warning(f"[DailyFlush] Failed for session {label}: {e}")
|
||||
|
||||
if flushed:
|
||||
logger.info(f"[DailyFlush] Flushed {flushed}/{len(agents)} agent session(s)")
|
||||
|
||||
# Wait for all flush threads to finish before dreaming
|
||||
for t in flush_threads:
|
||||
t.join(timeout=60)
|
||||
|
||||
# Phase 2: Deep Dream — distill daily memories → MEMORY.md + dream diary
|
||||
if dream_candidate:
|
||||
try:
|
||||
result = dream_candidate.deep_dream()
|
||||
if result:
|
||||
logger.info("[DeepDream] Memory distillation completed successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DeepDream] Failed: {e}")
|
||||
@@ -13,8 +13,10 @@ from voice.factory import create_voice
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
self.btype = {
|
||||
"chat": const.CHATGPT,
|
||||
"voice_to_text": conf().get("voice_to_text", "openai"),
|
||||
"chat": const.OPENAI,
|
||||
# Empty `voice_to_text` (the default in new configs) triggers
|
||||
# the auto-pick below — see _auto_pick_voice_to_text for order.
|
||||
"voice_to_text": conf().get("voice_to_text") or self._auto_pick_voice_to_text(),
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
}
|
||||
@@ -24,6 +26,13 @@ class Bridge(object):
|
||||
self.btype["chat"] = bot_type
|
||||
else:
|
||||
model_type = conf().get("model") or const.GPT_41_MINI
|
||||
|
||||
# Ensure model_type is string to prevent AttributeError when using startswith()
|
||||
# This handles cases where numeric model names (e.g., "1") are parsed as integers from YAML
|
||||
if not isinstance(model_type, str):
|
||||
logger.warning(f"[Bridge] model_type is not a string: {model_type} (type: {type(model_type).__name__}), converting to string")
|
||||
model_type = str(model_type)
|
||||
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
@@ -32,9 +41,9 @@ class Bridge(object):
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
if model_type in [const.QWEN, const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and (model_type.startswith("qwen") or model_type.startswith("qwq") or model_type.startswith("qvq")):
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and model_type.startswith("gemini"):
|
||||
self.btype["chat"] = const.GEMINI
|
||||
@@ -43,16 +52,31 @@ class Bridge(object):
|
||||
if model_type and model_type.startswith("claude"):
|
||||
self.btype["chat"] = const.CLAUDEAPI
|
||||
|
||||
if model_type in ["claude"]:
|
||||
self.btype["chat"] = const.CLAUDEAI
|
||||
|
||||
if model_type in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
if model_type and model_type.startswith("kimi"):
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
|
||||
if model_type and model_type.startswith("doubao"):
|
||||
self.btype["chat"] = const.DOUBAO
|
||||
|
||||
if model_type and model_type.startswith("deepseek"):
|
||||
self.btype["chat"] = const.DEEPSEEK
|
||||
|
||||
# 小米 MiMo 系列模型,全部以 mimo- 开头
|
||||
if model_type and model_type.startswith("mimo-"):
|
||||
self.btype["chat"] = const.MIMO
|
||||
|
||||
if model_type and isinstance(model_type, str):
|
||||
lowered_model_type = model_type.lower()
|
||||
if lowered_model_type == const.QIANFAN or lowered_model_type.startswith("ernie"):
|
||||
self.btype["chat"] = const.QIANFAN
|
||||
|
||||
if model_type in [const.MODELSCOPE]:
|
||||
self.btype["chat"] = const.MODELSCOPE
|
||||
|
||||
if model_type in ["abab6.5-chat"]:
|
||||
# MiniMax models
|
||||
if model_type and (model_type in ["abab6.5-chat", "abab6.5"] or model_type.lower().startswith("minimax")):
|
||||
self.btype["chat"] = const.MiniMax
|
||||
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
@@ -66,6 +90,46 @@ class Bridge(object):
|
||||
self.chat_bots = {}
|
||||
self._agent_bridge = None
|
||||
|
||||
def refresh_voice(self):
|
||||
"""Re-read voice_to_text / text_to_voice from config and drop the
|
||||
cached voice bots so the next call picks up the new provider.
|
||||
Used by the web console after the user edits voice settings.
|
||||
Does NOT touch the agent_bridge / agent state.
|
||||
"""
|
||||
new_v2t = conf().get("voice_to_text") or self._auto_pick_voice_to_text()
|
||||
new_t2v = conf().get("text_to_voice", "google")
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
new_v2t = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
new_t2v = const.LINKAI
|
||||
self.btype["voice_to_text"] = new_v2t
|
||||
self.btype["text_to_voice"] = new_t2v
|
||||
self.bots.pop("voice_to_text", None)
|
||||
self.bots.pop("text_to_voice", None)
|
||||
logger.info(f"[Bridge] voice refreshed: voice_to_text={new_v2t}, text_to_voice={new_t2v}")
|
||||
|
||||
@staticmethod
|
||||
def _auto_pick_voice_to_text() -> str:
|
||||
"""Pick an ASR provider by configured api keys when voice_to_text is
|
||||
unset. Order matches the web console: openai → dashscope → zhipu →
|
||||
linkai. Falls back to 'openai' when nothing is configured so the
|
||||
original "missing key" error is preserved.
|
||||
"""
|
||||
def has(k: str) -> bool:
|
||||
v = (conf().get(k) or "").strip()
|
||||
return v != "" and v not in ("YOUR API KEY", "YOUR_API_KEY")
|
||||
|
||||
for key, provider in (
|
||||
("open_ai_api_key", "openai"),
|
||||
("dashscope_api_key", "dashscope"),
|
||||
("zhipu_ai_api_key", "zhipu"),
|
||||
("linkai_api_key", "linkai"),
|
||||
):
|
||||
if has(key):
|
||||
return provider
|
||||
return "openai"
|
||||
|
||||
# 模型对应的接口
|
||||
def get_bot(self, typename):
|
||||
if self.bots.get(typename) is None:
|
||||
|
||||
@@ -13,12 +13,44 @@ class Channel(object):
|
||||
channel_type = ""
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
||||
|
||||
def __init__(self):
|
||||
import threading
|
||||
self._startup_event = threading.Event()
|
||||
self._startup_error = None
|
||||
self.cloud_mode = False # set to True by ChannelManager when running with cloud client
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
init channel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def report_startup_success(self):
|
||||
self._startup_error = None
|
||||
self._startup_event.set()
|
||||
|
||||
def report_startup_error(self, error: str):
|
||||
self._startup_error = error
|
||||
self._startup_event.set()
|
||||
|
||||
def wait_startup(self, timeout: float = 3) -> (bool, str):
|
||||
"""
|
||||
Wait for channel startup result.
|
||||
Returns (success: bool, error_msg: str).
|
||||
"""
|
||||
ready = self._startup_event.wait(timeout=timeout)
|
||||
if not ready:
|
||||
return True, ""
|
||||
if self._startup_error:
|
||||
return False, self._startup_error
|
||||
return True, ""
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
stop channel gracefully, called before restart
|
||||
"""
|
||||
pass
|
||||
|
||||
def handle_text(self, msg):
|
||||
"""
|
||||
process received msg
|
||||
@@ -41,7 +73,7 @@ class Channel(object):
|
||||
Build reply content, using agent if enabled in config
|
||||
"""
|
||||
# Check if agent mode is enabled
|
||||
use_agent = conf().get("agent", False)
|
||||
use_agent = conf().get("agent", True)
|
||||
|
||||
if use_agent:
|
||||
try:
|
||||
@@ -51,11 +83,14 @@ class Channel(object):
|
||||
if context and "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
|
||||
# Read on_event callback injected by the channel (e.g. web SSE)
|
||||
on_event = context.get("on_event") if context else None
|
||||
|
||||
# Use agent bridge to handle the query
|
||||
return Bridge().fetch_agent_reply(
|
||||
query=query,
|
||||
context=context,
|
||||
on_event=None,
|
||||
on_event=on_event,
|
||||
clear_history=False
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -12,16 +12,7 @@ def create_channel(channel_type) -> Channel:
|
||||
:return: channel instance
|
||||
"""
|
||||
ch = Channel()
|
||||
if channel_type == "wx":
|
||||
from channel.wechat.wechat_channel import WechatChannel
|
||||
ch = WechatChannel()
|
||||
elif channel_type == "wxy":
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
ch = WechatyChannel()
|
||||
elif channel_type == "wcf":
|
||||
from channel.wechat.wcf_channel import WechatfChannel
|
||||
ch = WechatfChannel()
|
||||
elif channel_type == "terminal":
|
||||
if channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == 'web':
|
||||
@@ -36,15 +27,34 @@ def create_channel(channel_type) -> Channel:
|
||||
elif channel_type == "wechatcom_app":
|
||||
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
|
||||
ch = WechatComAppChannel()
|
||||
elif channel_type == "wework":
|
||||
from channel.wework.wework_channel import WeworkChannel
|
||||
ch = WeworkChannel()
|
||||
elif channel_type == const.WECHAT_KF:
|
||||
from channel.wechat_kf.wechat_kf_channel import WechatKfChannel
|
||||
ch = WechatKfChannel()
|
||||
elif channel_type == const.FEISHU:
|
||||
from channel.feishu.feishu_channel import FeiShuChanel
|
||||
ch = FeiShuChanel()
|
||||
elif channel_type == const.DINGTALK:
|
||||
from channel.dingtalk.dingtalk_channel import DingTalkChanel
|
||||
ch = DingTalkChanel()
|
||||
elif channel_type == const.WECOM_BOT:
|
||||
from channel.wecom_bot.wecom_bot_channel import WecomBotChannel
|
||||
ch = WecomBotChannel()
|
||||
elif channel_type == const.QQ:
|
||||
from channel.qq.qq_channel import QQChannel
|
||||
ch = QQChannel()
|
||||
elif channel_type == const.TELEGRAM:
|
||||
from channel.telegram.telegram_channel import TelegramChannel
|
||||
ch = TelegramChannel()
|
||||
elif channel_type == const.SLACK:
|
||||
from channel.slack.slack_channel import SlackChannel
|
||||
ch = SlackChannel()
|
||||
elif channel_type == const.DISCORD:
|
||||
from channel.discord.discord_channel import DiscordChannel
|
||||
ch = DiscordChannel()
|
||||
elif channel_type in (const.WEIXIN, "wx"):
|
||||
from channel.weixin.weixin_channel import WeixinChannel
|
||||
ch = WeixinChannel()
|
||||
channel_type = const.WEIXIN
|
||||
else:
|
||||
raise RuntimeError
|
||||
ch.channel_type = channel_type
|
||||
|
||||
@@ -10,6 +10,7 @@ from bridge.reply import *
|
||||
from channel.channel import Channel
|
||||
from common.dequeue import Dequeue
|
||||
from common import memory
|
||||
from common.i18n import t as _t
|
||||
from plugins import *
|
||||
|
||||
try:
|
||||
@@ -24,11 +25,17 @@ handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Instance-level attributes so each channel subclass has its own
|
||||
# independent session queue and lock. Previously these were class-level,
|
||||
# which caused contexts from one channel (e.g. Feishu) to be consumed
|
||||
# by another channel's consume() thread (e.g. Web), leading to errors
|
||||
# like "No request_id found in context".
|
||||
self.futures = {}
|
||||
self.sessions = {}
|
||||
self.lock = threading.Lock()
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
@@ -37,9 +44,8 @@ class ChatChannel(Channel):
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
# context首次传入时,origin_ctype是None,
|
||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
# context首次传入时,receiver是None,根据类型设置receiver
|
||||
@@ -166,7 +172,13 @@ class ChatChannel(Channel):
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
# Voice input replies with voice when either voice_reply_voice
|
||||
# (mirror voice) or the global always_reply_voice toggle is on.
|
||||
if (
|
||||
"desire_rtype" not in context
|
||||
and (conf().get("voice_reply_voice") or conf().get("always_reply_voice"))
|
||||
and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE
|
||||
):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
return context
|
||||
|
||||
@@ -254,11 +266,13 @@ class ChatChannel(Channel):
|
||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||
logger.error("[chat_channel]reply type not support: " + str(reply.type))
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "不支持发送的消息类型: " + str(reply.type)
|
||||
reply.content = _t("不支持发送的消息类型: ", "Unsupported message type: ") + str(reply.type)
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
# Preserve original text for the "text-then-voice" pattern in _send_reply.
|
||||
context["voice_reply_text"] = reply.content
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context.get("isgroup", False):
|
||||
@@ -292,8 +306,12 @@ class ChatChannel(Channel):
|
||||
logger.debug("[chat_channel] sending reply: {}, context: {}".format(reply, context))
|
||||
|
||||
# 如果是文本回复,尝试提取并发送图片
|
||||
if reply.type == ReplyType.TEXT:
|
||||
# Web channel renders images/videos inline via renderMarkdown,
|
||||
# so skip the extract-and-send step to avoid duplicate media.
|
||||
if reply.type == ReplyType.TEXT and context.get("channel_type") != "web":
|
||||
self._extract_and_send_images(reply, context)
|
||||
elif reply.type == ReplyType.TEXT:
|
||||
self._send(reply, context)
|
||||
# 如果是图片回复但带有文本内容,先发文本再发图片
|
||||
elif reply.type == ReplyType.IMAGE_URL and hasattr(reply, 'text_content') and reply.text_content:
|
||||
# 先发送文本
|
||||
@@ -302,6 +320,15 @@ class ChatChannel(Channel):
|
||||
# 短暂延迟后发送图片
|
||||
time.sleep(0.3)
|
||||
self._send(reply, context)
|
||||
# Send text bubble before voice, unless channel already streamed
|
||||
# the text (feishu) or natively renders STT under the voice (wechatcom).
|
||||
elif reply.type == ReplyType.VOICE and context.get("voice_reply_text") \
|
||||
and not context.get("feishu_streamed") \
|
||||
and context.get("channel_type") not in ("wechatcom_app",):
|
||||
text_reply = Reply(ReplyType.TEXT, context.get("voice_reply_text"))
|
||||
self._send(text_reply, context)
|
||||
time.sleep(0.3)
|
||||
self._send(reply, context)
|
||||
else:
|
||||
self._send(reply, context)
|
||||
|
||||
@@ -342,38 +369,30 @@ class ChatChannel(Channel):
|
||||
if media_items:
|
||||
logger.info(f"[chat_channel] Extracted {len(media_items)} media item(s) from reply")
|
||||
|
||||
# 先发送文本(保持原文本不变)
|
||||
# Send text first (the frontend will embed video players via renderMarkdown).
|
||||
logger.info(f"[chat_channel] Sending text content before media: {reply.content[:100]}...")
|
||||
self._send(reply, context)
|
||||
logger.info(f"[chat_channel] Text sent, now sending {len(media_items)} media item(s)")
|
||||
|
||||
# 然后逐个发送媒体文件
|
||||
for i, (url, media_type) in enumerate(media_items):
|
||||
try:
|
||||
# 判断是本地文件还是URL
|
||||
# Determine whether it is a remote URL or a local file.
|
||||
if url.startswith(('http://', 'https://')):
|
||||
# 网络资源
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型发送
|
||||
media_reply = Reply(ReplyType.FILE, url)
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, url)
|
||||
elif os.path.exists(url):
|
||||
# 本地文件
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.FILE, f"file://{url}")
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, f"file://{url}")
|
||||
else:
|
||||
logger.warning(f"[chat_channel] Media file not found or invalid URL: {url}")
|
||||
continue
|
||||
|
||||
# 发送媒体文件(添加小延迟避免频率限制)
|
||||
if i > 0:
|
||||
time.sleep(0.5)
|
||||
self._send(media_reply, context)
|
||||
@@ -420,19 +439,55 @@ class ChatChannel(Channel):
|
||||
|
||||
return func
|
||||
|
||||
# Chat commands that must bypass the per-session serial queue,
|
||||
# otherwise /cancel would queue behind the task it tries to cancel.
|
||||
# Use /cancel (not /stop) to avoid colliding with `cow stop` CLI.
|
||||
_BYPASS_QUEUE_COMMANDS = ("/cancel",)
|
||||
|
||||
def produce(self, context: Context):
|
||||
session_id = context["session_id"]
|
||||
|
||||
# Fast path: /cancel must not enter the queue.
|
||||
if context.type == ContextType.TEXT and context.content:
|
||||
stripped = context.content.strip().lower()
|
||||
if stripped in self._BYPASS_QUEUE_COMMANDS:
|
||||
self._handle_cancel_command(context, session_id)
|
||||
return
|
||||
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = [
|
||||
Dequeue(),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)),
|
||||
]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
else:
|
||||
self.sessions[session_id][0].put(context)
|
||||
|
||||
def _handle_cancel_command(self, context: Context, session_id: str) -> None:
|
||||
"""Cancel any in-flight agent run for *session_id* and reply inline.
|
||||
|
||||
Runs synchronously on the caller's thread. Reply is sent through
|
||||
_send_reply so plugins (e.g. logging) still observe it.
|
||||
"""
|
||||
try:
|
||||
from agent.protocol import get_cancel_registry
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
cancelled = get_cancel_registry().cancel_session(session_id)
|
||||
text = (
|
||||
_t("🛑 已中止", "🛑 Cancelled")
|
||||
if cancelled > 0
|
||||
else _t("当前没有可中止的任务。", "Nothing to cancel.")
|
||||
)
|
||||
logger.info(
|
||||
f"[chat_channel] /cancel fast-path: session={session_id}, cancelled={cancelled}"
|
||||
)
|
||||
self._send_reply(context, Reply(ReplyType.TEXT, text))
|
||||
except Exception as e:
|
||||
logger.warning(f"[chat_channel] /cancel fast-path failed: {e}")
|
||||
|
||||
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
||||
def consume(self):
|
||||
while True:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
||||
Unified chat message class for different channel implementations.
|
||||
|
||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from dingtalk_stream.card_replier import CardReplier
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel
|
||||
from common.utils import expand_path
|
||||
from channel.dingtalk.dingtalk_message import DingTalkMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
@@ -85,17 +86,15 @@ def _check(func):
|
||||
|
||||
@singleton
|
||||
class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
|
||||
def setup_logger(self):
|
||||
logger = logging.getLogger()
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter('%(asctime)s %(name)-8s %(levelname)-8s %(message)s [%(filename)s:%(lineno)d]'))
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
return logger
|
||||
# Suppress verbose logs from dingtalk_stream SDK
|
||||
logging.getLogger("dingtalk_stream").setLevel(logging.WARNING)
|
||||
return logging.getLogger("DingTalk")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -103,6 +102,9 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
self._stream_client = None
|
||||
self._running = False
|
||||
self._event_loop = None
|
||||
logger.debug("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
@@ -115,12 +117,130 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
# Robot code cache (extracted from incoming messages)
|
||||
self._robot_code = None
|
||||
|
||||
def _open_connection(self, client):
|
||||
"""
|
||||
Open a DingTalk stream connection directly, bypassing SDK's internal error-swallowing.
|
||||
Returns (connection_dict, error_str). On success error_str is empty; on failure
|
||||
connection_dict is None and error_str contains a human-readable message.
|
||||
"""
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://api.dingtalk.com/v1.0/gateway/connections/open",
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
json={
|
||||
"clientId": client.credential.client_id,
|
||||
"clientSecret": client.credential.client_secret,
|
||||
"subscriptions": [{"type": "CALLBACK",
|
||||
"topic": dingtalk_stream.chatbot.ChatbotMessage.TOPIC}],
|
||||
"ua": "dingtalk-sdk-python/cow",
|
||||
"localIp": "",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
body = resp.json()
|
||||
if not resp.ok:
|
||||
code = body.get("code", resp.status_code)
|
||||
message = body.get("message", resp.reason)
|
||||
return None, f"open connection failed: [{code}] {message}"
|
||||
return body, ""
|
||||
except Exception as e:
|
||||
return None, f"open connection failed: {e}"
|
||||
|
||||
def startup(self):
|
||||
import asyncio
|
||||
self.dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
self.dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
self._running = True
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
self._stream_client = client
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
logger.info("[DingTalk] ✅ Stream connected, ready to receive messages")
|
||||
client.start_forever()
|
||||
logger.info("[DingTalk] ✅ Stream client initialized, ready to receive messages")
|
||||
|
||||
# Run the connection loop ourselves instead of delegating to client.start(),
|
||||
# so we can get detailed error messages and respond to stop() quickly.
|
||||
import urllib.parse as _urlparse
|
||||
import websockets as _ws
|
||||
import json as _json
|
||||
client.pre_start()
|
||||
_first_connect = True
|
||||
while self._running:
|
||||
# Open connection using our own request so we get detailed error info.
|
||||
connection, err_msg = self._open_connection(client)
|
||||
|
||||
if connection is None:
|
||||
if _first_connect:
|
||||
logger.warning(f"[DingTalk] {err_msg}")
|
||||
self.report_startup_error(err_msg)
|
||||
_first_connect = False
|
||||
else:
|
||||
logger.warning(f"[DingTalk] {err_msg}, retrying in 10s...")
|
||||
|
||||
# Interruptible sleep: checks _running every 100ms.
|
||||
for _ in range(100):
|
||||
if not self._running:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
if _first_connect:
|
||||
logger.info("[DingTalk] ✅ Connected to DingTalk stream")
|
||||
self.report_startup_success()
|
||||
_first_connect = False
|
||||
else:
|
||||
logger.info("[DingTalk] Reconnected to DingTalk stream")
|
||||
|
||||
# Run the WebSocket session in an asyncio loop.
|
||||
uri = '%s?ticket=%s' % (
|
||||
connection['endpoint'],
|
||||
_urlparse.quote_plus(connection['ticket'])
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
self._event_loop = loop
|
||||
try:
|
||||
async def _session():
|
||||
async with _ws.connect(uri) as websocket:
|
||||
client.websocket = websocket
|
||||
async for raw_message in websocket:
|
||||
json_message = _json.loads(raw_message)
|
||||
result = await client.route_message(json_message)
|
||||
if result == dingtalk_stream.DingTalkStreamClient.TAG_DISCONNECT:
|
||||
break
|
||||
|
||||
loop.run_until_complete(_session())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("[DingTalk] Session loop received stop signal, exiting")
|
||||
break
|
||||
except Exception as e:
|
||||
if not self._running:
|
||||
break
|
||||
logger.warning(f"[DingTalk] Stream session error: {e}, reconnecting in 3s...")
|
||||
for _ in range(30):
|
||||
if not self._running:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
finally:
|
||||
self._event_loop = None
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("[DingTalk] Startup loop exited")
|
||||
|
||||
def stop(self):
|
||||
logger.info("[DingTalk] stop() called, setting _running=False")
|
||||
self._running = False
|
||||
loop = self._event_loop
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
logger.info("[DingTalk] Sent stop signal to event loop")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DingTalk] Error stopping event loop: {e}")
|
||||
self._stream_client = None
|
||||
logger.info("[DingTalk] stop() completed")
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
@@ -276,7 +396,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
|
||||
# 保存到临时文件
|
||||
file_name = os.path.basename(file_path) or f"media_{uuid.uuid4()}"
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
temp_file = os.path.join(tmp_dir, file_name)
|
||||
@@ -457,23 +577,21 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
|
||||
|
||||
# 缓存 robot_code,用于后续图片下载
|
||||
if hasattr(incoming_message, 'robot_code'):
|
||||
self._robot_code_cache = incoming_message.robot_code
|
||||
|
||||
# Debug: 打印完整的 event 数据
|
||||
logger.debug(f"[DingTalk] ===== Incoming Message Debug =====")
|
||||
logger.debug(f"[DingTalk] callback.data keys: {callback.data.keys() if hasattr(callback.data, 'keys') else 'N/A'}")
|
||||
logger.debug(f"[DingTalk] incoming_message attributes: {dir(incoming_message)}")
|
||||
logger.debug(f"[DingTalk] robot_code: {getattr(incoming_message, 'robot_code', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] chatbot_corp_id: {getattr(incoming_message, 'chatbot_corp_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] chatbot_user_id: {getattr(incoming_message, 'chatbot_user_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] conversation_id: {getattr(incoming_message, 'conversation_id', 'N/A')}")
|
||||
logger.debug(f"[DingTalk] Raw callback.data: {callback.data}")
|
||||
logger.debug(f"[DingTalk] =====================================")
|
||||
|
||||
image_download_handler = self # 传入方法所在的类实例
|
||||
|
||||
# Filter out stale messages from before channel startup (offline backlog)
|
||||
create_at = getattr(incoming_message, 'create_at', None)
|
||||
if create_at:
|
||||
msg_age_s = time.time() - int(create_at) / 1000
|
||||
if msg_age_s > 60:
|
||||
logger.warning(f"[DingTalk] stale msg filtered (age={msg_age_s:.0f}s), "
|
||||
f"msg_id={getattr(incoming_message, 'message_id', 'N/A')}")
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
|
||||
image_download_handler = self
|
||||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
|
||||
|
||||
if dingtalk_msg.is_group:
|
||||
@@ -482,8 +600,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.handle_single(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] process error: {e}")
|
||||
logger.exception(e) # 打印完整堆栈跟踪
|
||||
logger.error(f"[DingTalk] process error: {e}", exc_info=True)
|
||||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
|
||||
|
||||
@time_checker
|
||||
@@ -607,7 +724,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
logger.info(f"[DingTalk] send() called with reply.type={reply.type}, content_length={len(str(reply.content))}")
|
||||
logger.debug(f"[DingTalk] send() called with reply.type={reply.type}, content_length={len(str(reply.content))}")
|
||||
receiver = context["receiver"]
|
||||
|
||||
# Check if msg exists (for scheduled tasks, msg might be None)
|
||||
@@ -647,7 +764,7 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
robot_code = msg.robot_code
|
||||
if robot_code and robot_code != self._robot_code:
|
||||
self._robot_code = robot_code
|
||||
logger.info(f"[DingTalk] Cached robot_code: {robot_code}")
|
||||
logger.debug(f"[DingTalk] Cached robot_code: {robot_code}")
|
||||
|
||||
isgroup = msg.is_group
|
||||
incoming_message = msg.incoming_message
|
||||
@@ -755,6 +872,48 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
self.reply_text("抱歉,文件上传失败", incoming_message)
|
||||
return
|
||||
|
||||
# Native sampleAudio. Upload only accepts ogg/amr, so convert TTS mp3/wav to amr.
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
logger.info(f"[DingTalk] Sending voice: {reply.content}")
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot get access token for voice")
|
||||
self.reply_text("抱歉,语音发送失败(无法获取token)", incoming_message)
|
||||
return
|
||||
|
||||
voice_path = reply.content
|
||||
if voice_path.startswith("file://"):
|
||||
voice_path = voice_path[7:]
|
||||
|
||||
amr_path = voice_path
|
||||
duration_ms = 0
|
||||
if not voice_path.lower().endswith((".amr", ".ogg")):
|
||||
try:
|
||||
from voice.audio_convert import any_to_amr
|
||||
amr_path = os.path.splitext(voice_path)[0] + ".amr"
|
||||
duration_ms = int(any_to_amr(voice_path, amr_path) or 0)
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Failed to convert voice to amr: {e}")
|
||||
self.reply_text("抱歉,语音转码失败", incoming_message)
|
||||
return
|
||||
|
||||
media_id = self.upload_media(amr_path, media_type="voice")
|
||||
if not media_id:
|
||||
logger.error("[DingTalk] Failed to upload voice media")
|
||||
self.reply_text("抱歉,语音上传失败", incoming_message)
|
||||
return
|
||||
|
||||
msg_param = {
|
||||
"mediaId": media_id,
|
||||
"duration": str(duration_ms or 1000),
|
||||
}
|
||||
success = self._send_file_message(
|
||||
access_token, incoming_message, "sampleAudio", msg_param, isgroup
|
||||
)
|
||||
if not success:
|
||||
self.reply_text("抱歉,语音发送失败", incoming_message)
|
||||
return
|
||||
|
||||
# 处理文本消息
|
||||
elif reply.type == ReplyType.TEXT:
|
||||
logger.info(f"[DingTalk] Sending text message, length={len(reply.content)}")
|
||||
|
||||
@@ -9,6 +9,7 @@ from channel.chat_message import ChatMessage
|
||||
# -*- coding=utf-8 -*-
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
@@ -49,7 +50,7 @@ class DingTalkMessage(ChatMessage):
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
|
||||
# 下载到工作空间 tmp 目录
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
@@ -67,7 +68,7 @@ class DingTalkMessage(ChatMessage):
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
# 下载到工作空间 tmp 目录
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
|
||||
0
channel/discord/__init__.py
Normal file
0
channel/discord/__init__.py
Normal file
500
channel/discord/discord_channel.py
Normal file
500
channel/discord/discord_channel.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
Discord channel via the Gateway (WebSocket) using discord.py.
|
||||
|
||||
Features:
|
||||
- Direct message & guild channel chat (text / image / file)
|
||||
- Guild trigger: @mention or reply-to-bot (configurable)
|
||||
- /cancel fast-path matches Web channel behaviour
|
||||
- Gateway long connection: no public IP / callback URL required, works behind NAT
|
||||
|
||||
Implementation note:
|
||||
discord.py is async-first. We run the client inside a dedicated thread
|
||||
with its own asyncio loop so the rest of cow (which is sync) stays
|
||||
untouched. Inbound messages are dispatched onto cow's existing sync
|
||||
ChatChannel.produce() pipeline; outbound send() schedules coroutines
|
||||
back onto that loop via asyncio.run_coroutine_threadsafe.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.discord.discord_message import DiscordMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
# Discord caps a single message at 2000 chars; split conservatively below.
|
||||
DISCORD_MSG_LIMIT = 1900
|
||||
|
||||
|
||||
@singleton
|
||||
class DiscordChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = ""
|
||||
self.bot_user_id = "" # used to strip @mention and ignore self messages
|
||||
self.bot_username = ""
|
||||
self._client = None
|
||||
self._loop = None
|
||||
self._loop_thread = None
|
||||
self._stop_event = threading.Event()
|
||||
# Idempotent dedup; guard against rare duplicate dispatch
|
||||
self._received_msgs = ExpiredDict(60 * 60 * 1)
|
||||
|
||||
# Disable group whitelist / prefix checks (we handle triggering ourselves
|
||||
# in _should_reply_in_guild), aligned with telegram / slack channels.
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.bot_token = conf().get("discord_token", "")
|
||||
if not self.bot_token:
|
||||
err = "[Discord] discord_token is required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
try:
|
||||
import discord
|
||||
except ImportError:
|
||||
err = (
|
||||
"[Discord] discord.py is not installed. "
|
||||
"Run: pip install discord.py"
|
||||
)
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
# Run the asyncio event loop in a dedicated thread so the sync cow body
|
||||
# is untouched.
|
||||
self._loop = asyncio.new_event_loop()
|
||||
|
||||
def _run_loop():
|
||||
asyncio.set_event_loop(self._loop)
|
||||
try:
|
||||
self._loop.run_until_complete(self._async_main(discord))
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] event loop crashed: {e}", exc_info=True)
|
||||
self.report_startup_error(str(e))
|
||||
finally:
|
||||
try:
|
||||
self._loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[Discord] event loop exited")
|
||||
|
||||
self._loop_thread = threading.Thread(target=_run_loop, daemon=True, name="discord-loop")
|
||||
self._loop_thread.start()
|
||||
# Block startup() until the loop thread exits, matching other channels'
|
||||
# behaviour (startup is a blocking call).
|
||||
self._loop_thread.join()
|
||||
|
||||
async def _async_main(self, discord):
|
||||
"""Build the discord client, register handlers, and connect to the Gateway."""
|
||||
# message_content is a privileged intent; it must be enabled in the
|
||||
# Developer Portal (Bot -> Privileged Gateway Intents) to read text.
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
client = discord.Client(intents=intents)
|
||||
self._client = client
|
||||
|
||||
channel = self
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
channel.bot_user_id = str(client.user.id)
|
||||
channel.bot_username = client.user.name or ""
|
||||
channel.name = channel.bot_user_id # ChatChannel uses self.name to strip @-mention
|
||||
logger.info(f"[Discord] Bot logged in as {client.user} (id={client.user.id})")
|
||||
channel.report_startup_success()
|
||||
logger.info("[Discord] ✅ Discord bot ready, listening for messages")
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
await channel._on_message(message)
|
||||
|
||||
# Connect to the Gateway; discord.py auto-reconnects on transient errors.
|
||||
logger.info("[Discord] Connecting to Gateway...")
|
||||
|
||||
# client.start() handles login + Gateway connection and runs until
|
||||
# close(); it is the standard entrypoint across discord.py versions.
|
||||
runner_task = asyncio.create_task(client.start(self.bot_token))
|
||||
|
||||
# Block until stop()
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
if runner_task.done():
|
||||
# Surface a startup/connection failure (e.g. bad token)
|
||||
exc = runner_task.exception()
|
||||
if exc:
|
||||
logger.error(f"[Discord] client stopped: {exc}", exc_info=exc)
|
||||
self.report_startup_error(str(exc))
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
finally:
|
||||
try:
|
||||
if not client.is_closed():
|
||||
await client.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Discord] shutdown error: {e}")
|
||||
|
||||
def stop(self):
|
||||
logger.info("[Discord] stop() called")
|
||||
self._stop_event.set()
|
||||
if self._loop_thread and self._loop_thread.is_alive():
|
||||
try:
|
||||
self._loop_thread.join(timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[Discord] stop() completed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound: discord message -> ChatMessage -> ChatChannel.produce
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _on_message(self, message):
|
||||
"""Discord message entry: parse -> build ChatMessage -> produce()."""
|
||||
try:
|
||||
# Ignore our own messages and other bots. self._client.user may be
|
||||
# None until on_ready completes, so guard against that.
|
||||
if self._client and self._client.user and message.author.id == self._client.user.id:
|
||||
return
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
# Idempotent dedup
|
||||
msg_uid = f"{message.channel.id}:{message.id}"
|
||||
if self._received_msgs.get(msg_uid):
|
||||
return
|
||||
self._received_msgs[msg_uid] = True
|
||||
|
||||
# guild is None for DMs
|
||||
is_group = message.guild is not None
|
||||
|
||||
# Guild trigger gate (silently drop if not triggered)
|
||||
if is_group and not self._should_reply_in_guild(message):
|
||||
logger.debug(f"[Discord] guild message not triggered (need @mention or reply), skip")
|
||||
return
|
||||
|
||||
# Parse message type + download attachments if needed.
|
||||
ctype, content, caption = await self._parse_message(message)
|
||||
if ctype is None:
|
||||
logger.debug(f"[Discord] unsupported message type, skip. msg_id={message.id}")
|
||||
return
|
||||
|
||||
# Strip the bot mention from guild text/caption
|
||||
if is_group:
|
||||
if ctype == ContextType.TEXT and content:
|
||||
content = self._strip_at_mention(content)
|
||||
if caption:
|
||||
caption = self._strip_at_mention(caption)
|
||||
|
||||
dc_msg = DiscordMessage(
|
||||
message,
|
||||
is_group=is_group,
|
||||
bot_user_id=self.bot_user_id,
|
||||
ctype=ctype,
|
||||
content=content,
|
||||
)
|
||||
dc_msg.is_at = is_group # if we reached here in a guild, bot is mentioned/replied
|
||||
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
session_id = self._compute_session_id(message, is_group)
|
||||
|
||||
# Media + caption together: treat as a complete query and bypass the cache
|
||||
if ctype in (ContextType.IMAGE, ContextType.FILE) and caption:
|
||||
tag = "image" if ctype == ContextType.IMAGE else "file"
|
||||
merged_text = f"{caption}\n[{tag}: {content}]"
|
||||
dc_msg.ctype = ContextType.TEXT
|
||||
dc_msg.content = merged_text
|
||||
ctype = ContextType.TEXT
|
||||
logger.info(f"[Discord] Media+caption merged for session {session_id}")
|
||||
# fallthrough to the TEXT branch below
|
||||
|
||||
elif ctype == ContextType.IMAGE:
|
||||
file_cache.add(session_id, content, file_type="image")
|
||||
logger.info(f"[Discord] Image cached for session {session_id}, waiting for query...")
|
||||
return
|
||||
elif ctype == ContextType.FILE:
|
||||
file_cache.add(session_id, content, file_type="file")
|
||||
logger.info(f"[Discord] File cached for session {session_id}: {content}")
|
||||
return
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# Fast-path: /cancel mirrors Web channel behaviour
|
||||
if (content or "").strip().lower() in ("/cancel", "cancel"):
|
||||
await self._do_cancel(session_id, message)
|
||||
return
|
||||
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
tag = ftype if ftype in ("image", "video") else "file"
|
||||
refs.append(f"[{tag}: {fi['path']}]")
|
||||
dc_msg.content = (dc_msg.content or "") + "\n" + "\n".join(refs)
|
||||
file_cache.clear(session_id)
|
||||
logger.info(f"[Discord] Attached {len(cached_files)} cached file(s) to query")
|
||||
|
||||
context = self._compose_context(
|
||||
dc_msg.ctype,
|
||||
dc_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=dc_msg,
|
||||
# Replies use Discord's reply mechanism, no manual @mention needed
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = str(message.channel.id)
|
||||
context["discord_channel_id"] = message.channel.id
|
||||
context["discord_reply_to_msg_id"] = message.id if is_group else None
|
||||
self.produce(context)
|
||||
logger.debug(f"[Discord] received: type={ctype}, content={str(dc_msg.content)[:80]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] _on_message error: {e}", exc_info=True)
|
||||
|
||||
async def _do_cancel(self, session_id: str, message):
|
||||
"""Fast-path: /cancel calls cancel_session directly without going through agent."""
|
||||
try:
|
||||
from agent.protocol import get_cancel_registry
|
||||
cancelled = get_cancel_registry().cancel_session(session_id)
|
||||
text = "Current task cancelled." if cancelled else "No running task to cancel."
|
||||
await message.channel.send(text)
|
||||
logger.info(f"[Discord] /cancel session={session_id}, cancelled={cancelled}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] /cancel error: {e}", exc_info=True)
|
||||
|
||||
async def _parse_message(self, message):
|
||||
"""Parse a discord message and return (ctype, content, caption).
|
||||
|
||||
- content is text for ContextType.TEXT, otherwise the local file path
|
||||
- caption is the optional text accompanying an attachment; empty for plain text
|
||||
"""
|
||||
text = (message.content or "").strip()
|
||||
attachments = message.attachments or []
|
||||
|
||||
if attachments:
|
||||
# Handle the first attachment; caption is the accompanying message text
|
||||
att = attachments[0]
|
||||
content_type = (att.content_type or "").lower()
|
||||
name = att.filename or str(att.id)
|
||||
path = await self._download_attachment(att, name)
|
||||
if not path:
|
||||
return (None, None, "")
|
||||
is_image = content_type.startswith("image/") or name.lower().endswith(
|
||||
(".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
)
|
||||
if is_image:
|
||||
return (ContextType.IMAGE, path, text)
|
||||
return (ContextType.FILE, path, text)
|
||||
|
||||
if text:
|
||||
return (ContextType.TEXT, text, "")
|
||||
|
||||
return (None, None, "")
|
||||
|
||||
async def _download_attachment(self, attachment, name: str):
|
||||
"""Download a discord attachment into the local tmp dir; return path or None."""
|
||||
try:
|
||||
tmp_dir = DiscordMessage.get_tmp_dir()
|
||||
safe_name = re.sub(r"[^\w.\-]", "_", name)
|
||||
# Prefix with attachment id to avoid name collisions
|
||||
local_path = os.path.join(tmp_dir, f"{attachment.id}_{safe_name}")
|
||||
await attachment.save(local_path)
|
||||
logger.debug(f"[Discord] downloaded {name} -> {local_path}")
|
||||
return local_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] download_attachment failed ({name}): {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Guild trigger logic
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _should_reply_in_guild(self, message) -> bool:
|
||||
"""Decide whether to reply to a guild channel message based on configuration."""
|
||||
mode = conf().get("discord_group_trigger", "mention_or_reply")
|
||||
if mode == "all":
|
||||
return True
|
||||
|
||||
# self._client.user may be None until on_ready completes
|
||||
if not self._client or not self._client.user:
|
||||
return False
|
||||
|
||||
# 1) Mentioned (direct @bot, not @everyone / @role)
|
||||
if self._client.user in message.mentions:
|
||||
return True
|
||||
|
||||
# 2) Reply to a bot message
|
||||
if mode == "mention_or_reply":
|
||||
ref = message.reference
|
||||
resolved = getattr(ref, "resolved", None) if ref else None
|
||||
if resolved and getattr(resolved, "author", None):
|
||||
if resolved.author.id == self._client.user.id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _strip_at_mention(self, content: str) -> str:
|
||||
"""Strip <@BOT_ID> / <@!BOT_ID> from guild text."""
|
||||
if not content or not self.bot_user_id:
|
||||
return content
|
||||
pattern = re.compile(r"<@!?" + re.escape(self.bot_user_id) + r">")
|
||||
return pattern.sub("", content).strip()
|
||||
|
||||
@staticmethod
|
||||
def _compute_session_id(message, is_group: bool) -> str:
|
||||
channel_id = message.channel.id
|
||||
user_id = message.author.id
|
||||
if is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
return f"discord_channel_{channel_id}"
|
||||
return f"discord_channel_{channel_id}_{user_id}"
|
||||
return f"discord_user_{user_id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override _compose_context: skip the parent's group whitelist/at checks
|
||||
# (already handled via _should_reply_in_guild). Same idea as telegram / slack.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
if cmsg.is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = (content or "").strip()
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif ctype == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and (
|
||||
conf().get("voice_reply_voice") or conf().get("always_reply_voice")
|
||||
):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound: ChatChannel.send -> Discord Gateway/REST
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""Called from cow's sync main thread; marshal the coroutine onto the loop thread."""
|
||||
if self._loop is None or self._client is None:
|
||||
logger.warning("[Discord] client not ready, drop reply")
|
||||
return
|
||||
|
||||
channel_id = context.get("discord_channel_id")
|
||||
if channel_id is None:
|
||||
logger.warning("[Discord] no discord_channel_id in context, drop reply")
|
||||
return
|
||||
|
||||
coro = self._async_send(reply, channel_id)
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
future.result(timeout=180)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] send failed: {e}")
|
||||
|
||||
async def _async_send(self, reply: Reply, channel_id):
|
||||
try:
|
||||
import discord
|
||||
|
||||
channel = self._client.get_channel(channel_id)
|
||||
if channel is None:
|
||||
# Not in cache (e.g. DM channel); fetch it explicitly
|
||||
channel = await self._client.fetch_channel(channel_id)
|
||||
|
||||
rtype = reply.type
|
||||
content = reply.content
|
||||
|
||||
if rtype in (ReplyType.TEXT, ReplyType.INFO, ReplyType.ERROR):
|
||||
text = str(content) if content is not None else ""
|
||||
if not text:
|
||||
return
|
||||
for chunk in _split_text(text, DISCORD_MSG_LIMIT):
|
||||
await channel.send(chunk)
|
||||
|
||||
elif rtype == ReplyType.IMAGE:
|
||||
# Already a local BytesIO; send it directly
|
||||
content.seek(0)
|
||||
await channel.send(file=discord.File(content, filename="image.png"))
|
||||
|
||||
elif rtype == ReplyType.IMAGE_URL:
|
||||
url = str(content)
|
||||
if url.startswith("file://"):
|
||||
local = url[7:]
|
||||
await channel.send(file=discord.File(local))
|
||||
else:
|
||||
# Post the URL as text; Discord will unfurl it as an image preview
|
||||
await channel.send(url)
|
||||
|
||||
elif rtype in (ReplyType.VOICE, ReplyType.FILE):
|
||||
local = content[7:] if isinstance(content, str) and content.startswith("file://") else content
|
||||
caption = getattr(reply, "text_content", None) or None
|
||||
await channel.send(content=caption, file=discord.File(local))
|
||||
|
||||
else:
|
||||
# Fallback: send as plain text
|
||||
await channel.send(str(content))
|
||||
|
||||
logger.info(f"[Discord] sent reply (type={rtype}, channel={channel_id})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] _async_send error: {e}", exc_info=True)
|
||||
|
||||
|
||||
def _split_text(text: str, limit: int):
|
||||
"""Split long text preferring line breaks to keep markdown structure intact."""
|
||||
if len(text) <= limit:
|
||||
yield text
|
||||
return
|
||||
buf = []
|
||||
size = 0
|
||||
for line in text.splitlines(keepends=True):
|
||||
if size + len(line) > limit and buf:
|
||||
yield "".join(buf)
|
||||
buf, size = [], 0
|
||||
# Hard-split single lines that exceed the limit
|
||||
while len(line) > limit:
|
||||
yield line[:limit]
|
||||
line = line[limit:]
|
||||
buf.append(line)
|
||||
size += len(line)
|
||||
if buf:
|
||||
yield "".join(buf)
|
||||
60
channel/discord/discord_message.py
Normal file
60
channel/discord/discord_message.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Discord message adapter.
|
||||
|
||||
Convert a discord.py Message into cow's unified ChatMessage.
|
||||
File downloads are NOT performed here; the channel layer downloads
|
||||
attachments on demand inside the async event loop.
|
||||
"""
|
||||
import os
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
class DiscordMessage(ChatMessage):
|
||||
"""Wrap a discord.py Message into the unified ChatMessage."""
|
||||
|
||||
def __init__(self, message, is_group: bool = False, bot_user_id: str = "",
|
||||
ctype: ContextType = ContextType.TEXT, content: str = ""):
|
||||
super().__init__(message)
|
||||
# Basic fields
|
||||
self.msg_id = str(message.id)
|
||||
self.create_time = int(message.created_at.timestamp()) if message.created_at else 0
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
|
||||
author = message.author
|
||||
channel = message.channel
|
||||
|
||||
# Sender / chat info
|
||||
from_user_id = str(author.id)
|
||||
from_user_nick = getattr(author, "display_name", None) or getattr(author, "name", None) or from_user_id
|
||||
self.from_user_id = from_user_id
|
||||
self.from_user_nickname = from_user_nick
|
||||
self.to_user_id = bot_user_id or "discord_bot"
|
||||
self.to_user_nickname = bot_user_id or "discord_bot"
|
||||
|
||||
self.is_group = is_group
|
||||
if is_group:
|
||||
# Guild channel: other_user_id = channel_id, actual_user_id = sender id
|
||||
self.other_user_id = str(channel.id)
|
||||
self.other_user_nickname = getattr(channel, "name", None) or str(channel.id)
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = from_user_nick
|
||||
else:
|
||||
# DM: use channel_id so replies go back to the same DM channel
|
||||
self.other_user_id = str(channel.id)
|
||||
self.other_user_nickname = from_user_nick
|
||||
|
||||
# Whether the bot was triggered by @-mention (set by channel layer)
|
||||
self.is_at = False
|
||||
|
||||
@staticmethod
|
||||
def get_tmp_dir() -> str:
|
||||
"""Local download directory, aligned with other channels (agent_workspace/tmp)."""
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
@@ -140,6 +140,23 @@ python3 app.py
|
||||
|
||||
**解决**: 安装依赖 `pip install lark-oapi`
|
||||
|
||||
### SSL证书验证失败
|
||||
|
||||
```
|
||||
[Lark][ERROR] connect failed, err:[SSL:CERTIFICATE_VERIFY_FAILED] certificate verify failed: self signed certificate in certificate chain
|
||||
```
|
||||
|
||||
**原因**: 网络环境中存在自签名证书或SSL中间人代理(如企业代理、VPN等)
|
||||
|
||||
**解决**: 程序会自动检测SSL证书验证失败,并自动重试禁用证书验证的连接。无需手动配置。
|
||||
|
||||
当遇到证书错误时,日志会显示:
|
||||
```
|
||||
[FeiShu] SSL certificate verification disabled due to certificate error. This may happen when using corporate proxy or self-signed certificates.
|
||||
```
|
||||
|
||||
这是正常现象,程序会自动处理并继续运行。
|
||||
|
||||
### Webhook模式端口被占用
|
||||
|
||||
```
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
@@ -31,7 +32,7 @@ class FeishuMessage(ChatMessage):
|
||||
image_key = content.get("image_key")
|
||||
|
||||
# 下载图片到工作空间临时目录
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
image_path = os.path.join(tmp_dir, f"{image_key}.png")
|
||||
@@ -97,7 +98,7 @@ class FeishuMessage(ChatMessage):
|
||||
|
||||
if image_keys:
|
||||
# 如果包含图片,下载并在文本中引用本地路径
|
||||
workspace_root = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
@@ -143,7 +144,14 @@ class FeishuMessage(ChatMessage):
|
||||
file_key = content.get("file_key")
|
||||
file_name = content.get("file_name")
|
||||
|
||||
self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name)
|
||||
# 落到 agent_workspace/tmp 下(绝对路径),与图片处理一致;
|
||||
# 否则相对路径 ./tmp 在 agent 工作区里 read 时会找不到。
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
self.content = os.path.join(
|
||||
tmp_dir, f"{file_key}.{utils.get_path_suffix(file_name)}"
|
||||
)
|
||||
|
||||
def _download_file():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
@@ -161,6 +169,42 @@ class FeishuMessage(ChatMessage):
|
||||
else:
|
||||
logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
|
||||
self._prepare_fn = _download_file
|
||||
elif msg_type == "audio":
|
||||
# 飞书用户发送的语音消息类型为 "audio",文件为 opus 编码格式。
|
||||
# 映射为 ContextType.VOICE,交由 chat_channel 的语音转文字(STT)流程处理。
|
||||
# 文件通过 _prepare_fn 延迟下载,在 chat_channel 调用 cmsg.prepare() 时才执行。
|
||||
self.ctype = ContextType.VOICE
|
||||
content = json.loads(msg.get("content"))
|
||||
file_key = content.get("file_key")
|
||||
|
||||
# 落到 agent_workspace/tmp 下(绝对路径),保证语音 STT 流程可读到
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
self.content = os.path.join(tmp_dir, f"{file_key}.opus")
|
||||
logger.info(f"[FeiShu] audio message: file_key={file_key}, save_path={self.content}")
|
||||
|
||||
def _download_audio():
|
||||
logger.info(f"[FeiShu] downloading audio: file_key={file_key}, msg_id={self.msg_id}")
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}"
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
}
|
||||
params = {
|
||||
"type": "file"
|
||||
}
|
||||
try:
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
logger.info(f"[FeiShu] download audio response: status={response.status_code}, size={len(response.content)} bytes")
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"[FeiShu] audio saved to: {self.content}")
|
||||
else:
|
||||
logger.error(f"[FeiShu] Failed to download audio, key={file_key}, status={response.status_code}, res={response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] Exception downloading audio, key={file_key}: {e}", exc_info=True)
|
||||
self._prepare_fn = _download_audio
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))
|
||||
|
||||
|
||||
0
channel/qq/__init__.py
Normal file
0
channel/qq/__init__.py
Normal file
736
channel/qq/qq_channel.py
Normal file
736
channel/qq/qq_channel.py
Normal file
@@ -0,0 +1,736 @@
|
||||
"""
|
||||
QQ Bot channel via WebSocket long connection.
|
||||
|
||||
Supports:
|
||||
- Group chat (@bot), single chat (C2C), guild channel, guild DM
|
||||
- Text / image / file message send & receive
|
||||
- Heartbeat keep-alive and auto-reconnect with session resume
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
import websocket
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.qq.qq_message import QQMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.ws_client_compat import websocket_app_run_forever
|
||||
from config import conf
|
||||
|
||||
# Rich media file_type constants
|
||||
QQ_FILE_TYPE_IMAGE = 1
|
||||
QQ_FILE_TYPE_VIDEO = 2
|
||||
QQ_FILE_TYPE_VOICE = 3
|
||||
QQ_FILE_TYPE_FILE = 4
|
||||
|
||||
QQ_API_BASE = "https://api.sgroup.qq.com"
|
||||
|
||||
# Intents: GROUP_AND_C2C_EVENT(1<<25) | PUBLIC_GUILD_MESSAGES(1<<30)
|
||||
DEFAULT_INTENTS = (1 << 25) | (1 << 30)
|
||||
|
||||
# OpCode constants
|
||||
OP_DISPATCH = 0
|
||||
OP_HEARTBEAT = 1
|
||||
OP_IDENTIFY = 2
|
||||
OP_RESUME = 6
|
||||
OP_RECONNECT = 7
|
||||
OP_INVALID_SESSION = 9
|
||||
OP_HELLO = 10
|
||||
OP_HEARTBEAT_ACK = 11
|
||||
|
||||
# Resumable error codes
|
||||
RESUMABLE_CLOSE_CODES = {4008, 4009}
|
||||
|
||||
|
||||
@singleton
|
||||
class QQChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app_id = ""
|
||||
self.app_secret = ""
|
||||
|
||||
self._access_token = ""
|
||||
self._token_expires_at = 0
|
||||
|
||||
self._ws = None
|
||||
self._ws_thread = None
|
||||
self._heartbeat_thread = None
|
||||
self._connected = False
|
||||
self._stop_event = threading.Event()
|
||||
self._token_lock = threading.Lock()
|
||||
|
||||
self._session_id = None
|
||||
self._last_seq = None
|
||||
self._heartbeat_interval = 45000
|
||||
self._can_resume = False
|
||||
|
||||
self.received_msgs = ExpiredDict(60 * 60 * 7.1)
|
||||
self._msg_seq_counter = {}
|
||||
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.app_id = conf().get("qq_app_id", "")
|
||||
self.app_secret = conf().get("qq_app_secret", "")
|
||||
|
||||
if not self.app_id or not self.app_secret:
|
||||
err = "[QQ] qq_app_id and qq_app_secret are required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._refresh_access_token()
|
||||
if not self._access_token:
|
||||
err = "[QQ] Failed to get initial access_token"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_ws()
|
||||
|
||||
def stop(self):
|
||||
logger.info("[QQ] stop() called")
|
||||
self._stop_event.set()
|
||||
if self._ws:
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Access Token
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _refresh_access_token(self):
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://bots.qq.com/app/getAppAccessToken",
|
||||
json={"appId": self.app_id, "clientSecret": self.app_secret},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._access_token = data.get("access_token", "")
|
||||
expires_in = int(data.get("expires_in", 7200))
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
logger.debug(f"[QQ] Access token refreshed, expires_in={expires_in}s")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to refresh access_token: {e}")
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
with self._token_lock:
|
||||
if time.time() >= self._token_expires_at:
|
||||
self._refresh_access_token()
|
||||
return self._access_token
|
||||
|
||||
def _get_auth_headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": f"QQBot {self._get_access_token()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket connection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_ws_url(self) -> str:
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{QQ_API_BASE}/gateway",
|
||||
headers=self._get_auth_headers(),
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
url = resp.json().get("url", "")
|
||||
logger.debug(f"[QQ] Gateway URL: {url}")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to get gateway URL: {e}")
|
||||
return ""
|
||||
|
||||
def _start_ws(self):
|
||||
ws_url = self._get_ws_url()
|
||||
if not ws_url:
|
||||
logger.error("[QQ] Cannot start WebSocket without gateway URL")
|
||||
self.report_startup_error("Failed to get gateway URL")
|
||||
return
|
||||
|
||||
def _on_open(ws):
|
||||
logger.debug("[QQ] WebSocket connected, waiting for Hello...")
|
||||
|
||||
def _on_message(ws, raw):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
self._handle_ws_message(data)
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to handle ws message: {e}", exc_info=True)
|
||||
|
||||
def _on_error(ws, error):
|
||||
logger.error(f"[QQ] WebSocket error: {error}")
|
||||
|
||||
def _on_close(ws, close_status_code, close_msg):
|
||||
logger.warning(f"[QQ] WebSocket closed: status={close_status_code}, msg={close_msg}")
|
||||
self._connected = False
|
||||
if not self._stop_event.is_set():
|
||||
if close_status_code in RESUMABLE_CLOSE_CODES and self._session_id:
|
||||
self._can_resume = True
|
||||
logger.info("[QQ] Will attempt resume in 3s...")
|
||||
time.sleep(3)
|
||||
else:
|
||||
self._can_resume = False
|
||||
logger.info("[QQ] Will reconnect in 5s...")
|
||||
time.sleep(5)
|
||||
if not self._stop_event.is_set():
|
||||
self._start_ws()
|
||||
|
||||
self._ws = websocket.WebSocketApp(
|
||||
ws_url,
|
||||
on_open=_on_open,
|
||||
on_message=_on_message,
|
||||
on_error=_on_error,
|
||||
on_close=_on_close,
|
||||
)
|
||||
|
||||
def run_forever():
|
||||
try:
|
||||
websocket_app_run_forever(self._ws, ping_interval=0, reconnect=0)
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
logger.info("[QQ] WebSocket thread interrupted")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] WebSocket run_forever error: {e}")
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_forever, daemon=True)
|
||||
self._ws_thread.start()
|
||||
self._ws_thread.join()
|
||||
|
||||
def _ws_send(self, data: dict):
|
||||
if self._ws:
|
||||
self._ws.send(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Identify & Resume & Heartbeat
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_identify(self):
|
||||
self._ws_send({
|
||||
"op": OP_IDENTIFY,
|
||||
"d": {
|
||||
"token": f"QQBot {self._get_access_token()}",
|
||||
"intents": DEFAULT_INTENTS,
|
||||
"shard": [0, 1],
|
||||
"properties": {
|
||||
"$os": "linux",
|
||||
"$browser": "chatgpt-on-wechat",
|
||||
"$device": "chatgpt-on-wechat",
|
||||
},
|
||||
},
|
||||
})
|
||||
logger.debug(f"[QQ] Identify sent with intents={DEFAULT_INTENTS}")
|
||||
|
||||
def _send_resume(self):
|
||||
self._ws_send({
|
||||
"op": OP_RESUME,
|
||||
"d": {
|
||||
"token": f"QQBot {self._get_access_token()}",
|
||||
"session_id": self._session_id,
|
||||
"seq": self._last_seq,
|
||||
},
|
||||
})
|
||||
logger.debug(f"[QQ] Resume sent: session_id={self._session_id}, seq={self._last_seq}")
|
||||
|
||||
def _start_heartbeat(self, interval_ms: int):
|
||||
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
|
||||
return
|
||||
self._heartbeat_interval = interval_ms
|
||||
interval_sec = interval_ms / 1000.0
|
||||
|
||||
def heartbeat_loop():
|
||||
while not self._stop_event.is_set() and self._connected:
|
||||
try:
|
||||
self._ws_send({
|
||||
"op": OP_HEARTBEAT,
|
||||
"d": self._last_seq,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"[QQ] Heartbeat send failed: {e}")
|
||||
break
|
||||
self._stop_event.wait(interval_sec)
|
||||
|
||||
self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
|
||||
self._heartbeat_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Incoming message dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_ws_message(self, data: dict):
|
||||
op = data.get("op")
|
||||
d = data.get("d")
|
||||
t = data.get("t")
|
||||
s = data.get("s")
|
||||
|
||||
if s is not None:
|
||||
self._last_seq = s
|
||||
|
||||
if op == OP_HELLO:
|
||||
heartbeat_interval = d.get("heartbeat_interval", 45000) if d else 45000
|
||||
logger.debug(f"[QQ] Received Hello, heartbeat_interval={heartbeat_interval}ms")
|
||||
self._heartbeat_interval = heartbeat_interval
|
||||
if self._can_resume and self._session_id:
|
||||
self._send_resume()
|
||||
else:
|
||||
self._send_identify()
|
||||
|
||||
elif op == OP_HEARTBEAT_ACK:
|
||||
pass
|
||||
|
||||
elif op == OP_HEARTBEAT:
|
||||
self._ws_send({"op": OP_HEARTBEAT, "d": self._last_seq})
|
||||
|
||||
elif op == OP_RECONNECT:
|
||||
logger.warning("[QQ] Server requested reconnect")
|
||||
self._can_resume = True
|
||||
if self._ws:
|
||||
self._ws.close()
|
||||
|
||||
elif op == OP_INVALID_SESSION:
|
||||
logger.warning("[QQ] Invalid session, re-identifying...")
|
||||
self._session_id = None
|
||||
self._can_resume = False
|
||||
time.sleep(2)
|
||||
self._send_identify()
|
||||
|
||||
elif op == OP_DISPATCH:
|
||||
if t == "READY":
|
||||
self._session_id = d.get("session_id", "")
|
||||
user = d.get("user", {})
|
||||
bot_name = user.get('username', '')
|
||||
logger.info(f"[QQ] ✅ Connected successfully (bot={bot_name})")
|
||||
self._connected = True
|
||||
self._can_resume = False
|
||||
self._start_heartbeat(self._heartbeat_interval)
|
||||
self.report_startup_success()
|
||||
|
||||
elif t == "RESUMED":
|
||||
logger.info("[QQ] Session resumed successfully")
|
||||
self._connected = True
|
||||
self._can_resume = False
|
||||
self._start_heartbeat(self._heartbeat_interval)
|
||||
|
||||
elif t in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE",
|
||||
"AT_MESSAGE_CREATE", "DIRECT_MESSAGE_CREATE"):
|
||||
self._handle_msg_event(d, t)
|
||||
|
||||
elif t in ("GROUP_ADD_ROBOT", "FRIEND_ADD"):
|
||||
logger.info(f"[QQ] Event: {t}")
|
||||
|
||||
else:
|
||||
logger.debug(f"[QQ] Dispatch event: {t}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message event handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_msg_event(self, event_data: dict, event_type: str):
|
||||
msg_id = event_data.get("id", "")
|
||||
if self.received_msgs.get(msg_id):
|
||||
logger.debug(f"[QQ] Duplicate msg filtered: {msg_id}")
|
||||
return
|
||||
self.received_msgs[msg_id] = True
|
||||
|
||||
try:
|
||||
qq_msg = QQMessage(event_data, event_type)
|
||||
except NotImplementedError as e:
|
||||
logger.warning(f"[QQ] {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to parse message: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
is_group = qq_msg.is_group
|
||||
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
if is_group:
|
||||
session_id = qq_msg.other_user_id
|
||||
else:
|
||||
session_id = qq_msg.from_user_id
|
||||
|
||||
if qq_msg.ctype == ContextType.IMAGE:
|
||||
if hasattr(qq_msg, "image_path") and qq_msg.image_path:
|
||||
file_cache.add(session_id, qq_msg.image_path, file_type="image")
|
||||
logger.info(f"[QQ] Image cached for session {session_id}")
|
||||
return
|
||||
|
||||
if qq_msg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
file_refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
fpath = fi["path"]
|
||||
if ftype == "image":
|
||||
file_refs.append(f"[图片: {fpath}]")
|
||||
elif ftype == "video":
|
||||
file_refs.append(f"[视频: {fpath}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {fpath}]")
|
||||
qq_msg.content = qq_msg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[QQ] Attached {len(cached_files)} cached file(s)")
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(
|
||||
qq_msg.ctype,
|
||||
qq_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=qq_msg,
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _compose_context
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
|
||||
if cmsg.is_group:
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send reply
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context.get("isgroup", False)
|
||||
receiver = context.get("receiver", "")
|
||||
|
||||
if not msg:
|
||||
# Active send (e.g. scheduled tasks), no original message to reply to
|
||||
self._active_send_text(reply.content if reply.type == ReplyType.TEXT else str(reply.content),
|
||||
receiver, is_group)
|
||||
return
|
||||
|
||||
event_type = getattr(msg, "event_type", "")
|
||||
msg_id = getattr(msg, "msg_id", "")
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
self._send_text(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
|
||||
self._send_image(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type == ReplyType.FILE:
|
||||
if hasattr(reply, "text_content") and reply.text_content:
|
||||
self._send_text(reply.text_content, msg, event_type, msg_id)
|
||||
time.sleep(0.3)
|
||||
self._send_file(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type in (ReplyType.VIDEO, ReplyType.VIDEO_URL):
|
||||
self._send_media(reply.content, msg, event_type, msg_id, QQ_FILE_TYPE_VIDEO)
|
||||
else:
|
||||
logger.warning(f"[QQ] Unsupported reply type: {reply.type}, falling back to text")
|
||||
self._send_text(str(reply.content), msg, event_type, msg_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_next_msg_seq(self, msg_id: str) -> int:
|
||||
seq = self._msg_seq_counter.get(msg_id, 1)
|
||||
self._msg_seq_counter[msg_id] = seq + 1
|
||||
return seq
|
||||
|
||||
def _build_msg_url_and_base_body(self, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Build the API URL and base body dict for sending a message."""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
url = f"{QQ_API_BASE}/v2/groups/{group_openid}/messages"
|
||||
body = {
|
||||
"msg_id": msg_id,
|
||||
"msg_seq": self._get_next_msg_seq(msg_id),
|
||||
}
|
||||
return url, body, "group", group_openid
|
||||
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = msg._rawmsg.get("author", {}).get("user_openid", "") or msg.from_user_id
|
||||
url = f"{QQ_API_BASE}/v2/users/{user_openid}/messages"
|
||||
body = {
|
||||
"msg_id": msg_id,
|
||||
"msg_seq": self._get_next_msg_seq(msg_id),
|
||||
}
|
||||
return url, body, "c2c", user_openid
|
||||
|
||||
elif event_type == "AT_MESSAGE_CREATE":
|
||||
channel_id = msg._rawmsg.get("channel_id", "")
|
||||
url = f"{QQ_API_BASE}/channels/{channel_id}/messages"
|
||||
body = {"msg_id": msg_id}
|
||||
return url, body, "channel", channel_id
|
||||
|
||||
elif event_type == "DIRECT_MESSAGE_CREATE":
|
||||
guild_id = msg._rawmsg.get("guild_id", "")
|
||||
url = f"{QQ_API_BASE}/dms/{guild_id}/messages"
|
||||
body = {"msg_id": msg_id}
|
||||
return url, body, "dm", guild_id
|
||||
|
||||
return None, None, None, None
|
||||
|
||||
def _post_message(self, url: str, body: dict, event_type: str):
|
||||
try:
|
||||
resp = requests.post(url, json=body, headers=self._get_auth_headers(), timeout=10)
|
||||
if resp.status_code in (200, 201, 202, 204):
|
||||
logger.info(f"[QQ] Message sent successfully: event_type={event_type}")
|
||||
else:
|
||||
logger.error(f"[QQ] Failed to send message: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Send message error: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Active send (no original message, e.g. scheduled tasks)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _active_send_text(self, content: str, receiver: str, is_group: bool):
|
||||
"""Send text without an original message (active push). QQ limits active messages to 4/month per user."""
|
||||
if not receiver:
|
||||
logger.warning("[QQ] No receiver for active send")
|
||||
return
|
||||
if is_group:
|
||||
url = f"{QQ_API_BASE}/v2/groups/{receiver}/messages"
|
||||
else:
|
||||
url = f"{QQ_API_BASE}/v2/users/{receiver}/messages"
|
||||
body = {
|
||||
"content": content,
|
||||
"msg_type": 0,
|
||||
}
|
||||
event_label = "GROUP_ACTIVE" if is_group else "C2C_ACTIVE"
|
||||
self._post_message(url, body, event_label)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send text
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_text(self, content: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
|
||||
if not url:
|
||||
logger.warning(f"[QQ] Cannot send reply for event_type: {event_type}")
|
||||
return
|
||||
body["content"] = content
|
||||
body["msg_type"] = 0
|
||||
self._post_message(url, body, event_type)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rich media upload & send (image / video / file)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _upload_rich_media(self, file_url: str, file_type: int, msg: QQMessage,
|
||||
event_type: str) -> str:
|
||||
"""
|
||||
Upload media via QQ rich media API and return file_info.
|
||||
For group: POST /v2/groups/{group_openid}/files
|
||||
For c2c: POST /v2/users/{openid}/files
|
||||
"""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
|
||||
or msg.from_user_id)
|
||||
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
|
||||
else:
|
||||
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
|
||||
return ""
|
||||
|
||||
upload_body = {
|
||||
"file_type": file_type,
|
||||
"url": file_url,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
upload_url, json=upload_body,
|
||||
headers=self._get_auth_headers(), timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
data = resp.json()
|
||||
file_info = data.get("file_info", "")
|
||||
logger.info(f"[QQ] Rich media uploaded: file_type={file_type}, "
|
||||
f"file_uuid={data.get('file_uuid', '')}")
|
||||
return file_info
|
||||
else:
|
||||
logger.error(f"[QQ] Rich media upload failed: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Rich media upload error: {e}")
|
||||
return ""
|
||||
|
||||
def _upload_rich_media_base64(self, file_path: str, file_type: int, msg: QQMessage,
|
||||
event_type: str) -> str:
|
||||
"""Upload local file via base64 file_data field."""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
|
||||
or msg.from_user_id)
|
||||
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
|
||||
else:
|
||||
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = base64.b64encode(f.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to read file for upload: {e}")
|
||||
return ""
|
||||
|
||||
upload_body = {
|
||||
"file_type": file_type,
|
||||
"file_data": file_data,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
upload_url, json=upload_body,
|
||||
headers=self._get_auth_headers(), timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
data = resp.json()
|
||||
file_info = data.get("file_info", "")
|
||||
logger.info(f"[QQ] Rich media uploaded (base64): file_type={file_type}, "
|
||||
f"file_uuid={data.get('file_uuid', '')}")
|
||||
return file_info
|
||||
else:
|
||||
logger.error(f"[QQ] Rich media upload (base64) failed: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Rich media upload (base64) error: {e}")
|
||||
return ""
|
||||
|
||||
def _send_media_msg(self, file_info: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send a message with msg_type=7 (rich media) using file_info."""
|
||||
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
|
||||
if not url:
|
||||
return
|
||||
body["msg_type"] = 7
|
||||
body["media"] = {"file_info": file_info}
|
||||
self._post_message(url, body, event_type)
|
||||
|
||||
def _send_image(self, img_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send image reply. Supports URL and local file path."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(img_path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if img_path_or_url.startswith("file://"):
|
||||
img_path_or_url = img_path_or_url[7:]
|
||||
|
||||
if img_path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(
|
||||
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
|
||||
elif os.path.exists(img_path_or_url):
|
||||
file_info = self._upload_rich_media_base64(
|
||||
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] Image not found: {img_path_or_url}")
|
||||
self._send_text("[Image send failed]", msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
self._send_text("[Image upload failed]", msg, event_type, msg_id)
|
||||
|
||||
def _send_file(self, file_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send file reply."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(file_path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_path_or_url.startswith("file://"):
|
||||
file_path_or_url = file_path_or_url[7:]
|
||||
|
||||
if file_path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(
|
||||
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
|
||||
elif os.path.exists(file_path_or_url):
|
||||
file_info = self._upload_rich_media_base64(
|
||||
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] File not found: {file_path_or_url}")
|
||||
self._send_text("[File send failed]", msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
self._send_text("[File upload failed]", msg, event_type, msg_id)
|
||||
|
||||
def _send_media(self, path_or_url: str, msg: QQMessage, event_type: str,
|
||||
msg_id: str, file_type: int):
|
||||
"""Generic media send for video/voice etc."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if path_or_url.startswith("file://"):
|
||||
path_or_url = path_or_url[7:]
|
||||
|
||||
if path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(path_or_url, file_type, msg, event_type)
|
||||
elif os.path.exists(path_or_url):
|
||||
file_info = self._upload_rich_media_base64(path_or_url, file_type, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] Media not found: {path_or_url}")
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
logger.error(f"[QQ] Media upload failed: {path_or_url}")
|
||||
123
channel/qq/qq_message.py
Normal file
123
channel/qq/qq_message.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
import requests
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
def _get_tmp_dir() -> str:
|
||||
"""Return the workspace tmp directory (absolute path), creating it if needed."""
|
||||
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(ws_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class QQMessage(ChatMessage):
|
||||
"""Message wrapper for QQ Bot (websocket long-connection mode)."""
|
||||
|
||||
def __init__(self, event_data: dict, event_type: str):
|
||||
super().__init__(event_data)
|
||||
self.msg_id = event_data.get("id", "")
|
||||
self.create_time = event_data.get("timestamp", "")
|
||||
self.is_group = event_type in ("GROUP_AT_MESSAGE_CREATE",)
|
||||
self.event_type = event_type
|
||||
|
||||
author = event_data.get("author", {})
|
||||
from_user_id = author.get("member_openid", "") or author.get("id", "")
|
||||
group_openid = event_data.get("group_openid", "")
|
||||
|
||||
content = event_data.get("content", "").strip()
|
||||
|
||||
attachments = event_data.get("attachments", [])
|
||||
has_image = any(
|
||||
a.get("content_type", "").startswith("image/") for a in attachments
|
||||
) if attachments else False
|
||||
|
||||
if has_image and not content:
|
||||
self.ctype = ContextType.IMAGE
|
||||
img_attachment = next(
|
||||
a for a in attachments if a.get("content_type", "").startswith("image/")
|
||||
)
|
||||
img_url = img_attachment.get("url", "")
|
||||
if img_url and not img_url.startswith("http"):
|
||||
img_url = "https://" + img_url
|
||||
tmp_dir = _get_tmp_dir()
|
||||
image_path = os.path.join(tmp_dir, f"qq_{self.msg_id}.png")
|
||||
try:
|
||||
resp = requests.get(img_url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
self.content = image_path
|
||||
self.image_path = image_path
|
||||
logger.info(f"[QQ] Image downloaded: {image_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to download image: {e}")
|
||||
self.content = "[Image download failed]"
|
||||
self.image_path = None
|
||||
elif has_image and content:
|
||||
self.ctype = ContextType.TEXT
|
||||
image_paths = []
|
||||
tmp_dir = _get_tmp_dir()
|
||||
for idx, att in enumerate(attachments):
|
||||
if not att.get("content_type", "").startswith("image/"):
|
||||
continue
|
||||
img_url = att.get("url", "")
|
||||
if img_url and not img_url.startswith("http"):
|
||||
img_url = "https://" + img_url
|
||||
img_path = os.path.join(tmp_dir, f"qq_{self.msg_id}_{idx}.png")
|
||||
try:
|
||||
resp = requests.get(img_url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
with open(img_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
image_paths.append(img_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to download mixed image: {e}")
|
||||
content_parts = [content]
|
||||
for p in image_paths:
|
||||
content_parts.append(f"[图片: {p}]")
|
||||
self.content = "\n".join(content_parts)
|
||||
else:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = content
|
||||
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
self.other_user_id = group_openid
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = from_user_id
|
||||
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = author.get("user_openid", "") or from_user_id
|
||||
self.from_user_id = user_openid
|
||||
self.to_user_id = ""
|
||||
self.other_user_id = user_openid
|
||||
self.actual_user_id = user_openid
|
||||
|
||||
elif event_type == "AT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
channel_id = event_data.get("channel_id", "")
|
||||
self.other_user_id = channel_id
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = author.get("username", from_user_id)
|
||||
|
||||
elif event_type == "DIRECT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
guild_id = event_data.get("guild_id", "")
|
||||
self.other_user_id = f"dm_{guild_id}_{from_user_id}"
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = author.get("username", from_user_id)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported QQ event type: {event_type}")
|
||||
|
||||
logger.debug(f"[QQ] Message parsed: type={event_type}, ctype={self.ctype}, "
|
||||
f"from={self.from_user_id}, content_len={len(self.content)}")
|
||||
1
channel/slack/__init__.py
Normal file
1
channel/slack/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
506
channel/slack/slack_channel.py
Normal file
506
channel/slack/slack_channel.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Slack channel via Bolt for Python (Socket Mode).
|
||||
|
||||
Features:
|
||||
- Direct message & channel chat (text / image / file)
|
||||
- Channel trigger: @mention or reply in a thread the bot is in (configurable)
|
||||
- /cancel fast-path matches Web channel behaviour
|
||||
- Socket Mode: no public IP / callback URL required, works behind NAT
|
||||
|
||||
Implementation note:
|
||||
slack_bolt's SocketModeHandler is blocking and runs its own background
|
||||
threads. We start it in a dedicated thread so the rest of cow (sync) stays
|
||||
untouched. Inbound events are dispatched onto cow's existing sync
|
||||
ChatChannel.produce() pipeline; outbound send() calls the Slack Web API
|
||||
client directly (it is sync-safe).
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
|
||||
import requests
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.slack.slack_message import SlackMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
|
||||
@singleton
|
||||
class SlackChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = ""
|
||||
self.app_token = ""
|
||||
self.bot_user_id = "" # used to strip @mention and ignore self messages
|
||||
self._app = None
|
||||
self._handler = None
|
||||
self._client = None
|
||||
self._loop_thread = None
|
||||
# Idempotent dedup; Slack retries event delivery on slow ack
|
||||
self._received_msgs = ExpiredDict(60 * 60 * 1)
|
||||
|
||||
# Disable group whitelist / prefix checks (we handle triggering ourselves
|
||||
# in _should_reply_in_channel), aligned with telegram / feishu channels.
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.bot_token = conf().get("slack_bot_token", "")
|
||||
self.app_token = conf().get("slack_app_token", "")
|
||||
if not self.bot_token or not self.app_token:
|
||||
err = "[Slack] slack_bot_token and slack_app_token are both required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
# Guard against the common mistake of swapping the two tokens:
|
||||
# bot token must start with xoxb-, app-level token with xapp-.
|
||||
if not self.bot_token.startswith("xoxb-") or not self.app_token.startswith("xapp-"):
|
||||
err = (
|
||||
"[Slack] token type mismatch: slack_bot_token must start with 'xoxb-' "
|
||||
"and slack_app_token must start with 'xapp-' (they look swapped)"
|
||||
)
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
try:
|
||||
from slack_bolt import App
|
||||
from slack_bolt.adapter.socket_mode import SocketModeHandler
|
||||
except ImportError:
|
||||
err = (
|
||||
"[Slack] slack_bolt is not installed. "
|
||||
"Run: pip install slack_bolt"
|
||||
)
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
try:
|
||||
self._app = App(token=self.bot_token)
|
||||
self._client = self._app.client
|
||||
|
||||
# Resolve our own bot user id (needed for @mention strip / self-ignore)
|
||||
auth = self._client.auth_test()
|
||||
self.bot_user_id = auth.get("user_id", "")
|
||||
self.name = self.bot_user_id # ChatChannel uses self.name to strip @-mention
|
||||
logger.info(f"[Slack] Bot logged in as user_id={self.bot_user_id}, team={auth.get('team')}")
|
||||
except Exception as e:
|
||||
err = f"[Slack] auth_test failed: {e}"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._register_handlers()
|
||||
|
||||
self._handler = SocketModeHandler(self._app, self.app_token)
|
||||
|
||||
def _run():
|
||||
try:
|
||||
logger.info("[Slack] Starting Socket Mode connection...")
|
||||
self.report_startup_success()
|
||||
logger.info("[Slack] ✅ Slack bot ready, listening for events")
|
||||
self._handler.start()
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] socket mode crashed: {e}", exc_info=True)
|
||||
self.report_startup_error(str(e))
|
||||
finally:
|
||||
logger.info("[Slack] socket mode exited")
|
||||
|
||||
self._loop_thread = threading.Thread(target=_run, daemon=True, name="slack-socket")
|
||||
self._loop_thread.start()
|
||||
# Block startup() until the handler thread exits, matching other channels'
|
||||
# behaviour (startup is a blocking call).
|
||||
self._loop_thread.join()
|
||||
|
||||
def _register_handlers(self):
|
||||
app = self._app
|
||||
|
||||
# app_mention: bot is @-mentioned in a channel
|
||||
@app.event("app_mention")
|
||||
def _on_app_mention(event, ack):
|
||||
ack()
|
||||
self._handle_event(event, is_group=True)
|
||||
|
||||
# message: DMs and channel messages (including thread replies)
|
||||
@app.event("message")
|
||||
def _on_message(event, ack):
|
||||
ack()
|
||||
self._handle_message_event(event)
|
||||
|
||||
def stop(self):
|
||||
logger.info("[Slack] stop() called")
|
||||
try:
|
||||
if self._handler is not None:
|
||||
self._handler.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Slack] handler close error: {e}")
|
||||
if self._loop_thread and self._loop_thread.is_alive():
|
||||
try:
|
||||
self._loop_thread.join(timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[Slack] stop() completed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound: slack event -> ChatMessage -> ChatChannel.produce
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_message_event(self, event: dict):
|
||||
"""Route a raw `message` event: skip bot/system noise, decide grouping."""
|
||||
try:
|
||||
logger.debug(
|
||||
f"[Slack] message event: channel_type={event.get('channel_type')}, "
|
||||
f"subtype={event.get('subtype')}, user={event.get('user')}, "
|
||||
f"ts={event.get('ts')}, thread_ts={event.get('thread_ts')}"
|
||||
)
|
||||
# Ignore bot messages (including our own) and message edits/deletes
|
||||
if event.get("bot_id") or event.get("subtype") in ("bot_message", "message_changed", "message_deleted"):
|
||||
return
|
||||
if event.get("user") == self.bot_user_id:
|
||||
return
|
||||
|
||||
channel_type = event.get("channel_type", "")
|
||||
# DM (im) is single chat; channel/group is group chat. app_mention
|
||||
# already covers channel @-mentions, so for plain channel messages we
|
||||
# only react when configured / thread-following.
|
||||
is_group = channel_type in ("channel", "group", "mpim")
|
||||
if is_group:
|
||||
# app_mention handler covers explicit @bot; here we only handle
|
||||
# follow-up replies in threads the bot participates in.
|
||||
if not self._should_reply_in_channel(event):
|
||||
return
|
||||
self._handle_event(event, is_group=is_group)
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] _handle_message_event error: {e}", exc_info=True)
|
||||
|
||||
def _handle_event(self, event: dict, is_group: bool):
|
||||
"""Parse event -> build SlackMessage -> produce()."""
|
||||
try:
|
||||
channel_id = event.get("channel", "")
|
||||
ts = event.get("ts", "")
|
||||
if not channel_id:
|
||||
return
|
||||
|
||||
# Idempotent dedup
|
||||
msg_uid = f"{channel_id}:{ts}"
|
||||
if self._received_msgs.get(msg_uid):
|
||||
return
|
||||
self._received_msgs[msg_uid] = True
|
||||
|
||||
# Parse type + download media if needed.
|
||||
ctype, content, caption = self._parse_event(event)
|
||||
if ctype is None:
|
||||
logger.debug(f"[Slack] unsupported message type, skip. event={event}")
|
||||
return
|
||||
|
||||
# Strip <@bot_user_id> mention from channel text
|
||||
if is_group and self.bot_user_id:
|
||||
if ctype == ContextType.TEXT and content:
|
||||
content = self._strip_at_mention(content)
|
||||
if caption:
|
||||
caption = self._strip_at_mention(caption)
|
||||
|
||||
slack_msg = SlackMessage(
|
||||
event,
|
||||
is_group=is_group,
|
||||
bot_user_id=self.bot_user_id,
|
||||
ctype=ctype,
|
||||
content=content,
|
||||
)
|
||||
slack_msg.is_at = is_group # if we reached here in a channel, bot is mentioned/threaded
|
||||
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
session_id = self._compute_session_id(event, is_group)
|
||||
|
||||
# Media + caption together: treat as a complete query and bypass the cache
|
||||
if ctype in (ContextType.IMAGE, ContextType.FILE) and caption:
|
||||
tag = "image" if ctype == ContextType.IMAGE else "file"
|
||||
merged_text = f"{caption}\n[{tag}: {content}]"
|
||||
slack_msg.ctype = ContextType.TEXT
|
||||
slack_msg.content = merged_text
|
||||
ctype = ContextType.TEXT
|
||||
logger.info(f"[Slack] Media+caption merged for session {session_id}")
|
||||
# fallthrough to the TEXT branch below
|
||||
|
||||
elif ctype == ContextType.IMAGE:
|
||||
file_cache.add(session_id, content, file_type="image")
|
||||
logger.info(f"[Slack] Image cached for session {session_id}, waiting for query...")
|
||||
return
|
||||
elif ctype == ContextType.FILE:
|
||||
file_cache.add(session_id, content, file_type="file")
|
||||
logger.info(f"[Slack] File cached for session {session_id}: {content}")
|
||||
return
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# Fast-path: /cancel mirrors Web channel behaviour
|
||||
if (content or "").strip().lower() in ("/cancel", "cancel"):
|
||||
self._do_cancel(session_id, channel_id, event)
|
||||
return
|
||||
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
tag = ftype if ftype in ("image", "video") else "file"
|
||||
refs.append(f"[{tag}: {fi['path']}]")
|
||||
slack_msg.content = (slack_msg.content or "") + "\n" + "\n".join(refs)
|
||||
file_cache.clear(session_id)
|
||||
logger.info(f"[Slack] Attached {len(cached_files)} cached file(s) to query")
|
||||
|
||||
# Reply in the originating thread when present, else start one on this msg
|
||||
thread_ts = event.get("thread_ts") or ts
|
||||
|
||||
context = self._compose_context(
|
||||
slack_msg.ctype,
|
||||
slack_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=slack_msg,
|
||||
# Replies go back into the thread, no manual @mention needed
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = channel_id
|
||||
context["slack_channel"] = channel_id
|
||||
context["slack_thread_ts"] = thread_ts if is_group else None
|
||||
self.produce(context)
|
||||
logger.debug(f"[Slack] received: type={ctype}, content={str(slack_msg.content)[:80]}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] _handle_event error: {e}", exc_info=True)
|
||||
|
||||
def _do_cancel(self, session_id: str, channel_id: str, event: dict):
|
||||
"""Fast-path: /cancel calls cancel_session directly without going through agent."""
|
||||
try:
|
||||
from agent.protocol import get_cancel_registry
|
||||
cancelled = get_cancel_registry().cancel_session(session_id)
|
||||
text = "Current task cancelled." if cancelled else "No running task to cancel."
|
||||
thread_ts = event.get("thread_ts") or event.get("ts")
|
||||
self._client.chat_postMessage(channel=channel_id, text=text, thread_ts=thread_ts)
|
||||
logger.info(f"[Slack] /cancel session={session_id}, cancelled={cancelled}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] /cancel error: {e}", exc_info=True)
|
||||
|
||||
def _parse_event(self, event: dict):
|
||||
"""Parse a slack event and return (ctype, content, caption).
|
||||
|
||||
- content is text for ContextType.TEXT, otherwise the local file path
|
||||
- caption is the optional text accompanying a file; empty for plain text
|
||||
"""
|
||||
text = (event.get("text") or "").strip()
|
||||
files = event.get("files") or []
|
||||
|
||||
if files:
|
||||
# Handle the first attachment; caption is the accompanying message text
|
||||
f = files[0]
|
||||
mimetype = (f.get("mimetype") or "").lower()
|
||||
url = f.get("url_private_download") or f.get("url_private")
|
||||
name = f.get("name") or f.get("id") or "file"
|
||||
if not url:
|
||||
return (None, None, "")
|
||||
path = self._download_file(url, name)
|
||||
if not path:
|
||||
return (None, None, "")
|
||||
if mimetype.startswith("image/"):
|
||||
return (ContextType.IMAGE, path, text)
|
||||
return (ContextType.FILE, path, text)
|
||||
|
||||
if text:
|
||||
return (ContextType.TEXT, text, "")
|
||||
|
||||
return (None, None, "")
|
||||
|
||||
def _download_file(self, url: str, name: str):
|
||||
"""Download a Slack private file (requires bot token auth) to local tmp dir."""
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {self.bot_token}"}
|
||||
resp = requests.get(url, headers=headers, timeout=60, stream=True)
|
||||
resp.raise_for_status()
|
||||
tmp_dir = SlackMessage.get_tmp_dir()
|
||||
# Sanitize the name and keep it unique-ish via the url tail
|
||||
safe_name = re.sub(r"[^\w.\-]", "_", name)
|
||||
local_path = os.path.join(tmp_dir, safe_name)
|
||||
with open(local_path, "wb") as fp:
|
||||
for chunk in resp.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
fp.write(chunk)
|
||||
logger.debug(f"[Slack] downloaded {name} -> {local_path}")
|
||||
return local_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] download_file failed ({name}): {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Channel trigger logic
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _should_reply_in_channel(self, event: dict) -> bool:
|
||||
"""Decide whether to reply to a plain channel message (no @mention).
|
||||
|
||||
app_mention already handles explicit @bot, so here we only deal with
|
||||
follow-up messages. `all` replies to every message; `mention_or_reply`
|
||||
replies inside threads the bot already participates in.
|
||||
"""
|
||||
mode = conf().get("slack_group_trigger", "mention_or_reply")
|
||||
if mode == "all":
|
||||
return True
|
||||
if mode == "mention_only":
|
||||
return False
|
||||
# mention_or_reply: follow up only within an existing thread
|
||||
return bool(event.get("thread_ts"))
|
||||
|
||||
def _strip_at_mention(self, content: str) -> str:
|
||||
"""Strip <@BOT_USER_ID> from channel text."""
|
||||
if not content or not self.bot_user_id:
|
||||
return content
|
||||
pattern = re.compile(r"<@" + re.escape(self.bot_user_id) + r">", re.IGNORECASE)
|
||||
return pattern.sub("", content).strip()
|
||||
|
||||
@staticmethod
|
||||
def _compute_session_id(event: dict, is_group: bool) -> str:
|
||||
channel_id = event.get("channel", "")
|
||||
user_id = event.get("user", "")
|
||||
if is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
return f"slack_channel_{channel_id}"
|
||||
return f"slack_channel_{channel_id}_{user_id}"
|
||||
return f"slack_user_{user_id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override _compose_context: skip the parent's group whitelist/at checks
|
||||
# (already handled via _should_reply_in_channel). Same idea as telegram.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
if cmsg.is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = (content or "").strip()
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif ctype == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and (
|
||||
conf().get("voice_reply_voice") or conf().get("always_reply_voice")
|
||||
):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound: ChatChannel.send -> Slack Web API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""Called from cow's sync main thread; Slack Web client is sync-safe."""
|
||||
if self._client is None:
|
||||
logger.warning("[Slack] client not ready, drop reply")
|
||||
return
|
||||
|
||||
channel_id = context.get("slack_channel")
|
||||
thread_ts = context.get("slack_thread_ts")
|
||||
if not channel_id:
|
||||
logger.warning("[Slack] no slack_channel in context, drop reply")
|
||||
return
|
||||
|
||||
try:
|
||||
self._do_send(reply, channel_id, thread_ts)
|
||||
logger.info(f"[Slack] sent reply (type={reply.type}, channel={channel_id})")
|
||||
except Exception as e:
|
||||
logger.error(f"[Slack] send failed: {e}", exc_info=True)
|
||||
|
||||
def _do_send(self, reply: Reply, channel_id: str, thread_ts):
|
||||
rtype = reply.type
|
||||
content = reply.content
|
||||
|
||||
if rtype in (ReplyType.TEXT, ReplyType.INFO, ReplyType.ERROR):
|
||||
text = str(content) if content is not None else ""
|
||||
if not text:
|
||||
return
|
||||
# Slack caps a message around 40k chars; split conservatively
|
||||
for chunk in _split_text(text, 3500):
|
||||
self._client.chat_postMessage(channel=channel_id, text=chunk, thread_ts=thread_ts)
|
||||
|
||||
elif rtype == ReplyType.IMAGE:
|
||||
# Already a local BytesIO; upload it directly
|
||||
content.seek(0)
|
||||
self._client.files_upload_v2(
|
||||
channel=channel_id, file=content, filename="image.png", thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
elif rtype == ReplyType.IMAGE_URL:
|
||||
url = str(content)
|
||||
if url.startswith("file://"):
|
||||
local = url[7:]
|
||||
self._client.files_upload_v2(
|
||||
channel=channel_id, file=local, thread_ts=thread_ts,
|
||||
)
|
||||
else:
|
||||
# Post the URL as text; Slack will unfurl it as an image preview
|
||||
self._client.chat_postMessage(channel=channel_id, text=url, thread_ts=thread_ts)
|
||||
|
||||
elif rtype in (ReplyType.VOICE, ReplyType.FILE):
|
||||
local = content[7:] if isinstance(content, str) and content.startswith("file://") else content
|
||||
caption = getattr(reply, "text_content", None) or None
|
||||
self._client.files_upload_v2(
|
||||
channel=channel_id, file=local, initial_comment=caption, thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
else:
|
||||
# Fallback: send as plain text
|
||||
self._client.chat_postMessage(channel=channel_id, text=str(content), thread_ts=thread_ts)
|
||||
|
||||
|
||||
def _split_text(text: str, limit: int):
|
||||
"""Split long text preferring line breaks to keep markdown structure intact."""
|
||||
if len(text) <= limit:
|
||||
yield text
|
||||
return
|
||||
buf = []
|
||||
size = 0
|
||||
for line in text.splitlines(keepends=True):
|
||||
if size + len(line) > limit and buf:
|
||||
yield "".join(buf)
|
||||
buf, size = [], 0
|
||||
# Hard-split single lines that exceed the limit
|
||||
while len(line) > limit:
|
||||
yield line[:limit]
|
||||
line = line[limit:]
|
||||
buf.append(line)
|
||||
size += len(line)
|
||||
if buf:
|
||||
yield "".join(buf)
|
||||
60
channel/slack/slack_message.py
Normal file
60
channel/slack/slack_message.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Slack message adapter.
|
||||
|
||||
Convert a Slack event payload into cow's unified ChatMessage.
|
||||
File downloads are NOT performed here; the channel layer downloads files
|
||||
on demand because it needs the bot token for authenticated download URLs.
|
||||
"""
|
||||
import os
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
class SlackMessage(ChatMessage):
|
||||
"""Wrap a Slack event into the unified ChatMessage."""
|
||||
|
||||
def __init__(self, event: dict, is_group: bool = False, bot_user_id: str = "",
|
||||
ctype: ContextType = ContextType.TEXT, content: str = ""):
|
||||
super().__init__(event)
|
||||
# Basic fields
|
||||
self.msg_id = event.get("client_msg_id") or event.get("ts") or ""
|
||||
try:
|
||||
self.create_time = int(float(event.get("ts", 0)))
|
||||
except (TypeError, ValueError):
|
||||
self.create_time = 0
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
|
||||
# Sender / chat info
|
||||
from_user_id = event.get("user", "unknown")
|
||||
channel_id = event.get("channel", "")
|
||||
self.from_user_id = from_user_id
|
||||
self.from_user_nickname = from_user_id
|
||||
self.to_user_id = bot_user_id or "slack_bot"
|
||||
self.to_user_nickname = bot_user_id or "slack_bot"
|
||||
|
||||
self.is_group = is_group
|
||||
if is_group:
|
||||
# Channel chat: other_user_id = channel_id, actual_user_id = sender id
|
||||
self.other_user_id = channel_id
|
||||
self.other_user_nickname = channel_id
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = from_user_id
|
||||
else:
|
||||
# DM: use channel_id so replies go back to the same DM channel
|
||||
self.other_user_id = channel_id or from_user_id
|
||||
self.other_user_nickname = from_user_id
|
||||
|
||||
# Whether the bot was triggered by @-mention (set by channel layer)
|
||||
self.is_at = False
|
||||
|
||||
@staticmethod
|
||||
def get_tmp_dir() -> str:
|
||||
"""Local download directory, aligned with other channels (agent_workspace/tmp)."""
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
0
channel/telegram/__init__.py
Normal file
0
channel/telegram/__init__.py
Normal file
719
channel/telegram/telegram_channel.py
Normal file
719
channel/telegram/telegram_channel.py
Normal file
@@ -0,0 +1,719 @@
|
||||
"""
|
||||
Telegram channel via Bot API (long polling mode).
|
||||
|
||||
Features:
|
||||
- Single chat & group chat (text / photo / voice / video / document)
|
||||
- Group trigger: @mention or reply-to-bot (configurable)
|
||||
- /cancel fast-path matches Web channel behaviour
|
||||
- Auto-register bot commands menu on startup (mirrors Web slash menu)
|
||||
- Optional HTTP/SOCKS5 proxy support for restricted networks
|
||||
|
||||
Implementation note:
|
||||
python-telegram-bot is async-first. We run the bot inside a dedicated
|
||||
thread with its own asyncio loop so the rest of cow (which is sync)
|
||||
stays untouched. Inbound updates are dispatched onto cow's existing
|
||||
sync ChatChannel.produce() pipeline; outbound send() schedules
|
||||
coroutines back onto that loop via asyncio.run_coroutine_threadsafe.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.telegram.telegram_message import TelegramMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
# Bot command menu, aligned with Web slash commands.
|
||||
# Top-level commands only; sub-commands are entered with a space (e.g. "/skill list").
|
||||
TELEGRAM_BOT_COMMANDS = [
|
||||
("help", "Show command help"),
|
||||
("status", "Show running status"),
|
||||
("context", "View/clear conversation context (sub: clear)"),
|
||||
("skill", "Manage skills (list/search/install/...)"),
|
||||
("memory", "Manage memory (sub: dream)"),
|
||||
("knowledge", "Manage knowledge base (list/on/off)"),
|
||||
("config", "Show current config"),
|
||||
("cancel", "Cancel running agent task"),
|
||||
("logs", "Show recent logs"),
|
||||
("version", "Show version"),
|
||||
]
|
||||
|
||||
|
||||
@singleton
|
||||
class TelegramChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bot_token = ""
|
||||
self.bot_username = "" # used for @-mention matching
|
||||
self._bot = None
|
||||
self._application = None
|
||||
self._loop = None
|
||||
self._loop_thread = None
|
||||
self._stop_event = threading.Event()
|
||||
# Idempotent dedup; TG occasionally redelivers the same update on flaky networks
|
||||
self._received_msgs = ExpiredDict(60 * 60 * 1)
|
||||
|
||||
# Disable group whitelist / prefix checks (we handle triggering ourselves
|
||||
# in _should_reply_in_group), aligned with feishu / wecom_bot channels.
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.bot_token = conf().get("telegram_token", "")
|
||||
if not self.bot_token:
|
||||
err = "[Telegram] telegram_token is required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
try:
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
MessageHandler,
|
||||
CommandHandler,
|
||||
filters,
|
||||
)
|
||||
except ImportError:
|
||||
err = (
|
||||
"[Telegram] python-telegram-bot is not installed. "
|
||||
"Run: pip install python-telegram-bot"
|
||||
)
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
# Run the asyncio event loop in a dedicated thread so the sync cow body
|
||||
# is untouched.
|
||||
self._loop = asyncio.new_event_loop()
|
||||
|
||||
def _run_loop():
|
||||
asyncio.set_event_loop(self._loop)
|
||||
try:
|
||||
self._loop.run_until_complete(self._async_main(Application, MessageHandler, CommandHandler, filters))
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] event loop crashed: {e}", exc_info=True)
|
||||
self.report_startup_error(str(e))
|
||||
finally:
|
||||
try:
|
||||
self._loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[Telegram] event loop exited")
|
||||
|
||||
self._loop_thread = threading.Thread(target=_run_loop, daemon=True, name="telegram-loop")
|
||||
self._loop_thread.start()
|
||||
# Block startup() until the loop thread exits, matching other channels'
|
||||
# behaviour (startup is a blocking call).
|
||||
self._loop_thread.join()
|
||||
|
||||
async def _async_main(self, Application, MessageHandler, CommandHandler, filters):
|
||||
"""Build Application, register handlers, and run polling."""
|
||||
builder = Application.builder().token(self.bot_token)
|
||||
|
||||
# Proxy: prefer telegram_proxy config, fall back to HTTPS_PROXY env var
|
||||
proxy_url = conf().get("telegram_proxy", "") or os.environ.get("HTTPS_PROXY", "")
|
||||
if proxy_url:
|
||||
try:
|
||||
builder = builder.proxy(proxy_url).get_updates_proxy(proxy_url)
|
||||
logger.info(f"[Telegram] using proxy: {proxy_url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Telegram] proxy config failed, fallback to direct: {e}")
|
||||
|
||||
# Media uploads (photo/voice/video/document) over a proxy can be slow,
|
||||
# bump read/write/connect/pool timeouts.
|
||||
builder = (
|
||||
builder
|
||||
.read_timeout(60)
|
||||
.write_timeout(120)
|
||||
.connect_timeout(30)
|
||||
.pool_timeout(30)
|
||||
)
|
||||
|
||||
application = builder.build()
|
||||
self._application = application
|
||||
self._bot = application.bot
|
||||
|
||||
# Fetch our own username (needed for @-mention matching in groups)
|
||||
try:
|
||||
me = await self._bot.get_me()
|
||||
self.bot_username = me.username or ""
|
||||
self.name = self.bot_username # ChatChannel uses self.name to strip @-mention
|
||||
logger.info(f"[Telegram] Bot logged in as @{self.bot_username} (id={me.id})")
|
||||
except Exception as e:
|
||||
err = f"[Telegram] get_me failed: {e}"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
# Register the command menu (failure is non-fatal)
|
||||
if conf().get("telegram_register_commands", True):
|
||||
try:
|
||||
from telegram import BotCommand
|
||||
cmds = [BotCommand(name, desc) for name, desc in TELEGRAM_BOT_COMMANDS]
|
||||
await self._bot.set_my_commands(cmds)
|
||||
logger.info(f"[Telegram] Registered {len(cmds)} bot commands")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Telegram] set_my_commands failed: {e}")
|
||||
|
||||
# Handlers:
|
||||
# 1) /cancel uses the fast-path
|
||||
application.add_handler(CommandHandler("cancel", self._on_cancel))
|
||||
# 2) Normal messages (text + media)
|
||||
application.add_handler(MessageHandler(filters.ALL & ~filters.COMMAND, self._on_message))
|
||||
# 3) Other slash commands are forwarded as plain text for the agent to handle
|
||||
application.add_handler(MessageHandler(filters.COMMAND, self._on_command_passthrough))
|
||||
|
||||
# Start polling. drop_pending_updates avoids replaying backlog after restart.
|
||||
# Transient "Server disconnected" / RemoteProtocolError during get_updates
|
||||
# are common over proxies/flaky networks; PTB's network loop auto-retries,
|
||||
# so we only need to keep the noise down (see _quiet_polling_network_errors).
|
||||
self._quiet_polling_network_errors()
|
||||
logger.info("[Telegram] Starting long polling...")
|
||||
await application.initialize()
|
||||
await application.start()
|
||||
await application.updater.start_polling(
|
||||
drop_pending_updates=True,
|
||||
# Long-poll hold time on the server side; smaller value = reconnect more
|
||||
# often but each hung connection fails faster.
|
||||
timeout=30,
|
||||
# Retry forever on transient get_updates network errors instead of giving up.
|
||||
bootstrap_retries=-1,
|
||||
)
|
||||
self.report_startup_success()
|
||||
logger.info("[Telegram] ✅ Telegram bot ready, polling for updates")
|
||||
|
||||
# Block until stop()
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
await asyncio.sleep(0.5)
|
||||
finally:
|
||||
try:
|
||||
await application.updater.stop()
|
||||
await application.stop()
|
||||
await application.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Telegram] shutdown error: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _quiet_polling_network_errors():
|
||||
"""Downgrade PTB's noisy 'Exception happened while polling for updates' logs.
|
||||
|
||||
These transient get_updates errors (RemoteProtocolError / NetworkError /
|
||||
TimedOut, typically over a proxy) are auto-retried by PTB's network loop,
|
||||
so logging the full traceback at ERROR is just noise. We attach a filter
|
||||
that drops these specific records while leaving real errors untouched.
|
||||
"""
|
||||
import logging
|
||||
|
||||
class _PollingNoiseFilter(logging.Filter):
|
||||
_NEEDLES = (
|
||||
"Exception happened while polling for updates",
|
||||
"Server disconnected without sending a response",
|
||||
)
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
try:
|
||||
msg = record.getMessage()
|
||||
except Exception:
|
||||
return True
|
||||
if any(n in msg for n in self._NEEDLES):
|
||||
# Keep a single-line breadcrumb at DEBUG, drop the traceback.
|
||||
logger.debug(f"[Telegram] transient polling network error (auto-retrying): {msg.splitlines()[0]}")
|
||||
return False
|
||||
return True
|
||||
|
||||
noise_filter = _PollingNoiseFilter()
|
||||
for name in ("telegram.ext.Updater", "telegram.ext._updater", "telegram.ext"):
|
||||
logging.getLogger(name).addFilter(noise_filter)
|
||||
|
||||
def stop(self):
|
||||
logger.info("[Telegram] stop() called")
|
||||
self._stop_event.set()
|
||||
if self._loop_thread and self._loop_thread.is_alive():
|
||||
try:
|
||||
self._loop_thread.join(timeout=10)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[Telegram] stop() completed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound: telegram update -> ChatMessage -> ChatChannel.produce
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _on_cancel(self, update, _context):
|
||||
"""Fast-path: /cancel calls cancel_session directly without going through agent."""
|
||||
try:
|
||||
from agent.protocol import get_cancel_registry
|
||||
session_id = self._compute_session_id(update)
|
||||
cancelled = get_cancel_registry().cancel_session(session_id)
|
||||
text = "Current task cancelled." if cancelled else "No running task to cancel."
|
||||
await update.effective_message.reply_text(text)
|
||||
logger.info(f"[Telegram] /cancel session={session_id}, cancelled={cancelled}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] /cancel error: {e}", exc_info=True)
|
||||
try:
|
||||
await update.effective_message.reply_text(f"⚠️ /cancel failed: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _on_command_passthrough(self, update, _context):
|
||||
"""All non-/cancel commands fall through to plain message handling."""
|
||||
await self._on_message(update, _context)
|
||||
|
||||
async def _on_message(self, update, _context):
|
||||
"""Telegram update entry: parse message -> build ChatMessage -> produce()."""
|
||||
try:
|
||||
message = update.effective_message
|
||||
chat = update.effective_chat
|
||||
if not message or not chat:
|
||||
return
|
||||
|
||||
# Idempotent dedup
|
||||
msg_uid = f"{chat.id}:{message.message_id}"
|
||||
if self._received_msgs.get(msg_uid):
|
||||
return
|
||||
self._received_msgs[msg_uid] = True
|
||||
|
||||
is_group = chat.type in ("group", "supergroup")
|
||||
|
||||
# Debug log: helpful when group messages are silently dropped
|
||||
if is_group:
|
||||
logger.debug(
|
||||
f"[Telegram] group update received: chat_id={chat.id}, "
|
||||
f"text={(message.text or message.caption or '')[:40]!r}, "
|
||||
f"reply_to_bot={bool(message.reply_to_message and message.reply_to_message.from_user and message.reply_to_message.from_user.username == self.bot_username)}"
|
||||
)
|
||||
|
||||
# Group trigger gate (silently drop if not triggered)
|
||||
if is_group and not self._should_reply_in_group(update):
|
||||
logger.debug(f"[Telegram] group message not triggered (need @{self.bot_username} or reply), skip")
|
||||
return
|
||||
|
||||
# Parse message type + download media if needed.
|
||||
# Media messages with caption return both the local path and the caption text.
|
||||
ctype, content, caption = await self._parse_message(message)
|
||||
if ctype is None:
|
||||
logger.debug(f"[Telegram] unsupported message type, skip. msg={message}")
|
||||
return
|
||||
|
||||
# Strip @bot mention for group text/caption
|
||||
if is_group and self.bot_username:
|
||||
if ctype == ContextType.TEXT and content:
|
||||
content = self._strip_at_mention(content)
|
||||
if caption:
|
||||
caption = self._strip_at_mention(caption)
|
||||
|
||||
tg_msg = TelegramMessage(
|
||||
update,
|
||||
is_group=is_group,
|
||||
bot_username=self.bot_username,
|
||||
ctype=ctype,
|
||||
content=content,
|
||||
)
|
||||
tg_msg.is_at = is_group # If we got here in a group, the bot is mentioned/replied
|
||||
|
||||
# File cache: standalone media goes into cache, the next text query attaches them
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
session_id = self._compute_session_id(update)
|
||||
|
||||
# Media + caption together: treat as a complete query and bypass the cache
|
||||
if ctype in (ContextType.IMAGE, ContextType.FILE) and caption:
|
||||
tag = "image" if ctype == ContextType.IMAGE else "file"
|
||||
merged_text = f"{caption}\n[{tag}: {content}]"
|
||||
tg_msg.ctype = ContextType.TEXT
|
||||
tg_msg.content = merged_text
|
||||
ctype = ContextType.TEXT
|
||||
logger.info(f"[Telegram] Media+caption merged for session {session_id}")
|
||||
# fallthrough to the TEXT branch below
|
||||
|
||||
elif ctype == ContextType.IMAGE:
|
||||
file_cache.add(session_id, content, file_type="image")
|
||||
logger.info(f"[Telegram] Image cached for session {session_id}, waiting for query...")
|
||||
return
|
||||
elif ctype == ContextType.FILE:
|
||||
file_cache.add(session_id, content, file_type="file")
|
||||
logger.info(f"[Telegram] File cached for session {session_id}: {content}")
|
||||
return
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
tag = ftype if ftype in ("image", "video") else "file"
|
||||
refs.append(f"[{tag}: {fi['path']}]")
|
||||
tg_msg.content = (tg_msg.content or "") + "\n" + "\n".join(refs)
|
||||
file_cache.clear(session_id)
|
||||
logger.info(f"[Telegram] Attached {len(cached_files)} cached file(s) to query")
|
||||
|
||||
# Dispatch to cow main pipeline (reuses ChatChannel._compose_context routing)
|
||||
context = self._compose_context(
|
||||
tg_msg.ctype,
|
||||
tg_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=tg_msg,
|
||||
)
|
||||
if context:
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = str(chat.id)
|
||||
context["telegram_chat_id"] = chat.id
|
||||
context["telegram_reply_to_msg_id"] = message.message_id if is_group else None
|
||||
self.produce(context)
|
||||
logger.debug(f"[Telegram] received: type={ctype}, content={str(tg_msg.content)[:80]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] _on_message error: {e}", exc_info=True)
|
||||
|
||||
async def _parse_message(self, message):
|
||||
"""Parse a telegram message and return (ctype, content, caption).
|
||||
|
||||
- content is text for ContextType.TEXT, otherwise the local file path
|
||||
- caption is the optional text accompanying a media message; empty for plain text
|
||||
"""
|
||||
caption = (message.caption or "").strip()
|
||||
|
||||
if message.photo:
|
||||
largest = message.photo[-1]
|
||||
path = await self._download_file(largest.file_id, suffix=".jpg")
|
||||
return (ContextType.IMAGE, path, caption) if path else (None, None, "")
|
||||
|
||||
if message.voice or message.audio:
|
||||
audio_obj = message.voice or message.audio
|
||||
suffix = ".ogg" if message.voice else (
|
||||
"." + (audio_obj.mime_type.split("/")[-1] if getattr(audio_obj, "mime_type", "") else "mp3")
|
||||
)
|
||||
path = await self._download_file(audio_obj.file_id, suffix=suffix)
|
||||
return (ContextType.VOICE, path, caption) if path else (None, None, "")
|
||||
|
||||
if message.video or message.video_note:
|
||||
video_obj = message.video or message.video_note
|
||||
path = await self._download_file(video_obj.file_id, suffix=".mp4")
|
||||
return (ContextType.FILE, path, caption) if path else (None, None, "")
|
||||
|
||||
if message.document:
|
||||
doc = message.document
|
||||
ext = ""
|
||||
if doc.file_name and "." in doc.file_name:
|
||||
ext = "." + doc.file_name.rsplit(".", 1)[-1]
|
||||
path = await self._download_file(doc.file_id, suffix=ext, original_name=doc.file_name)
|
||||
if not path:
|
||||
return (None, None, "")
|
||||
# Image-typed documents (user picked "send as file") are treated as images
|
||||
mime = (doc.mime_type or "").lower()
|
||||
if mime.startswith("image/"):
|
||||
return (ContextType.IMAGE, path, caption)
|
||||
return (ContextType.FILE, path, caption)
|
||||
|
||||
if message.text:
|
||||
return (ContextType.TEXT, message.text.strip(), "")
|
||||
|
||||
return (None, None, "")
|
||||
|
||||
async def _download_file(self, file_id: str, suffix: str = "", original_name: str = ""):
|
||||
"""Download via bot.get_file into the local tmp dir; return path or None on failure."""
|
||||
try:
|
||||
f = await self._bot.get_file(file_id)
|
||||
tmp_dir = TelegramMessage.get_tmp_dir()
|
||||
base = original_name or f"{file_id}{suffix or ''}"
|
||||
# Prefix with file_id to avoid name collisions / weird chars
|
||||
safe_name = f"{file_id}_{base}" if original_name else base
|
||||
local_path = os.path.join(tmp_dir, safe_name)
|
||||
await f.download_to_drive(custom_path=local_path)
|
||||
logger.debug(f"[Telegram] downloaded file_id={file_id} -> {local_path}")
|
||||
return local_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] download_file failed (file_id={file_id}): {e}")
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group trigger logic
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _should_reply_in_group(self, update) -> bool:
|
||||
"""Decide whether to reply to a group message based on configuration."""
|
||||
mode = conf().get("telegram_group_trigger", "mention_or_reply")
|
||||
if mode == "all":
|
||||
return True
|
||||
|
||||
message = update.effective_message
|
||||
if not message:
|
||||
return False
|
||||
|
||||
# 1) Mentioned
|
||||
if self.bot_username and self._is_mentioned(message, self.bot_username):
|
||||
return True
|
||||
|
||||
# 2) Reply to a bot message
|
||||
if mode == "mention_or_reply":
|
||||
reply = message.reply_to_message
|
||||
if reply and reply.from_user and reply.from_user.username == self.bot_username:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_mentioned(message, bot_username: str) -> bool:
|
||||
"""Check whether entities/caption_entities contain a @mention of the bot."""
|
||||
bot_at = "@" + bot_username.lower()
|
||||
text = (message.text or message.caption or "").lower()
|
||||
if bot_at in text:
|
||||
return True
|
||||
# Also check entities strictly to support text_mention (no-username @)
|
||||
for ent in (message.entities or []) + (message.caption_entities or []):
|
||||
if ent.type == "mention":
|
||||
src = message.text or message.caption or ""
|
||||
if src[ent.offset: ent.offset + ent.length].lower() == bot_at:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _strip_at_mention(self, content: str) -> str:
|
||||
"""Strip @bot_username from group text (case-insensitive)."""
|
||||
if not content or not self.bot_username:
|
||||
return content
|
||||
pattern = re.compile(r"@" + re.escape(self.bot_username), re.IGNORECASE)
|
||||
return pattern.sub("", content).strip()
|
||||
|
||||
@staticmethod
|
||||
def _compute_session_id(update) -> str:
|
||||
chat = update.effective_chat
|
||||
user = update.effective_user
|
||||
is_group = chat.type in ("group", "supergroup")
|
||||
if is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
return f"tg_group_{chat.id}"
|
||||
return f"tg_group_{chat.id}_{user.id}"
|
||||
return f"tg_user_{user.id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override _compose_context: skip the parent's group whitelist/at checks
|
||||
# (already handled in _on_message via _should_reply_in_group). Same idea
|
||||
# as the feishu channel.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
if cmsg.is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = (content or "").strip()
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif ctype == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and (
|
||||
conf().get("voice_reply_voice") or conf().get("always_reply_voice")
|
||||
):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound: ChatChannel.send -> Telegram API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""Called from cow's sync main thread; we marshal the coroutine onto the loop thread."""
|
||||
if self._loop is None or self._bot is None:
|
||||
logger.warning("[Telegram] bot not ready, drop reply")
|
||||
return
|
||||
|
||||
chat_id = context.get("telegram_chat_id")
|
||||
reply_to = context.get("telegram_reply_to_msg_id")
|
||||
if chat_id is None:
|
||||
logger.warning("[Telegram] no telegram_chat_id in context, drop reply")
|
||||
return
|
||||
|
||||
coro = self._async_send(reply, chat_id, reply_to)
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
# Media uploads through a proxy can be slow; let PTB's own timeouts win
|
||||
future.result(timeout=180)
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] send failed: {e}")
|
||||
|
||||
# Number of retries for transient network errors (proxy hiccups etc.)
|
||||
_SEND_RETRIES = 2
|
||||
_SEND_RETRY_BACKOFF = 2.0 # seconds
|
||||
|
||||
async def _send_with_retry(self, send_fn, *, label: str):
|
||||
"""Run a single Telegram API call with retries for transient network errors."""
|
||||
from telegram.error import NetworkError, TimedOut
|
||||
last_err = None
|
||||
for attempt in range(self._SEND_RETRIES + 1):
|
||||
try:
|
||||
return await send_fn()
|
||||
except (NetworkError, TimedOut) as e:
|
||||
last_err = e
|
||||
if attempt >= self._SEND_RETRIES:
|
||||
break
|
||||
wait = self._SEND_RETRY_BACKOFF * (attempt + 1)
|
||||
logger.warning(
|
||||
f"[Telegram] {label} transient error (attempt {attempt + 1}/"
|
||||
f"{self._SEND_RETRIES + 1}): {e}; retry in {wait}s"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
raise last_err
|
||||
|
||||
async def _async_send(self, reply: Reply, chat_id, reply_to_msg_id):
|
||||
try:
|
||||
rtype = reply.type
|
||||
content = reply.content
|
||||
|
||||
if rtype == ReplyType.TEXT or rtype == ReplyType.INFO or rtype == ReplyType.ERROR:
|
||||
# Telegram caps a single text message at 4096 chars; auto-split
|
||||
text = str(content) if content is not None else ""
|
||||
if not text:
|
||||
return
|
||||
for chunk in _split_text(text, 4000):
|
||||
await self._send_with_retry(
|
||||
lambda c=chunk: self._bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=c,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
# Avoid failing the whole send if reply_to was deleted
|
||||
allow_sending_without_reply=True,
|
||||
),
|
||||
label="send_message",
|
||||
)
|
||||
|
||||
elif rtype == ReplyType.IMAGE:
|
||||
# Already a local BytesIO; send it directly
|
||||
content.seek(0)
|
||||
await self._send_with_retry(
|
||||
lambda: self._bot.send_photo(
|
||||
chat_id=chat_id,
|
||||
photo=content,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
),
|
||||
label="send_photo",
|
||||
)
|
||||
|
||||
elif rtype == ReplyType.IMAGE_URL:
|
||||
url = str(content)
|
||||
if url.startswith("file://"):
|
||||
local = url[7:]
|
||||
# Open inside the lambda so each retry gets a fresh stream
|
||||
async def _send_local_photo():
|
||||
with open(local, "rb") as f:
|
||||
return await self._bot.send_photo(
|
||||
chat_id=chat_id, photo=f,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
)
|
||||
await self._send_with_retry(_send_local_photo, label="send_photo(file)")
|
||||
else:
|
||||
await self._send_with_retry(
|
||||
lambda: self._bot.send_photo(
|
||||
chat_id=chat_id, photo=url,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
),
|
||||
label="send_photo(url)",
|
||||
)
|
||||
|
||||
elif rtype == ReplyType.VOICE:
|
||||
local = content[7:] if isinstance(content, str) and content.startswith("file://") else content
|
||||
async def _send_voice():
|
||||
with open(local, "rb") as f:
|
||||
return await self._bot.send_voice(
|
||||
chat_id=chat_id, voice=f,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
)
|
||||
await self._send_with_retry(_send_voice, label="send_voice")
|
||||
|
||||
elif rtype == ReplyType.FILE:
|
||||
# Videos go through send_video, everything else through send_document
|
||||
local = content[7:] if isinstance(content, str) and content.startswith("file://") else content
|
||||
# File replies may carry an accompanying text caption
|
||||
caption = getattr(reply, "text_content", None) or None
|
||||
is_video = isinstance(local, str) and local.lower().endswith(
|
||||
(".mp4", ".mov", ".avi", ".mkv", ".webm")
|
||||
)
|
||||
|
||||
async def _send_file():
|
||||
with open(local, "rb") as f:
|
||||
if is_video:
|
||||
return await self._bot.send_video(
|
||||
chat_id=chat_id, video=f, caption=caption,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
)
|
||||
return await self._bot.send_document(
|
||||
chat_id=chat_id, document=f, caption=caption,
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
)
|
||||
await self._send_with_retry(_send_file, label="send_video" if is_video else "send_document")
|
||||
|
||||
else:
|
||||
# Fallback: send as plain text
|
||||
await self._send_with_retry(
|
||||
lambda: self._bot.send_message(
|
||||
chat_id=chat_id, text=str(content),
|
||||
reply_to_message_id=reply_to_msg_id,
|
||||
allow_sending_without_reply=True,
|
||||
),
|
||||
label="send_message(fallback)",
|
||||
)
|
||||
|
||||
logger.info(f"[Telegram] sent reply (type={rtype}, chat_id={chat_id})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Telegram] _async_send error: {e}", exc_info=True)
|
||||
|
||||
|
||||
def _split_text(text: str, limit: int):
|
||||
"""Split long text preferring line breaks to keep markdown structure intact."""
|
||||
if len(text) <= limit:
|
||||
yield text
|
||||
return
|
||||
buf = []
|
||||
size = 0
|
||||
for line in text.splitlines(keepends=True):
|
||||
if size + len(line) > limit and buf:
|
||||
yield "".join(buf)
|
||||
buf, size = [], 0
|
||||
# Hard-split single lines that exceed the limit
|
||||
while len(line) > limit:
|
||||
yield line[:limit]
|
||||
line = line[limit:]
|
||||
buf.append(line)
|
||||
size += len(line)
|
||||
if buf:
|
||||
yield "".join(buf)
|
||||
62
channel/telegram/telegram_message.py
Normal file
62
channel/telegram/telegram_message.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Telegram message adapter.
|
||||
|
||||
Convert a python-telegram-bot Update into cow's unified ChatMessage.
|
||||
File downloads are NOT performed here; the channel layer triggers
|
||||
bot.get_file() on demand because it requires the async event loop.
|
||||
"""
|
||||
import os
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
class TelegramMessage(ChatMessage):
|
||||
"""Wrap a Telegram Update into the unified ChatMessage."""
|
||||
|
||||
def __init__(self, update, is_group: bool = False, bot_username: str = "",
|
||||
ctype: ContextType = ContextType.TEXT, content: str = ""):
|
||||
super().__init__(update)
|
||||
message = update.effective_message
|
||||
chat = update.effective_chat
|
||||
user = update.effective_user
|
||||
|
||||
# Basic fields
|
||||
self.msg_id = str(message.message_id) if message else ""
|
||||
self.create_time = int(message.date.timestamp()) if message and message.date else 0
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
|
||||
# Sender / chat info
|
||||
from_user_id = str(user.id) if user else "unknown"
|
||||
from_user_nick = (
|
||||
user.full_name if user and user.full_name else (user.username if user else "unknown")
|
||||
)
|
||||
self.from_user_id = from_user_id
|
||||
self.from_user_nickname = from_user_nick or from_user_id
|
||||
self.to_user_id = bot_username or "telegram_bot"
|
||||
self.to_user_nickname = bot_username or "telegram_bot"
|
||||
|
||||
self.is_group = is_group
|
||||
if is_group:
|
||||
# Group: other_user_id = group_id, actual_user_id = sender id
|
||||
self.other_user_id = str(chat.id)
|
||||
self.other_user_nickname = chat.title or str(chat.id)
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = self.from_user_nickname
|
||||
else:
|
||||
self.other_user_id = from_user_id
|
||||
self.other_user_nickname = self.from_user_nickname
|
||||
|
||||
# Whether the bot was triggered by @-mention or reply (set by channel layer)
|
||||
self.is_at = False
|
||||
|
||||
@staticmethod
|
||||
def get_tmp_dir() -> str:
|
||||
"""Local download directory, aligned with other channels (agent_workspace/tmp)."""
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user