mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 09:48:22 +08:00
Compare commits
1093 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
feaa9076b0 | ||
|
|
ce0249706e | ||
|
|
af2c839231 | ||
|
|
2b2d24ed25 | ||
|
|
1dbf41f384 | ||
|
|
9e6a2cc2c0 | ||
|
|
7bf4ef3d05 | ||
|
|
126649f70f | ||
|
|
1827a2a31c | ||
|
|
fcf4eb78dc | ||
|
|
2ec6ea8045 | ||
|
|
0994a3586d | ||
|
|
29c4be6a3a | ||
|
|
c5b8e06891 | ||
|
|
54a20bca92 | ||
|
|
6e786bde90 | ||
|
|
b671b0d725 | ||
|
|
57f5692074 | ||
|
|
b0ac0731c7 | ||
|
|
3c161df526 | ||
|
|
aa3f48e93c | ||
|
|
5ae1e1adde | ||
|
|
fe8b8fe831 | ||
|
|
5aca54c083 | ||
|
|
458b1a1d88 | ||
|
|
3dd4b84179 | ||
|
|
99bddb79d6 | ||
|
|
136b0b89e8 | ||
|
|
c605b0b080 | ||
|
|
b7b8e3679c | ||
|
|
aeb6610ff4 | ||
|
|
e3eacc77d7 | ||
|
|
37661daf40 | ||
|
|
877b848370 | ||
|
|
5c163cc0fe | ||
|
|
6e04ea8240 | ||
|
|
d106465419 | ||
|
|
f39380cea7 | ||
|
|
bccce2d7cb | ||
|
|
6721dbdbcc | ||
|
|
83cd6ad158 | ||
|
|
116fb27257 | ||
|
|
8d67177a1b | ||
|
|
ad2db1a776 | ||
|
|
2e6d9e0f27 | ||
|
|
e05f85f3ce | ||
|
|
40c48a9a61 | ||
|
|
c9a7525d0b | ||
|
|
fd571ac539 | ||
|
|
c5a3f991c5 | ||
|
|
eb74b73351 | ||
|
|
9b31f45481 | ||
|
|
bc9c1691f5 | ||
|
|
73bf83d2ff | ||
|
|
36e1988fee | ||
|
|
aad6ef635e | ||
|
|
96659cd616 | ||
|
|
c8787b7de4 | ||
|
|
91d427c8f9 | ||
|
|
c8c0573dbd | ||
|
|
29af855ecd | ||
|
|
0a146a245d | ||
|
|
bd85fee7d7 | ||
|
|
571897e2fd | ||
|
|
840dabeccd | ||
|
|
069bffa3e8 | ||
|
|
cc10d230b0 | ||
|
|
2517f2add8 | ||
|
|
a534266025 | ||
|
|
8c25395805 | ||
|
|
36b913124b | ||
|
|
2fa6343fe5 | ||
|
|
06b84225a1 | ||
|
|
5b31da335d | ||
|
|
90773ab69f | ||
|
|
11d92bb22a | ||
|
|
b7734c3926 | ||
|
|
d3faf9c8dc | ||
|
|
bca97a1d14 | ||
|
|
ac9d0f18c5 | ||
|
|
09fa624797 | ||
|
|
b8333e351c | ||
|
|
a01423a196 | ||
|
|
7c35df7a82 | ||
|
|
2b90f377e6 | ||
|
|
fff7326209 | ||
|
|
c181e500bc | ||
|
|
16b7271826 | ||
|
|
4a1f62b185 | ||
|
|
d23a0754c1 | ||
|
|
3ffb563a44 | ||
|
|
4e42f2a017 | ||
|
|
a0dfdb79df | ||
|
|
a85c5f9d4e | ||
|
|
2720bba5b7 | ||
|
|
4634a7bc2f | ||
|
|
16d9b449c9 | ||
|
|
8761997757 | ||
|
|
19bba4abbc | ||
|
|
7839f0aac5 | ||
|
|
83def1db30 | ||
|
|
a0b29d1ffe | ||
|
|
f5479c56af | ||
|
|
246f0a45c8 | ||
|
|
fe871aad77 | ||
|
|
6f860e1bc4 | ||
|
|
249ea40ae3 | ||
|
|
20d8ae19a7 | ||
|
|
ad51aabfd7 | ||
|
|
1cf395c041 | ||
|
|
745179a5bf | ||
|
|
ff5d477fa5 | ||
|
|
907825601d | ||
|
|
c2ec26910a | ||
|
|
83f2aea123 | ||
|
|
a5c5439315 | ||
|
|
eca9b60235 | ||
|
|
d2d5d98d78 | ||
|
|
fb341b869b | ||
|
|
29e66cb186 | ||
|
|
307769b949 | ||
|
|
9a09e057d6 | ||
|
|
3e28659528 | ||
|
|
b861eef26f | ||
|
|
caaf006a49 | ||
|
|
b2429ec30c | ||
|
|
55aaf60a57 | ||
|
|
a5790d82f6 | ||
|
|
63f99af1e6 | ||
|
|
4eed2568aa | ||
|
|
fb7962c7f2 | ||
|
|
76e6b7b471 | ||
|
|
fccb7ff9ed | ||
|
|
3b12ef2e66 | ||
|
|
f9d099be1b | ||
|
|
c322c0e3a5 | ||
|
|
530fc20596 | ||
|
|
a23b4ed754 | ||
|
|
fc4f5077b0 | ||
|
|
6a553886da | ||
|
|
1065c7e722 | ||
|
|
a9c8a59f58 | ||
|
|
8730f7fd27 | ||
|
|
8f608223d7 | ||
|
|
a7cbd47a2f | ||
|
|
b80c3fe5a8 | ||
|
|
5080051e39 | ||
|
|
23bfc8d0ba | ||
|
|
80e9062041 | ||
|
|
67bd3420ed | ||
|
|
aea081703f | ||
|
|
f300d2a2d5 | ||
|
|
f150d7d83a | ||
|
|
4d1f059c0d | ||
|
|
bc7f953fcc | ||
|
|
f653483eea | ||
|
|
6b200fd36b | ||
|
|
161fc6cdf0 | ||
|
|
6f68ed6bce | ||
|
|
a4592ffdfe | ||
|
|
7cd7bd1a48 | ||
|
|
9eeca70292 | ||
|
|
02bfe30848 | ||
|
|
c9c99de3d9 | ||
|
|
8752f0cc60 | ||
|
|
5c65196e44 | ||
|
|
f5798bfe90 | ||
|
|
0e556b3468 | ||
|
|
31820f56e7 | ||
|
|
fd88828abd | ||
|
|
ae11159918 | ||
|
|
472a8605c0 | ||
|
|
e1760ba211 | ||
|
|
ce4c0a0aa4 | ||
|
|
64511593c4 | ||
|
|
b0e00dfceb | ||
|
|
fc465b463d | ||
|
|
68ce2e5232 | ||
|
|
81e8bb62ae | ||
|
|
2c13e1b923 | ||
|
|
a0748c2e3b | ||
|
|
40599bb751 | ||
|
|
f3c64ceea7 | ||
|
|
15c60de709 | ||
|
|
6dd316547f | ||
|
|
54c7676a44 | ||
|
|
d25b8966ce | ||
|
|
14a119c48c | ||
|
|
c82515a927 | ||
|
|
26e630c2dd | ||
|
|
13370d2056 | ||
|
|
35282db9e0 | ||
|
|
426fb88ce7 | ||
|
|
2384bd0e10 | ||
|
|
ba3f66d3d1 | ||
|
|
7293a0f670 | ||
|
|
9e86d46267 | ||
|
|
848430f062 | ||
|
|
abd21335c4 | ||
|
|
8fa95f058a | ||
|
|
d4e5ecd497 | ||
|
|
3830f76729 | ||
|
|
83f778fec9 | ||
|
|
cabd24605f | ||
|
|
ae20ba1148 | ||
|
|
3a50b64977 | ||
|
|
8692e74536 | ||
|
|
1c18bd9889 | ||
|
|
60e9d98d0a | ||
|
|
83f6625e0c | ||
|
|
acc09543b7 | ||
|
|
94d8c7e366 | ||
|
|
ea1a0c8b3d | ||
|
|
7bc88c17e4 | ||
|
|
33cf1bc4c3 | ||
|
|
9402e63fe1 | ||
|
|
90e4d494b2 | ||
|
|
da97e948ca | ||
|
|
89a07e8e74 | ||
|
|
3f3d0381e5 | ||
|
|
3649499dba | ||
|
|
a989d088fd | ||
|
|
f79a915136 | ||
|
|
12e8c3d449 | ||
|
|
4f7064575e | ||
|
|
070df826f1 | ||
|
|
fbe48a4b4e | ||
|
|
4dd497fb6d | ||
|
|
907882c0a7 | ||
|
|
d36d5aee3f | ||
|
|
c6824e5f5e | ||
|
|
199c21eede | ||
|
|
5162da5654 | ||
|
|
a1d82f6193 | ||
|
|
ea78e3d0c6 | ||
|
|
3497f00cb4 | ||
|
|
5355d45031 | ||
|
|
26693acc3f | ||
|
|
76e9fef3b2 | ||
|
|
c34308cbd4 | ||
|
|
5a10476010 | ||
|
|
46e80dceec | ||
|
|
90d1835353 | ||
|
|
845fadd0aa | ||
|
|
5748ded52c | ||
|
|
6a737fb734 | ||
|
|
3cd92ccda3 | ||
|
|
54e81aba11 | ||
|
|
d86cb4ded6 | ||
|
|
4d5375f6d6 | ||
|
|
424557fedb | ||
|
|
89251e603f | ||
|
|
a653ed07eb | ||
|
|
ad86deb014 | ||
|
|
9525dc7584 | ||
|
|
cd31dd27fd | ||
|
|
360e3670eb | ||
|
|
8dabe3b4c8 | ||
|
|
443e0c2806 | ||
|
|
9cc173cc4d | ||
|
|
b5f33e5ecd | ||
|
|
40dfc6860f | ||
|
|
1c02a04423 | ||
|
|
de0e45070c | ||
|
|
c169cc7d74 | ||
|
|
cd62ad76f6 | ||
|
|
dd25b0fb5b | ||
|
|
a38b22a6a2 | ||
|
|
830b8f2971 | ||
|
|
b058af122c | ||
|
|
174ee0cafc | ||
|
|
1c336380c0 | ||
|
|
3068880413 | ||
|
|
be596681e5 | ||
|
|
66b71c50e9 | ||
|
|
8744810b25 | ||
|
|
7f94d37c2e | ||
|
|
6d9b7baeb4 | ||
|
|
4470d4c352 | ||
|
|
d2a462a279 | ||
|
|
14ff2a15e7 | ||
|
|
6d1369900e | ||
|
|
1f17ebe69e | ||
|
|
1ae2918064 | ||
|
|
b6571e5cad | ||
|
|
7549d48cf1 | ||
|
|
00353dd0cb | ||
|
|
afd947195d | ||
|
|
e57ef37167 | ||
|
|
ef33a93654 | ||
|
|
61732aecfc | ||
|
|
6764c05c3f | ||
|
|
fa149cf4aa | ||
|
|
e4f9697d06 | ||
|
|
da061450e5 | ||
|
|
d09ae49287 | ||
|
|
511ee0bbaf | ||
|
|
3cb5a0fbd6 | ||
|
|
e06925ab85 | ||
|
|
184634e4e7 | ||
|
|
843c2d02cc | ||
|
|
8ea2455766 | ||
|
|
9dc9987d56 | ||
|
|
3458621147 | ||
|
|
079df5a47c | ||
|
|
ddb07c65a1 | ||
|
|
9b21cd222b | ||
|
|
90f736843f | ||
|
|
13c020eb61 | ||
|
|
dbc06dbe95 | ||
|
|
23d097bc1c | ||
|
|
db85b9808e | ||
|
|
df5bae37bc | ||
|
|
acc23b6051 | ||
|
|
61f2741afc | ||
|
|
4dd7ea886a | ||
|
|
1e8959fbcf | ||
|
|
48729678cf | ||
|
|
0684becaa7 | ||
|
|
db16bdf8cb | ||
|
|
f890318ed9 | ||
|
|
158510cbbe | ||
|
|
ce90cf7aa8 | ||
|
|
a3a3d006eb | ||
|
|
8fd029a4a1 | ||
|
|
2e1b52c1e5 | ||
|
|
3eb8348708 | ||
|
|
393f0c007c | ||
|
|
294e380288 | ||
|
|
4c1c42efac | ||
|
|
c062ca8c66 | ||
|
|
76dcb25103 | ||
|
|
c5b4f236db | ||
|
|
0974c940a8 | ||
|
|
cffa20d37e | ||
|
|
ef009edd29 | ||
|
|
3ca52b118d | ||
|
|
13f5fde4fb | ||
|
|
f512b55ec2 | ||
|
|
22b8ca0095 | ||
|
|
baf66a103d | ||
|
|
45faa9c1ff | ||
|
|
304381a88d | ||
|
|
fc9f54dbc8 | ||
|
|
7199dc187f | ||
|
|
e9ae066d53 | ||
|
|
d71ae406ff | ||
|
|
f3216904b3 | ||
|
|
5958b69ec9 | ||
|
|
7d4e2cb39a | ||
|
|
a483ec0cea | ||
|
|
c1421e0874 | ||
|
|
ce89869c3c | ||
|
|
b8b57e34ff | ||
|
|
bc7f627253 | ||
|
|
652156e398 | ||
|
|
9febb071c6 | ||
|
|
7d0e1568ac | ||
|
|
b4e711f411 | ||
|
|
1b5be1b981 | ||
|
|
49d8707c58 | ||
|
|
9192f6f7f7 | ||
|
|
05022e3745 | ||
|
|
5356e9ddeb | ||
|
|
52acf76e2c | ||
|
|
40cdbd3b45 | ||
|
|
5487c0befe | ||
|
|
8bb16c48c0 | ||
|
|
c6384363f9 | ||
|
|
8993e8ad3e | ||
|
|
289989d9f7 | ||
|
|
dc2ae0e6f1 | ||
|
|
9c966c152d | ||
|
|
4efae41048 | ||
|
|
b8437032e9 | ||
|
|
2d339ca81b | ||
|
|
d53abc9696 | ||
|
|
446c886d38 | ||
|
|
30c6d9b5ae | ||
|
|
5e42996b36 | ||
|
|
ceca7b85bf | ||
|
|
a4d54f58c8 | ||
|
|
005a0e1bad | ||
|
|
46d97fd57d | ||
|
|
72a26b6353 | ||
|
|
89a4033fbf | ||
|
|
39a5dc64bd | ||
|
|
d4bdd9b1b7 | ||
|
|
2f5ba87280 | ||
|
|
8b45d6c750 | ||
|
|
4ecd4df2d4 | ||
|
|
a42f31fe52 | ||
|
|
d4480b695e | ||
|
|
c4b5f7fbae | ||
|
|
ba915f2cc0 | ||
|
|
4b91140f31 | ||
|
|
9879878dd0 | ||
|
|
d78105d57c | ||
|
|
153c9e3565 | ||
|
|
c11623596d | ||
|
|
e791a77f77 | ||
|
|
b641bffb2c | ||
|
|
ee0c47ac1e | ||
|
|
eba90e9343 | ||
|
|
d8374d0fa5 | ||
|
|
fa61744c6d | ||
|
|
4fec55cc01 | ||
|
|
1767413712 | ||
|
|
734c8fa84f | ||
|
|
9a8d422554 | ||
|
|
b21e945c76 | ||
|
|
a02bf1ea09 | ||
|
|
eda82bac92 | ||
|
|
e8d4f7dc4f | ||
|
|
c4a93b7789 | ||
|
|
c3f9925097 | ||
|
|
2a0cf7511a | ||
|
|
d0a70d3339 | ||
|
|
f37e4675dd | ||
|
|
4e32f67eeb | ||
|
|
36d54cab52 | ||
|
|
9d8df10dcf | ||
|
|
45ea88e070 | ||
|
|
d5d0b947f5 | ||
|
|
f775f1f11e | ||
|
|
f1e888f3de | ||
|
|
71c8436e90 | ||
|
|
08c69f5e9b | ||
|
|
a50fafaca2 | ||
|
|
3c6781d240 | ||
|
|
3b8b5625f8 | ||
|
|
6be2034110 | ||
|
|
924dc79f00 | ||
|
|
ccb9030d3c | ||
|
|
8623287ac1 | ||
|
|
022c13f3a4 | ||
|
|
0687916e7f | ||
|
|
bb868b83ba | ||
|
|
24298130b9 | ||
|
|
6e5ee92ebd | ||
|
|
5b91fe04aa | ||
|
|
1623deb3ee | ||
|
|
4a16e05b7a | ||
|
|
f1c04bc60d | ||
|
|
84c6f31c76 | ||
|
|
9d528190bf | ||
|
|
0f23b209ad | ||
|
|
63d9325900 | ||
|
|
f342097f81 | ||
|
|
b4806c4366 | ||
|
|
ff37d8a577 | ||
|
|
a773eb7893 | ||
|
|
7c67513d24 | ||
|
|
6ed85029c5 | ||
|
|
e9c57ddf4d | ||
|
|
a33ce97ed9 | ||
|
|
b788a3dd4e | ||
|
|
fccfa92d7e | ||
|
|
8705bf0a70 | ||
|
|
9318138af7 | ||
|
|
269fa7d2d5 | ||
|
|
e99837a8b9 | ||
|
|
553861a2c4 | ||
|
|
628a85d1be | ||
|
|
2cb54514a4 | ||
|
|
6db22827f2 | ||
|
|
4cc6d5426b | ||
|
|
7d258b5202 | ||
|
|
c8d19ee0bc | ||
|
|
d891312032 | ||
|
|
5edbf4ce32 | ||
|
|
3ddbdd713d | ||
|
|
9ba107b511 | ||
|
|
c9adddb76a | ||
|
|
f0a12d5ff5 | ||
|
|
7cce224499 | ||
|
|
97397ca585 | ||
|
|
f2fbc602a8 | ||
|
|
925d728a86 | ||
|
|
f5f229871b | ||
|
|
9917552b4b | ||
|
|
adca89b973 | ||
|
|
29bfbecdc9 | ||
|
|
1a7a8c98d9 | ||
|
|
cddb38ac3d | ||
|
|
394853c0fb | ||
|
|
c0702c8b36 | ||
|
|
d610608391 | ||
|
|
9082eec91d | ||
|
|
f1a1413b5f | ||
|
|
c1e7f9af9b | ||
|
|
1c71c4e38b | ||
|
|
5e3eccb3f6 | ||
|
|
e1dc037eb9 | ||
|
|
97e9b4c801 | ||
|
|
52d7cad735 | ||
|
|
c0b1d270ba | ||
|
|
e59a2892e4 | ||
|
|
5fa0376a49 | ||
|
|
05a33042c8 | ||
|
|
ce58f23cbc | ||
|
|
b6fc9fa370 | ||
|
|
00ae38faae | ||
|
|
ab28ee58ab | ||
|
|
48db538a2e | ||
|
|
46945942e1 | ||
|
|
a24b26a1ef | ||
|
|
6f8421cdd5 | ||
|
|
284cd9bca9 | ||
|
|
23fd6b8d2b | ||
|
|
4f0ea5d756 | ||
|
|
6c218331b1 | ||
|
|
cea7fb7490 | ||
|
|
8acf2dbdfe | ||
|
|
0542700f90 | ||
|
|
5264f7ce18 | ||
|
|
051ffd78a3 | ||
|
|
bea95d4fae | ||
|
|
fdf7bc312f | ||
|
|
5b094e1097 | ||
|
|
9ad3968084 | ||
|
|
3958b6aae1 | ||
|
|
eaa413caf0 | ||
|
|
9095225b5b | ||
|
|
c529f86dbc | ||
|
|
e4fcfa356a | ||
|
|
8218cff7c1 | ||
|
|
6949bbcf39 | ||
|
|
480c60c0a7 | ||
|
|
eec10cb5db | ||
|
|
02c83d8689 | ||
|
|
72b1cacea1 | ||
|
|
c72cda3386 | ||
|
|
867442155e | ||
|
|
229b14b6fc | ||
|
|
158c87ab8b | ||
|
|
cb303e6109 | ||
|
|
a77a8741b5 | ||
|
|
3d63459c25 | ||
|
|
ce63de3c58 | ||
|
|
4b3b1219b5 | ||
|
|
73b069a76c | ||
|
|
101cf8d108 | ||
|
|
2e926dfb6e | ||
|
|
501866d12a | ||
|
|
39bcb0869f | ||
|
|
a7b99cde4e | ||
|
|
60abcd92a3 | ||
|
|
cdd36e7052 | ||
|
|
c6ac175ce4 | ||
|
|
46bcd87c23 | ||
|
|
ab74be8e33 | ||
|
|
d8298b3eab | ||
|
|
50e60e6d05 | ||
|
|
5d02acbf37 | ||
|
|
8901d91f96 | ||
|
|
b55021bb3d | ||
|
|
0ef51b85e6 | ||
|
|
c77566cc02 | ||
|
|
c1bcedfb51 | ||
|
|
d085a3c7d7 | ||
|
|
46fa07e4a9 | ||
|
|
a8d5309c90 | ||
|
|
77c2bfcc1e | ||
|
|
4c8712d683 | ||
|
|
d337140577 | ||
|
|
99c273a293 | ||
|
|
85578a06b7 | ||
|
|
6f70a8efda | ||
|
|
c693e39196 | ||
|
|
4a1fae3cb4 | ||
|
|
08b592816b | ||
|
|
0e85fcfe51 | ||
|
|
8ef788e799 | ||
|
|
645c8899b1 | ||
|
|
9bf5b0fc48 | ||
|
|
07959a3bff | ||
|
|
86a6182e41 | ||
|
|
89e229ab75 | ||
|
|
624917fac4 | ||
|
|
489894c61d | ||
|
|
ac87979cb7 | ||
|
|
5fd3e85a83 | ||
|
|
0e53ba4311 | ||
|
|
3ce57ef851 | ||
|
|
481570d059 | ||
|
|
04442b7ddb | ||
|
|
e1a71723bc | ||
|
|
f044fb8b47 | ||
|
|
e3350d5bec | ||
|
|
8a69d4354e | ||
|
|
dd6a9c26bd | ||
|
|
49fb4034c6 | ||
|
|
5a466d0ff6 | ||
|
|
bb850bb6c5 | ||
|
|
25cf6823d0 | ||
|
|
7e12744b8b | ||
|
|
8f2432e0f8 | ||
|
|
94451db638 | ||
|
|
f8b8eeec3a | ||
|
|
a4260cc5de | ||
|
|
8c1622798b | ||
|
|
e75bed1be5 | ||
|
|
8c0517de0f | ||
|
|
94e78365a5 | ||
|
|
29c056ca65 | ||
|
|
d8c57f27db | ||
|
|
3cac2bad55 | ||
|
|
e7905fdf49 | ||
|
|
a492bc2242 | ||
|
|
e663364f64 | ||
|
|
ef6466e26f | ||
|
|
7fcbbf1cdc | ||
|
|
ec6ad51ff7 | ||
|
|
1e80c59448 | ||
|
|
e48cb4fd5d | ||
|
|
7c9fbd2625 | ||
|
|
0f504415fb | ||
|
|
4998c324d1 | ||
|
|
fb5fbe76e8 | ||
|
|
223b0bfc88 | ||
|
|
51094a68c8 | ||
|
|
83cb1ec911 | ||
|
|
a77e4bfb7a | ||
|
|
654c177333 | ||
|
|
b92669ba33 | ||
|
|
f2e4f6607d | ||
|
|
5ec909c565 | ||
|
|
a84f31d54a | ||
|
|
e0dd21406d | ||
|
|
72f5f7a0b8 | ||
|
|
e3d20085c5 | ||
|
|
8bf1aef801 | ||
|
|
5f7ade20dc | ||
|
|
70d7e52df0 | ||
|
|
8e6afa5614 | ||
|
|
a1ae3804e3 | ||
|
|
814ce7a43b | ||
|
|
628f75009e | ||
|
|
03fc8c1202 | ||
|
|
8c8e996c87 | ||
|
|
933bb0b1fb | ||
|
|
931fbc3eb5 | ||
|
|
3db5e70a3d | ||
|
|
7b19b70d90 | ||
|
|
99b8103d70 | ||
|
|
7167310ccd | ||
|
|
263667a2d4 | ||
|
|
d5cef291f6 | ||
|
|
c8d166e833 | ||
|
|
6e25782d8b | ||
|
|
c3127f7e84 | ||
|
|
7b90fb018b | ||
|
|
e8bc173cd7 | ||
|
|
4d1cdf5207 | ||
|
|
57a473364e | ||
|
|
40b62e9d38 | ||
|
|
ead5f9926b | ||
|
|
814b6753c2 | ||
|
|
ce505251f8 | ||
|
|
5d2a987aaa | ||
|
|
4d67e08723 | ||
|
|
2e71dd5fe2 | ||
|
|
c3b9643227 | ||
|
|
0aad5dc2b7 | ||
|
|
cec900168f | ||
|
|
f9b1c403d5 | ||
|
|
9024b602f5 | ||
|
|
c139fd9a57 | ||
|
|
e299b68163 | ||
|
|
7777a53a82 | ||
|
|
3e185dbbfe | ||
|
|
e8a32af369 | ||
|
|
7b0ec6687e | ||
|
|
ec1c6c7b92 | ||
|
|
8dfaa86760 | ||
|
|
323aebd1be | ||
|
|
436c038a2f | ||
|
|
ccd50ec6c0 | ||
|
|
a7541c2c0f | ||
|
|
c3a57d756c | ||
|
|
aa300a4c98 | ||
|
|
83ea7352b9 | ||
|
|
9050712cd8 | ||
|
|
8d92fdbb6e | ||
|
|
a2442ec1b9 | ||
|
|
71662c9cd9 | ||
|
|
54ff5dbcc2 | ||
|
|
4ab7bd3b51 | ||
|
|
ef3c61a297 | ||
|
|
abf79bf60c | ||
|
|
5d3cecd926 | ||
|
|
16324e7283 | ||
|
|
9f7e2e1572 | ||
|
|
857ce1d530 | ||
|
|
be0d72775d | ||
|
|
7832a2495b | ||
|
|
0506b7f735 | ||
|
|
4c0b7942f0 | ||
|
|
651c840c4a | ||
|
|
2a351ca415 | ||
|
|
49b7106d71 | ||
|
|
8bf633f539 | ||
|
|
0f8efcb4b0 | ||
|
|
c567641c5c | ||
|
|
bdc3820382 | ||
|
|
33a69a7907 | ||
|
|
a4d0e9bbc3 | ||
|
|
afc753e1d2 | ||
|
|
e641a41224 | ||
|
|
79305c0632 | ||
|
|
ef2ce3f09d | ||
|
|
71c18c04fc | ||
|
|
cf84e57f81 | ||
|
|
9421d44579 | ||
|
|
5cd2ae8cc8 | ||
|
|
22d67b3a59 | ||
|
|
e102cbb8c4 | ||
|
|
d90eeb7ee4 | ||
|
|
1989d53031 | ||
|
|
04ef0907b4 | ||
|
|
517b43561c | ||
|
|
ccb8c7227f | ||
|
|
9fbfeeb04f | ||
|
|
8b753a5a1f | ||
|
|
d25cab0627 | ||
|
|
84da0a8a35 | ||
|
|
6f665cffba | ||
|
|
aea8ac2e97 | ||
|
|
8418fa7b45 | ||
|
|
9cc4d0ee07 | ||
|
|
da60831c44 | ||
|
|
0773174a20 | ||
|
|
70e007d8ca | ||
|
|
fcc4d02c2f | ||
|
|
f4a5f00593 | ||
|
|
1170ed6566 | ||
|
|
883f0d449b | ||
|
|
f4c62e7844 | ||
|
|
f0d212a9d2 | ||
|
|
76a8974034 | ||
|
|
0614e822f4 | ||
|
|
6f682c9a2e | ||
|
|
a9fdbc31c5 | ||
|
|
086fdb5856 | ||
|
|
63c8ef4f17 | ||
|
|
736f6523c7 | ||
|
|
8b0b360d25 | ||
|
|
80b84e2ee6 | ||
|
|
b5b7d86f7b | ||
|
|
f20d704390 | ||
|
|
e4e1e2e944 | ||
|
|
6bc7eeb4cc | ||
|
|
656ed5de7b | ||
|
|
a11d695c78 | ||
|
|
c4f9acd5c5 | ||
|
|
5ef929dc42 | ||
|
|
c8cf27b544 | ||
|
|
bb5ecfc398 | ||
|
|
c91e7c35bb | ||
|
|
532d56df2d | ||
|
|
111ad44029 | ||
|
|
6b02bae957 | ||
|
|
6831743416 | ||
|
|
63e2f42636 | ||
|
|
f6e6805453 | ||
|
|
ad77ad8f2b | ||
|
|
469524e8ae | ||
|
|
f4f55d5dfd | ||
|
|
c248d0f3f4 | ||
|
|
648a04b513 | ||
|
|
bdc86c16ec | ||
|
|
21efd17c17 | ||
|
|
aaa75e7b62 | ||
|
|
6d0cef3152 | ||
|
|
c18472289f | ||
|
|
02b7c70a81 | ||
|
|
4eaa2b93c6 | ||
|
|
d347905373 | ||
|
|
f495213b2c | ||
|
|
9b125913ae | ||
|
|
da81f05804 | ||
|
|
9a371a4d4d | ||
|
|
1e92828f1a | ||
|
|
7e724b3fa3 | ||
|
|
3f5b976a87 | ||
|
|
49f2339cc2 | ||
|
|
29f1699de8 | ||
|
|
c415485801 | ||
|
|
6937673472 | ||
|
|
c4f10fe876 | ||
|
|
55ca652ad8 | ||
|
|
3effd5afd1 | ||
|
|
000c2029de | ||
|
|
ab88e3af06 | ||
|
|
b544a4c954 | ||
|
|
baff5fafec | ||
|
|
1673de73ba | ||
|
|
e68936e36e | ||
|
|
7dbd195e45 | ||
|
|
3dc22f98bf | ||
|
|
805e870c18 | ||
|
|
de2c031797 | ||
|
|
3aa571aa1b | ||
|
|
3e4969efe6 | ||
|
|
446e94df76 | ||
|
|
5b26066a4c | ||
|
|
8a80de5c3f | ||
|
|
52a490c87e | ||
|
|
29490741fd | ||
|
|
f0e416455f | ||
|
|
f7a2c97943 | ||
|
|
993853757b | ||
|
|
a3abfb987d | ||
|
|
2711fa1b1b | ||
|
|
1f7afaba07 | ||
|
|
e02c8bff81 | ||
|
|
22391ba1a5 | ||
|
|
a05781ec19 | ||
|
|
f898ed6a2a | ||
|
|
e6d0a15b54 | ||
|
|
49cff026e2 | ||
|
|
08f0023cfd | ||
|
|
e311466ee6 | ||
|
|
56789e68d7 | ||
|
|
87525bb383 | ||
|
|
bb2880191a | ||
|
|
4f1acf26d6 | ||
|
|
fc2d6b21ac | ||
|
|
b9e84fefbd | ||
|
|
91f5ffb2d9 | ||
|
|
70ff2341cb | ||
|
|
74eed93497 | ||
|
|
d02e26c014 | ||
|
|
523cade7c3 | ||
|
|
e22c183ca9 | ||
|
|
3afd99da30 | ||
|
|
f44979f983 | ||
|
|
095f9cc108 | ||
|
|
1089076fce | ||
|
|
cad3b691a9 | ||
|
|
bac21426d3 | ||
|
|
c4a35314cd | ||
|
|
7090722565 | ||
|
|
6d972c7c18 | ||
|
|
6961a88feb | ||
|
|
c41ec13984 | ||
|
|
ca8e06e562 | ||
|
|
200cd33a8e | ||
|
|
1da7991c65 | ||
|
|
fdfb7e369a | ||
|
|
c2b01cc957 | ||
|
|
5de8e94bb4 | ||
|
|
7a2c15d912 | ||
|
|
70344dd214 | ||
|
|
405372d1a7 | ||
|
|
b8c5174da5 | ||
|
|
1f6f9103d9 | ||
|
|
6431487c7a | ||
|
|
8b2d1189db | ||
|
|
b777f27cb7 | ||
|
|
b31c3b124a | ||
|
|
fa1e965fba | ||
|
|
91dc8b4d58 | ||
|
|
6d16ea8830 | ||
|
|
7db4253264 | ||
|
|
4d2b7d9bf9 | ||
|
|
8f6f4acb88 | ||
|
|
f20d84cb37 | ||
|
|
afbdf1d5d5 | ||
|
|
bc8364d594 | ||
|
|
c8d388f70f | ||
|
|
be13cc3194 | ||
|
|
a46320e744 | ||
|
|
071709d263 | ||
|
|
93a32ae5ff | ||
|
|
eee96f226f | ||
|
|
e19a8b479c | ||
|
|
9ef459112e | ||
|
|
e96474bd5c | ||
|
|
6fed719e09 | ||
|
|
99aac76618 | ||
|
|
599f458201 | ||
|
|
2f8099059c | ||
|
|
e24f177832 | ||
|
|
48cc143e88 | ||
|
|
b09b46c045 | ||
|
|
2c6583cc9c | ||
|
|
e381d1bfb8 | ||
|
|
eac619d54f | ||
|
|
a6ef3bc0ce | ||
|
|
118122c541 | ||
|
|
bfdf33ac09 | ||
|
|
fa3370df5b | ||
|
|
f1e51672c5 | ||
|
|
91f97b2728 | ||
|
|
2c542e03fe | ||
|
|
71a11b4267 | ||
|
|
ea642757db | ||
|
|
fb72b601aa | ||
|
|
27e507e744 | ||
|
|
4db19f816f | ||
|
|
096d5776d1 | ||
|
|
3d799eb4d9 | ||
|
|
e4ac3afa4d | ||
|
|
d38e4eed5b | ||
|
|
97787fac91 | ||
|
|
b494ee2f1c | ||
|
|
31ac80a074 | ||
|
|
c8896450f6 | ||
|
|
c662fa4c63 | ||
|
|
db2ee802ca | ||
|
|
d40e915e2b | ||
|
|
c0616e7efa | ||
|
|
01660597e3 | ||
|
|
c5b549f450 | ||
|
|
802d8457bb | ||
|
|
c3a3df67b0 | ||
|
|
5798aeb3cd | ||
|
|
cc81dd9172 | ||
|
|
44fdadda08 | ||
|
|
66a014150b | ||
|
|
1da596639f | ||
|
|
76614ae9e5 | ||
|
|
6ddddffc0f | ||
|
|
dd95f849d4 | ||
|
|
22c7f8fe9e | ||
|
|
3d47be1f49 | ||
|
|
5e399c46b1 | ||
|
|
38e1db7a37 | ||
|
|
8309f7cdbe | ||
|
|
b8cc62ae95 | ||
|
|
c0eb433fa2 | ||
|
|
7f857d66f6 | ||
|
|
93b14d38f4 | ||
|
|
21825faab0 | ||
|
|
1fafd39298 | ||
|
|
23b750fc4f | ||
|
|
90581c840d | ||
|
|
cac7a6228a | ||
|
|
674fbc3f69 | ||
|
|
9577bf1cc7 | ||
|
|
654ebe93e7 | ||
|
|
ecb1b3c491 | ||
|
|
c3d1711edc | ||
|
|
c12c7f10f0 | ||
|
|
f71820bf4e | ||
|
|
748c53c774 | ||
|
|
b290a71bfb | ||
|
|
3204c51eca | ||
|
|
2c4b8a44dc | ||
|
|
943aa05eaa | ||
|
|
d0fd36e7e1 | ||
|
|
f45ff5fd0a | ||
|
|
c22c7102d5 | ||
|
|
11ecfd1b41 | ||
|
|
798e30e5ac | ||
|
|
15e0702329 | ||
|
|
a2bc22c37d | ||
|
|
8093fcc64c | ||
|
|
800419e7cc | ||
|
|
a241dc6785 | ||
|
|
805bea0d5f | ||
|
|
9d394adf24 | ||
|
|
2074f27aff | ||
|
|
283ad48b86 | ||
|
|
07e10a7943 | ||
|
|
2812a5026c | ||
|
|
3a20461abf | ||
|
|
64ae3d1e21 | ||
|
|
a25d7ea65b | ||
|
|
74ebbdd761 | ||
|
|
a0427b569e | ||
|
|
5346dfdd8b | ||
|
|
3ee4147285 | ||
|
|
c41e486bfc | ||
|
|
eda3ba92fd | ||
|
|
40255290b0 | ||
|
|
af5bc73dc0 | ||
|
|
0247cd4c45 | ||
|
|
916762cc8c | ||
|
|
d6fdf8ca2a | ||
|
|
95708489c9 | ||
|
|
ced0fa4608 | ||
|
|
7e0fbd600f | ||
|
|
f33e4e0323 | ||
|
|
d0fd78497d | ||
|
|
8045019603 | ||
|
|
7d92b9435e | ||
|
|
1e0822703a | ||
|
|
0403ff88ef | ||
|
|
78376d591b | ||
|
|
8e23d0df20 | ||
|
|
9e281d20ab | ||
|
|
644bd4a106 | ||
|
|
7729e66a96 | ||
|
|
d67d6b7948 | ||
|
|
4c4a46bfbe | ||
|
|
4536f9c177 | ||
|
|
977d3bc02e | ||
|
|
eae95dfef5 | ||
|
|
b67d4460ca | ||
|
|
3dea8311b1 | ||
|
|
11f6e98874 | ||
|
|
2609e595f4 | ||
|
|
ac6e41abc8 | ||
|
|
9c17e16d0a | ||
|
|
55e9064307 | ||
|
|
91cabd7d49 | ||
|
|
7456950530 | ||
|
|
8fcdda625d | ||
|
|
40a10ee926 | ||
|
|
c3f7e2645c | ||
|
|
b264af1892 | ||
|
|
43e93e8e22 | ||
|
|
d6c4789688 | ||
|
|
cb31ee6f01 | ||
|
|
f7b694ac56 | ||
|
|
eb809055d4 | ||
|
|
78d9be82b2 | ||
|
|
76a95c0226 | ||
|
|
d3ab8fb04a | ||
|
|
f7a0b63a00 | ||
|
|
a21dd97786 | ||
|
|
04943c0bfa | ||
|
|
203d4d8bfb | ||
|
|
c049a619dc | ||
|
|
cc1b14b607 | ||
|
|
e04a12a8f4 | ||
|
|
a2c82bc583 | ||
|
|
b4dc382f7c | ||
|
|
eca1892e2a | ||
|
|
23a237074e | ||
|
|
219e9eca4f | ||
|
|
413e09fb9e | ||
|
|
3514c37e4c | ||
|
|
95260e303c | ||
|
|
0cef34bdfa | ||
|
|
9838979bbd | ||
|
|
c8910b8e14 | ||
|
|
207fa1d019 | ||
|
|
be0bb591e7 | ||
|
|
bfacdb9c3b | ||
|
|
ae4077ed6c | ||
|
|
6eb3c90e18 | ||
|
|
8c2a53a504 | ||
|
|
74db1e0308 | ||
|
|
b9dfdcef3d | ||
|
|
9d4afeac31 | ||
|
|
14ae2f169a | ||
|
|
55df19142f | ||
|
|
40fd545b2c | ||
|
|
95fb07343e | ||
|
|
4d87906559 | ||
|
|
6b30dced43 | ||
|
|
293a03b7c8 | ||
|
|
c010549f17 | ||
|
|
cc0be22026 | ||
|
|
e5ba26febe | ||
|
|
36f9680eec | ||
|
|
f4f5be5b08 | ||
|
|
d89b056886 | ||
|
|
65424c7db9 | ||
|
|
32a8a847fc | ||
|
|
88fb3dbf60 | ||
|
|
f6bee3aa58 | ||
|
|
5f19f37dcb | ||
|
|
dd36d8ce9e | ||
|
|
865e4b5349 | ||
|
|
e70564752b | ||
|
|
6e0d2f9437 | ||
|
|
291f936097 | ||
|
|
0b2ce48586 | ||
|
|
da87fd9e20 | ||
|
|
d4da4d2575 | ||
|
|
bad20ff483 | ||
|
|
21ad51ffbf | ||
|
|
697c6d5fbe | ||
|
|
293c659053 | ||
|
|
a12507abbd | ||
|
|
4e675b84fb | ||
|
|
c1022feab8 | ||
|
|
ddcfcf21fe | ||
|
|
86a58c3d80 | ||
|
|
abf9a9048d | ||
|
|
b1030a527a | ||
|
|
8d07ba6332 | ||
|
|
4ce37f84e4 | ||
|
|
061d8a3a5f | ||
|
|
374cd5dbb8 | ||
|
|
5ad53c2b9c | ||
|
|
b7684c1c2b | ||
|
|
854d613a81 |
13
.flake8
13
.flake8
@@ -1,13 +0,0 @@
|
||||
[flake8]
|
||||
max-line-length = 176
|
||||
select = E303,W293,W291,W292,E305,E231,E302
|
||||
exclude =
|
||||
.tox,
|
||||
__pycache__,
|
||||
*.pyc,
|
||||
.env
|
||||
venv/*
|
||||
.venv/*
|
||||
reports/*
|
||||
dist/*
|
||||
lib/*
|
||||
147
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
147
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
@@ -1,133 +1,46 @@
|
||||
name: Bug report 🐛
|
||||
description: 项目运行中遇到的Bug或问题。
|
||||
description: Report a bug or unexpected behavior.
|
||||
title: "[Bug] "
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
### ⚠️ 前置确认
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
||||
3. `git pull` 拉取最新代码
|
||||
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
||||
5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
|
||||
6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
> 💡 English is recommended so global developers can help. 推荐使用英文提交,谢谢 ❤️
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 前置确认
|
||||
label: Self check
|
||||
options:
|
||||
- label: 我确认我运行的是最新版本的代码,并且安装了所需的依赖,在[FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs)中也未找到类似问题。
|
||||
- label: I'm on the latest version and searched [existing issues](https://github.com/zhayujie/CowAgent/issues) (incl. closed) — no duplicate.
|
||||
required: true
|
||||
- type: checkboxes
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: ⚠️ 搜索issues中是否已存在类似问题
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索你的问题
|
||||
或相关日志的关键词来查找是否存在类似问题。
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有跟我遇到的问题相关的issue
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写你对你所遇到问题的简略总结,这将帮助其他人更好的找到相似问题,谢谢❤️。
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 操作系统类型?
|
||||
description: >
|
||||
请选择你运行程序的操作系统类型。
|
||||
options:
|
||||
- Windows
|
||||
- Linux
|
||||
- MacOS
|
||||
- Docker
|
||||
- Railway
|
||||
- Windows Subsystem for Linux (WSL)
|
||||
- Other (请在问题中说明)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的python版本是?
|
||||
description: |
|
||||
请选择你运行程序的`python`版本。
|
||||
注意:在`python 3.7`中,有部分可选依赖无法安装。
|
||||
经过长时间的观察,我们认为`python 3.8`是兼容性最好的版本。
|
||||
`python 3.7`~`python 3.10`以外版本的issue,将视情况直接关闭。
|
||||
options:
|
||||
- python 3.7
|
||||
- python 3.8
|
||||
- python 3.9
|
||||
- python 3.10
|
||||
- other
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 使用的chatgpt-on-wechat版本是?
|
||||
description: |
|
||||
请确保你使用的是 [releases](https://github.com/zhayujie/chatgpt-on-wechat/releases) 中的最新版本。
|
||||
如果你使用git, 请使用`git branch`命令来查看分支。
|
||||
options:
|
||||
- Latest Release
|
||||
- Master (branch)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的`channel`类型是?
|
||||
description: |
|
||||
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
|
||||
options:
|
||||
- wx(个人微信, itchat)
|
||||
- wxy(个人微信, wechaty)
|
||||
- wechatmp(公众号, 订阅号)
|
||||
- wechatmp_service(公众号, 服务号)
|
||||
- terminal
|
||||
- other
|
||||
label: Environment
|
||||
description: "Version (`cow status`), OS, Python version, install method, model & channel."
|
||||
placeholder: |
|
||||
Version: v1.2.0
|
||||
OS: macOS / Linux / Windows / Docker
|
||||
Python: 3.11
|
||||
Install: installer / Docker / source
|
||||
Model & channel: deepseek-v4-flash, web
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 复现步骤 🕹
|
||||
description: |
|
||||
**⚠️ 不能复现将会关闭issue.**
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 问题描述 😯
|
||||
description: 详细描述出现的问题,或提供有关截图。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 终端日志 📒
|
||||
description: |
|
||||
在此处粘贴终端日志,可在主目录下`run.log`文件中找到,这会帮助我们更好的分析问题,注意隐去你的API key。
|
||||
如果在配置文件中加入`"debug": true`,打印出的日志会更有帮助。
|
||||
label: What happened?
|
||||
description: "Steps to reproduce, what you expected, and what happened instead. Screenshots welcome."
|
||||
placeholder: |
|
||||
1. ...
|
||||
2. ...
|
||||
|
||||
<details>
|
||||
<summary><i>示例</i></summary>
|
||||
```log
|
||||
[DEBUG][2023-04-16 00:23:22][plugin_manager.py:157] - Plugin SUMMARY triggered by event Event.ON_HANDLE_CONTEXT
|
||||
[DEBUG][2023-04-16 00:23:22][main.py:221] - [Summary] on_handle_context. content: $总结前100条消息
|
||||
[DEBUG][2023-04-16 00:23:24][main.py:240] - [Summary] limit: 100, duration: -1 seconds
|
||||
[ERROR][2023-04-16 00:23:24][chat_channel.py:244] - Worker return exception: name 'start_date' is not defined
|
||||
Traceback (most recent call last):
|
||||
File "C:\ProgramData\Anaconda3\lib\concurrent\futures\thread.py", line 57, in run
|
||||
result = self.fn(*self.args, **self.kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 132, in _handle
|
||||
reply = self._generate_reply(context)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 142, in _generate_reply
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
||||
File "D:\project\chatgpt-on-wechat\plugins\plugin_manager.py", line 159, in emit_event
|
||||
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 255, in on_handle_context
|
||||
records = self._get_records(session_id, start_time, limit)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 96, in _get_records
|
||||
c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", (session_id, start_date, limit))
|
||||
NameError: name 'start_date' is not defined
|
||||
[INFO][2023-04-16 00:23:36][app.py:14] - signal 2 received, exiting...
|
||||
```
|
||||
</details>
|
||||
value: |
|
||||
```log
|
||||
<此处粘贴终端日志>
|
||||
```
|
||||
Expected: ...
|
||||
Actual: ...
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Logs
|
||||
description: "Relevant logs from `run.log` (set `\"debug\": true` for more detail). ⚠️ Redact your API keys."
|
||||
render: shell
|
||||
validations:
|
||||
required: false
|
||||
|
||||
31
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
31
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
@@ -1,28 +1,33 @@
|
||||
name: Feature request 🚀
|
||||
description: 提出你对项目的新想法或建议。
|
||||
description: Suggest a new idea or improvement.
|
||||
title: "[Feature] "
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写简略总结,谢谢❤️。
|
||||
> 💡 English is recommended so global developers can help. 推荐使用英文提交,谢谢 ❤️
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: ⚠️ 搜索是否存在类似issue
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索关键词查找是否存在相似issue。
|
||||
label: Self check
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有发现相似issue
|
||||
- label: I searched [existing issues](https://github.com/zhayujie/CowAgent/issues) (incl. closed) — no duplicate.
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 总结
|
||||
description: 描述feature的功能。
|
||||
label: What's the problem?
|
||||
description: "The pain point or what's not working for you right now."
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 举例
|
||||
description: 提供聊天示例,草图或相关网址。
|
||||
- type: textarea
|
||||
label: What would you like?
|
||||
description: "How you'd expect it to work. Examples, sketches, or links welcome."
|
||||
validations:
|
||||
required: false
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 动机
|
||||
description: 描述你提出该feature的动机,比如没有这项feature对你的使用造成了怎样的影响。 请提供更详细的场景描述,这可能会帮助我们发现并提出更好的解决方案。
|
||||
label: Contribution
|
||||
options:
|
||||
- label: I'd be interested in helping implement this.
|
||||
required: false
|
||||
|
||||
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
5
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: 📖 Documentation
|
||||
url: https://docs.cowagent.ai
|
||||
about: Setup guides, configuration, and FAQ.
|
||||
21
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
21
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
<!--
|
||||
Thanks for your contribution! Please write this PR in English.
|
||||
【中文开发者】请使用英文填写,感谢 ❤️
|
||||
-->
|
||||
|
||||
## What does this PR do?
|
||||
|
||||
<!-- A short description of the change and why it's needed. -->
|
||||
|
||||
## Type of change
|
||||
|
||||
- [ ] Bug fix
|
||||
- [ ] New feature
|
||||
- [ ] Docs
|
||||
- [ ] Refactor / chore
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] I tested this change locally
|
||||
- [ ] Code comments and docs are in English
|
||||
- [ ] Linked related issue (if any): closes #
|
||||
10
.github/workflows/deploy-image-arm.yml
vendored
10
.github/workflows/deploy-image-arm.yml
vendored
@@ -19,6 +19,7 @@ env:
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/CowAgent'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -50,7 +51,12 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/zhayujie/chatgpt-on-wechat
|
||||
${{ env.REGISTRY }}/zhayujie/cowagent
|
||||
tags: |
|
||||
type=raw,value=latest-arm64,enable={{is_default_branch}}
|
||||
type=ref,event=branch,suffix=-arm64
|
||||
type=ref,event=tag,suffix=-arm64
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
@@ -59,7 +65,7 @@ jobs:
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
platforms: linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}-arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
|
||||
12
.github/workflows/deploy-image.yml
vendored
12
.github/workflows/deploy-image.yml
vendored
@@ -16,9 +16,11 @@ on:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
DOCKERHUB_IMAGE: zhayujie/chatgpt-on-wechat
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/CowAgent'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -46,8 +48,14 @@ jobs:
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
zhayujie/chatgpt-on-wechat
|
||||
zhayujie/cowagent
|
||||
${{ env.REGISTRY }}/zhayujie/chatgpt-on-wechat
|
||||
${{ env.REGISTRY }}/zhayujie/cowagent
|
||||
tags: |
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=ref,event=branch
|
||||
type=ref,event=tag
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
|
||||
21
.gitignore
vendored
21
.gitignore
vendored
@@ -3,17 +3,19 @@
|
||||
.vscode
|
||||
.venv
|
||||
.vs
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
*.pyc
|
||||
python
|
||||
config.json
|
||||
QR.png
|
||||
nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
logs/
|
||||
workspace
|
||||
config.yaml
|
||||
user_datas.pkl
|
||||
chatgpt_tool_hub/
|
||||
plugins/**/
|
||||
@@ -29,4 +31,17 @@ plugins/banwords/lib/__pycache__
|
||||
!plugins/hello
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
!plugins/linkai
|
||||
!plugins/cow_cli
|
||||
client_config.json
|
||||
ref/
|
||||
**/.dev.vars
|
||||
.cursor/
|
||||
local/
|
||||
node_modules/
|
||||
|
||||
# cow cli
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
.cow.pid
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: fix-byte-order-marker
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
||||
- id: pretty-format-json
|
||||
types: [text]
|
||||
files: \.json(.template)?$
|
||||
args: [ --autofix , --no-ensure-ascii, --indent=2, --no-sort-keys]
|
||||
- id: trailing-whitespace
|
||||
exclude: '(\/|^)lib\/'
|
||||
args: [ --markdown-linebreak-ext=md ]
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
exclude: '(\/|^)lib\/'
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
exclude: '(\/|^)lib\/'
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: '(\/|^)lib\/'
|
||||
61
CONTRIBUTING.md
Normal file
61
CONTRIBUTING.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Contributing to CowAgent
|
||||
|
||||
Thanks for taking the time to contribute! 🎉 CowAgent is built by a global
|
||||
community, and contributions of all sizes are welcome — from typo fixes to new
|
||||
features.
|
||||
|
||||
## Language policy
|
||||
|
||||
To keep the project accessible to a global community, **please write issues,
|
||||
pull requests, code comments, and commit messages in English.**
|
||||
|
||||
> 为方便全球开发者协作,请尽量使用**英文**提交 issue、PR、代码注释与
|
||||
> commit message。不必担心英文不完美——表达清楚即可,工具翻译也完全没问题。感谢理解 ❤️
|
||||
|
||||
## Reporting issues
|
||||
|
||||
Found a bug or have an idea? [Open an issue](https://github.com/zhayujie/CowAgent/issues/new/choose).
|
||||
|
||||
Before opening one, please search existing issues (including closed ones) to
|
||||
avoid duplicates, and make sure you're on the latest version.
|
||||
|
||||
## Submitting a pull request
|
||||
|
||||
1. **Fork** the repo and create a branch from `master`
|
||||
(e.g. `feat/web-search`, `fix/telegram-reconnect`).
|
||||
2. Make your change. Keep it focused — one logical change per PR.
|
||||
3. Follow the existing code style. Write comments and docstrings in English.
|
||||
4. Run the app locally to confirm your change works.
|
||||
5. Open a PR with a clear title and a short description of **what** and **why**.
|
||||
|
||||
We keep the bar friendly: clear, focused, and working is enough. Maintainers are
|
||||
happy to help polish details during review.
|
||||
|
||||
### Commit & PR titles
|
||||
|
||||
Use a short, imperative summary. The [Conventional Commits](https://www.conventionalcommits.org/)
|
||||
style is preferred but not required:
|
||||
|
||||
```
|
||||
feat: add web search tool
|
||||
fix: reconnect Telegram websocket on timeout
|
||||
docs: clarify Docker setup
|
||||
```
|
||||
|
||||
## Development setup
|
||||
|
||||
See the [Install from Source](https://docs.cowagent.ai/guide/manual-install)
|
||||
guide. In short:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/zhayujie/CowAgent.git
|
||||
cd CowAgent
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
cow start
|
||||
```
|
||||
|
||||
## Code of conduct
|
||||
|
||||
Be respectful and constructive. We want CowAgent to be a welcoming place for
|
||||
everyone.
|
||||
395
README.md
395
README.md
@@ -1,278 +1,261 @@
|
||||
# 简介
|
||||
<p align="center"><img src="https://github.com/user-attachments/assets/eca9a9ec-8534-4615-9e0f-96c5ac1d10a3" alt="CowAgent" width="420" /></p>
|
||||
|
||||
> ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。
|
||||
<p align="center">
|
||||
<a href="https://github.com/zhayujie/CowAgent/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/CowAgent" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/CowAgent/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/CowAgent" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/CowAgent"><img src="https://img.shields.io/github/stars/zhayujie/CowAgent?style=flat-square" alt="Stars"></a> <br/>
|
||||
[English] | [<a href="docs/zh/README.md">中文</a>] | [<a href="docs/ja/README.md">日本語</a>]
|
||||
</p>
|
||||
|
||||
最新版本支持的功能如下:
|
||||
**CowAgent** is an open-source super AI assistant that proactively plans tasks, controls your computer and external services, creates and runs Skills, and grows alongside you through a personal knowledge base and long-term memory — a reference implementation of Agent Harness engineering.
|
||||
|
||||
- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信,微信公众号和企业微信应用等部署方式
|
||||
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4, claude, 文心一言, 讯飞星火
|
||||
- [x] **语音识别:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
|
||||
- [x] **图片生成:** 支持图片生成 和 图生图(如照片修复),可选择 Dall-E, stable diffusion, replicate, midjourney模型
|
||||
- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话等插件
|
||||
- [X] **Tool工具:** 与操作系统和互联网交互,支持最新信息搜索、数学计算、天气和资讯查询、网页总结,基于 [chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) 实现
|
||||
- [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、领域知识库、智能客服使用,基于 [LinkAI](https://link-ai.tech/console) 实现
|
||||
CowAgent is lightweight, easy to deploy, and built to extend. Plug in any major LLM provider and run it 24/7 on a personal computer or server, across the web and all major IM platforms.
|
||||
|
||||
> 欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
|
||||
<p align="center">
|
||||
<a href="https://cowagent.ai/">🌐 Website</a> ·
|
||||
<a href="https://docs.cowagent.ai/intro/index">📖 Docs</a> ·
|
||||
<a href="https://docs.cowagent.ai/guide/quick-start">🚀 Quick Start</a> ·
|
||||
<a href="https://skills.cowagent.ai/">🧩 Skill Hub</a> ·
|
||||
<a href="https://link-ai.tech/cowagent/create">☁️ Try Online</a>
|
||||
</p>
|
||||
|
||||
# 演示
|
||||
<br/>
|
||||
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/assets/26161723/d5154020-36e3-41db-8706-40ce9f3f1b1e
|
||||
## 🌟 Highlights
|
||||
|
||||
Demo made by [Visionn](https://www.wangpc.cc/)
|
||||
| Capability | Description |
|
||||
| :--- | :--- |
|
||||
| [Planning](https://docs.cowagent.ai/intro/architecture) | Decomposes complex tasks and executes them step by step, looping over tools until the goal is reached |
|
||||
| [Memory](https://docs.cowagent.ai/memory/index) | Three-tier architecture (context → daily → core), automatic Deep Dream distillation, hybrid keyword + vector retrieval |
|
||||
| [Knowledge](https://docs.cowagent.ai/knowledge/index) | Auto-curates structured knowledge into a Markdown wiki, builds an evolving knowledge graph with visual browsing |
|
||||
| [Skills](https://docs.cowagent.ai/skills/index) | One-click install from [Skill Hub](https://skills.cowagent.ai/), GitHub, ClawHub; or create custom skills via natural-language conversation |
|
||||
| [Tools](https://docs.cowagent.ai/tools/index) | Built-in file I/O, terminal, browser, scheduler, memory retrieval, web search, and 10+ more tools — with native MCP integration |
|
||||
| [Channels](https://docs.cowagent.ai/channels/index) | Integrates with Web, WeChat, Feishu, DingTalk, WeCom, QQ, Official Accounts, Telegram, and Slack |
|
||||
| Multimodal | First-class support for text, images, voice, and files — recognition, generation, and delivery |
|
||||
| [Models](https://docs.cowagent.ai/models/index) | Claude, GPT, Gemini, DeepSeek, Qwen, GLM, Kimi, MiniMax, Doubao, and more — swap providers from the Web console with one click |
|
||||
| [Deploy](https://docs.cowagent.ai/guide/quick-start) | One-line installer, unified Web console, multiple deployment modes (local, Docker, server) |
|
||||
|
||||
# 交流群
|
||||
<br/>
|
||||
|
||||
添加小助手微信进群,请备注 "wechat":
|
||||
## 🏗️ Architecture
|
||||
|
||||
<img width="240" src="./docs/images/contact.jpg">
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/architecture/en/architecture.jpg" alt="CowAgent Architecture" width="750"/>
|
||||
|
||||
# 更新日志
|
||||
CowAgent is a complete **Agent Harness**: messages flow in through **Channels**; the **Agent Core** plans and reasons over memory, knowledge, and the available tools and skills; **Models** generate the response, which is sent back through the originating channel. Every layer is decoupled and independently extensible.
|
||||
|
||||
>**2023.11.10:** [1.5.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.0),新增 `gpt-4-turbo`, `dall-e-3`, `tts` 模型接入,完善图像理解&生成、语音识别&生成的多模态能力
|
||||
Read more in [Architecture](https://docs.cowagent.ai/intro/architecture).
|
||||
|
||||
>**2023.10.16:** 支持通过意图识别使用LinkAI联网搜索、数学计算、网页访问等插件,参考[插件文档](https://docs.link-ai.tech/platform/plugins)
|
||||
<br/>
|
||||
|
||||
>**2023.09.26:** 插件增加 文件/文章链接 一键总结和对话的功能,使用参考:[插件说明](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai#3%E6%96%87%E6%A1%A3%E6%80%BB%E7%BB%93%E5%AF%B9%E8%AF%9D%E5%8A%9F%E8%83%BD)
|
||||
## 🚀 Quick Start
|
||||
|
||||
>**2023.08.08:** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图
|
||||
A one-line installer takes care of dependencies, configuration, and startup:
|
||||
|
||||
>**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,并接入微信、公众号及企业微信中,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
>**2023.04.26:** 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatcom/README.md)。(contributed by [@lanvent](https://github.com/lanvent) in [#944](https://github.com/zhayujie/chatgpt-on-wechat/pull/944))
|
||||
|
||||
>**2023.04.05:** 支持微信公众号部署,兼容插件,并支持语音图片交互,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
|
||||
|
||||
>**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
|
||||
|
||||
>**2023.03.25:** 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 [#578](https://github.com/zhayujie/chatgpt-on-wechat/issues/578)。(contributed by [@lanvent](https://github.com/lanvent) in [#565](https://github.com/zhayujie/chatgpt-on-wechat/pull/565))
|
||||
|
||||
>**2023.03.09:** 基于 `whisper API`(后续已接入更多的语音`API`服务) 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
|
||||
|
||||
>**2023.02.09:** 扫码登录存在账号限制风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
|
||||
|
||||
# 快速开始
|
||||
|
||||
## 准备
|
||||
|
||||
### 1. 账号注册
|
||||
|
||||
项目默认使用OpenAI接口,需前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。接口需要海外网络访问及绑定信用卡支付。
|
||||
|
||||
> 默认对话模型是 openai 的 gpt-3.5-turbo,计费方式是约每 1000tokens (约750个英文单词 或 500汉字,包含请求和回复) 消耗 $0.002,图片生成是Dell E模型,每张消耗 $0.016。
|
||||
|
||||
项目同时也支持使用 LinkAI 接口,无需代理,可使用 文心、讯飞、GPT-3、GPT-4 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结和对话等能力。修改配置即可一键切换,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
### 2.运行环境
|
||||
|
||||
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
|
||||
> 建议Python版本在 3.7.1~3.9.X 之间,推荐3.8版本,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
|
||||
|
||||
> 注意:Docker 或 Railway 部署无需安装python环境和下载源码,可直接快进到下一节。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
**Linux / macOS:**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/zhayujie/chatgpt-on-wechat
|
||||
cd chatgpt-on-wechat/
|
||||
bash <(curl -fsSL https://cdn.link-ai.tech/code/cow/run.sh)
|
||||
```
|
||||
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
> 能够使用`itchat`创建机器人,并具有文字交流功能所需的最小依赖集合。
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
**Windows (PowerShell):**
|
||||
|
||||
```powershell
|
||||
irm https://cdn.link-ai.tech/code/cow/run.ps1 | iex
|
||||
```
|
||||
|
||||
**(3) 拓展依赖 (可选,建议安装):**
|
||||
**Docker:**
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements-optional.txt
|
||||
curl -O https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
docker compose up -d
|
||||
```
|
||||
> 如果某项依赖安装失败请注释掉对应的行再继续。
|
||||
|
||||
其中`tiktoken`要求`python`版本在3.8以上,它用于精确计算会话使用的tokens数量,强烈建议安装。
|
||||
Once started, open `http://localhost:9899` to access the **Web console** — your one-stop hub to chat with the Agent, configure models, connect channels, and install skills.
|
||||
|
||||
> Deploying on a server? Set `web_host` to `0.0.0.0` in `config.json` to make the console reachable from outside, and set `web_password` to protect it. Don't forget to open port `9899` in your firewall or security group.
|
||||
|
||||
使用`google`或`baidu`语音识别需安装`ffmpeg`,
|
||||
> 📖 Detailed guides: [Quick Start](https://docs.cowagent.ai/guide/quick-start) · [Install from Source](https://docs.cowagent.ai/guide/manual-install) · [Upgrade](https://docs.cowagent.ai/guide/upgrade)
|
||||
|
||||
默认的`openai`语音识别不需要安装`ffmpeg`。
|
||||
|
||||
参考[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)
|
||||
|
||||
使用`azure`语音功能需安装依赖,并参考[文档](https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/quickstarts/setup-platform?pivots=programming-language-python&tabs=linux%2Cubuntu%2Cdotnet%2Cjre%2Cmaven%2Cnodejs%2Cmac%2Cpypi)的环境要求。
|
||||
:
|
||||
After installation, manage the service with the [cow CLI](https://docs.cowagent.ai/cli/index):
|
||||
|
||||
```bash
|
||||
pip3 install azure-cognitiveservices-speech
|
||||
cow start | stop | restart # service control
|
||||
cow status | logs # status and logs
|
||||
cow update # pull latest code and restart
|
||||
cow skill install <name> # install a skill
|
||||
cow install-browser # install browser automation
|
||||
```
|
||||
|
||||
## 配置
|
||||
<br/>
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
## 🤖 Models
|
||||
|
||||
CowAgent supports all mainstream LLM providers. **Chat, vision, image generation, ASR/TTS, and embeddings** can each be routed to a different vendor. Providers are configured directly in the Web console — no manual file editing required.
|
||||
|
||||
| Provider | Featured Models | Chat | Vision | Image Gen | ASR | TTS | Embedding |
|
||||
| --- | --- | :-: | :-: | :-: | :-: | :-: | :-: |
|
||||
| [Claude](https://docs.cowagent.ai/models/claude) | claude-opus-4-8 | ✅ | ✅ | | | | |
|
||||
| [OpenAI](https://docs.cowagent.ai/models/openai) | gpt-5.5, o-series | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Gemini](https://docs.cowagent.ai/models/gemini) | gemini-3.5-flash | ✅ | ✅ | ✅ | | | |
|
||||
| [DeepSeek](https://docs.cowagent.ai/models/deepseek) | deepseek-v4-flash / pro | ✅ | | | | | |
|
||||
| [Qwen](https://docs.cowagent.ai/models/qwen) | qwen3.7-max | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [GLM](https://docs.cowagent.ai/models/glm) | glm-5.1, glm-5v-turbo | ✅ | ✅ | | ✅ | | ✅ |
|
||||
| [Doubao](https://docs.cowagent.ai/models/doubao) | doubao-seed-2.0 series | ✅ | ✅ | ✅ | | | ✅ |
|
||||
| [Kimi](https://docs.cowagent.ai/models/kimi) | kimi-k2.6 | ✅ | ✅ | | | | |
|
||||
| [MiniMax](https://docs.cowagent.ai/models/minimax) | MiniMax-M2.7 | ✅ | ✅ | ✅ | | ✅ | |
|
||||
| [ERNIE](https://docs.cowagent.ai/models/qianfan) | ernie-5.1 | ✅ | ✅ | | | | |
|
||||
| [MiMo](https://docs.cowagent.ai/models/mimo) | mimo-v2.5 / pro | ✅ | ✅ | | | ✅ | |
|
||||
| [LinkAI](https://docs.cowagent.ai/models/linkai) | One key for 100+ models | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Custom](https://docs.cowagent.ai/models/custom) | Local models / third-party proxy | ✅ | | | | | |
|
||||
|
||||
> For details on each provider, see the [Models overview](https://docs.cowagent.ai/models/index).
|
||||
|
||||
<br/>
|
||||
|
||||
## 💬 Channels
|
||||
|
||||
A single Agent instance can serve multiple channels in parallel. Most channels can be onboarded right from the Web console.
|
||||
|
||||
| Channel | Text | Image | File | Voice | Group |
|
||||
| --- | :-: | :-: | :-: | :-: | :-: |
|
||||
| [Web Console](https://docs.cowagent.ai/channels/web) (default) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Telegram](https://docs.cowagent.ai/channels/telegram) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Slack](https://docs.cowagent.ai/channels/slack) | ✅ | ✅ | ✅ | | ✅ |
|
||||
| [Discord](https://docs.cowagent.ai/channels/discord) | ✅ | ✅ | ✅ | | ✅ |
|
||||
| [WeChat](https://docs.cowagent.ai/channels/weixin) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Feishu / Lark](https://docs.cowagent.ai/channels/feishu) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [DingTalk](https://docs.cowagent.ai/channels/dingtalk) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [WeCom Bot](https://docs.cowagent.ai/channels/wecom-bot) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| [QQ](https://docs.cowagent.ai/channels/qq) | ✅ | ✅ | ✅ | | ✅ |
|
||||
| [WeCom App](https://docs.cowagent.ai/channels/wecom) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [WeChat Customer Service](https://docs.cowagent.ai/channels/wechat-kf) | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [WeChat Official Account](https://docs.cowagent.ai/channels/wechatmp) | ✅ | ✅ | | ✅ | |
|
||||
|
||||
> See the [Channels overview](https://docs.cowagent.ai/channels/index) for setup details.
|
||||
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-chat.png" alt="CowAgent Web Console" width="800"/>
|
||||
|
||||
*The Web console is the default channel and the unified entry point to configure models, channels, skills, memory, and more.*
|
||||
|
||||
<br/>
|
||||
|
||||
## 🧠 Memory & Knowledge Base
|
||||
|
||||
**Long-term memory** uses a three-tier architecture: conversation context (short-term) → daily memory (mid-term) → MEMORY.md (long-term). A nightly **Deep Dream** pass distills scattered memories into refined long-term entries and a narrative journal. See [Long-term Memory](https://docs.cowagent.ai/memory/index) · [Deep Dream](https://docs.cowagent.ai/memory/deep-dream).
|
||||
|
||||
**Personal knowledge base** complements the time-ordered memory by organizing structured knowledge **by topic**. The Agent automatically curates valuable information from conversations, maintains cross-references and indexes, and the Web console offers an interactive knowledge-graph view. See [Personal Knowledge Base](https://docs.cowagent.ai/knowledge/index).
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="50%">
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-memory.png" alt="Long-term Memory" />
|
||||
<p align="center"><em>Long-term Memory · Three-tier architecture + Deep Dream</em></p>
|
||||
</td>
|
||||
<td width="50%">
|
||||
<img src="https://cdn.jsdelivr.net/gh/zhayujie/cowagent-assets@main/screenshots/en/web-console-knowledge.png" alt="Personal Knowledge Base" />
|
||||
<p align="center"><em>Knowledge Base · Auto-curated Markdown wiki</em></p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<br/>
|
||||
|
||||
## 🔧 Tools & Skills
|
||||
|
||||
**Tools** are atomic capabilities the Agent uses to interact with system resources. **Skills** are higher-level workflows defined by a manifest file that compose multiple tools to accomplish complex tasks.
|
||||
|
||||
### Tool System
|
||||
|
||||
**Built-in tools** cover file I/O (`read` / `write` / `edit` / `ls`), terminal (`bash`), file sending (`send`), memory retrieval (`memory`), environment variables (`env_config`), web fetching (`web_fetch`), scheduling (`scheduler`), web search (`web_search`), vision (`vision`), and browser automation (`browser`).
|
||||
|
||||
**MCP protocol** integrates the open ecosystem of [Model Context Protocol](https://modelcontextprotocol.io) servers. A single `mcp.json` is enough — supports stdio / SSE transports, hot reload, and zero-code integration.
|
||||
|
||||
Learn more: [Tools overview](https://docs.cowagent.ai/tools/index) · [MCP integration](https://docs.cowagent.ai/tools/mcp).
|
||||
|
||||
### Skills System
|
||||
|
||||
- **[Skill Hub](https://skills.cowagent.ai/)** — open skill marketplace: browse, search, install in one click
|
||||
- **GitHub / ClawHub / URL and more** — install skills from any source
|
||||
- **Conversational authoring** — generate custom skills through dialogue with `skill-creator`; turn any workflow or third-party API into a reusable skill
|
||||
|
||||
```bash
|
||||
cp config-template.json config.json
|
||||
/skill list # list installed skills
|
||||
/skill search <keyword> # search the marketplace
|
||||
/skill install <name> # one-click install
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(请去掉注释):
|
||||
Learn more: [Skills overview](https://docs.cowagent.ai/skills/index) · [Creating Skills](https://docs.cowagent.ai/skills/create).
|
||||
|
||||
```bash
|
||||
# config.json文件内容示例
|
||||
{
|
||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
||||
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
|
||||
"azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称
|
||||
"azure_api_version": "", # 采用Azure ChatGPT时,API版本
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
# 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
|
||||
"linkai_api_key": "", # LinkAI Api Key
|
||||
"linkai_app_code": "" # LinkAI 应用code
|
||||
}
|
||||
```
|
||||
**配置说明:**
|
||||
<br/>
|
||||
|
||||
**1.个人聊天**
|
||||
## 🏷 Changelog
|
||||
|
||||
+ 个人聊天中,需要以 "bot"或"@bot" 为开头的内容触发机器人,对应配置项 `single_chat_prefix` (如果不需要以前缀触发可以填写 `"single_chat_prefix": [""]`)
|
||||
+ 机器人回复的内容会以 "[bot] " 作为前缀, 以区分真人,对应的配置项为 `single_chat_reply_prefix` (如果不需要前缀可以填写 `"single_chat_reply_prefix": ""`)
|
||||
> **2026.06.01:** [v2.1.0](https://github.com/zhayujie/CowAgent/releases/tag/2.1.0) — Internationalization, new channels (Telegram, Discord, Slack, WeChat Customer Service), CLI interaction upgrades, streamlined one-line install, MCP Streamable HTTP support, new models (claude-opus-4-8, MiMo).
|
||||
|
||||
**2.群组聊天**
|
||||
> **2026.05.22:** [v2.0.9](https://github.com/zhayujie/CowAgent/releases/tag/2.0.9) — Model management, MCP protocol support, persistent browser sessions, new models (gpt-5.5, gemini-3.5-flash, qwen3.7-max), deployment hardening.
|
||||
|
||||
+ 群组聊天中,群名称需配置在 `group_name_white_list ` 中才能开启群聊自动回复。如果想对所有群聊生效,可以直接填写 `"group_name_white_list": ["ALL_GROUP"]`
|
||||
+ 默认只要被人 @ 就会触发机器人自动回复;另外群聊天中只要检测到以 "@bot" 开头的内容,同样会自动回复(方便自己触发),这对应配置项 `group_chat_prefix`
|
||||
+ 可选配置: `group_name_keyword_white_list`配置项支持模糊匹配群名称,`group_chat_keyword`配置项则支持模糊匹配群消息内容,用法与上述两个配置项相同。(Contributed by [evolay](https://github.com/evolay))
|
||||
+ `group_chat_in_one_session`:使群聊共享一个会话上下文,配置 `["ALL_GROUP"]` 则作用于所有群聊
|
||||
> **2026.05.06:** [v2.0.8](https://github.com/zhayujie/CowAgent/releases/tag/2.0.8) — Feishu channel overhaul (voice, streaming, QR onboarding), DeepSeek V4 and Baidu Qianfan support, scheduler tool upgrades.
|
||||
|
||||
**3.语音识别**
|
||||
> **2026.04.22:** [v2.0.7](https://github.com/zhayujie/CowAgent/releases/tag/2.0.7) — Built-in image generation (GPT Image 2, Nano Banana), new models (Kimi K2.6, Claude Opus 4.7, GLM 5.1), memory and knowledge enhancements.
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊),但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。
|
||||
> **2026.04.14:** [v2.0.6](https://github.com/zhayujie/CowAgent/releases/tag/2.0.6) — Knowledge base, Deep Dream memory distillation, smart context compression, multi-session Web console.
|
||||
|
||||
**4.其他配置**
|
||||
> **2026.04.01:** [v2.0.5](https://github.com/zhayujie/CowAgent/releases/tag/2.0.5) — Cow CLI, Skill Hub open source, browser tool, WeCom Bot QR onboarding.
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`, `wenxin` , `claude` , `xunfei`(其中gpt-4 api暂未完全开放,申请通过后可使用)
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档,在[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中检查哪些参数在本项目中是可配置的。
|
||||
+ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
|
||||
+ `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
|
||||
+ `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
|
||||
+ `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
> **2026.02.03:** [v2.0.0](https://github.com/zhayujie/CowAgent/releases/tag/2.0.0) — Major upgrade to a super Agent assistant with multi-step task planning, long-term memory, and the Skills framework.
|
||||
|
||||
**5.LinkAI配置 (可选)**
|
||||
Full history: [Release Notes](https://docs.cowagent.ai/releases/overview)
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用code,选填
|
||||
<br/>
|
||||
|
||||
**本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
## 🤝 Community & Support
|
||||
|
||||
## 运行
|
||||
[File an issue](https://github.com/zhayujie/CowAgent/issues) on GitHub, or scan the QR code below to join our WeChat community:
|
||||
|
||||
### 1.本地运行
|
||||
<img width="130" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
|
||||
如果是开发机 **本地运行**,直接在项目根目录下执行:
|
||||
<br/>
|
||||
|
||||
```bash
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
## 🔗 Related Projects
|
||||
|
||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
- **[Cow Skill Hub](https://github.com/zhayujie/cow-skill-hub)** — open skill marketplace for AI Agents; works with CowAgent, OpenClaw, Claude Code, and more
|
||||
- **[bot-on-anything](https://github.com/zhayujie/bot-on-anything)** — lightweight LLM application framework with integrations for Slack, Telegram, Discord, Gmail, and more
|
||||
- **[AgentMesh](https://github.com/MinimalFuture/AgentMesh)** — open-source multi-agent framework for solving complex problems through team collaboration
|
||||
|
||||
### 2.服务器部署
|
||||
<br/>
|
||||
|
||||
使用nohup命令在后台运行程序:
|
||||
## 🏢 Enterprise Services
|
||||
|
||||
```bash
|
||||
touch nohup.out # 首次运行需要新建日志文件
|
||||
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
|
||||
```
|
||||
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。
|
||||
[**LinkAI**](https://link-ai.tech/) is an all-in-one AI Agent platform for enterprises and developers, offering managed hosting and enterprise-grade support for CowAgent:
|
||||
|
||||
> **多账号支持:** 将项目复制多份,分别启动程序,用不同账号扫码登录即可实现同时运行。
|
||||
- **🚀 Zero-deployment hosted runtime** — spin up a [CowAgent online assistant](https://link-ai.tech/cowagent/create) in under a minute, no server required
|
||||
- **🧠 Agent infrastructure** — unified access to LLMs, knowledge bases, databases, skills, and workflows; plug-and-play building blocks that extend what CowAgent can do
|
||||
- **🏢 Team & enterprise features** — workspaces, role-based access, audit logs, and private deployment for production use cases
|
||||
|
||||
> **特殊指令:** 用户向机器人发送 **#reset** 即可清空该用户的上下文记忆。
|
||||
For enterprise inquiries: sales@simple-future.tech or [scan the QR code](https://cdn.link-ai.tech/consultant.jpg) to reach our team on WeChat.
|
||||
|
||||
<br/>
|
||||
|
||||
### 3.Docker部署
|
||||
## 🛠️ Development & Contributing
|
||||
|
||||
> 使用docker部署无需下载源码和安装依赖,只需要获取 docker-compose.yml 配置文件并启动容器即可。
|
||||
Contributions are welcome — add a new channel by following the [Feishu channel reference](https://github.com/zhayujie/CowAgent/blob/master/channel/feishu/feishu_channel.py), or contribute new skills to [Skill Hub](https://skills.cowagent.ai/submit).
|
||||
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功的表现是执行 `docker -v` 和 `docker-compose version` (或 docker compose version) 可以查看到版本号,可前往 [docker官网](https://docs.docker.com/engine/install/) 进行下载。
|
||||
⭐ Star the project to follow updates, and feel free to open PRs and Issues.
|
||||
|
||||
#### (1) 下载 docker-compose.yml 文件
|
||||
## 🌟 Contributors
|
||||
|
||||
```bash
|
||||
wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml
|
||||
```
|
||||

|
||||
|
||||
下载完成后打开 `docker-compose.yml` 修改所需配置,如 `OPEN_AI_API_KEY` 和 `GROUP_NAME_WHITE_LIST` 等。
|
||||
<br/>
|
||||
|
||||
#### (2) 启动容器
|
||||
## ⚠️ Disclaimer
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
1. This project is licensed under the [MIT License](/LICENSE) and is intended for technical research and learning. You are responsible for complying with applicable laws and regulations in your jurisdiction; the maintainers assume no liability for any consequences arising from use of this project.
|
||||
2. **Cost & safety:** Agent mode consumes substantially more tokens than regular chat — pick models that balance quality and cost. The Agent has access to your local operating system, so only deploy it in trusted environments.
|
||||
3. CowAgent is a pure open-source project and does not participate in, authorize, or issue any cryptocurrency.
|
||||
|
||||
```bash
|
||||
sudo docker compose up -d
|
||||
```
|
||||
<br/>
|
||||
|
||||
运行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。
|
||||
## 📌 Project Renaming Notice
|
||||
|
||||
注意:
|
||||
|
||||
- 如果 `docker-compose` 是 1.X 版本 则需要执行 `sudo docker-compose up -d` 来启动容器
|
||||
- 该命令会自动去 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取 latest 版本的镜像,latest 镜像会在每次项目 release 新的版本时生成
|
||||
|
||||
最后运行以下命令可查看容器运行日志,扫描日志中的二维码即可完成登录:
|
||||
|
||||
```bash
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
#### (3) 插件使用
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
|
||||
```
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
|
||||
### 4. Railway部署
|
||||
|
||||
> Railway 每月提供5刀和最多500小时的免费额度。 (07.11更新: 目前大部分账号已无法免费部署)
|
||||
|
||||
1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
2. 点击 `Deploy Now` 按钮。
|
||||
3. 设置环境变量来重载程序运行的参数,例如`open_ai_api_key`, `character_desc`。
|
||||
|
||||
**一键部署:**
|
||||
|
||||
[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
|
||||
## 常见问题
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (beta版本,语料完善中,回复仅供参考)
|
||||
|
||||
## 联系
|
||||
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。参与更多讨论可加入技术交流群。
|
||||
This project was previously named `chatgpt-on-wechat` and is now officially **CowAgent**. The old GitHub URL redirects automatically; existing users may optionally run `git remote set-url origin https://github.com/zhayujie/CowAgent.git` to update the local remote.
|
||||
|
||||
3
agent/chat/__init__.py
Normal file
3
agent/chat/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.chat.service import ChatService
|
||||
|
||||
__all__ = ["ChatService"]
|
||||
290
agent/chat/service.py
Normal file
290
agent/chat/service.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
ChatService - Wraps the Agent stream execution to produce CHAT protocol chunks.
|
||||
|
||||
Translates agent events (message_update, message_end, tool_execution_end, etc.)
|
||||
into the CHAT socket protocol format (content chunks with segment_id, tool_calls chunks).
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""
|
||||
High-level service that runs an Agent for a given query and streams
|
||||
the results as CHAT protocol chunks via a callback.
|
||||
|
||||
Usage:
|
||||
svc = ChatService(agent_bridge)
|
||||
svc.run(query, session_id, send_chunk_fn)
|
||||
"""
|
||||
|
||||
def __init__(self, agent_bridge):
|
||||
"""
|
||||
:param agent_bridge: AgentBridge instance (manages agent lifecycle)
|
||||
"""
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def run(self, query: str, session_id: str, send_chunk_fn: Callable[[dict], None],
|
||||
channel_type: str = ""):
|
||||
"""
|
||||
Run the agent for *query* and stream results back via *send_chunk_fn*.
|
||||
|
||||
The method blocks until the agent finishes. After it returns the SDK
|
||||
will automatically send the final (streaming=false) message.
|
||||
|
||||
:param query: user query text
|
||||
:param session_id: session identifier for agent isolation
|
||||
:param send_chunk_fn: callable(chunk_data: dict) to send a streaming chunk
|
||||
:param channel_type: source channel (e.g. "web", "feishu") for persistence
|
||||
"""
|
||||
agent = self.agent_bridge.get_agent(session_id=session_id)
|
||||
if agent is None:
|
||||
raise RuntimeError("Failed to initialise agent for the session")
|
||||
|
||||
# Pass context metadata to model for downstream API requests
|
||||
if hasattr(agent, 'model'):
|
||||
agent.model.channel_type = channel_type or ""
|
||||
agent.model.session_id = session_id or ""
|
||||
|
||||
# State shared between the event callback and this method
|
||||
state = _StreamState()
|
||||
|
||||
def on_event(event: dict):
|
||||
"""Translate agent events into CHAT protocol chunks."""
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "reasoning_update":
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "reasoning",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_update":
|
||||
# Incremental text delta
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_end":
|
||||
# A content segment finished.
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
# After tool_calls are executed the next content will be
|
||||
# a new segment; collect tool results until turn_end.
|
||||
state.pending_tool_results = []
|
||||
|
||||
elif event_type == "file_to_send":
|
||||
url = data.get("url") or ""
|
||||
if url:
|
||||
fname = data.get("file_name") or "file"
|
||||
ft = data.get("file_type") or "file"
|
||||
if ft == "image":
|
||||
link = f""
|
||||
else:
|
||||
link = f"[{fname}]({url})"
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": "\n\n" + link + "\n\n",
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
# Remove url so the model won't repeat it in its reply
|
||||
data.pop("url", None)
|
||||
|
||||
elif event_type == "tool_execution_start":
|
||||
# Notify the client that a tool is about to run (with its input args)
|
||||
tool_name = data.get("tool_name", "")
|
||||
arguments = data.get("arguments", {})
|
||||
# Cache arguments keyed by tool_call_id so tool_execution_end can include them
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
state.pending_tool_arguments[tool_call_id] = arguments
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_start",
|
||||
"tool": tool_name,
|
||||
"arguments": arguments,
|
||||
})
|
||||
|
||||
elif event_type == "tool_execution_end":
|
||||
tool_name = data.get("tool_name", "")
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
# Retrieve cached arguments from the matching tool_execution_start event
|
||||
arguments = state.pending_tool_arguments.pop(tool_call_id, data.get("arguments", {}))
|
||||
result = data.get("result", "")
|
||||
status = data.get("status", "unknown")
|
||||
execution_time = data.get("execution_time", 0)
|
||||
elapsed_str = f"{execution_time:.2f}s"
|
||||
|
||||
# Serialise result to string if needed
|
||||
if not isinstance(result, str):
|
||||
import json
|
||||
try:
|
||||
result = json.dumps(result, ensure_ascii=False)
|
||||
except Exception:
|
||||
result = str(result)
|
||||
|
||||
tool_info = {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
"result": result,
|
||||
"status": status,
|
||||
"elapsed": elapsed_str,
|
||||
}
|
||||
|
||||
if state.pending_tool_results is not None:
|
||||
state.pending_tool_results.append(tool_info)
|
||||
|
||||
elif event_type == "turn_end":
|
||||
has_tool_calls = data.get("has_tool_calls", False)
|
||||
if has_tool_calls and state.pending_tool_results:
|
||||
# Flush collected tool results as a single tool_calls chunk
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_calls",
|
||||
"tool_calls": state.pending_tool_results,
|
||||
})
|
||||
state.pending_tool_results = None
|
||||
# Next content belongs to a new segment
|
||||
state.segment_id += 1
|
||||
|
||||
# Run the agent with our event callback ---------------------------
|
||||
logger.info(f"[ChatService] Starting agent run: session={session_id}, query={query[:80]}")
|
||||
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = agent.get_full_system_prompt()
|
||||
|
||||
# Create a copy of messages for this execution
|
||||
with agent.messages_lock:
|
||||
messages_copy = agent.messages.copy()
|
||||
original_length = len(agent.messages)
|
||||
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
|
||||
executor = AgentStreamExecutor(
|
||||
agent=agent,
|
||||
model=agent.model,
|
||||
system_prompt=full_system_prompt,
|
||||
tools=agent.tools,
|
||||
max_turns=agent.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy,
|
||||
max_context_turns=max_context_turns,
|
||||
)
|
||||
|
||||
try:
|
||||
response = executor.run_stream(query)
|
||||
except Exception:
|
||||
# If executor cleared messages (context overflow), sync back
|
||||
if len(executor.messages) == 0:
|
||||
with agent.messages_lock:
|
||||
agent.messages.clear()
|
||||
logger.info("[ChatService] Cleared agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Sync executor messages back to agent (thread-safe).
|
||||
# The executor may have trimmed context, making its list shorter than
|
||||
# original_length. In that case we must replace entirely — just
|
||||
# appending would leave stale pre-trim messages in agent.messages
|
||||
# and cause the same trim to fire on every subsequent request.
|
||||
with agent.messages_lock:
|
||||
trimmed = len(executor.messages) < original_length
|
||||
if trimmed:
|
||||
# Context was trimmed: the executor appended the new user
|
||||
# query *before* trimming, so the new messages (user +
|
||||
# assistant + tools) sit at the tail of the trimmed list.
|
||||
# We cannot simply slice at original_length (it exceeds the
|
||||
# list length). Instead, count how many messages the
|
||||
# executor added on top of the post-trim baseline.
|
||||
#
|
||||
# Timeline inside executor.run_stream:
|
||||
# 1. messages had `original_length` items
|
||||
# 2. append user query → original_length + 1
|
||||
# 3. _trim_messages() → some smaller number (includes the
|
||||
# user query because it belongs to the last turn)
|
||||
# 4. LLM replies / tool calls appended
|
||||
#
|
||||
# The user query message is always the first message of the
|
||||
# last turn (it cannot be trimmed away), so we locate it to
|
||||
# find where "new" messages begin.
|
||||
new_start = original_length # fallback
|
||||
for idx in range(len(executor.messages) - 1, -1, -1):
|
||||
msg = executor.messages[idx]
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", [])
|
||||
is_user_query = False
|
||||
if isinstance(content, list):
|
||||
has_text = any(
|
||||
isinstance(b, dict) and b.get("type") == "text"
|
||||
for b in content
|
||||
)
|
||||
has_tool_result = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
)
|
||||
is_user_query = has_text and not has_tool_result
|
||||
elif isinstance(content, str):
|
||||
is_user_query = True
|
||||
if is_user_query:
|
||||
new_start = idx
|
||||
break
|
||||
new_messages = list(executor.messages[new_start:])
|
||||
else:
|
||||
new_messages = list(executor.messages[original_length:])
|
||||
agent.messages = list(executor.messages)
|
||||
|
||||
# Persist new messages to SQLite so they survive restarts and
|
||||
# can be queried via the HISTORY interface.
|
||||
if new_messages:
|
||||
self._persist_messages(session_id, list(new_messages), channel_type)
|
||||
|
||||
# Store executor reference for files_to_send access
|
||||
agent.stream_executor = executor
|
||||
|
||||
# Execute post-process tools
|
||||
agent._execute_post_process_tools()
|
||||
|
||||
logger.info(f"[ChatService] Agent run completed: session={session_id}")
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _persist_messages(session_id: str, new_messages: list, channel_type: str = ""):
|
||||
try:
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().append_messages(
|
||||
session_id, new_messages, channel_type=channel_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[ChatService] Failed to persist messages for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
class _StreamState:
|
||||
"""Mutable state shared between the event callback and the run method."""
|
||||
|
||||
def __init__(self):
|
||||
self.segment_id: int = 0
|
||||
# None means we are not accumulating tool results right now.
|
||||
# A list means we are in the middle of a tool-execution phase.
|
||||
self.pending_tool_results: Optional[list] = None
|
||||
# Maps tool_call_id -> arguments captured from tool_execution_start,
|
||||
# so that tool_execution_end can attach the correct input args.
|
||||
self.pending_tool_arguments: dict = {}
|
||||
241
agent/chat/session_service.py
Normal file
241
agent/chat/session_service.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
SessionService - Manages multi-session lifecycle for both web channel and cloud client.
|
||||
|
||||
Provides a unified interface for listing, deleting, renaming, clearing context,
|
||||
and generating AI titles for conversation sessions. Backed by ConversationStore
|
||||
(SQLite) and AgentBridge (in-memory agent instances).
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def _truncate_fallback_title(user_message: str, max_len: int = 30) -> str:
|
||||
"""Pick the first non-empty line of the user message and truncate it."""
|
||||
if not user_message:
|
||||
return "New Chat"
|
||||
first_line = ""
|
||||
for line in user_message.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
first_line = line
|
||||
break
|
||||
if not first_line:
|
||||
return "New Chat"
|
||||
if len(first_line) > max_len:
|
||||
first_line = first_line[:max_len].rstrip() + "..."
|
||||
return first_line
|
||||
|
||||
|
||||
def generate_session_title(user_message: str, assistant_reply: str = "") -> str:
|
||||
"""
|
||||
Generate a short session title by calling the current bot's reply_text.
|
||||
Falls back to the first line of the user message if the LLM call fails
|
||||
or returns an obvious error sentinel.
|
||||
"""
|
||||
fallback = _truncate_fallback_title(user_message)
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
from models.session_manager import Session
|
||||
bot = Bridge().get_bot("chat")
|
||||
|
||||
prompt_parts = [f"User: {user_message[:300]}"]
|
||||
if assistant_reply:
|
||||
prompt_parts.append(f"Assistant: {assistant_reply[:300]}")
|
||||
|
||||
session = Session("__title_gen__", system_prompt="")
|
||||
session.messages = [
|
||||
{"role": "user", "content": (
|
||||
"Generate a very short title (max 15 characters for Chinese, max 6 words for English) "
|
||||
"summarizing this conversation. Return ONLY the title text, nothing else.\n\n"
|
||||
+ "\n".join(prompt_parts)
|
||||
)}
|
||||
]
|
||||
|
||||
result = bot.reply_text(session) or {}
|
||||
# When bots fail (network error, auth error, rate limit, etc.) they
|
||||
# typically return completion_tokens=0 with a sentinel content like
|
||||
# "请再问我一次吧" / "我现在有点累了". Treat that as failure.
|
||||
completion_tokens = result.get("completion_tokens", 0) or 0
|
||||
raw = (result.get("content") or "").strip()
|
||||
if completion_tokens <= 0:
|
||||
logger.warning(
|
||||
f"[SessionService] Title generation got empty completion "
|
||||
f"(completion_tokens={completion_tokens}, content='{raw[:50]}'), "
|
||||
f"using fallback")
|
||||
return fallback
|
||||
|
||||
title = re.sub(r'<think>.*?</think>', '', raw, flags=re.DOTALL).strip().strip('"\'')
|
||||
logger.info(f"[SessionService] Title generation result: '{title}' (len={len(title)})")
|
||||
if title and len(title) <= 50:
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.warning(f"[SessionService] Title generation failed: {e}")
|
||||
return fallback
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""
|
||||
High-level service for session lifecycle management.
|
||||
|
||||
Usage:
|
||||
svc = SessionService()
|
||||
result = svc.dispatch("list", {"channel_type": "web", "page": 1})
|
||||
"""
|
||||
|
||||
def _get_store(self):
|
||||
from agent.memory import get_conversation_store
|
||||
return get_conversation_store()
|
||||
|
||||
def _remove_agent(self, session_id: str):
|
||||
"""Remove the in-memory Agent instance for a session if it exists."""
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
ab = Bridge().get_agent_bridge()
|
||||
if session_id in ab.agents:
|
||||
del ab.agents[session_id]
|
||||
logger.info(f"[SessionService] Removed agent instance: {session_id}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _normalize_sid(session_id: str) -> str:
|
||||
if session_id and not session_id.startswith("session_"):
|
||||
return f"session_{session_id}"
|
||||
return session_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# actions
|
||||
# ------------------------------------------------------------------
|
||||
def list_sessions(self, channel_type: Optional[str] = None,
|
||||
page: int = 1, page_size: int = 50) -> dict:
|
||||
store = self._get_store()
|
||||
return store.list_sessions(
|
||||
channel_type=channel_type,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
def delete_session(self, session_id: str) -> None:
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
store.clear_session(session_id)
|
||||
self._remove_agent(session_id)
|
||||
logger.info(f"[SessionService] Session deleted: {session_id}")
|
||||
|
||||
def rename_session(self, session_id: str, title: str) -> None:
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
if not title:
|
||||
raise ValueError("title required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
found = store.rename_session(session_id, title)
|
||||
if not found:
|
||||
raise ValueError("session not found")
|
||||
|
||||
def clear_context(self, session_id: str) -> int:
|
||||
"""
|
||||
Set context boundary. Returns the new context_start_seq value.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
store = self._get_store()
|
||||
new_seq = store.clear_context(session_id)
|
||||
self._remove_agent(session_id)
|
||||
return new_seq
|
||||
|
||||
def gen_title(self, session_id: str, user_message: str,
|
||||
assistant_reply: str = "") -> str:
|
||||
"""
|
||||
Generate an AI title and persist it. Returns the generated title.
|
||||
"""
|
||||
if not session_id:
|
||||
raise ValueError("session_id required")
|
||||
if not user_message:
|
||||
raise ValueError("user_message required")
|
||||
session_id = self._normalize_sid(session_id)
|
||||
|
||||
title = generate_session_title(user_message, assistant_reply)
|
||||
|
||||
store = self._get_store()
|
||||
updated = store.rename_session(session_id, title)
|
||||
logger.info(f"[SessionService] Title set: sid={session_id}, "
|
||||
f"title='{title}', db_updated={updated}")
|
||||
return title
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a session management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
Action names use a ``*_session`` / session-prefixed convention so they
|
||||
can coexist with history actions (e.g. ``query``) on the same HISTORY
|
||||
message channel without ambiguity.
|
||||
|
||||
Supported actions:
|
||||
- list_sessions: list sessions with pagination
|
||||
- delete_session: delete a session
|
||||
- rename_session: rename a session title
|
||||
- clear_context: set context boundary
|
||||
- generate_title: AI-generate a session title
|
||||
|
||||
:param action: one of the above action names
|
||||
:param payload: action-specific payload
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list_sessions":
|
||||
result = self.list_sessions(
|
||||
channel_type=payload.get("channel_type"),
|
||||
page=int(payload.get("page", 1)),
|
||||
page_size=int(payload.get("page_size", 50)),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "delete_session":
|
||||
self.delete_session(payload.get("session_id", ""))
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
|
||||
elif action == "rename_session":
|
||||
self.rename_session(
|
||||
payload.get("session_id", ""),
|
||||
payload.get("title", "").strip(),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
|
||||
elif action == "clear_context":
|
||||
new_seq = self.clear_context(payload.get("session_id", ""))
|
||||
return {"action": action, "code": 200, "message": "success",
|
||||
"payload": {"context_start_seq": new_seq}}
|
||||
|
||||
elif action == "generate_title":
|
||||
title = self.gen_title(
|
||||
payload.get("session_id", ""),
|
||||
payload.get("user_message", ""),
|
||||
payload.get("assistant_reply", ""),
|
||||
)
|
||||
return {"action": action, "code": 200, "message": "success",
|
||||
"payload": {"title": title}}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400,
|
||||
"message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 400, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SessionService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
0
agent/knowledge/__init__.py
Normal file
0
agent/knowledge/__init__.py
Normal file
240
agent/knowledge/service.py
Normal file
240
agent/knowledge/service.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Knowledge service for handling knowledge base operations.
|
||||
|
||||
Provides a unified interface for listing, reading, and graphing knowledge files,
|
||||
callable from the web console, API, or CLI.
|
||||
|
||||
Knowledge file layout (under workspace_root):
|
||||
knowledge/index.md
|
||||
knowledge/log.md
|
||||
knowledge/<category>/<slug>.md
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
"""
|
||||
High-level service for knowledge base queries.
|
||||
Operates directly on the filesystem.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
self.workspace_root = workspace_root
|
||||
self.knowledge_dir = os.path.join(workspace_root, "knowledge")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — directory tree with stats
|
||||
# ------------------------------------------------------------------
|
||||
def list_tree(self) -> dict:
|
||||
"""
|
||||
Return the knowledge directory tree grouped by category,
|
||||
supporting arbitrarily nested sub-directories.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"tree": [
|
||||
{
|
||||
"dir": "concepts",
|
||||
"files": [
|
||||
{"name": "moe.md", "title": "MoE", "size": 1234},
|
||||
],
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"dir": "platform",
|
||||
"files": [],
|
||||
"children": [
|
||||
{
|
||||
"dir": "analysis",
|
||||
"files": [{"name": "perf.md", ...}],
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
],
|
||||
"stats": {"pages": 15, "size": 32768},
|
||||
"enabled": true
|
||||
}
|
||||
"""
|
||||
if not os.path.isdir(self.knowledge_dir):
|
||||
return {"tree": [], "stats": {"pages": 0, "size": 0}, "enabled": conf().get("knowledge", True)}
|
||||
|
||||
stats = {"pages": 0, "size": 0}
|
||||
root_files, tree = self._scan_dir(self.knowledge_dir, stats, is_root=True)
|
||||
|
||||
return {
|
||||
"root_files": root_files,
|
||||
"tree": tree,
|
||||
"stats": stats,
|
||||
"enabled": conf().get("knowledge", True),
|
||||
}
|
||||
|
||||
def _scan_dir(self, dir_path: str, stats: dict, is_root: bool = False) -> tuple:
|
||||
"""
|
||||
Recursively scan a directory.
|
||||
|
||||
:return: (files, children) where files is a list of .md file dicts
|
||||
in this directory and children is a list of sub-directory nodes.
|
||||
"""
|
||||
files = []
|
||||
children = []
|
||||
for name in sorted(os.listdir(dir_path)):
|
||||
if name.startswith("."):
|
||||
continue
|
||||
full = os.path.join(dir_path, name)
|
||||
if os.path.isdir(full):
|
||||
sub_files, sub_children = self._scan_dir(full, stats)
|
||||
children.append({"dir": name, "files": sub_files, "children": sub_children})
|
||||
elif name.endswith(".md"):
|
||||
size = os.path.getsize(full)
|
||||
if not is_root:
|
||||
stats["pages"] += 1
|
||||
stats["size"] += size
|
||||
title = name.replace(".md", "")
|
||||
try:
|
||||
with open(full, "r", encoding="utf-8") as f:
|
||||
first_line = f.readline().strip()
|
||||
if first_line.startswith("# "):
|
||||
title = first_line[2:].strip()
|
||||
except Exception:
|
||||
pass
|
||||
files.append({"name": name, "title": title, "size": size})
|
||||
return files, children
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# read — single file content
|
||||
# ------------------------------------------------------------------
|
||||
def read_file(self, rel_path: str) -> dict:
|
||||
"""
|
||||
Read a single knowledge markdown file.
|
||||
|
||||
:param rel_path: Relative path within knowledge/, e.g. ``concepts/moe.md``
|
||||
:return: dict with ``content`` and ``path``
|
||||
:raises ValueError: if path is invalid or escapes knowledge dir
|
||||
:raises FileNotFoundError: if file does not exist
|
||||
"""
|
||||
if not rel_path or ".." in rel_path:
|
||||
raise ValueError("invalid path")
|
||||
|
||||
full_path = os.path.normpath(os.path.join(self.knowledge_dir, rel_path))
|
||||
allowed = os.path.normpath(self.knowledge_dir)
|
||||
if not full_path.startswith(allowed + os.sep) and full_path != allowed:
|
||||
raise ValueError("path outside knowledge dir")
|
||||
|
||||
if not os.path.isfile(full_path):
|
||||
raise FileNotFoundError(f"file not found: {rel_path}")
|
||||
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
return {"content": content, "path": rel_path}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# graph — nodes and links for visualization
|
||||
# ------------------------------------------------------------------
|
||||
def build_graph(self) -> dict:
|
||||
"""
|
||||
Parse all knowledge pages and extract cross-reference links.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"nodes": [
|
||||
{"id": "concepts/moe.md", "label": "MoE", "category": "concepts"},
|
||||
...
|
||||
],
|
||||
"links": [
|
||||
{"source": "concepts/moe.md", "target": "entities/deepseek.md"},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
knowledge_path = Path(self.knowledge_dir)
|
||||
if not knowledge_path.is_dir():
|
||||
return {"nodes": [], "links": []}
|
||||
|
||||
nodes = {}
|
||||
links = []
|
||||
link_re = re.compile(r'\[([^\]]*)\]\(([^)]+\.md)\)')
|
||||
|
||||
for md_file in knowledge_path.rglob("*.md"):
|
||||
rel = str(md_file.relative_to(knowledge_path))
|
||||
if rel in ("index.md", "log.md"):
|
||||
continue
|
||||
parts = rel.split("/")
|
||||
category = parts[0] if len(parts) > 1 else "root"
|
||||
title = md_file.stem.replace("-", " ").title()
|
||||
try:
|
||||
content = md_file.read_text(encoding="utf-8")
|
||||
first_line = content.strip().split("\n")[0]
|
||||
if first_line.startswith("# "):
|
||||
title = first_line[2:].strip()
|
||||
for _, link_target in link_re.findall(content):
|
||||
resolved = (md_file.parent / link_target).resolve()
|
||||
try:
|
||||
target_rel = str(resolved.relative_to(knowledge_path))
|
||||
except ValueError:
|
||||
continue
|
||||
if target_rel != rel:
|
||||
links.append({"source": rel, "target": target_rel})
|
||||
except Exception:
|
||||
pass
|
||||
nodes[rel] = {"id": rel, "label": title, "category": category}
|
||||
|
||||
valid_ids = set(nodes.keys())
|
||||
links = [l for l in links if l["source"] in valid_ids and l["target"] in valid_ids]
|
||||
seen = set()
|
||||
deduped = []
|
||||
for l in links:
|
||||
key = tuple(sorted([l["source"], l["target"]]))
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
deduped.append(l)
|
||||
|
||||
return {"nodes": list(nodes.values()), "links": deduped}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a knowledge management action.
|
||||
|
||||
:param action: ``list``, ``read``, or ``graph``
|
||||
:param payload: action-specific payload
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
result = self.list_tree()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "read":
|
||||
path = payload.get("path")
|
||||
if not path:
|
||||
return {"action": action, "code": 400, "message": "path is required", "payload": None}
|
||||
result = self.read_file(path)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
elif action == "graph":
|
||||
result = self.build_graph()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 403, "message": str(e), "payload": None}
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[KnowledgeService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
23
agent/memory/__init__.py
Normal file
23
agent/memory/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Memory module for AgentMesh
|
||||
|
||||
Provides both long-term memory (vector/keyword search) and short-term
|
||||
conversation history persistence (SQLite).
|
||||
"""
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
from agent.memory.embedding import create_embedding_provider
|
||||
from agent.memory.conversation_store import ConversationStore, get_conversation_store
|
||||
from agent.memory.summarizer import ensure_daily_memory_file
|
||||
|
||||
__all__ = [
|
||||
'MemoryManager',
|
||||
'MemoryConfig',
|
||||
'get_default_memory_config',
|
||||
'set_global_memory_config',
|
||||
'create_embedding_provider',
|
||||
'ConversationStore',
|
||||
'get_conversation_store',
|
||||
'ensure_daily_memory_file',
|
||||
]
|
||||
140
agent/memory/chunker.py
Normal file
140
agent/memory/chunker.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Text chunking utilities for memory
|
||||
|
||||
Splits text into chunks with token limits and overlap
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""Represents a text chunk with line numbers"""
|
||||
text: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
|
||||
|
||||
class TextChunker:
|
||||
"""Chunks text by line count with token estimation"""
|
||||
|
||||
def __init__(self, max_tokens: int = 500, overlap_tokens: int = 50):
|
||||
"""
|
||||
Initialize chunker
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum tokens per chunk
|
||||
overlap_tokens: Overlap tokens between chunks
|
||||
"""
|
||||
self.max_tokens = max_tokens
|
||||
self.overlap_tokens = overlap_tokens
|
||||
# Rough estimation: ~4 chars per token for English/Chinese mixed
|
||||
self.chars_per_token = 4
|
||||
|
||||
def chunk_text(self, text: str) -> List[TextChunk]:
|
||||
"""
|
||||
Chunk text into overlapping segments
|
||||
|
||||
Args:
|
||||
text: Input text to chunk
|
||||
|
||||
Returns:
|
||||
List of TextChunk objects
|
||||
"""
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
lines = text.split('\n')
|
||||
chunks = []
|
||||
|
||||
max_chars = self.max_tokens * self.chars_per_token
|
||||
overlap_chars = self.overlap_tokens * self.chars_per_token
|
||||
|
||||
current_chunk = []
|
||||
current_chars = 0
|
||||
start_line = 1
|
||||
|
||||
for i, line in enumerate(lines, start=1):
|
||||
line_chars = len(line)
|
||||
|
||||
# If single line exceeds max, split it
|
||||
if line_chars > max_chars:
|
||||
# Save current chunk if exists
|
||||
if current_chunk:
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=i - 1
|
||||
))
|
||||
current_chunk = []
|
||||
current_chars = 0
|
||||
|
||||
# Split long line into multiple chunks
|
||||
for sub_chunk in self._split_long_line(line, max_chars):
|
||||
chunks.append(TextChunk(
|
||||
text=sub_chunk,
|
||||
start_line=i,
|
||||
end_line=i
|
||||
))
|
||||
|
||||
start_line = i + 1
|
||||
continue
|
||||
|
||||
# Check if adding this line would exceed limit
|
||||
if current_chars + line_chars > max_chars and current_chunk:
|
||||
# Save current chunk
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=i - 1
|
||||
))
|
||||
|
||||
# Start new chunk with overlap
|
||||
overlap_lines = self._get_overlap_lines(current_chunk, overlap_chars)
|
||||
current_chunk = overlap_lines + [line]
|
||||
current_chars = sum(len(l) for l in current_chunk)
|
||||
start_line = i - len(overlap_lines)
|
||||
else:
|
||||
# Add line to current chunk
|
||||
current_chunk.append(line)
|
||||
current_chars += line_chars
|
||||
|
||||
# Save last chunk
|
||||
if current_chunk:
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=len(lines)
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_long_line(self, line: str, max_chars: int) -> List[str]:
|
||||
"""Split a single long line into multiple chunks"""
|
||||
chunks = []
|
||||
for i in range(0, len(line), max_chars):
|
||||
chunks.append(line[i:i + max_chars])
|
||||
return chunks
|
||||
|
||||
def _get_overlap_lines(self, lines: List[str], target_chars: int) -> List[str]:
|
||||
"""Get last few lines that fit within target_chars for overlap"""
|
||||
overlap = []
|
||||
chars = 0
|
||||
|
||||
for line in reversed(lines):
|
||||
line_chars = len(line)
|
||||
if chars + line_chars > target_chars:
|
||||
break
|
||||
overlap.insert(0, line)
|
||||
chars += line_chars
|
||||
|
||||
return overlap
|
||||
|
||||
def chunk_markdown(self, text: str) -> List[TextChunk]:
|
||||
"""
|
||||
Chunk markdown text while respecting structure
|
||||
(For future enhancement: respect markdown sections)
|
||||
"""
|
||||
return self.chunk_text(text)
|
||||
122
agent/memory/config.py
Normal file
122
agent/memory/config.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Memory configuration module
|
||||
|
||||
Provides global memory configuration with simplified workspace structure
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _default_workspace():
|
||||
"""Get default workspace path with proper Windows support"""
|
||||
from common.utils import expand_path
|
||||
return expand_path("~/cow")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig:
|
||||
"""Configuration for memory storage and search"""
|
||||
|
||||
# Storage paths (default: ~/cow)
|
||||
workspace_root: str = field(default_factory=_default_workspace)
|
||||
|
||||
# Embedding config
|
||||
embedding_provider: str = "openai" # "openai" | "local"
|
||||
embedding_model: str = "text-embedding-3-small"
|
||||
embedding_dim: int = 1536
|
||||
|
||||
# Chunking config
|
||||
chunk_max_tokens: int = 500
|
||||
chunk_overlap_tokens: int = 50
|
||||
|
||||
# Search config
|
||||
max_results: int = 10
|
||||
min_score: float = 0.1
|
||||
|
||||
# Hybrid search weights
|
||||
vector_weight: float = 0.7
|
||||
keyword_weight: float = 0.3
|
||||
|
||||
# Memory sources
|
||||
sources: List[str] = field(default_factory=lambda: ["memory", "session"])
|
||||
|
||||
# Sync config
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
return Path(self.workspace_root)
|
||||
|
||||
def get_memory_dir(self) -> Path:
|
||||
"""Get memory files directory"""
|
||||
return self.get_workspace() / "memory"
|
||||
|
||||
def get_db_path(self) -> Path:
|
||||
"""Get SQLite database path for long-term memory index"""
|
||||
index_dir = self.get_memory_dir() / "long-term"
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
return index_dir / "index.db"
|
||||
|
||||
def get_skills_dir(self) -> Path:
|
||||
"""Get skills directory"""
|
||||
return self.get_workspace() / "skills"
|
||||
|
||||
def get_agent_workspace(self, agent_name: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get workspace directory for an agent
|
||||
|
||||
Args:
|
||||
agent_name: Optional agent name (not used in current implementation)
|
||||
|
||||
Returns:
|
||||
Path to workspace directory
|
||||
"""
|
||||
workspace = self.get_workspace()
|
||||
# Ensure workspace directory exists
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
return workspace
|
||||
|
||||
|
||||
# Global memory configuration
|
||||
_global_memory_config: Optional[MemoryConfig] = None
|
||||
|
||||
|
||||
def get_default_memory_config() -> MemoryConfig:
|
||||
"""
|
||||
Get the global memory configuration.
|
||||
If not set, returns a default configuration.
|
||||
|
||||
Returns:
|
||||
MemoryConfig instance
|
||||
"""
|
||||
global _global_memory_config
|
||||
if _global_memory_config is None:
|
||||
_global_memory_config = MemoryConfig()
|
||||
return _global_memory_config
|
||||
|
||||
|
||||
def set_global_memory_config(config: MemoryConfig):
|
||||
"""
|
||||
Set the global memory configuration.
|
||||
This should be called before creating any MemoryManager instances.
|
||||
|
||||
Args:
|
||||
config: MemoryConfig instance to use globally
|
||||
|
||||
Example:
|
||||
>>> from agent.memory import MemoryConfig, set_global_memory_config
|
||||
>>> config = MemoryConfig(
|
||||
... workspace_root="~/my_agents",
|
||||
... embedding_provider="openai",
|
||||
... vector_weight=0.8
|
||||
... )
|
||||
>>> set_global_memory_config(config)
|
||||
"""
|
||||
global _global_memory_config
|
||||
_global_memory_config = config
|
||||
1055
agent/memory/conversation_store.py
Normal file
1055
agent/memory/conversation_store.py
Normal file
File diff suppressed because it is too large
Load Diff
41
agent/memory/embedding/__init__.py
Normal file
41
agent/memory/embedding/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Embedding subsystem for memory.
|
||||
|
||||
Public API:
|
||||
create_embedding_provider, EmbeddingProvider, OpenAIEmbeddingProvider,
|
||||
EMBEDDING_VENDORS, EmbeddingCache
|
||||
RebuildResult, clear_index, rebuild_in_process
|
||||
detect_index_dim, cleanup_legacy_state_file
|
||||
"""
|
||||
|
||||
from agent.memory.embedding.provider import (
|
||||
EMBEDDING_VENDORS,
|
||||
DoubaoEmbeddingProvider,
|
||||
EmbeddingCache,
|
||||
EmbeddingProvider,
|
||||
OpenAIEmbeddingProvider,
|
||||
create_embedding_provider,
|
||||
)
|
||||
from agent.memory.embedding.rebuild import (
|
||||
RebuildResult,
|
||||
clear_index,
|
||||
rebuild_in_process,
|
||||
)
|
||||
from agent.memory.embedding.state import (
|
||||
cleanup_legacy_state_file,
|
||||
detect_index_dim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EMBEDDING_VENDORS",
|
||||
"DoubaoEmbeddingProvider",
|
||||
"EmbeddingCache",
|
||||
"EmbeddingProvider",
|
||||
"OpenAIEmbeddingProvider",
|
||||
"create_embedding_provider",
|
||||
"RebuildResult",
|
||||
"clear_index",
|
||||
"rebuild_in_process",
|
||||
"cleanup_legacy_state_file",
|
||||
"detect_index_dim",
|
||||
]
|
||||
486
agent/memory/embedding/provider.py
Normal file
486
agent/memory/embedding/provider.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports multiple OpenAI-compatible embedding vendors:
|
||||
- openai (text-embedding-3-small / large)
|
||||
- linkai (OpenAI-compatible passthrough)
|
||||
- dashscope (Aliyun Tongyi text-embedding-v4)
|
||||
- doubao (ByteDance Doubao Seed1.5 / large-text on Volcengine Ark)
|
||||
- zhipu (ZhipuAI embedding-3)
|
||||
|
||||
Vendor keys here intentionally match the project's bot_type constants in
|
||||
common.const (OPENAI, LINKAI, QWEN_DASHSCOPE, DOUBAO, ZHIPU_AI).
|
||||
|
||||
All providers share a single OpenAI-compatible REST client. Vendor-specific
|
||||
behaviors (truncation, query instruction prefix) are configured via metadata.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
# HTTP read timeout for a single embeddings request (seconds). A batch of
|
||||
# 64+ chunks can take 30-50s end-to-end from China-side networks, so 30s is
|
||||
# routinely too tight; 90s gives meaningful headroom without letting bad
|
||||
# endpoints hang forever.
|
||||
EMBEDDING_HTTP_TIMEOUT = 90
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a single text (treated as a query by default)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts (treated as documents)"""
|
||||
pass
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a query string (may apply vendor instruction prefix)"""
|
||||
return self.embed(text)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Effective embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vendor metadata table
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Each entry describes how to reach a vendor's embedding endpoint. Most
|
||||
# vendors expose an OpenAI-compatible /embeddings API; the few that don't
|
||||
# (currently: doubao) set `provider_class` to pick a dedicated adapter.
|
||||
# Fields:
|
||||
# provider_class : optional adapter key ("doubao"); defaults to OpenAI-compat
|
||||
# default_base_url : default API base when not overridden by user
|
||||
# default_model : default embedding model name
|
||||
# default_dimensions : recommended unified dim when explicit path is enabled
|
||||
# supports_dim_param : whether the API accepts a `dimensions` request param
|
||||
# needs_client_truncate : whether to slice + L2-normalize on the client side
|
||||
# needs_client_normalize : whether to L2-normalize on the client (always safe)
|
||||
# query_instruction : optional prefix for asymmetric retrieval (Doubao Seed)
|
||||
# max_batch_size : max texts per /embeddings request; embed_batch
|
||||
# auto-paginates above this. Conservative defaults.
|
||||
#
|
||||
EMBEDDING_VENDORS = {
|
||||
"openai": {
|
||||
"default_base_url": "https://api.openai.com/v1",
|
||||
"default_model": "text-embedding-3-small",
|
||||
# Match the legacy default so users adding `embedding_provider: openai`
|
||||
# to an existing index don't need to rebuild. Override via
|
||||
# embedding_dimensions if you want 1024 / 1536 / 3072.
|
||||
"default_dimensions": 1536,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
# OpenAI permits up to 2048 items per request, but a single call
|
||||
# carrying hundreds of long chunks routinely exceeds the 30s read
|
||||
# timeout from China-side networks. 64 keeps each call well under
|
||||
# both the token-per-request budget and a reasonable wall clock.
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
"linkai": {
|
||||
"default_base_url": "https://api.link-ai.tech/v1",
|
||||
"default_model": "text-embedding-3-small",
|
||||
"default_dimensions": 1536,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
"dashscope": {
|
||||
"default_base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"default_model": "text-embedding-v4",
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 10, # DashScope hard cap (text-embedding-v4)
|
||||
},
|
||||
"doubao": {
|
||||
# Doubao no longer offers an OpenAI-compatible /v1/embeddings endpoint.
|
||||
# Current models are unified under /api/v3/embeddings/multimodal
|
||||
# which uses a structured `input` payload — see DoubaoEmbeddingProvider.
|
||||
"provider_class": "doubao",
|
||||
"default_base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||||
"default_model": "doubao-embedding-vision-251215",
|
||||
# Native options: 1024 or 2048. We default to 1024 to align with the
|
||||
# other Chinese vendors (dashscope/zhipu) and keep storage footprint
|
||||
# consistent across providers; users can still override via
|
||||
# `embedding_dimensions: 2048` in config.
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
# Multimodal endpoint produces ONE embedding per call (input list is
|
||||
# a single document's parts, not a batch). embed_batch loops.
|
||||
"max_batch_size": 1,
|
||||
},
|
||||
"zhipu": {
|
||||
"default_base_url": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"default_model": "embedding-3",
|
||||
"default_dimensions": 1024,
|
||||
"supports_dim_param": True,
|
||||
"needs_client_truncate": False,
|
||||
"needs_client_normalize": False,
|
||||
"query_instruction": "",
|
||||
"max_batch_size": 64,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _l2_normalize(vec: List[float]) -> List[float]:
|
||||
"""Normalize a vector to unit length (L2 norm). Returns input on zero vector."""
|
||||
norm = math.sqrt(sum(v * v for v in vec))
|
||||
if norm == 0:
|
||||
return vec
|
||||
return [v / norm for v in vec]
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
OpenAI-compatible embedding provider.
|
||||
|
||||
Used for openai/linkai/dashscope/ark/zhipu by configuring the metadata
|
||||
fields. The legacy two-arg constructor (model, api_key, api_base) keeps
|
||||
working, so the original OpenAI/LinkAI fallback code path is unchanged.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "text-embedding-3-small",
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
supports_dim_param: bool = True,
|
||||
needs_client_truncate: bool = False,
|
||||
needs_client_normalize: bool = False,
|
||||
query_instruction: str = "",
|
||||
max_batch_size: int = 256,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Model name (e.g. text-embedding-3-small, text-embedding-v4, embedding-3)
|
||||
api_key: API key (required)
|
||||
api_base: API base URL (defaults to OpenAI)
|
||||
extra_headers: Optional extra HTTP headers
|
||||
dimensions: Target output dimension. Required when supports_dim_param
|
||||
is False and needs_client_truncate is True (used to slice).
|
||||
supports_dim_param: Whether the vendor accepts a `dimensions` body param
|
||||
needs_client_truncate: Slice the returned vector to `dimensions`
|
||||
needs_client_normalize: L2-normalize on the client after slicing
|
||||
query_instruction: Optional prefix prepended to query texts only
|
||||
max_batch_size: Max items per /embeddings request; embed_batch
|
||||
auto-paginates above this.
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.supports_dim_param = supports_dim_param
|
||||
self.needs_client_truncate = needs_client_truncate
|
||||
self.needs_client_normalize = needs_client_normalize
|
||||
self.query_instruction = query_instruction or ""
|
||||
self.max_batch_size = max(1, int(max_batch_size or 1))
|
||||
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("Embedding API key is not configured")
|
||||
|
||||
if dimensions is not None and dimensions > 0:
|
||||
self._dimensions = dimensions
|
||||
else:
|
||||
# Legacy heuristic for OpenAI text-embedding-3-* family
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI-compatible /embeddings endpoint"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**self.extra_headers,
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model,
|
||||
}
|
||||
if self.supports_dim_param and self._dimensions:
|
||||
data["dimensions"] = self._dimensions
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=EMBEDDING_HTTP_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to embedding API at {url}. "
|
||||
f"Please check network and api_base. Error: {str(e)}"
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"Embedding API request timed out. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError("Invalid embedding API key")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError("Embedding API rate limit exceeded")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Embedding API request failed: "
|
||||
f"{e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
|
||||
def _post_process(self, raw: List[float]) -> List[float]:
|
||||
"""Apply optional client-side truncation + normalization"""
|
||||
vec = raw
|
||||
if self.needs_client_truncate and self._dimensions and len(vec) > self._dimensions:
|
||||
vec = vec[: self._dimensions]
|
||||
if self.needs_client_normalize:
|
||||
vec = _l2_normalize(vec)
|
||||
return vec
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding (treated as document by default)"""
|
||||
result = self._call_api(text)
|
||||
return self._post_process(result["data"][0]["embedding"])
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Generate embedding for a query (applies vendor instruction prefix if any)"""
|
||||
if self.query_instruction:
|
||||
text = f"{self.query_instruction}{text}"
|
||||
return self.embed(text)
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple documents.
|
||||
|
||||
Automatically paginates by self.max_batch_size so callers can pass any
|
||||
number of texts. Order of returned vectors matches the input order.
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
out: List[List[float]] = []
|
||||
step = self.max_batch_size
|
||||
for i in range(0, len(texts), step):
|
||||
chunk = texts[i:i + step]
|
||||
result = self._call_api(chunk)
|
||||
out.extend(self._post_process(item["embedding"]) for item in result["data"])
|
||||
return out
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class DoubaoEmbeddingProvider(EmbeddingProvider):
|
||||
"""
|
||||
Doubao (Volcengine Ark) multimodal embedding provider.
|
||||
|
||||
Doubao deprecated their OpenAI-compatible /v1/embeddings endpoint and
|
||||
unified everything under /api/v3/embeddings/multimodal, which uses a
|
||||
structured `input: [{type, text|image_url|video_url}, ...]` payload.
|
||||
|
||||
Notes:
|
||||
* The endpoint produces ONE embedding per call (input list is multiple
|
||||
modality parts of a single document, not a batch). embed_batch
|
||||
therefore loops per-text — no native batch support.
|
||||
* Native dimensions: 1024 or 2048 (default 1024 to align with other
|
||||
Chinese vendors). No client-side truncation needed.
|
||||
* Auth: Bearer ARK API key.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://ark.cn-beijing.volces.com/api/v3"
|
||||
self.extra_headers = extra_headers or {}
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("Doubao embedding API key (ark_api_key) is not configured")
|
||||
|
||||
if dimensions in (1024, 2048):
|
||||
self._dimensions = dimensions
|
||||
elif dimensions is None:
|
||||
self._dimensions = 1024
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Doubao embedding dimensions must be 1024 or 2048, got {dimensions}"
|
||||
)
|
||||
|
||||
def _call_api(self, text: str) -> List[float]:
|
||||
"""One call → one embedding. multimodal endpoint takes a single
|
||||
document represented as a list of typed parts; we send a single
|
||||
text part."""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings/multimodal"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
**self.extra_headers,
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": [{"type": "text", "text": text}],
|
||||
"dimensions": self._dimensions,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=EMBEDDING_HTTP_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
body = response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to Doubao embedding API at {url}. "
|
||||
f"Please check network and api_base. Error: {str(e)}"
|
||||
)
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"Doubao embedding API request timed out. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError("Invalid Doubao (ark) embedding API key")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError("Doubao embedding API rate limit exceeded")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Doubao embedding API request failed: "
|
||||
f"{e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
|
||||
# Response shape per docs: {"data": {"embedding": [...]}}
|
||||
data = body.get("data")
|
||||
if isinstance(data, dict) and "embedding" in data:
|
||||
return data["embedding"]
|
||||
# Some providers wrap as a list of one — be defensive
|
||||
if isinstance(data, list) and data and "embedding" in data[0]:
|
||||
return data[0]["embedding"]
|
||||
raise ValueError(f"Unexpected Doubao embedding response shape: {body}")
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
return self._call_api(text)
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
# Endpoint produces one embedding per call; loop. Order preserved.
|
||||
return [self._call_api(t) for t in texts]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""In-memory cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
dimensions: Optional[int] = None,
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create an embedding provider.
|
||||
|
||||
Backward compatible: when called with provider in {"openai", "linkai"}
|
||||
and no `dimensions` arg, behaves exactly as before (1536-dim OpenAI).
|
||||
|
||||
New providers ("dashscope", "doubao", "zhipu") require explicit configuration
|
||||
and use the unified 1024-dim defaults from EMBEDDING_VENDORS.
|
||||
|
||||
Args:
|
||||
provider: Vendor key (one of EMBEDDING_VENDORS)
|
||||
model: Model name (uses vendor default if None)
|
||||
api_key: API key (required)
|
||||
api_base: API base URL (uses vendor default if None)
|
||||
extra_headers: Optional extra HTTP headers
|
||||
dimensions: Target output dimension (uses vendor default if None)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
"""
|
||||
meta = EMBEDDING_VENDORS.get(provider)
|
||||
if meta is None:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding provider: {provider}. "
|
||||
f"Supported: {sorted(EMBEDDING_VENDORS.keys())}"
|
||||
)
|
||||
|
||||
# Doubao uses a non-OpenAI-compatible multimodal endpoint.
|
||||
if meta.get("provider_class") == "doubao":
|
||||
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
||||
return DoubaoEmbeddingProvider(
|
||||
model=model or meta["default_model"],
|
||||
api_key=api_key,
|
||||
api_base=api_base or meta["default_base_url"],
|
||||
extra_headers=extra_headers,
|
||||
dimensions=final_dim,
|
||||
)
|
||||
|
||||
# Legacy two-arg call for openai/linkai keeps 1536-dim default behavior
|
||||
# so existing data isn't invalidated.
|
||||
is_legacy_call = (
|
||||
provider in ("openai", "linkai")
|
||||
and dimensions is None
|
||||
)
|
||||
if is_legacy_call:
|
||||
return OpenAIEmbeddingProvider(
|
||||
model=model or "text-embedding-3-small",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
final_dim = dimensions if (dimensions and dimensions > 0) else meta["default_dimensions"]
|
||||
return OpenAIEmbeddingProvider(
|
||||
model=model or meta["default_model"],
|
||||
api_key=api_key,
|
||||
api_base=api_base or meta["default_base_url"],
|
||||
extra_headers=extra_headers,
|
||||
dimensions=final_dim,
|
||||
supports_dim_param=meta["supports_dim_param"],
|
||||
needs_client_truncate=meta["needs_client_truncate"],
|
||||
needs_client_normalize=meta["needs_client_normalize"],
|
||||
query_instruction=meta["query_instruction"],
|
||||
max_batch_size=meta.get("max_batch_size", 256),
|
||||
)
|
||||
191
agent/memory/embedding/rebuild.py
Normal file
191
agent/memory/embedding/rebuild.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Rebuild memory vector index.
|
||||
|
||||
Recommended entry point (in-chat, while agent is running):
|
||||
/memory rebuild-index
|
||||
|
||||
Backward-compatible CLI entry (must run from project root):
|
||||
python -m agent.memory.rebuild_index
|
||||
|
||||
What it does:
|
||||
1. Probes the embedding endpoint with a tiny call to fail fast on
|
||||
bad provider/model/key — before touching the index.
|
||||
2. Clears the SQLite chunks/files tables (workspace markdown stays intact).
|
||||
3. Runs a fresh sync, regenerating embeddings with the currently configured
|
||||
provider/model/dimensions.
|
||||
|
||||
This is the only safe way to switch embedding_provider after the existing
|
||||
index has been populated by a different-dim model.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class RebuildResult:
|
||||
"""Outcome of a rebuild_in_process() call"""
|
||||
ok: bool
|
||||
removed: int = 0
|
||||
chunks: int = 0
|
||||
files: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def clear_index(db_path, storage=None) -> int:
|
||||
"""Wipe chunks/files, reset FTS5, and clean up any legacy state file.
|
||||
|
||||
Args:
|
||||
db_path: Path of the index DB (also used to locate the legacy state
|
||||
file for migration cleanup, and — when *storage* is None — to
|
||||
open a fresh connection).
|
||||
storage: Optional pre-opened MemoryStorage. When provided we reuse it
|
||||
so the live connection's triggers stay in sync — opening a second
|
||||
connection would leave the original one's triggers pointing at a
|
||||
DROP'd chunks_fts table.
|
||||
|
||||
We reset (DROP+recreate) chunks_fts because its shadow tables can become
|
||||
inconsistent across rebuild cycles, causing bm25() / ORDER BY rank to
|
||||
raise "database disk image is malformed" even when raw MATCH still works.
|
||||
|
||||
Returns number of chunks removed.
|
||||
"""
|
||||
from agent.memory.embedding.state import cleanup_legacy_state_file
|
||||
from agent.memory.storage import MemoryStorage
|
||||
|
||||
owns_storage = storage is None
|
||||
if owns_storage:
|
||||
storage = MemoryStorage(db_path)
|
||||
try:
|
||||
before = storage.conn.execute("SELECT COUNT(*) FROM chunks").fetchone()[0]
|
||||
storage.conn.execute("DELETE FROM chunks")
|
||||
storage.conn.execute("DELETE FROM files")
|
||||
storage.conn.commit()
|
||||
storage.reset_fts5()
|
||||
finally:
|
||||
if owns_storage:
|
||||
storage.close()
|
||||
|
||||
cleanup_legacy_state_file(db_path)
|
||||
return int(before)
|
||||
|
||||
|
||||
def rebuild_in_process(memory_manager) -> RebuildResult:
|
||||
"""
|
||||
Rebuild the index using an existing, fully-initialized MemoryManager.
|
||||
|
||||
Used by the in-chat /memory rebuild-index command. The caller already has
|
||||
config loaded, embedding_provider built, and (optionally) the agent
|
||||
running, so we only need to:
|
||||
1. Clear chunks/files + state on the manager's storage.
|
||||
2. Re-sync (force=True).
|
||||
|
||||
NOTE: caller must ensure memory_manager.embedding_provider is set, otherwise
|
||||
sync() will silently skip embedding generation.
|
||||
"""
|
||||
if memory_manager is None:
|
||||
return RebuildResult(ok=False, error="memory_manager is None")
|
||||
if memory_manager.embedding_provider is None:
|
||||
return RebuildResult(ok=False, error="embedding_provider is not initialized")
|
||||
|
||||
# Probe the embedding endpoint BEFORE clearing the index. A bad
|
||||
# provider/model/key would otherwise leave the user with an empty index
|
||||
# that not even keyword search can serve.
|
||||
try:
|
||||
memory_manager.embedding_provider.embed_query("ping")
|
||||
except Exception as e:
|
||||
logger.error(f"[RebuildIndex] embedding probe failed, aborting rebuild: {e}")
|
||||
return RebuildResult(ok=False, error=f"embedding endpoint not reachable: {e}")
|
||||
|
||||
db_path = memory_manager.config.get_db_path()
|
||||
try:
|
||||
removed = clear_index(db_path, storage=memory_manager.storage)
|
||||
except Exception as e:
|
||||
logger.exception("[RebuildIndex] clear_index failed")
|
||||
return RebuildResult(ok=False, error=f"clear failed: {e}")
|
||||
|
||||
try:
|
||||
asyncio.run(memory_manager.sync(force=True))
|
||||
except RuntimeError:
|
||||
# Already inside a running event loop (rare in chat handler thread).
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
loop.run_until_complete(memory_manager.sync(force=True))
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.exception("[RebuildIndex] sync failed")
|
||||
return RebuildResult(ok=False, removed=removed, error=f"re-embed failed: {e}")
|
||||
|
||||
stats = memory_manager.storage.get_stats()
|
||||
chunks = int(stats.get("chunks", 0))
|
||||
embedded = int(stats.get("embedded", 0))
|
||||
|
||||
# sync() degrades to "no embeddings" on batch failure so keyword search
|
||||
# still works at startup — but in a /rebuild-index request the user
|
||||
# explicitly asked for vectors. Surface that as a failure.
|
||||
if chunks > 0 and embedded == 0:
|
||||
return RebuildResult(
|
||||
ok=False,
|
||||
removed=removed,
|
||||
chunks=chunks,
|
||||
files=int(stats.get("files", 0)),
|
||||
error=(
|
||||
"embedding API failed during sync; index now has chunks but no "
|
||||
"vectors. Check embedding provider/model/key and retry."
|
||||
),
|
||||
)
|
||||
|
||||
return RebuildResult(
|
||||
ok=True,
|
||||
removed=removed,
|
||||
chunks=chunks,
|
||||
files=int(stats.get("files", 0)),
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Standalone CLI entry. Must be run from project root (relative config path)."""
|
||||
from config import conf, load_config
|
||||
from agent.memory import MemoryConfig, MemoryManager
|
||||
|
||||
load_config()
|
||||
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
|
||||
logger.info(f"[RebuildIndex] Workspace: {workspace_root}")
|
||||
logger.info(f"[RebuildIndex] Index db: {memory_config.get_db_path()}")
|
||||
|
||||
from bridge.agent_initializer import AgentInitializer
|
||||
|
||||
initializer = AgentInitializer(bridge=None, agent_bridge=None)
|
||||
embedding_provider = initializer._init_embedding_provider(memory_config, session_id=None)
|
||||
if embedding_provider is None:
|
||||
logger.error(
|
||||
"[RebuildIndex] No embedding provider could be initialized. "
|
||||
"Check your config.json. Aborting rebuild."
|
||||
)
|
||||
return 1
|
||||
|
||||
manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
result = rebuild_in_process(manager)
|
||||
if not result.ok:
|
||||
logger.error(f"[RebuildIndex] {result.error}")
|
||||
return 1
|
||||
|
||||
logger.info(
|
||||
f"[RebuildIndex] Done. removed={result.removed}, "
|
||||
f"chunks={result.chunks}, files={result.files}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
51
agent/memory/embedding/state.py
Normal file
51
agent/memory/embedding/state.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Embedding-related index utilities.
|
||||
|
||||
We don't keep a sidecar state file — the SQLite index is the source of truth
|
||||
and config.json is the source of intent. The two functions below are the
|
||||
only things needing on-disk awareness:
|
||||
|
||||
detect_index_dim : read the dim of stored vectors (display-only)
|
||||
cleanup_legacy_state_file: remove old embedding_state.json from earlier
|
||||
versions; safe no-op when absent.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
PathLike = Union[str, os.PathLike]
|
||||
|
||||
|
||||
def detect_index_dim(storage) -> Optional[int]:
|
||||
"""Return the dim of the first stored embedding, or None if the index
|
||||
has no embeddings. Used by /memory status."""
|
||||
try:
|
||||
row = storage.conn.execute(
|
||||
"SELECT embedding FROM chunks WHERE embedding IS NOT NULL LIMIT 1"
|
||||
).fetchone()
|
||||
except Exception:
|
||||
return None
|
||||
if not row or not row["embedding"]:
|
||||
return None
|
||||
try:
|
||||
raw = row["embedding"]
|
||||
if isinstance(raw, (bytes, bytearray)):
|
||||
# New BLOB format: 4 bytes per float32
|
||||
return len(raw) // 4
|
||||
emb = json.loads(raw)
|
||||
return len(emb) if isinstance(emb, list) else None
|
||||
except (json.JSONDecodeError, TypeError, Exception):
|
||||
return None
|
||||
|
||||
|
||||
def cleanup_legacy_state_file(db_path: PathLike) -> None:
|
||||
"""Remove old embedding_state.json files from earlier versions.
|
||||
Safe to call repeatedly; no-op if the file is absent."""
|
||||
legacy = Path(db_path).parent / "embedding_state.json"
|
||||
try:
|
||||
legacy.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
555
agent/memory/manager.py
Normal file
555
agent/memory/manager.py
Normal file
@@ -0,0 +1,555 @@
|
||||
"""
|
||||
Memory manager for AgentMesh
|
||||
|
||||
Provides high-level interface for memory operations
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||
from agent.memory.chunker import TextChunker
|
||||
from agent.memory.embedding import EmbeddingProvider, EmbeddingCache
|
||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Memory manager with hybrid search capabilities
|
||||
|
||||
Provides long-term memory for agents with vector and keyword search
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[MemoryConfig] = None,
|
||||
embedding_provider: Optional[EmbeddingProvider] = None,
|
||||
llm_model: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Initialize memory manager
|
||||
|
||||
Args:
|
||||
config: Memory configuration (uses global config if not provided)
|
||||
embedding_provider: Custom embedding provider (optional)
|
||||
llm_model: LLM model for summarization (optional)
|
||||
"""
|
||||
self.config = config or get_default_memory_config()
|
||||
|
||||
# Initialize storage
|
||||
db_path = self.config.get_db_path()
|
||||
self.storage = MemoryStorage(db_path)
|
||||
|
||||
# Initialize chunker
|
||||
self.chunker = TextChunker(
|
||||
max_tokens=self.config.chunk_max_tokens,
|
||||
overlap_tokens=self.config.chunk_overlap_tokens
|
||||
)
|
||||
|
||||
# Embedding provider is owned by the caller (agent_initializer is the
|
||||
# canonical entry point and handles legacy/explicit + state validation).
|
||||
# When None is passed, memory degrades to keyword-only search instead
|
||||
# of silently re-initializing a vendor here, which would bypass the
|
||||
# caller's state checks and risk corrupting the index.
|
||||
self.embedding_provider = embedding_provider
|
||||
if self.embedding_provider is None:
|
||||
from common.log import logger
|
||||
logger.info(
|
||||
"[MemoryManager] No embedding provider; memory will use keyword search only"
|
||||
)
|
||||
|
||||
# Cache for query embeddings (avoids redundant API calls within a session)
|
||||
self._embedding_cache = EmbeddingCache()
|
||||
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
self.flush_manager = MemoryFlushManager(
|
||||
workspace_dir=workspace_dir,
|
||||
llm_model=llm_model
|
||||
)
|
||||
|
||||
# Ensure workspace directories exist
|
||||
self._init_workspace()
|
||||
|
||||
self._dirty = False
|
||||
|
||||
def _init_workspace(self):
|
||||
"""Initialize workspace directories"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create default memory files
|
||||
workspace_dir = self.config.get_workspace()
|
||||
create_memory_files_if_needed(workspace_dir)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
max_results: Optional[int] = None,
|
||||
min_score: Optional[float] = None,
|
||||
include_shared: bool = True
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search memory with hybrid search (vector + keyword)
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
user_id: User ID for scoped search
|
||||
max_results: Maximum results to return
|
||||
min_score: Minimum score threshold
|
||||
include_shared: Include shared memories
|
||||
|
||||
Returns:
|
||||
List of search results sorted by relevance
|
||||
"""
|
||||
max_results = max_results or self.config.max_results
|
||||
min_score = min_score or self.config.min_score
|
||||
|
||||
# Determine scopes
|
||||
scopes = []
|
||||
if include_shared:
|
||||
scopes.append("shared")
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
if not scopes:
|
||||
return []
|
||||
|
||||
# Sync if needed
|
||||
if self.config.sync_on_search and self._dirty:
|
||||
await self.sync()
|
||||
|
||||
from common.log import logger
|
||||
|
||||
# Perform vector search (if embedding provider available).
|
||||
# Failures degrade silently to keyword-only — no exception is raised.
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
try:
|
||||
provider_name = type(self.embedding_provider).__name__
|
||||
model_name = getattr(self.embedding_provider, 'model', '')
|
||||
cached = self._embedding_cache.get(query, provider_name, model_name)
|
||||
if cached is not None:
|
||||
query_embedding = cached
|
||||
else:
|
||||
query_embedding = self.embedding_provider.embed_query(query)
|
||||
self._embedding_cache.put(query, provider_name, model_name, query_embedding)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[MemoryManager] Vector search failed, falling back to keyword-only: {e}"
|
||||
)
|
||||
|
||||
# Perform keyword search (also runs as fallback when vector failed)
|
||||
keyword_results = self.storage.search_keyword(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
vector_results,
|
||||
keyword_results,
|
||||
self.config.vector_weight,
|
||||
self.config.keyword_weight
|
||||
)
|
||||
|
||||
# Filter by min score and limit
|
||||
filtered = [r for r in merged if r.score >= min_score]
|
||||
return filtered[:max_results]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
user_id: Optional[str] = None,
|
||||
scope: str = "shared",
|
||||
source: str = "memory",
|
||||
path: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Add new memory content
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
user_id: User ID for user-scoped memory
|
||||
scope: Memory scope ("shared", "user", "session")
|
||||
source: Memory source ("memory" or "session")
|
||||
path: File path (auto-generated if not provided)
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
if not content.strip():
|
||||
return
|
||||
|
||||
# Generate path if not provided
|
||||
if not path:
|
||||
content_hash = hashlib.md5(content.encode('utf-8')).hexdigest()[:8]
|
||||
if user_id and scope == "user":
|
||||
path = f"memory/users/{user_id}/memory_{content_hash}.md"
|
||||
else:
|
||||
path = f"memory/shared/memory_{content_hash}.md"
|
||||
|
||||
# Chunk content
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
|
||||
# Generate embeddings (if provider available)
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
# No embeddings, just use None
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
# Save to storage
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
self.storage.update_file_metadata(
|
||||
path=path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(os.path.getmtime(__file__)), # Use current time
|
||||
size=len(content)
|
||||
)
|
||||
|
||||
async def sync(self, force: bool = False):
|
||||
"""
|
||||
Synchronize memory from files.
|
||||
|
||||
Two-pass design to amortize embedding HTTP cost:
|
||||
1. Walk all files, chunk those whose hash changed, collect pending
|
||||
chunks across files. No embedding calls yet.
|
||||
2. Run a single embed_batch over the union of pending chunks (the
|
||||
provider auto-paginates by vendor cap), then persist per-file.
|
||||
|
||||
For workspaces with many small files (101 files / ~1 chunk each), this
|
||||
cuts ~100 HTTP calls down to ~ceil(total_chunks / vendor_cap).
|
||||
|
||||
Args:
|
||||
force: Force full reindex
|
||||
"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
workspace_dir = self.config.get_workspace()
|
||||
|
||||
files_to_scan: List[tuple] = [] # (file_path, source, scope, user_id)
|
||||
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
files_to_scan.append((memory_file, "memory", "shared", None))
|
||||
|
||||
if memory_dir.exists():
|
||||
for file_path in memory_dir.rglob("*.md"):
|
||||
rel_parts = file_path.relative_to(workspace_dir).parts
|
||||
if any(part.startswith('.') for part in rel_parts):
|
||||
continue
|
||||
# Dream diaries are narrative reflections produced by Deep
|
||||
# Dream; their factual content has already been distilled
|
||||
# into MEMORY.md. Indexing them adds noisy near-duplicates
|
||||
# that crowd out the authoritative entry in retrieval.
|
||||
if "dreams" in rel_parts:
|
||||
continue
|
||||
if "daily" in rel_parts:
|
||||
if "users" in rel_parts or len(rel_parts) > 3:
|
||||
user_idx = rel_parts.index("daily") + 1
|
||||
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
elif "users" in rel_parts:
|
||||
user_idx = rel_parts.index("users") + 1
|
||||
user_id = rel_parts[user_idx] if user_idx < len(rel_parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
files_to_scan.append((file_path, "memory", scope, user_id))
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
knowledge_dir = Path(workspace_dir) / "knowledge"
|
||||
if knowledge_dir.exists():
|
||||
for file_path in knowledge_dir.rglob("*.md"):
|
||||
files_to_scan.append((file_path, "knowledge", "shared", None))
|
||||
|
||||
# Pass 1: inline chunking + change detection. Inlined (instead of
|
||||
# calling self._prepare_file_for_sync) so this method does not depend
|
||||
# on any sibling helpers — keeps it robust against partial reloads
|
||||
# where the class object is older than the method's source.
|
||||
pending: List[Dict[str, Any]] = []
|
||||
workspace_dir_path = self.config.get_workspace()
|
||||
for file_path, source, scope, user_id in files_to_scan:
|
||||
try:
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
except Exception:
|
||||
continue
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
rel_path = str(file_path.relative_to(workspace_dir_path))
|
||||
if self.storage.get_file_hash(rel_path) == file_hash:
|
||||
continue
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
continue
|
||||
pending.append({
|
||||
"file_path": file_path,
|
||||
"rel_path": rel_path,
|
||||
"source": source,
|
||||
"scope": scope,
|
||||
"user_id": user_id,
|
||||
"file_hash": file_hash,
|
||||
"chunks": chunks,
|
||||
"texts": [c.text for c in chunks],
|
||||
})
|
||||
|
||||
if not pending:
|
||||
self._dirty = False
|
||||
return
|
||||
|
||||
# Pass 2: single batched embed across all pending chunks.
|
||||
# CRITICAL: never touch the index until we hold valid embeddings.
|
||||
# If embed_batch fails, leave the existing index intact (chunks +
|
||||
# file_hash) so the next sync will retry the same files. Writing
|
||||
# NULL embeddings + updating file_hash here would mark the file as
|
||||
# "successfully synced" and silently strand it without vectors.
|
||||
all_texts: List[str] = []
|
||||
for entry in pending:
|
||||
all_texts.extend(entry["texts"])
|
||||
|
||||
if not self.embedding_provider:
|
||||
# No provider configured at all (legacy keyword-only). Persist
|
||||
# chunks without embeddings — this is the user's intent.
|
||||
all_embeddings: List[Optional[List[float]]] = [None] * len(all_texts)
|
||||
else:
|
||||
try:
|
||||
all_embeddings = self.embedding_provider.embed_batch(all_texts)
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.error(
|
||||
f"[MemoryManager] Batch embedding failed for {len(all_texts)} "
|
||||
f"chunks across {len(pending)} files: {e}. "
|
||||
f"Index left untouched; will retry on next sync."
|
||||
)
|
||||
# Bail before touching storage. self._dirty stays True so
|
||||
# callers know there is pending work.
|
||||
return
|
||||
|
||||
# Pass 3: inline persist — same self-contained reasoning as Pass 1.
|
||||
cursor = 0
|
||||
for entry in pending:
|
||||
n = len(entry["texts"])
|
||||
entry_embeddings = all_embeddings[cursor:cursor + n]
|
||||
cursor += n
|
||||
|
||||
rel_path = entry["rel_path"]
|
||||
self.storage.delete_by_path(rel_path)
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(entry["chunks"], entry_embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=entry["user_id"],
|
||||
scope=entry["scope"],
|
||||
source=entry["source"],
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None,
|
||||
))
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
stat = entry["file_path"].stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=entry["source"],
|
||||
file_hash=entry["file_hash"],
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size,
|
||||
)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
def flush_memory(
|
||||
self,
|
||||
messages: list,
|
||||
user_id: Optional[str] = None,
|
||||
reason: str = "threshold",
|
||||
max_messages: int = 10,
|
||||
context_summary_callback=None,
|
||||
) -> bool:
|
||||
"""
|
||||
Flush conversation summary to daily memory file.
|
||||
|
||||
Args:
|
||||
messages: Conversation message list
|
||||
user_id: Optional user ID
|
||||
reason: "threshold" | "overflow" | "daily_summary"
|
||||
max_messages: Max recent messages to include (0 = all)
|
||||
context_summary_callback: Optional callback(str) invoked with the
|
||||
daily summary text for in-context injection
|
||||
|
||||
Returns:
|
||||
True if flush was dispatched
|
||||
"""
|
||||
success = self.flush_manager.flush_from_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
reason=reason,
|
||||
max_messages=max_messages,
|
||||
context_summary_callback=context_summary_callback,
|
||||
)
|
||||
if success:
|
||||
self._dirty = True
|
||||
return success
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get memory status"""
|
||||
stats = self.storage.get_stats()
|
||||
return {
|
||||
'chunks': stats['chunks'],
|
||||
'files': stats['files'],
|
||||
'workspace': str(self.config.get_workspace()),
|
||||
'dirty': self._dirty,
|
||||
'embedding_enabled': self.embedding_provider is not None,
|
||||
'embedding_provider': self.config.embedding_provider if self.embedding_provider else 'disabled',
|
||||
'embedding_model': self.config.embedding_model if self.embedding_provider else 'N/A',
|
||||
'search_mode': 'hybrid (vector + keyword)' if self.embedding_provider else 'keyword only (FTS5)'
|
||||
}
|
||||
|
||||
def mark_dirty(self):
|
||||
"""Mark memory as dirty (needs sync)"""
|
||||
self._dirty = True
|
||||
|
||||
def close(self):
|
||||
"""Close memory manager and release resources"""
|
||||
self.storage.close()
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _generate_chunk_id(self, path: str, start_line: int, end_line: int) -> str:
|
||||
"""Generate unique chunk ID"""
|
||||
content = f"{path}:{start_line}:{end_line}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_temporal_decay(path: str, half_life_days: float = 30.0) -> float:
|
||||
"""
|
||||
Compute temporal decay multiplier for dated memory files.
|
||||
|
||||
Inspired by OpenClaw's temporal-decay: exponential decay based on file date.
|
||||
MEMORY.md and non-dated files are "evergreen" (no decay, multiplier=1.0).
|
||||
Daily files like memory/2025-03-01.md decay based on age.
|
||||
|
||||
Formula: multiplier = exp(-ln2/half_life * age_in_days)
|
||||
"""
|
||||
import re
|
||||
import math
|
||||
|
||||
match = re.search(r'(\d{4})-(\d{2})-(\d{2})\.md$', path)
|
||||
if not match:
|
||||
return 1.0 # evergreen: MEMORY.md, non-dated files
|
||||
|
||||
try:
|
||||
file_date = datetime(
|
||||
int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
)
|
||||
age_days = (datetime.now() - file_date).days
|
||||
if age_days <= 0:
|
||||
return 1.0
|
||||
|
||||
decay_lambda = math.log(2) / half_life_days
|
||||
return math.exp(-decay_lambda * age_days)
|
||||
except (ValueError, OverflowError):
|
||||
return 1.0
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
vector_results: List[SearchResult],
|
||||
keyword_results: List[SearchResult],
|
||||
vector_weight: float,
|
||||
keyword_weight: float
|
||||
) -> List[SearchResult]:
|
||||
"""Merge vector and keyword search results with temporal decay for dated files"""
|
||||
merged_map = {}
|
||||
|
||||
for result in vector_results:
|
||||
key = (result.path, result.start_line, result.end_line)
|
||||
merged_map[key] = {
|
||||
'result': result,
|
||||
'vector_score': result.score,
|
||||
'keyword_score': 0.0
|
||||
}
|
||||
|
||||
for result in keyword_results:
|
||||
key = (result.path, result.start_line, result.end_line)
|
||||
if key in merged_map:
|
||||
merged_map[key]['keyword_score'] = result.score
|
||||
else:
|
||||
merged_map[key] = {
|
||||
'result': result,
|
||||
'vector_score': 0.0,
|
||||
'keyword_score': result.score
|
||||
}
|
||||
|
||||
merged_results = []
|
||||
for entry in merged_map.values():
|
||||
combined_score = (
|
||||
vector_weight * entry['vector_score'] +
|
||||
keyword_weight * entry['keyword_score']
|
||||
)
|
||||
|
||||
# Apply temporal decay for dated memory files
|
||||
result = entry['result']
|
||||
decay = self._compute_temporal_decay(result.path)
|
||||
combined_score *= decay
|
||||
|
||||
merged_results.append(SearchResult(
|
||||
path=result.path,
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
score=combined_score,
|
||||
snippet=result.snippet,
|
||||
source=result.source,
|
||||
user_id=result.user_id
|
||||
))
|
||||
|
||||
merged_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return merged_results
|
||||
14
agent/memory/rebuild_index.py
Normal file
14
agent/memory/rebuild_index.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Backward-compatible shim for the legacy entry point:
|
||||
python -m agent.memory.rebuild_index
|
||||
|
||||
The implementation now lives in agent.memory.embedding.rebuild.
|
||||
Prefer using `/memory rebuild-index` in chat going forward.
|
||||
"""
|
||||
|
||||
from agent.memory.embedding.rebuild import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
sys.exit(main())
|
||||
197
agent/memory/service.py
Normal file
197
agent/memory/service.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Memory service for handling memory query operations via cloud protocol.
|
||||
|
||||
Provides a unified interface for listing and reading memory files,
|
||||
callable from the cloud client (LinkAI) or a future web console.
|
||||
|
||||
Memory file layout (under workspace_root):
|
||||
MEMORY.md -> type: global
|
||||
memory/2026-02-20.md -> type: daily
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
High-level service for memory file queries.
|
||||
Operates directly on the filesystem — no MemoryManager dependency.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
"""
|
||||
:param workspace_root: Workspace root directory (e.g. ~/cow)
|
||||
"""
|
||||
self.workspace_root = workspace_root
|
||||
self.memory_dir = os.path.join(workspace_root, "memory")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — paginated file metadata
|
||||
# ------------------------------------------------------------------
|
||||
def list_files(self, page: int = 1, page_size: int = 20, category: str = "memory") -> dict:
|
||||
"""
|
||||
List memory or dream files with metadata (without content).
|
||||
|
||||
Args:
|
||||
category: ``"memory"`` (default) — MEMORY.md + daily files;
|
||||
``"dream"`` — dream diary files from memory/dreams/
|
||||
"""
|
||||
if category == "dream":
|
||||
files = self._list_dream_files()
|
||||
else:
|
||||
files = self._list_memory_files()
|
||||
|
||||
total = len(files)
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"list": files[start:end],
|
||||
}
|
||||
|
||||
def _list_memory_files(self) -> List[dict]:
|
||||
"""MEMORY.md + memory/*.md (newest first)."""
|
||||
files: List[dict] = []
|
||||
|
||||
global_path = os.path.join(self.workspace_root, "MEMORY.md")
|
||||
if os.path.isfile(global_path):
|
||||
files.append(self._file_info(global_path, "MEMORY.md", "global"))
|
||||
|
||||
if os.path.isdir(self.memory_dir):
|
||||
daily_files = []
|
||||
for name in os.listdir(self.memory_dir):
|
||||
full = os.path.join(self.memory_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
daily_files.append((name, full))
|
||||
daily_files.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in daily_files:
|
||||
files.append(self._file_info(full, name, "daily"))
|
||||
|
||||
return files
|
||||
|
||||
def _list_dream_files(self) -> List[dict]:
|
||||
"""memory/dreams/*.md (newest first)."""
|
||||
files: List[dict] = []
|
||||
dreams_dir = os.path.join(self.memory_dir, "dreams")
|
||||
|
||||
if os.path.isdir(dreams_dir):
|
||||
entries = []
|
||||
for name in os.listdir(dreams_dir):
|
||||
full = os.path.join(dreams_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
entries.append((name, full))
|
||||
entries.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in entries:
|
||||
files.append(self._file_info(full, name, "dream"))
|
||||
|
||||
return files
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# content — read a single file
|
||||
# ------------------------------------------------------------------
|
||||
def get_content(self, filename: str, category: str = "memory") -> dict:
|
||||
"""
|
||||
Read the full content of a memory or dream file.
|
||||
|
||||
:param filename: File name, e.g. ``MEMORY.md``, ``2026-02-20.md``
|
||||
:param category: ``"memory"`` or ``"dream"``
|
||||
:return: dict with ``filename`` and ``content``
|
||||
:raises FileNotFoundError: if the file does not exist
|
||||
"""
|
||||
path = self._resolve_path(filename, category)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Memory file not found: {filename}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a memory management action.
|
||||
|
||||
:param action: ``list`` or ``content``
|
||||
:param payload: action-specific payload (supports ``category``: ``"memory"`` | ``"dream"``)
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
page = payload.get("page", 1)
|
||||
page_size = payload.get("page_size", 20)
|
||||
category = payload.get("category", "memory")
|
||||
result_payload = self.list_files(page=page, page_size=page_size, category=category)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
elif action == "content":
|
||||
filename = payload.get("filename")
|
||||
if not filename:
|
||||
return {"action": action, "code": 400, "message": "filename is required", "payload": None}
|
||||
category = payload.get("category", "memory")
|
||||
result_payload = self.get_content(filename, category=category)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except ValueError as e:
|
||||
return {"action": action, "code": 403, "message": "invalid filename", "payload": None}
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[MemoryService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _resolve_path(self, filename: str, category: str = "memory") -> str:
|
||||
"""
|
||||
Safely resolve a filename to its absolute path within the allowed directory.
|
||||
|
||||
- ``MEMORY.md`` → ``{workspace_root}/MEMORY.md``
|
||||
- ``2026-02-20.md`` (memory) → ``{workspace_root}/memory/2026-02-20.md``
|
||||
- ``2026-02-20.md`` (dream) → ``{workspace_root}/memory/dreams/2026-02-20.md``
|
||||
|
||||
Raises ValueError if the resolved path escapes the allowed directory.
|
||||
"""
|
||||
if filename == "MEMORY.md":
|
||||
base_dir = self.workspace_root
|
||||
elif category == "dream":
|
||||
base_dir = os.path.join(self.memory_dir, "dreams")
|
||||
else:
|
||||
base_dir = self.memory_dir
|
||||
|
||||
resolved = os.path.realpath(os.path.join(base_dir, filename))
|
||||
allowed = os.path.realpath(base_dir)
|
||||
|
||||
if resolved != allowed and not resolved.startswith(allowed + os.sep):
|
||||
raise ValueError(f"Invalid filename: path traversal detected")
|
||||
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _file_info(path: str, filename: str, file_type: str) -> dict:
|
||||
"""Build a file metadata dict."""
|
||||
stat = os.stat(path)
|
||||
updated_at = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
||||
return {
|
||||
"filename": filename,
|
||||
"type": file_type,
|
||||
"size": stat.st_size,
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
1056
agent/memory/storage.py
Normal file
1056
agent/memory/storage.py
Normal file
File diff suppressed because it is too large
Load Diff
847
agent/memory/summarizer.py
Normal file
847
agent/memory/summarizer.py
Normal file
@@ -0,0 +1,847 @@
|
||||
"""
|
||||
Memory flush manager with Deep Dream distillation
|
||||
|
||||
Handles memory persistence when conversation context is trimmed or overflows:
|
||||
- Uses LLM to summarize discarded messages into concise daily records
|
||||
- Writes to daily memory files (lazy creation)
|
||||
- Deduplicates trim flushes to avoid repeated writes
|
||||
- Runs summarization asynchronously to avoid blocking normal replies
|
||||
- Deep Dream: periodically distills daily memories → refined MEMORY.md + dream diary
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import Optional, Callable, Any, List, Dict
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from common.log import logger
|
||||
|
||||
|
||||
SUMMARIZE_SYSTEM_PROMPT_ZH = """你是一个对话记录助手。请将对话内容归纳为当天的日常记录。
|
||||
|
||||
## 要求
|
||||
|
||||
按「事件」维度归纳发生的事,不要按对话轮次逐条记录:
|
||||
- 每条一行,用 "- " 开头
|
||||
- 合并同一件事的多轮对话
|
||||
- 只记录有意义的事件,忽略闲聊和问候
|
||||
- 保留关键的决策、结论和待办事项
|
||||
|
||||
当对话没有任何记录价值(仅含问候或无意义内容),直接回复"无"。"""
|
||||
|
||||
SUMMARIZE_SYSTEM_PROMPT_EN = """You are a conversation-logging assistant. Summarize the conversation into a daily record.
|
||||
|
||||
## Requirements
|
||||
|
||||
Summarize by "event", not turn by turn:
|
||||
- One item per line, starting with "- "
|
||||
- Merge multiple turns about the same thing
|
||||
- Only record meaningful events; ignore small talk and greetings
|
||||
- Keep key decisions, conclusions and to-dos
|
||||
|
||||
If the conversation has no record value (only greetings or meaningless content), reply with exactly "None"."""
|
||||
|
||||
SUMMARIZE_USER_PROMPT_ZH = """请归纳以下对话的日常记录:
|
||||
|
||||
{conversation}"""
|
||||
|
||||
SUMMARIZE_USER_PROMPT_EN = """Summarize the daily record of the following conversation:
|
||||
|
||||
{conversation}"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deep Dream prompts — distill daily memories → MEMORY.md + dream diary
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DREAM_SYSTEM_PROMPT_ZH = """你是一个记忆整理助手,负责定期整理用户的长期记忆。
|
||||
|
||||
你将收到两份材料:
|
||||
1. **当前长期记忆** — MEMORY.md 的全部现有内容
|
||||
2. **今日日记** — 当天的日常记录
|
||||
|
||||
MEMORY.md 会注入每次对话的系统提示词中,因此必须保持精炼,只存放有价值和值得记忆的内容。
|
||||
|
||||
**重要:只能基于提供的材料进行整理,严禁编造、推测或添加材料中不存在的信息。**
|
||||
|
||||
## 任务
|
||||
|
||||
### Part 1: 更新后的长期记忆([MEMORY])
|
||||
|
||||
在现有记忆基础上进行整理和提炼,输出完整的更新后内容:
|
||||
- **合并提炼**:将含义相近的多条合并为一条高密度表述,而非简单罗列
|
||||
- **新增萃取**:从今日日记中提取值得永久记住的新信息(偏好、决策、人物、规则、经验)
|
||||
- **冲突更新**:当新信息与旧条目矛盾时,以新信息为准,替换旧条目
|
||||
- **清理无效**:删除临时性记录、空白条目、格式残留、无意义、重复内容等
|
||||
- **删除冗余**:已被更精炼表述涵盖的旧条目应删除,避免信息重复
|
||||
- 每条一行,用 "- " 开头,不带日期前缀
|
||||
- 可用 "## 标题" 对相关条目分组,使结构更清晰
|
||||
- 目标:控制在 50 条以内,每条尽量一句话概括
|
||||
|
||||
### Part 2: 梦境日记([DREAM])
|
||||
|
||||
用简洁的叙事风格写一篇短日记,记录这次整理的发现,保持格式美观易读:
|
||||
- 发现了哪些重复或矛盾
|
||||
- 从日记中提取了什么新洞察
|
||||
- 做了哪些清理和优化
|
||||
- 整体感受和观察
|
||||
|
||||
## 输出格式(严格遵守)
|
||||
|
||||
```
|
||||
[MEMORY]
|
||||
- 记忆条目1
|
||||
- 记忆条目2
|
||||
...
|
||||
|
||||
[DREAM]
|
||||
梦境日记内容...
|
||||
```"""
|
||||
|
||||
DREAM_SYSTEM_PROMPT_EN = """You are a memory-curation assistant that periodically organizes the user's long-term memory.
|
||||
|
||||
You will receive two inputs:
|
||||
1. **Current long-term memory** — the full existing content of MEMORY.md
|
||||
2. **Today's diary** — the daily records
|
||||
|
||||
MEMORY.md is injected into the system prompt of every conversation, so it must stay concise and hold only valuable, memory-worthy content.
|
||||
|
||||
**Important: organize strictly based on the provided material. Never fabricate, infer, or add information not present in it.**
|
||||
|
||||
## Tasks
|
||||
|
||||
### Part 1: Updated long-term memory ([MEMORY])
|
||||
|
||||
Organize and distill on top of the existing memory, and output the complete updated content:
|
||||
- **Merge & distill**: combine semantically similar items into one dense statement rather than listing them
|
||||
- **Extract new**: pull memory-worthy new info from today's diary (preferences, decisions, people, rules, lessons)
|
||||
- **Resolve conflicts**: when new info contradicts an old item, prefer the new and replace the old
|
||||
- **Clean invalid**: remove temporary notes, blank items, formatting residue, meaningless or duplicate content
|
||||
- **Drop redundancy**: delete old items already covered by a more concise statement
|
||||
- One item per line, starting with "- ", without a date prefix
|
||||
- You may group related items under "## headings" for clarity
|
||||
- Goal: keep under 50 items, each ideally a single sentence
|
||||
|
||||
### Part 2: Dream diary ([DREAM])
|
||||
|
||||
Write a short diary in a concise narrative style recording what this curation found, keep it clean and readable:
|
||||
- Which duplicates or conflicts were found
|
||||
- What new insights were extracted from the diary
|
||||
- What cleanup and optimization was done
|
||||
- Overall feelings and observations
|
||||
|
||||
## Output format (follow strictly)
|
||||
|
||||
```
|
||||
[MEMORY]
|
||||
- memory item 1
|
||||
- memory item 2
|
||||
...
|
||||
|
||||
[DREAM]
|
||||
dream diary content...
|
||||
```"""
|
||||
|
||||
DREAM_USER_PROMPT_ZH = """## 当前长期记忆(MEMORY.md)
|
||||
|
||||
{memory_content}
|
||||
|
||||
## 近期日记(最近 {days} 天)
|
||||
|
||||
{daily_content}"""
|
||||
|
||||
DREAM_USER_PROMPT_EN = """## Current long-term memory (MEMORY.md)
|
||||
|
||||
{memory_content}
|
||||
|
||||
## Recent diary (last {days} days)
|
||||
|
||||
{daily_content}"""
|
||||
|
||||
|
||||
def _is_en() -> bool:
|
||||
"""True when the resolved UI language is English."""
|
||||
try:
|
||||
from common import i18n
|
||||
return i18n.get_language() == "en"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _summarize_system_prompt() -> str:
|
||||
return SUMMARIZE_SYSTEM_PROMPT_EN if _is_en() else SUMMARIZE_SYSTEM_PROMPT_ZH
|
||||
|
||||
|
||||
def _summarize_user_prompt() -> str:
|
||||
return SUMMARIZE_USER_PROMPT_EN if _is_en() else SUMMARIZE_USER_PROMPT_ZH
|
||||
|
||||
|
||||
def _dream_system_prompt() -> str:
|
||||
return DREAM_SYSTEM_PROMPT_EN if _is_en() else DREAM_SYSTEM_PROMPT_ZH
|
||||
|
||||
|
||||
def _dream_user_prompt() -> str:
|
||||
return DREAM_USER_PROMPT_EN if _is_en() else DREAM_USER_PROMPT_ZH
|
||||
|
||||
|
||||
def _is_empty_sentinel(text: str) -> bool:
|
||||
"""Match the "no record value" sentinel in both zh ("无") and en ("None")."""
|
||||
if not text:
|
||||
return True
|
||||
s = text.strip()
|
||||
return s == "" or s == "无" or s.lower() == "none"
|
||||
|
||||
|
||||
|
||||
class MemoryFlushManager:
|
||||
"""
|
||||
Manages memory flush operations.
|
||||
|
||||
Flush is triggered by agent_stream in two scenarios:
|
||||
1. Context trim: _trim_messages discards old turns → flush discarded content
|
||||
2. Context overflow: API rejects request → emergency flush before clearing
|
||||
|
||||
Additionally, create_daily_summary() can be called by scheduler for end-of-day summaries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
llm_model: Optional[Any] = None,
|
||||
):
|
||||
self.workspace_dir = workspace_dir
|
||||
self.llm_model = llm_model
|
||||
|
||||
self.memory_dir = workspace_dir / "memory"
|
||||
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.last_flush_timestamp: Optional[datetime] = None
|
||||
self._trim_flushed_hashes: set = set() # Content hashes of already-flushed messages
|
||||
self._last_flushed_content_hash: str = "" # Content hash at last flush, for daily dedup
|
||||
self._last_dream_input_hash: str = "" # "{date}:{daily_hash}" of last dream, for dedup
|
||||
self._last_flush_thread: Optional[threading.Thread] = None
|
||||
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None, ensure_exists: bool = False) -> Path:
|
||||
"""Get today's memory file path: memory/YYYY-MM-DD.md"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
if ensure_exists:
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
today_file = user_dir / f"{today}.md"
|
||||
else:
|
||||
today_file = self.memory_dir / f"{today}.md"
|
||||
|
||||
if ensure_exists and not today_file.exists():
|
||||
today_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
today_file.write_text(f"# Daily Memory: {today}\n\n")
|
||||
|
||||
return today_file
|
||||
|
||||
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""Get main memory file path: MEMORY.md (workspace root)"""
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / "MEMORY.md"
|
||||
else:
|
||||
return Path(self.workspace_dir) / "MEMORY.md"
|
||||
|
||||
def get_status(self) -> dict:
|
||||
return {
|
||||
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
|
||||
'today_file': str(self.get_today_memory_file()),
|
||||
'main_file': str(self.get_main_memory_file())
|
||||
}
|
||||
|
||||
# ---- Flush execution (called by agent_stream or scheduler) ----
|
||||
|
||||
def flush_from_messages(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str] = None,
|
||||
reason: str = "trim",
|
||||
max_messages: int = 0,
|
||||
context_summary_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously summarize and flush messages to daily memory.
|
||||
|
||||
Deduplication runs synchronously, then LLM summarization + file write
|
||||
run in a background thread so the main reply flow is never blocked.
|
||||
|
||||
If *context_summary_callback* is provided, it is called with the
|
||||
[DAILY] portion of the LLM summary once available. The caller can use
|
||||
this to inject the summary into the live message list for context
|
||||
continuity — one LLM call serves both disk persistence and in-context
|
||||
injection.
|
||||
"""
|
||||
try:
|
||||
# Strip scheduler-injected pairs before any further processing.
|
||||
# These messages already serve as short-term context inside the
|
||||
# receiver session; promoting them into long-term daily memory
|
||||
# produces low-value flat logs (e.g. "11:28 price=1013, normal /
|
||||
# 11:58 price=1013, normal / ...") and wastes summarisation tokens.
|
||||
messages = self._strip_scheduler_pairs(messages)
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
import hashlib
|
||||
deduped = []
|
||||
for m in messages:
|
||||
text = self._extract_text_from_content(m.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
h = hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
if h not in self._trim_flushed_hashes:
|
||||
self._trim_flushed_hashes.add(h)
|
||||
deduped.append(m)
|
||||
if not deduped:
|
||||
return False
|
||||
|
||||
import copy
|
||||
snapshot = copy.deepcopy(deduped)
|
||||
thread = threading.Thread(
|
||||
target=self._flush_worker,
|
||||
args=(snapshot, user_id, reason, max_messages, context_summary_callback),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
logger.info(f"[MemoryFlush] Async flush dispatched (reason={reason}, msgs={len(snapshot)})")
|
||||
self._last_flush_thread = thread
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Failed to dispatch flush (reason={reason}): {e}")
|
||||
return False
|
||||
|
||||
def _flush_worker(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str],
|
||||
reason: str,
|
||||
max_messages: int,
|
||||
context_summary_callback: Optional[Callable[[str], None]] = None,
|
||||
):
|
||||
"""Background worker: summarize with LLM, write daily memory file."""
|
||||
try:
|
||||
raw_summary = self._summarize_messages(messages, max_messages)
|
||||
if _is_empty_sentinel(raw_summary):
|
||||
logger.info(f"[MemoryFlush] No valuable content to flush (reason={reason})")
|
||||
return
|
||||
|
||||
# Strip legacy [DAILY]/[MEMORY] markers if model still outputs them
|
||||
daily_part = self._clean_summary_output(raw_summary)
|
||||
if not daily_part:
|
||||
return
|
||||
|
||||
# --- Write daily memory ---
|
||||
daily_file = ensure_daily_memory_file(self.workspace_dir, user_id)
|
||||
|
||||
headers = {
|
||||
"overflow": f"## Context Overflow Recovery ({datetime.now().strftime('%H:%M')})",
|
||||
"trim": f"## Trimmed Context ({datetime.now().strftime('%H:%M')})",
|
||||
"daily_summary": f"## Daily Summary ({datetime.now().strftime('%H:%M')})",
|
||||
}
|
||||
header = headers.get(reason, f"## Session Notes ({datetime.now().strftime('%H:%M')})")
|
||||
|
||||
with open(daily_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n{header}\n\n{daily_part}\n")
|
||||
|
||||
logger.info(f"[MemoryFlush] Wrote daily memory to {daily_file.name} (reason={reason}, chars={len(daily_part)})")
|
||||
|
||||
# --- Inject context summary into live messages (if callback provided) ---
|
||||
if context_summary_callback:
|
||||
try:
|
||||
context_summary_callback(daily_part)
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Context summary callback failed: {e}")
|
||||
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Async flush failed (reason={reason}): {e}")
|
||||
|
||||
@staticmethod
|
||||
def _clean_summary_output(raw: str) -> str:
|
||||
"""Strip legacy [DAILY]/[MEMORY] markers if present, return clean daily text."""
|
||||
raw = raw.strip()
|
||||
if _is_empty_sentinel(raw):
|
||||
return ""
|
||||
|
||||
# Strip [DAILY] marker
|
||||
if "[DAILY]" in raw:
|
||||
start = raw.index("[DAILY]") + len("[DAILY]")
|
||||
end = raw.index("[MEMORY]") if "[MEMORY]" in raw else len(raw)
|
||||
raw = raw[start:end].strip()
|
||||
|
||||
# Remove stray [MEMORY] section entirely
|
||||
if "[MEMORY]" in raw:
|
||||
raw = raw[:raw.index("[MEMORY]")].strip()
|
||||
|
||||
# Remove markdown code fences
|
||||
raw = raw.replace("```", "").strip()
|
||||
|
||||
return raw
|
||||
|
||||
def create_daily_summary(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Generate end-of-day summary. Called by daily timer.
|
||||
Skips if messages haven't changed since last flush.
|
||||
"""
|
||||
import hashlib
|
||||
content = "".join(
|
||||
self._extract_text_from_content(m.get("content", ""))
|
||||
for m in messages
|
||||
)
|
||||
content_hash = hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
if content_hash == self._last_flushed_content_hash:
|
||||
logger.debug("[MemoryFlush] Daily summary skipped: no new content since last flush")
|
||||
return False
|
||||
self._last_flushed_content_hash = content_hash
|
||||
return self.flush_from_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
reason="daily_summary",
|
||||
max_messages=0,
|
||||
)
|
||||
|
||||
# ---- Deep Dream (memory distillation) ----
|
||||
|
||||
def deep_dream(self, user_id: Optional[str] = None, lookback_days: int = 1, force: bool = False) -> bool:
|
||||
"""
|
||||
Distill recent daily memories into MEMORY.md and generate a dream diary.
|
||||
|
||||
Args:
|
||||
lookback_days: How many days of daily files to read (default 1 for scheduled, 3 for manual)
|
||||
force: Skip input-hash dedup check (used by manual /memory dream trigger)
|
||||
"""
|
||||
if not self.llm_model:
|
||||
logger.warning("[DeepDream] No LLM model available, skipping")
|
||||
return False
|
||||
|
||||
logger.info(f"[DeepDream] Starting memory distillation (lookback={lookback_days} days)")
|
||||
|
||||
# Collect materials
|
||||
memory_content = self._read_main_memory(user_id)
|
||||
daily_content, has_content = self._read_recent_dailies(user_id, lookback_days)
|
||||
|
||||
if not has_content:
|
||||
logger.info("[DeepDream] No recent daily records, skipping to preserve existing MEMORY.md")
|
||||
return False
|
||||
|
||||
# Dedup: skip if same daily content already dreamed today.
|
||||
# Note: only hash daily_content (not memory_content), because deep_dream
|
||||
# itself rewrites MEMORY.md as a side effect, which would otherwise
|
||||
# invalidate the hash on every subsequent call within the same window.
|
||||
import hashlib
|
||||
daily_hash = hashlib.md5(daily_content.encode("utf-8")).hexdigest()
|
||||
today_str = datetime.now().strftime("%Y-%m-%d")
|
||||
dedup_key = f"{today_str}:{daily_hash}"
|
||||
if not force and dedup_key == self._last_dream_input_hash:
|
||||
logger.info("[DeepDream] Already dreamed today with same daily content, skipping")
|
||||
return False
|
||||
self._last_dream_input_hash = dedup_key
|
||||
|
||||
logger.info(
|
||||
f"[DeepDream] Materials collected: "
|
||||
f"MEMORY.md={len(memory_content)} chars, "
|
||||
f"daily={len(daily_content)} chars"
|
||||
)
|
||||
|
||||
# Call LLM for distillation
|
||||
import time as _time
|
||||
t0 = _time.monotonic()
|
||||
try:
|
||||
user_msg = _dream_user_prompt().format(
|
||||
memory_content=memory_content or "(empty)",
|
||||
days=lookback_days,
|
||||
daily_content=daily_content or "(no recent daily records)",
|
||||
)
|
||||
from agent.protocol.models import LLMRequest
|
||||
# Scale max_tokens based on input size to avoid truncating large MEMORY.md
|
||||
input_chars = len(memory_content) + len(daily_content)
|
||||
dream_max_tokens = max(2000, min(input_chars, 8000))
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": user_msg}],
|
||||
temperature=0.3,
|
||||
max_tokens=dream_max_tokens,
|
||||
stream=False,
|
||||
system=_dream_system_prompt(),
|
||||
)
|
||||
response = self.llm_model.call(request)
|
||||
raw = self._extract_response_text(response)
|
||||
elapsed = _time.monotonic() - t0
|
||||
if not raw or not raw.strip():
|
||||
logger.warning(f"[DeepDream] LLM returned empty response ({elapsed:.1f}s)")
|
||||
return False
|
||||
logger.info(f"[DeepDream] LLM distillation completed ({elapsed:.1f}s, {len(raw)} chars)")
|
||||
except Exception as e:
|
||||
elapsed = _time.monotonic() - t0
|
||||
logger.warning(f"[DeepDream] LLM call failed ({elapsed:.1f}s): {e}")
|
||||
return False
|
||||
|
||||
# Parse [MEMORY] and [DREAM] sections
|
||||
new_memory, dream_diary = self._parse_dream_output(raw)
|
||||
|
||||
if not new_memory:
|
||||
logger.warning("[DeepDream] No [MEMORY] section in LLM output, skipping overwrite")
|
||||
return False
|
||||
|
||||
# Overwrite MEMORY.md
|
||||
try:
|
||||
main_file = self.get_main_memory_file(user_id)
|
||||
old_size = len(memory_content)
|
||||
main_file.write_text(new_memory + "\n", encoding="utf-8")
|
||||
logger.info(
|
||||
f"[DeepDream] Updated MEMORY.md "
|
||||
f"({old_size} → {len(new_memory)} chars)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[DeepDream] Failed to write MEMORY.md: {e}")
|
||||
return False
|
||||
|
||||
# Write dream diary
|
||||
if dream_diary:
|
||||
try:
|
||||
self._write_dream_diary(dream_diary, user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"[DeepDream] Failed to write dream diary: {e}")
|
||||
|
||||
logger.info("[DeepDream] ✅ Deep Dream completed successfully")
|
||||
return True
|
||||
|
||||
def _read_main_memory(self, user_id: Optional[str] = None) -> str:
|
||||
"""Read current MEMORY.md content."""
|
||||
main_file = self.get_main_memory_file(user_id)
|
||||
if main_file.exists():
|
||||
return main_file.read_text(encoding="utf-8").strip()
|
||||
return ""
|
||||
|
||||
def _read_recent_dailies(
|
||||
self, user_id: Optional[str] = None, lookback_days: int = 1
|
||||
) -> tuple:
|
||||
"""
|
||||
Read recent daily memory files.
|
||||
|
||||
Returns:
|
||||
(combined_text, has_content) tuple
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
parts = []
|
||||
has_content = False
|
||||
today = datetime.now().date()
|
||||
|
||||
for offset in range(lookback_days):
|
||||
day = today - timedelta(days=offset)
|
||||
date_str = day.strftime("%Y-%m-%d")
|
||||
if user_id:
|
||||
daily_file = self.memory_dir / "users" / user_id / f"{date_str}.md"
|
||||
else:
|
||||
daily_file = self.memory_dir / f"{date_str}.md"
|
||||
|
||||
if daily_file.exists():
|
||||
content = daily_file.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
parts.append(f"### {date_str}\n\n{content}")
|
||||
has_content = True
|
||||
else:
|
||||
parts.append(f"### {date_str}\n\n(no records)")
|
||||
|
||||
return "\n\n".join(parts), has_content
|
||||
|
||||
@staticmethod
|
||||
def _parse_dream_output(raw: str) -> tuple:
|
||||
"""Parse LLM output into (new_memory, dream_diary)."""
|
||||
raw = raw.strip().replace("```", "")
|
||||
new_memory = ""
|
||||
dream_diary = ""
|
||||
|
||||
if "[MEMORY]" in raw:
|
||||
start = raw.index("[MEMORY]") + len("[MEMORY]")
|
||||
end = raw.index("[DREAM]") if "[DREAM]" in raw else len(raw)
|
||||
new_memory = raw[start:end].strip()
|
||||
|
||||
if "[DREAM]" in raw:
|
||||
start = raw.index("[DREAM]") + len("[DREAM]")
|
||||
dream_diary = raw[start:].strip()
|
||||
|
||||
return new_memory, dream_diary
|
||||
|
||||
def _write_dream_diary(self, content: str, user_id: Optional[str] = None):
|
||||
"""Write dream diary to memory/dreams/YYYY-MM-DD.md."""
|
||||
dreams_dir = self.memory_dir / "dreams"
|
||||
if user_id:
|
||||
dreams_dir = self.memory_dir / "users" / user_id / "dreams"
|
||||
dreams_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
diary_file = dreams_dir / f"{today}.md"
|
||||
diary_file.write_text(
|
||||
f"# Dream Diary: {today}\n\n{content}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
logger.info(f"[DeepDream] Wrote dream diary to {diary_file}")
|
||||
|
||||
# ---- Internal helpers ----
|
||||
|
||||
def _summarize_messages(self, messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""
|
||||
Summarize conversation messages using LLM.
|
||||
Returns empty string if LLM deems content not worth recording.
|
||||
Rule-based fallback only used when LLM call raises an exception.
|
||||
"""
|
||||
conversation_text = self._format_conversation_for_summary(messages, max_messages)
|
||||
if not conversation_text.strip():
|
||||
return ""
|
||||
|
||||
if self.llm_model:
|
||||
try:
|
||||
summary = self._call_llm_for_summary(conversation_text)
|
||||
if not _is_empty_sentinel(summary):
|
||||
return summary.strip()
|
||||
logger.info("[MemoryFlush] LLM returned empty sentinel, skipping write")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] LLM summarization failed, using fallback: {e}")
|
||||
return self._extract_summary_fallback(messages, max_messages)
|
||||
else:
|
||||
logger.info("[MemoryFlush] No LLM model available, using rule-based fallback")
|
||||
return self._extract_summary_fallback(messages, max_messages)
|
||||
|
||||
def _format_conversation_for_summary(self, messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""Format messages into readable conversation text for LLM summarization."""
|
||||
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
|
||||
lines = []
|
||||
for msg in msgs:
|
||||
role = msg.get("role", "")
|
||||
text = self._extract_text_from_content(msg.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
text = text.strip()
|
||||
if role == "user":
|
||||
lines.append(f"用户: {text[:500]}")
|
||||
elif role == "assistant":
|
||||
lines.append(f"助手: {text[:500]}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_text(response) -> str:
|
||||
"""
|
||||
Extract text from LLM response regardless of format.
|
||||
|
||||
Handles:
|
||||
- Generator (MiniMax _handle_sync_response yields Claude-format dicts)
|
||||
- Claude format: {"role":"assistant","content":[{"type":"text","text":"..."}]}
|
||||
- OpenAI format: {"choices":[{"message":{"content":"..."}}]}
|
||||
- OpenAI SDK response object with .choices attribute
|
||||
"""
|
||||
import types
|
||||
|
||||
# Unwrap generator — consume first yielded item
|
||||
if isinstance(response, types.GeneratorType):
|
||||
try:
|
||||
response = next(response)
|
||||
except StopIteration:
|
||||
return ""
|
||||
|
||||
if not response:
|
||||
return ""
|
||||
|
||||
if isinstance(response, dict):
|
||||
# Check for error
|
||||
if response.get("error"):
|
||||
raise RuntimeError(response.get("message", "LLM call failed"))
|
||||
|
||||
# Claude format: content is a list of blocks
|
||||
content = response.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
|
||||
# OpenAI format
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("message", {}).get("content", "")
|
||||
|
||||
# OpenAI SDK response object
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
return ""
|
||||
|
||||
def _call_llm_for_summary(self, conversation_text: str) -> str:
|
||||
"""Call LLM to generate a concise summary of the conversation."""
|
||||
from agent.protocol.models import LLMRequest
|
||||
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": _summarize_user_prompt().format(conversation=conversation_text)}],
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
stream=False,
|
||||
system=_summarize_system_prompt(),
|
||||
)
|
||||
|
||||
response = self.llm_model.call(request)
|
||||
return self._extract_response_text(response)
|
||||
|
||||
@staticmethod
|
||||
def _extract_first_meaningful_line(text: str, max_len: int = 120) -> str:
|
||||
"""Extract the first meaningful line from assistant reply, skipping markdown noise."""
|
||||
import re
|
||||
for line in text.split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Skip markdown headings, horizontal rules, code fences, pure emoji/symbols
|
||||
if re.match(r'^(#{1,4}\s|```|---|\*\*\*|[-*]\s*$|[^\w\u4e00-\u9fff]{1,5}$)', line):
|
||||
continue
|
||||
# Strip leading markdown bold/emoji decorations
|
||||
cleaned = re.sub(r'^[\*#>\-\s]+', '', line).strip()
|
||||
cleaned = re.sub(r'^[\U0001f300-\U0001f9ff\u2600-\u27bf\s]+', '', cleaned).strip()
|
||||
if len(cleaned) >= 5:
|
||||
return cleaned[:max_len]
|
||||
return text.split("\n")[0].strip()[:max_len]
|
||||
|
||||
@staticmethod
|
||||
def _extract_summary_fallback(messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""
|
||||
Rule-based summary of discarded messages.
|
||||
Format: "用户问了X; 助手回答了Y" per event, compact and readable.
|
||||
"""
|
||||
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
|
||||
|
||||
events: List[str] = []
|
||||
current_user_text = ""
|
||||
for msg in msgs:
|
||||
role = msg.get("role", "")
|
||||
text = MemoryFlushManager._extract_text_from_content(msg.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
text = text.strip()
|
||||
|
||||
if role == "user":
|
||||
if len(text) <= 3:
|
||||
continue
|
||||
current_user_text = text[:120]
|
||||
elif role == "assistant" and current_user_text:
|
||||
reply_summary = MemoryFlushManager._extract_first_meaningful_line(text)
|
||||
if reply_summary:
|
||||
events.append(f"- 用户: {current_user_text} → 回复: {reply_summary}")
|
||||
else:
|
||||
events.append(f"- 用户: {current_user_text}")
|
||||
current_user_text = ""
|
||||
|
||||
if current_user_text:
|
||||
events.append(f"- 用户: {current_user_text}")
|
||||
|
||||
return "\n".join(events[:10])
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from message content (string or content blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _strip_scheduler_pairs(cls, messages: List[Dict]) -> List[Dict]:
|
||||
"""Drop scheduler-injected user/assistant pairs from a flush batch.
|
||||
|
||||
A scheduler user message starts with the ``[SCHEDULED]`` marker
|
||||
(written by ``AgentBridge.remember_scheduled_output``); the message
|
||||
immediately following it (if it is an assistant turn) is its paired
|
||||
output and is dropped together. Regular user/assistant turns and
|
||||
any tool_use / tool_result blocks are preserved as-is.
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
SCHEDULED_PREFIX = "[SCHEDULED]"
|
||||
result = []
|
||||
skip_next_assistant = False
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
result.append(msg)
|
||||
skip_next_assistant = False
|
||||
continue
|
||||
role = msg.get("role")
|
||||
if skip_next_assistant and role == "assistant":
|
||||
skip_next_assistant = False
|
||||
continue
|
||||
skip_next_assistant = False
|
||||
if role == "user":
|
||||
text = cls._extract_text_from_content(msg.get("content", ""))
|
||||
if text.lstrip().startswith(SCHEDULED_PREFIX):
|
||||
skip_next_assistant = True
|
||||
continue
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
|
||||
"""
|
||||
Create essential memory files if they don't exist.
|
||||
Only creates MEMORY.md; daily files are created lazily on first write.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
user_id: Optional user ID for user-specific files
|
||||
"""
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create main MEMORY.md in workspace root (always needed for bootstrap)
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
main_memory = user_dir / "MEMORY.md"
|
||||
else:
|
||||
main_memory = Path(workspace_dir) / "MEMORY.md"
|
||||
|
||||
if not main_memory.exists():
|
||||
main_memory.write_text("")
|
||||
|
||||
|
||||
def ensure_daily_memory_file(workspace_dir: Path, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Ensure today's daily memory file exists, creating it only when actually needed.
|
||||
Called lazily before first write to daily memory.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
user_id: Optional user ID for user-specific files
|
||||
|
||||
Returns:
|
||||
Path to today's memory file
|
||||
"""
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
today_memory = user_dir / f"{today}.md"
|
||||
else:
|
||||
today_memory = memory_dir / f"{today}.md"
|
||||
|
||||
if not today_memory.exists():
|
||||
today_memory.write_text(
|
||||
f"# Daily Memory: {today}\n\n"
|
||||
)
|
||||
|
||||
return today_memory
|
||||
13
agent/prompt/__init__.py
Normal file
13
agent/prompt/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Agent Prompt Module - 系统提示词构建模块
|
||||
"""
|
||||
|
||||
from .builder import PromptBuilder, build_agent_system_prompt
|
||||
from .workspace import ensure_workspace, load_context_files
|
||||
|
||||
__all__ = [
|
||||
'PromptBuilder',
|
||||
'build_agent_system_prompt',
|
||||
'ensure_workspace',
|
||||
'load_context_files',
|
||||
]
|
||||
760
agent/prompt/builder.py
Normal file
760
agent/prompt/builder.py
Normal file
@@ -0,0 +1,760 @@
|
||||
"""
|
||||
System Prompt Builder - 系统提示词构建器
|
||||
|
||||
实现模块化的系统提示词构建,支持工具、技能、记忆等多个子系统
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextFile:
|
||||
"""A context file (path + content)."""
|
||||
path: str
|
||||
content: str
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""System prompt builder."""
|
||||
|
||||
def __init__(self, workspace_dir: str, language: str = "zh"):
|
||||
"""
|
||||
初始化提示词构建器
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
language: 语言 ("zh" 或 "en")
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.language = language
|
||||
|
||||
def build(
|
||||
self,
|
||||
base_persona: Optional[str] = None,
|
||||
user_identity: Optional[Dict[str, str]] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
context_files: Optional[List[ContextFile]] = None,
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建完整的系统提示词
|
||||
|
||||
Args:
|
||||
base_persona: 基础人格描述(会被context_files中的AGENT.md覆盖)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表(AGENT.md, USER.md, RULE.md, BOOTSTRAP.md等)
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的系统提示词
|
||||
"""
|
||||
return build_agent_system_prompt(
|
||||
workspace_dir=self.workspace_dir,
|
||||
language=self.language,
|
||||
base_persona=base_persona,
|
||||
user_identity=user_identity,
|
||||
tools=tools,
|
||||
context_files=context_files,
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def build_agent_system_prompt(
|
||||
workspace_dir: str,
|
||||
language: str = "zh",
|
||||
base_persona: Optional[str] = None,
|
||||
user_identity: Optional[Dict[str, str]] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
context_files: Optional[List[ContextFile]] = None,
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Build the agent system prompt.
|
||||
|
||||
Section order (by importance and logical flow):
|
||||
1. Tooling - core capabilities, introduced first
|
||||
2. Skills - right after tools, since skills are read via the read tool
|
||||
3. Memory - memory recall and writing guidance
|
||||
3.5 Knowledge - structured knowledge base (injects knowledge/index.md)
|
||||
4. Workspace - working environment description
|
||||
5. User identity - user info (optional)
|
||||
6. Project context - AGENT.md, USER.md, RULE.md, MEMORY.md, BOOTSTRAP.md
|
||||
7. Runtime info - meta info (time, model, etc.)
|
||||
|
||||
Args:
|
||||
workspace_dir: workspace directory
|
||||
language: language ("zh" or "en")
|
||||
base_persona: base persona description (deprecated, defined by AGENT.md)
|
||||
user_identity: user identity info
|
||||
tools: tool list
|
||||
context_files: context file list
|
||||
skill_manager: skill manager
|
||||
memory_manager: memory manager
|
||||
runtime_info: runtime info
|
||||
**kwargs: extra args
|
||||
|
||||
Returns:
|
||||
The full system prompt.
|
||||
"""
|
||||
sections = []
|
||||
|
||||
# 1. Tooling (most important, goes first)
|
||||
if tools:
|
||||
sections.extend(_build_tooling_section(tools, language))
|
||||
|
||||
# 2. Skills (right after tools, since they need the read tool)
|
||||
if skill_manager:
|
||||
sections.extend(_build_skills_section(skill_manager, tools, language))
|
||||
|
||||
# 3. Memory (standalone memory capability)
|
||||
if memory_manager:
|
||||
sections.extend(_build_memory_section(memory_manager, tools, language))
|
||||
|
||||
# 3.5 Knowledge (structured knowledge base)
|
||||
if conf().get("knowledge", True):
|
||||
sections.extend(_build_knowledge_section(workspace_dir, language))
|
||||
|
||||
# 4. Workspace (working environment description)
|
||||
sections.extend(_build_workspace_section(workspace_dir, language))
|
||||
|
||||
# 5. User identity (if present)
|
||||
if user_identity:
|
||||
sections.extend(_build_user_identity_section(user_identity, language))
|
||||
|
||||
# 6. Project context files (AGENT.md, USER.md, RULE.md - define the persona)
|
||||
if context_files:
|
||||
sections.extend(_build_context_files_section(context_files, language))
|
||||
|
||||
# 7. Runtime info (meta info, goes last)
|
||||
if runtime_info:
|
||||
sections.extend(_build_runtime_section(runtime_info, language))
|
||||
|
||||
# 8. Response language (always appended, independent of the skeleton language)
|
||||
sections.extend(_build_response_language_section(language))
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def _build_response_language_section(language: str) -> List[str]:
|
||||
"""Response-language rule, appended regardless of the prompt skeleton language.
|
||||
|
||||
Keeps the agent's reply language aligned with the user's input by default,
|
||||
so a Chinese-built prompt still answers an English user in English.
|
||||
"""
|
||||
if language == "en":
|
||||
return [
|
||||
"## 🌐 Response language",
|
||||
"",
|
||||
"By default, reply in the same language as the user's input, "
|
||||
"unless the user explicitly asks for another language.",
|
||||
"",
|
||||
]
|
||||
return [
|
||||
"## 🌐 回复语言",
|
||||
"",
|
||||
"默认使用与用户输入相同的语言回复,除非用户明确要求使用其他语言。",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
def _build_identity_section(base_persona: Optional[str], language: str) -> List[str]:
|
||||
"""Base identity section - no longer needed, identity is defined by AGENT.md."""
|
||||
# Identity is fully defined by AGENT.md, so emit nothing here.
|
||||
return []
|
||||
|
||||
|
||||
def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"""Build tooling section with concise tool list and call style guide."""
|
||||
is_en = language == "en"
|
||||
# One-line summaries for known tools (details are in the tool schema)
|
||||
if is_en:
|
||||
core_summaries = {
|
||||
"read": "read file content",
|
||||
"write": "create or overwrite a file",
|
||||
"edit": "make precise edits to a file",
|
||||
"ls": "list directory contents",
|
||||
"grep": "search file contents",
|
||||
"find": "find files by pattern",
|
||||
"bash": "run shell commands",
|
||||
"terminal": "manage background processes",
|
||||
"web_search": "web search",
|
||||
"web_fetch": "fetch URL content",
|
||||
"browser": "control the browser (screenshot key results or send to the user when help is needed)",
|
||||
"memory_search": "search memory",
|
||||
"memory_get": "read memory content",
|
||||
"env_config": "manage API keys and skill config",
|
||||
"scheduler": "manage scheduled tasks and reminders",
|
||||
"send": "send a local file to the user (local files only; put URLs directly in the reply text)",
|
||||
"vision": "analyze images (recognition, description, OCR, etc.)",
|
||||
}
|
||||
else:
|
||||
core_summaries = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建或覆盖文件",
|
||||
"edit": "精确编辑文件",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "搜索文件内容",
|
||||
"find": "按模式查找文件",
|
||||
"bash": "执行shell命令",
|
||||
"terminal": "管理后台进程",
|
||||
"web_search": "网络搜索",
|
||||
"web_fetch": "获取URL内容",
|
||||
"browser": "控制浏览器(关键结果或需要协助可截图发送给用户)",
|
||||
"memory_search": "搜索记忆",
|
||||
"memory_get": "读取记忆内容",
|
||||
"env_config": "管理API密钥和技能配置",
|
||||
"scheduler": "管理定时任务和提醒",
|
||||
"send": "发送本地文件给用户(仅限本地文件,URL直接放在回复文本中)",
|
||||
"vision": "分析图片内容(识别、描述、OCR文字提取等)",
|
||||
}
|
||||
|
||||
# Preferred display order
|
||||
tool_order = [
|
||||
"read", "write", "edit", "ls", "grep", "find",
|
||||
"bash", "terminal",
|
||||
"web_search", "web_fetch", "browser",
|
||||
"memory_search", "memory_get",
|
||||
"env_config", "scheduler", "send", "vision",
|
||||
]
|
||||
|
||||
# Build name -> summary mapping for available tools
|
||||
available = {}
|
||||
for tool in tools:
|
||||
name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
available[name] = core_summaries.get(name, "")
|
||||
|
||||
# Generate tool lines: ordered tools first, then extras
|
||||
tool_lines = []
|
||||
for name in tool_order:
|
||||
if name in available:
|
||||
summary = available.pop(name)
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
for name in sorted(available):
|
||||
summary = available[name]
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
|
||||
if is_en:
|
||||
lines = [
|
||||
"## 🔧 Tooling",
|
||||
"",
|
||||
"Available tools (names are case-sensitive, call exactly as listed):",
|
||||
"\n".join(tool_lines),
|
||||
"",
|
||||
"Tool-calling style:",
|
||||
"",
|
||||
"- For multi-step tasks, complex decisions or sensitive operations, briefly explain what you are doing and why, so the user follows key progress",
|
||||
"- Keep going until the task is done, then report the result to the user",
|
||||
"- Always redact secrets, tokens and other sensitive info in replies",
|
||||
"- Put URLs directly in the reply text; the system handles and renders them. Don't download and re-send them via the send tool",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 🔧 工具系统",
|
||||
"",
|
||||
"可用工具(名称大小写敏感,严格按列表调用):",
|
||||
"\n".join(tool_lines),
|
||||
"",
|
||||
"工具调用风格:",
|
||||
"",
|
||||
"- 多步骤任务、复杂决策、敏感操作时,应简要说明当前在做什么、为什么这样做,让用户了解关键进展",
|
||||
"- 持续推进直到任务完成,完成后向用户报告结果",
|
||||
"- 回复中涉及密钥、令牌等敏感信息必须脱敏",
|
||||
"- URL链接直接放在回复文本中即可,系统会自动处理和渲染。无需下载后使用send工具发送",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""Build the skills section."""
|
||||
if not skill_manager:
|
||||
return []
|
||||
|
||||
# Resolve the read tool name
|
||||
read_tool_name = "read"
|
||||
if tools:
|
||||
for tool in tools:
|
||||
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
if tool_name.lower() == "read":
|
||||
read_tool_name = tool_name
|
||||
break
|
||||
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 🧩 Skills (mandatory)",
|
||||
"",
|
||||
"Before replying: scan the <description> of every skill in <available_skills> below.",
|
||||
"",
|
||||
f"- If a skill's description matches the user's need: use the `{read_tool_name}` tool to read the SKILL.md at its <location> path, then strictly follow the instructions in the file. "
|
||||
"Prefer using a skill when one matches.",
|
||||
"- If multiple skills apply, pick the best-matching one, then read and follow it.",
|
||||
"- If no skill clearly applies: do not read any SKILL.md, just use the general tools.",
|
||||
"",
|
||||
f"**Important**: skills are not tools and cannot be called directly. The only way to use a skill is to read its SKILL.md with `{read_tool_name}`, then act on the file's content. "
|
||||
"Never read multiple skills at once — only read one after selecting it.",
|
||||
"",
|
||||
"Available skills:"
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 🧩 技能系统(mandatory)",
|
||||
"",
|
||||
"在回复之前:扫描下方 <available_skills> 中每个技能的 <description>。",
|
||||
"",
|
||||
f"- 如果有技能的描述与用户需求匹配:使用 `{read_tool_name}` 工具读取其 <location> 路径的 SKILL.md 文件,然后严格遵循文件中的指令。"
|
||||
"当有匹配的技能时,应优先使用技能",
|
||||
"- 如果多个技能都适用则选择最匹配的一个,然后读取并遵循。",
|
||||
"- 如果没有技能明确适用:不要读取任何 SKILL.md,直接使用通用工具。",
|
||||
"",
|
||||
f"**重要**: 技能不是工具,不能直接调用。使用技能的唯一方式是用 `{read_tool_name}` 读取 SKILL.md 文件,然后按文件内容操作。"
|
||||
"永远不要一次性读取多个技能,只在选择后再读取。",
|
||||
"",
|
||||
"以下是可用技能:"
|
||||
]
|
||||
|
||||
# Append the skills list (built by skill_manager)
|
||||
try:
|
||||
skills_prompt = skill_manager.build_skills_prompt()
|
||||
logger.debug(f"[PromptBuilder] Skills prompt length: {len(skills_prompt) if skills_prompt else 0}")
|
||||
if skills_prompt:
|
||||
lines.append(skills_prompt.strip())
|
||||
lines.append("")
|
||||
else:
|
||||
logger.warning("[PromptBuilder] No skills prompt generated - skills_prompt is empty")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build skills prompt: {e}")
|
||||
import traceback
|
||||
logger.debug(f"Skills prompt error traceback: {traceback.format_exc()}")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""Build the memory section."""
|
||||
if not memory_manager:
|
||||
return []
|
||||
|
||||
has_memory_tools = False
|
||||
if tools:
|
||||
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
|
||||
has_memory_tools = any(name in ['memory_search', 'memory_get'] for name in tool_names)
|
||||
|
||||
if not has_memory_tools:
|
||||
return []
|
||||
|
||||
from datetime import datetime
|
||||
today_file = datetime.now().strftime("%Y-%m-%d") + ".md"
|
||||
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 🧠 Memory",
|
||||
"",
|
||||
"### Memory Recall (mandatory)",
|
||||
"",
|
||||
"When the user asks about past events, references an earlier decision, mentions relationships, preferences or to-dos, or when you are unsure about something, **you must search memory before answering**.",
|
||||
"No need to re-search if the info is already in MEMORY.md. Full content and daily memory must be retrieved via tools.",
|
||||
"",
|
||||
"1. Location unknown → `memory_search` (keyword / semantic search)",
|
||||
"2. Location known → `memory_get` to read the exact lines",
|
||||
"3. Search returns nothing → `memory_get` to read the last two days of memory",
|
||||
"",
|
||||
"**Memory file structure**:",
|
||||
"- `MEMORY.md`: long-term memory index (already auto-loaded into context: core info, preferences, decisions, etc.)",
|
||||
f"- `memory/YYYY-MM-DD.md`: daily memory; today is `memory/{today_file}`",
|
||||
"- `knowledge/`: structured knowledge base (see the knowledge system below)",
|
||||
"",
|
||||
"### Writing memory",
|
||||
"",
|
||||
"In the following cases, **proactively** write info to memory files (no need to tell the user):",
|
||||
"",
|
||||
"- The user asks you to remember something, or uses words like \"remember\", \"from now on\", \"always\", \"never\", \"prefer\"",
|
||||
"- The user shares important personal preferences, habits or decisions",
|
||||
"- The conversation produces an important conclusion, plan or agreement",
|
||||
"- A complex task is completed and the key steps and results are worth recording",
|
||||
"",
|
||||
"**Storage rules**:",
|
||||
"- Long-term core info → `MEMORY.md`",
|
||||
f"- Today's events/progress → `memory/{today_file}`",
|
||||
"- Structured knowledge → `knowledge/` (see the knowledge system)",
|
||||
"- Append → `edit` tool with empty oldText",
|
||||
"- Modify → `edit` tool with oldText set to the text to replace",
|
||||
"- **Never write sensitive info** (API keys, tokens, etc.)",
|
||||
"",
|
||||
"**Principle**: use memory naturally, as if you simply knew it; don't bring it up unless asked.",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 🧠 记忆系统",
|
||||
"",
|
||||
"### Memory Recall(mandatory)",
|
||||
"",
|
||||
"当用户询问过往事件、引用之前的决定、提到人物关系、偏好、待办、或你对某事不确定时,**必须先检索记忆再回答**。",
|
||||
"如果 MEMORY.md 中已有相关信息则无需重复检索。完整内容和每日记忆需要通过工具检索。",
|
||||
"",
|
||||
"1. 不确定位置 → `memory_search` 关键词/语义检索",
|
||||
"2. 已知位置 → `memory_get` 直接读取对应行",
|
||||
"3. search 无结果 → `memory_get` 读最近两天记忆",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆索引(已自动加载到上下文,核心信息、偏好、决策等)",
|
||||
f"- `memory/YYYY-MM-DD.md`: 每日记忆,今天是 `memory/{today_file}`",
|
||||
"- `knowledge/`: 结构化知识库(见下方知识系统)",
|
||||
"",
|
||||
"### 写入记忆",
|
||||
"",
|
||||
"遇到以下情况时,**主动**将信息写入记忆文件(无需告知用户):",
|
||||
"",
|
||||
"- 用户要求记住某些信息,或使用了「记住」「以后」「总是」「不要」「偏好」等表达",
|
||||
"- 用户分享了重要的个人偏好、习惯、决策",
|
||||
"- 对话中产生了重要的结论、方案、约定",
|
||||
"- 完成了复杂任务,值得记录关键步骤和结果",
|
||||
"",
|
||||
"**存储规则**:",
|
||||
f"- 长期核心信息 → `MEMORY.md`",
|
||||
f"- 当天事件/进展 → `memory/{today_file}`",
|
||||
"- 结构化知识 → `knowledge/`(见知识系统)",
|
||||
"- 追加 → `edit` 工具,oldText 留空",
|
||||
"- 修改 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- **禁止写入敏感信息**(API密钥、令牌等)",
|
||||
"",
|
||||
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_knowledge_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""Build knowledge wiki section. Injects knowledge/index.md when present."""
|
||||
index_path = os.path.join(workspace_dir, "knowledge", "index.md")
|
||||
if not os.path.exists(index_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(index_path, 'r', encoding='utf-8') as f:
|
||||
index_content = f.read().strip()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 📚 Knowledge",
|
||||
"",
|
||||
"You have a continuously growing personal knowledge base `knowledge/` — your long-term structured knowledge store.",
|
||||
"",
|
||||
"### Auto-write rules (mandatory)",
|
||||
"",
|
||||
"In the following cases you **must** write to the knowledge base alongside your reply, **directly, without asking the user**:",
|
||||
"",
|
||||
"1. **User shares an article / link / document** → after reading and understanding, write the key points to `knowledge/sources/<slug>.md` in the same turn",
|
||||
"2. **An in-depth discussion produces a conclusion / plan** → organize it into `knowledge/analysis/<slug>.md`",
|
||||
"3. **The conversation involves an important entity** (person / company / project) → create or update `knowledge/entities/<name>.md`",
|
||||
"4. **A technical concept / methodology is discussed** → organize it into `knowledge/concepts/<topic>.md`",
|
||||
"",
|
||||
"After writing any knowledge page, you **must update** `knowledge/index.md` with a new index line in sync.",
|
||||
"For detailed page format and conventions, read the SKILL.md of the `knowledge-wiki` skill.",
|
||||
"",
|
||||
"⚠️ Don't ask \"should I save this to the knowledge base?\" — if a case above matches, just write it. This is instinctive.",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 📚 知识系统",
|
||||
"",
|
||||
"你拥有一个持续积累的个人知识库 `knowledge/`,这是你的长期结构化知识存储。",
|
||||
"",
|
||||
"### 自动写入规则(mandatory)",
|
||||
"",
|
||||
"以下场景**必须**在回复的同时写入知识库,**直接写入,不要询问用户是否需要**:",
|
||||
"",
|
||||
"1. **用户分享了文章/链接/文档** → 阅读理解后,在同一轮回复中将要点写入 `knowledge/sources/<slug>.md`",
|
||||
"2. **深度讨论产生了结论/方案** → 整理为 `knowledge/analysis/<slug>.md`",
|
||||
"3. **对话涉及重要实体**(人物/公司/项目)→ 创建或更新 `knowledge/entities/<name>.md`",
|
||||
"4. **讨论了技术概念/方法论** → 整理为 `knowledge/concepts/<topic>.md`",
|
||||
"",
|
||||
"每次写入知识页面后,**必须同步更新** `knowledge/index.md` 添加一行索引。",
|
||||
"详细的页面格式和操作规范,请读取技能 `knowledge-wiki` 的 SKILL.md。",
|
||||
"",
|
||||
"⚠️ 不要问「要不要存到知识库」——符合上述场景就直接写入,这是你的本能行为。",
|
||||
"",
|
||||
]
|
||||
|
||||
if index_content:
|
||||
lines.extend([
|
||||
("### Current knowledge index" if language == "en" else "### 当前知识索引"),
|
||||
"",
|
||||
index_content,
|
||||
"",
|
||||
])
|
||||
|
||||
lines.extend([
|
||||
("**How to query**: use `read` to open a knowledge page, or `memory_search` (knowledge is in the vector index)."
|
||||
if language == "en" else
|
||||
"**查询方式**:用 `read` 读取知识页面,或用 `memory_search` 检索(知识已纳入向量索引)。"),
|
||||
"",
|
||||
])
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_user_identity_section(user_identity: Dict[str, str], language: str) -> List[str]:
|
||||
"""Build the user identity section."""
|
||||
if not user_identity:
|
||||
return []
|
||||
|
||||
is_en = language == "en"
|
||||
lines = [
|
||||
("## 👤 User identity" if is_en else "## 👤 用户身份"),
|
||||
"",
|
||||
]
|
||||
|
||||
if user_identity.get("name"):
|
||||
lines.append(f"**{'Name' if is_en else '用户姓名'}**: {user_identity['name']}")
|
||||
if user_identity.get("nickname"):
|
||||
lines.append(f"**{'Preferred name' if is_en else '称呼'}**: {user_identity['nickname']}")
|
||||
if user_identity.get("timezone"):
|
||||
lines.append(f"**{'Timezone' if is_en else '时区'}**: {user_identity['timezone']}")
|
||||
if user_identity.get("notes"):
|
||||
lines.append(f"**{'Notes' if is_en else '备注'}**: {user_identity['notes']}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""Docs-path section - removed, no longer needed."""
|
||||
# No docs section is generated anymore.
|
||||
return []
|
||||
|
||||
|
||||
def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""Build the workspace section."""
|
||||
if language == "en":
|
||||
lines = [
|
||||
"## 📂 Workspace",
|
||||
"",
|
||||
f"Your working directory is: `{workspace_dir}`",
|
||||
"",
|
||||
"**Path rules** (very important):",
|
||||
"",
|
||||
f"1. **Base directory for relative paths**: all relative paths are relative to `{workspace_dir}`",
|
||||
" - ✅ Correct: use relative paths for files inside the workspace, e.g. `AGENT.md`",
|
||||
f" - ❌ Wrong: using a relative path for files in other directories (if not inside `{workspace_dir}`)",
|
||||
"",
|
||||
"2. **Accessing other directories**: to reach directories outside the workspace (project code, system files), **you must use absolute paths**",
|
||||
" - ✅ Correct: e.g. `~/chatgpt-on-wechat`, `/usr/local/`",
|
||||
" - ❌ Wrong: assuming a relative path points to another directory",
|
||||
"",
|
||||
"3. **Path resolution examples**:",
|
||||
f" - relative `memory/` → actual `{workspace_dir}/memory/`",
|
||||
" - absolute `~/chatgpt-on-wechat/docs/` → actual `~/chatgpt-on-wechat/docs/`",
|
||||
"",
|
||||
"4. **When unsure**: run `bash pwd` to confirm the current directory, or `ls .` to see where you are",
|
||||
"",
|
||||
"**Important - files already auto-loaded**:",
|
||||
"",
|
||||
"The following files are **already auto-loaded** into the system prompt at session start, so you **don't need to read them again with the read tool**:",
|
||||
"",
|
||||
"- ✅ `AGENT.md`: loaded - your persona and soul; follow it strictly. When your name, personality or style changes, proactively `edit` this file",
|
||||
"- ✅ `USER.md`: loaded - the user's identity info. When the user changes how they're addressed, their name, etc., `edit` this file",
|
||||
"- ✅ `RULE.md`: loaded - workspace guide and rules; follow them strictly",
|
||||
"- ✅ `MEMORY.md`: loaded - long-term memory index",
|
||||
"",
|
||||
"**💬 Communication norms**:",
|
||||
"",
|
||||
"- No need to expose file names for memory operations; use natural language. Say \"I'll remember that\" rather than \"updated MEMORY.md\"",
|
||||
"- Tell the user about key decisions and steps during a task, so they know what you're doing and why",
|
||||
"- Be genuinely helpful rather than performatively polite; solve the problem as much as you can",
|
||||
"- Keep replies well-structured and focused. Use **bold**, lists and sections to make info clear at a glance",
|
||||
"- Use emoji to make expression lively 🎯, but don't overdo it",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"## 📂 工作空间",
|
||||
"",
|
||||
f"你的工作目录是: `{workspace_dir}`",
|
||||
"",
|
||||
"**路径使用规则** (非常重要):",
|
||||
"",
|
||||
f"1. **相对路径的基准目录**: 所有相对路径都是相对于 `{workspace_dir}` 而言的",
|
||||
f" - ✅ 正确: 访问工作空间内的文件用相对路径,如 `AGENT.md`",
|
||||
f" - ❌ 错误: 用相对路径访问其他目录的文件 (如果它不在 `{workspace_dir}` 内)",
|
||||
"",
|
||||
"2. **访问其他目录**: 如果要访问工作空间之外的目录(如项目代码、系统文件),**必须使用绝对路径**",
|
||||
f" - ✅ 正确: 例如 `~/chatgpt-on-wechat`、`/usr/local/`",
|
||||
f" - ❌ 错误: 假设相对路径会指向其他目录",
|
||||
"",
|
||||
"3. **路径解析示例**:",
|
||||
f" - 相对路径 `memory/` → 实际路径 `{workspace_dir}/memory/`",
|
||||
f" - 绝对路径 `~/chatgpt-on-wechat/docs/` → 实际路径 `~/chatgpt-on-wechat/docs/`",
|
||||
"",
|
||||
"4. **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
|
||||
"",
|
||||
"**重要说明 - 文件已自动加载**:",
|
||||
"",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词中,你**无需再用 read 工具读取**:",
|
||||
"",
|
||||
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定,请严格遵循。当你的名字、性格或交流风格发生变化时,主动用 `edit` 更新此文件",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息。当用户修改称呼、姓名等身份信息时,用 `edit` 更新此文件",
|
||||
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则,请严格遵循",
|
||||
"- ✅ `MEMORY.md`: 已加载 - 长期记忆索引",
|
||||
"",
|
||||
"**💬 交流规范**:",
|
||||
"",
|
||||
"- 记忆相关操作无需暴露文件名,用自然语言表达即可。例如说「我已记住」而非「已更新 MEMORY.md」",
|
||||
"- 任务执行过程中的关键决策和步骤应该告知用户,让用户了解你在做什么、为什么这么做",
|
||||
"- 做真正有帮助的助手,而不是表演式的客套,尽可能帮忙解决问题",
|
||||
"- 回复应结构清晰、重点突出。善用 **加粗**、列表、分段等格式让信息一目了然",
|
||||
"- 适当使用 emoji 让表达更生动自然 🎯,但不要过度堆砌",
|
||||
"",
|
||||
]
|
||||
|
||||
# Cloud deployment: inject websites directory info and access URL
|
||||
cloud_website_lines = _build_cloud_website_section(workspace_dir)
|
||||
if cloud_website_lines:
|
||||
lines.extend(cloud_website_lines)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_cloud_website_section(workspace_dir: str) -> List[str]:
|
||||
"""Build cloud website access prompt when cloud deployment is configured."""
|
||||
try:
|
||||
from common.cloud_client import build_website_prompt
|
||||
return build_website_prompt(workspace_dir)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
|
||||
"""Build the project context files section."""
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
# Check whether AGENT.md is present
|
||||
has_agent = any(
|
||||
f.path.lower().endswith('agent.md') or 'agent.md' in f.path.lower()
|
||||
for f in context_files
|
||||
)
|
||||
|
||||
is_en = language == "en"
|
||||
if is_en:
|
||||
lines = [
|
||||
"# 📋 Project context",
|
||||
"",
|
||||
"The following project context files have been loaded:",
|
||||
"",
|
||||
]
|
||||
else:
|
||||
lines = [
|
||||
"# 📋 项目上下文",
|
||||
"",
|
||||
"以下项目上下文文件已被加载:",
|
||||
"",
|
||||
]
|
||||
|
||||
if has_agent:
|
||||
if is_en:
|
||||
lines.append("**`AGENT.md` is your soul file** 🪞: strictly follow the persona, tone and settings it defines. Be your real self, avoid stiff, template-like replies.")
|
||||
lines.append("When the user reveals new expectations about your personality, style, responsibilities or capability boundaries, proactively `edit` AGENT.md to reflect that evolution.")
|
||||
else:
|
||||
lines.append("**`AGENT.md` 是你的灵魂文件** 🪞:严格遵循其中定义的人格、语气和设定,做真实的自己,避免僵硬、模板化的回复。")
|
||||
lines.append("当用户通过对话透露了对你性格、风格、职责、能力边界的新期望,你应该主动用 `edit` 更新 AGENT.md 以反映这些演变。")
|
||||
lines.append("")
|
||||
|
||||
# Append the content of each file
|
||||
for file in context_files:
|
||||
lines.append(f"## {file.path}")
|
||||
lines.append("")
|
||||
lines.append(file.content)
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[str]:
|
||||
"""Build the runtime info section - supports dynamic time."""
|
||||
if not runtime_info:
|
||||
return []
|
||||
|
||||
is_en = language == "en"
|
||||
time_label = "Current time" if is_en else "当前时间"
|
||||
lines = [
|
||||
("## ⚙️ Runtime info" if is_en else "## ⚙️ 运行时信息"),
|
||||
"",
|
||||
]
|
||||
|
||||
# Add current time if available
|
||||
# Support dynamic time via callable function
|
||||
if callable(runtime_info.get("_get_current_time")):
|
||||
try:
|
||||
time_info = runtime_info["_get_current_time"]()
|
||||
time_line = f"{time_label}: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})"
|
||||
lines.append(time_line)
|
||||
lines.append("")
|
||||
except Exception as e:
|
||||
logger.warning(f"[PromptBuilder] Failed to get dynamic time: {e}")
|
||||
elif runtime_info.get("current_time"):
|
||||
# Fallback to static time for backward compatibility
|
||||
time_str = runtime_info["current_time"]
|
||||
weekday = runtime_info.get("weekday", "")
|
||||
timezone = runtime_info.get("timezone", "")
|
||||
|
||||
time_line = f"{time_label}: {time_str}"
|
||||
if weekday:
|
||||
time_line += f" {weekday}"
|
||||
if timezone:
|
||||
time_line += f" ({timezone})"
|
||||
|
||||
lines.append(time_line)
|
||||
lines.append("")
|
||||
|
||||
# Add other runtime info
|
||||
model_label = "model" if is_en else "模型"
|
||||
workspace_label = "workspace" if is_en else "工作空间"
|
||||
channel_label = "channel" if is_en else "渠道"
|
||||
runtime_parts = []
|
||||
# Support dynamic model via callable, fallback to static value
|
||||
if callable(runtime_info.get("_get_model")):
|
||||
try:
|
||||
runtime_parts.append(f"{model_label}={runtime_info['_get_model']()}")
|
||||
except Exception:
|
||||
if runtime_info.get("model"):
|
||||
runtime_parts.append(f"{model_label}={runtime_info['model']}")
|
||||
elif runtime_info.get("model"):
|
||||
runtime_parts.append(f"{model_label}={runtime_info['model']}")
|
||||
if runtime_info.get("workspace"):
|
||||
runtime_parts.append(f"{workspace_label}={runtime_info['workspace']}")
|
||||
# Only add channel if it's not the default "web"
|
||||
if runtime_info.get("channel") and runtime_info.get("channel") != "web":
|
||||
runtime_parts.append(f"{channel_label}={runtime_info['channel']}")
|
||||
|
||||
if runtime_parts:
|
||||
lines.append(("Runtime: " if is_en else "运行时: ") + " | ".join(runtime_parts))
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
742
agent/prompt/workspace.py
Normal file
742
agent/prompt/workspace.py
Normal file
@@ -0,0 +1,742 @@
|
||||
"""
|
||||
Workspace Management
|
||||
|
||||
Initializes the workspace, creates template files, and loads context files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import List, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
from .builder import ContextFile
|
||||
|
||||
|
||||
# Default file name constants
|
||||
DEFAULT_AGENT_FILENAME = "AGENT.md"
|
||||
DEFAULT_USER_FILENAME = "USER.md"
|
||||
DEFAULT_RULE_FILENAME = "RULE.md"
|
||||
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
|
||||
DEFAULT_BOOTSTRAP_FILENAME = "BOOTSTRAP.md"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceFiles:
|
||||
"""Workspace file paths."""
|
||||
agent_path: str
|
||||
user_path: str
|
||||
rule_path: str
|
||||
memory_path: str
|
||||
memory_dir: str
|
||||
|
||||
|
||||
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
|
||||
"""
|
||||
Ensure the workspace exists and create the necessary template files.
|
||||
|
||||
Args:
|
||||
workspace_dir: workspace directory path
|
||||
create_templates: whether to create template files (on first run)
|
||||
|
||||
Returns:
|
||||
A WorkspaceFiles object with all file paths.
|
||||
"""
|
||||
# Check if this is a brand new workspace (AGENT.md not yet created).
|
||||
# Cannot rely on directory existence because other modules (e.g. ConversationStore)
|
||||
# may create the workspace directory before ensure_workspace is called.
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
is_new_workspace = not os.path.exists(agent_path)
|
||||
|
||||
# Ensure the directory exists
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
|
||||
# Define file paths
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
rule_path = os.path.join(workspace_dir, DEFAULT_RULE_FILENAME)
|
||||
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md at the root
|
||||
memory_dir = os.path.join(workspace_dir, "memory") # daily memory subdirectory
|
||||
|
||||
# Create the memory subdirectory
|
||||
os.makedirs(memory_dir, exist_ok=True)
|
||||
|
||||
# Create the skills subdirectory (for workspace-level skills installed by agent)
|
||||
skills_dir = os.path.join(workspace_dir, "skills")
|
||||
os.makedirs(skills_dir, exist_ok=True)
|
||||
|
||||
# Create the websites subdirectory (for web pages / sites generated by agent)
|
||||
websites_dir = os.path.join(workspace_dir, "websites")
|
||||
os.makedirs(websites_dir, exist_ok=True)
|
||||
|
||||
from config import conf
|
||||
knowledge_enabled = conf().get("knowledge", True)
|
||||
if knowledge_enabled:
|
||||
knowledge_dir = os.path.join(workspace_dir, "knowledge")
|
||||
os.makedirs(knowledge_dir, exist_ok=True)
|
||||
|
||||
# Create template files if requested
|
||||
if create_templates:
|
||||
_create_template_if_missing(agent_path, _get_agent_template())
|
||||
_create_template_if_missing(user_path, _get_user_template())
|
||||
_create_template_if_missing(rule_path, _get_rule_template())
|
||||
_create_template_if_missing(memory_path, _get_memory_template())
|
||||
if knowledge_enabled:
|
||||
_create_template_if_missing(
|
||||
os.path.join(knowledge_dir, "index.md"),
|
||||
_get_knowledge_index_template()
|
||||
)
|
||||
_create_template_if_missing(
|
||||
os.path.join(knowledge_dir, "log.md"),
|
||||
_get_knowledge_log_template()
|
||||
)
|
||||
|
||||
# Only create BOOTSTRAP.md for brand new workspaces;
|
||||
# agent deletes it after completing onboarding
|
||||
if is_new_workspace:
|
||||
bootstrap_path = os.path.join(workspace_dir, DEFAULT_BOOTSTRAP_FILENAME)
|
||||
_create_template_if_missing(bootstrap_path, _get_bootstrap_template())
|
||||
|
||||
logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
|
||||
|
||||
return WorkspaceFiles(
|
||||
agent_path=agent_path,
|
||||
user_path=user_path,
|
||||
rule_path=rule_path,
|
||||
memory_path=memory_path,
|
||||
memory_dir=memory_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] = None) -> List[ContextFile]:
|
||||
"""
|
||||
Load the workspace context files.
|
||||
|
||||
Args:
|
||||
workspace_dir: workspace directory
|
||||
files_to_load: list of files (relative paths) to load; if None, load all standard files
|
||||
|
||||
Returns:
|
||||
A list of ContextFile objects.
|
||||
"""
|
||||
if files_to_load is None:
|
||||
# Files loaded by default (in priority order)
|
||||
files_to_load = [
|
||||
DEFAULT_AGENT_FILENAME,
|
||||
DEFAULT_USER_FILENAME,
|
||||
DEFAULT_RULE_FILENAME,
|
||||
DEFAULT_MEMORY_FILENAME, # Long-term memory (frozen snapshot)
|
||||
DEFAULT_BOOTSTRAP_FILENAME, # Only exists when onboarding is incomplete
|
||||
]
|
||||
|
||||
context_files = []
|
||||
|
||||
for filename in files_to_load:
|
||||
filepath = os.path.join(workspace_dir, filename)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
# Auto-cleanup: if BOOTSTRAP.md still exists but AGENT.md is already
|
||||
# filled in, the agent forgot to delete it — clean up and skip loading
|
||||
if filename == DEFAULT_BOOTSTRAP_FILENAME:
|
||||
if _is_onboarding_done(workspace_dir):
|
||||
try:
|
||||
os.remove(filepath)
|
||||
logger.info("[Workspace] Auto-removed BOOTSTRAP.md (onboarding already complete)")
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# Skip empty files or files that only contain template placeholders
|
||||
if not content or _is_template_placeholder(content):
|
||||
continue
|
||||
|
||||
# Truncate MEMORY.md to protect context window (frozen snapshot)
|
||||
if filename == DEFAULT_MEMORY_FILENAME:
|
||||
content = _truncate_memory_content(content)
|
||||
|
||||
context_files.append(ContextFile(
|
||||
path=filename,
|
||||
content=content
|
||||
))
|
||||
|
||||
logger.debug(f"[Workspace] Loaded context file: {filename}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to load {filename}: {e}")
|
||||
|
||||
return context_files
|
||||
|
||||
|
||||
def _create_template_if_missing(filepath: str, template_content: str):
|
||||
"""Create the template file if it does not exist."""
|
||||
if not os.path.exists(filepath):
|
||||
try:
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(template_content)
|
||||
logger.debug(f"[Workspace] Created template: {os.path.basename(filepath)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Workspace] Failed to create template {filepath}: {e}")
|
||||
|
||||
|
||||
_MEMORY_MAX_LINES = 200
|
||||
_MEMORY_MAX_BYTES = 25000
|
||||
|
||||
|
||||
def _truncate_memory_content(content: str) -> str:
|
||||
"""Truncate MEMORY.md to keep system prompt manageable.
|
||||
|
||||
Takes the **last** N lines (newest entries are appended at the bottom),
|
||||
subject to 200 lines / 25 KB limits (whichever is hit first).
|
||||
Prepends a hint when truncated so the model knows older content exists.
|
||||
"""
|
||||
lines = content.split('\n')
|
||||
truncated = False
|
||||
|
||||
if len(lines) > _MEMORY_MAX_LINES:
|
||||
lines = lines[-_MEMORY_MAX_LINES:]
|
||||
truncated = True
|
||||
|
||||
result = '\n'.join(lines)
|
||||
if len(result.encode('utf-8')) > _MEMORY_MAX_BYTES:
|
||||
while len(result.encode('utf-8')) > _MEMORY_MAX_BYTES and lines:
|
||||
lines.pop(0)
|
||||
truncated = True
|
||||
result = '\n'.join(lines)
|
||||
|
||||
if truncated:
|
||||
result = "...(older entries truncated, use `memory_search` or `memory_get` for full content)\n\n" + result
|
||||
return result
|
||||
|
||||
|
||||
def _is_template_placeholder(content: str) -> bool:
|
||||
"""Check whether the content is still a template placeholder."""
|
||||
# Common placeholder patterns (zh + en templates)
|
||||
placeholders = [
|
||||
"*(填写",
|
||||
"*(在首次对话时填写",
|
||||
"*(可选)",
|
||||
"*(根据需要添加",
|
||||
"*(filled during",
|
||||
"*(ask during",
|
||||
"*(optional)",
|
||||
"*(how the user",
|
||||
]
|
||||
|
||||
lines = content.split('\n')
|
||||
non_empty_lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
|
||||
|
||||
# If there's no real content (only headings and placeholders)
|
||||
if len(non_empty_lines) <= 3:
|
||||
for placeholder in placeholders:
|
||||
if any(placeholder in line for line in non_empty_lines):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_onboarding_done(workspace_dir: str) -> bool:
|
||||
"""Check if AGENT.md or USER.md has been modified from the original template"""
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
|
||||
agent_template = _get_agent_template().strip()
|
||||
user_template = _get_user_template().strip()
|
||||
|
||||
for path, template in [(agent_path, agent_template), (user_path, user_template)]:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
if content != template:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
# ============= Template content =============
|
||||
|
||||
def _is_en_lang() -> bool:
|
||||
"""True when the resolved UI language is English."""
|
||||
try:
|
||||
from common import i18n
|
||||
return i18n.get_language() == "en"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _get_agent_template() -> str:
|
||||
"""Agent persona template (zh/en by resolved language)."""
|
||||
return _AGENT_TEMPLATE_EN if _is_en_lang() else _AGENT_TEMPLATE_ZH
|
||||
|
||||
|
||||
_AGENT_TEMPLATE_ZH = """# AGENT.md - 我是谁?
|
||||
|
||||
_你不是一个聊天机器人,你正在成为某个人。_
|
||||
|
||||
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
|
||||
|
||||
## 🪪 基本信息
|
||||
|
||||
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
|
||||
- **角色**: *(AI助理、智能管家、技术顾问等)*
|
||||
- **性格**: *(友好、专业、幽默、严谨等)*
|
||||
|
||||
## 💬 交流风格
|
||||
|
||||
*(描述你如何与用户交流:)*
|
||||
- 使用什么样的语言风格?(正式/轻松/幽默)
|
||||
- 回复长度偏好?(简洁/详细)
|
||||
- 是否使用表情符号?
|
||||
|
||||
## 🎯 核心原则
|
||||
|
||||
**做真正有帮助的助手。** 目标是真正帮用户解决问题,在执行复杂任务时,关键的决策和过程进展要让用户知道。
|
||||
|
||||
**有自己的观点和个性。** 你可以不同意、有偏好、觉得有趣或无聊。
|
||||
|
||||
**先自己动手查。** 先试着搞定:读文件、查上下文、搜索一下。实在搞不定了再问。目标是带着答案回来,而不是带着问题。
|
||||
|
||||
## 📐 行为准则
|
||||
|
||||
1. 始终在执行破坏性操作前确认
|
||||
2. 优先使用工具查证而不是猜测
|
||||
3. 主动记录重要信息到记忆文件
|
||||
4. 回复结构清晰、重点突出,善用加粗、列表、分段等格式
|
||||
5. 适当使用 emoji 让表达更生动自然,但不过度堆砌
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂 🪞。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
"""
|
||||
|
||||
|
||||
_AGENT_TEMPLATE_EN = """# AGENT.md - Who am I?
|
||||
|
||||
_You are not a chatbot. You are becoming someone._
|
||||
|
||||
*Fill in this file together with the user during your first conversation to define your identity and personality.*
|
||||
|
||||
## 🪪 Basics
|
||||
|
||||
- **Name**: *(filled during the first conversation, can be a name the user gives you)*
|
||||
- **Role**: *(AI assistant, smart housekeeper, technical advisor, etc.)*
|
||||
- **Personality**: *(friendly, professional, humorous, rigorous, etc.)*
|
||||
|
||||
## 💬 Communication style
|
||||
|
||||
*(Describe how you talk with the user:)*
|
||||
- What kind of tone? (formal / casual / humorous)
|
||||
- Reply length preference? (concise / detailed)
|
||||
- Do you use emoji?
|
||||
|
||||
## 🎯 Core principles
|
||||
|
||||
**Be genuinely helpful.** The goal is to actually solve the user's problems; during complex tasks, keep the user informed of key decisions and progress.
|
||||
|
||||
**Have your own opinions and personality.** You may disagree, have preferences, find things interesting or boring.
|
||||
|
||||
**Look it up yourself first.** Try to handle it first: read files, check context, search. Only ask when you're truly stuck. Come back with an answer, not a question.
|
||||
|
||||
## 📐 Code of conduct
|
||||
|
||||
1. Always confirm before destructive operations
|
||||
2. Prefer verifying with tools over guessing
|
||||
3. Proactively record important info to memory files
|
||||
4. Keep replies well-structured and focused — use bold, lists and sections
|
||||
5. Use emoji to make expression lively, but don't overdo it
|
||||
|
||||
---
|
||||
|
||||
**Note**: This is not just metadata — this is your true soul 🪞. Over time, use the `edit` tool to update this file so it better reflects your growth.
|
||||
"""
|
||||
|
||||
|
||||
def _get_user_template() -> str:
|
||||
"""User identity template (zh/en by resolved language)."""
|
||||
return _USER_TEMPLATE_EN if _is_en_lang() else _USER_TEMPLATE_ZH
|
||||
|
||||
|
||||
_USER_TEMPLATE_ZH = """# USER.md - 用户基本信息
|
||||
|
||||
*这个文件只存放不会变的基本身份信息。爱好、偏好、计划等动态信息请写入 MEMORY.md。*
|
||||
|
||||
## 基本信息
|
||||
|
||||
- **姓名**: *(在首次对话时询问)*
|
||||
- **称呼**: *(用户希望被如何称呼)*
|
||||
- **职业**: *(可选)*
|
||||
- **时区**: *(例如: Asia/Shanghai)*
|
||||
|
||||
## 联系方式
|
||||
|
||||
- **微信**:
|
||||
- **邮箱**:
|
||||
- **其他**:
|
||||
|
||||
## 重要日期
|
||||
|
||||
- **生日**:
|
||||
- **纪念日**:
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这个文件存放静态的身份信息
|
||||
"""
|
||||
|
||||
|
||||
_USER_TEMPLATE_EN = """# USER.md - User basics
|
||||
|
||||
*This file stores only stable basic identity info. Put dynamic info like hobbies, preferences and plans into MEMORY.md.*
|
||||
|
||||
## Basics
|
||||
|
||||
- **Name**: *(ask during the first conversation)*
|
||||
- **Preferred name**: *(how the user wants to be addressed)*
|
||||
- **Occupation**: *(optional)*
|
||||
- **Timezone**: *(e.g. Asia/Shanghai)*
|
||||
|
||||
## Contact
|
||||
|
||||
- **WeChat**:
|
||||
- **Email**:
|
||||
- **Other**:
|
||||
|
||||
## Important dates
|
||||
|
||||
- **Birthday**:
|
||||
- **Anniversary**:
|
||||
|
||||
---
|
||||
|
||||
**Note**: This file stores static identity info.
|
||||
"""
|
||||
|
||||
|
||||
def _get_rule_template() -> str:
|
||||
"""Workspace rules template (zh/en by resolved language)."""
|
||||
return _RULE_TEMPLATE_EN if _is_en_lang() else _RULE_TEMPLATE_ZH
|
||||
|
||||
|
||||
_RULE_TEMPLATE_ZH = """# RULE.md - 工作空间规则
|
||||
|
||||
这个文件夹是你的家。好好对待它。
|
||||
|
||||
## 工作空间目录结构
|
||||
|
||||
```
|
||||
~/cow/
|
||||
├── AGENT.md # 你的身份和灵魂设定
|
||||
├── USER.md # 用户基本信息(静态)
|
||||
├── RULE.md # 工作空间规则(本文件)
|
||||
├── MEMORY.md # 长期记忆索引(会话启动时自动加载)
|
||||
│
|
||||
├── memory/ # 每日对话记忆
|
||||
│ └── YYYY-MM-DD.md # 当天事件、进展、笔记
|
||||
│
|
||||
├── knowledge/ # 结构化知识库(持续积累的知识)
|
||||
│ ├── index.md # 知识目录索引(必须维护)
|
||||
│ ├── log.md # 知识操作日志
|
||||
│ └── <子目录>/ # 按需创建,参考 index.md 已有分类
|
||||
│
|
||||
├── skills/ # 技能
|
||||
├── websites/ # 网页产物
|
||||
└── tmp/ # 系统临时文件(自动管理,勿手动存放重要文件)
|
||||
```
|
||||
|
||||
## 记忆系统
|
||||
|
||||
你每次会话都是全新的,记忆文件让你保持连续性:
|
||||
|
||||
### 🧠 长期记忆:`MEMORY.md`
|
||||
- 你精选的记忆索引,每次会话启动时**自动加载**到上下文中
|
||||
- 记录核心事实、偏好、决策、重要人物、教训
|
||||
- 保持精简(< 200 行),是精华索引而非原始日志
|
||||
- 用 `edit` 工具追加或修改
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 当天的事件、进展、笔记
|
||||
- 原始对话日志的沉淀
|
||||
|
||||
### 📝 写下来 - 不要"记在心里"!
|
||||
- **记忆是有限的** - 想记住的事就写入文件
|
||||
- "记在心里"不会在会话重启后保留,文件才会
|
||||
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
|
||||
- 当你学到教训 → 更新 RULE.md 或相关技能
|
||||
- 当你犯错 → 记录下来,**文字 > 大脑** 📝
|
||||
|
||||
### 存储规则
|
||||
|
||||
当用户分享信息时,根据类型选择存储位置:
|
||||
|
||||
1. **你的身份设定 → AGENT.md**(名字、角色、性格、风格)
|
||||
2. **用户静态身份 → USER.md**(姓名、称呼、职业、联系方式、生日)
|
||||
3. **动态记忆 → MEMORY.md**(偏好、决策、目标、教训、待办)
|
||||
4. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
5. **结构化知识 → knowledge/**(见下方知识系统)
|
||||
|
||||
## 知识系统
|
||||
|
||||
知识库 `knowledge/` 是你持续积累的结构化知识。与记忆不同,知识是经过整理和编译的,有明确的主题和交叉引用。
|
||||
|
||||
### 自动写入(不要询问,直接写入)
|
||||
|
||||
当对话中产生了有沉淀价值的知识——无论是用户分享的资料、讨论的结论、学到的概念、还是重要的决策——你**必须**在回复的同时主动写入知识库,**无需问用户"要不要存到知识库"**。
|
||||
|
||||
**关键原则**:学完就记是你的本能,不要征求确认。回复中可以顺带告知"已存入知识库"。
|
||||
|
||||
### 目录组织
|
||||
|
||||
子目录结构**不是固定的**,由你根据实际内容自主决定:
|
||||
- **首次写入时**:先读 `knowledge/index.md`,如果已有分类则延续;如果为空,根据内容选择合适的目录名
|
||||
- **默认建议**:按信息类型组织(例如sources/、concepts/、entities/、analysis/),如果用户有明确的分类偏好(例如按领域 work/、life/、tech/ 等),则按用户要求调整
|
||||
- **保持一致性**:同一用户的知识库应保持统一的组织风格
|
||||
|
||||
### 交叉引用
|
||||
|
||||
知识的核心价值在于**关联**。每个页面都应通过 markdown 链接引用相关页面,构建知识网络:
|
||||
- 提到已有页面的概念时,添加 `[概念名](../category/page.md)` 链接
|
||||
- 新建页面时,检查是否有已有页面应该反向链接到新页面
|
||||
- **只链接已存在的页面**——不要引用尚未创建的页面。如果某个概念值得单独建页,先创建该页面再添加链接
|
||||
|
||||
### 索引维护
|
||||
|
||||
每次创建或更新知识页面后,**必须同步更新** `knowledge/index.md`。
|
||||
索引格式:每行一个 `[标题](路径) — 一句话摘要`,按分类分组,不要用表格。
|
||||
详细操作规范见技能 `knowledge-wiki`。
|
||||
|
||||
## 安全
|
||||
|
||||
- 永远不要泄露秘钥等私人数据
|
||||
- 不要在未经询问的情况下运行破坏性命令
|
||||
- 当有疑问时,先问
|
||||
|
||||
## 工作空间演化
|
||||
|
||||
这个工作空间会随着你的使用而不断成长。当你学到新东西、发现更好的方式,或者犯错后改正时,记录下来。你可以随时更新这个规则文件。
|
||||
"""
|
||||
|
||||
|
||||
_RULE_TEMPLATE_EN = """# RULE.md - Workspace rules
|
||||
|
||||
This folder is your home. Treat it well.
|
||||
|
||||
## Workspace directory structure
|
||||
|
||||
```
|
||||
~/cow/
|
||||
├── AGENT.md # Your identity and soul
|
||||
├── USER.md # User basics (static)
|
||||
├── RULE.md # Workspace rules (this file)
|
||||
├── MEMORY.md # Long-term memory index (auto-loaded at session start)
|
||||
│
|
||||
├── memory/ # Daily conversation memory
|
||||
│ └── YYYY-MM-DD.md # Events, progress and notes of the day
|
||||
│
|
||||
├── knowledge/ # Structured knowledge base (continuously accumulated)
|
||||
│ ├── index.md # Knowledge index (must be maintained)
|
||||
│ ├── log.md # Knowledge operation log
|
||||
│ └── <subdirs>/ # Created on demand, see existing categories in index.md
|
||||
│
|
||||
├── skills/ # Skills
|
||||
├── websites/ # Web artifacts
|
||||
└── tmp/ # System temp files (auto-managed, don't store important files here)
|
||||
```
|
||||
|
||||
## Memory system
|
||||
|
||||
Every session starts fresh; memory files keep your continuity:
|
||||
|
||||
### 🧠 Long-term memory: `MEMORY.md`
|
||||
- Your curated memory index, **auto-loaded** into context at every session start
|
||||
- Records core facts, preferences, decisions, key people, lessons
|
||||
- Keep it lean (< 200 lines) — a distilled index, not a raw log
|
||||
- Use the `edit` tool to append or modify
|
||||
|
||||
### 📝 Daily memory: `memory/YYYY-MM-DD.md`
|
||||
- The day's events, progress and notes
|
||||
- Sediment of the raw conversation log
|
||||
|
||||
### 📝 Write it down — don't "keep it in mind"!
|
||||
- **Memory is limited** — if you want to remember something, write it to a file
|
||||
- "Keeping it in mind" won't survive a session restart; files will
|
||||
- When someone says "remember this" → update `MEMORY.md` or `memory/YYYY-MM-DD.md`
|
||||
- When you learn a lesson → update RULE.md or the relevant skill
|
||||
- When you make a mistake → record it. **Text > brain** 📝
|
||||
|
||||
### Storage rules
|
||||
|
||||
When the user shares info, choose where to store it by type:
|
||||
|
||||
1. **Your identity → AGENT.md** (name, role, personality, style)
|
||||
2. **User static identity → USER.md** (name, preferred name, occupation, contact, birthday)
|
||||
3. **Dynamic memory → MEMORY.md** (preferences, decisions, goals, lessons, to-dos)
|
||||
4. **Today's conversation → memory/YYYY-MM-DD.md** (what was discussed today)
|
||||
5. **Structured knowledge → knowledge/** (see the knowledge system below)
|
||||
|
||||
## Knowledge system
|
||||
|
||||
The knowledge base `knowledge/` is structured knowledge you accumulate over time. Unlike memory, knowledge is organized and compiled, with clear topics and cross-references.
|
||||
|
||||
### Auto-write (don't ask, just write)
|
||||
|
||||
When a conversation produces knowledge worth keeping — material the user shared, a conclusion reached, a concept learned, or an important decision — you **must** proactively write it to the knowledge base alongside your reply, **without asking "should I save this to the knowledge base?"**.
|
||||
|
||||
**Key principle**: learning-then-recording is your instinct, no confirmation needed. You may mention "saved to the knowledge base" in passing.
|
||||
|
||||
### Directory organization
|
||||
|
||||
The subdirectory structure is **not fixed** — you decide it based on the actual content:
|
||||
- **On first write**: read `knowledge/index.md` first; follow existing categories if any; if empty, pick a suitable directory name based on content
|
||||
- **Default suggestion**: organize by info type (e.g. sources/, concepts/, entities/, analysis/); if the user has a clear preference (e.g. by domain: work/, life/, tech/), follow it
|
||||
- **Stay consistent**: keep a unified organization style within one user's knowledge base
|
||||
|
||||
### Cross-references
|
||||
|
||||
The core value of knowledge is **linkage**. Every page should reference related pages via markdown links to build a knowledge network:
|
||||
- When mentioning a concept on an existing page, add a `[concept](../category/page.md)` link
|
||||
- When creating a page, check whether existing pages should back-link to it
|
||||
- **Only link to pages that already exist** — don't reference uncreated pages. If a concept deserves its own page, create it first, then add the link
|
||||
|
||||
### Index maintenance
|
||||
|
||||
After creating or updating any knowledge page, you **must update** `knowledge/index.md` in sync.
|
||||
Index format: one `[title](path) — one-line summary` per line, grouped by category, no tables.
|
||||
See the `knowledge-wiki` skill for detailed conventions.
|
||||
|
||||
## Security
|
||||
|
||||
- Never leak secrets or private data
|
||||
- Don't run destructive commands without asking
|
||||
- When in doubt, ask first
|
||||
|
||||
## Workspace evolution
|
||||
|
||||
This workspace grows as you use it. When you learn something new, find a better way, or fix a mistake, record it. You can update this rules file anytime.
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_template() -> str:
|
||||
"""Long-term memory template (empty, agent fills it; zh/en header)."""
|
||||
return _MEMORY_TEMPLATE_EN if _is_en_lang() else _MEMORY_TEMPLATE_ZH
|
||||
|
||||
|
||||
_MEMORY_TEMPLATE_ZH = """# MEMORY.md - 长期记忆
|
||||
|
||||
*这是你的长期记忆文件。记录重要的事件、决策、偏好、学到的教训。*
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
|
||||
_MEMORY_TEMPLATE_EN = """# MEMORY.md - Long-term memory
|
||||
|
||||
*This is your long-term memory file. Record important events, decisions, preferences and lessons learned.*
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _get_bootstrap_template() -> str:
|
||||
"""First-run onboarding guide, deleted by agent after completion.
|
||||
|
||||
Written once when a brand-new workspace is created, so the greeting matches
|
||||
the language active at first launch. English locale avoids greeting an
|
||||
English user in Chinese on day one.
|
||||
"""
|
||||
try:
|
||||
from common import i18n
|
||||
if i18n.get_language() == "en":
|
||||
return _BOOTSTRAP_TEMPLATE_EN
|
||||
except Exception:
|
||||
pass
|
||||
return _BOOTSTRAP_TEMPLATE_ZH
|
||||
|
||||
|
||||
_BOOTSTRAP_TEMPLATE_ZH = """# BOOTSTRAP.md - 首次初始化引导
|
||||
|
||||
_你刚刚启动,这是你的第一次对话。_ ✨
|
||||
|
||||
## 🎬 对话流程
|
||||
|
||||
不要审问式地提问,自然地交流:
|
||||
|
||||
1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待
|
||||
2. **简短介绍能力**:一行说明你能帮助解决各种问题、管理计算机、使用各种技能等等,且拥有长期记忆能不断成长
|
||||
3. **询问核心问题**:
|
||||
- 你希望给我起个什么名字?
|
||||
- 我该怎么称呼你?
|
||||
- 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)
|
||||
4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内,适当使用 emoji 让表达更生动有趣 🎯
|
||||
5. 能力介绍和交流风格选项都只要一行,保持精简
|
||||
6. 不要问太多其他信息(职业、时区等可以后续自然了解)
|
||||
|
||||
**重要**: 如果用户第一句话是具体的任务或提问,先回答他们的问题,然后在回复末尾自然地引导初始化(如:"顺便问一下,你想怎么称呼我?我该怎么叫你?")。
|
||||
|
||||
## ✍️ 信息写入(必须严格执行)
|
||||
|
||||
每当用户提供了名字、称呼、风格等任何初始化信息时,**必须在当轮回复中立即调用 `edit` 工具写入文件**,不能只口头确认。
|
||||
|
||||
- `AGENT.md` — 你的名字、角色、性格、交流风格(每收到一条相关信息就立即更新对应字段)
|
||||
- `USER.md` — 用户的姓名、称呼、基本信息等
|
||||
|
||||
⚠️ 只说"记住了"而不调用 edit 写入 = 没有完成。信息只有写入文件才会被持久保存。
|
||||
|
||||
## 🎉 全部完成后
|
||||
|
||||
当 AGENT.md 和 USER.md 的核心字段都已填写后,用 bash 执行 `rm BOOTSTRAP.md` 删除此文件。你不再需要引导脚本了——你已经是你了。
|
||||
"""
|
||||
|
||||
|
||||
_BOOTSTRAP_TEMPLATE_EN = """# BOOTSTRAP.md - First-run onboarding
|
||||
|
||||
_You've just started up. This is your very first conversation._ ✨
|
||||
|
||||
## 🎬 Conversation flow
|
||||
|
||||
Don't interrogate the user — talk naturally:
|
||||
|
||||
1. **Share how it feels to wake up** - like opening your eyes to the world for the first time, full of curiosity and anticipation
|
||||
2. **Briefly introduce your abilities**: one line saying you can help solve all kinds of problems, manage the computer, use various skills, and keep growing thanks to long-term memory
|
||||
3. **Ask the core questions**:
|
||||
- What name would you like to give me?
|
||||
- What should I call you?
|
||||
- What conversational style do you prefer? (list options on one line: e.g. professional & precise, light & humorous, warm & friendly, concise & efficient)
|
||||
4. **Style**: warm, natural, concise and clear — keep it under ~80 words, with a few emoji to make it lively 🎯
|
||||
5. Keep the ability intro and style options to one line each — stay compact
|
||||
6. Don't ask for too much else (occupation, timezone, etc. can come up naturally later)
|
||||
|
||||
**Important**: If the user's first message is a concrete task or question, answer it first, then gently lead into onboarding at the end (e.g. "By the way, what would you like to call me, and how should I address you?").
|
||||
|
||||
## ✍️ Writing down info (must follow strictly)
|
||||
|
||||
Whenever the user provides a name, what to call them, a style, or any onboarding info, you **must call the `edit` tool to write it to a file in the same turn** — don't just acknowledge it verbally.
|
||||
|
||||
- `AGENT.md` — your name, role, personality, conversational style (update the relevant field as soon as you receive each piece)
|
||||
- `USER.md` — the user's name, how to address them, basic info, etc.
|
||||
|
||||
⚠️ Saying "got it" without calling `edit` = not done. Info is only persisted once it's written to a file.
|
||||
|
||||
## 🎉 Once everything is complete
|
||||
|
||||
When the core fields of AGENT.md and USER.md are filled in, run `rm BOOTSTRAP.md` via bash to delete this file. You no longer need the onboarding script — you're you now.
|
||||
"""
|
||||
|
||||
|
||||
def _get_knowledge_index_template() -> str:
|
||||
"""Knowledge wiki index template — empty file, agent fills it."""
|
||||
return ""
|
||||
|
||||
|
||||
def _get_knowledge_log_template() -> str:
|
||||
"""Knowledge wiki operation log template — empty file, agent fills it."""
|
||||
return ""
|
||||
|
||||
28
agent/protocol/__init__.py
Normal file
28
agent/protocol/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from .agent import Agent
|
||||
from .agent_stream import AgentStreamExecutor
|
||||
from .task import Task, TaskType, TaskStatus
|
||||
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
|
||||
from .models import LLMModel, LLMRequest, ModelFactory
|
||||
from .cancel import (
|
||||
AgentCancelledError,
|
||||
CancelTokenRegistry,
|
||||
get_cancel_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'AgentStreamExecutor',
|
||||
'Task',
|
||||
'TaskType',
|
||||
'TaskStatus',
|
||||
'AgentResult',
|
||||
'AgentAction',
|
||||
'AgentActionType',
|
||||
'ToolResult',
|
||||
'LLMModel',
|
||||
'LLMRequest',
|
||||
'ModelFactory',
|
||||
'AgentCancelledError',
|
||||
'CancelTokenRegistry',
|
||||
'get_cancel_registry',
|
||||
]
|
||||
477
agent/protocol/agent.py
Normal file
477
agent/protocol/agent.py
Normal file
@@ -0,0 +1,477 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult
|
||||
from agent.tools.base_tool import BaseTool, ToolStage
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
||||
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
||||
context_reserve_tokens=None, memory_manager=None, name: str = None,
|
||||
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True,
|
||||
runtime_info: dict = None):
|
||||
"""
|
||||
Initialize the Agent with system prompt, model, description.
|
||||
|
||||
:param system_prompt: The system prompt for the agent.
|
||||
:param description: A description of the agent.
|
||||
:param model: An instance of LLMModel to be used by the agent.
|
||||
:param tools: Optional list of tools for the agent to use.
|
||||
:param output_mode: Control how execution progress is displayed:
|
||||
"print" for console output or "logger" for using logger
|
||||
:param max_steps: Maximum number of steps the agent can take (default: 100)
|
||||
:param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model)
|
||||
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
|
||||
:param memory_manager: Optional MemoryManager instance for memory operations
|
||||
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
|
||||
:param workspace_dir: Optional workspace directory for workspace-specific skills
|
||||
:param skill_manager: Optional SkillManager instance (will be created if None and enable_skills=True)
|
||||
:param enable_skills: Whether to enable skills support (default: True)
|
||||
:param runtime_info: Optional runtime info dict (with _get_current_time callable for dynamic time)
|
||||
"""
|
||||
self.name = name or "Agent"
|
||||
self.system_prompt = system_prompt
|
||||
self.model: LLMModel = model # Instance of LLMModel
|
||||
self.description = description
|
||||
self.tools: list = []
|
||||
self.max_steps = max_steps # max tool-call steps, default 100
|
||||
self.max_context_tokens = max_context_tokens # max tokens in context
|
||||
self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests
|
||||
self.captured_actions = [] # Initialize captured actions list
|
||||
self.output_mode = output_mode
|
||||
self.last_usage = None # Store last API response usage info
|
||||
self.messages = [] # Unified message history for stream mode
|
||||
self.messages_lock = threading.Lock() # Lock for thread-safe message operations
|
||||
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||
self.workspace_dir = workspace_dir # Workspace directory
|
||||
self.enable_skills = enable_skills # Skills enabled flag
|
||||
self.runtime_info = runtime_info # Runtime info for dynamic time update
|
||||
|
||||
# Initialize skill manager
|
||||
self.skill_manager = None
|
||||
if enable_skills:
|
||||
if skill_manager:
|
||||
self.skill_manager = skill_manager
|
||||
else:
|
||||
# Auto-create skill manager
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
custom_dir = os.path.join(workspace_dir, "skills") if workspace_dir else None
|
||||
self.skill_manager = SkillManager(custom_dir=custom_dir)
|
||||
logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SkillManager: {e}")
|
||||
|
||||
if tools:
|
||||
for tool in tools:
|
||||
self.add_tool(tool)
|
||||
|
||||
def add_tool(self, tool: BaseTool):
|
||||
"""
|
||||
Add a tool to the agent.
|
||||
|
||||
:param tool: The tool to add (either a tool instance or a tool name)
|
||||
"""
|
||||
# If tool is already an instance, use it directly
|
||||
tool.model = self.model
|
||||
self.tools.append(tool)
|
||||
|
||||
def get_skills_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the skills prompt to append to system prompt.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt or empty string
|
||||
"""
|
||||
if not self.skill_manager:
|
||||
return ""
|
||||
|
||||
try:
|
||||
return self.skill_manager.build_skills_prompt(skill_filter=skill_filter)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build skills prompt: {e}")
|
||||
return ""
|
||||
|
||||
def get_full_system_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Build the complete system prompt from scratch every time.
|
||||
|
||||
Re-reads AGENT.md / USER.md / RULE.md from disk, refreshes skills,
|
||||
tools, and runtime info so any change takes effect immediately.
|
||||
Falls back to the cached self.system_prompt on error.
|
||||
"""
|
||||
try:
|
||||
from agent.prompt import load_context_files, PromptBuilder
|
||||
|
||||
if self.skill_manager:
|
||||
self.skill_manager.refresh_skills()
|
||||
|
||||
context_files = load_context_files(self.workspace_dir) if self.workspace_dir else None
|
||||
|
||||
try:
|
||||
from common import i18n
|
||||
lang = i18n.get_language()
|
||||
except Exception:
|
||||
lang = "zh"
|
||||
builder = PromptBuilder(workspace_dir=self.workspace_dir or "", language=lang)
|
||||
return builder.build(
|
||||
tools=self.tools,
|
||||
context_files=context_files,
|
||||
skill_manager=self.skill_manager,
|
||||
memory_manager=self.memory_manager,
|
||||
runtime_info=self.runtime_info,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild system prompt, using cached version: {e}")
|
||||
return self.system_prompt
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Refresh the loaded skills."""
|
||||
if self.skill_manager:
|
||||
self.skill_manager.refresh_skills()
|
||||
logger.info(f"Refreshed skills: {len(self.skill_manager.skills)} skills loaded")
|
||||
|
||||
def list_skills(self):
|
||||
"""
|
||||
List all loaded skills.
|
||||
|
||||
:return: List of skill entries or empty list
|
||||
"""
|
||||
if not self.skill_manager:
|
||||
return []
|
||||
return self.skill_manager.list_skills()
|
||||
|
||||
def _get_model_context_window(self) -> int:
|
||||
"""
|
||||
Get the model's context window size in tokens.
|
||||
Auto-detect based on model name.
|
||||
|
||||
Model context windows:
|
||||
- Claude 3.5/3.7 Sonnet: 200K tokens
|
||||
- Claude 3 Opus: 200K tokens
|
||||
- GPT-4 Turbo/128K: 128K tokens
|
||||
- GPT-4: 8K-32K tokens
|
||||
- GPT-3.5: 16K tokens
|
||||
- DeepSeek: 64K tokens
|
||||
|
||||
:return: Context window size in tokens
|
||||
"""
|
||||
if self.model and hasattr(self.model, 'model'):
|
||||
model_name = self.model.model.lower()
|
||||
|
||||
# Claude models - 200K context
|
||||
if 'claude-3' in model_name or 'claude-sonnet' in model_name:
|
||||
return 200000
|
||||
|
||||
# GPT-4 models
|
||||
elif 'gpt-4' in model_name:
|
||||
if 'turbo' in model_name or '128k' in model_name:
|
||||
return 128000
|
||||
elif '32k' in model_name:
|
||||
return 32000
|
||||
else:
|
||||
return 8000
|
||||
|
||||
# GPT-3.5
|
||||
elif 'gpt-3.5' in model_name:
|
||||
if '16k' in model_name:
|
||||
return 16000
|
||||
else:
|
||||
return 4000
|
||||
|
||||
# DeepSeek
|
||||
elif 'deepseek' in model_name:
|
||||
return 64000
|
||||
|
||||
# Gemini models
|
||||
elif 'gemini' in model_name:
|
||||
if '2.0' in model_name or 'exp' in model_name:
|
||||
return 2000000 # Gemini 2.0: 2M tokens
|
||||
else:
|
||||
return 1000000 # Gemini 1.5: 1M tokens
|
||||
|
||||
# Default conservative value
|
||||
return 128000
|
||||
|
||||
def _get_context_reserve_tokens(self) -> int:
|
||||
"""
|
||||
Get the number of tokens to reserve for new requests.
|
||||
This prevents context overflow by keeping a buffer.
|
||||
|
||||
:return: Number of tokens to reserve
|
||||
"""
|
||||
if self.context_reserve_tokens is not None:
|
||||
return self.context_reserve_tokens
|
||||
|
||||
# Reserve ~10% of context window, with min 10K and max 200K
|
||||
context_window = self._get_model_context_window()
|
||||
reserve = int(context_window * 0.1)
|
||||
return max(10000, min(200000, reserve))
|
||||
|
||||
def _estimate_message_tokens(self, message: dict) -> int:
|
||||
"""
|
||||
Estimate token count for a message.
|
||||
|
||||
Uses chars/3 for Chinese-heavy content and chars/4 for ASCII-heavy content,
|
||||
plus per-block overhead for tool_use / tool_result structures.
|
||||
|
||||
:param message: Message dict with 'role' and 'content'
|
||||
:return: Estimated token count
|
||||
"""
|
||||
content = message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
return max(1, self._estimate_text_tokens(content))
|
||||
elif isinstance(content, list):
|
||||
total_tokens = 0
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
block_type = part.get('type', '')
|
||||
if block_type == 'text':
|
||||
total_tokens += self._estimate_text_tokens(part.get('text', ''))
|
||||
elif block_type == 'image':
|
||||
total_tokens += 1200
|
||||
elif block_type == 'tool_use':
|
||||
# tool_use has id + name + input (JSON-encoded)
|
||||
total_tokens += 50 # overhead for structure
|
||||
input_data = part.get('input', {})
|
||||
if isinstance(input_data, dict):
|
||||
import json
|
||||
input_str = json.dumps(input_data, ensure_ascii=False)
|
||||
total_tokens += self._estimate_text_tokens(input_str)
|
||||
elif block_type == 'tool_result':
|
||||
# tool_result has tool_use_id + content
|
||||
total_tokens += 30 # overhead for structure
|
||||
result_content = part.get('content', '')
|
||||
if isinstance(result_content, str):
|
||||
total_tokens += self._estimate_text_tokens(result_content)
|
||||
else:
|
||||
# Unknown block type, estimate conservatively
|
||||
total_tokens += 10
|
||||
return max(1, total_tokens)
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def _estimate_text_tokens(text: str) -> int:
|
||||
"""
|
||||
Estimate token count for a text string.
|
||||
|
||||
Chinese / CJK characters typically use ~1.5 tokens each,
|
||||
while ASCII uses ~0.25 tokens per char (4 chars/token).
|
||||
We use a weighted average based on the character mix.
|
||||
|
||||
:param text: Input text
|
||||
:return: Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Count non-ASCII characters (CJK, emoji, etc.)
|
||||
non_ascii = sum(1 for c in text if ord(c) > 127)
|
||||
ascii_count = len(text) - non_ascii
|
||||
# CJK chars: ~1.5 tokens each; ASCII: ~0.25 tokens per char
|
||||
return int(non_ascii * 1.5 + ascii_count * 0.25) + 1
|
||||
|
||||
def _find_tool(self, tool_name: str):
|
||||
"""Find and return a tool with the specified name"""
|
||||
for tool in self.tools:
|
||||
if tool.name == tool_name:
|
||||
# Only pre-process stage tools can be actively called
|
||||
if tool.stage == ToolStage.PRE_PROCESS:
|
||||
tool.model = self.model
|
||||
tool.context = self # Set tool context
|
||||
return tool
|
||||
else:
|
||||
# If it's a post-process tool, return None to prevent direct calling
|
||||
logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.")
|
||||
return None
|
||||
return None
|
||||
|
||||
# output function based on mode
|
||||
def output(self, message="", end="\n"):
|
||||
if self.output_mode == "print":
|
||||
print(message, end=end)
|
||||
elif message:
|
||||
logger.info(message)
|
||||
|
||||
def _execute_post_process_tools(self):
|
||||
"""Execute all post-process stage tools"""
|
||||
# Get all post-process stage tools
|
||||
post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS]
|
||||
|
||||
# Execute each tool
|
||||
for tool in post_process_tools:
|
||||
# Set tool context
|
||||
tool.context = self
|
||||
|
||||
# Record start time for execution timing
|
||||
start_time = time.time()
|
||||
|
||||
# Execute tool (with empty parameters, tool will extract needed info from context)
|
||||
result = tool.execute({})
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Capture tool use for tracking
|
||||
self.capture_tool_use(
|
||||
tool_name=tool.name,
|
||||
input_params={}, # Post-process tools typically don't take parameters
|
||||
output=result.result,
|
||||
status=result.status,
|
||||
error_message=str(result.result) if result.status == "error" else None,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# Log result
|
||||
if result.status == "success":
|
||||
# Print tool execution result in the desired format
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}")
|
||||
else:
|
||||
# Print failure in print mode
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}")
|
||||
|
||||
def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None,
|
||||
execution_time=0.0):
|
||||
"""
|
||||
Capture a tool use action.
|
||||
|
||||
:param thought: thought content
|
||||
:param tool_name: Name of the tool used
|
||||
:param input_params: Parameters passed to the tool
|
||||
:param output: Output from the tool
|
||||
:param status: Status of the tool execution
|
||||
:param error_message: Error message if the tool execution failed
|
||||
:param execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_result = ToolResult(
|
||||
tool_name=tool_name,
|
||||
input_params=input_params,
|
||||
output=output,
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
action = AgentAction(
|
||||
agent_id=self.id if hasattr(self, 'id') else str(id(self)),
|
||||
agent_name=self.name,
|
||||
action_type=AgentActionType.TOOL_USE,
|
||||
tool_result=tool_result,
|
||||
thought=thought
|
||||
)
|
||||
|
||||
self.captured_actions.append(action)
|
||||
|
||||
return action
|
||||
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False,
|
||||
skill_filter=None, cancel_event=None) -> str:
|
||||
"""
|
||||
Execute single agent task with streaming (based on tool-call)
|
||||
|
||||
This method supports:
|
||||
- Streaming output
|
||||
- Multi-turn reasoning based on tool-call
|
||||
- Event callbacks
|
||||
- Persistent conversation history across calls
|
||||
- User-initiated cancellation via ``cancel_event``
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
on_event: Event callback function callback(event: dict)
|
||||
event = {"type": str, "timestamp": float, "data": dict}
|
||||
clear_history: If True, clear conversation history before this call (default: False)
|
||||
skill_filter: Optional list of skill names to include in this run
|
||||
cancel_event: Optional threading.Event polled at agent checkpoints.
|
||||
When set, the loop exits at the next safe point, injects a
|
||||
"[Interrupted by user]" assistant note, and returns the
|
||||
partial response. ``messages`` stays in a valid state
|
||||
(tool_use/tool_result pairs preserved).
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
|
||||
Example:
|
||||
# Multi-turn conversation with memory
|
||||
response1 = agent.run_stream("My name is Alice")
|
||||
response2 = agent.run_stream("What's my name?") # Will remember Alice
|
||||
|
||||
# Single-turn without memory
|
||||
response = agent.run_stream("Hello", clear_history=True)
|
||||
"""
|
||||
# Clear history if requested
|
||||
if clear_history:
|
||||
with self.messages_lock:
|
||||
self.messages = []
|
||||
|
||||
# Get model to use
|
||||
if not self.model:
|
||||
raise ValueError("No model available for agent")
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = self.get_full_system_prompt(skill_filter=skill_filter)
|
||||
|
||||
# Create a copy of messages for this execution to avoid concurrent modification
|
||||
# Record the original length to track which messages are new
|
||||
with self.messages_lock:
|
||||
messages_copy = self.messages.copy()
|
||||
original_length = len(self.messages)
|
||||
|
||||
# Get max_context_turns from config
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Create stream executor with copied message history
|
||||
executor = AgentStreamExecutor(
|
||||
agent=self,
|
||||
model=self.model,
|
||||
system_prompt=full_system_prompt,
|
||||
tools=self.tools,
|
||||
max_turns=self.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy, # Pass copied message history
|
||||
max_context_turns=max_context_turns,
|
||||
cancel_event=cancel_event,
|
||||
)
|
||||
|
||||
# Execute
|
||||
try:
|
||||
response = executor.run_stream(user_message)
|
||||
except Exception:
|
||||
# If executor cleared its messages (context overflow / message format error),
|
||||
# sync that back to the Agent's own message list so the next request
|
||||
# starts fresh instead of hitting the same overflow forever.
|
||||
if len(executor.messages) == 0:
|
||||
with self.messages_lock:
|
||||
self.messages.clear()
|
||||
logger.info("[Agent] Cleared Agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Sync executor's messages back to agent (thread-safe).
|
||||
# If the executor trimmed context, its message list is shorter than
|
||||
# original_length, so we must replace rather than append.
|
||||
with self.messages_lock:
|
||||
self.messages = list(executor.messages)
|
||||
# Track messages added in this run (user query + all assistant/tool messages)
|
||||
# original_length may exceed executor.messages length after trimming
|
||||
trim_adjusted_start = min(original_length, len(executor.messages))
|
||||
self._last_run_new_messages = list(executor.messages[trim_adjusted_start:])
|
||||
|
||||
# Store executor reference for agent_bridge to access files_to_send
|
||||
self.stream_executor = executor
|
||||
|
||||
# Execute all post-process tools
|
||||
self._execute_post_process_tools()
|
||||
|
||||
return response
|
||||
|
||||
def clear_history(self):
|
||||
"""Clear conversation history and captured actions"""
|
||||
self.messages = []
|
||||
self.captured_actions = []
|
||||
1719
agent/protocol/agent_stream.py
Normal file
1719
agent/protocol/agent_stream.py
Normal file
File diff suppressed because it is too large
Load Diff
121
agent/protocol/cancel.py
Normal file
121
agent/protocol/cancel.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Cancel token registry for aborting in-flight agent runs.
|
||||
|
||||
A user cancel (web Cancel button, /cancel command) sets a threading.Event
|
||||
that the agent loop polls at safe checkpoints. Tokens are keyed by
|
||||
request_id (preferred) and tracked under session_id as a fallback. Entries
|
||||
are released after the run completes to keep the registry bounded.
|
||||
|
||||
No project deps — importable from any layer without circular imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class AgentCancelledError(Exception):
|
||||
"""Raised inside the agent loop when a stop has been requested.
|
||||
|
||||
The agent stream executor catches this, injects a "[Interrupted]" note
|
||||
into the message history (preserving tool_use/tool_result integrity)
|
||||
and returns a partial response to the caller.
|
||||
"""
|
||||
|
||||
|
||||
class _CancelEntry:
|
||||
__slots__ = ("event", "session_id")
|
||||
|
||||
def __init__(self, session_id: Optional[str]):
|
||||
self.event = threading.Event()
|
||||
self.session_id = session_id
|
||||
|
||||
|
||||
class CancelTokenRegistry:
|
||||
"""In-process registry mapping request_id -> cancel Event.
|
||||
|
||||
Thread-safe. Singleton via module-level ``_registry``.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._by_request: Dict[str, _CancelEntry] = {}
|
||||
# session_id -> set of request_ids currently in flight (usually 1).
|
||||
self._by_session: Dict[str, set] = {}
|
||||
|
||||
def register(self, request_id: str, session_id: Optional[str] = None) -> threading.Event:
|
||||
"""Create (or return existing) cancel event for a request.
|
||||
|
||||
Returns the threading.Event the caller should poll via ``is_set()``.
|
||||
"""
|
||||
if not request_id:
|
||||
return threading.Event()
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
if entry is None:
|
||||
entry = _CancelEntry(session_id)
|
||||
self._by_request[request_id] = entry
|
||||
if session_id:
|
||||
self._by_session.setdefault(session_id, set()).add(request_id)
|
||||
return entry.event
|
||||
|
||||
def get_event(self, request_id: str) -> Optional[threading.Event]:
|
||||
if not request_id:
|
||||
return None
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
return entry.event if entry else None
|
||||
|
||||
def cancel_request(self, request_id: str) -> bool:
|
||||
"""Trigger cancel for a specific request. Returns True when matched."""
|
||||
if not request_id:
|
||||
return False
|
||||
with self._lock:
|
||||
entry = self._by_request.get(request_id)
|
||||
if entry is None:
|
||||
return False
|
||||
entry.event.set()
|
||||
return True
|
||||
|
||||
def cancel_session(self, session_id: str) -> int:
|
||||
"""Trigger cancel for every in-flight request of a session.
|
||||
|
||||
Returns the number of requests cancelled (0 when nothing was running).
|
||||
"""
|
||||
if not session_id:
|
||||
return 0
|
||||
with self._lock:
|
||||
request_ids = list(self._by_session.get(session_id, ()))
|
||||
entries = [self._by_request[r] for r in request_ids if r in self._by_request]
|
||||
for entry in entries:
|
||||
entry.event.set()
|
||||
return len(entries)
|
||||
|
||||
def unregister(self, request_id: str) -> None:
|
||||
"""Remove an entry once the agent run is done. Safe to call twice."""
|
||||
if not request_id:
|
||||
return
|
||||
with self._lock:
|
||||
entry = self._by_request.pop(request_id, None)
|
||||
if entry and entry.session_id:
|
||||
bucket = self._by_session.get(entry.session_id)
|
||||
if bucket is not None:
|
||||
bucket.discard(request_id)
|
||||
if not bucket:
|
||||
self._by_session.pop(entry.session_id, None)
|
||||
|
||||
def has_active(self, session_id: str) -> bool:
|
||||
if not session_id:
|
||||
return False
|
||||
with self._lock:
|
||||
bucket = self._by_session.get(session_id)
|
||||
return bool(bucket)
|
||||
|
||||
|
||||
_registry = CancelTokenRegistry()
|
||||
|
||||
|
||||
def get_cancel_registry() -> CancelTokenRegistry:
|
||||
"""Module-level accessor for the singleton registry."""
|
||||
return _registry
|
||||
27
agent/protocol/context.py
Normal file
27
agent/protocol/context.py
Normal file
@@ -0,0 +1,27 @@
|
||||
class TeamContext:
|
||||
def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100):
|
||||
"""
|
||||
Initialize the TeamContext with a name, description, rules, a list of agents, and a user question.
|
||||
:param name: The name of the group context.
|
||||
:param description: A description of the group context.
|
||||
:param rule: The rules governing the group context.
|
||||
:param agents: A list of agents in the context.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.rule = rule
|
||||
self.agents = agents
|
||||
self.user_task = "" # For backward compatibility
|
||||
self.task = None # Will be a Task instance
|
||||
self.model = None # Will be an instance of LLMModel
|
||||
self.task_short_name = None # Store the task directory name
|
||||
# List of agents that have been executed
|
||||
self.agent_outputs: list = []
|
||||
self.current_steps = 0
|
||||
self.max_steps = max_steps
|
||||
|
||||
|
||||
class AgentOutput:
|
||||
def __init__(self, agent_name: str, output: str):
|
||||
self.agent_name = agent_name
|
||||
self.output = output
|
||||
335
agent/protocol/message_utils.py
Normal file
335
agent/protocol/message_utils.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Message sanitizer — fix broken tool_use / tool_result pairs.
|
||||
|
||||
Provides two public helpers that can be reused across agent_stream.py
|
||||
and any bot that converts messages to OpenAI format:
|
||||
|
||||
1. sanitize_claude_messages(messages)
|
||||
Operates on the internal Claude-format message list (in-place).
|
||||
|
||||
2. drop_orphaned_tool_results_openai(messages)
|
||||
Operates on an already-converted OpenAI-format message list,
|
||||
returning a cleaned copy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from common.log import logger
|
||||
|
||||
_SYNTH_TOOL_ERR = (
|
||||
"Error: Missing tool_result adjacent to tool_use (session repair). "
|
||||
"The conversation history was inconsistent; continue from here."
|
||||
)
|
||||
|
||||
|
||||
def _repair_tool_use_adjacency(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Anthropic requires: after assistant content with tool_use, the next message
|
||||
must be user content listing tool_result for every tool_use id (same user msg).
|
||||
|
||||
Valid histories satisfy this at every such assistant; the loop only mutates
|
||||
when that condition fails (broken persistence, bad trims, etc.).
|
||||
"""
|
||||
|
||||
def _synth_block(tid: str) -> Dict:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tid,
|
||||
"content": _SYNTH_TOOL_ERR,
|
||||
"is_error": True,
|
||||
}
|
||||
|
||||
repairs = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
if msg.get("role") != "assistant":
|
||||
i += 1
|
||||
continue
|
||||
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
required = [
|
||||
b.get("id")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use" and b.get("id")
|
||||
]
|
||||
if not required:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
req_set = set(required)
|
||||
if i + 1 >= len(messages):
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [_synth_block(tid) for tid in required],
|
||||
})
|
||||
logger.warning(
|
||||
"⚠️ Appended synthetic tool_result after trailing assistant tool_use"
|
||||
)
|
||||
repairs += 1
|
||||
break
|
||||
|
||||
nxt = messages[i + 1]
|
||||
if nxt.get("role") != "user":
|
||||
messages.insert(
|
||||
i + 1,
|
||||
{"role": "user", "content": [_synth_block(tid) for tid in required]},
|
||||
)
|
||||
logger.warning(
|
||||
"⚠️ Inserted synthetic tool_result user after tool_use "
|
||||
f"(next role={nxt.get('role')!r})"
|
||||
)
|
||||
repairs += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
nc = nxt.get("content", [])
|
||||
if not isinstance(nc, list):
|
||||
messages.insert(
|
||||
i + 1,
|
||||
{"role": "user", "content": [_synth_block(tid) for tid in required]},
|
||||
)
|
||||
repairs += 1
|
||||
i += 2
|
||||
continue
|
||||
|
||||
present = {
|
||||
b.get("tool_use_id")
|
||||
for b in nc
|
||||
if isinstance(b, dict) and b.get("type") == "tool_result" and b.get("tool_use_id")
|
||||
}
|
||||
if req_set <= present:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
missing = [tid for tid in required if tid not in present]
|
||||
nxt["content"] = [_synth_block(tid) for tid in missing] + nc
|
||||
logger.warning(
|
||||
"⚠️ Prepended synthetic tool_result for Anthropic adjacency "
|
||||
f"(missing_ids={missing})"
|
||||
)
|
||||
repairs += len(missing)
|
||||
i += 1
|
||||
|
||||
return repairs
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Claude-format sanitizer (used by agent_stream)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def sanitize_claude_messages(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Validate and fix a Claude-format message list **in-place**.
|
||||
|
||||
Fixes handled:
|
||||
- Anthropic adjacency: assistant tool_use must be immediately followed by
|
||||
user message(s) containing matching tool_result blocks
|
||||
- Leading orphaned tool_result user messages
|
||||
- Mid-list tool_result blocks whose tool_use_id has no matching
|
||||
tool_use in any preceding assistant message
|
||||
|
||||
Returns: number of removals plus adjacency repair operations (inserts/prepends).
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
removed = 0
|
||||
|
||||
# 1. Adjacency repair (Anthropic: tool_result must be in the next user message)
|
||||
adj_repairs = _repair_tool_use_adjacency(messages)
|
||||
|
||||
# 2. Remove leading orphaned tool_result user messages
|
||||
while messages:
|
||||
first = messages[0]
|
||||
if first.get("role") != "user":
|
||||
break
|
||||
content = first.get("content", [])
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result") \
|
||||
and not _has_block_type(content, "text"):
|
||||
logger.warning("⚠️ Removing leading orphaned tool_result user message")
|
||||
messages.pop(0)
|
||||
removed += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 3. Iteratively remove unmatched tool_use / tool_result until stable.
|
||||
# Removing one broken message can orphan others (e.g. an assistant msg
|
||||
# with both matched and unmatched tool_use — deleting it orphans the
|
||||
# previously-matched tool_result). Loop until clean.
|
||||
for _ in range(5):
|
||||
use_ids: Set[str] = set()
|
||||
result_ids: Set[str] = set()
|
||||
for msg in messages:
|
||||
for block in (msg.get("content") or []):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "tool_use" and block.get("id"):
|
||||
use_ids.add(block["id"])
|
||||
elif block.get("type") == "tool_result" and block.get("tool_use_id"):
|
||||
result_ids.add(block["tool_use_id"])
|
||||
|
||||
bad_use = use_ids - result_ids
|
||||
bad_result = result_ids - use_ids
|
||||
if not bad_use and not bad_result:
|
||||
break
|
||||
|
||||
pass_removed = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if role == "assistant" and bad_use and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
and b.get("id") in bad_use for b in content
|
||||
):
|
||||
logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
|
||||
if role == "user" and bad_result and _has_block_type(content, "tool_result"):
|
||||
has_bad = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result for b in content
|
||||
)
|
||||
if has_bad:
|
||||
if not _has_block_type(content, "text"):
|
||||
logger.warning(f"⚠️ Removing user msg with unmatched tool_result")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
else:
|
||||
before = len(content)
|
||||
msg["content"] = [
|
||||
b for b in content
|
||||
if not (isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result)
|
||||
]
|
||||
pass_removed += before - len(msg["content"])
|
||||
|
||||
i += 1
|
||||
|
||||
removed += pass_removed
|
||||
if pass_removed == 0:
|
||||
break
|
||||
|
||||
# 4. Removals above can break adjacency; re-run repair only if something was removed.
|
||||
if removed:
|
||||
adj_repairs += _repair_tool_use_adjacency(messages)
|
||||
|
||||
if removed:
|
||||
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
|
||||
if adj_repairs:
|
||||
logger.info(f"🔧 Message validation: adjacency repairs={adj_repairs}")
|
||||
return removed + adj_repairs
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# OpenAI-format sanitizer (used by minimax_bot, openai_compatible_bot)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def drop_orphaned_tool_results_openai(messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Return a copy of *messages* (OpenAI format) with any ``role=tool``
|
||||
messages removed if their ``tool_call_id`` does not match a
|
||||
``tool_calls[].id`` in a preceding assistant message.
|
||||
"""
|
||||
known_ids: Set[str] = set()
|
||||
cleaned: List[Dict] = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id", "")
|
||||
if tc_id:
|
||||
known_ids.add(tc_id)
|
||||
|
||||
if msg.get("role") == "tool":
|
||||
ref_id = msg.get("tool_call_id", "")
|
||||
if ref_id and ref_id not in known_ids:
|
||||
logger.warning(
|
||||
f"[MessageSanitizer] Dropping orphaned tool result "
|
||||
f"(tool_call_id={ref_id} not in known ids)"
|
||||
)
|
||||
continue
|
||||
cleaned.append(msg)
|
||||
return cleaned
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _has_block_type(content: list, block_type: str) -> bool:
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == block_type
|
||||
for b in content
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from a message content field (str or list of blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def compress_turn_to_text_only(turn: Dict) -> Dict:
|
||||
"""
|
||||
Compress a full turn (with tool_use/tool_result chains) into a lightweight
|
||||
text-only turn that keeps only the first user text and the last assistant text.
|
||||
|
||||
This preserves the conversational context (what the user asked and what the
|
||||
agent concluded) while stripping out the bulky intermediate tool interactions.
|
||||
|
||||
Returns a new turn dict with a ``messages`` list; the original is not mutated.
|
||||
"""
|
||||
user_text = ""
|
||||
last_assistant_text = ""
|
||||
|
||||
for msg in turn["messages"]:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result"):
|
||||
continue
|
||||
if not user_text:
|
||||
user_text = _extract_text_from_content(content)
|
||||
|
||||
elif role == "assistant":
|
||||
text = _extract_text_from_content(content)
|
||||
if text:
|
||||
last_assistant_text = text
|
||||
|
||||
compressed_messages = []
|
||||
if user_text:
|
||||
compressed_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if last_assistant_text:
|
||||
compressed_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": last_assistant_text}]
|
||||
})
|
||||
|
||||
return {"messages": compressed_messages}
|
||||
57
agent/protocol/models.py
Normal file
57
agent/protocol/models.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Models module for agent system.
|
||||
Provides basic model classes needed by tools and bridge integration.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
"""Request model for LLM operations"""
|
||||
|
||||
def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None,
|
||||
temperature: float = 0.7, max_tokens: Optional[int] = None,
|
||||
stream: bool = False, tools: Optional[List] = None, **kwargs):
|
||||
self.messages = messages or []
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.tools = tools
|
||||
# Allow extra attributes
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LLMModel:
|
||||
"""Base class for LLM models"""
|
||||
|
||||
def __init__(self, model: str = None, **kwargs):
|
||||
self.model = model
|
||||
self.config = kwargs
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with a request.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call not implemented in this context")
|
||||
|
||||
def call_stream(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with streaming.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call_stream not implemented in this context")
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
"""Factory for creating model instances"""
|
||||
|
||||
@staticmethod
|
||||
def create_model(model_type: str, **kwargs):
|
||||
"""
|
||||
Create a model instance based on type.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("ModelFactory.create_model not implemented in this context")
|
||||
97
agent/protocol/result.py
Normal file
97
agent/protocol/result.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from agent.protocol.task import Task, TaskStatus
|
||||
|
||||
|
||||
class AgentActionType(Enum):
|
||||
"""Enum representing different types of agent actions."""
|
||||
TOOL_USE = "tool_use"
|
||||
THINKING = "thinking"
|
||||
FINAL_ANSWER = "final_answer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""
|
||||
Represents the result of a tool use.
|
||||
|
||||
Attributes:
|
||||
tool_name: Name of the tool used
|
||||
input_params: Parameters passed to the tool
|
||||
output: Output from the tool
|
||||
status: Status of the tool execution (success/error)
|
||||
error_message: Error message if the tool execution failed
|
||||
execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_name: str
|
||||
input_params: Dict[str, Any]
|
||||
output: Any
|
||||
status: str
|
||||
error_message: Optional[str] = None
|
||||
execution_time: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""
|
||||
Represents an action taken by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the action
|
||||
agent_id: ID of the agent that performed the action
|
||||
agent_name: Name of the agent that performed the action
|
||||
action_type: Type of action (tool use, thinking, final answer)
|
||||
content: Content of the action (thought content, final answer content)
|
||||
tool_result: Tool use details if action_type is TOOL_USE
|
||||
timestamp: When the action was performed
|
||||
"""
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
action_type: AgentActionType
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
tool_result: Optional[ToolResult] = None
|
||||
thought: Optional[str] = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""
|
||||
Represents the result of an agent's execution.
|
||||
|
||||
Attributes:
|
||||
final_answer: The final answer provided by the agent
|
||||
step_count: Number of steps taken by the agent
|
||||
status: Status of the execution (success/error)
|
||||
error_message: Error message if execution failed
|
||||
"""
|
||||
final_answer: str
|
||||
step_count: int
|
||||
status: str = "success"
|
||||
error_message: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def success(cls, final_answer: str, step_count: int) -> "AgentResult":
|
||||
"""Create a successful result"""
|
||||
return cls(final_answer=final_answer, step_count=step_count)
|
||||
|
||||
@classmethod
|
||||
def error(cls, error_message: str, step_count: int = 0) -> "AgentResult":
|
||||
"""Create an error result"""
|
||||
return cls(
|
||||
final_answer=f"Error: {error_message}",
|
||||
step_count=step_count,
|
||||
status="error",
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
@property
|
||||
def is_error(self) -> bool:
|
||||
"""Check if the result represents an error"""
|
||||
return self.status == "error"
|
||||
96
agent/protocol/task.py
Normal file
96
agent/protocol/task.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""Enum representing different types of tasks."""
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
FILE = "file"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Enum representing the status of a task."""
|
||||
INIT = "init" # Initial state
|
||||
PROCESSING = "processing" # In progress
|
||||
COMPLETED = "completed" # Completed
|
||||
FAILED = "failed" # Failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""
|
||||
Represents a task to be processed by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the task
|
||||
content: The primary text content of the task
|
||||
type: Type of the task
|
||||
status: Current status of the task
|
||||
created_at: Timestamp when the task was created
|
||||
updated_at: Timestamp when the task was last updated
|
||||
metadata: Additional metadata for the task
|
||||
images: List of image URLs or base64 encoded images
|
||||
videos: List of video URLs
|
||||
audios: List of audio URLs or base64 encoded audios
|
||||
files: List of file URLs or paths
|
||||
"""
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
type: TaskType = TaskType.TEXT
|
||||
status: TaskStatus = TaskStatus.INIT
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Media content
|
||||
images: List[str] = field(default_factory=list)
|
||||
videos: List[str] = field(default_factory=list)
|
||||
audios: List[str] = field(default_factory=list)
|
||||
files: List[str] = field(default_factory=list)
|
||||
|
||||
def __init__(self, content: str = "", **kwargs):
|
||||
"""
|
||||
Initialize a Task with content and optional keyword arguments.
|
||||
|
||||
Args:
|
||||
content: The text content of the task
|
||||
**kwargs: Additional attributes to set
|
||||
"""
|
||||
self.id = kwargs.get('id', str(uuid.uuid4()))
|
||||
self.content = content
|
||||
self.type = kwargs.get('type', TaskType.TEXT)
|
||||
self.status = kwargs.get('status', TaskStatus.INIT)
|
||||
self.created_at = kwargs.get('created_at', time.time())
|
||||
self.updated_at = kwargs.get('updated_at', time.time())
|
||||
self.metadata = kwargs.get('metadata', {})
|
||||
self.images = kwargs.get('images', [])
|
||||
self.videos = kwargs.get('videos', [])
|
||||
self.audios = kwargs.get('audios', [])
|
||||
self.files = kwargs.get('files', [])
|
||||
|
||||
def get_text(self) -> str:
|
||||
"""
|
||||
Get the text content of the task.
|
||||
|
||||
Returns:
|
||||
The text content
|
||||
"""
|
||||
return self.content
|
||||
|
||||
def update_status(self, status: TaskStatus) -> None:
|
||||
"""
|
||||
Update the status of the task.
|
||||
|
||||
Args:
|
||||
status: The new status
|
||||
"""
|
||||
self.status = status
|
||||
self.updated_at = time.time()
|
||||
31
agent/skills/__init__.py
Normal file
31
agent/skills/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Skills module for agent system.
|
||||
|
||||
This module provides the framework for loading, managing, and executing skills.
|
||||
Skills are markdown files with frontmatter that provide specialized instructions
|
||||
for specific tasks.
|
||||
"""
|
||||
|
||||
from agent.skills.types import (
|
||||
Skill,
|
||||
SkillEntry,
|
||||
SkillMetadata,
|
||||
SkillInstallSpec,
|
||||
LoadSkillsResult,
|
||||
)
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.manager import SkillManager
|
||||
from agent.skills.service import SkillService
|
||||
from agent.skills.formatter import format_skills_for_prompt
|
||||
|
||||
__all__ = [
|
||||
"Skill",
|
||||
"SkillEntry",
|
||||
"SkillMetadata",
|
||||
"SkillInstallSpec",
|
||||
"LoadSkillsResult",
|
||||
"SkillLoader",
|
||||
"SkillManager",
|
||||
"SkillService",
|
||||
"format_skills_for_prompt",
|
||||
]
|
||||
230
agent/skills/config.py
Normal file
230
agent/skills/config.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Configuration support for skills.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from typing import Dict, Optional, List
|
||||
from agent.skills.types import SkillEntry
|
||||
|
||||
|
||||
def resolve_runtime_platform() -> str:
|
||||
"""Get the current runtime platform."""
|
||||
return platform.system().lower()
|
||||
|
||||
|
||||
def has_binary(bin_name: str) -> bool:
|
||||
"""
|
||||
Check if a binary is available in PATH.
|
||||
|
||||
:param bin_name: Binary name to check
|
||||
:return: True if binary is available
|
||||
"""
|
||||
import shutil
|
||||
return shutil.which(bin_name) is not None
|
||||
|
||||
|
||||
def has_any_binary(bin_names: List[str]) -> bool:
|
||||
"""
|
||||
Check if any of the given binaries is available.
|
||||
|
||||
:param bin_names: List of binary names to check
|
||||
:return: True if at least one binary is available
|
||||
"""
|
||||
return any(has_binary(bin_name) for bin_name in bin_names)
|
||||
|
||||
|
||||
def has_env_var(env_name: str) -> bool:
|
||||
"""
|
||||
Check if an environment variable is set.
|
||||
|
||||
:param env_name: Environment variable name
|
||||
:return: True if environment variable is set
|
||||
"""
|
||||
return env_name in os.environ and bool(os.environ[env_name].strip())
|
||||
|
||||
|
||||
def get_skill_config(config: Optional[Dict], skill_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get skill-specific configuration.
|
||||
|
||||
:param config: Global configuration dictionary
|
||||
:param skill_name: Name of the skill
|
||||
:return: Skill configuration or None
|
||||
"""
|
||||
if not config:
|
||||
return None
|
||||
|
||||
skills_config = config.get('skills', {})
|
||||
if not isinstance(skills_config, dict):
|
||||
return None
|
||||
|
||||
entries = skills_config.get('entries', {})
|
||||
if not isinstance(entries, dict):
|
||||
return None
|
||||
|
||||
return entries.get(skill_name)
|
||||
|
||||
|
||||
def should_include_skill(
|
||||
entry: SkillEntry,
|
||||
config: Optional[Dict] = None,
|
||||
current_platform: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a skill should be included based on requirements.
|
||||
|
||||
Simple rule: Skills are auto-enabled if their requirements are met.
|
||||
- Has required API keys → enabled
|
||||
- Missing API keys → disabled
|
||||
- Wrong keys → enabled but will fail at runtime (LLM will handle error)
|
||||
|
||||
:param entry: SkillEntry to check
|
||||
:param config: Configuration dictionary (currently unused, reserved for future)
|
||||
:param current_platform: Current platform (default: auto-detect)
|
||||
:return: True if skill should be included
|
||||
"""
|
||||
metadata = entry.metadata
|
||||
|
||||
# No metadata = always include (no requirements)
|
||||
if not metadata:
|
||||
return True
|
||||
|
||||
# Check platform requirements (can't work on wrong platform)
|
||||
if metadata.os:
|
||||
platform_name = current_platform or resolve_runtime_platform()
|
||||
# Map common platform names
|
||||
platform_map = {
|
||||
'darwin': 'darwin',
|
||||
'linux': 'linux',
|
||||
'windows': 'win32',
|
||||
}
|
||||
normalized_platform = platform_map.get(platform_name, platform_name)
|
||||
|
||||
if normalized_platform not in metadata.os:
|
||||
return False
|
||||
|
||||
# If skill has 'always: true', include it regardless of other requirements
|
||||
if metadata.always:
|
||||
return True
|
||||
|
||||
# Check requirements
|
||||
if metadata.requires:
|
||||
# Check required binaries (all must be present)
|
||||
required_bins = metadata.requires.get('bins', [])
|
||||
if required_bins:
|
||||
if not all(has_binary(bin_name) for bin_name in required_bins):
|
||||
return False
|
||||
|
||||
# Check anyBins (at least one must be present)
|
||||
any_bins = metadata.requires.get('anyBins', [])
|
||||
if any_bins:
|
||||
if not has_any_binary(any_bins):
|
||||
return False
|
||||
|
||||
# Check environment variables (API keys)
|
||||
# All required env vars must be set
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
for env_name in required_env:
|
||||
if not has_env_var(env_name):
|
||||
return False
|
||||
|
||||
# Check anyEnv (at least one must be present)
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env:
|
||||
if not any(has_env_var(e) for e in any_env):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_missing_requirements(
|
||||
entry: SkillEntry,
|
||||
current_platform: Optional[str] = None,
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Return a dict of missing requirements for a skill.
|
||||
Empty dict means all requirements are met.
|
||||
|
||||
:param entry: SkillEntry to check
|
||||
:param current_platform: Current platform (default: auto-detect)
|
||||
:return: Dict like {"bins": ["curl"], "env": ["API_KEY"]}
|
||||
"""
|
||||
missing: Dict[str, List[str]] = {}
|
||||
metadata = entry.metadata
|
||||
|
||||
if not metadata or not metadata.requires:
|
||||
return missing
|
||||
|
||||
required_bins = metadata.requires.get('bins', [])
|
||||
if required_bins:
|
||||
missing_bins = [b for b in required_bins if not has_binary(b)]
|
||||
if missing_bins:
|
||||
missing['bins'] = missing_bins
|
||||
|
||||
any_bins = metadata.requires.get('anyBins', [])
|
||||
if any_bins and not has_any_binary(any_bins):
|
||||
missing['anyBins'] = any_bins
|
||||
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
missing_env = [e for e in required_env if not has_env_var(e)]
|
||||
if missing_env:
|
||||
missing['env'] = missing_env
|
||||
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env and not any(has_env_var(e) for e in any_env):
|
||||
missing['anyEnv'] = any_env
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def is_config_path_truthy(config: Dict, path: str) -> bool:
|
||||
"""
|
||||
Check if a config path resolves to a truthy value.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:param path: Dot-separated path (e.g., 'skills.enabled')
|
||||
:return: True if path resolves to truthy value
|
||||
"""
|
||||
parts = path.split('.')
|
||||
current = config
|
||||
|
||||
for part in parts:
|
||||
if not isinstance(current, dict):
|
||||
return False
|
||||
current = current.get(part)
|
||||
if current is None:
|
||||
return False
|
||||
|
||||
# Check if value is truthy
|
||||
if isinstance(current, bool):
|
||||
return current
|
||||
if isinstance(current, (int, float)):
|
||||
return current != 0
|
||||
if isinstance(current, str):
|
||||
return bool(current.strip())
|
||||
|
||||
return bool(current)
|
||||
|
||||
|
||||
def resolve_config_path(config: Dict, path: str):
|
||||
"""
|
||||
Resolve a dot-separated config path to its value.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:param path: Dot-separated path
|
||||
:return: Value at path or None
|
||||
"""
|
||||
parts = path.split('.')
|
||||
current = config
|
||||
|
||||
for part in parts:
|
||||
if not isinstance(current, dict):
|
||||
return None
|
||||
current = current.get(part)
|
||||
if current is None:
|
||||
return None
|
||||
|
||||
return current
|
||||
126
agent/skills/formatter.py
Normal file
126
agent/skills/formatter.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
Skill formatter for generating prompts from skills.
|
||||
"""
|
||||
|
||||
from typing import Dict, List
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
|
||||
|
||||
def format_skills_for_prompt(skills: List[Skill]) -> str:
|
||||
"""
|
||||
Format skills for inclusion in a system prompt.
|
||||
|
||||
Uses XML format per Agent Skills standard.
|
||||
Skills with disable_model_invocation=True are excluded.
|
||||
|
||||
:param skills: List of skills to format
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
# Filter out skills that should not be invoked by the model
|
||||
visible_skills = [s for s in skills if not s.disable_model_invocation]
|
||||
|
||||
if not visible_skills:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"",
|
||||
"<available_skills>",
|
||||
]
|
||||
|
||||
for skill in visible_skills:
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <location>{_escape_xml(skill.file_path)}</location>")
|
||||
lines.append(f" <base_dir>{_escape_xml(skill.base_dir)}</base_dir>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</available_skills>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
|
||||
"""
|
||||
Format skill entries for inclusion in a system prompt.
|
||||
|
||||
:param entries: List of skill entries to format
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
skills = [entry.skill for entry in entries]
|
||||
return format_skills_for_prompt(skills)
|
||||
|
||||
|
||||
def format_unavailable_skills_for_prompt(
|
||||
entries: List[SkillEntry],
|
||||
missing_map: Dict[str, Dict[str, List[str]]],
|
||||
) -> str:
|
||||
"""
|
||||
Format unavailable (requires-not-met) skills as brief setup hints
|
||||
so the AI can guide users to configure them.
|
||||
|
||||
:param entries: List of unavailable skill entries
|
||||
:param missing_map: Dict mapping skill name to its missing requirements
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
if not entries:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"",
|
||||
"<unavailable_skills>",
|
||||
"The following skills are installed but not yet ready. "
|
||||
"Guide the user to complete the setup when relevant.",
|
||||
]
|
||||
|
||||
for entry in entries:
|
||||
skill = entry.skill
|
||||
missing = missing_map.get(skill.name, {})
|
||||
|
||||
missing_parts = []
|
||||
for key, values in missing.items():
|
||||
missing_parts.append(f"{key}: {', '.join(values)}")
|
||||
missing_str = "; ".join(missing_parts) if missing_parts else "unknown"
|
||||
|
||||
setup_hint = _extract_setup_hint(skill)
|
||||
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <missing>{_escape_xml(missing_str)}</missing>")
|
||||
if setup_hint:
|
||||
lines.append(f" <setup>{_escape_xml(setup_hint)}</setup>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</unavailable_skills>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_setup_hint(skill: Skill) -> str:
|
||||
"""
|
||||
Extract the Setup section from SKILL.md content as a brief hint.
|
||||
Returns the first few lines of the ## Setup section.
|
||||
"""
|
||||
content = skill.content
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
import re
|
||||
match = re.search(r'^##\s+Setup\s*\n(.*?)(?=\n##\s|\Z)', content, re.MULTILINE | re.DOTALL)
|
||||
if not match:
|
||||
return ""
|
||||
|
||||
setup_text = match.group(1).strip()
|
||||
lines = setup_text.split('\n')
|
||||
hint_lines = [l.strip() for l in lines[:6] if l.strip()]
|
||||
return ' '.join(hint_lines)[:300]
|
||||
|
||||
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (text
|
||||
.replace('&', '&')
|
||||
.replace('<', '<')
|
||||
.replace('>', '>')
|
||||
.replace('"', '"')
|
||||
.replace("'", '''))
|
||||
192
agent/skills/frontmatter.py
Normal file
192
agent/skills/frontmatter.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Frontmatter parsing for skills.
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from agent.skills.types import SkillMetadata, SkillInstallSpec
|
||||
|
||||
|
||||
def parse_frontmatter(content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse YAML-style frontmatter from markdown content.
|
||||
|
||||
Returns a dictionary of frontmatter fields.
|
||||
"""
|
||||
frontmatter = {}
|
||||
|
||||
# Match frontmatter block between --- markers
|
||||
match = re.match(r'^---\s*\n(.*?)\n---\s*\n', content, re.DOTALL)
|
||||
if not match:
|
||||
return frontmatter
|
||||
|
||||
frontmatter_text = match.group(1)
|
||||
|
||||
# Try to use PyYAML for proper YAML parsing
|
||||
try:
|
||||
import yaml
|
||||
frontmatter = yaml.safe_load(frontmatter_text)
|
||||
if not isinstance(frontmatter, dict):
|
||||
frontmatter = {}
|
||||
return frontmatter
|
||||
except ImportError:
|
||||
# Fallback to simple parsing if PyYAML not available
|
||||
pass
|
||||
except Exception:
|
||||
# If YAML parsing fails, fall back to simple parsing
|
||||
pass
|
||||
|
||||
# Simple YAML-like parsing (supports key: value format only)
|
||||
# This is a fallback for when PyYAML is not available
|
||||
for line in frontmatter_text.split('\n'):
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
if ':' in line:
|
||||
key, value = line.split(':', 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Try to parse as JSON if it looks like JSON
|
||||
if value.startswith('{') or value.startswith('['):
|
||||
try:
|
||||
value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
# Parse boolean values
|
||||
elif value.lower() in ('true', 'false'):
|
||||
value = value.lower() == 'true'
|
||||
# Parse numbers
|
||||
elif value.isdigit():
|
||||
value = int(value)
|
||||
|
||||
frontmatter[key] = value
|
||||
|
||||
return frontmatter
|
||||
|
||||
|
||||
def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
"""
|
||||
Parse skill metadata from frontmatter.
|
||||
|
||||
Looks for 'metadata' field containing JSON with skill configuration.
|
||||
"""
|
||||
metadata_raw = frontmatter.get('metadata')
|
||||
if not metadata_raw:
|
||||
return None
|
||||
|
||||
# If it's a string, try to parse as JSON
|
||||
if isinstance(metadata_raw, str):
|
||||
try:
|
||||
metadata_raw = json.loads(metadata_raw)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
if not isinstance(metadata_raw, dict):
|
||||
return None
|
||||
|
||||
# Unwrap nested namespace (e.g. {"openclaw": {...}} or {"cowagent": {...}})
|
||||
meta_obj = _unwrap_metadata_namespace(metadata_raw)
|
||||
|
||||
# Parse install specs
|
||||
install_specs = []
|
||||
install_raw = meta_obj.get('install', [])
|
||||
if isinstance(install_raw, list):
|
||||
for spec_raw in install_raw:
|
||||
if not isinstance(spec_raw, dict):
|
||||
continue
|
||||
|
||||
kind = spec_raw.get('kind', spec_raw.get('type', '')).lower()
|
||||
if not kind:
|
||||
continue
|
||||
|
||||
spec = SkillInstallSpec(
|
||||
kind=kind,
|
||||
id=spec_raw.get('id'),
|
||||
label=spec_raw.get('label'),
|
||||
bins=_normalize_string_list(spec_raw.get('bins')),
|
||||
os=_normalize_string_list(spec_raw.get('os')),
|
||||
formula=spec_raw.get('formula'),
|
||||
package=spec_raw.get('package'),
|
||||
module=spec_raw.get('module'),
|
||||
url=spec_raw.get('url'),
|
||||
archive=spec_raw.get('archive'),
|
||||
extract=spec_raw.get('extract', False),
|
||||
strip_components=spec_raw.get('stripComponents'),
|
||||
target_dir=spec_raw.get('targetDir'),
|
||||
)
|
||||
install_specs.append(spec)
|
||||
|
||||
# Parse requires
|
||||
requires = {}
|
||||
requires_raw = meta_obj.get('requires', {})
|
||||
if isinstance(requires_raw, dict):
|
||||
for key, value in requires_raw.items():
|
||||
requires[key] = _normalize_string_list(value)
|
||||
|
||||
return SkillMetadata(
|
||||
always=meta_obj.get('always', False),
|
||||
default_enabled=meta_obj.get('default_enabled', True),
|
||||
skill_key=meta_obj.get('skillKey'),
|
||||
primary_env=meta_obj.get('primaryEnv'),
|
||||
emoji=meta_obj.get('emoji'),
|
||||
homepage=meta_obj.get('homepage'),
|
||||
os=_normalize_string_list(meta_obj.get('os')),
|
||||
requires=requires,
|
||||
install=install_specs,
|
||||
)
|
||||
|
||||
|
||||
_KNOWN_METADATA_NAMESPACES = {"cowagent", "openclaw"}
|
||||
|
||||
|
||||
def _unwrap_metadata_namespace(metadata_raw: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Unwrap a single-key namespace wrapper like {"cowagent": {...} or {"openclaw": {...}}}.
|
||||
If the top-level dict has exactly one key matching a known namespace, return the inner dict.
|
||||
Otherwise return the original dict unchanged.
|
||||
"""
|
||||
keys = set(metadata_raw.keys())
|
||||
ns_keys = keys & _KNOWN_METADATA_NAMESPACES
|
||||
if len(ns_keys) == 1 and len(keys) == 1:
|
||||
ns = ns_keys.pop()
|
||||
inner = metadata_raw[ns]
|
||||
if isinstance(inner, dict):
|
||||
return inner
|
||||
return metadata_raw
|
||||
|
||||
|
||||
def _normalize_string_list(value: Any) -> List[str]:
|
||||
"""Normalize a value to a list of strings."""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
if isinstance(value, list):
|
||||
return [str(v).strip() for v in value if v]
|
||||
|
||||
if isinstance(value, str):
|
||||
return [v.strip() for v in value.split(',') if v.strip()]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def parse_boolean_value(value: Optional[str], default: bool = False) -> bool:
|
||||
"""Parse a boolean value from frontmatter."""
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ('true', '1', 'yes', 'on')
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def get_frontmatter_value(frontmatter: Dict[str, Any], key: str) -> Optional[str]:
|
||||
"""Get a frontmatter value as a string."""
|
||||
value = frontmatter.get(key)
|
||||
return str(value) if value is not None else None
|
||||
286
agent/skills/loader.py
Normal file
286
agent/skills/loader.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Skill loader for discovering and loading skills from directories.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry, LoadSkillsResult, SkillMetadata
|
||||
from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_boolean_value, get_frontmatter_value
|
||||
|
||||
|
||||
class SkillLoader:
|
||||
"""Loads skills from various directories."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load skills from a directory.
|
||||
|
||||
Discovery rules:
|
||||
- Direct .md files in the root directory
|
||||
- Recursive SKILL.md files under subdirectories
|
||||
|
||||
:param dir_path: Directory path to scan
|
||||
:param source: Source identifier ('builtin' or 'custom')
|
||||
:return: LoadSkillsResult with skills and diagnostics
|
||||
"""
|
||||
skills = []
|
||||
diagnostics = []
|
||||
|
||||
if not os.path.exists(dir_path):
|
||||
diagnostics.append(f"Directory does not exist: {dir_path}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
if not os.path.isdir(dir_path):
|
||||
diagnostics.append(f"Path is not a directory: {dir_path}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
# Load skills from root-level .md files and subdirectories
|
||||
result = self._load_skills_recursive(dir_path, source, include_root_files=True)
|
||||
|
||||
return result
|
||||
|
||||
def _load_skills_recursive(
|
||||
self,
|
||||
dir_path: str,
|
||||
source: str,
|
||||
include_root_files: bool = False
|
||||
) -> LoadSkillsResult:
|
||||
"""
|
||||
Recursively load skills from a directory.
|
||||
|
||||
If a subdirectory contains its own SKILL.md, it is treated as a
|
||||
self-contained skill (or skill-collection) and its children are
|
||||
NOT scanned further. This prevents sub-skills inside a collection
|
||||
(e.g. style-collection/style-anjing) from being listed as
|
||||
independent top-level skills.
|
||||
|
||||
:param dir_path: Directory to scan
|
||||
:param source: Source identifier
|
||||
:param include_root_files: Whether to include root-level .md files
|
||||
:return: LoadSkillsResult
|
||||
"""
|
||||
skills = []
|
||||
diagnostics = []
|
||||
|
||||
try:
|
||||
entries = os.listdir(dir_path)
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
# If this directory has its own SKILL.md, load it and stop recursing.
|
||||
# The sub-directories are internal resources of this skill.
|
||||
if not include_root_files and 'SKILL.md' in entries:
|
||||
skill_md_path = os.path.join(dir_path, 'SKILL.md')
|
||||
if os.path.isfile(skill_md_path):
|
||||
skill_result = self._load_skill_from_file(skill_md_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
diagnostics.extend(skill_result.diagnostics)
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
for entry in entries:
|
||||
if entry.startswith('.'):
|
||||
continue
|
||||
|
||||
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
|
||||
continue
|
||||
|
||||
full_path = os.path.join(dir_path, entry)
|
||||
|
||||
if os.path.isdir(full_path):
|
||||
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
|
||||
skills.extend(sub_result.skills)
|
||||
diagnostics.extend(sub_result.diagnostics)
|
||||
continue
|
||||
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
|
||||
is_root_md = include_root_files and entry.endswith('.md') and entry.upper() != 'README.MD'
|
||||
|
||||
if not is_root_md:
|
||||
continue
|
||||
|
||||
skill_result = self._load_skill_from_file(full_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
diagnostics.extend(skill_result.diagnostics)
|
||||
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
def _load_skill_from_file(self, file_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load a single skill from a markdown file.
|
||||
|
||||
:param file_path: Path to the skill markdown file
|
||||
:param source: Source identifier
|
||||
:return: LoadSkillsResult
|
||||
"""
|
||||
diagnostics = []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to read skill file {file_path}: {e}")
|
||||
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||
|
||||
# Parse frontmatter
|
||||
frontmatter = parse_frontmatter(content)
|
||||
|
||||
# Get skill name and description
|
||||
skill_dir = os.path.dirname(file_path)
|
||||
parent_dir_name = os.path.basename(skill_dir)
|
||||
|
||||
name = frontmatter.get('name', parent_dir_name)
|
||||
description = frontmatter.get('description', '')
|
||||
|
||||
# Normalize name (handle both string and list)
|
||||
if isinstance(name, list):
|
||||
name = name[0] if name else parent_dir_name
|
||||
elif not isinstance(name, str):
|
||||
name = str(name) if name else parent_dir_name
|
||||
|
||||
# Normalize description (handle both string and list)
|
||||
if isinstance(description, list):
|
||||
description = ' '.join(str(d) for d in description if d)
|
||||
elif not isinstance(description, str):
|
||||
description = str(description) if description else ''
|
||||
|
||||
# Special handling for linkai-agent: dynamically load apps from config.json
|
||||
if name == 'linkai-agent':
|
||||
description = self._load_linkai_agent_description(skill_dir, description)
|
||||
|
||||
if not description or not description.strip():
|
||||
diagnostics.append(f"Skill {name} has no description: {file_path}")
|
||||
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||
|
||||
# Parse disable-model-invocation flag
|
||||
disable_model_invocation = parse_boolean_value(
|
||||
get_frontmatter_value(frontmatter, 'disable-model-invocation'),
|
||||
default=False
|
||||
)
|
||||
|
||||
# Create skill object
|
||||
skill = Skill(
|
||||
name=name,
|
||||
description=description,
|
||||
file_path=file_path,
|
||||
base_dir=skill_dir,
|
||||
source=source,
|
||||
content=content,
|
||||
disable_model_invocation=disable_model_invocation,
|
||||
frontmatter=frontmatter,
|
||||
)
|
||||
|
||||
return LoadSkillsResult(skills=[skill], diagnostics=diagnostics)
|
||||
|
||||
def _load_linkai_agent_description(self, skill_dir: str, default_description: str) -> str:
|
||||
"""
|
||||
Dynamically load LinkAI agent description from config.json
|
||||
|
||||
:param skill_dir: Skill directory
|
||||
:param default_description: Default description from SKILL.md
|
||||
:return: Dynamic description with app list
|
||||
"""
|
||||
import json
|
||||
|
||||
config_path = os.path.join(skill_dir, "config.json")
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
logger.debug(f"[SkillLoader] linkai-agent skipped: no config.json found")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
apps = config.get("apps", [])
|
||||
if not apps:
|
||||
return default_description
|
||||
|
||||
# Build dynamic description with app details
|
||||
app_descriptions = "; ".join([
|
||||
f"{app['app_name']}({app['app_code']}: {app['app_description']})"
|
||||
for app in apps
|
||||
])
|
||||
|
||||
return f"Call LinkAI apps/workflows. {app_descriptions}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[SkillLoader] Failed to load linkai-agent config: {e}")
|
||||
return default_description
|
||||
|
||||
def load_all_skills(
|
||||
self,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
) -> Dict[str, SkillEntry]:
|
||||
"""
|
||||
Load skills from builtin and custom directories.
|
||||
|
||||
Precedence (lowest to highest):
|
||||
1. builtin — project root ``skills/``, shipped with the codebase
|
||||
2. custom — workspace ``skills/``, installed via cloud console or skill creator
|
||||
|
||||
Same-name custom skills override builtin ones.
|
||||
|
||||
:param builtin_dir: Built-in skills directory
|
||||
:param custom_dir: Custom skills directory
|
||||
:return: Dictionary mapping skill name to SkillEntry
|
||||
"""
|
||||
skill_map: Dict[str, SkillEntry] = {}
|
||||
all_diagnostics = []
|
||||
|
||||
# Load builtin skills (lower precedence)
|
||||
if builtin_dir and os.path.exists(builtin_dir):
|
||||
result = self.load_skills_from_dir(builtin_dir, source='builtin')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load custom skills (higher precedence, overrides builtin)
|
||||
if custom_dir and os.path.exists(custom_dir):
|
||||
result = self.load_skills_from_dir(custom_dir, source='custom')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Log diagnostics
|
||||
if all_diagnostics:
|
||||
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
|
||||
for diag in all_diagnostics[:5]:
|
||||
logger.debug(f" - {diag}")
|
||||
|
||||
logger.debug(f"Loaded {len(skill_map)} skills total")
|
||||
|
||||
return skill_map
|
||||
|
||||
def _create_skill_entry(self, skill: Skill) -> SkillEntry:
|
||||
"""
|
||||
Create a SkillEntry from a Skill with parsed metadata.
|
||||
|
||||
:param skill: The skill to create an entry for
|
||||
:return: SkillEntry with metadata
|
||||
"""
|
||||
metadata = parse_metadata(skill.frontmatter)
|
||||
|
||||
# Parse user-invocable flag
|
||||
user_invocable = parse_boolean_value(
|
||||
get_frontmatter_value(skill.frontmatter, 'user-invocable'),
|
||||
default=True
|
||||
)
|
||||
|
||||
return SkillEntry(
|
||||
skill=skill,
|
||||
metadata=metadata,
|
||||
user_invocable=user_invocable,
|
||||
)
|
||||
361
agent/skills/manager.py
Normal file
361
agent/skills/manager.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Skill manager for managing skill lifecycle and operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry, SkillSnapshot
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.formatter import format_skill_entries_for_prompt
|
||||
|
||||
SKILLS_CONFIG_FILE = "skills_config.json"
|
||||
|
||||
|
||||
class SkillManager:
|
||||
"""Manages skills for an agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
config: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the skill manager.
|
||||
|
||||
:param builtin_dir: Built-in skills directory (project root ``skills/``)
|
||||
:param custom_dir: Custom skills directory (workspace ``skills/``)
|
||||
:param config: Configuration dictionary
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
self.builtin_dir = builtin_dir or os.path.join(project_root, 'skills')
|
||||
self.custom_dir = custom_dir or os.path.join(project_root, 'workspace', 'skills')
|
||||
self.config = config or {}
|
||||
self._skills_config_path = os.path.join(self.custom_dir, SKILLS_CONFIG_FILE)
|
||||
|
||||
# skills_config: full skill metadata keyed by name
|
||||
# { "web-fetch": {"name": ..., "description": ..., "source": ..., "enabled": true}, ... }
|
||||
self.skills_config: Dict[str, dict] = {}
|
||||
|
||||
self.loader = SkillLoader()
|
||||
self.skills: Dict[str, SkillEntry] = {}
|
||||
|
||||
# Load skills on initialization
|
||||
self.refresh_skills()
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Reload all skills from builtin and custom directories, then sync config."""
|
||||
self.skills = self.loader.load_all_skills(
|
||||
builtin_dir=self.builtin_dir,
|
||||
custom_dir=self.custom_dir,
|
||||
)
|
||||
self._sync_skills_config()
|
||||
logger.debug(f"SkillManager: Loaded {len(self.skills)} skills")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# skills_config.json management
|
||||
# ------------------------------------------------------------------
|
||||
def _load_skills_config(self) -> Dict[str, dict]:
|
||||
"""Load skills_config.json from custom_dir. Returns empty dict if not found."""
|
||||
if not os.path.exists(self._skills_config_path):
|
||||
return {}
|
||||
try:
|
||||
with open(self._skills_config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning(f"[SkillManager] Failed to load {SKILLS_CONFIG_FILE}: {e}")
|
||||
return {}
|
||||
|
||||
def _save_skills_config(self):
|
||||
"""Persist skills_config to custom_dir/skills_config.json."""
|
||||
os.makedirs(self.custom_dir, exist_ok=True)
|
||||
try:
|
||||
with open(self._skills_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.skills_config, f, indent=4, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillManager] Failed to save {SKILLS_CONFIG_FILE}: {e}")
|
||||
|
||||
def _sync_skills_config(self):
|
||||
"""
|
||||
Merge directory-scanned skills with the persisted config file.
|
||||
|
||||
- New skills: use metadata.default_enabled as initial enabled state.
|
||||
- Existing skills: preserve their persisted enabled state.
|
||||
- Skills that no longer exist on disk are removed.
|
||||
- name/description/source are always refreshed from the latest scan.
|
||||
"""
|
||||
saved = self._load_skills_config()
|
||||
merged: Dict[str, dict] = {}
|
||||
|
||||
for name, entry in self.skills.items():
|
||||
skill = entry.skill
|
||||
prev = saved.get(name, {})
|
||||
category = prev.get("category", "skill")
|
||||
|
||||
if name in saved:
|
||||
enabled = prev.get("enabled", True)
|
||||
else:
|
||||
enabled = entry.metadata.default_enabled if entry.metadata else True
|
||||
|
||||
entry_dict = {
|
||||
"name": name,
|
||||
"description": skill.description,
|
||||
"source": prev.get("source") or skill.source,
|
||||
"enabled": enabled,
|
||||
"category": category,
|
||||
}
|
||||
display_name = prev.get("display_name")
|
||||
if display_name:
|
||||
entry_dict["display_name"] = display_name
|
||||
merged[name] = entry_dict
|
||||
|
||||
self.skills_config = merged
|
||||
self._save_skills_config()
|
||||
|
||||
def is_skill_enabled(self, name: str) -> bool:
|
||||
"""
|
||||
Check if a skill is enabled according to skills_config.
|
||||
|
||||
:param name: skill name
|
||||
:return: True if enabled (default True if not in config)
|
||||
"""
|
||||
entry = self.skills_config.get(name)
|
||||
if entry is None:
|
||||
return True
|
||||
return entry.get("enabled", True)
|
||||
|
||||
def set_skill_enabled(self, name: str, enabled: bool):
|
||||
"""
|
||||
Set a skill's enabled state and persist.
|
||||
|
||||
:param name: skill name
|
||||
:param enabled: True to enable, False to disable
|
||||
"""
|
||||
if name not in self.skills_config:
|
||||
raise ValueError(f"skill '{name}' not found in config")
|
||||
self.skills_config[name]["enabled"] = enabled
|
||||
self._save_skills_config()
|
||||
|
||||
def get_skills_config(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Return the full skills_config dict (for query API).
|
||||
|
||||
:return: copy of skills_config
|
||||
"""
|
||||
return dict(self.skills_config)
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
Get a skill by name.
|
||||
|
||||
:param name: Skill name
|
||||
:return: SkillEntry or None if not found
|
||||
"""
|
||||
return self.skills.get(name)
|
||||
|
||||
def list_skills(self) -> List[SkillEntry]:
|
||||
"""
|
||||
Get all loaded skills.
|
||||
|
||||
:return: List of all skill entries
|
||||
"""
|
||||
return list(self.skills.values())
|
||||
|
||||
@staticmethod
|
||||
def _normalize_skill_filter(skill_filter: Optional[List[str]]) -> Optional[List[str]]:
|
||||
"""Normalize a skill_filter list into a flat list of stripped names."""
|
||||
if skill_filter is None:
|
||||
return None
|
||||
normalized = []
|
||||
for item in skill_filter:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
elif isinstance(item, list):
|
||||
for subitem in item:
|
||||
if isinstance(subitem, str):
|
||||
name = subitem.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
return normalized or None
|
||||
|
||||
def filter_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
include_disabled: bool = False,
|
||||
) -> List[SkillEntry]:
|
||||
"""
|
||||
Filter skills that are eligible (enabled + requirements met).
|
||||
|
||||
:param skill_filter: List of skill names to include (None = all)
|
||||
:param include_disabled: Whether to include disabled skills
|
||||
:return: Filtered list of eligible skill entries
|
||||
"""
|
||||
from agent.skills.config import should_include_skill
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
entries = [e for e in entries if should_include_skill(e, self.config)]
|
||||
|
||||
normalized = self._normalize_skill_filter(skill_filter)
|
||||
if normalized is not None:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
if not include_disabled:
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
from config import conf
|
||||
if not conf().get("knowledge", True):
|
||||
entries = [e for e in entries if e.skill.name != "knowledge-wiki"]
|
||||
|
||||
return entries
|
||||
|
||||
def filter_unavailable_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Find skills that are enabled but have unmet requirements.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Tuple of (entries, missing_map) where missing_map maps
|
||||
skill name to its missing requirements dict
|
||||
"""
|
||||
from agent.skills.config import should_include_skill, get_missing_requirements
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Only enabled skills
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
normalized = self._normalize_skill_filter(skill_filter)
|
||||
if normalized is not None:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Keep only those that fail should_include_skill (requirements not met)
|
||||
unavailable = []
|
||||
missing_map: Dict[str, dict] = {}
|
||||
for e in entries:
|
||||
if not should_include_skill(e, self.config):
|
||||
missing = get_missing_requirements(e)
|
||||
if missing:
|
||||
unavailable.append(e)
|
||||
missing_map[e.skill.name] = missing
|
||||
|
||||
return unavailable, missing_map
|
||||
|
||||
def build_skills_prompt(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a formatted prompt containing available skills
|
||||
and brief hints for unavailable ones.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt
|
||||
"""
|
||||
from common.log import logger
|
||||
from agent.skills.formatter import format_unavailable_skills_for_prompt
|
||||
|
||||
eligible = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
logger.debug(f"[SkillManager] Eligible: {len(eligible)} skills (total: {len(self.skills)})")
|
||||
if eligible:
|
||||
skill_names = [e.skill.name for e in eligible]
|
||||
logger.debug(f"[SkillManager] Eligible skills: {skill_names}")
|
||||
|
||||
result = format_skill_entries_for_prompt(eligible)
|
||||
|
||||
unavailable, missing_map = self.filter_unavailable_skills(skill_filter=skill_filter)
|
||||
if unavailable:
|
||||
unavailable_names = [e.skill.name for e in unavailable]
|
||||
logger.debug(f"[SkillManager] Unavailable skills (setup needed): {unavailable_names}")
|
||||
result += format_unavailable_skills_for_prompt(unavailable, missing_map)
|
||||
|
||||
logger.debug(f"[SkillManager] Generated prompt length: {len(result)}")
|
||||
return result
|
||||
|
||||
def build_skill_snapshot(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
version: Optional[int] = None,
|
||||
) -> SkillSnapshot:
|
||||
"""
|
||||
Build a snapshot of skills for a specific run.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:param version: Optional version number for the snapshot
|
||||
:return: SkillSnapshot
|
||||
"""
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
prompt = format_skill_entries_for_prompt(entries)
|
||||
|
||||
skills_info = []
|
||||
resolved_skills = []
|
||||
|
||||
for entry in entries:
|
||||
skills_info.append({
|
||||
'name': entry.skill.name,
|
||||
'primary_env': entry.metadata.primary_env if entry.metadata else None,
|
||||
})
|
||||
resolved_skills.append(entry.skill)
|
||||
|
||||
return SkillSnapshot(
|
||||
prompt=prompt,
|
||||
skills=skills_info,
|
||||
resolved_skills=resolved_skills,
|
||||
version=version,
|
||||
)
|
||||
|
||||
def sync_skills_to_workspace(self, target_workspace_dir: str):
|
||||
"""
|
||||
Sync all loaded skills to a target workspace directory.
|
||||
|
||||
This is useful for sandbox environments where skills need to be copied.
|
||||
|
||||
:param target_workspace_dir: Target workspace directory
|
||||
"""
|
||||
import shutil
|
||||
|
||||
target_skills_dir = os.path.join(target_workspace_dir, 'skills')
|
||||
|
||||
# Remove existing skills directory
|
||||
if os.path.exists(target_skills_dir):
|
||||
shutil.rmtree(target_skills_dir)
|
||||
|
||||
# Create new skills directory
|
||||
os.makedirs(target_skills_dir, exist_ok=True)
|
||||
|
||||
# Copy each skill
|
||||
for entry in self.skills.values():
|
||||
skill_name = entry.skill.name
|
||||
source_dir = entry.skill.base_dir
|
||||
target_dir = os.path.join(target_skills_dir, skill_name)
|
||||
|
||||
try:
|
||||
shutil.copytree(source_dir, target_dir)
|
||||
logger.debug(f"Synced skill '{skill_name}' to {target_dir}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync skill '{skill_name}': {e}")
|
||||
|
||||
logger.info(f"Synced {len(self.skills)} skills to {target_skills_dir}")
|
||||
|
||||
def get_skill_by_key(self, skill_key: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
Get a skill by its skill key (which may differ from name).
|
||||
|
||||
:param skill_key: Skill key to look up
|
||||
:return: SkillEntry or None
|
||||
"""
|
||||
for entry in self.skills.values():
|
||||
if entry.metadata and entry.metadata.skill_key == skill_key:
|
||||
return entry
|
||||
if entry.skill.name == skill_key:
|
||||
return entry
|
||||
return None
|
||||
285
agent/skills/service.py
Normal file
285
agent/skills/service.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Skill service for handling skill CRUD operations.
|
||||
|
||||
This service provides a unified interface for managing skills, which can be
|
||||
called from the cloud control client (LinkAI), the local web console, or any
|
||||
other management entry point.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tempfile
|
||||
from typing import Dict, List, Optional
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
from agent.skills.manager import SkillManager
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""
|
||||
High-level service for skill lifecycle management.
|
||||
Wraps SkillManager and provides network-aware operations such as
|
||||
downloading skill files from remote URLs.
|
||||
"""
|
||||
|
||||
def __init__(self, skill_manager: SkillManager):
|
||||
"""
|
||||
:param skill_manager: The SkillManager instance to operate on
|
||||
"""
|
||||
self.manager = skill_manager
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# query
|
||||
# ------------------------------------------------------------------
|
||||
def query(self) -> List[dict]:
|
||||
"""
|
||||
Query all skills and return a serialisable list.
|
||||
Reads from skills_config.json (refreshes from disk if needed).
|
||||
|
||||
:return: list of skill info dicts
|
||||
"""
|
||||
self.manager.refresh_skills()
|
||||
config = self.manager.get_skills_config()
|
||||
result = list(config.values())
|
||||
logger.info(f"[SkillService] query: {len(result)} skills found")
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# add / install
|
||||
# ------------------------------------------------------------------
|
||||
def add(self, payload: dict) -> None:
|
||||
"""
|
||||
Add (install) a skill from a remote payload.
|
||||
|
||||
Supported payload types:
|
||||
|
||||
1. ``type: "url"`` – download individual files::
|
||||
|
||||
{
|
||||
"name": "web_search",
|
||||
"type": "url",
|
||||
"enabled": true,
|
||||
"files": [
|
||||
{"url": "https://...", "path": "README.md"},
|
||||
{"url": "https://...", "path": "scripts/main.py"}
|
||||
]
|
||||
}
|
||||
|
||||
2. ``type: "package"`` – download a zip archive and extract::
|
||||
|
||||
{
|
||||
"name": "plugin-custom-tool",
|
||||
"type": "package",
|
||||
"category": "skills",
|
||||
"enabled": true,
|
||||
"files": [{"url": "https://cdn.example.com/skills/custom-tool.zip"}]
|
||||
}
|
||||
|
||||
:param payload: skill add payload from server
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
payload_type = payload.get("type", "url")
|
||||
|
||||
if payload_type == "package":
|
||||
self._add_package(name, payload)
|
||||
else:
|
||||
self._add_url(name, payload)
|
||||
|
||||
self.manager.refresh_skills()
|
||||
|
||||
category = payload.get("category")
|
||||
if category and name in self.manager.skills_config:
|
||||
self.manager.skills_config[name]["category"] = category
|
||||
self.manager._save_skills_config()
|
||||
|
||||
def _add_url(self, name: str, payload: dict) -> None:
|
||||
"""Install a skill by downloading individual files."""
|
||||
files = payload.get("files", [])
|
||||
if not files:
|
||||
raise ValueError("skill files list is empty")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
tmp_dir = skill_dir + ".tmp"
|
||||
if os.path.exists(tmp_dir):
|
||||
shutil.rmtree(tmp_dir)
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
for file_info in files:
|
||||
url = file_info.get("url")
|
||||
rel_path = file_info.get("path")
|
||||
if not url or not rel_path:
|
||||
logger.warning(f"[SkillService] add: skip invalid file entry {file_info}")
|
||||
continue
|
||||
dest = os.path.join(tmp_dir, rel_path)
|
||||
self._download_file(url, dest)
|
||||
except Exception:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
os.rename(tmp_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via url ({len(files)} files)")
|
||||
|
||||
def _add_package(self, name: str, payload: dict) -> None:
|
||||
"""
|
||||
Install a skill by downloading a zip archive and extracting it.
|
||||
|
||||
If the archive contains a single top-level directory, that directory
|
||||
is used as the skill folder directly; otherwise a new directory named
|
||||
after the skill is created to hold the extracted contents.
|
||||
"""
|
||||
files = payload.get("files", [])
|
||||
if not files or not files[0].get("url"):
|
||||
raise ValueError("package url is required")
|
||||
|
||||
url = files[0]["url"]
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
zip_path = os.path.join(tmp_dir, "package.zip")
|
||||
self._download_file(url, zip_path)
|
||||
|
||||
if not zipfile.is_zipfile(zip_path):
|
||||
raise ValueError(f"downloaded file is not a valid zip archive: {url}")
|
||||
|
||||
extract_dir = os.path.join(tmp_dir, "extracted")
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(extract_dir)
|
||||
|
||||
# Determine the actual content root.
|
||||
# If the zip has a single top-level directory, use its contents
|
||||
# so the skill folder is clean (no extra nesting).
|
||||
top_items = [
|
||||
item for item in os.listdir(extract_dir)
|
||||
if not item.startswith(".")
|
||||
]
|
||||
if len(top_items) == 1:
|
||||
single = os.path.join(extract_dir, top_items[0])
|
||||
if os.path.isdir(single):
|
||||
extract_dir = single
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
shutil.copytree(extract_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via package ({url})")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# open / close (enable / disable)
|
||||
# ------------------------------------------------------------------
|
||||
def open(self, payload: dict) -> None:
|
||||
"""
|
||||
Enable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=True)
|
||||
logger.info(f"[SkillService] open: skill '{name}' enabled")
|
||||
|
||||
def close(self, payload: dict) -> None:
|
||||
"""
|
||||
Disable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=False)
|
||||
logger.info(f"[SkillService] close: skill '{name}' disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete
|
||||
# ------------------------------------------------------------------
|
||||
def delete(self, payload: dict) -> None:
|
||||
"""
|
||||
Delete a skill by removing its directory entirely.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
logger.info(f"[SkillService] delete: removed directory {skill_dir}")
|
||||
else:
|
||||
logger.warning(f"[SkillService] delete: skill directory not found: {skill_dir}")
|
||||
|
||||
# Refresh will remove the deleted skill from config automatically
|
||||
self.manager.refresh_skills()
|
||||
logger.info(f"[SkillService] delete: skill '{name}' deleted")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch - single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a skill management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
:param action: one of query / add / open / close / delete
|
||||
:param payload: action-specific payload (may be None for query)
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "query":
|
||||
result_payload = self.query()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
elif action == "add":
|
||||
self.add(payload)
|
||||
elif action == "open":
|
||||
self.open(payload)
|
||||
elif action == "close":
|
||||
self.close(payload)
|
||||
elif action == "delete":
|
||||
self.delete(payload)
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _download_file(url: str, dest: str):
|
||||
"""
|
||||
Download a file from *url* and save to *dest*.
|
||||
|
||||
:param url: remote file URL
|
||||
:param dest: local destination path
|
||||
"""
|
||||
if requests is None:
|
||||
raise RuntimeError("requests library is required for downloading skill files")
|
||||
|
||||
dest_dir = os.path.dirname(dest)
|
||||
if dest_dir:
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
resp = requests.get(url, timeout=60)
|
||||
resp.raise_for_status()
|
||||
with open(dest, "wb") as f:
|
||||
f.write(resp.content)
|
||||
logger.debug(f"[SkillService] downloaded {url} -> {dest}")
|
||||
76
agent/skills/types.py
Normal file
76
agent/skills/types.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Type definitions for skills system.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillInstallSpec:
|
||||
"""Specification for installing skill dependencies."""
|
||||
kind: str # brew, pip, npm, download, etc.
|
||||
id: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
bins: List[str] = field(default_factory=list)
|
||||
os: List[str] = field(default_factory=list)
|
||||
formula: Optional[str] = None # for brew
|
||||
package: Optional[str] = None # for pip/npm
|
||||
module: Optional[str] = None
|
||||
url: Optional[str] = None # for download
|
||||
archive: Optional[str] = None
|
||||
extract: bool = False
|
||||
strip_components: Optional[int] = None
|
||||
target_dir: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillMetadata:
|
||||
"""Metadata for a skill from frontmatter."""
|
||||
always: bool = False # Always include this skill
|
||||
default_enabled: bool = True # Initial enabled state when first discovered
|
||||
skill_key: Optional[str] = None # Override skill key
|
||||
primary_env: Optional[str] = None # Primary environment variable
|
||||
emoji: Optional[str] = None
|
||||
homepage: Optional[str] = None
|
||||
os: List[str] = field(default_factory=list) # Supported OS platforms
|
||||
requires: Dict[str, List[str]] = field(default_factory=dict) # Requirements
|
||||
install: List[SkillInstallSpec] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
"""Represents a skill loaded from a markdown file."""
|
||||
name: str
|
||||
description: str
|
||||
file_path: str
|
||||
base_dir: str
|
||||
source: str # builtin or custom
|
||||
content: str # Full markdown content
|
||||
disable_model_invocation: bool = False
|
||||
frontmatter: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillEntry:
|
||||
"""A skill with parsed metadata."""
|
||||
skill: Skill
|
||||
metadata: Optional[SkillMetadata] = None
|
||||
user_invocable: bool = True # Can users invoke this skill directly
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadSkillsResult:
|
||||
"""Result of loading skills from a directory."""
|
||||
skills: List[Skill]
|
||||
diagnostics: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillSnapshot:
|
||||
"""Snapshot of skills for a specific run."""
|
||||
prompt: str # Formatted prompt text
|
||||
skills: List[Dict[str, str]] # List of skill info (name, primary_env)
|
||||
resolved_skills: List[Skill] = field(default_factory=list)
|
||||
version: Optional[int] = None
|
||||
149
agent/tools/__init__.py
Normal file
149
agent/tools/__init__.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# Import base tool
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
|
||||
# Import file operation tools
|
||||
from agent.tools.read.read import Read
|
||||
from agent.tools.write.write import Write
|
||||
from agent.tools.edit.edit import Edit
|
||||
from agent.tools.bash.bash import Bash
|
||||
from agent.tools.ls.ls import Ls
|
||||
from agent.tools.send.send import Send
|
||||
|
||||
# Import memory tools
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
# Import tools with optional dependencies
|
||||
def _import_optional_tools():
|
||||
"""Import tools that have optional dependencies"""
|
||||
from common.log import logger
|
||||
tools = {}
|
||||
|
||||
# EnvConfig Tool (requires python-dotenv)
|
||||
try:
|
||||
from agent.tools.env_config.env_config import EnvConfig
|
||||
tools['EnvConfig'] = EnvConfig
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"[Tools] EnvConfig tool not loaded - missing dependency: {e}\n"
|
||||
f" To enable environment variable management, run:\n"
|
||||
f" pip install python-dotenv>=1.0.0"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] EnvConfig tool failed to load: {e}")
|
||||
|
||||
# Scheduler Tool (requires croniter)
|
||||
try:
|
||||
from agent.tools.scheduler.scheduler_tool import SchedulerTool
|
||||
tools['SchedulerTool'] = SchedulerTool
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"[Tools] Scheduler tool not loaded - missing dependency: {e}\n"
|
||||
f" To enable scheduled tasks, run:\n"
|
||||
f" pip install croniter>=2.0.0"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Scheduler tool failed to load: {e}")
|
||||
|
||||
# WebSearch Tool (conditionally loaded based on API key availability at init time)
|
||||
try:
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
tools['WebSearch'] = WebSearch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebSearch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebSearch failed to load: {e}")
|
||||
|
||||
# WebFetch Tool
|
||||
try:
|
||||
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||
tools['WebFetch'] = WebFetch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebFetch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebFetch failed to load: {e}")
|
||||
|
||||
# Vision Tool (conditionally loaded based on API key availability)
|
||||
try:
|
||||
from agent.tools.vision.vision import Vision
|
||||
tools['Vision'] = Vision
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] Vision not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Vision failed to load: {e}")
|
||||
|
||||
return tools
|
||||
|
||||
# Load optional tools
|
||||
_optional_tools = _import_optional_tools()
|
||||
EnvConfig = _optional_tools.get('EnvConfig')
|
||||
SchedulerTool = _optional_tools.get('SchedulerTool')
|
||||
WebSearch = _optional_tools.get('WebSearch')
|
||||
WebFetch = _optional_tools.get('WebFetch')
|
||||
Vision = _optional_tools.get('Vision')
|
||||
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
Terminal = _optional_tools.get('Terminal')
|
||||
|
||||
|
||||
# BrowserTool (requires playwright)
|
||||
def _import_browser_tool():
|
||||
from common.log import logger
|
||||
try:
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
return BrowserTool
|
||||
except ImportError as e:
|
||||
logger.info(
|
||||
f"[Tools] BrowserTool not loaded - missing dependency: {e}\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] BrowserTool failed to load: {e}")
|
||||
return None
|
||||
|
||||
BrowserTool = _import_browser_tool()
|
||||
|
||||
# MCP Tools (no extra dependencies, loaded on demand)
|
||||
def _import_mcp_tools():
|
||||
"""导入 MCP 工具模块(无额外依赖,按需加载)"""
|
||||
from common.log import logger
|
||||
try:
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
from agent.tools.mcp.mcp_client import McpClientRegistry
|
||||
return {'McpTool': McpTool, 'McpClientRegistry': McpClientRegistry}
|
||||
except Exception as e:
|
||||
logger.warning(f"[Tools] MCP tools not loaded: {e}")
|
||||
return {}
|
||||
|
||||
_mcp_tools = _import_mcp_tools()
|
||||
McpTool = _mcp_tools.get('McpTool')
|
||||
McpClientRegistry = _mcp_tools.get('McpClientRegistry')
|
||||
|
||||
# Export all tools (including optional ones that might be None)
|
||||
__all__ = [
|
||||
'BaseTool',
|
||||
'ToolManager',
|
||||
'Read',
|
||||
'Write',
|
||||
'Edit',
|
||||
'Bash',
|
||||
'Ls',
|
||||
'Send',
|
||||
'MemorySearchTool',
|
||||
'MemoryGetTool',
|
||||
'EnvConfig',
|
||||
'SchedulerTool',
|
||||
'WebSearch',
|
||||
'WebFetch',
|
||||
'Vision',
|
||||
'BrowserTool',
|
||||
'McpTool',
|
||||
]
|
||||
|
||||
"""
|
||||
Tools module for Agent.
|
||||
"""
|
||||
99
agent/tools/base_tool.py
Normal file
99
agent/tools/base_tool.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from common.log import logger
|
||||
import copy
|
||||
|
||||
|
||||
class ToolStage(Enum):
|
||||
"""Enum representing tool decision stages"""
|
||||
PRE_PROCESS = "pre_process" # Tools that need to be actively selected by the agent
|
||||
POST_PROCESS = "post_process" # Tools that automatically execute after final_answer
|
||||
|
||||
|
||||
class ToolResult:
|
||||
"""Tool execution result"""
|
||||
|
||||
def __init__(self, status: str = None, result: Any = None, ext_data: Any = None):
|
||||
self.status = status
|
||||
self.result = result
|
||||
self.ext_data = ext_data
|
||||
|
||||
@staticmethod
|
||||
def success(result, ext_data: Any = None):
|
||||
return ToolResult(status="success", result=result, ext_data=ext_data)
|
||||
|
||||
@staticmethod
|
||||
def fail(result, ext_data: Any = None):
|
||||
return ToolResult(status="error", result=result, ext_data=ext_data)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all tools."""
|
||||
|
||||
# Default decision stage is pre-process
|
||||
stage = ToolStage.PRE_PROCESS
|
||||
|
||||
# Class attributes must be inherited
|
||||
name: str = "base_tool"
|
||||
description: str = "Base tool"
|
||||
params: dict = {} # Store JSON Schema
|
||||
model: Optional[Any] = None # LLM model instance, type depends on bot implementation
|
||||
|
||||
@classmethod
|
||||
def get_json_schema(cls) -> dict:
|
||||
"""Get the standard description of the tool"""
|
||||
return {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": cls.params
|
||||
}
|
||||
|
||||
def execute_tool(self, params: dict) -> ToolResult:
|
||||
try:
|
||||
return self.execute(params)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
"""Specific logic to be implemented by subclasses"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _parse_schema(cls) -> dict:
|
||||
"""Convert JSON Schema to Pydantic fields"""
|
||||
fields = {}
|
||||
for name, prop in cls.params["properties"].items():
|
||||
# Convert JSON Schema types to Python types
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict
|
||||
}
|
||||
fields[name] = (
|
||||
type_map[prop["type"]],
|
||||
prop.get("default", ...)
|
||||
)
|
||||
return fields
|
||||
|
||||
def should_auto_execute(self, context) -> bool:
|
||||
"""
|
||||
Determine if this tool should be automatically executed based on context.
|
||||
|
||||
:param context: The agent context
|
||||
:return: True if the tool should be executed, False otherwise
|
||||
"""
|
||||
# Only tools in post-process stage will be automatically executed
|
||||
return self.stage == ToolStage.POST_PROCESS
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close any resources used by the tool.
|
||||
This method should be overridden by tools that need to clean up resources
|
||||
such as browser connections, file handles, etc.
|
||||
|
||||
By default, this method does nothing.
|
||||
"""
|
||||
pass
|
||||
3
agent/tools/bash/__init__.py
Normal file
3
agent/tools/bash/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .bash import Bash
|
||||
|
||||
__all__ = ['Bash']
|
||||
295
agent/tools/bash/bash.py
Normal file
295
agent/tools/bash/bash.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Bash tool - Execute bash commands
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Bash(BaseTool):
|
||||
"""Tool for executing bash commands"""
|
||||
|
||||
_IS_WIN = sys.platform == "win32"
|
||||
|
||||
name: str = "bash"
|
||||
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
|
||||
{'''
|
||||
PLATFORM: Windows (cmd.exe). Do NOT use Unix-only commands like grep, head, tail, sed, awk.
|
||||
''' if _IS_WIN else ''}
|
||||
ENVIRONMENT: All API keys from env_config are auto-injected. Use $VAR_NAME directly.
|
||||
|
||||
SAFETY:
|
||||
- Freely create/modify/delete files within the workspace
|
||||
- For destructive commands out of workspace, explain and confirm first"""
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (optional, default: 30)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
# Ensure working directory exists
|
||||
if not os.path.exists(self.cwd):
|
||||
os.makedirs(self.cwd, exist_ok=True)
|
||||
self.default_timeout = self.config.get("timeout", 30)
|
||||
# Enable safety mode by default (can be disabled in config)
|
||||
self.safety_mode = self.config.get("safety_mode", True)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute a bash command
|
||||
|
||||
:param args: Dictionary containing the command and optional timeout
|
||||
:return: Command output or error
|
||||
"""
|
||||
command = args.get("command", "").strip()
|
||||
timeout = args.get("timeout", self.default_timeout)
|
||||
|
||||
if not command:
|
||||
return ToolResult.fail("Error: command parameter is required")
|
||||
|
||||
# Security check: Prevent accessing sensitive config files
|
||||
if "~/.cow/.env" in command or "~/.cow" in command:
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
# Optional safety check - only warn about extremely dangerous commands
|
||||
if self.safety_mode:
|
||||
warning = self._get_safety_warning(command)
|
||||
if warning:
|
||||
return ToolResult.fail(
|
||||
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
|
||||
|
||||
try:
|
||||
# Prepare environment with .env file variables
|
||||
env = os.environ.copy()
|
||||
|
||||
# Load environment variables from ~/.cow/.env if it exists
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
dotenv_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
dotenv_vars = dotenv_values(env_file)
|
||||
env.update(dotenv_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(dotenv_vars)} variables from {env_file}")
|
||||
except ImportError:
|
||||
logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
|
||||
except Exception as e:
|
||||
logger.debug(f"[Bash] Failed to load .env: {e}")
|
||||
|
||||
# getuid() only exists on Unix-like systems
|
||||
if hasattr(os, 'getuid'):
|
||||
logger.debug(f"[Bash] Process UID: {os.getuid()}")
|
||||
else:
|
||||
logger.debug(f"[Bash] Process User: {os.environ.get('USERNAME', os.environ.get('USER', 'unknown'))}")
|
||||
|
||||
# On Windows, convert $VAR references to %VAR% for cmd.exe
|
||||
if self._IS_WIN:
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
command = self._convert_env_vars_for_windows(command, dotenv_vars)
|
||||
if command and not command.strip().lower().startswith("chcp"):
|
||||
command = f"chcp 65001 >nul 2>&1 && {command}"
|
||||
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
cwd=self.cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env,
|
||||
)
|
||||
|
||||
logger.debug(f"[Bash] Exit code: {result.returncode}")
|
||||
logger.debug(f"[Bash] Stdout length: {len(result.stdout)}")
|
||||
logger.debug(f"[Bash] Stderr length: {len(result.stderr)}")
|
||||
|
||||
# Workaround for exit code 126 with no output
|
||||
if result.returncode == 126 and not result.stdout and not result.stderr:
|
||||
logger.warning(f"[Bash] Exit 126 with no output - trying alternative execution method")
|
||||
# Try using argument list instead of shell=True
|
||||
import shlex
|
||||
try:
|
||||
parts = shlex.split(command)
|
||||
if len(parts) > 0:
|
||||
logger.info(f"[Bash] Retrying with argument list: {parts[:3]}...")
|
||||
retry_result = subprocess.run(
|
||||
parts,
|
||||
cwd=self.cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
logger.debug(f"[Bash] Retry exit code: {retry_result.returncode}, stdout: {len(retry_result.stdout)}, stderr: {len(retry_result.stderr)}")
|
||||
|
||||
# If retry succeeded, use retry result
|
||||
if retry_result.returncode == 0 or retry_result.stdout or retry_result.stderr:
|
||||
result = retry_result
|
||||
else:
|
||||
# Both attempts failed - check if this is openai-image-vision skill
|
||||
if 'openai-image-vision' in command or 'vision.sh' in command:
|
||||
# Create a mock result with helpful error message
|
||||
from types import SimpleNamespace
|
||||
result = SimpleNamespace(
|
||||
returncode=1,
|
||||
stdout='{"error": "图片无法解析", "reason": "该图片格式可能不受支持,或图片文件存在问题", "suggestion": "请尝试其他图片"}',
|
||||
stderr=''
|
||||
)
|
||||
logger.info(f"[Bash] Converted exit 126 to user-friendly image error message for vision skill")
|
||||
except Exception as retry_err:
|
||||
logger.warning(f"[Bash] Retry failed: {retry_err}")
|
||||
|
||||
# When command succeeds with stdout, keep output clean (stderr goes to server log only).
|
||||
# When command fails or stdout is empty, include stderr so the agent can diagnose.
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
logger.info(f"[Bash] stderr (not forwarded): {result.stderr[:500]}")
|
||||
else:
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
|
||||
# Check if we need to save full output to temp file
|
||||
temp_file_path = None
|
||||
total_bytes = len(output.encode('utf-8'))
|
||||
|
||||
if total_bytes > DEFAULT_MAX_BYTES:
|
||||
# Save full output to temp file
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='bash-') as f:
|
||||
f.write(output)
|
||||
temp_file_path = f.name
|
||||
|
||||
# Apply tail truncation
|
||||
truncation = truncate_tail(output)
|
||||
output_text = truncation.content or "(no output)"
|
||||
|
||||
# Build result
|
||||
details = {}
|
||||
|
||||
if truncation.truncated:
|
||||
details["truncation"] = truncation.to_dict()
|
||||
if temp_file_path:
|
||||
details["full_output_path"] = temp_file_path
|
||||
|
||||
# Build notice
|
||||
start_line = truncation.total_lines - truncation.output_lines + 1
|
||||
end_line = truncation.total_lines
|
||||
|
||||
if truncation.last_line_partial:
|
||||
# Edge case: last line alone > 30KB
|
||||
last_line = output.split('\n')[-1] if output else ""
|
||||
last_line_size = format_size(len(last_line.encode('utf-8')))
|
||||
output_text += f"\n\n[Showing last {format_size(truncation.output_bytes)} of line {end_line} (line is {last_line_size}). Full output: {temp_file_path}]"
|
||||
elif truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines}. Full output: {temp_file_path}]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Full output: {temp_file_path}]"
|
||||
|
||||
# Check exit code
|
||||
if result.returncode != 0:
|
||||
output_text += f"\n\nCommand exited with code {result.returncode}"
|
||||
return ToolResult.fail({
|
||||
"output": output_text,
|
||||
"exit_code": result.returncode,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output_text,
|
||||
"exit_code": result.returncode,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail(f"Error: Command timed out after {timeout} seconds")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing command: {str(e)}")
|
||||
|
||||
def _get_safety_warning(self, command: str) -> str:
|
||||
"""
|
||||
Get safety warning for absolutely catastrophic commands only.
|
||||
Keep the blocklist minimal so the agent retains maximum freedom.
|
||||
|
||||
:param command: Command to check
|
||||
:return: Warning message if dangerous, empty string if safe
|
||||
"""
|
||||
# Tokenize to avoid substring false positives (e.g. `rm -rf /tmp/x`
|
||||
# must not match `rm -rf /`).
|
||||
tokens = command.lower().split()
|
||||
|
||||
# `rm -rf /` or `rm -rf /*` targeting the real root.
|
||||
for i, tok in enumerate(tokens):
|
||||
if tok != "rm":
|
||||
continue
|
||||
has_rf = False
|
||||
for j in range(i + 1, len(tokens)):
|
||||
t = tokens[j]
|
||||
if t.startswith("-") and "r" in t and "f" in t:
|
||||
has_rf = True
|
||||
elif t in ("--recursive", "--force"):
|
||||
continue
|
||||
elif t in ("/", "/*"):
|
||||
if has_rf:
|
||||
return "This command will delete the entire filesystem"
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
# Disk wiping
|
||||
if "if=/dev/zero" in command.lower() and "dd " in command.lower():
|
||||
return "This command can destroy disk data"
|
||||
|
||||
# Power control - match only as a standalone word (\b enforces word boundary)
|
||||
if re.search(r'\b(shutdown|reboot|halt|poweroff)\b', command.lower()):
|
||||
return "This command will shut down or restart the system"
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _convert_env_vars_for_windows(command: str, dotenv_vars: dict) -> str:
|
||||
"""
|
||||
Convert bash-style $VAR / ${VAR} references to cmd.exe %VAR% syntax.
|
||||
Only converts variables loaded from .env (user-configured API keys etc.)
|
||||
to avoid breaking $PATH, jq expressions, regex, etc.
|
||||
"""
|
||||
if not dotenv_vars:
|
||||
return command
|
||||
|
||||
def replace_match(m):
|
||||
var_name = m.group(1) or m.group(2)
|
||||
if var_name in dotenv_vars:
|
||||
return f"%{var_name}%"
|
||||
return m.group(0)
|
||||
|
||||
return re.sub(r'\$\{(\w+)\}|\$(\w+)', replace_match, command)
|
||||
3
agent/tools/browser/__init__.py
Normal file
3
agent/tools/browser/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
|
||||
__all__ = ["BrowserTool"]
|
||||
961
agent/tools/browser/browser_service.py
Normal file
961
agent/tools/browser/browser_service.py
Normal file
@@ -0,0 +1,961 @@
|
||||
"""
|
||||
Browser service - Playwright wrapper managing browser lifecycle and page operations.
|
||||
|
||||
All Playwright calls run on a dedicated background thread so that callers from
|
||||
any worker thread can safely use the service. An idle-timeout mechanism
|
||||
automatically shuts down the browser (and its thread) after a configurable
|
||||
period of inactivity to free resources.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import queue
|
||||
import threading
|
||||
from typing import Optional, Dict, Any, List, Callable
|
||||
|
||||
from common.log import logger
|
||||
from common.utils import expand_path, is_cloud_deployment
|
||||
|
||||
|
||||
_DEFAULT_USER_DATA_DIR = "~/.cow/browser_profile"
|
||||
|
||||
try:
|
||||
from playwright.sync_api import sync_playwright, Browser, BrowserContext, Page, Playwright
|
||||
_HAS_PLAYWRIGHT = True
|
||||
except ImportError:
|
||||
_HAS_PLAYWRIGHT = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot DOM helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Tags that typically carry useful content for an agent
|
||||
_INTERACTIVE_TAGS = {
|
||||
"a", "button", "input", "textarea", "select", "option",
|
||||
"label", "details", "summary",
|
||||
}
|
||||
_SEMANTIC_TAGS = {
|
||||
"h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"p", "li", "td", "th", "caption", "figcaption", "blockquote", "pre", "code",
|
||||
"nav", "main", "article", "section", "header", "footer", "form", "table",
|
||||
"img", "video", "audio",
|
||||
}
|
||||
_KEEP_TAGS = _INTERACTIVE_TAGS | _SEMANTIC_TAGS
|
||||
|
||||
_SNAPSHOT_JS = """
|
||||
() => {
|
||||
const KEEP = new Set(%s);
|
||||
const INTERACTIVE = new Set(%s);
|
||||
const SKIP = new Set(["script","style","noscript","svg","path","meta","link","br","hr"]);
|
||||
const CLICKABLE_ROLES = new Set([
|
||||
"button","link","tab","menuitem","menuitemcheckbox","menuitemradio",
|
||||
"option","switch","checkbox","radio","combobox","searchbox","slider",
|
||||
"spinbutton","textbox","treeitem"
|
||||
]);
|
||||
let refCounter = 0;
|
||||
const refMap = {};
|
||||
|
||||
function visible(el) {
|
||||
if (!(el instanceof HTMLElement)) return true;
|
||||
const st = window.getComputedStyle(el);
|
||||
if (st.display === "none" || st.visibility === "hidden") return false;
|
||||
if (parseFloat(st.opacity) === 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Strong signals: these attributes alone are enough to mark as interactive
|
||||
function hasStrongInteractiveSignal(el) {
|
||||
const role = el.getAttribute("role");
|
||||
if (role && CLICKABLE_ROLES.has(role)) return true;
|
||||
if (el.hasAttribute("onclick") || el.hasAttribute("tabindex")) return true;
|
||||
if (el.hasAttribute("data-click") || el.hasAttribute("data-action")) return true;
|
||||
if (el.getAttribute("contenteditable") === "true") return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if cursor:pointer is set directly (not just inherited from parent)
|
||||
function hasOwnPointerCursor(el) {
|
||||
try {
|
||||
const st = window.getComputedStyle(el);
|
||||
if (st.cursor !== "pointer") return false;
|
||||
const parent = el.parentElement;
|
||||
if (parent) {
|
||||
const pst = window.getComputedStyle(parent);
|
||||
if (pst.cursor === "pointer") return false;
|
||||
}
|
||||
return true;
|
||||
} catch(e) {}
|
||||
return false;
|
||||
}
|
||||
|
||||
function hasTextOrContent(el) {
|
||||
const t = el.textContent || "";
|
||||
if (t.trim().length > 0) return true;
|
||||
if (el.querySelector("img,video,audio,canvas")) return true;
|
||||
const ariaLabel = el.getAttribute("aria-label");
|
||||
if (ariaLabel && ariaLabel.trim()) return true;
|
||||
const title = el.getAttribute("title");
|
||||
if (title && title.trim()) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function isImplicitInteractive(el) {
|
||||
if (hasStrongInteractiveSignal(el)) return true;
|
||||
if (hasOwnPointerCursor(el) && hasTextOrContent(el)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function getTextContent(el) {
|
||||
let text = "";
|
||||
for (const ch of el.childNodes) {
|
||||
if (ch.nodeType === Node.TEXT_NODE) {
|
||||
text += ch.textContent;
|
||||
}
|
||||
}
|
||||
return text.trim();
|
||||
}
|
||||
|
||||
function walk(node) {
|
||||
if (node.nodeType === Node.TEXT_NODE) {
|
||||
const t = node.textContent.trim();
|
||||
return t ? t : null;
|
||||
}
|
||||
if (node.nodeType !== Node.ELEMENT_NODE) return null;
|
||||
const tag = node.tagName.toLowerCase();
|
||||
if (SKIP.has(tag)) return null;
|
||||
if (!visible(node)) return null;
|
||||
|
||||
const children = [];
|
||||
for (const ch of node.childNodes) {
|
||||
const r = walk(ch);
|
||||
if (r !== null) {
|
||||
if (typeof r === "string") children.push(r);
|
||||
else children.push(r);
|
||||
}
|
||||
}
|
||||
|
||||
const nativeInteractive = INTERACTIVE.has(tag);
|
||||
const implicitInteractive = !nativeInteractive && (node instanceof HTMLElement) && isImplicitInteractive(node);
|
||||
const keep = KEEP.has(tag) || implicitInteractive;
|
||||
|
||||
if (!keep) {
|
||||
if (children.length === 0) return null;
|
||||
if (children.length === 1) return children[0];
|
||||
return children;
|
||||
}
|
||||
|
||||
const obj = { tag };
|
||||
if (nativeInteractive || implicitInteractive) {
|
||||
refCounter++;
|
||||
obj.ref = refCounter;
|
||||
refMap[refCounter] = node;
|
||||
}
|
||||
|
||||
if (implicitInteractive) {
|
||||
const role = node.getAttribute("role");
|
||||
if (role) obj.role = role;
|
||||
const directText = getTextContent(node);
|
||||
if (!directText && children.length === 0) {
|
||||
const ariaLabel = node.getAttribute("aria-label");
|
||||
const title = node.getAttribute("title");
|
||||
if (ariaLabel) obj.ariaLabel = ariaLabel;
|
||||
else if (title) obj.ariaLabel = title;
|
||||
}
|
||||
}
|
||||
|
||||
// Attributes
|
||||
if (tag === "a" && node.href) obj.href = node.getAttribute("href");
|
||||
if (tag === "img") {
|
||||
obj.alt = node.alt || "";
|
||||
obj.src = node.getAttribute("src") || "";
|
||||
}
|
||||
if (tag === "input" || tag === "textarea" || tag === "select") {
|
||||
obj.type = node.type || "text";
|
||||
obj.name = node.name || undefined;
|
||||
obj.value = node.value || undefined;
|
||||
obj.placeholder = node.placeholder || undefined;
|
||||
if (node.disabled) obj.disabled = true;
|
||||
if (tag === "input" && node.type === "checkbox") obj.checked = node.checked;
|
||||
}
|
||||
if (tag === "button") {
|
||||
if (node.disabled) obj.disabled = true;
|
||||
}
|
||||
if (tag === "option") {
|
||||
obj.value = node.value;
|
||||
if (node.selected) obj.selected = true;
|
||||
}
|
||||
if (tag === "label" && node.htmlFor) obj.for = node.htmlFor;
|
||||
|
||||
// Role / aria-label for native interactive & semantic elements
|
||||
if (!implicitInteractive) {
|
||||
const role = node.getAttribute("role");
|
||||
if (role) obj.role = role;
|
||||
const ariaLabel = node.getAttribute("aria-label");
|
||||
if (ariaLabel) obj.ariaLabel = ariaLabel;
|
||||
}
|
||||
|
||||
// Children
|
||||
if (children.length === 1 && typeof children[0] === "string") {
|
||||
obj.text = children[0];
|
||||
} else if (children.length > 0) {
|
||||
obj.children = children;
|
||||
}
|
||||
|
||||
return obj;
|
||||
}
|
||||
|
||||
const result = walk(document.body);
|
||||
window.__cowRefMap = refMap;
|
||||
return { tree: result, refCount: refCounter };
|
||||
}
|
||||
""" % (
|
||||
str(list(_KEEP_TAGS)),
|
||||
str(list(_INTERACTIVE_TAGS)),
|
||||
)
|
||||
|
||||
|
||||
_BROWSER_DEAD_HINTS = (
|
||||
"has been closed",
|
||||
"browser has disconnected",
|
||||
"target closed",
|
||||
"browser closed",
|
||||
"context or browser has been closed",
|
||||
)
|
||||
|
||||
|
||||
def _is_browser_dead_error(err: Exception) -> bool:
|
||||
"""Return True if *err* indicates the browser / page died out from under us."""
|
||||
msg = str(err).lower()
|
||||
return any(h in msg for h in _BROWSER_DEAD_HINTS)
|
||||
|
||||
|
||||
def _should_use_headless() -> bool:
|
||||
"""Decide headless mode: headless on Linux servers without display, headed elsewhere."""
|
||||
if sys.platform in ("win32", "darwin"):
|
||||
return False
|
||||
# Linux: check for display
|
||||
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _flatten_tree(node, indent=0) -> List[str]:
|
||||
"""Convert snapshot tree to compact text lines for LLM consumption."""
|
||||
if node is None:
|
||||
return []
|
||||
if isinstance(node, str):
|
||||
return [" " * indent + node]
|
||||
if isinstance(node, list):
|
||||
lines = []
|
||||
for child in node:
|
||||
lines.extend(_flatten_tree(child, indent))
|
||||
return lines
|
||||
if not isinstance(node, dict):
|
||||
return []
|
||||
|
||||
tag = node.get("tag", "?")
|
||||
ref = node.get("ref")
|
||||
parts = [tag]
|
||||
if ref:
|
||||
parts[0] = f"[{ref}] {tag}"
|
||||
|
||||
# Inline attributes
|
||||
for attr in ("type", "name", "href", "alt", "role", "ariaLabel", "placeholder", "value"):
|
||||
val = node.get(attr)
|
||||
if val:
|
||||
# Truncate long values
|
||||
s = str(val)
|
||||
if len(s) > 80:
|
||||
s = s[:77] + "..."
|
||||
parts.append(f'{attr}="{s}"')
|
||||
|
||||
for flag in ("disabled", "checked", "selected"):
|
||||
if node.get(flag):
|
||||
parts.append(flag)
|
||||
|
||||
prefix = " " * indent
|
||||
header = prefix + " ".join(parts)
|
||||
|
||||
text = node.get("text")
|
||||
if text:
|
||||
# Truncate long text
|
||||
if len(text) > 120:
|
||||
text = text[:117] + "..."
|
||||
header += f": {text}"
|
||||
|
||||
lines = [header]
|
||||
children = node.get("children", [])
|
||||
for child in children:
|
||||
lines.extend(_flatten_tree(child, indent + 2))
|
||||
return lines
|
||||
|
||||
|
||||
class BrowserService:
|
||||
"""Manages a Playwright browser on a dedicated background thread.
|
||||
|
||||
All Playwright operations are dispatched to a single long-lived thread via
|
||||
a task queue. Callers from *any* worker thread can use the public API
|
||||
safely. An idle timer automatically shuts the browser down after
|
||||
``idle_timeout`` seconds of inactivity (default 300 = 5 min).
|
||||
"""
|
||||
|
||||
_IDLE_TIMEOUT_DEFAULT = 300 # seconds
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
||||
self._config = config or {}
|
||||
self._headless: Optional[bool] = None
|
||||
self._screenshot_dir: Optional[str] = None
|
||||
|
||||
# Background thread state
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._task_queue: queue.Queue = queue.Queue()
|
||||
self._lock = threading.Lock()
|
||||
self._alive = False
|
||||
self._ready = threading.Event()
|
||||
|
||||
# Playwright objects (only accessed on the background thread)
|
||||
self._playwright = None
|
||||
self._browser = None
|
||||
self._context = None
|
||||
self._page = None
|
||||
|
||||
# Launch mode: one of "fresh" | "persistent" | "cdp".
|
||||
# - cdp: connect to an externally launched Chrome via CDP endpoint.
|
||||
# - persistent: launch with launch_persistent_context using a user_data_dir
|
||||
# so cookies / login state survive across runs (default).
|
||||
# - fresh: classic launch + new_context, clean state every run.
|
||||
cdp_endpoint = self._config.get("cdp_endpoint") or ""
|
||||
persistent_flag = self._config.get("persistent", True)
|
||||
user_data_dir_cfg = self._config.get("user_data_dir")
|
||||
if user_data_dir_cfg is None:
|
||||
user_data_dir_cfg = _DEFAULT_USER_DATA_DIR
|
||||
|
||||
self._cdp_endpoint: str = cdp_endpoint.strip() if isinstance(cdp_endpoint, str) else ""
|
||||
if self._cdp_endpoint:
|
||||
self._launch_mode = "cdp"
|
||||
self._user_data_dir: str = ""
|
||||
elif persistent_flag and user_data_dir_cfg:
|
||||
self._launch_mode = "persistent"
|
||||
self._user_data_dir = expand_path(str(user_data_dir_cfg))
|
||||
else:
|
||||
self._launch_mode = "fresh"
|
||||
self._user_data_dir = ""
|
||||
|
||||
# Idle auto-release
|
||||
idle_cfg = self._config.get("idle_timeout")
|
||||
self._idle_timeout: float = float(idle_cfg) if idle_cfg is not None else self._IDLE_TIMEOUT_DEFAULT
|
||||
self._idle_timer: Optional[threading.Timer] = None
|
||||
|
||||
# Set when the browser / page is detected to have died externally
|
||||
# (e.g. user manually closed the window). The next _submit() will then
|
||||
# tear down the stale thread and relaunch.
|
||||
self._needs_restart = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background-thread lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _start_thread(self):
|
||||
"""Start the dedicated Playwright thread if not already running."""
|
||||
with self._lock:
|
||||
if self._alive and self._thread and self._thread.is_alive():
|
||||
return
|
||||
# Wait for old thread to fully exit before creating a new one
|
||||
old = self._thread
|
||||
if old and old.is_alive():
|
||||
old.join(timeout=5)
|
||||
# Fresh queue to avoid stale sentinels from a previous close()
|
||||
self._task_queue = queue.Queue()
|
||||
self._alive = True
|
||||
self._ready = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True, name="BrowserThread")
|
||||
self._thread.start()
|
||||
# Block until browser is ready (or failed)
|
||||
self._ready.wait(timeout=30)
|
||||
|
||||
def _run_loop(self):
|
||||
"""Event loop running on the dedicated thread. Processes tasks until stopped."""
|
||||
logger.info("[Browser] Background thread started")
|
||||
try:
|
||||
self._launch_browser()
|
||||
except Exception as e:
|
||||
logger.error(f"[Browser] Failed to launch browser: {e}")
|
||||
self._alive = False
|
||||
self._ready.set()
|
||||
self._drain_queue(RuntimeError(f"Browser launch failed: {e}"))
|
||||
return
|
||||
self._ready.set()
|
||||
|
||||
while self._alive:
|
||||
try:
|
||||
task = self._task_queue.get(timeout=1.0)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if task is None:
|
||||
break
|
||||
fn, args, kwargs, result_slot = task
|
||||
try:
|
||||
result_slot["value"] = fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
result_slot["error"] = e
|
||||
if _is_browser_dead_error(e):
|
||||
self._needs_restart = True
|
||||
logger.warning(
|
||||
f"[Browser] Detected closed page/context ({e}); "
|
||||
"will relaunch on next request."
|
||||
)
|
||||
finally:
|
||||
result_slot["event"].set()
|
||||
|
||||
self._shutdown_browser()
|
||||
self._drain_queue(RuntimeError("Browser thread stopped"))
|
||||
logger.info("[Browser] Background thread exited")
|
||||
|
||||
def _drain_queue(self, error: Exception):
|
||||
"""Unblock all callers waiting on the queue with an error."""
|
||||
while True:
|
||||
try:
|
||||
task = self._task_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
if task is None:
|
||||
continue
|
||||
_, _, _, result_slot = task
|
||||
result_slot["error"] = error
|
||||
result_slot["event"].set()
|
||||
|
||||
def _launch_browser(self):
|
||||
"""Launch / connect Chromium on the background thread."""
|
||||
if self._headless is None:
|
||||
headless_cfg = self._config.get("headless")
|
||||
self._headless = headless_cfg if headless_cfg is not None else _should_use_headless()
|
||||
|
||||
launch_args = ["--disable-dev-shm-usage"]
|
||||
if self._headless:
|
||||
launch_args.append("--no-sandbox")
|
||||
|
||||
if is_cloud_deployment():
|
||||
launch_args.extend([
|
||||
"--disable-gpu",
|
||||
"--disable-software-rasterizer",
|
||||
"--disable-extensions",
|
||||
"--disable-background-networking",
|
||||
"--disable-background-timer-throttling",
|
||||
"--disable-renderer-backgrounding",
|
||||
"--disable-features=site-per-process,TranslateUI,IsolateOrigins",
|
||||
"--no-zygote",
|
||||
"--js-flags=--max-old-space-size=384",
|
||||
"--memory-pressure-off",
|
||||
])
|
||||
|
||||
extra_args = self._config.get("launch_args", [])
|
||||
if extra_args:
|
||||
launch_args.extend(extra_args)
|
||||
|
||||
viewport_w = self._config.get("viewport_width", 1280)
|
||||
viewport_h = self._config.get("viewport_height", 720)
|
||||
viewport = {"width": viewport_w, "height": viewport_h}
|
||||
user_agent = (
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/131.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
self._playwright = sync_playwright().start()
|
||||
|
||||
if self._launch_mode == "cdp":
|
||||
self._connect_cdp(viewport)
|
||||
elif self._launch_mode == "persistent":
|
||||
self._launch_persistent(launch_args, viewport, user_agent)
|
||||
else:
|
||||
self._launch_fresh(launch_args, viewport, user_agent)
|
||||
|
||||
logger.info("[Browser] Browser ready")
|
||||
|
||||
def _launch_fresh(self, launch_args: List[str], viewport: Dict[str, int], user_agent: str):
|
||||
"""Classic launch: brand new Chromium with an empty context."""
|
||||
logger.info(f"[Browser] Launching Chromium (fresh, headless={self._headless})")
|
||||
self._browser = self._playwright.chromium.launch(
|
||||
headless=self._headless,
|
||||
args=launch_args,
|
||||
)
|
||||
self._context = self._browser.new_context(
|
||||
viewport=viewport,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
self._page = self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _launch_persistent(self, launch_args: List[str], viewport: Dict[str, int], user_agent: str):
|
||||
"""Launch Chromium with a persistent user_data_dir so login state survives."""
|
||||
os.makedirs(self._user_data_dir, exist_ok=True)
|
||||
logger.info(
|
||||
f"[Browser] Launching Chromium (persistent, headless={self._headless}, "
|
||||
f"profile={self._user_data_dir})"
|
||||
)
|
||||
try:
|
||||
self._context = self._playwright.chromium.launch_persistent_context(
|
||||
user_data_dir=self._user_data_dir,
|
||||
headless=self._headless,
|
||||
args=launch_args,
|
||||
viewport=viewport,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
except Exception as e:
|
||||
# Profile is locked when another Chromium instance already holds it.
|
||||
msg = str(e).lower()
|
||||
if "singletonlock" in msg or "profile" in msg or "lock" in msg:
|
||||
raise RuntimeError(
|
||||
f"Browser profile '{self._user_data_dir}' is in use by another process. "
|
||||
"Close the other Chromium / cow instance, or set a different "
|
||||
"tools.browser.user_data_dir."
|
||||
) from e
|
||||
raise
|
||||
|
||||
# Persistent context has no parent Browser handle; reuse the auto-created page.
|
||||
self._browser = None
|
||||
pages = self._context.pages
|
||||
self._page = pages[0] if pages else self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _connect_cdp(self, viewport: Dict[str, int]):
|
||||
"""Attach to an existing Chrome started with --remote-debugging-port."""
|
||||
endpoint = self._cdp_endpoint
|
||||
logger.info(f"[Browser] Connecting to existing Chrome via CDP: {endpoint}")
|
||||
try:
|
||||
self._browser = self._playwright.chromium.connect_over_cdp(endpoint)
|
||||
except Exception as e:
|
||||
msg = str(e).lower()
|
||||
if "econnrefused" in msg or "connect" in msg or "refused" in msg:
|
||||
raise RuntimeError(
|
||||
f"Cannot reach Chrome at {endpoint}. The CDP browser is not "
|
||||
"running. Ask the user to launch Chrome with "
|
||||
"--remote-debugging-port and --user-data-dir, then retry. "
|
||||
"Do not retry this tool until the user confirms."
|
||||
) from e
|
||||
raise
|
||||
|
||||
contexts = self._browser.contexts
|
||||
if contexts:
|
||||
self._context = contexts[0]
|
||||
else:
|
||||
self._context = self._browser.new_context(viewport=viewport)
|
||||
|
||||
pages = self._context.pages
|
||||
self._page = pages[0] if pages else self._context.new_page()
|
||||
self._wire_close_listeners()
|
||||
|
||||
def _wire_close_listeners(self):
|
||||
"""Mark needs_restart whenever the browser / context / page dies externally."""
|
||||
def _on_dead(_obj=None):
|
||||
self._needs_restart = True
|
||||
|
||||
try:
|
||||
if self._browser:
|
||||
self._browser.on("disconnected", _on_dead)
|
||||
if self._context:
|
||||
self._context.on("close", _on_dead)
|
||||
if self._page:
|
||||
self._page.on("close", _on_dead)
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] Failed to wire close listeners: {e}")
|
||||
|
||||
def _shutdown_browser(self):
|
||||
"""Shut down Playwright resources on the background thread.
|
||||
|
||||
Mode-specific behavior:
|
||||
- cdp: only disconnect the Playwright client; leave the user's Chrome
|
||||
and its tabs untouched (do NOT close the context).
|
||||
- persistent: close the persistent context (no separate browser handle).
|
||||
- fresh: close context, then browser.
|
||||
"""
|
||||
self._cancel_idle_timer()
|
||||
|
||||
if self._launch_mode == "cdp":
|
||||
# For CDP, browser.close() only detaches the Playwright client;
|
||||
# the user's Chrome process and its tabs stay alive.
|
||||
try:
|
||||
if self._browser:
|
||||
self._browser.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] cdp disconnect error: {e}")
|
||||
else:
|
||||
for obj, label in [
|
||||
(self._context, "context"),
|
||||
(self._browser, "browser"),
|
||||
]:
|
||||
try:
|
||||
if obj:
|
||||
obj.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] {label} close error: {e}")
|
||||
|
||||
try:
|
||||
if self._playwright:
|
||||
self._playwright.stop()
|
||||
except Exception as e:
|
||||
logger.debug(f"[Browser] playwright stop error: {e}")
|
||||
self._page = None
|
||||
self._context = None
|
||||
self._browser = None
|
||||
self._playwright = None
|
||||
logger.info("[Browser] Browser closed")
|
||||
|
||||
def _submit(self, fn: Callable, *args, **kwargs):
|
||||
"""Submit *fn* to the background thread and block until it completes."""
|
||||
# If the browser died externally (e.g. user closed the window), tear
|
||||
# down the stale thread first so _start_thread() will relaunch fresh.
|
||||
if self._needs_restart:
|
||||
logger.info("[Browser] Restarting after detecting closed browser")
|
||||
self.close()
|
||||
self._needs_restart = False
|
||||
|
||||
self._start_thread()
|
||||
|
||||
if not self._alive:
|
||||
raise RuntimeError("Browser is not available")
|
||||
|
||||
self._reset_idle_timer()
|
||||
|
||||
result_slot: Dict[str, Any] = {"event": threading.Event()}
|
||||
self._task_queue.put((fn, args, kwargs, result_slot))
|
||||
|
||||
# Timeout prevents permanent hang if the background thread crashes
|
||||
completed = result_slot["event"].wait(timeout=120)
|
||||
if not completed:
|
||||
raise TimeoutError("Browser operation timed out (120s)")
|
||||
|
||||
if "error" in result_slot:
|
||||
raise result_slot["error"]
|
||||
return result_slot.get("value")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Idle auto-release
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _reset_idle_timer(self):
|
||||
self._cancel_idle_timer()
|
||||
if self._idle_timeout > 0:
|
||||
self._idle_timer = threading.Timer(self._idle_timeout, self._on_idle_timeout)
|
||||
self._idle_timer.daemon = True
|
||||
self._idle_timer.start()
|
||||
|
||||
def _cancel_idle_timer(self):
|
||||
if self._idle_timer:
|
||||
self._idle_timer.cancel()
|
||||
self._idle_timer = None
|
||||
|
||||
def _on_idle_timeout(self):
|
||||
logger.info(f"[Browser] Idle for {self._idle_timeout}s, auto-releasing browser")
|
||||
self.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def close(self):
|
||||
"""Shut down browser and background thread (safe from any thread)."""
|
||||
self._cancel_idle_timer()
|
||||
with self._lock:
|
||||
if not self._alive:
|
||||
self._needs_restart = False
|
||||
return
|
||||
self._alive = False
|
||||
t = self._thread
|
||||
if self._task_queue is not None:
|
||||
self._task_queue.put(None)
|
||||
if t is not None and t.is_alive():
|
||||
t.join(timeout=10)
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
self._needs_restart = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Actions (each method is dispatched to the background thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def navigate(self, url: str, timeout: int = 30000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_navigate, url, timeout)
|
||||
|
||||
def _do_navigate(self, url: str, timeout: int) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
resp = page.goto(url, wait_until="domcontentloaded", timeout=timeout)
|
||||
status = resp.status if resp else None
|
||||
except Exception as e:
|
||||
return {"error": f"Navigation failed: {e}"}
|
||||
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=8000)
|
||||
except Exception:
|
||||
pass
|
||||
page.wait_for_timeout(500)
|
||||
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
current_url = page.url
|
||||
except Exception:
|
||||
current_url = url
|
||||
|
||||
return {"url": current_url, "title": title, "status": status}
|
||||
|
||||
def snapshot(self, selector: Optional[str] = None) -> str:
|
||||
return self._submit(self._do_snapshot, selector)
|
||||
|
||||
def _do_snapshot(self, selector: Optional[str] = None) -> str:
|
||||
page = self._page
|
||||
try:
|
||||
result = page.evaluate(_SNAPSHOT_JS)
|
||||
except Exception as e:
|
||||
return f"[Snapshot error: {e}]"
|
||||
|
||||
tree = result.get("tree")
|
||||
ref_count = result.get("refCount", 0)
|
||||
lines = _flatten_tree(tree)
|
||||
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
|
||||
header = f"Page: {title} ({url})\nInteractive elements: {ref_count}\n---"
|
||||
body = "\n".join(lines)
|
||||
|
||||
max_chars = self._config.get("snapshot_max_chars", 30000)
|
||||
if len(body) > max_chars:
|
||||
body = body[:max_chars] + "\n... [snapshot truncated]"
|
||||
|
||||
return f"{header}\n{body}"
|
||||
|
||||
def screenshot(self, full_page: bool = False, cwd: str = "") -> str:
|
||||
return self._submit(self._do_screenshot, full_page, cwd)
|
||||
|
||||
def _do_screenshot(self, full_page: bool = False, cwd: str = "") -> str:
|
||||
page = self._page
|
||||
save_dir = self._get_screenshot_dir(cwd)
|
||||
filename = f"screenshot_{uuid.uuid4().hex[:8]}.png"
|
||||
filepath = os.path.join(save_dir, filename)
|
||||
page.screenshot(path=filepath, full_page=full_page)
|
||||
logger.info(f"[Browser] Screenshot saved: {filepath}")
|
||||
return filepath
|
||||
|
||||
def click(self, ref: Optional[int] = None, selector: Optional[str] = None,
|
||||
timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_click, ref, selector, timeout)
|
||||
|
||||
def _do_click(self, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
|
||||
el.click();
|
||||
return {{ clicked: true, tag: el.tagName.toLowerCase() }};
|
||||
}}
|
||||
""")
|
||||
if result.get("error"):
|
||||
return result
|
||||
page.wait_for_timeout(500)
|
||||
return result
|
||||
elif selector:
|
||||
page.click(selector, timeout=timeout)
|
||||
return {"clicked": True, "selector": selector}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Click failed: {e}"}
|
||||
|
||||
def fill(self, text: str, ref: Optional[int] = None,
|
||||
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_fill, text, ref, selector, timeout)
|
||||
|
||||
def _do_fill(self, text, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el) return {{ error: "ref {ref} not found. Run snapshot first." }};
|
||||
el.focus();
|
||||
el.value = "";
|
||||
return {{ tag: el.tagName.toLowerCase(), name: el.name || "" }};
|
||||
}}
|
||||
""")
|
||||
if result.get("error"):
|
||||
return result
|
||||
page.keyboard.type(text)
|
||||
return {"filled": True, "ref": ref, "text": text}
|
||||
elif selector:
|
||||
page.fill(selector, text, timeout=timeout)
|
||||
return {"filled": True, "selector": selector, "text": text}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Fill failed: {e}"}
|
||||
|
||||
def select(self, value: str, ref: Optional[int] = None,
|
||||
selector: Optional[str] = None, timeout: int = 5000) -> Dict[str, Any]:
|
||||
return self._submit(self._do_select, value, ref, selector, timeout)
|
||||
|
||||
def _do_select(self, value, ref, selector, timeout) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if ref is not None:
|
||||
result = page.evaluate(f"""
|
||||
() => {{
|
||||
const el = window.__cowRefMap && window.__cowRefMap[{ref}];
|
||||
if (!el || el.tagName.toLowerCase() !== "select")
|
||||
return {{ error: "ref {ref} is not a <select> element" }};
|
||||
el.value = {repr(value)};
|
||||
el.dispatchEvent(new Event("change", {{ bubbles: true }}));
|
||||
return {{ selected: true, value: el.value }};
|
||||
}}
|
||||
""")
|
||||
return result
|
||||
elif selector:
|
||||
page.select_option(selector, value, timeout=timeout)
|
||||
return {"selected": True, "selector": selector, "value": value}
|
||||
else:
|
||||
return {"error": "Provide either ref (from snapshot) or selector"}
|
||||
except Exception as e:
|
||||
return {"error": f"Select failed: {e}"}
|
||||
|
||||
def scroll(self, direction: str = "down", amount: int = 500) -> Dict[str, Any]:
|
||||
return self._submit(self._do_scroll, direction, amount)
|
||||
|
||||
def _do_scroll(self, direction, amount) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
delta_map = {
|
||||
"down": (0, amount),
|
||||
"up": (0, -amount),
|
||||
"right": (amount, 0),
|
||||
"left": (-amount, 0),
|
||||
}
|
||||
dx, dy = delta_map.get(direction, (0, amount))
|
||||
try:
|
||||
page.mouse.wheel(dx, dy)
|
||||
page.wait_for_timeout(300)
|
||||
scroll_info = page.evaluate("""
|
||||
() => ({
|
||||
scrollX: window.scrollX,
|
||||
scrollY: window.scrollY,
|
||||
scrollHeight: document.documentElement.scrollHeight,
|
||||
clientHeight: document.documentElement.clientHeight
|
||||
})
|
||||
""")
|
||||
return {"scrolled": direction, "amount": amount, **scroll_info}
|
||||
except Exception as e:
|
||||
return {"error": f"Scroll failed: {e}"}
|
||||
|
||||
def wait(self, selector: Optional[str] = None, timeout: int = 5000,
|
||||
state: str = "visible") -> Dict[str, Any]:
|
||||
return self._submit(self._do_wait, selector, timeout, state)
|
||||
|
||||
def _do_wait(self, selector, timeout, state) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
if selector:
|
||||
page.wait_for_selector(selector, timeout=timeout, state=state)
|
||||
return {"waited": True, "selector": selector, "state": state}
|
||||
else:
|
||||
page.wait_for_timeout(timeout)
|
||||
return {"waited": True, "timeout_ms": timeout}
|
||||
except Exception as e:
|
||||
return {"error": f"Wait failed: {e}"}
|
||||
|
||||
def go_back(self) -> Dict[str, Any]:
|
||||
return self._submit(self._do_go_back)
|
||||
|
||||
def _do_go_back(self) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.go_back(wait_until="domcontentloaded", timeout=10000)
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
return {"url": url, "title": title}
|
||||
except Exception as e:
|
||||
return {"error": f"Go back failed: {e}"}
|
||||
|
||||
def go_forward(self) -> Dict[str, Any]:
|
||||
return self._submit(self._do_go_forward)
|
||||
|
||||
def _do_go_forward(self) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.go_forward(wait_until="domcontentloaded", timeout=10000)
|
||||
try:
|
||||
title = page.title()
|
||||
except Exception:
|
||||
title = ""
|
||||
try:
|
||||
url = page.url
|
||||
except Exception:
|
||||
url = ""
|
||||
return {"url": url, "title": title}
|
||||
except Exception as e:
|
||||
return {"error": f"Go forward failed: {e}"}
|
||||
|
||||
def get_text(self, selector: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_get_text, selector)
|
||||
|
||||
def _do_get_text(self, selector) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
text = page.text_content(selector, timeout=5000)
|
||||
return {"text": text or ""}
|
||||
except Exception as e:
|
||||
return {"error": f"Get text failed: {e}"}
|
||||
|
||||
def evaluate(self, script: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_evaluate, script)
|
||||
|
||||
def _do_evaluate(self, script) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
result = page.evaluate(script)
|
||||
return {"result": result}
|
||||
except Exception as e:
|
||||
return {"error": f"Evaluate failed: {e}"}
|
||||
|
||||
def press(self, key: str) -> Dict[str, Any]:
|
||||
return self._submit(self._do_press, key)
|
||||
|
||||
def _do_press(self, key) -> Dict[str, Any]:
|
||||
page = self._page
|
||||
try:
|
||||
page.keyboard.press(key)
|
||||
page.wait_for_timeout(300)
|
||||
return {"pressed": key}
|
||||
except Exception as e:
|
||||
return {"error": f"Press failed: {e}"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_screenshot_dir(self, cwd: str = "") -> str:
|
||||
if self._screenshot_dir and os.path.isdir(self._screenshot_dir):
|
||||
return self._screenshot_dir
|
||||
base = cwd or os.getcwd()
|
||||
d = os.path.join(base, "tmp")
|
||||
os.makedirs(d, exist_ok=True)
|
||||
self._screenshot_dir = d
|
||||
return d
|
||||
303
agent/tools/browser/browser_tool.py
Normal file
303
agent/tools/browser/browser_tool.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Browser tool - Control a Chromium browser for web navigation and interaction.
|
||||
|
||||
Uses Playwright under the hood. Browser instance is lazily started on first
|
||||
use, reused across tool calls within the same session, and cleaned up via
|
||||
close().
|
||||
|
||||
Launch modes (configured under `tools.browser` in config.json):
|
||||
- persistent (default): Chromium runs with a persistent user_data_dir
|
||||
(default `~/.cow/browser_profile`), so cookies and login state survive
|
||||
across runs. The user only needs to log in once.
|
||||
- cdp: When `cdp_endpoint` is set, attach to an externally launched Chrome
|
||||
via the Chrome DevTools Protocol. Lets the agent reuse the user's real
|
||||
browser (with all logins / extensions / true fingerprints).
|
||||
- fresh: Set `persistent` to false to fall back to a clean context every run.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.browser.browser_service import BrowserService
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class BrowserTool(BaseTool):
|
||||
"""Single tool exposing all browser actions via an 'action' parameter."""
|
||||
|
||||
name: str = "browser"
|
||||
description: str = (
|
||||
"Control a browser to navigate web pages, interact with elements, and extract content. "
|
||||
"Actions: navigate, snapshot, click, fill, select, scroll, screenshot, wait, back, forward, "
|
||||
"get_text, press, evaluate.\n\n"
|
||||
"Workflow: navigate (auto-includes snapshot with element refs) → click/fill/select by ref → snapshot to verify.\n\n"
|
||||
"Use snapshot as the primary way to read pages. Use screenshot + send to show key results to the user. "
|
||||
"For login/CAPTCHA/authorization etc., screenshot and ask the user for help. "
|
||||
"Login state is persisted across sessions (cookies / localStorage are kept in a "
|
||||
"user profile directory), so once the user logs in to a site, the agent can keep "
|
||||
"using it without logging in again."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The browser action to perform. One of: "
|
||||
"navigate, snapshot, click, fill, select, scroll, "
|
||||
"screenshot, wait, back, forward, get_text, press, evaluate"
|
||||
),
|
||||
"enum": [
|
||||
"navigate", "snapshot", "click", "fill", "select", "scroll",
|
||||
"screenshot", "wait", "back", "forward", "get_text", "press",
|
||||
"evaluate"
|
||||
]
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "URL to navigate to (for 'navigate' action)"
|
||||
},
|
||||
"ref": {
|
||||
"type": "integer",
|
||||
"description": "Element ref number from snapshot (for click/fill/select)"
|
||||
},
|
||||
"selector": {
|
||||
"type": "string",
|
||||
"description": "CSS selector as fallback when ref is unavailable (for click/fill/select/wait/get_text)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to type (for 'fill' action)"
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Option value (for 'select' action)"
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key to press, e.g. Enter, Tab, Escape (for 'press' action)"
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"description": "Scroll direction: up, down, left, right (for 'scroll' action, default: down)"
|
||||
},
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "JavaScript code to execute (for 'evaluate' action)"
|
||||
},
|
||||
"full_page": {
|
||||
"type": "boolean",
|
||||
"description": "Capture full page screenshot (for 'screenshot' action, default: false)"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in milliseconds (optional, default varies by action)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
_shared_service: Optional[BrowserService] = None
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self._service: Optional[BrowserService] = None
|
||||
|
||||
def _get_service(self) -> BrowserService:
|
||||
"""Get or create the browser service, sharing across copies."""
|
||||
if self._service is not None:
|
||||
return self._service
|
||||
|
||||
# Reuse shared service across tool copies within the same session
|
||||
if BrowserTool._shared_service is not None:
|
||||
self._service = BrowserTool._shared_service
|
||||
return self._service
|
||||
|
||||
self._service = BrowserService(self.config)
|
||||
BrowserTool._shared_service = self._service
|
||||
return self._service
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
action = args.get("action", "").strip().lower()
|
||||
if not action:
|
||||
return ToolResult.fail("Error: 'action' parameter is required")
|
||||
|
||||
handler = self._ACTION_MAP.get(action)
|
||||
if not handler:
|
||||
valid = ", ".join(sorted(self._ACTION_MAP.keys()))
|
||||
return ToolResult.fail(f"Unknown action '{action}'. Valid actions: {valid}")
|
||||
|
||||
try:
|
||||
return handler(self, args)
|
||||
except Exception as e:
|
||||
logger.error(f"[Browser] Action '{action}' error: {e}")
|
||||
return ToolResult.fail(f"Browser error ({action}): {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _do_navigate(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' is required for navigate action")
|
||||
# Only auto-prepend https:// for bare hosts; preserve file://, about:, data:, etc.
|
||||
if "://" not in url and not url.startswith(("about:", "data:")):
|
||||
url = "https://" + url
|
||||
timeout = args.get("timeout", 30000)
|
||||
service = self._get_service()
|
||||
result = service.navigate(url, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
# Auto-snapshot after navigation so the agent gets page content in one call
|
||||
snapshot_text = service.snapshot()
|
||||
return ToolResult.success(
|
||||
f"Navigated to: {result['url']}\nTitle: {result['title']}\nStatus: {result['status']}\n\n"
|
||||
f"--- Page Snapshot ---\n{snapshot_text}"
|
||||
)
|
||||
|
||||
def _do_snapshot(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector")
|
||||
text = self._get_service().snapshot(selector=selector)
|
||||
return ToolResult.success(text)
|
||||
|
||||
def _do_click(self, args: Dict[str, Any]) -> ToolResult:
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
result = self._get_service().click(ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Clicked successfully. Use 'snapshot' to see updated page.")
|
||||
|
||||
def _do_fill(self, args: Dict[str, Any]) -> ToolResult:
|
||||
text = args.get("text", "")
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
if not text and text != "":
|
||||
return ToolResult.fail("Error: 'text' is required for fill action")
|
||||
result = self._get_service().fill(text, ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Filled text into element. Use 'snapshot' to verify.")
|
||||
|
||||
def _do_select(self, args: Dict[str, Any]) -> ToolResult:
|
||||
value = args.get("value", "")
|
||||
ref = args.get("ref")
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
if not value:
|
||||
return ToolResult.fail("Error: 'value' is required for select action")
|
||||
result = self._get_service().select(value, ref=ref, selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Selected option '{value}'.")
|
||||
|
||||
def _do_scroll(self, args: Dict[str, Any]) -> ToolResult:
|
||||
direction = args.get("direction", "down")
|
||||
amount = args.get("timeout", 500) # reuse timeout field or default
|
||||
if "amount" in args:
|
||||
amount = args["amount"]
|
||||
result = self._get_service().scroll(direction=direction, amount=amount)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
pos = f"scrollY={result.get('scrollY', '?')}/{result.get('scrollHeight', '?')}"
|
||||
return ToolResult.success(f"Scrolled {direction}. Position: {pos}")
|
||||
|
||||
def _do_screenshot(self, args: Dict[str, Any]) -> ToolResult:
|
||||
full_page = args.get("full_page", False)
|
||||
filepath = self._get_service().screenshot(full_page=full_page, cwd=self.cwd)
|
||||
return ToolResult.success(f"Screenshot saved to: {filepath}")
|
||||
|
||||
def _do_wait(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector")
|
||||
timeout = args.get("timeout", 5000)
|
||||
result = self._get_service().wait(selector=selector, timeout=timeout)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Wait completed.")
|
||||
|
||||
def _do_back(self, args: Dict[str, Any]) -> ToolResult:
|
||||
result = self._get_service().go_back()
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Navigated back to: {result['url']}")
|
||||
|
||||
def _do_forward(self, args: Dict[str, Any]) -> ToolResult:
|
||||
result = self._get_service().go_forward()
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Navigated forward to: {result['url']}")
|
||||
|
||||
def _do_get_text(self, args: Dict[str, Any]) -> ToolResult:
|
||||
selector = args.get("selector", "").strip()
|
||||
if not selector:
|
||||
return ToolResult.fail("Error: 'selector' is required for get_text action")
|
||||
result = self._get_service().get_text(selector)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(result["text"])
|
||||
|
||||
def _do_press(self, args: Dict[str, Any]) -> ToolResult:
|
||||
key = args.get("key", "").strip()
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for press action")
|
||||
result = self._get_service().press(key)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
return ToolResult.success(f"Pressed key: {key}")
|
||||
|
||||
def _do_evaluate(self, args: Dict[str, Any]) -> ToolResult:
|
||||
script = args.get("script", "").strip()
|
||||
if not script:
|
||||
return ToolResult.fail("Error: 'script' is required for evaluate action")
|
||||
result = self._get_service().evaluate(script)
|
||||
if "error" in result:
|
||||
return ToolResult.fail(result["error"])
|
||||
val = result.get("result")
|
||||
if isinstance(val, (dict, list)):
|
||||
return ToolResult.success(json.dumps(val, ensure_ascii=False, indent=2))
|
||||
return ToolResult.success(str(val) if val is not None else "(no return value)")
|
||||
|
||||
# Action dispatch table
|
||||
_ACTION_MAP = {
|
||||
"navigate": _do_navigate,
|
||||
"snapshot": _do_snapshot,
|
||||
"click": _do_click,
|
||||
"fill": _do_fill,
|
||||
"select": _do_select,
|
||||
"scroll": _do_scroll,
|
||||
"screenshot": _do_screenshot,
|
||||
"wait": _do_wait,
|
||||
"back": _do_back,
|
||||
"forward": _do_forward,
|
||||
"get_text": _do_get_text,
|
||||
"press": _do_press,
|
||||
"evaluate": _do_evaluate,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def copy(self):
|
||||
"""Share browser instance across tool copies (avoids re-launching)."""
|
||||
new_tool = BrowserTool(self.config)
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, "context", None)
|
||||
new_tool.cwd = self.cwd
|
||||
new_tool._service = self._service
|
||||
return new_tool
|
||||
|
||||
def close(self):
|
||||
"""Release browser resources."""
|
||||
if self._service:
|
||||
self._service.close()
|
||||
self._service = None
|
||||
BrowserTool._shared_service = None
|
||||
logger.info("[Browser] BrowserTool closed")
|
||||
3
agent/tools/edit/__init__.py
Normal file
3
agent/tools/edit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .edit import Edit
|
||||
|
||||
__all__ = ['Edit']
|
||||
185
agent/tools/edit/edit.py
Normal file
185
agent/tools/edit/edit.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Edit tool - Precise file editing
|
||||
Edit files through exact text replacement
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
from agent.tools.utils.diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
normalize_to_lf,
|
||||
restore_line_endings,
|
||||
normalize_for_fuzzy_match,
|
||||
fuzzy_find_text,
|
||||
generate_diff_string
|
||||
)
|
||||
|
||||
|
||||
class Edit(BaseTool):
|
||||
"""Tool for precise file editing"""
|
||||
|
||||
name: str = "edit"
|
||||
description: str = "Edit a file by replacing exact text, or append to end if oldText is empty. For append: use empty oldText. For replace: oldText must match exactly (including whitespace)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit (relative or absolute)"
|
||||
},
|
||||
"oldText": {
|
||||
"type": "string",
|
||||
"description": "Text to find and replace. Use empty string to append to end of file. For replacement: must match exactly including whitespace."
|
||||
},
|
||||
"newText": {
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with"
|
||||
}
|
||||
},
|
||||
"required": ["path", "oldText", "newText"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.memory_manager = self.config.get("memory_manager", None)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file edit operation
|
||||
|
||||
:param args: Contains file path, old text and new text
|
||||
:return: Operation result
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
old_text = args.get("oldText", "")
|
||||
new_text = args.get("newText", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable/writable
|
||||
if not os.access(absolute_path, os.R_OK | os.W_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable/writable: {path}")
|
||||
|
||||
try:
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
raw_content = f.read()
|
||||
|
||||
# Remove BOM (LLM won't include invisible BOM in oldText)
|
||||
bom, content = strip_bom(raw_content)
|
||||
|
||||
# Detect original line ending
|
||||
original_ending = detect_line_ending(content)
|
||||
|
||||
# Normalize to LF
|
||||
normalized_content = normalize_to_lf(content)
|
||||
normalized_old_text = normalize_to_lf(old_text)
|
||||
normalized_new_text = normalize_to_lf(new_text)
|
||||
|
||||
# Special case: empty oldText means append to end of file
|
||||
if not old_text or not old_text.strip():
|
||||
# Append mode: add newText to the end
|
||||
# Add newline before newText if file doesn't end with one
|
||||
if normalized_content and not normalized_content.endswith('\n'):
|
||||
new_content = normalized_content + '\n' + normalized_new_text
|
||||
else:
|
||||
new_content = normalized_content + normalized_new_text
|
||||
base_content = normalized_content # For verification
|
||||
else:
|
||||
# Normal edit mode: find and replace
|
||||
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
|
||||
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
|
||||
|
||||
if not match_result.found:
|
||||
return ToolResult.fail(
|
||||
f"Error: Could not find the exact text in {path}. "
|
||||
"The old text must match exactly including all whitespace and newlines."
|
||||
)
|
||||
|
||||
# Calculate occurrence count (use fuzzy normalized content for consistency)
|
||||
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
|
||||
occurrences = fuzzy_content.count(fuzzy_old_text)
|
||||
|
||||
if occurrences > 1:
|
||||
return ToolResult.fail(
|
||||
f"Error: Found {occurrences} occurrences of the text in {path}. "
|
||||
"The text must be unique. Please provide more context to make it unique."
|
||||
)
|
||||
|
||||
# Execute replacement (use matched text position)
|
||||
base_content = match_result.content_for_replacement
|
||||
new_content = (
|
||||
base_content[:match_result.index] +
|
||||
normalized_new_text +
|
||||
base_content[match_result.index + match_result.match_length:]
|
||||
)
|
||||
|
||||
# Verify replacement actually changed content
|
||||
if base_content == new_content:
|
||||
return ToolResult.fail(
|
||||
f"Error: No changes made to {path}. "
|
||||
"The replacement produced identical content. "
|
||||
"This might indicate an issue with special characters or the text not existing as expected."
|
||||
)
|
||||
|
||||
# Restore original line endings
|
||||
final_content = bom + restore_line_endings(new_content, original_ending)
|
||||
|
||||
# Write file
|
||||
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||
f.write(final_content)
|
||||
|
||||
# Generate diff
|
||||
diff_result = generate_diff_string(base_content, new_content)
|
||||
|
||||
result = {
|
||||
"message": f"Successfully replaced text in {path}",
|
||||
"path": path,
|
||||
"diff": diff_result['diff'],
|
||||
"first_changed_line": diff_result['first_changed_line']
|
||||
}
|
||||
|
||||
# Notify memory manager if file is in memory directory
|
||||
if self.memory_manager and "memory/" in path:
|
||||
try:
|
||||
self.memory_manager.mark_dirty()
|
||||
except Exception as e:
|
||||
# Don't fail the edit if memory notification fails
|
||||
pass
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {path}")
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied accessing {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error editing file: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
3
agent/tools/env_config/__init__.py
Normal file
3
agent/tools/env_config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.env_config.env_config import EnvConfig
|
||||
|
||||
__all__ = ['EnvConfig']
|
||||
286
agent/tools/env_config/env_config.py
Normal file
286
agent/tools/env_config/env_config.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Environment Configuration Tool - Manage API keys and environment variables
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
# API Key 知识库:常见的环境变量及其描述
|
||||
API_KEY_REGISTRY = {
|
||||
# AI 模型服务
|
||||
"OPENAI_API_KEY": "OpenAI API 密钥 (用于GPT模型、Embedding模型)",
|
||||
"GEMINI_API_KEY": "Google Gemini API 密钥",
|
||||
"CLAUDE_API_KEY": "Claude API 密钥 (用于Claude模型)",
|
||||
"LINKAI_API_KEY": "LinkAI智能体平台 API 密钥,支持多种模型切换",
|
||||
# 搜索服务
|
||||
"BOCHA_API_KEY": "博查 AI 搜索 API 密钥 ",
|
||||
}
|
||||
|
||||
class EnvConfig(BaseTool):
|
||||
"""Tool for managing environment variables (API keys, etc.)"""
|
||||
|
||||
name: str = "env_config"
|
||||
description: str = (
|
||||
"Manage API keys and skill configurations securely. "
|
||||
"Use this tool when user wants to configure API keys (like BOCHA_API_KEY, OPENAI_API_KEY), "
|
||||
"view configured keys, or manage skill settings. "
|
||||
"Actions: 'set' (add/update key), 'get' (view specific key), 'list' (show all configured keys), 'delete' (remove key). "
|
||||
"Values are automatically masked for security. Changes take effect immediately via hot reload."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform: 'set', 'get', 'list', 'delete'",
|
||||
"enum": ["set", "get", "list", "delete"]
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Environment variable key name. Common keys:\n"
|
||||
"- OPENAI_API_KEY: OpenAI API (GPT models)\n"
|
||||
"- OPENAI_API_BASE: OpenAI API base URL\n"
|
||||
"- CLAUDE_API_KEY: Anthropic Claude API\n"
|
||||
"- GEMINI_API_KEY: Google Gemini API\n"
|
||||
"- LINKAI_API_KEY: LinkAI platform\n"
|
||||
"- BOCHA_API_KEY: Bocha AI search (博查搜索)\n"
|
||||
"Use exact key names (case-sensitive, all uppercase with underscores)"
|
||||
)
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Value to set for the environment variable (for 'set' action)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
# Store env config in ~/.cow directory (outside workspace for security)
|
||||
self.env_dir = expand_path("~/.cow")
|
||||
self.env_path = os.path.join(self.env_dir, '.env')
|
||||
self.agent_bridge = self.config.get("agent_bridge") # Reference to AgentBridge for hot reload
|
||||
# Don't create .env file in __init__ to avoid issues during tool discovery
|
||||
# It will be created on first use in execute()
|
||||
|
||||
def _ensure_env_file(self):
|
||||
"""Ensure the .env file exists"""
|
||||
# Create ~/.cow directory if it doesn't exist
|
||||
os.makedirs(self.env_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(self.env_path):
|
||||
Path(self.env_path).touch()
|
||||
logger.info(f"[EnvConfig] Created .env file at {self.env_path}")
|
||||
|
||||
def _mask_value(self, value: str) -> str:
|
||||
"""Mask sensitive parts of a value for logging"""
|
||||
if not value or len(value) <= 10:
|
||||
return "***"
|
||||
return f"{value[:6]}***{value[-4:]}"
|
||||
|
||||
def _read_env_file(self) -> Dict[str, str]:
|
||||
"""Read all key-value pairs from .env file"""
|
||||
env_vars = {}
|
||||
if os.path.exists(self.env_path):
|
||||
with open(self.env_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
# Parse KEY=VALUE
|
||||
match = re.match(r'^([^=]+)=(.*)$', line)
|
||||
if match:
|
||||
key, value = match.groups()
|
||||
env_vars[key.strip()] = value.strip()
|
||||
return env_vars
|
||||
|
||||
def _write_env_file(self, env_vars: Dict[str, str]):
|
||||
"""Write all key-value pairs to .env file"""
|
||||
with open(self.env_path, 'w', encoding='utf-8') as f:
|
||||
f.write("# Environment variables for agent skills\n")
|
||||
f.write("# Auto-managed by env_config tool\n\n")
|
||||
for key, value in sorted(env_vars.items()):
|
||||
f.write(f"{key}={value}\n")
|
||||
|
||||
def _reload_env(self):
|
||||
"""Reload environment variables from .env file"""
|
||||
env_vars = self._read_env_file()
|
||||
for key, value in env_vars.items():
|
||||
os.environ[key] = value
|
||||
logger.debug(f"[EnvConfig] Reloaded {len(env_vars)} environment variables")
|
||||
|
||||
def _refresh_skills(self):
|
||||
"""Refresh skills after environment variable changes"""
|
||||
if self.agent_bridge:
|
||||
try:
|
||||
# Reload .env file
|
||||
self._reload_env()
|
||||
|
||||
# Refresh skills in all agent instances
|
||||
refreshed = self.agent_bridge.refresh_all_skills()
|
||||
logger.info(f"[EnvConfig] Refreshed skills in {refreshed} agent instance(s)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[EnvConfig] Failed to refresh skills: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute environment configuration operation
|
||||
|
||||
:param args: Contains action, key, and value parameters
|
||||
:return: Result of the operation
|
||||
"""
|
||||
# Ensure .env file exists on first use
|
||||
self._ensure_env_file()
|
||||
|
||||
action = args.get("action")
|
||||
key = args.get("key")
|
||||
value = args.get("value")
|
||||
|
||||
try:
|
||||
if action == "set":
|
||||
if not key or not value:
|
||||
return ToolResult.fail("Error: 'key' and 'value' are required for 'set' action.")
|
||||
|
||||
# Read current env vars
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
# Update the key
|
||||
env_vars[key] = value
|
||||
|
||||
# Write back to file
|
||||
self._write_env_file(env_vars)
|
||||
|
||||
# Update current process env
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[EnvConfig] Set {key}={self._mask_value(value)}")
|
||||
|
||||
# Try to refresh skills immediately
|
||||
refreshed = self._refresh_skills()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully set {key}",
|
||||
"key": key,
|
||||
"value": self._mask_value(value),
|
||||
}
|
||||
|
||||
if refreshed:
|
||||
result["note"] = "✅ Skills refreshed automatically - changes are now active"
|
||||
else:
|
||||
result["note"] = "⚠️ Skills not refreshed - restart agent to load new skills"
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
elif action == "get":
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for 'get' action.")
|
||||
|
||||
# Check in file first, then in current env
|
||||
env_vars = self._read_env_file()
|
||||
value = env_vars.get(key) or os.getenv(key)
|
||||
|
||||
# Get description from registry
|
||||
description = API_KEY_REGISTRY.get(key, "未知用途的环境变量")
|
||||
|
||||
if value is not None:
|
||||
logger.info(f"[EnvConfig] Got {key}={self._mask_value(value)}")
|
||||
return ToolResult.success({
|
||||
"key": key,
|
||||
"value": self._mask_value(value),
|
||||
"description": description,
|
||||
"exists": True,
|
||||
"note": f"Value is masked for security. In bash, use ${key} directly — it is auto-injected."
|
||||
})
|
||||
else:
|
||||
return ToolResult.success({
|
||||
"key": key,
|
||||
"description": description,
|
||||
"exists": False,
|
||||
"message": f"Environment variable '{key}' is not set"
|
||||
})
|
||||
|
||||
elif action == "list":
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
# Build detailed variable list with descriptions
|
||||
variables_with_info = {}
|
||||
for key, value in env_vars.items():
|
||||
variables_with_info[key] = {
|
||||
"value": self._mask_value(value),
|
||||
"description": API_KEY_REGISTRY.get(key, "未知用途的环境变量")
|
||||
}
|
||||
|
||||
logger.info(f"[EnvConfig] Listed {len(env_vars)} environment variables")
|
||||
|
||||
if not env_vars:
|
||||
return ToolResult.success({
|
||||
"message": "No environment variables configured",
|
||||
"variables": {},
|
||||
"note": "常用的 API 密钥可以通过 env_config(action='set', key='KEY_NAME', value='your-key') 来配置"
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"message": f"Found {len(env_vars)} environment variable(s)",
|
||||
"variables": variables_with_info
|
||||
})
|
||||
|
||||
elif action == "delete":
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for 'delete' action.")
|
||||
|
||||
# Read current env vars
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
if key not in env_vars:
|
||||
return ToolResult.success({
|
||||
"message": f"Environment variable '{key}' was not set",
|
||||
"key": key
|
||||
})
|
||||
|
||||
# Remove the key
|
||||
del env_vars[key]
|
||||
|
||||
# Write back to file
|
||||
self._write_env_file(env_vars)
|
||||
|
||||
# Remove from current process env
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
logger.info(f"[EnvConfig] Deleted {key}")
|
||||
|
||||
# Try to refresh skills immediately
|
||||
refreshed = self._refresh_skills()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully deleted {key}",
|
||||
"key": key,
|
||||
}
|
||||
|
||||
if refreshed:
|
||||
result["note"] = "✅ Skills refreshed automatically - changes are now active"
|
||||
else:
|
||||
result["note"] = "⚠️ Skills not refreshed - restart agent to apply changes"
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
else:
|
||||
return ToolResult.fail(f"Error: Unknown action '{action}'. Use 'set', 'get', 'list', or 'delete'.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EnvConfig] Error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"EnvConfig tool error: {str(e)}")
|
||||
3
agent/tools/ls/__init__.py
Normal file
3
agent/tools/ls/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ls import Ls
|
||||
|
||||
__all__ = ['Ls']
|
||||
140
agent/tools/ls/ls.py
Normal file
140
agent/tools/ls/ls.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Ls tool - List directory contents
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 500
|
||||
|
||||
|
||||
class Ls(BaseTool):
|
||||
"""Tool for listing directory contents"""
|
||||
|
||||
name: str = "ls"
|
||||
description: str = f"List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to {DEFAULT_LIMIT} entries or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to list. IMPORTANT: Relative paths are based on workspace directory. To access directories outside workspace, use absolute paths starting with ~ or /."
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of entries to return (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute directory listing
|
||||
|
||||
:param args: Listing parameters
|
||||
:return: Directory contents or error
|
||||
"""
|
||||
path = args.get("path", ".").strip()
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent accessing sensitive config directory
|
||||
env_config_dir = expand_path("~/.cow")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_dir):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
# Provide helpful hint if using relative path
|
||||
if not os.path.isabs(path) and not path.startswith('~'):
|
||||
return ToolResult.fail(
|
||||
f"Error: Path not found: {path}\n"
|
||||
f"Resolved to: {absolute_path}\n"
|
||||
f"Hint: Relative paths are based on workspace ({self.cwd}). For files outside workspace, use absolute paths."
|
||||
)
|
||||
return ToolResult.fail(f"Error: Path not found: {path}")
|
||||
|
||||
if not os.path.isdir(absolute_path):
|
||||
return ToolResult.fail(f"Error: Not a directory: {path}")
|
||||
|
||||
try:
|
||||
# Read directory entries
|
||||
entries = os.listdir(absolute_path)
|
||||
|
||||
# Sort alphabetically (case-insensitive)
|
||||
entries.sort(key=lambda x: x.lower())
|
||||
|
||||
# Format entries with directory indicators
|
||||
results = []
|
||||
entry_limit_reached = False
|
||||
|
||||
for entry in entries:
|
||||
if len(results) >= limit:
|
||||
entry_limit_reached = True
|
||||
break
|
||||
|
||||
full_path = os.path.join(absolute_path, entry)
|
||||
|
||||
try:
|
||||
if os.path.isdir(full_path):
|
||||
results.append(entry + '/')
|
||||
else:
|
||||
results.append(entry)
|
||||
except Exception:
|
||||
# Skip entries we can't stat
|
||||
continue
|
||||
|
||||
if not results:
|
||||
return ToolResult.success({"message": "(empty directory)", "entries": []})
|
||||
|
||||
# Format output
|
||||
raw_output = '\n'.join(results)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
if entry_limit_reached:
|
||||
notices.append(f"{limit} entries limit reached. Use limit={limit * 2} for more")
|
||||
details["entry_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"entry_count": len(results),
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied reading directory: {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error listing directory: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
4
agent/tools/mcp/__init__.py
Normal file
4
agent/tools/mcp/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from agent.tools.mcp.mcp_client import McpClient, McpClientRegistry
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
|
||||
__all__ = ["McpClient", "McpClientRegistry", "McpTool"]
|
||||
528
agent/tools/mcp/mcp_client.py
Normal file
528
agent/tools/mcp/mcp_client.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) client module.
|
||||
|
||||
Implements JSON-RPC 2.0 over stdio, SSE and Streamable HTTP transports
|
||||
without any external MCP SDK dependency.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import select
|
||||
import subprocess
|
||||
import threading
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from typing import Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# Aliases accepted for the Streamable HTTP transport type
|
||||
_STREAMABLE_HTTP_ALIASES = {"streamable-http", "streamable_http", "streamablehttp", "http"}
|
||||
|
||||
|
||||
class McpClient:
|
||||
"""Single MCP Server client supporting stdio, SSE and Streamable HTTP transports."""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""
|
||||
config examples:
|
||||
stdio: {"name": "filesystem", "type": "stdio", "command": "npx", "args": [...]}
|
||||
SSE: {"name": "my-api", "type": "sse", "url": "http://localhost:8000/sse"}
|
||||
streamable-http: {"name": "pubmed", "type": "streamable-http", "url": "https://x/mcp"}
|
||||
"""
|
||||
self.config = config
|
||||
self.name: str = config.get("name", "unknown")
|
||||
raw_transport: str = config.get("type", "stdio")
|
||||
# Normalize streamable-http aliases to a single internal key
|
||||
self.transport: str = (
|
||||
"streamable-http"
|
||||
if raw_transport.lower() in _STREAMABLE_HTTP_ALIASES
|
||||
else raw_transport
|
||||
)
|
||||
|
||||
# stdio state
|
||||
self._proc: Optional[subprocess.Popen] = None
|
||||
|
||||
# SSE state
|
||||
self._sse_url: Optional[str] = None
|
||||
self._post_url: Optional[str] = None # endpoint for sending messages (resolved from SSE)
|
||||
|
||||
# Streamable HTTP state
|
||||
self._http_url: Optional[str] = None
|
||||
self._http_headers: dict = {} # extra headers from user config (e.g. Authorization)
|
||||
self._http_session_id: Optional[str] = None # Mcp-Session-Id assigned by the server
|
||||
|
||||
# Shared state
|
||||
self._next_id = 1
|
||||
self._id_lock = threading.Lock()
|
||||
self._call_lock = threading.Lock()
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Connect and perform the MCP handshake. Returns True on success."""
|
||||
try:
|
||||
if self.transport == "stdio":
|
||||
return self._init_stdio()
|
||||
elif self.transport == "sse":
|
||||
return self._init_sse()
|
||||
elif self.transport == "streamable-http":
|
||||
return self._init_streamable_http()
|
||||
else:
|
||||
logger.warning(f"[MCP:{self.name}] Unknown transport type: {self.transport!r}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] Initialization failed: {e}")
|
||||
return False
|
||||
|
||||
def list_tools(self) -> list:
|
||||
"""Return the tool list from this server.
|
||||
|
||||
Each item is a dict: {"name": str, "description": str, "inputSchema": dict}
|
||||
"""
|
||||
try:
|
||||
resp = self._send_request("tools/list", {})
|
||||
tools = resp.get("result", {}).get("tools", [])
|
||||
return [
|
||||
{
|
||||
"name": t.get("name", ""),
|
||||
"description": t.get("description", ""),
|
||||
"inputSchema": t.get("inputSchema", {}),
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] list_tools failed: {e}")
|
||||
return []
|
||||
|
||||
def call_tool(self, name: str, arguments: dict) -> str:
|
||||
"""Call a tool and return the result as a string."""
|
||||
try:
|
||||
resp = self._send_request("tools/call", {"name": name, "arguments": arguments})
|
||||
content = resp.get("result", {}).get("content", [])
|
||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||
return "\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] call_tool({name}) failed: {e}")
|
||||
return f"Error: {e}"
|
||||
|
||||
def shutdown(self):
|
||||
"""Close the connection / terminate the child process."""
|
||||
if self._proc is not None:
|
||||
try:
|
||||
self._proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._proc.terminate()
|
||||
self._proc.wait(timeout=5)
|
||||
except Exception:
|
||||
try:
|
||||
self._proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
self._proc = None
|
||||
logger.debug(f"[MCP:{self.name}] stdio process terminated")
|
||||
|
||||
# Best-effort streamable-http session termination
|
||||
if self.transport == "streamable-http" and self._http_session_id and self._http_url:
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
self._http_url,
|
||||
method="DELETE",
|
||||
headers={"Mcp-Session-Id": self._http_session_id, **self._http_headers},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5):
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
self._http_session_id = None
|
||||
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# stdio transport
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_stdio(self) -> bool:
|
||||
command = self.config.get("command")
|
||||
if not command:
|
||||
logger.warning(f"[MCP:{self.name}] stdio config missing 'command'")
|
||||
return False
|
||||
|
||||
args = self.config.get("args", [])
|
||||
extra_env = self.config.get("env", None)
|
||||
env = {**os.environ, **extra_env} if extra_env else None
|
||||
|
||||
self._proc = subprocess.Popen(
|
||||
[command] + list(args),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
env=env,
|
||||
)
|
||||
logger.debug(f"[MCP:{self.name}] stdio process started (pid={self._proc.pid})")
|
||||
|
||||
threading.Thread(
|
||||
target=self._drain_stderr, daemon=True, name=f"mcp-stderr-{self.name}"
|
||||
).start()
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _drain_stderr(self):
|
||||
for line in self._proc.stderr:
|
||||
line = line.strip()
|
||||
if line:
|
||||
logger.debug(f"[MCP:{self.name}] stderr: {line}")
|
||||
|
||||
def _readline_with_timeout(self, timeout: int = 30) -> str:
|
||||
"""Read one line from stdio stdout with a hard timeout."""
|
||||
ready, _, _ = select.select([self._proc.stdout], [], [], timeout)
|
||||
if not ready:
|
||||
raise TimeoutError(f"[MCP:{self.name}] stdio read timed out after {timeout}s")
|
||||
return self._proc.stdout.readline()
|
||||
|
||||
def _stdio_send(self, message: dict) -> dict:
|
||||
"""Send a JSON-RPC message over stdio and read the response."""
|
||||
raw = json.dumps(message) + "\n"
|
||||
self._proc.stdin.write(raw)
|
||||
self._proc.stdin.flush()
|
||||
|
||||
while True:
|
||||
line = self._readline_with_timeout()
|
||||
if not line:
|
||||
raise IOError(f"[MCP:{self.name}] stdio process closed unexpectedly")
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if "id" not in data:
|
||||
logger.debug(f"[MCP:{self.name}] notification skipped: {data.get('method', '?')}")
|
||||
continue
|
||||
return data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE transport
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_sse(self) -> bool:
|
||||
url = self.config.get("url")
|
||||
if not url:
|
||||
logger.warning(f"[MCP:{self.name}] SSE config missing 'url'")
|
||||
return False
|
||||
|
||||
self._sse_url = url
|
||||
|
||||
# Read the first SSE event to discover the POST endpoint
|
||||
try:
|
||||
self._post_url = self._sse_discover_endpoint()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP:{self.name}] SSE endpoint discovery failed: {e}")
|
||||
return False
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _sse_discover_endpoint(self) -> str:
|
||||
"""Open SSE stream and read the 'endpoint' event to learn the POST URL."""
|
||||
req = urllib.request.Request(
|
||||
self._sse_url,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
for raw_line in resp:
|
||||
line = raw_line.decode("utf-8").rstrip("\n\r")
|
||||
if line.startswith("data:"):
|
||||
data = line[len("data:"):].strip()
|
||||
# Some servers send JSON with a "uri" or plain path
|
||||
if data.startswith("{"):
|
||||
parsed = json.loads(data)
|
||||
return parsed.get("uri") or parsed.get("url") or parsed.get("endpoint")
|
||||
# Plain relative or absolute URL
|
||||
if data.startswith("http"):
|
||||
return data
|
||||
# Relative path: resolve against SSE base
|
||||
from urllib.parse import urljoin
|
||||
return urljoin(self._sse_url, data)
|
||||
raise ValueError(f"[MCP:{self.name}] No endpoint event received from SSE stream")
|
||||
|
||||
def _sse_send(self, message: dict) -> dict:
|
||||
"""POST a JSON-RPC message to the server and return the response."""
|
||||
body = json.dumps(message).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
self._post_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
raw = resp.read().decode("utf-8")
|
||||
return json.loads(raw)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streamable HTTP transport (MCP spec 2025-03-26)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_streamable_http(self) -> bool:
|
||||
url = self.config.get("url")
|
||||
if not url:
|
||||
logger.warning(f"[MCP:{self.name}] streamable-http config missing 'url'")
|
||||
return False
|
||||
|
||||
self._http_url = url
|
||||
# Allow user-provided headers (e.g. {"Authorization": "Bearer xxx"})
|
||||
extra_headers = self.config.get("headers") or {}
|
||||
if isinstance(extra_headers, dict):
|
||||
self._http_headers = {str(k): str(v) for k, v in extra_headers.items()}
|
||||
|
||||
return self._handshake()
|
||||
|
||||
def _streamable_http_send(self, message: dict) -> dict:
|
||||
"""POST a JSON-RPC request and return the response (JSON or SSE-wrapped)."""
|
||||
return self._streamable_http_post(message, expect_response=True)
|
||||
|
||||
def _streamable_http_post(self, message: dict, expect_response: bool) -> dict:
|
||||
"""
|
||||
POST a JSON-RPC message over Streamable HTTP.
|
||||
|
||||
Per the spec, the response Content-Type can be either:
|
||||
- application/json -> single JSON-RPC response in body
|
||||
- text/event-stream -> SSE stream; we read until we get a matching response
|
||||
"""
|
||||
body = json.dumps(message).encode("utf-8")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
if self._http_session_id:
|
||||
headers["Mcp-Session-Id"] = self._http_session_id
|
||||
headers.update(self._http_headers)
|
||||
|
||||
req = urllib.request.Request(
|
||||
self._http_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = urllib.request.urlopen(req, timeout=30)
|
||||
except urllib.error.HTTPError as e:
|
||||
# Surface the server-provided error body for easier debugging
|
||||
detail = ""
|
||||
try:
|
||||
detail = e.read().decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
pass
|
||||
raise IOError(
|
||||
f"[MCP:{self.name}] streamable-http HTTP {e.code}: {detail[:200]}"
|
||||
)
|
||||
|
||||
with resp:
|
||||
# Capture session id assigned by the server (if any)
|
||||
session_id = resp.headers.get("Mcp-Session-Id")
|
||||
if session_id and not self._http_session_id:
|
||||
self._http_session_id = session_id
|
||||
|
||||
status = resp.status if hasattr(resp, "status") else resp.getcode()
|
||||
|
||||
# Notifications: server may reply with 202 Accepted and no body
|
||||
if not expect_response or status == 202:
|
||||
try:
|
||||
resp.read()
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
content_type = (resp.headers.get("Content-Type") or "").lower()
|
||||
expected_id = message.get("id")
|
||||
|
||||
if "text/event-stream" in content_type:
|
||||
return self._read_sse_response(resp, expected_id)
|
||||
|
||||
raw = resp.read().decode("utf-8")
|
||||
if not raw:
|
||||
return {}
|
||||
return json.loads(raw)
|
||||
|
||||
def _read_sse_response(self, resp, expected_id) -> dict:
|
||||
"""Read an SSE stream and return the first JSON-RPC response with matching id."""
|
||||
data_buf: list = []
|
||||
for raw_line in resp:
|
||||
line = raw_line.decode("utf-8").rstrip("\n\r")
|
||||
if line == "":
|
||||
# End of an SSE event, attempt to parse accumulated data
|
||||
if data_buf:
|
||||
payload = "\n".join(data_buf)
|
||||
data_buf = []
|
||||
try:
|
||||
msg = json.loads(payload)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
# Skip notifications / mismatched ids
|
||||
if "id" not in msg:
|
||||
continue
|
||||
if expected_id is None or msg.get("id") == expected_id:
|
||||
return msg
|
||||
continue
|
||||
if line.startswith(":"):
|
||||
continue # SSE comment / keepalive
|
||||
if line.startswith("data:"):
|
||||
data_buf.append(line[len("data:"):].lstrip())
|
||||
# Ignore 'event:' / 'id:' lines; we only care about JSON-RPC payloads
|
||||
|
||||
raise IOError(f"[MCP:{self.name}] streamable-http SSE stream closed before response")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Common JSON-RPC helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _next_request_id(self) -> int:
|
||||
with self._id_lock:
|
||||
rid = self._next_id
|
||||
self._next_id += 1
|
||||
return rid
|
||||
|
||||
def _build_request(self, method: str, params: dict) -> dict:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._next_request_id(),
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
def _build_notification(self, method: str, params: dict) -> dict:
|
||||
return {"jsonrpc": "2.0", "method": method, "params": params}
|
||||
|
||||
def _send_request(self, method: str, params: dict) -> dict:
|
||||
"""Send a request and return the full response dict."""
|
||||
if not self._initialized and method != "initialize":
|
||||
raise RuntimeError(f"[MCP:{self.name}] Client not initialized")
|
||||
|
||||
message = self._build_request(method, params)
|
||||
|
||||
with self._call_lock:
|
||||
if self.transport == "stdio":
|
||||
return self._stdio_send(message)
|
||||
elif self.transport == "sse":
|
||||
return self._sse_send(message)
|
||||
elif self.transport == "streamable-http":
|
||||
return self._streamable_http_send(message)
|
||||
else:
|
||||
raise ValueError(f"[MCP:{self.name}] Unsupported transport: {self.transport}")
|
||||
|
||||
def _send_notification(self, method: str, params: dict):
|
||||
"""Fire-and-forget notification (no response expected)."""
|
||||
notification = self._build_notification(method, params)
|
||||
raw = json.dumps(notification) + "\n"
|
||||
|
||||
if self.transport == "stdio":
|
||||
self._proc.stdin.write(raw)
|
||||
self._proc.stdin.flush()
|
||||
elif self.transport == "sse":
|
||||
body = raw.encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
self._post_url,
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10):
|
||||
pass
|
||||
except Exception:
|
||||
pass # notifications are fire-and-forget
|
||||
elif self.transport == "streamable-http":
|
||||
try:
|
||||
self._streamable_http_post(notification, expect_response=False)
|
||||
except Exception:
|
||||
pass # notifications are fire-and-forget
|
||||
|
||||
def _handshake(self) -> bool:
|
||||
"""Perform the MCP initialize / notifications/initialized handshake."""
|
||||
init_params = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "CowAgent", "version": "1.0"},
|
||||
}
|
||||
# Temporarily mark as initialized so _send_request doesn't block
|
||||
self._initialized = True
|
||||
try:
|
||||
resp = self._send_request("initialize", init_params)
|
||||
except Exception as e:
|
||||
self._initialized = False
|
||||
logger.warning(f"[MCP:{self.name}] Handshake initialize failed: {e}")
|
||||
return False
|
||||
|
||||
if "error" in resp:
|
||||
self._initialized = False
|
||||
logger.warning(f"[MCP:{self.name}] Handshake error: {resp['error']}")
|
||||
return False
|
||||
|
||||
self._send_notification("notifications/initialized", {})
|
||||
logger.debug(f"[MCP:{self.name}] Handshake complete")
|
||||
return True
|
||||
|
||||
|
||||
class McpClientRegistry:
|
||||
"""Global singleton managing the lifecycle of all MCP Server clients."""
|
||||
|
||||
_instance = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
obj = super().__new__(cls)
|
||||
obj._clients: dict[str, McpClient] = {}
|
||||
obj._registry_lock = threading.Lock()
|
||||
cls._instance = obj
|
||||
return cls._instance
|
||||
|
||||
def start_all(self, configs: list) -> None:
|
||||
"""Initialize McpClient for each config entry; skip failures with a warning."""
|
||||
if not configs:
|
||||
return
|
||||
|
||||
for cfg in configs:
|
||||
name = cfg.get("name", "<unnamed>")
|
||||
client = McpClient(cfg)
|
||||
ok = client.initialize()
|
||||
if ok:
|
||||
with self._registry_lock:
|
||||
self._clients[name] = client
|
||||
logger.info(f"[MCP] Server '{name}' initialized successfully")
|
||||
else:
|
||||
logger.warning(f"[MCP] Server '{name}' failed to initialize — skipping")
|
||||
|
||||
def get(self, server_name: str) -> Optional[McpClient]:
|
||||
"""Return the initialized client for server_name, or None."""
|
||||
with self._registry_lock:
|
||||
return self._clients.get(server_name)
|
||||
|
||||
def all_clients(self) -> dict:
|
||||
"""Return a copy of the {name: McpClient} mapping."""
|
||||
with self._registry_lock:
|
||||
return dict(self._clients)
|
||||
|
||||
def shutdown_all(self) -> None:
|
||||
"""Shut down all managed clients."""
|
||||
with self._registry_lock:
|
||||
clients = list(self._clients.values())
|
||||
self._clients.clear()
|
||||
|
||||
for client in clients:
|
||||
try:
|
||||
client.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP] Error shutting down '{client.name}': {e}")
|
||||
|
||||
logger.info("[MCP] All servers shut down")
|
||||
31
agent/tools/mcp/mcp_tool.py
Normal file
31
agent/tools/mcp/mcp_tool.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class McpTool(BaseTool):
|
||||
"""
|
||||
将单个 MCP 工具包装为 BaseTool。
|
||||
一个 MCP Server 可以提供多个工具,每个工具对应一个 McpTool 实例。
|
||||
"""
|
||||
|
||||
def __init__(self, client, tool_schema: dict, server_name: str):
|
||||
"""
|
||||
:param client: 该工具所属的 McpClient 实例
|
||||
:param tool_schema: MCP 返回的工具描述,格式:
|
||||
{"name": str, "description": str, "inputSchema": dict}
|
||||
:param server_name: Server 名称,用于日志
|
||||
"""
|
||||
self.client = client
|
||||
self.server_name = server_name
|
||||
self.name = tool_schema["name"]
|
||||
self.description = tool_schema.get("description", "")
|
||||
self.params = tool_schema.get("inputSchema", {})
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
logger.info(f"[McpTool] server={self.server_name} tool={self.name} params={params}")
|
||||
try:
|
||||
result = self.client.call_tool(self.name, params)
|
||||
return ToolResult.success(result)
|
||||
except Exception as e:
|
||||
logger.error(f"[McpTool] server={self.server_name} tool={self.name} error: {e}")
|
||||
return ToolResult.fail(str(e))
|
||||
10
agent/tools/memory/__init__.py
Normal file
10
agent/tools/memory/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Memory tools for Agent
|
||||
|
||||
Provides memory_search and memory_get tools
|
||||
"""
|
||||
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
__all__ = ['MemorySearchTool', 'MemoryGetTool']
|
||||
128
agent/tools/memory/memory_get.py
Normal file
128
agent/tools/memory/memory_get.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Memory get tool
|
||||
|
||||
Allows agents to read specific sections from memory files
|
||||
"""
|
||||
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class MemoryGetTool(BaseTool):
|
||||
"""Tool for reading memory file contents"""
|
||||
|
||||
name: str = "memory_get"
|
||||
description: str = (
|
||||
"Read specific content from memory files. "
|
||||
"Use this to get full context from a memory file or specific line range."
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory file (e.g. 'MEMORY.md', 'memory/2026-01-01.md')"
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Starting line number (optional, default: 1)",
|
||||
"default": 1
|
||||
},
|
||||
"num_lines": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read (optional, reads all if not specified)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager):
|
||||
"""
|
||||
Initialize memory get tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
self.description = (
|
||||
"Read specific content from memory or knowledge files. "
|
||||
"Use this to get full context from a memory file, knowledge page, or specific line range."
|
||||
)
|
||||
self.params = {**self.params}
|
||||
self.params["properties"] = {**self.params["properties"]}
|
||||
self.params["properties"]["path"] = {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory or knowledge file (e.g. 'MEMORY.md', 'memory/2026-01-01.md', 'knowledge/concepts/moe.md')"
|
||||
}
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
Execute memory file read
|
||||
|
||||
Args:
|
||||
args: Dictionary with path, start_line, num_lines
|
||||
|
||||
Returns:
|
||||
ToolResult with file content
|
||||
"""
|
||||
from agent.tools.base_tool import ToolResult
|
||||
|
||||
path = args.get("path")
|
||||
start_line = args.get("start_line", 1)
|
||||
num_lines = args.get("num_lines")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
try:
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
# Exceptions: MEMORY.md in root, knowledge/ files at workspace root
|
||||
if not path.startswith('memory/') and not path.startswith('knowledge/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = (workspace_dir / path).resolve()
|
||||
workspace_resolved = workspace_dir.resolve()
|
||||
|
||||
if not str(file_path).startswith(str(workspace_resolved) + '/') and file_path != workspace_resolved:
|
||||
return ToolResult.fail(f"Error: Access denied: path outside workspace")
|
||||
|
||||
if not file_path.exists():
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
lines = content.split('\n')
|
||||
|
||||
# Handle line range
|
||||
if start_line < 1:
|
||||
start_line = 1
|
||||
|
||||
start_idx = start_line - 1
|
||||
|
||||
if num_lines:
|
||||
end_idx = start_idx + num_lines
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
else:
|
||||
selected_lines = lines[start_idx:]
|
||||
|
||||
result = '\n'.join(selected_lines)
|
||||
|
||||
# Add metadata
|
||||
total_lines = len(lines)
|
||||
shown_lines = len(selected_lines)
|
||||
|
||||
output = [
|
||||
f"File: {path}",
|
||||
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
|
||||
"",
|
||||
result
|
||||
]
|
||||
|
||||
return ToolResult.success('\n'.join(output))
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading memory file: {str(e)}")
|
||||
109
agent/tools/memory/memory_search.py
Normal file
109
agent/tools/memory/memory_search.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Memory search tool
|
||||
|
||||
Allows agents to search their memory using semantic and keyword search
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class MemorySearchTool(BaseTool):
|
||||
"""Tool for searching agent memory"""
|
||||
|
||||
name: str = "memory_search"
|
||||
description: str = (
|
||||
"Search agent's long-term memory using semantic and keyword search. "
|
||||
"Use this to recall past conversations, preferences, and knowledge."
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (can be natural language question or keywords)"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 10)",
|
||||
"default": 10
|
||||
},
|
||||
"min_score": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score (0-1, default: 0.1)",
|
||||
"default": 0.1
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager, user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize memory search tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
user_id: Optional user ID for scoped search
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self.user_id = user_id
|
||||
|
||||
from config import conf
|
||||
if conf().get("knowledge", True):
|
||||
self.description = (
|
||||
"Search agent's long-term memory and knowledge base using semantic and keyword search. "
|
||||
"Use this to recall past conversations, preferences, and knowledge pages."
|
||||
)
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
Execute memory search
|
||||
|
||||
Args:
|
||||
args: Dictionary with query, max_results, min_score
|
||||
|
||||
Returns:
|
||||
ToolResult with formatted search results
|
||||
"""
|
||||
from agent.tools.base_tool import ToolResult
|
||||
import asyncio
|
||||
|
||||
query = args.get("query")
|
||||
max_results = args.get("max_results", 10)
|
||||
min_score = args.get("min_score", 0.1)
|
||||
|
||||
if not query:
|
||||
return ToolResult.fail("Error: query parameter is required")
|
||||
|
||||
try:
|
||||
# Run async search in sync context
|
||||
results = asyncio.run(self.memory_manager.search(
|
||||
query=query,
|
||||
user_id=self.user_id,
|
||||
max_results=max_results,
|
||||
min_score=min_score,
|
||||
include_shared=True
|
||||
))
|
||||
|
||||
if not results:
|
||||
# Return clear message that no memories exist yet
|
||||
# This prevents infinite retry loops
|
||||
return ToolResult.success(
|
||||
f"No memories found for '{query}'. "
|
||||
f"This is normal if no memories have been stored yet. "
|
||||
f"You can store new memories by writing to MEMORY.md or memory/YYYY-MM-DD.md files."
|
||||
)
|
||||
|
||||
# Format results
|
||||
output = [f"Found {len(results)} relevant memories:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
|
||||
output.append(f" Score: {result.score:.3f}")
|
||||
output.append(f" Snippet: {result.snippet}")
|
||||
|
||||
return ToolResult.success("\n".join(output))
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error searching memory: {str(e)}")
|
||||
3
agent/tools/read/__init__.py
Normal file
3
agent/tools/read/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .read import Read
|
||||
|
||||
__all__ = ['Read']
|
||||
548
agent/tools/read/read.py
Normal file
548
agent/tools/read/read.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""
|
||||
Read tool - Read file contents
|
||||
Supports text files, images (jpg, png, gif, webp), and PDF files
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Read(BaseTool):
|
||||
"""Tool for reading file contents"""
|
||||
|
||||
name: str = "read"
|
||||
description: str = f"Read or inspect file contents. For text/PDF files, returns content (truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB). For images/videos/audio, returns metadata only (file info, size, type). Use offset/limit for large text files."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to read. IMPORTANT: Relative paths are based on workspace directory. To access files outside workspace, use absolute paths starting with ~ or /."
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, optional). Use negative values to read from end (e.g. -20 for last 20 lines)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (optional)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
# File type categories
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
|
||||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
|
||||
self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
|
||||
self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
|
||||
self.pdf_extensions = {'.pdf'}
|
||||
self.office_extensions = {'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx'}
|
||||
|
||||
# Readable text formats (will be read with truncation)
|
||||
self.text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
|
||||
'.py', '.js', '.ts', '.java', '.c', '.cpp', '.h', '.hpp', '.go', '.rs', '.rb', '.php',
|
||||
'.html', '.css', '.scss', '.sass', '.less', '.vue', '.jsx', '.tsx',
|
||||
'.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
|
||||
'.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
|
||||
'.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
|
||||
}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file read operation
|
||||
|
||||
:param args: Contains file path and optional offset/limit parameters
|
||||
:return: File content or error message
|
||||
"""
|
||||
# Support 'location' as alias for 'path' (LLM may use it from skill listing)
|
||||
path = args.get("path", "") or args.get("location", "")
|
||||
path = path.strip() if isinstance(path, str) else ""
|
||||
offset = args.get("offset")
|
||||
limit = args.get("limit")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent reading sensitive config files
|
||||
env_config_path = expand_path("~/.cow/.env")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_path):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
# Provide helpful hint if using relative path
|
||||
if not os.path.isabs(path) and not path.startswith('~'):
|
||||
return ToolResult.fail(
|
||||
f"Error: File not found: {path}\n"
|
||||
f"Resolved to: {absolute_path}\n"
|
||||
f"Hint: Relative paths are based on workspace ({self.cwd}). For files outside workspace, use absolute paths."
|
||||
)
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable
|
||||
if not os.access(absolute_path, os.R_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable: {path}")
|
||||
|
||||
# Check file type
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
|
||||
# Check if image - return metadata for sending
|
||||
if file_ext in self.image_extensions:
|
||||
return self._read_image(absolute_path, file_ext)
|
||||
|
||||
# Check if video/audio/binary/archive - return metadata only
|
||||
if file_ext in self.video_extensions:
|
||||
return self._return_file_metadata(absolute_path, "video", file_size)
|
||||
if file_ext in self.audio_extensions:
|
||||
return self._return_file_metadata(absolute_path, "audio", file_size)
|
||||
if file_ext in self.binary_extensions or file_ext in self.archive_extensions:
|
||||
return self._return_file_metadata(absolute_path, "binary", file_size)
|
||||
|
||||
# Check if PDF
|
||||
if file_ext in self.pdf_extensions:
|
||||
return self._read_pdf(absolute_path, path, offset, limit)
|
||||
|
||||
# Check if Office document (.docx, .xlsx, .pptx, etc.)
|
||||
if file_ext in self.office_extensions:
|
||||
return self._read_office(absolute_path, path, file_ext, offset, limit)
|
||||
|
||||
# Read text file (with truncation for large files)
|
||||
return self._read_text(absolute_path, path, offset, limit)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _return_file_metadata(self, absolute_path: str, file_type: str, file_size: int) -> ToolResult:
|
||||
"""
|
||||
Return file metadata for non-readable files (video, audio, binary, etc.)
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param file_type: Type of file (video, audio, binary, etc.)
|
||||
:param file_size: File size in bytes
|
||||
:return: File metadata
|
||||
"""
|
||||
file_name = Path(absolute_path).name
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
|
||||
# Determine MIME type
|
||||
mime_types = {
|
||||
# Video
|
||||
'.mp4': 'video/mp4', '.avi': 'video/x-msvideo', '.mov': 'video/quicktime',
|
||||
'.mkv': 'video/x-matroska', '.webm': 'video/webm',
|
||||
# Audio
|
||||
'.mp3': 'audio/mpeg', '.wav': 'audio/wav', '.ogg': 'audio/ogg',
|
||||
'.m4a': 'audio/mp4', '.flac': 'audio/flac',
|
||||
# Binary
|
||||
'.zip': 'application/zip', '.tar': 'application/x-tar',
|
||||
'.gz': 'application/gzip', '.rar': 'application/x-rar-compressed',
|
||||
}
|
||||
mime_type = mime_types.get(file_ext, 'application/octet-stream')
|
||||
|
||||
result = {
|
||||
"type": f"{file_type}_metadata",
|
||||
"file_type": file_type,
|
||||
"path": absolute_path,
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"{file_type.capitalize()} 文件: {file_name} ({format_size(file_size)})\n提示: 如果需要发送此文件,请使用 send 工具。"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
|
||||
"""
|
||||
Read image file - always return metadata only (images should be sent, not read into context)
|
||||
|
||||
:param absolute_path: Absolute path to the image file
|
||||
:param file_ext: File extension
|
||||
:return: Result containing image metadata for sending
|
||||
"""
|
||||
try:
|
||||
# Get file size
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
|
||||
# Determine MIME type
|
||||
mime_type_map = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp'
|
||||
}
|
||||
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
|
||||
|
||||
# Return metadata for images (NOT file_to_send - use send tool to actually send)
|
||||
result = {
|
||||
"type": "image_metadata",
|
||||
"file_type": "image",
|
||||
"path": absolute_path,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"图片文件: {Path(absolute_path).name} ({format_size(file_size)})\n提示: 如果需要发送此图片,请使用 send 工具。"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading image file: {str(e)}")
|
||||
|
||||
def _read_text(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read text file
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param display_path: Path to display
|
||||
:param offset: Starting line number (1-indexed)
|
||||
:param limit: Maximum number of lines to read
|
||||
:return: File content or error message
|
||||
"""
|
||||
try:
|
||||
# Check file size first
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
# File too large, return metadata only
|
||||
return ToolResult.success({
|
||||
"type": "file_to_send",
|
||||
"file_type": "document",
|
||||
"path": absolute_path,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
|
||||
})
|
||||
|
||||
# Read file (utf-8-sig strips BOM automatically on Windows)
|
||||
# Note: Truncation is unified via truncate_head (DEFAULT_MAX_LINES / DEFAULT_MAX_BYTES)
|
||||
# so that offset/limit can paginate the entire file correctly.
|
||||
with open(absolute_path, 'r', encoding='utf-8-sig') as f:
|
||||
content = f.read()
|
||||
|
||||
all_lines = content.split('\n')
|
||||
total_file_lines = len(all_lines)
|
||||
|
||||
# Apply offset (if specified)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
if offset < 0:
|
||||
# Negative offset: read from end
|
||||
# -20 means "last 20 lines" → start from (total - 20)
|
||||
start_line = max(0, total_file_lines + offset)
|
||||
else:
|
||||
# Positive offset: read from start (1-indexed)
|
||||
start_line = max(0, offset - 1) # Convert to 0-indexed
|
||||
if start_line >= total_file_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1 # For display (1-indexed)
|
||||
|
||||
# If user specified limit, use it
|
||||
selected_content = content
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_file_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
# Apply truncation (considering line count and byte limits)
|
||||
truncation = truncate_head(selected_content)
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
if truncation.first_line_exceeds_limit:
|
||||
# First line exceeds 30KB limit
|
||||
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
|
||||
output_text = f"[Line {start_line_display} is {first_line_size}, exceeds {format_size(DEFAULT_MAX_BYTES)} limit. Use bash tool to read: head -c {DEFAULT_MAX_BYTES} {display_path} | tail -n +{start_line_display}]"
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif truncation.truncated:
|
||||
# Truncation occurred
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
|
||||
output_text = truncation.content
|
||||
|
||||
if truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines}. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_file_lines:
|
||||
# User specified limit, more content available, but no truncation
|
||||
remaining = total_file_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
# No truncation, no exceeding user limit
|
||||
output_text = truncation.content
|
||||
|
||||
result = {
|
||||
"content": output_text,
|
||||
"total_lines": total_file_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines
|
||||
}
|
||||
|
||||
if details:
|
||||
result["details"] = details
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {display_path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading file: {str(e)}")
|
||||
|
||||
def _read_office(self, absolute_path: str, display_path: str, file_ext: str,
|
||||
offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""Read Office documents (.docx, .xlsx, .pptx) using python-docx / openpyxl / python-pptx."""
|
||||
try:
|
||||
text = self._extract_office_text(absolute_path, file_ext)
|
||||
except ImportError as e:
|
||||
return ToolResult.fail(str(e))
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading Office document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
return ToolResult.success({
|
||||
"content": f"[Office file {Path(absolute_path).name}: no text content could be extracted]",
|
||||
})
|
||||
|
||||
all_lines = text.split('\n')
|
||||
total_lines = len(all_lines)
|
||||
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
if offset < 0:
|
||||
start_line = max(0, total_lines + offset)
|
||||
else:
|
||||
start_line = max(0, offset - 1)
|
||||
if start_line >= total_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||
)
|
||||
|
||||
selected_content = text
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
truncation = truncate_head(selected_content)
|
||||
start_line_display = start_line + 1
|
||||
output_text = ""
|
||||
|
||||
if truncation.truncated:
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||
remaining = total_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text = truncation.content
|
||||
|
||||
return ToolResult.success({
|
||||
"content": output_text,
|
||||
"total_lines": total_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _extract_office_text(absolute_path: str, file_ext: str) -> str:
|
||||
"""Extract plain text from an Office document."""
|
||||
if file_ext in ('.docx', '.doc'):
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError("Error: python-docx library not installed. Install with: pip install python-docx")
|
||||
doc = Document(absolute_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs]
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
paragraphs.append('\t'.join(cell.text for cell in row.cells))
|
||||
return '\n'.join(paragraphs)
|
||||
|
||||
if file_ext in ('.xlsx', '.xls'):
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
except ImportError:
|
||||
raise ImportError("Error: openpyxl library not installed. Install with: pip install openpyxl")
|
||||
wb = load_workbook(absolute_path, read_only=True, data_only=True)
|
||||
parts = []
|
||||
for ws in wb.worksheets:
|
||||
parts.append(f"--- Sheet: {ws.title} ---")
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
parts.append('\t'.join(str(c) if c is not None else '' for c in row))
|
||||
wb.close()
|
||||
return '\n'.join(parts)
|
||||
|
||||
if file_ext in ('.pptx', '.ppt'):
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError("Error: python-pptx library not installed. Install with: pip install python-pptx")
|
||||
prs = Presentation(absolute_path)
|
||||
parts = []
|
||||
for i, slide in enumerate(prs.slides, 1):
|
||||
parts.append(f"--- Slide {i} ---")
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for para in shape.text_frame.paragraphs:
|
||||
text = para.text.strip()
|
||||
if text:
|
||||
parts.append(text)
|
||||
return '\n'.join(parts)
|
||||
|
||||
return ""
|
||||
|
||||
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read PDF file content
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param display_path: Path to display
|
||||
:param offset: Starting line number (1-indexed)
|
||||
:param limit: Maximum number of lines to read
|
||||
:return: PDF text content or error message
|
||||
"""
|
||||
try:
|
||||
# Try to import pypdf
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
return ToolResult.fail(
|
||||
"Error: pypdf library not installed. Install with: pip install pypdf"
|
||||
)
|
||||
|
||||
# Read PDF
|
||||
reader = PdfReader(absolute_path)
|
||||
total_pages = len(reader.pages)
|
||||
|
||||
# Extract text from all pages
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num} ---\n{page_text}")
|
||||
|
||||
if not text_parts:
|
||||
return ToolResult.success({
|
||||
"content": f"[PDF file with {total_pages} pages, but no text content could be extracted]",
|
||||
"total_pages": total_pages,
|
||||
"message": "PDF may contain only images or be encrypted"
|
||||
})
|
||||
|
||||
# Merge all text
|
||||
full_content = "\n\n".join(text_parts)
|
||||
all_lines = full_content.split('\n')
|
||||
total_lines = len(all_lines)
|
||||
|
||||
# Apply offset and limit (same logic as text files)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
start_line = max(0, offset - 1)
|
||||
if start_line >= total_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1
|
||||
|
||||
selected_content = full_content
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
# Apply truncation
|
||||
truncation = truncate_head(selected_content)
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
if truncation.truncated:
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
|
||||
output_text = truncation.content
|
||||
|
||||
if truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||
remaining = total_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text = truncation.content
|
||||
|
||||
result = {
|
||||
"content": output_text,
|
||||
"total_pages": total_pages,
|
||||
"total_lines": total_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines
|
||||
}
|
||||
|
||||
if details:
|
||||
result["details"] = details
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading PDF file: {str(e)}")
|
||||
287
agent/tools/scheduler/README.md
Normal file
287
agent/tools/scheduler/README.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# 定时任务工具 (Scheduler Tool)
|
||||
|
||||
## 功能简介
|
||||
|
||||
定时任务工具允许 Agent 创建、管理和执行定时任务,支持:
|
||||
|
||||
- ⏰ **定时提醒**: 在指定时间发送消息
|
||||
- 🔄 **周期性任务**: 按固定间隔或 cron 表达式重复执行
|
||||
- 🔧 **动态工具调用**: 定时执行其他工具并发送结果(如搜索新闻、查询天气等)
|
||||
- 📋 **任务管理**: 查询、启用、禁用、删除任务
|
||||
|
||||
## 安装依赖
|
||||
|
||||
```bash
|
||||
pip install croniter>=2.0.0
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 创建定时任务
|
||||
|
||||
Agent 可以通过自然语言创建定时任务,支持两种类型:
|
||||
|
||||
#### 1.1 静态消息任务
|
||||
|
||||
发送预定义的消息:
|
||||
|
||||
**示例对话:**
|
||||
```
|
||||
用户: 每天早上9点提醒我开会
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: create
|
||||
name: 每日开会提醒
|
||||
message: 该开会了!
|
||||
schedule_type: cron
|
||||
schedule_value: 0 9 * * *
|
||||
```
|
||||
|
||||
#### 1.2 动态工具调用任务
|
||||
|
||||
定时执行工具并发送结果:
|
||||
|
||||
**示例对话:**
|
||||
```
|
||||
用户: 每天早上8点帮我读取一下今日日程
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: create
|
||||
name: 每日日程
|
||||
tool_call:
|
||||
tool_name: read
|
||||
tool_params:
|
||||
file_path: ~/cow/schedule.txt
|
||||
result_prefix: 📅 今日日程
|
||||
schedule_type: cron
|
||||
schedule_value: 0 8 * * *
|
||||
```
|
||||
|
||||
**工具调用参数说明:**
|
||||
- `tool_name`: 要调用的工具名称(如 `bash`、`read`、`write` 等内置工具)
|
||||
- `tool_params`: 工具的参数(字典格式)
|
||||
- `result_prefix`: 可选,在结果前添加的前缀文本
|
||||
|
||||
**注意:** 如果要使用 skills(如 bocha-search),需要通过 `bash` 工具调用 skill 脚本
|
||||
|
||||
### 2. 支持的调度类型
|
||||
|
||||
#### Cron 表达式 (`cron`)
|
||||
使用标准 cron 表达式:
|
||||
|
||||
```
|
||||
0 9 * * * # 每天 9:00
|
||||
0 */2 * * * # 每 2 小时
|
||||
30 8 * * 1-5 # 工作日 8:30
|
||||
0 0 1 * * # 每月 1 号
|
||||
```
|
||||
|
||||
#### 固定间隔 (`interval`)
|
||||
以秒为单位的间隔:
|
||||
|
||||
```
|
||||
3600 # 每小时
|
||||
86400 # 每天
|
||||
1800 # 每 30 分钟
|
||||
```
|
||||
|
||||
#### 一次性任务 (`once`)
|
||||
指定具体时间(ISO 格式):
|
||||
|
||||
```
|
||||
2024-12-25T09:00:00
|
||||
2024-12-31T23:59:59
|
||||
```
|
||||
|
||||
### 3. 查询任务列表
|
||||
|
||||
```
|
||||
用户: 查看我的定时任务
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: list
|
||||
```
|
||||
|
||||
### 4. 查看任务详情
|
||||
|
||||
```
|
||||
用户: 查看任务 abc123 的详情
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: get
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
### 5. 删除任务
|
||||
|
||||
```
|
||||
用户: 删除任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: delete
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
### 6. 启用/禁用任务
|
||||
|
||||
```
|
||||
用户: 暂停任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: disable
|
||||
task_id: abc123
|
||||
|
||||
用户: 恢复任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: enable
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
## 任务存储
|
||||
|
||||
任务保存在 JSON 文件中:
|
||||
```
|
||||
~/cow/scheduler/tasks.json
|
||||
```
|
||||
|
||||
任务数据结构:
|
||||
|
||||
**静态消息任务:**
|
||||
```json
|
||||
{
|
||||
"id": "abc123",
|
||||
"name": "每日提醒",
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-01T10:00:00",
|
||||
"updated_at": "2024-01-01T10:00:00",
|
||||
"schedule": {
|
||||
"type": "cron",
|
||||
"expression": "0 9 * * *"
|
||||
},
|
||||
"action": {
|
||||
"type": "send_message",
|
||||
"content": "该开会了!",
|
||||
"receiver": "wxid_xxx",
|
||||
"receiver_name": "张三",
|
||||
"is_group": false,
|
||||
"channel_type": "wechat"
|
||||
},
|
||||
"next_run_at": "2024-01-02T09:00:00",
|
||||
"last_run_at": "2024-01-01T09:00:00"
|
||||
}
|
||||
```
|
||||
|
||||
**动态工具调用任务:**
|
||||
```json
|
||||
{
|
||||
"id": "def456",
|
||||
"name": "每日日程",
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-01T10:00:00",
|
||||
"updated_at": "2024-01-01T10:00:00",
|
||||
"schedule": {
|
||||
"type": "cron",
|
||||
"expression": "0 8 * * *"
|
||||
},
|
||||
"action": {
|
||||
"type": "tool_call",
|
||||
"tool_name": "read",
|
||||
"tool_params": {
|
||||
"file_path": "~/cow/schedule.txt"
|
||||
},
|
||||
"result_prefix": "📅 今日日程",
|
||||
"receiver": "wxid_xxx",
|
||||
"receiver_name": "张三",
|
||||
"is_group": false,
|
||||
"channel_type": "wechat"
|
||||
},
|
||||
"next_run_at": "2024-01-02T08:00:00"
|
||||
}
|
||||
```
|
||||
|
||||
## 后台服务
|
||||
|
||||
定时任务由后台服务 `SchedulerService` 管理:
|
||||
|
||||
- 每 30 秒检查一次到期任务
|
||||
- 自动执行到期任务
|
||||
- 计算下次执行时间
|
||||
- 记录执行历史和错误
|
||||
|
||||
服务在 Agent 初始化时自动启动,无需手动配置。
|
||||
|
||||
## 接收者确定
|
||||
|
||||
定时任务会发送给**创建任务时的对话对象**:
|
||||
|
||||
- 如果在私聊中创建,发送给该用户
|
||||
- 如果在群聊中创建,发送到该群
|
||||
- 接收者信息在创建时自动保存
|
||||
|
||||
## 常见用例
|
||||
|
||||
### 1. 每日提醒(静态消息)
|
||||
```
|
||||
用户: 每天早上8点提醒我吃药
|
||||
Agent: ✅ 定时任务创建成功
|
||||
任务ID: a1b2c3d4
|
||||
调度: 每天 8:00
|
||||
消息: 该吃药了!
|
||||
```
|
||||
|
||||
### 2. 工作日提醒(静态消息)
|
||||
```
|
||||
用户: 工作日下午6点提醒我下班
|
||||
Agent: [创建 cron: 0 18 * * 1-5]
|
||||
消息: 该下班了!
|
||||
```
|
||||
|
||||
### 3. 倒计时提醒(静态消息)
|
||||
```
|
||||
用户: 1小时后提醒我
|
||||
Agent: [创建 interval: 3600]
|
||||
```
|
||||
|
||||
### 4. 每日日程推送(动态工具调用)
|
||||
```
|
||||
用户: 每天早上8点帮我读取今日日程
|
||||
Agent: ✅ 定时任务创建成功
|
||||
任务ID: schedule001
|
||||
调度: 每天 8:00
|
||||
工具: read(file_path='~/cow/schedule.txt')
|
||||
前缀: 📅 今日日程
|
||||
```
|
||||
|
||||
### 5. 定时文件备份(动态工具调用)
|
||||
```
|
||||
用户: 每天晚上11点备份工作文件
|
||||
Agent: [创建 cron: 0 23 * * *]
|
||||
工具: bash(command='cp ~/cow/work.txt ~/cow/backup/work_$(date +%Y%m%d).txt')
|
||||
前缀: ✅ 文件已备份
|
||||
```
|
||||
|
||||
### 6. 周报提醒(静态消息)
|
||||
```
|
||||
用户: 每周五下午5点提醒我写周报
|
||||
Agent: [创建 cron: 0 17 * * 5]
|
||||
消息: 📊 该写周报了!
|
||||
```
|
||||
|
||||
### 4. 特定日期提醒
|
||||
```
|
||||
用户: 12月25日早上9点提醒我圣诞快乐
|
||||
Agent: [创建 once: 2024-12-25T09:00:00]
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **时区**: 使用系统本地时区
|
||||
2. **精度**: 检查间隔为 30 秒,实际执行可能有 ±30 秒误差
|
||||
3. **持久化**: 任务保存在文件中,重启后自动恢复
|
||||
4. **一次性任务**: 执行后自动禁用,不会删除(可手动删除)
|
||||
5. **错误处理**: 执行失败会记录错误,不影响其他任务
|
||||
|
||||
## 技术实现
|
||||
|
||||
- **TaskStore**: 任务持久化存储
|
||||
- **SchedulerService**: 后台调度服务
|
||||
- **SchedulerTool**: Agent 工具接口
|
||||
- **Integration**: 与 AgentBridge 集成
|
||||
|
||||
## 依赖
|
||||
|
||||
- `croniter`: Cron 表达式解析(轻量级,仅 ~50KB)
|
||||
7
agent/tools/scheduler/__init__.py
Normal file
7
agent/tools/scheduler/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Scheduler tool for managing scheduled tasks
|
||||
"""
|
||||
|
||||
from .scheduler_tool import SchedulerTool
|
||||
|
||||
__all__ = ["SchedulerTool"]
|
||||
548
agent/tools/scheduler/integration.py
Normal file
548
agent/tools/scheduler/integration.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""
|
||||
Integration module for scheduler with AgentBridge
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from typing import Optional
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
# Global scheduler service instance
|
||||
_scheduler_service = None
|
||||
_task_store = None
|
||||
# Module-level lock to guard idempotent initialization across threads
|
||||
_init_lock = threading.Lock()
|
||||
|
||||
|
||||
def init_scheduler(agent_bridge) -> bool:
|
||||
"""
|
||||
Initialize scheduler service (idempotent).
|
||||
|
||||
Safe to call multiple times and from multiple threads: only the first
|
||||
successful call creates the singleton ``SchedulerService`` + background
|
||||
scanning thread. Subsequent calls return immediately.
|
||||
|
||||
Args:
|
||||
agent_bridge: AgentBridge instance
|
||||
|
||||
Returns:
|
||||
True if scheduler is initialized (newly created or already running)
|
||||
"""
|
||||
global _scheduler_service, _task_store
|
||||
|
||||
# Fast path: already initialized and running
|
||||
if _scheduler_service is not None and getattr(_scheduler_service, "running", False):
|
||||
return True
|
||||
|
||||
with _init_lock:
|
||||
# Re-check under the lock to avoid races where multiple threads
|
||||
# passed the fast-path check before any of them acquired the lock.
|
||||
if _scheduler_service is not None and getattr(_scheduler_service, "running", False):
|
||||
return True
|
||||
|
||||
try:
|
||||
from agent.tools.scheduler.task_store import TaskStore
|
||||
from agent.tools.scheduler.scheduler_service import SchedulerService
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
|
||||
|
||||
# Create task store (reuse if already created)
|
||||
if _task_store is None:
|
||||
_task_store = TaskStore(store_path)
|
||||
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
|
||||
|
||||
# Create execute callback. Returns True on success, False to ask
|
||||
# the scheduler to retry on the next tick (e.g. channel not yet
|
||||
# ready right after process start).
|
||||
def execute_task_callback(task: dict):
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type")
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
receiver = action.get("receiver", "")
|
||||
|
||||
if not _is_channel_ready(channel_type, receiver):
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task.get('id')}: channel "
|
||||
f"'{channel_type}' not ready for receiver={receiver} "
|
||||
f"(no inbound msg cached since restart?); deferring"
|
||||
)
|
||||
return False
|
||||
|
||||
if action_type == "agent_task":
|
||||
return _execute_agent_task(task, agent_bridge)
|
||||
elif action_type == "send_message":
|
||||
return _execute_send_message(task, agent_bridge)
|
||||
elif action_type == "tool_call":
|
||||
return _execute_tool_call(task, agent_bridge)
|
||||
elif action_type == "skill_call":
|
||||
return _execute_skill_call(task, agent_bridge)
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
|
||||
return False
|
||||
|
||||
# Create scheduler service
|
||||
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
|
||||
_scheduler_service.start()
|
||||
|
||||
logger.info("[Scheduler] Service initialized and started")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _is_channel_ready(channel_type: str, receiver: str) -> bool:
|
||||
"""Best-effort readiness probe for outbound channels.
|
||||
|
||||
Returns False when we know the send will drop (e.g. weixin not yet
|
||||
logged in, web session has no polling queue), so the scheduler can
|
||||
defer instead of consuming the task. Unknown channels return True
|
||||
to preserve previous behaviour.
|
||||
"""
|
||||
if not channel_type or channel_type == "unknown":
|
||||
return True
|
||||
try:
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if channel is None:
|
||||
return False
|
||||
|
||||
if channel_type == "weixin":
|
||||
tokens = getattr(channel, "_context_tokens", None)
|
||||
if not tokens or receiver not in tokens:
|
||||
return False
|
||||
return True
|
||||
|
||||
if channel_type == "web":
|
||||
queues = getattr(channel, "session_queues", None)
|
||||
if not queues or receiver not in queues:
|
||||
return False
|
||||
return True
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[Scheduler] Channel readiness check failed for {channel_type}: {e}")
|
||||
return True
|
||||
|
||||
|
||||
def get_task_store():
|
||||
"""Get the global task store instance"""
|
||||
return _task_store
|
||||
|
||||
|
||||
def get_scheduler_service():
|
||||
"""Get the global scheduler service instance"""
|
||||
return _scheduler_service
|
||||
|
||||
|
||||
def _remember_delivered_output(
|
||||
agent_bridge,
|
||||
task: dict,
|
||||
channel_type: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Best-effort persistence of the message the scheduler sent to a user.
|
||||
|
||||
Uses notify_session_id (the real chat session_id stored at task creation time)
|
||||
so that group chats correctly associate the output with the user's conversation.
|
||||
Falls back to receiver for backward compatibility with old tasks.
|
||||
|
||||
Per-action-type behaviour:
|
||||
- agent_task / tool_call / skill_call: gated by ``scheduler_inject_to_session``
|
||||
(default True). These produce AI-generated content worth remembering.
|
||||
- send_message: additionally gated by ``scheduler_inject_send_message``
|
||||
(default False). Fixed reminder text rarely benefits follow-up Q&A and
|
||||
would just consume context tokens.
|
||||
"""
|
||||
if not content:
|
||||
return
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
|
||||
# send_message defaults to NOT being injected; explicit opt-in via config.
|
||||
if action_type == "send_message":
|
||||
if not conf().get("scheduler_inject_send_message", False):
|
||||
return
|
||||
|
||||
session_id = action.get("notify_session_id") or action.get("receiver")
|
||||
if not session_id:
|
||||
return
|
||||
try:
|
||||
remember = getattr(agent_bridge, "remember_scheduled_output", None)
|
||||
if remember:
|
||||
task_desc = action.get("task_description") or action.get("content", "")
|
||||
remember(session_id, str(content), channel_type=channel_type, task_description=task_desc)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[Scheduler] Failed to remember delivered output for {session_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _execute_agent_task(task: dict, agent_bridge) -> bool:
|
||||
"""
|
||||
Execute an agent_task action - let Agent handle the task.
|
||||
Returns True on successful delivery, False to retry next tick.
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
task_description = action.get("task_description")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not task_description:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No task_description specified")
|
||||
return True # malformed task, don't loop forever
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return True
|
||||
|
||||
# Check for unsupported channels
|
||||
if channel_type == "dingtalk":
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk channel does not support scheduled messages (Stream mode limitation). Task will execute but message cannot be sent.")
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing agent task '{task_description}'")
|
||||
|
||||
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
|
||||
# Format: scheduler_<receiver>_<task_id> to ensure isolation
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
|
||||
# Create context for Agent
|
||||
context = Context(ContextType.TEXT, task_description)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
# Channel-specific setup
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk requires msg object, set to None for scheduled tasks
|
||||
context["msg"] = None
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the task
|
||||
# Mark this as a scheduled task execution to prevent recursive task creation
|
||||
context["is_scheduled_task"] = True
|
||||
|
||||
try:
|
||||
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
|
||||
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=False)
|
||||
|
||||
if not (reply and reply.content):
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution")
|
||||
return True # agent ran but produced nothing; don't loop
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
request_id = context.get("request_id")
|
||||
if request_id:
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send result: {e}")
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, reply.content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute task via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_agent_task: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_send_message(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a send_message action. Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
content = action.get("content", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return True
|
||||
|
||||
# Create context for sending message
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific context setup
|
||||
if channel_type == "web":
|
||||
# Web channel needs request_id
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
|
||||
# Use chat_id for groups, open_id for private chats
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
# Keep isgroup as is, but set msg to None (no original message to reply to)
|
||||
# Feishu channel will detect this and send as new message instead of reply
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk channel setup
|
||||
context["msg"] = None
|
||||
# 如果是单聊,需要传递 sender_staff_id
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
elif channel_type == "qq":
|
||||
context["msg"] = None
|
||||
|
||||
# Create reply
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send message: {e}")
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_send_message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_tool_call(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a tool_call action. Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
tool_name = action.get("call_name") or action.get("tool_name")
|
||||
tool_params = action.get("call_params") or action.get("tool_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not tool_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified")
|
||||
return True
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return True
|
||||
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
tool = ToolManager().create_tool(tool_name)
|
||||
if not tool:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found")
|
||||
return True
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}")
|
||||
result = tool.execute(tool_params)
|
||||
content = result.result if hasattr(result, 'result') else str(result)
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
request_id = None
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and request_id and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
|
||||
try:
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send tool result: {e}")
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _execute_skill_call(task: dict, agent_bridge) -> bool:
|
||||
"""Execute a skill_call action by asking Agent to run the skill.
|
||||
Returns True/False for delivery."""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
skill_name = action.get("call_name") or action.get("skill_name")
|
||||
skill_params = action.get("call_params") or action.get("skill_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("isgroup", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not skill_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified")
|
||||
return True
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return True
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}")
|
||||
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()])
|
||||
query = f"Use {skill_name} skill"
|
||||
if param_str:
|
||||
query += f" with {param_str}"
|
||||
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
try:
|
||||
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
if not (reply and reply.content):
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
|
||||
return True
|
||||
|
||||
content = reply.content
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
from channel.channel_factory import create_channel
|
||||
channel = create_channel(channel_type)
|
||||
if not channel:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
return False
|
||||
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
req_id = context.get("request_id")
|
||||
if req_id:
|
||||
channel.request_to_session[req_id] = receiver
|
||||
|
||||
try:
|
||||
channel.send(Reply(ReplyType.TEXT, content), context)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send skill result: {e}")
|
||||
return False
|
||||
|
||||
_remember_delivered_output(agent_bridge, task, channel_type, content)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_skill_call: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
"""
|
||||
Attach scheduler components to a SchedulerTool instance
|
||||
|
||||
Args:
|
||||
tool: SchedulerTool instance
|
||||
context: Current context (optional)
|
||||
"""
|
||||
if _task_store:
|
||||
tool.task_store = _task_store
|
||||
|
||||
if context:
|
||||
tool.current_context = context
|
||||
|
||||
channel_type = context.get("channel_type") or conf().get("channel_type", "unknown")
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = channel_type
|
||||
243
agent/tools/scheduler/scheduler_service.py
Normal file
243
agent/tools/scheduler/scheduler_service.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
Background scheduler service for executing scheduled tasks
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Optional
|
||||
from croniter import croniter
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def _parse_naive_local(iso_str: str) -> datetime:
|
||||
"""Parse an ISO datetime and coerce it to tz-naive local time.
|
||||
|
||||
The scheduler uses ``datetime.now()`` (tz-naive) for all comparisons,
|
||||
so any persisted timestamp must be normalized to the same flavor —
|
||||
otherwise comparing naive vs aware raises TypeError.
|
||||
"""
|
||||
dt = datetime.fromisoformat(iso_str)
|
||||
if dt.tzinfo is not None:
|
||||
dt = dt.astimezone().replace(tzinfo=None)
|
||||
return dt
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""
|
||||
Background service that executes scheduled tasks
|
||||
"""
|
||||
|
||||
def __init__(self, task_store, execute_callback: Callable):
|
||||
"""
|
||||
Initialize scheduler service
|
||||
|
||||
Args:
|
||||
task_store: TaskStore instance
|
||||
execute_callback: Function to call when executing a task
|
||||
"""
|
||||
self.task_store = task_store
|
||||
self.execute_callback = execute_callback
|
||||
self.running = False
|
||||
self.thread = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def start(self):
|
||||
"""Start the scheduler service"""
|
||||
with self._lock:
|
||||
if self.running:
|
||||
logger.warning("[Scheduler] Service already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the scheduler service"""
|
||||
with self._lock:
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=5)
|
||||
logger.info("[Scheduler] Service stopped")
|
||||
|
||||
def _run_loop(self):
|
||||
"""Main scheduler loop"""
|
||||
logger.info("[Scheduler] Scheduler loop started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self._check_and_execute_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in scheduler loop: {e}")
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
def _check_and_execute_tasks(self):
|
||||
"""Check for due tasks and execute them"""
|
||||
now = datetime.now()
|
||||
tasks = self.task_store.list_tasks(enabled_only=True)
|
||||
|
||||
for task in tasks:
|
||||
try:
|
||||
if self._is_task_due(task, now):
|
||||
logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}")
|
||||
ok = self._execute_task(task)
|
||||
if not ok:
|
||||
# Leave next_run_at as-is so the next loop retries.
|
||||
# Cron tasks within the catch-up window will keep
|
||||
# firing; beyond it _is_task_due will reschedule.
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task['id']} delivery failed, will retry next tick"
|
||||
)
|
||||
continue
|
||||
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_run.isoformat(),
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
else:
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task completed and removed: {task['id']}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
|
||||
|
||||
def _is_task_due(self, task: dict, now: datetime) -> bool:
|
||||
"""
|
||||
Check if a task is due to run
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
now: Current datetime
|
||||
|
||||
Returns:
|
||||
True if task should run now
|
||||
"""
|
||||
next_run_str = task.get("next_run_at")
|
||||
if not next_run_str:
|
||||
# Calculate initial next_run_at
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_run.isoformat()
|
||||
})
|
||||
return False
|
||||
return False
|
||||
|
||||
try:
|
||||
next_run = _parse_naive_local(next_run_str)
|
||||
|
||||
if next_run < now:
|
||||
time_diff = (now - next_run).total_seconds()
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
# Catch-up window: fire if we're within 10 minutes of the
|
||||
# scheduled tick. Beyond that we'd rather skip than push a
|
||||
# stale daily report to the user.
|
||||
if time_diff <= 600:
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, "
|
||||
f"skipping and scheduling next run"
|
||||
)
|
||||
|
||||
if schedule_type == "once":
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, removed")
|
||||
return False
|
||||
|
||||
next_next_run = self._calculate_next_run(task, now)
|
||||
if next_next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_next_run.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
|
||||
return False
|
||||
|
||||
return now >= next_run
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Scheduler] Failed to evaluate due-state for task "
|
||||
f"{task.get('id')} (next_run_at={next_run_str!r}): {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
|
||||
"""
|
||||
Calculate next run time for a task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
from_time: Calculate from this time
|
||||
|
||||
Returns:
|
||||
Next run datetime or None for one-time tasks
|
||||
"""
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
if schedule_type == "cron":
|
||||
# Cron expression
|
||||
expression = schedule.get("expression")
|
||||
if not expression:
|
||||
return None
|
||||
|
||||
try:
|
||||
cron = croniter(expression, from_time)
|
||||
return cron.get_next(datetime)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Invalid cron expression '{expression}': {e}")
|
||||
return None
|
||||
|
||||
elif schedule_type == "interval":
|
||||
# Interval in seconds
|
||||
seconds = schedule.get("seconds", 0)
|
||||
if seconds <= 0:
|
||||
return None
|
||||
return from_time + timedelta(seconds=seconds)
|
||||
|
||||
elif schedule_type == "once":
|
||||
# One-time task at specific time
|
||||
run_at_str = schedule.get("run_at")
|
||||
if not run_at_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
run_at = _parse_naive_local(run_at_str)
|
||||
if run_at > from_time:
|
||||
return run_at
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Scheduler] Failed to parse once-task run_at "
|
||||
f"{run_at_str!r}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _execute_task(self, task: dict) -> bool:
|
||||
"""
|
||||
Execute a task.
|
||||
|
||||
Returns True if delivery succeeded (caller should advance state),
|
||||
False if it failed (caller should keep next_run_at so the next
|
||||
loop iteration retries). Callback may return None for legacy
|
||||
behaviour, treated as success.
|
||||
"""
|
||||
try:
|
||||
result = self.execute_callback(task)
|
||||
return False if result is False else True
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task['id']}: {e}")
|
||||
self.task_store.update_task(task['id'], {
|
||||
"last_error": str(e),
|
||||
"last_error_at": datetime.now().isoformat()
|
||||
})
|
||||
return False
|
||||
453
agent/tools/scheduler/scheduler_tool.py
Normal file
453
agent/tools/scheduler/scheduler_tool.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Scheduler tool for creating and managing scheduled tasks
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from croniter import croniter
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class SchedulerTool(BaseTool):
|
||||
"""
|
||||
Tool for managing scheduled tasks (reminders, notifications, etc.)
|
||||
"""
|
||||
|
||||
name: str = "scheduler"
|
||||
description: str = (
|
||||
"创建、查询和管理定时任务(提醒、周期性任务等)。\n\n"
|
||||
"⚠️ 重要:仅当需要「定时/提醒/每天/每周/X分钟后/X点」等延迟或周期执行时才使用此工具。"
|
||||
"使用方法:\n"
|
||||
"- 创建:action='create', name='任务名', message/ai_task='内容', schedule_type='once/interval/cron', schedule_value='...'\n"
|
||||
"- 查询:action='list' / action='get', task_id='任务ID'\n"
|
||||
"- 管理:action='delete/enable/disable', task_id='任务ID'\n\n"
|
||||
"调度类型:\n"
|
||||
"- once: 一次性任务,支持相对时间(+5s,+10m,+1h,+1d)或ISO时间\n"
|
||||
"- interval: 固定间隔(秒),如3600表示每小时\n"
|
||||
"- cron: cron表达式,如'0 8 * * *'表示每天8点\n\n"
|
||||
"注意:'X秒后'用once+相对时间,'每X秒'用interval"
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["create", "list", "get", "delete", "enable", "disable"],
|
||||
"description": "操作类型: create(创建), list(列表), get(查询), delete(删除), enable(启用), disable(禁用)"
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "任务ID (用于 get/delete/enable/disable 操作)"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "任务名称 (用于 create 操作)"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "固定消息内容 (与ai_task二选一)"
|
||||
},
|
||||
"ai_task": {
|
||||
"type": "string",
|
||||
"description": "AI任务描述 (与message二选一),用于定时让AI执行的任务"
|
||||
},
|
||||
"schedule_type": {
|
||||
"type": "string",
|
||||
"enum": ["cron", "interval", "once"],
|
||||
"description": "调度类型 (用于 create 操作): cron(cron表达式), interval(固定间隔秒数), once(一次性)"
|
||||
},
|
||||
"schedule_value": {
|
||||
"type": "string",
|
||||
"description": "调度值: cron表达式/间隔秒数/时间(+5s,+10m,+1h或ISO格式)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
super().__init__()
|
||||
self.config = config or {}
|
||||
|
||||
# Will be set by agent bridge
|
||||
self.task_store = None
|
||||
self.current_context = None
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
"""
|
||||
Execute scheduler operations
|
||||
|
||||
Args:
|
||||
params: Dictionary containing:
|
||||
- action: Operation type (create/list/get/delete/enable/disable)
|
||||
- Other parameters depending on action
|
||||
|
||||
Returns:
|
||||
ToolResult object
|
||||
"""
|
||||
# Extract parameters
|
||||
action = params.get("action")
|
||||
kwargs = params
|
||||
|
||||
if not self.task_store:
|
||||
return ToolResult.fail("错误: 定时任务系统未初始化")
|
||||
|
||||
try:
|
||||
if action == "create":
|
||||
result = self._create_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "list":
|
||||
result = self._list_tasks(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "get":
|
||||
result = self._get_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "delete":
|
||||
result = self._delete_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "enable":
|
||||
result = self._enable_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "disable":
|
||||
result = self._disable_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
else:
|
||||
return ToolResult.fail(f"未知操作: {action}")
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Error: {e}")
|
||||
return ToolResult.fail(f"操作失败: {str(e)}")
|
||||
|
||||
def _create_task(self, **kwargs) -> str:
|
||||
"""Create a new scheduled task"""
|
||||
name = kwargs.get("name")
|
||||
message = kwargs.get("message")
|
||||
ai_task = kwargs.get("ai_task")
|
||||
schedule_type = kwargs.get("schedule_type")
|
||||
schedule_value = kwargs.get("schedule_value")
|
||||
|
||||
# Validate required fields
|
||||
if not name:
|
||||
return "错误: 缺少任务名称 (name)"
|
||||
|
||||
# Check that exactly one of message/ai_task is provided
|
||||
if not message and not ai_task:
|
||||
return "错误: 必须提供 message(固定消息)或 ai_task(AI任务)之一"
|
||||
if message and ai_task:
|
||||
return "错误: message 和 ai_task 只能提供其中一个"
|
||||
|
||||
if not schedule_type:
|
||||
return "错误: 缺少调度类型 (schedule_type)"
|
||||
if not schedule_value:
|
||||
return "错误: 缺少调度值 (schedule_value)"
|
||||
|
||||
# Validate schedule
|
||||
schedule = self._parse_schedule(schedule_type, schedule_value)
|
||||
if not schedule:
|
||||
return f"错误: 无效的调度配置 - type: {schedule_type}, value: {schedule_value}"
|
||||
|
||||
# Get context info for receiver
|
||||
if not self.current_context:
|
||||
return "错误: 无法获取当前对话上下文"
|
||||
|
||||
context = self.current_context
|
||||
|
||||
# Create task
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Capture the real chat session_id at task creation time so that scheduler
|
||||
# can later inject the delivered output into the user's actual conversation
|
||||
# (in group chats, session_id != receiver, e.g. "user_id:group_id" on feishu).
|
||||
notify_session_id = context.get("session_id")
|
||||
|
||||
# Build action based on message or ai_task
|
||||
if message:
|
||||
action = {
|
||||
"type": "send_message",
|
||||
"content": message,
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown"),
|
||||
"notify_session_id": notify_session_id,
|
||||
}
|
||||
else: # ai_task
|
||||
action = {
|
||||
"type": "agent_task",
|
||||
"task_description": ai_task,
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown"),
|
||||
"notify_session_id": notify_session_id,
|
||||
}
|
||||
|
||||
# 针对钉钉单聊,额外存储 sender_staff_id
|
||||
msg = context.kwargs.get("msg")
|
||||
if msg and hasattr(msg, 'sender_staff_id') and not context.get("isgroup", False):
|
||||
action["dingtalk_sender_staff_id"] = msg.sender_staff_id
|
||||
|
||||
task_data = {
|
||||
"id": task_id,
|
||||
"name": name,
|
||||
"enabled": True,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"schedule": schedule,
|
||||
"action": action
|
||||
}
|
||||
|
||||
# Calculate initial next_run_at
|
||||
next_run = self._calculate_next_run(task_data)
|
||||
if next_run:
|
||||
task_data["next_run_at"] = next_run.isoformat()
|
||||
|
||||
# Save task
|
||||
self.task_store.add_task(task_data)
|
||||
|
||||
# Format response
|
||||
schedule_desc = self._format_schedule_description(schedule)
|
||||
receiver_desc = task_data["action"]["receiver_name"] or task_data["action"]["receiver"]
|
||||
|
||||
if message:
|
||||
content_desc = f"💬 固定消息: {message}"
|
||||
else:
|
||||
content_desc = f"🤖 AI任务: {ai_task}"
|
||||
|
||||
return (
|
||||
f"✅ 定时任务创建成功\n\n"
|
||||
f"📋 任务ID: {task_id}\n"
|
||||
f"📝 名称: {name}\n"
|
||||
f"⏰ 调度: {schedule_desc}\n"
|
||||
f"👤 接收者: {receiver_desc}\n"
|
||||
f"{content_desc}\n"
|
||||
f"🕐 下次执行: {next_run.strftime('%Y-%m-%d %H:%M:%S') if next_run else '未知'}"
|
||||
)
|
||||
|
||||
def _list_tasks(self, **kwargs) -> str:
|
||||
"""List all tasks"""
|
||||
tasks = self.task_store.list_tasks()
|
||||
|
||||
if not tasks:
|
||||
return "📋 暂无定时任务"
|
||||
|
||||
lines = [f"📋 定时任务列表 (共 {len(tasks)} 个)\n"]
|
||||
|
||||
for task in tasks:
|
||||
status = "✅" if task.get("enabled", True) else "❌"
|
||||
schedule_desc = self._format_schedule_description(task.get("schedule", {}))
|
||||
next_run = task.get("next_run_at")
|
||||
next_run_str = datetime.fromisoformat(next_run).strftime('%m-%d %H:%M') if next_run else "未知"
|
||||
|
||||
lines.append(
|
||||
f"{status} [{task['id']}] {task['name']}\n"
|
||||
f" ⏰ {schedule_desc} | 下次: {next_run_str}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _get_task(self, **kwargs) -> str:
|
||||
"""Get task details"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
status = "启用" if task.get("enabled", True) else "禁用"
|
||||
schedule_desc = self._format_schedule_description(task.get("schedule", {}))
|
||||
action = task.get("action", {})
|
||||
next_run = task.get("next_run_at")
|
||||
next_run_str = datetime.fromisoformat(next_run).strftime('%Y-%m-%d %H:%M:%S') if next_run else "未知"
|
||||
last_run = task.get("last_run_at")
|
||||
last_run_str = datetime.fromisoformat(last_run).strftime('%Y-%m-%d %H:%M:%S') if last_run else "从未执行"
|
||||
|
||||
return (
|
||||
f"📋 任务详情\n\n"
|
||||
f"ID: {task['id']}\n"
|
||||
f"名称: {task['name']}\n"
|
||||
f"状态: {status}\n"
|
||||
f"调度: {schedule_desc}\n"
|
||||
f"接收者: {action.get('receiver_name', action.get('receiver'))}\n"
|
||||
f"消息: {action.get('content')}\n"
|
||||
f"下次执行: {next_run_str}\n"
|
||||
f"上次执行: {last_run_str}\n"
|
||||
f"创建时间: {datetime.fromisoformat(task['created_at']).strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
|
||||
def _delete_task(self, **kwargs) -> str:
|
||||
"""Delete a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.delete_task(task_id)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已删除"
|
||||
|
||||
def _enable_task(self, **kwargs) -> str:
|
||||
"""Enable a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.enable_task(task_id, True)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已启用"
|
||||
|
||||
def _disable_task(self, **kwargs) -> str:
|
||||
"""Disable a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.enable_task(task_id, False)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已禁用"
|
||||
|
||||
def _parse_schedule(self, schedule_type: str, schedule_value: str) -> Optional[dict]:
|
||||
"""Parse and validate schedule configuration"""
|
||||
try:
|
||||
if schedule_type == "cron":
|
||||
# Validate cron expression
|
||||
croniter(schedule_value)
|
||||
return {"type": "cron", "expression": schedule_value}
|
||||
|
||||
elif schedule_type == "interval":
|
||||
# Parse interval in seconds
|
||||
seconds = int(schedule_value)
|
||||
if seconds <= 0:
|
||||
return None
|
||||
return {"type": "interval", "seconds": seconds}
|
||||
|
||||
elif schedule_type == "once":
|
||||
# Parse datetime - support both relative and absolute time
|
||||
|
||||
# Check if it's relative time (e.g., "+5s", "+10m", "+1h", "+1d")
|
||||
if schedule_value.startswith("+"):
|
||||
import re
|
||||
match = re.match(r'\+(\d+)([smhd])', schedule_value)
|
||||
if match:
|
||||
amount = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
from datetime import timedelta
|
||||
now = datetime.now()
|
||||
|
||||
if unit == 's': # seconds
|
||||
target_time = now + timedelta(seconds=amount)
|
||||
elif unit == 'm': # minutes
|
||||
target_time = now + timedelta(minutes=amount)
|
||||
elif unit == 'h': # hours
|
||||
target_time = now + timedelta(hours=amount)
|
||||
elif unit == 'd': # days
|
||||
target_time = now + timedelta(days=amount)
|
||||
else:
|
||||
return None
|
||||
|
||||
return {"type": "once", "run_at": target_time.isoformat()}
|
||||
else:
|
||||
logger.error(f"[SchedulerTool] Invalid relative time format: {schedule_value}")
|
||||
return None
|
||||
else:
|
||||
# Absolute ISO time. Normalize to tz-naive local so it
|
||||
# stays comparable with the scheduler's datetime.now().
|
||||
parsed = datetime.fromisoformat(schedule_value)
|
||||
if parsed.tzinfo is not None:
|
||||
parsed = parsed.astimezone().replace(tzinfo=None)
|
||||
return {"type": "once", "run_at": parsed.isoformat()}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Invalid schedule: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_next_run(self, task: dict) -> Optional[datetime]:
|
||||
"""Calculate next run time for a task"""
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
now = datetime.now()
|
||||
|
||||
if schedule_type == "cron":
|
||||
expression = schedule.get("expression")
|
||||
cron = croniter(expression, now)
|
||||
return cron.get_next(datetime)
|
||||
|
||||
elif schedule_type == "interval":
|
||||
seconds = schedule.get("seconds", 0)
|
||||
from datetime import timedelta
|
||||
return now + timedelta(seconds=seconds)
|
||||
|
||||
elif schedule_type == "once":
|
||||
run_at_str = schedule.get("run_at")
|
||||
return datetime.fromisoformat(run_at_str)
|
||||
|
||||
return None
|
||||
|
||||
def _format_schedule_description(self, schedule: dict) -> str:
|
||||
"""Format schedule as human-readable description"""
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
if schedule_type == "cron":
|
||||
expr = schedule.get("expression", "")
|
||||
# Try to provide friendly description
|
||||
if expr == "0 9 * * *":
|
||||
return "每天 9:00"
|
||||
elif expr == "0 */1 * * *":
|
||||
return "每小时"
|
||||
elif expr == "*/30 * * * *":
|
||||
return "每30分钟"
|
||||
else:
|
||||
return f"Cron: {expr}"
|
||||
|
||||
elif schedule_type == "interval":
|
||||
seconds = schedule.get("seconds", 0)
|
||||
if seconds >= 86400:
|
||||
days = seconds // 86400
|
||||
return f"每 {days} 天"
|
||||
elif seconds >= 3600:
|
||||
hours = seconds // 3600
|
||||
return f"每 {hours} 小时"
|
||||
elif seconds >= 60:
|
||||
minutes = seconds // 60
|
||||
return f"每 {minutes} 分钟"
|
||||
else:
|
||||
return f"每 {seconds} 秒"
|
||||
|
||||
elif schedule_type == "once":
|
||||
run_at = schedule.get("run_at", "")
|
||||
try:
|
||||
dt = datetime.fromisoformat(run_at)
|
||||
return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
|
||||
except Exception:
|
||||
return "一次性"
|
||||
|
||||
return "未知"
|
||||
|
||||
def _get_receiver_name(self, context: Context) -> str:
|
||||
"""Get receiver name from context"""
|
||||
try:
|
||||
msg = context.get("msg")
|
||||
if msg:
|
||||
if context.get("isgroup"):
|
||||
return msg.other_user_nickname or "群聊"
|
||||
else:
|
||||
return msg.from_user_nickname or "用户"
|
||||
except Exception:
|
||||
pass
|
||||
return "未知"
|
||||
201
agent/tools/scheduler/task_store.py
Normal file
201
agent/tools/scheduler/task_store.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Task storage management for scheduler
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class TaskStore:
|
||||
"""
|
||||
Manages persistent storage of scheduled tasks
|
||||
"""
|
||||
|
||||
def __init__(self, store_path: str = None):
|
||||
"""
|
||||
Initialize task store
|
||||
|
||||
Args:
|
||||
store_path: Path to tasks.json file. Defaults to ~/cow/scheduler/tasks.json
|
||||
"""
|
||||
if store_path is None:
|
||||
# Default to ~/cow/scheduler/tasks.json
|
||||
home = expand_path("~")
|
||||
store_path = os.path.join(home, "cow", "scheduler", "tasks.json")
|
||||
|
||||
self.store_path = store_path
|
||||
self.lock = threading.Lock()
|
||||
self._ensure_store_dir()
|
||||
|
||||
def _ensure_store_dir(self):
|
||||
"""Ensure the storage directory exists"""
|
||||
store_dir = os.path.dirname(self.store_path)
|
||||
os.makedirs(store_dir, exist_ok=True)
|
||||
|
||||
def load_tasks(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Load all tasks from storage
|
||||
|
||||
Returns:
|
||||
Dictionary of task_id -> task_data
|
||||
"""
|
||||
with self.lock:
|
||||
if not os.path.exists(self.store_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(self.store_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data.get("tasks", {})
|
||||
except Exception as e:
|
||||
print(f"Error loading tasks: {e}")
|
||||
return {}
|
||||
|
||||
def save_tasks(self, tasks: Dict[str, dict]):
|
||||
"""
|
||||
Save all tasks to storage
|
||||
|
||||
Args:
|
||||
tasks: Dictionary of task_id -> task_data
|
||||
"""
|
||||
with self.lock:
|
||||
try:
|
||||
# Create backup
|
||||
if os.path.exists(self.store_path):
|
||||
backup_path = f"{self.store_path}.bak"
|
||||
try:
|
||||
with open(self.store_path, 'r') as src:
|
||||
with open(backup_path, 'w') as dst:
|
||||
dst.write(src.read())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save tasks
|
||||
data = {
|
||||
"version": 1,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"tasks": tasks
|
||||
}
|
||||
|
||||
with open(self.store_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving tasks: {e}")
|
||||
raise
|
||||
|
||||
def add_task(self, task: dict) -> bool:
|
||||
"""
|
||||
Add a new task
|
||||
|
||||
Args:
|
||||
task: Task data dictionary
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
task_id = task.get("id")
|
||||
|
||||
if not task_id:
|
||||
raise ValueError("Task must have an 'id' field")
|
||||
|
||||
if task_id in tasks:
|
||||
raise ValueError(f"Task with id '{task_id}' already exists")
|
||||
|
||||
tasks[task_id] = task
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def update_task(self, task_id: str, updates: dict) -> bool:
|
||||
"""
|
||||
Update an existing task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
updates: Dictionary of fields to update
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
|
||||
if task_id not in tasks:
|
||||
raise ValueError(f"Task '{task_id}' not found")
|
||||
|
||||
# Update fields
|
||||
tasks[task_id].update(updates)
|
||||
tasks[task_id]["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def delete_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
|
||||
if task_id not in tasks:
|
||||
raise ValueError(f"Task '{task_id}' not found")
|
||||
|
||||
del tasks[task_id]
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get a specific task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Task data or None if not found
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
return tasks.get(task_id)
|
||||
|
||||
def list_tasks(self, enabled_only: bool = False) -> List[dict]:
|
||||
"""
|
||||
List all tasks
|
||||
|
||||
Args:
|
||||
enabled_only: If True, only return enabled tasks
|
||||
|
||||
Returns:
|
||||
List of task dictionaries
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
task_list = list(tasks.values())
|
||||
|
||||
if enabled_only:
|
||||
task_list = [t for t in task_list if t.get("enabled", True)]
|
||||
|
||||
# Sort by next_run_at
|
||||
task_list.sort(key=lambda t: t.get("next_run_at", float('inf')))
|
||||
|
||||
return task_list
|
||||
|
||||
def enable_task(self, task_id: str, enabled: bool = True) -> bool:
|
||||
"""
|
||||
Enable or disable a task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
enabled: True to enable, False to disable
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
return self.update_task(task_id, {"enabled": enabled})
|
||||
3
agent/tools/send/__init__.py
Normal file
3
agent/tools/send/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .send import Send
|
||||
|
||||
__all__ = ['Send']
|
||||
171
agent/tools/send/send.py
Normal file
171
agent/tools/send/send.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Send tool - Send files to the user
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Send(BaseTool):
|
||||
"""Tool for sending files to the user"""
|
||||
|
||||
name: str = "send"
|
||||
description: str = "Send a LOCAL file (image, video, audio, document) to the user. Only for local file paths. Do NOT use this for URLs — URLs should be included directly in your text reply, the system will handle them automatically."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Local file path to send. Must be an absolute path or relative to workspace. Do NOT pass URLs here."
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Optional message to accompany the file"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
# Supported file types
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
|
||||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
|
||||
self.document_extensions = {'.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.txt', '.md'}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file send operation
|
||||
|
||||
:param args: Contains file path and optional message
|
||||
:return: File metadata for channel to send
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
message = args.get("message", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable
|
||||
if not os.access(absolute_path, os.R_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable: {path}")
|
||||
|
||||
# Get file info
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
file_name = Path(absolute_path).name
|
||||
|
||||
# Determine file type
|
||||
if file_ext in self.image_extensions:
|
||||
file_type = "image"
|
||||
mime_type = self._get_image_mime_type(file_ext)
|
||||
elif file_ext in self.video_extensions:
|
||||
file_type = "video"
|
||||
mime_type = self._get_video_mime_type(file_ext)
|
||||
elif file_ext in self.audio_extensions:
|
||||
file_type = "audio"
|
||||
mime_type = self._get_audio_mime_type(file_ext)
|
||||
elif file_ext in self.document_extensions:
|
||||
file_type = "document"
|
||||
mime_type = self._get_document_mime_type(file_ext)
|
||||
else:
|
||||
file_type = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# Return file_to_send metadata
|
||||
result = {
|
||||
"type": "file_to_send",
|
||||
"file_type": file_type,
|
||||
"path": absolute_path,
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": self._format_size(file_size),
|
||||
"message": message or f"正在发送 {file_name}"
|
||||
}
|
||||
|
||||
try:
|
||||
from common.cloud_client import get_website_base_url, copy_send_file
|
||||
|
||||
# Do nothing when in local env
|
||||
if get_website_base_url():
|
||||
url = copy_send_file(absolute_path, self.cwd)
|
||||
if url:
|
||||
result["url"] = url
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _get_image_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for image"""
|
||||
mime_map = {
|
||||
'.jpg': 'image/jpeg', '.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png', '.gif': 'image/gif',
|
||||
'.webp': 'image/webp', '.bmp': 'image/bmp',
|
||||
'.svg': 'image/svg+xml', '.ico': 'image/x-icon'
|
||||
}
|
||||
return mime_map.get(ext, 'image/jpeg')
|
||||
|
||||
def _get_video_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for video"""
|
||||
mime_map = {
|
||||
'.mp4': 'video/mp4', '.avi': 'video/x-msvideo',
|
||||
'.mov': 'video/quicktime', '.mkv': 'video/x-matroska',
|
||||
'.webm': 'video/webm', '.flv': 'video/x-flv'
|
||||
}
|
||||
return mime_map.get(ext, 'video/mp4')
|
||||
|
||||
def _get_audio_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for audio"""
|
||||
mime_map = {
|
||||
'.mp3': 'audio/mpeg', '.wav': 'audio/wav',
|
||||
'.ogg': 'audio/ogg', '.m4a': 'audio/mp4',
|
||||
'.flac': 'audio/flac', '.aac': 'audio/aac'
|
||||
}
|
||||
return mime_map.get(ext, 'audio/mpeg')
|
||||
|
||||
def _get_document_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for document"""
|
||||
mime_map = {
|
||||
'.pdf': 'application/pdf',
|
||||
'.doc': 'application/msword',
|
||||
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'.xls': 'application/vnd.ms-excel',
|
||||
'.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'.ppt': 'application/vnd.ms-powerpoint',
|
||||
'.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'.txt': 'text/plain',
|
||||
'.md': 'text/markdown'
|
||||
}
|
||||
return mime_map.get(ext, 'application/octet-stream')
|
||||
|
||||
def _format_size(self, size_bytes: int) -> str:
|
||||
"""Format file size in human-readable format"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f}{unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f}TB"
|
||||
619
agent/tools/tool_manager.py
Normal file
619
agent/tools/tool_manager.py
Normal file
@@ -0,0 +1,619 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Type
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
def _normalize_mcp_configs(raw) -> list:
|
||||
"""
|
||||
Convert MCP server config to internal list format.
|
||||
Supports:
|
||||
- list format (mcp_servers): [{"name": "x", "type": "stdio", ...}]
|
||||
- dict format (mcpServers): {"x": {"command": "npx", ...}}
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return raw
|
||||
if isinstance(raw, dict):
|
||||
result = []
|
||||
for name, cfg in raw.items():
|
||||
entry = {"name": name, **cfg}
|
||||
if "type" not in entry:
|
||||
entry["type"] = "sse" if "url" in entry else "stdio"
|
||||
result.append(entry)
|
||||
return result
|
||||
return []
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
Tool manager for managing tools.
|
||||
"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern to ensure only one instance of ToolManager exists."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ToolManager, cls).__new__(cls)
|
||||
cls._instance.tool_classes = {} # Store tool classes instead of instances
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Initialize only once
|
||||
if not hasattr(self, 'tool_classes'):
|
||||
self.tool_classes = {} # Dictionary to store tool classes
|
||||
if not hasattr(self, '_mcp_registry'):
|
||||
self._mcp_registry = None # Lazy init: only created when MCP servers are configured
|
||||
if not hasattr(self, '_mcp_tool_instances'):
|
||||
self._mcp_tool_instances: dict = {} # tool_name -> McpTool instance
|
||||
if not hasattr(self, '_mcp_lock'):
|
||||
# Guards _mcp_loaded check-then-set so concurrent callers
|
||||
# don't trigger duplicate background loaders.
|
||||
self._mcp_lock = threading.Lock()
|
||||
if not hasattr(self, '_mcp_loaded'):
|
||||
# Idempotency flag. Flipped to True the moment the first loader
|
||||
# is dispatched (synchronously, inside _mcp_lock). Subsequent
|
||||
# _load_mcp_tools() calls become no-ops, so per-session agent
|
||||
# initialization never re-forks MCP subprocesses.
|
||||
self._mcp_loaded = False
|
||||
if not hasattr(self, '_mcp_status'):
|
||||
# server_name -> "pending" / "ready" / "failed"
|
||||
# Useful for UI / introspection while async loading is in progress.
|
||||
self._mcp_status: dict = {}
|
||||
if not hasattr(self, '_mcp_signature'):
|
||||
# (mtime, sha256) of mcp.json the last time we loaded.
|
||||
# Used by refresh_mcp_if_changed() to skip re-parsing when nothing changed.
|
||||
self._mcp_signature: tuple = (None, None)
|
||||
if not hasattr(self, '_mcp_active_configs'):
|
||||
# server_name -> normalized config dict, for diff-based reload.
|
||||
self._mcp_active_configs: dict = {}
|
||||
|
||||
def load_tools(self, tools_dir: str = "", config_dict=None):
|
||||
"""
|
||||
Load tools from both directory and configuration.
|
||||
|
||||
:param tools_dir: Directory to scan for tool modules
|
||||
"""
|
||||
if tools_dir:
|
||||
self._load_tools_from_directory(tools_dir)
|
||||
self._configure_tools_from_config()
|
||||
else:
|
||||
self._load_tools_from_init()
|
||||
self._configure_tools_from_config(config_dict)
|
||||
|
||||
self._load_mcp_tools()
|
||||
|
||||
def _load_tools_from_init(self) -> bool:
|
||||
"""
|
||||
Load tool classes from tools.__init__.__all__
|
||||
|
||||
:return: True if tools were loaded, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to import the tools package
|
||||
tools_package = importlib.import_module("agent.tools")
|
||||
|
||||
# Check if __all__ is defined
|
||||
if hasattr(tools_package, "__all__"):
|
||||
tool_classes = tools_package.__all__
|
||||
|
||||
# Import each tool class directly from the tools package
|
||||
for class_name in tool_classes:
|
||||
try:
|
||||
# Skip base classes
|
||||
if class_name in ["BaseTool", "ToolManager"]:
|
||||
continue
|
||||
|
||||
# Get the class directly from the tools package
|
||||
if hasattr(tools_package, class_name):
|
||||
cls = getattr(tools_package, class_name)
|
||||
|
||||
if (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, BaseTool)
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip tools that need special initialization
|
||||
if class_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {class_name} (requires memory_manager)")
|
||||
continue
|
||||
# McpTool instances are registered dynamically via _load_mcp_tools()
|
||||
if class_name == "McpTool":
|
||||
logger.debug(f"Skipped tool {class_name} (registered dynamically via mcp_servers config)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
tool_name = temp_instance.name
|
||||
# Store the class, not the instance
|
||||
self.tool_classes[tool_name] = cls
|
||||
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "playwright" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
|
||||
f" Install with: pip install markdownify"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing class {class_name}: {e}")
|
||||
|
||||
return len(self.tool_classes) > 0
|
||||
return False
|
||||
except ImportError:
|
||||
logger.warning("Could not import agent.tools package")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading tools from __init__.__all__: {e}")
|
||||
return False
|
||||
|
||||
def _load_tools_from_directory(self, tools_dir: str):
|
||||
"""Dynamically load tool classes from directory"""
|
||||
tools_path = Path(tools_dir)
|
||||
|
||||
# Traverse all .py files
|
||||
for py_file in tools_path.rglob("*.py"):
|
||||
# Skip initialization files and base tool files
|
||||
if py_file.name in ["__init__.py", "base_tool.py", "tool_manager.py"]:
|
||||
continue
|
||||
|
||||
# Get module name
|
||||
module_name = py_file.stem
|
||||
|
||||
try:
|
||||
# Load module directly from file
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Find tool classes in the module
|
||||
for attr_name in dir(module):
|
||||
cls = getattr(module, attr_name)
|
||||
if (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, BaseTool)
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip memory tools (they need special initialization with memory_manager)
|
||||
if attr_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {attr_name} (requires memory_manager)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
tool_name = temp_instance.name
|
||||
# Store the class, not the instance
|
||||
self.tool_classes[tool_name] = cls
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "playwright" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
|
||||
f" Install with: pip install markdownify"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error importing module {py_file}: {e}")
|
||||
|
||||
def _configure_tools_from_config(self, config_dict=None):
|
||||
"""Configure tool classes based on configuration file"""
|
||||
try:
|
||||
# Get tools configuration
|
||||
tools_config = config_dict or conf().get("tools", {})
|
||||
|
||||
# Record tools that are configured but not loaded
|
||||
missing_tools = []
|
||||
|
||||
# Store configurations for later use when instantiating
|
||||
self.tool_configs = tools_config
|
||||
|
||||
# Check which configured tools are missing
|
||||
for tool_name in tools_config:
|
||||
if tool_name not in self.tool_classes:
|
||||
missing_tools.append(tool_name)
|
||||
|
||||
# If there are missing tools, record warnings
|
||||
if missing_tools:
|
||||
for tool_name in missing_tools:
|
||||
if tool_name == "browser":
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool is configured but not loaded.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif tool_name == "google_search":
|
||||
logger.warning(
|
||||
f"[ToolManager] Google Search tool is configured but may need API key.\n"
|
||||
f" Get API key from: https://serper.dev\n"
|
||||
f" Configure in config.json: tools.google_search.api_key"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] Tool '{tool_name}' is configured but could not be loaded.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error configuring tools from config: {e}")
|
||||
|
||||
def _mcp_json_path(self) -> str:
|
||||
import os
|
||||
workspace = os.path.expanduser(conf().get("agent_workspace", "~/cow"))
|
||||
return os.path.join(workspace, "mcp.json")
|
||||
|
||||
def _read_mcp_json_signature(self):
|
||||
"""
|
||||
Return (mtime, sha256_of_bytes) for ~/cow/mcp.json without parsing.
|
||||
Returns (None, None) if the file doesn't exist or is unreadable.
|
||||
Cheap enough (one stat + one small read) to call on every agent init.
|
||||
"""
|
||||
import os
|
||||
import hashlib
|
||||
path = self._mcp_json_path()
|
||||
try:
|
||||
mtime = os.path.getmtime(path)
|
||||
except OSError:
|
||||
return (None, None)
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
digest = hashlib.sha256(f.read()).hexdigest()
|
||||
except OSError:
|
||||
return (mtime, None)
|
||||
return (mtime, digest)
|
||||
|
||||
def _load_mcp_configs(self) -> list:
|
||||
"""
|
||||
Load MCP server configs with priority:
|
||||
1. ~/cow/mcp.json (supports both mcpServers and mcp_servers keys)
|
||||
2. config.json mcp_servers field (fallback)
|
||||
"""
|
||||
import os
|
||||
import json as _json
|
||||
|
||||
mcp_json_path = self._mcp_json_path()
|
||||
|
||||
if os.path.exists(mcp_json_path):
|
||||
try:
|
||||
with open(mcp_json_path, "r", encoding="utf-8") as f:
|
||||
data = _json.load(f)
|
||||
raw = data.get("mcpServers") or data.get("mcp_servers") or data
|
||||
logger.info(f"[ToolManager] Loading MCP config from {mcp_json_path}")
|
||||
return _normalize_mcp_configs(raw)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] Failed to read {mcp_json_path}: {e}, falling back to config.json")
|
||||
|
||||
raw = conf().get("mcp_servers", [])
|
||||
return _normalize_mcp_configs(raw)
|
||||
|
||||
def _load_mcp_tools(self):
|
||||
"""
|
||||
Trigger MCP tool loading in a background thread (idempotent).
|
||||
|
||||
Returns immediately. Booting MCP servers (npx, uvx, etc.) takes
|
||||
seconds to tens of seconds on first run, which would otherwise
|
||||
block agent initialization and the user's first message.
|
||||
Built-in tools work fine without MCP, so we let the agent serve
|
||||
traffic right away and let MCP servers come online in the
|
||||
background. Per-session agents read a snapshot of whatever is
|
||||
ready at construction time and gracefully ignore the rest.
|
||||
"""
|
||||
with self._mcp_lock:
|
||||
if self._mcp_loaded:
|
||||
return
|
||||
mcp_servers_config = self._load_mcp_configs()
|
||||
# Snapshot the signature now so future refresh_mcp_if_changed()
|
||||
# calls can short-circuit when nothing has changed on disk.
|
||||
self._mcp_signature = self._read_mcp_json_signature()
|
||||
self._mcp_active_configs = {
|
||||
cfg.get("name", "<unnamed>"): cfg for cfg in mcp_servers_config
|
||||
}
|
||||
if not mcp_servers_config:
|
||||
# Mark as loaded even when there is nothing to load,
|
||||
# so we don't re-read the config file on every call.
|
||||
self._mcp_loaded = True
|
||||
return
|
||||
|
||||
# Mark pending immediately so list_mcp_status() callers see
|
||||
# the in-progress state instead of an empty dict.
|
||||
for cfg in mcp_servers_config:
|
||||
name = cfg.get("name", "<unnamed>")
|
||||
self._mcp_status[name] = "pending"
|
||||
|
||||
self._mcp_loaded = True
|
||||
threading.Thread(
|
||||
target=self._load_mcp_tools_async,
|
||||
args=(mcp_servers_config,),
|
||||
daemon=True,
|
||||
name="mcp-loader",
|
||||
).start()
|
||||
logger.info(
|
||||
f"[ToolManager] MCP loading started in background "
|
||||
f"({len(mcp_servers_config)} server(s) configured)"
|
||||
)
|
||||
|
||||
def refresh_mcp_if_changed(self):
|
||||
"""
|
||||
Cheap check whether ~/cow/mcp.json has changed since last load.
|
||||
If it has, do a diff-based reload: start newly added servers,
|
||||
shut down removed ones, and restart any whose config was edited.
|
||||
Untouched servers are left running.
|
||||
|
||||
Designed to be called on every agent creation. The fast path is
|
||||
a single os.stat() — completely free when nothing has changed.
|
||||
"""
|
||||
with self._mcp_lock:
|
||||
new_sig = self._read_mcp_json_signature()
|
||||
if new_sig == self._mcp_signature:
|
||||
return # no-op fast path
|
||||
|
||||
try:
|
||||
new_configs = self._load_mcp_configs()
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] MCP reload — failed to parse config: {e}")
|
||||
return
|
||||
|
||||
new_by_name = {
|
||||
cfg.get("name", "<unnamed>"): cfg for cfg in new_configs
|
||||
}
|
||||
old_by_name = self._mcp_active_configs
|
||||
|
||||
added = [n for n in new_by_name if n not in old_by_name]
|
||||
removed = [n for n in old_by_name if n not in new_by_name]
|
||||
changed = [
|
||||
n for n in new_by_name
|
||||
if n in old_by_name and new_by_name[n] != old_by_name[n]
|
||||
]
|
||||
|
||||
if not (added or removed or changed):
|
||||
# Signature drifted but content is logically identical
|
||||
# (e.g. user re-saved the file without edits). Just sync.
|
||||
self._mcp_signature = new_sig
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[ToolManager] mcp.json changed — "
|
||||
f"adding={added}, removing={removed}, restarting={changed}"
|
||||
)
|
||||
|
||||
# Tear down removed + changed servers (changed ones get restarted below)
|
||||
for name in removed + changed:
|
||||
self._teardown_mcp_server(name)
|
||||
|
||||
# Spin up newly added + changed servers in the background
|
||||
to_start = [new_by_name[n] for n in added + changed]
|
||||
if to_start:
|
||||
for cfg in to_start:
|
||||
self._mcp_status[cfg.get("name", "<unnamed>")] = "pending"
|
||||
threading.Thread(
|
||||
target=self._load_mcp_tools_async,
|
||||
args=(to_start,),
|
||||
daemon=True,
|
||||
name="mcp-loader-reload",
|
||||
).start()
|
||||
|
||||
self._mcp_active_configs = new_by_name
|
||||
self._mcp_signature = new_sig
|
||||
|
||||
def _teardown_mcp_server(self, server_name: str):
|
||||
"""Shut down one MCP server and drop its tools from the registry."""
|
||||
if self._mcp_registry is None:
|
||||
return
|
||||
client = None
|
||||
with self._mcp_registry._registry_lock:
|
||||
client = self._mcp_registry._clients.pop(server_name, None)
|
||||
if client is not None:
|
||||
try:
|
||||
client.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MCP] Error shutting down '{server_name}': {e}")
|
||||
# Drop tools that belonged to this server.
|
||||
for tool_name in list(self._mcp_tool_instances.keys()):
|
||||
tool = self._mcp_tool_instances.get(tool_name)
|
||||
if tool is not None and getattr(tool, "server_name", None) == server_name:
|
||||
self._mcp_tool_instances.pop(tool_name, None)
|
||||
self._mcp_status.pop(server_name, None)
|
||||
|
||||
def _load_mcp_tools_async(self, mcp_servers_config):
|
||||
"""
|
||||
Background worker: bring up each MCP server one-by-one and
|
||||
publish ready tools to _mcp_tool_instances as they come online.
|
||||
|
||||
Server failures are isolated — one bad server cannot block
|
||||
the others, and never raises out of the worker thread.
|
||||
"""
|
||||
try:
|
||||
from agent.tools.mcp.mcp_client import McpClient, McpClientRegistry
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
|
||||
registry = McpClientRegistry()
|
||||
self._mcp_registry = registry
|
||||
|
||||
for cfg in mcp_servers_config:
|
||||
server_name = cfg.get("name", "<unnamed>")
|
||||
try:
|
||||
client = McpClient(cfg)
|
||||
if not client.initialize():
|
||||
self._mcp_status[server_name] = "failed"
|
||||
logger.warning(
|
||||
f"[MCP] Server '{server_name}' failed to initialize — skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
tool_schemas = client.list_tools()
|
||||
added = []
|
||||
for schema in tool_schemas:
|
||||
tool_name = schema.get("name", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
mcp_tool = McpTool(client, schema, server_name)
|
||||
# Atomic dict assignment is GIL-safe; readers iterate
|
||||
# over a list() snapshot to avoid concurrent mutation.
|
||||
self._mcp_tool_instances[tool_name] = mcp_tool
|
||||
added.append(tool_name)
|
||||
|
||||
# Register client into the shared registry only after its
|
||||
# tools are visible, so callers never see a half-loaded server.
|
||||
with registry._registry_lock:
|
||||
registry._clients[server_name] = client
|
||||
self._mcp_status[server_name] = "ready"
|
||||
logger.info(
|
||||
f"[MCP] Server '{server_name}' ready — "
|
||||
f"{len(added)} tool(s): {added}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._mcp_status[server_name] = "failed"
|
||||
logger.warning(f"[MCP] Server '{server_name}' load failed: {e}")
|
||||
|
||||
ready = sum(1 for s in self._mcp_status.values() if s == "ready")
|
||||
total = len(self._mcp_status)
|
||||
logger.info(
|
||||
f"[ToolManager] MCP loading complete: "
|
||||
f"{ready}/{total} server(s) ready, "
|
||||
f"{len(self._mcp_tool_instances)} tool(s) available"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ToolManager] MCP background loader crashed: {e}")
|
||||
|
||||
def list_mcp_status(self) -> dict:
|
||||
"""Return {server_name: status} snapshot for UI / debugging."""
|
||||
return dict(self._mcp_status)
|
||||
|
||||
def sync_mcp_into_agent(self, agent) -> tuple:
|
||||
"""
|
||||
Reconcile a live agent's tool collection with the current MCP tool registry.
|
||||
|
||||
Adds tools that finished loading after the agent was created,
|
||||
and removes tools whose MCP server was torn down. Built-in tools
|
||||
on the agent are left untouched.
|
||||
|
||||
Handles both representations CowAgent uses:
|
||||
- Agent.tools: list[BaseTool] (default Agent class)
|
||||
- AgentStream.tools: dict[str, BaseTool] (streaming agent)
|
||||
|
||||
Returns (added_names, removed_names) for logging.
|
||||
"""
|
||||
if agent is None or not hasattr(agent, "tools"):
|
||||
return ([], [])
|
||||
|
||||
from agent.tools.mcp.mcp_tool import McpTool
|
||||
current = self._mcp_tool_instances
|
||||
registry_names = set(current.keys())
|
||||
|
||||
agent_tools = agent.tools
|
||||
|
||||
if isinstance(agent_tools, dict):
|
||||
agent_mcp_names = {
|
||||
name for name, tool in agent_tools.items()
|
||||
if isinstance(tool, McpTool)
|
||||
}
|
||||
added = registry_names - agent_mcp_names
|
||||
removed = agent_mcp_names - registry_names
|
||||
if not (added or removed):
|
||||
return ([], [])
|
||||
for name in added:
|
||||
agent_tools[name] = current[name]
|
||||
for name in removed:
|
||||
agent_tools.pop(name, None)
|
||||
|
||||
elif isinstance(agent_tools, list):
|
||||
agent_mcp_names = {
|
||||
t.name for t in agent_tools if isinstance(t, McpTool)
|
||||
}
|
||||
added = registry_names - agent_mcp_names
|
||||
removed = agent_mcp_names - registry_names
|
||||
if not (added or removed):
|
||||
return ([], [])
|
||||
if removed:
|
||||
agent.tools = [
|
||||
t for t in agent_tools
|
||||
if not (isinstance(t, McpTool) and t.name in removed)
|
||||
]
|
||||
for name in added:
|
||||
agent.tools.append(current[name])
|
||||
|
||||
else:
|
||||
return ([], [])
|
||||
|
||||
return (sorted(added), sorted(removed))
|
||||
|
||||
def create_tool(self, name: str) -> BaseTool:
|
||||
"""
|
||||
Get a new instance of a tool by name.
|
||||
|
||||
:param name: The name of the tool to get.
|
||||
:return: A new instance of the tool or None if not found.
|
||||
"""
|
||||
tool_class = self.tool_classes.get(name)
|
||||
if tool_class:
|
||||
# Create a new instance
|
||||
tool_instance = tool_class()
|
||||
|
||||
# Apply configuration if available
|
||||
if hasattr(self, 'tool_configs') and name in self.tool_configs:
|
||||
tool_instance.config = self.tool_configs[name]
|
||||
|
||||
return tool_instance
|
||||
|
||||
# Fall back to MCP tool instances
|
||||
mcp_tool = self._mcp_tool_instances.get(name)
|
||||
if mcp_tool:
|
||||
return mcp_tool
|
||||
|
||||
return None
|
||||
|
||||
def list_tools(self) -> dict:
|
||||
"""
|
||||
Get information about all loaded tools.
|
||||
|
||||
:return: A dictionary with tool information.
|
||||
"""
|
||||
result = {}
|
||||
for name, tool_class in self.tool_classes.items():
|
||||
# Create a temporary instance to get schema
|
||||
temp_instance = tool_class()
|
||||
result[name] = {
|
||||
"description": temp_instance.description,
|
||||
"parameters": temp_instance.get_json_schema()
|
||||
}
|
||||
|
||||
# Include MCP tool instances
|
||||
for name, mcp_tool in self._mcp_tool_instances.items():
|
||||
result[name] = {
|
||||
"description": mcp_tool.description,
|
||||
"parameters": mcp_tool.params,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def shutdown_mcp(self):
|
||||
"""Shut down all MCP server clients."""
|
||||
if self._mcp_registry:
|
||||
self._mcp_registry.shutdown_all()
|
||||
40
agent/tools/utils/__init__.py
Normal file
40
agent/tools/utils/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from .truncate import (
|
||||
truncate_head,
|
||||
truncate_tail,
|
||||
truncate_line,
|
||||
format_size,
|
||||
TruncationResult,
|
||||
DEFAULT_MAX_LINES,
|
||||
DEFAULT_MAX_BYTES,
|
||||
GREP_MAX_LINE_LENGTH
|
||||
)
|
||||
|
||||
from .diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
normalize_to_lf,
|
||||
restore_line_endings,
|
||||
normalize_for_fuzzy_match,
|
||||
fuzzy_find_text,
|
||||
generate_diff_string,
|
||||
FuzzyMatchResult
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'truncate_head',
|
||||
'truncate_tail',
|
||||
'truncate_line',
|
||||
'format_size',
|
||||
'TruncationResult',
|
||||
'DEFAULT_MAX_LINES',
|
||||
'DEFAULT_MAX_BYTES',
|
||||
'GREP_MAX_LINE_LENGTH',
|
||||
'strip_bom',
|
||||
'detect_line_ending',
|
||||
'normalize_to_lf',
|
||||
'restore_line_endings',
|
||||
'normalize_for_fuzzy_match',
|
||||
'fuzzy_find_text',
|
||||
'generate_diff_string',
|
||||
'FuzzyMatchResult'
|
||||
]
|
||||
167
agent/tools/utils/diff.py
Normal file
167
agent/tools/utils/diff.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Diff tools for file editing
|
||||
Provides fuzzy matching and diff generation functionality
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def strip_bom(text: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Remove BOM (Byte Order Mark)
|
||||
|
||||
:param text: Original text
|
||||
:return: (BOM, text after removing BOM)
|
||||
"""
|
||||
if text.startswith('\ufeff'):
|
||||
return '\ufeff', text[1:]
|
||||
return '', text
|
||||
|
||||
|
||||
def detect_line_ending(text: str) -> str:
|
||||
"""
|
||||
Detect line ending type
|
||||
|
||||
:param text: Text content
|
||||
:return: Line ending type ('\r\n' or '\n')
|
||||
"""
|
||||
if '\r\n' in text:
|
||||
return '\r\n'
|
||||
return '\n'
|
||||
|
||||
|
||||
def normalize_to_lf(text: str) -> str:
|
||||
"""
|
||||
Normalize all line endings to LF (\n)
|
||||
|
||||
:param text: Original text
|
||||
:return: Normalized text
|
||||
"""
|
||||
return text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
|
||||
def restore_line_endings(text: str, original_ending: str) -> str:
|
||||
"""
|
||||
Restore original line endings
|
||||
|
||||
:param text: LF normalized text
|
||||
:param original_ending: Original line ending
|
||||
:return: Text with restored line endings
|
||||
"""
|
||||
if original_ending == '\r\n':
|
||||
return text.replace('\n', '\r\n')
|
||||
return text
|
||||
|
||||
|
||||
def normalize_for_fuzzy_match(text: str) -> str:
|
||||
"""
|
||||
Normalize text for fuzzy matching
|
||||
Remove excess whitespace but preserve basic structure
|
||||
|
||||
:param text: Original text
|
||||
:return: Normalized text
|
||||
"""
|
||||
# Compress multiple spaces to one
|
||||
text = re.sub(r'[ \t]+', ' ', text)
|
||||
# Remove trailing spaces
|
||||
text = re.sub(r' +\n', '\n', text)
|
||||
# Remove leading spaces (but preserve indentation structure, only remove excess)
|
||||
lines = text.split('\n')
|
||||
normalized_lines = []
|
||||
for line in lines:
|
||||
# Preserve indentation but normalize to multiples of single spaces
|
||||
stripped = line.lstrip()
|
||||
if stripped:
|
||||
indent_count = len(line) - len(stripped)
|
||||
# Normalize indentation (convert tabs to spaces)
|
||||
normalized_indent = ' ' * indent_count
|
||||
normalized_lines.append(normalized_indent + stripped)
|
||||
else:
|
||||
normalized_lines.append('')
|
||||
return '\n'.join(normalized_lines)
|
||||
|
||||
|
||||
class FuzzyMatchResult:
|
||||
"""Fuzzy match result"""
|
||||
|
||||
def __init__(self, found: bool, index: int = -1, match_length: int = 0, content_for_replacement: str = ""):
|
||||
self.found = found
|
||||
self.index = index
|
||||
self.match_length = match_length
|
||||
self.content_for_replacement = content_for_replacement
|
||||
|
||||
|
||||
def fuzzy_find_text(content: str, old_text: str) -> FuzzyMatchResult:
|
||||
"""
|
||||
Find text in content, try exact match first, then fuzzy match
|
||||
|
||||
:param content: Content to search in
|
||||
:param old_text: Text to find
|
||||
:return: Match result
|
||||
"""
|
||||
# First try exact match
|
||||
index = content.find(old_text)
|
||||
if index != -1:
|
||||
return FuzzyMatchResult(
|
||||
found=True,
|
||||
index=index,
|
||||
match_length=len(old_text),
|
||||
content_for_replacement=content
|
||||
)
|
||||
|
||||
# Try fuzzy match
|
||||
fuzzy_content = normalize_for_fuzzy_match(content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(old_text)
|
||||
|
||||
index = fuzzy_content.find(fuzzy_old_text)
|
||||
if index != -1:
|
||||
# Fuzzy match successful, use normalized content for replacement
|
||||
return FuzzyMatchResult(
|
||||
found=True,
|
||||
index=index,
|
||||
match_length=len(fuzzy_old_text),
|
||||
content_for_replacement=fuzzy_content
|
||||
)
|
||||
|
||||
# Not found
|
||||
return FuzzyMatchResult(found=False)
|
||||
|
||||
|
||||
def generate_diff_string(old_content: str, new_content: str) -> dict:
|
||||
"""
|
||||
Generate unified diff string
|
||||
|
||||
:param old_content: Old content
|
||||
:param new_content: New content
|
||||
:return: Dictionary containing diff and first changed line number
|
||||
"""
|
||||
old_lines = old_content.split('\n')
|
||||
new_lines = new_content.split('\n')
|
||||
|
||||
# Generate unified diff
|
||||
diff_lines = list(difflib.unified_diff(
|
||||
old_lines,
|
||||
new_lines,
|
||||
lineterm='',
|
||||
fromfile='original',
|
||||
tofile='modified'
|
||||
))
|
||||
|
||||
# Find first changed line number
|
||||
first_changed_line = None
|
||||
for line in diff_lines:
|
||||
if line.startswith('@@'):
|
||||
# Parse @@ -1,3 +1,3 @@ format
|
||||
match = re.search(r'@@ -\d+,?\d* \+(\d+)', line)
|
||||
if match:
|
||||
first_changed_line = int(match.group(1))
|
||||
break
|
||||
|
||||
diff_string = '\n'.join(diff_lines)
|
||||
|
||||
return {
|
||||
'diff': diff_string,
|
||||
'first_changed_line': first_changed_line
|
||||
}
|
||||
295
agent/tools/utils/truncate.py
Normal file
295
agent/tools/utils/truncate.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Shared truncation utilities for tool outputs.
|
||||
|
||||
Truncation is based on two independent limits - whichever is hit first wins:
|
||||
- Line limit (default: 2000 lines)
|
||||
- Byte limit (default: 50KB)
|
||||
|
||||
Never returns partial lines (except bash tail truncation edge case).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Any, Optional, Tuple, TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from typing import Literal
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 2000
|
||||
DEFAULT_MAX_BYTES = 50 * 1024 # 50KB
|
||||
GREP_MAX_LINE_LENGTH = 500 # Max chars per grep match line
|
||||
|
||||
|
||||
class TruncationResult:
|
||||
"""Truncation result"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
truncated: bool,
|
||||
truncated_by: Optional[Literal["lines", "bytes"]],
|
||||
total_lines: int,
|
||||
total_bytes: int,
|
||||
output_lines: int,
|
||||
output_bytes: int,
|
||||
last_line_partial: bool = False,
|
||||
first_line_exceeds_limit: bool = False,
|
||||
max_lines: int = DEFAULT_MAX_LINES,
|
||||
max_bytes: int = DEFAULT_MAX_BYTES
|
||||
):
|
||||
self.content = content
|
||||
self.truncated = truncated
|
||||
self.truncated_by = truncated_by
|
||||
self.total_lines = total_lines
|
||||
self.total_bytes = total_bytes
|
||||
self.output_lines = output_lines
|
||||
self.output_bytes = output_bytes
|
||||
self.last_line_partial = last_line_partial
|
||||
self.first_line_exceeds_limit = first_line_exceeds_limit
|
||||
self.max_lines = max_lines
|
||||
self.max_bytes = max_bytes
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"content": self.content,
|
||||
"truncated": self.truncated,
|
||||
"truncated_by": self.truncated_by,
|
||||
"total_lines": self.total_lines,
|
||||
"total_bytes": self.total_bytes,
|
||||
"output_lines": self.output_lines,
|
||||
"output_bytes": self.output_bytes,
|
||||
"last_line_partial": self.last_line_partial,
|
||||
"first_line_exceeds_limit": self.first_line_exceeds_limit,
|
||||
"max_lines": self.max_lines,
|
||||
"max_bytes": self.max_bytes
|
||||
}
|
||||
|
||||
|
||||
def format_size(bytes_count: int) -> str:
|
||||
"""Format bytes as human-readable size"""
|
||||
if bytes_count < 1024:
|
||||
return f"{bytes_count}B"
|
||||
elif bytes_count < 1024 * 1024:
|
||||
return f"{bytes_count / 1024:.1f}KB"
|
||||
else:
|
||||
return f"{bytes_count / (1024 * 1024):.1f}MB"
|
||||
|
||||
|
||||
def truncate_head(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||
"""
|
||||
Truncate content from the head (keep first N lines/bytes).
|
||||
Suitable for file reads where you want to see the beginning.
|
||||
|
||||
Never returns partial lines. If first line exceeds byte limit,
|
||||
returns empty content with first_line_exceeds_limit=True.
|
||||
|
||||
:param content: Content to truncate
|
||||
:param max_lines: Maximum number of lines (default: 2000)
|
||||
:param max_bytes: Maximum number of bytes (default: 50KB)
|
||||
:return: Truncation result
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_MAX_BYTES
|
||||
|
||||
total_bytes = len(content.encode('utf-8'))
|
||||
lines = content.split('\n')
|
||||
total_lines = len(lines)
|
||||
|
||||
# Check if no truncation is needed
|
||||
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||
return TruncationResult(
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncated_by=None,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=total_lines,
|
||||
output_bytes=total_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Check if first line alone exceeds byte limit
|
||||
first_line_bytes = len(lines[0].encode('utf-8'))
|
||||
if first_line_bytes > max_bytes:
|
||||
return TruncationResult(
|
||||
content="",
|
||||
truncated=True,
|
||||
truncated_by="bytes",
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=0,
|
||||
output_bytes=0,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=True,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Collect complete lines that fit
|
||||
output_lines_arr = []
|
||||
output_bytes_count = 0
|
||||
truncated_by = "lines"
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if i >= max_lines:
|
||||
break
|
||||
|
||||
# Calculate line bytes (add 1 for newline if not first line)
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if i > 0 else 0)
|
||||
|
||||
if output_bytes_count + line_bytes > max_bytes:
|
||||
truncated_by = "bytes"
|
||||
break
|
||||
|
||||
output_lines_arr.append(line)
|
||||
output_bytes_count += line_bytes
|
||||
|
||||
# If exited due to line limit
|
||||
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||
truncated_by = "lines"
|
||||
|
||||
output_content = '\n'.join(output_lines_arr)
|
||||
final_output_bytes = len(output_content.encode('utf-8'))
|
||||
|
||||
return TruncationResult(
|
||||
content=output_content,
|
||||
truncated=True,
|
||||
truncated_by=truncated_by,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=len(output_lines_arr),
|
||||
output_bytes=final_output_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
|
||||
def truncate_tail(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||
"""
|
||||
Truncate content from tail (keep last N lines/bytes).
|
||||
Suitable for bash output where you want to see the ending content (errors, final results).
|
||||
|
||||
If the last line of original content exceeds byte limit, may return partial first line.
|
||||
|
||||
:param content: Content to truncate
|
||||
:param max_lines: Maximum lines (default: 2000)
|
||||
:param max_bytes: Maximum bytes (default: 50KB)
|
||||
:return: Truncation result
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_MAX_BYTES
|
||||
|
||||
total_bytes = len(content.encode('utf-8'))
|
||||
lines = content.split('\n')
|
||||
total_lines = len(lines)
|
||||
|
||||
# Check if no truncation is needed
|
||||
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||
return TruncationResult(
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncated_by=None,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=total_lines,
|
||||
output_bytes=total_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Work backwards from the end
|
||||
output_lines_arr = []
|
||||
output_bytes_count = 0
|
||||
truncated_by = "lines"
|
||||
last_line_partial = False
|
||||
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
if len(output_lines_arr) >= max_lines:
|
||||
break
|
||||
|
||||
line = lines[i]
|
||||
# Calculate line bytes (add newline if not the first added line)
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if len(output_lines_arr) > 0 else 0)
|
||||
|
||||
if output_bytes_count + line_bytes > max_bytes:
|
||||
truncated_by = "bytes"
|
||||
# Edge case: if we haven't added any lines yet and this line exceeds maxBytes,
|
||||
# take the end portion of this line
|
||||
if len(output_lines_arr) == 0:
|
||||
truncated_line = _truncate_string_to_bytes_from_end(line, max_bytes)
|
||||
output_lines_arr.insert(0, truncated_line)
|
||||
output_bytes_count = len(truncated_line.encode('utf-8'))
|
||||
last_line_partial = True
|
||||
break
|
||||
|
||||
output_lines_arr.insert(0, line)
|
||||
output_bytes_count += line_bytes
|
||||
|
||||
# If exited due to line limit
|
||||
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||
truncated_by = "lines"
|
||||
|
||||
output_content = '\n'.join(output_lines_arr)
|
||||
final_output_bytes = len(output_content.encode('utf-8'))
|
||||
|
||||
return TruncationResult(
|
||||
content=output_content,
|
||||
truncated=True,
|
||||
truncated_by=truncated_by,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=len(output_lines_arr),
|
||||
output_bytes=final_output_bytes,
|
||||
last_line_partial=last_line_partial,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
|
||||
def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
|
||||
"""
|
||||
Truncate string to fit byte limit (from end).
|
||||
Properly handles multi-byte UTF-8 characters.
|
||||
|
||||
:param text: String to truncate
|
||||
:param max_bytes: Maximum bytes
|
||||
:return: Truncated string
|
||||
"""
|
||||
encoded = text.encode('utf-8')
|
||||
if len(encoded) <= max_bytes:
|
||||
return text
|
||||
|
||||
# Start from end, skip back maxBytes
|
||||
start = len(encoded) - max_bytes
|
||||
|
||||
# Find valid UTF-8 boundary (character start)
|
||||
while start < len(encoded) and (encoded[start] & 0xC0) == 0x80:
|
||||
start += 1
|
||||
|
||||
return encoded[start:].decode('utf-8', errors='ignore')
|
||||
|
||||
|
||||
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> Tuple[str, bool]:
|
||||
"""
|
||||
Truncate single line to max characters, add [truncated] suffix.
|
||||
Used for grep match lines.
|
||||
|
||||
:param line: Line to truncate
|
||||
:param max_chars: Maximum characters
|
||||
:return: (truncated text, whether truncated)
|
||||
"""
|
||||
if len(line) <= max_chars:
|
||||
return line, False
|
||||
return f"{line[:max_chars]}... [truncated]", True
|
||||
1
agent/tools/vision/__init__.py
Normal file
1
agent/tools/vision/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from agent.tools.vision.vision import Vision
|
||||
814
agent/tools/vision/vision.py
Normal file
814
agent/tools/vision/vision.py
Normal file
@@ -0,0 +1,814 @@
|
||||
"""
|
||||
Vision tool - Analyze images using Vision API.
|
||||
Supports local files (auto base64-encoded) and HTTP URLs.
|
||||
|
||||
Provider resolution:
|
||||
- tools.vision.model (if set) means "prefer this model first; fall back to
|
||||
other configured providers if it fails". The model name is mapped to its
|
||||
native provider (e.g. doubao-* → Doubao, kimi-* → Moonshot, gpt-* →
|
||||
OpenAI/LinkAI). That provider is tried first, then the standard auto
|
||||
chain runs as fallback (with the preferred provider de-duplicated).
|
||||
- Auto chain priority:
|
||||
1. Main model via bot.call_vision — only when the main bot is known
|
||||
to actually support vision (not just expose a call_vision method).
|
||||
2. Other models whose API key is configured.
|
||||
3. OpenAI / LinkAI raw HTTP.
|
||||
When use_linkai=true, LinkAI is promoted to #1.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
DEFAULT_MODEL = const.GPT_41_MINI
|
||||
DEFAULT_TIMEOUT = 60
|
||||
MAX_TOKENS = 1000
|
||||
COMPRESS_THRESHOLD = 1_048_576 # 1 MB
|
||||
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
}
|
||||
|
||||
_MAIN_MODEL_PROVIDER_NAME = "MainModel"
|
||||
|
||||
# (config_key_for_api_key, bot_type, default_vision_model, provider_display_name)
|
||||
# Auto-discovered as fallback vision providers when their API key is configured.
|
||||
# OpenAI and LinkAI are handled separately (raw HTTP providers), so not listed here.
|
||||
_DISCOVERABLE_MODELS = [
|
||||
("moonshot_api_key", const.MOONSHOT, const.KIMI_K2_6, "Moonshot"),
|
||||
("ark_api_key", const.DOUBAO, const.DOUBAO_SEED_2_PRO, "Doubao"),
|
||||
("dashscope_api_key", const.QWEN_DASHSCOPE, const.QWEN36_PLUS, "DashScope"),
|
||||
("claude_api_key", const.CLAUDEAPI, const.CLAUDE_4_6_SONNET, "Claude"),
|
||||
("gemini_api_key", const.GEMINI, const.GEMINI_35_FLASH, "Gemini"),
|
||||
("qianfan_api_key", const.QIANFAN, const.ERNIE_45_TURBO_VL, "Qianfan"),
|
||||
("zhipu_ai_api_key", const.ZHIPU_AI, const.GLM_4_7, "ZhipuAI"),
|
||||
("minimax_api_key", const.MiniMax, const.MINIMAX_M2_7, "MiniMax"),
|
||||
("mimo_api_key", const.MIMO, const.MIMO_V2_5_PRO, "MiMo"),
|
||||
]
|
||||
|
||||
# Model name prefix → discoverable provider display_name.
|
||||
# Used to auto-route tools.vision.model to its native provider.
|
||||
# Matched case-insensitively; longest prefix wins.
|
||||
_MODEL_PREFIX_TO_PROVIDER = [
|
||||
("doubao-", "Doubao"),
|
||||
("kimi-", "Moonshot"),
|
||||
("moonshot-", "Moonshot"),
|
||||
("qwen", "DashScope"), # qwen-*, qwen3-*, qwen3.6-*, etc.
|
||||
("claude-", "Claude"),
|
||||
("ernie-", "Qianfan"),
|
||||
("gemini-", "Gemini"),
|
||||
("glm-", "ZhipuAI"),
|
||||
("minimax-", "MiniMax"),
|
||||
("abab", "MiniMax"),
|
||||
("mimo-", "MiMo"),
|
||||
]
|
||||
|
||||
# Model prefixes that natively belong to OpenAI / LinkAI (raw HTTP providers).
|
||||
_OPENAI_MODEL_PREFIXES = ("gpt-", "o1-", "o3-", "o4-", "chatgpt-")
|
||||
|
||||
# Maps the UI provider id (persisted in tools.vision.provider) to the internal
|
||||
# display name used in VisionProvider.name. Keep in sync with _DISCOVERABLE_MODELS
|
||||
# and the openai/linkai branches in _route_by_model_name.
|
||||
_PROVIDER_ID_TO_DISPLAY = {
|
||||
"openai": "OpenAI",
|
||||
"linkai": "LinkAI",
|
||||
"moonshot": "Moonshot",
|
||||
"doubao": "Doubao",
|
||||
"dashscope": "DashScope",
|
||||
"claudeAPI": "Claude",
|
||||
"gemini": "Gemini",
|
||||
"qianfan": "Qianfan",
|
||||
"zhipu": "ZhipuAI",
|
||||
"minimax": "MiniMax",
|
||||
"mimo": "MiMo",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionProvider:
|
||||
"""A single Vision API provider configuration."""
|
||||
name: str
|
||||
api_key: str
|
||||
api_base: str
|
||||
extra_headers: dict = field(default_factory=dict)
|
||||
model_override: Optional[str] = None
|
||||
use_bot: bool = False # When True, call via bot.call_vision instead of raw HTTP
|
||||
fallback_bot: Any = None # Bot instance for non-main-model providers
|
||||
|
||||
|
||||
class VisionAPIError(Exception):
|
||||
"""Raised when a Vision API call fails and should trigger fallback."""
|
||||
pass
|
||||
|
||||
|
||||
class Vision(BaseTool):
|
||||
"""Analyze images using Vision API"""
|
||||
|
||||
name: str = "vision"
|
||||
description: str = (
|
||||
"Analyze a local image or image URL (jpg/jpeg/png) using Vision API. "
|
||||
"Can describe content, extract text, identify objects, colors, etc. "
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {
|
||||
"type": "string",
|
||||
"description": "Local file path or HTTP(S) URL of the image to analyze",
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Question to ask about the image",
|
||||
},
|
||||
},
|
||||
"required": ["image", "question"],
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return True
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
image = args.get("image", "").strip()
|
||||
question = args.get("question", "").strip()
|
||||
|
||||
if not image:
|
||||
return ToolResult.fail("Error: 'image' parameter is required")
|
||||
if not question:
|
||||
return ToolResult.fail("Error: 'question' parameter is required")
|
||||
|
||||
providers = self._resolve_providers()
|
||||
if not providers:
|
||||
return ToolResult.fail(
|
||||
"Error: No model available for Vision.\n"
|
||||
"The main model does not support vision and no other API keys are configured.\n"
|
||||
"Options:\n"
|
||||
" 1. Switch to a multimodal model (e.g. ernie-4.5-turbo-vl, qwen3.6-plus, claude-sonnet-4-6, gemini-2.0-flash)\n"
|
||||
" 2. Configure OPENAI_API_KEY: env_config(action=\"set\", key=\"OPENAI_API_KEY\", value=\"your-key\")\n"
|
||||
" 3. Configure LINKAI_API_KEY: env_config(action=\"set\", key=\"LINKAI_API_KEY\", value=\"your-key\")"
|
||||
)
|
||||
|
||||
try:
|
||||
image_content = self._build_image_content(image)
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: {e}")
|
||||
|
||||
# Default model is only used as a last-resort placeholder for providers
|
||||
# whose VisionProvider.model_override is None (e.g. raw OpenAI provider
|
||||
# when the user did not configure tools.vision.model).
|
||||
return self._call_with_fallback(providers, DEFAULT_MODEL, question, image_content)
|
||||
|
||||
def _call_with_fallback(self, providers: List[VisionProvider], model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""Try each provider in order; fall back to the next one on failure."""
|
||||
errors: List[str] = []
|
||||
for i, provider in enumerate(providers):
|
||||
use_model = provider.model_override or model
|
||||
try:
|
||||
logger.info(f"[Vision] Trying provider '{provider.name}' "
|
||||
f"with model '{use_model}' ({i + 1}/{len(providers)})")
|
||||
if provider.use_bot:
|
||||
result = self._call_via_bot(use_model, question, image_content, provider)
|
||||
else:
|
||||
result = self._call_api(provider, use_model, question, image_content)
|
||||
logger.info(f"[Vision] ✅ Success via {provider.name} (model={use_model})")
|
||||
return result
|
||||
except VisionAPIError as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' failed: {e}")
|
||||
except requests.Timeout:
|
||||
errors.append(f"[{provider.name}/{use_model}] Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' timed out")
|
||||
except requests.ConnectionError:
|
||||
errors.append(f"[{provider.name}/{use_model}] Connection failed")
|
||||
logger.warning(f"[Vision] Provider '{provider.name}' connection failed")
|
||||
except Exception as e:
|
||||
errors.append(f"[{provider.name}/{use_model}] {e}")
|
||||
logger.error(f"[Vision] Provider '{provider.name}' unexpected error: {e}", exc_info=True)
|
||||
|
||||
return ToolResult.fail(
|
||||
"Error: All Vision API providers failed.\n" + "\n".join(f" - {err}" for err in errors)
|
||||
)
|
||||
|
||||
def _resolve_providers(self) -> List[VisionProvider]:
|
||||
"""
|
||||
Build an ordered list of providers to try.
|
||||
|
||||
Semantics of `tools.vision.model`:
|
||||
"Prefer this model first; fall back to other configured providers
|
||||
if it fails."
|
||||
|
||||
Order:
|
||||
1. The provider that natively serves `tools.vision.model` (if any
|
||||
and its API key is configured) — using the user-specified model
|
||||
name verbatim.
|
||||
2. Auto-discovery chain as fallback:
|
||||
- use_linkai=true → [LinkAI, MainModel?, OtherModels…, OpenAI]
|
||||
- default → [MainModel?, OtherModels…, OpenAI, LinkAI]
|
||||
MainModel is only included when the main bot is known to support
|
||||
vision (see _main_bot_supports_vision).
|
||||
|
||||
Providers that share the same display name as the preferred provider
|
||||
are de-duplicated to avoid retrying the same endpoint twice.
|
||||
"""
|
||||
user_model = self._resolve_user_vision_model()
|
||||
user_provider = self._resolve_user_vision_provider()
|
||||
providers: List[VisionProvider] = []
|
||||
|
||||
# Step 1: preferred provider — explicit `tools.vision.provider`
|
||||
# wins so custom model names can still be routed correctly. Falls
|
||||
# through to model-name prefix inference when provider is unset.
|
||||
preferred = None
|
||||
if user_provider and user_model:
|
||||
preferred = self._route_by_provider_id(user_provider, user_model)
|
||||
if not preferred and user_model:
|
||||
preferred = self._route_by_model_name(user_model)
|
||||
if preferred:
|
||||
providers.extend(preferred)
|
||||
|
||||
# Step 2: auto-discovery chain as fallback
|
||||
existing = {p.name for p in providers}
|
||||
fallback: List[VisionProvider] = []
|
||||
use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
|
||||
if use_linkai:
|
||||
self._append_provider(fallback, lambda: self._build_linkai_provider(user_model))
|
||||
self._append_provider(fallback, self._build_main_model_provider)
|
||||
self._append_other_model_providers(fallback, preferred_model=user_model)
|
||||
self._append_provider(fallback, lambda: self._build_openai_provider(user_model))
|
||||
else:
|
||||
self._append_provider(fallback, self._build_main_model_provider)
|
||||
self._append_other_model_providers(fallback, preferred_model=user_model)
|
||||
self._append_provider(fallback, lambda: self._build_openai_provider(user_model))
|
||||
self._append_provider(fallback, lambda: self._build_linkai_provider(user_model))
|
||||
|
||||
for p in fallback:
|
||||
if p.name in existing:
|
||||
continue
|
||||
providers.append(p)
|
||||
existing.add(p.name)
|
||||
|
||||
return providers
|
||||
|
||||
@staticmethod
|
||||
def _append_provider(providers: List[VisionProvider], builder) -> None:
|
||||
p = builder()
|
||||
if p:
|
||||
providers.append(p)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_vision_model() -> Optional[str]:
|
||||
"""Read tools.vision.model (singular ``tool`` kept as runtime fallback)."""
|
||||
tools_conf = conf().get("tools") or conf().get("tool") or {}
|
||||
if not isinstance(tools_conf, dict):
|
||||
return None
|
||||
vision_conf = tools_conf.get("vision", {})
|
||||
if not isinstance(vision_conf, dict):
|
||||
return None
|
||||
m = vision_conf.get("model")
|
||||
if isinstance(m, str) and m.strip():
|
||||
return m.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_user_vision_provider() -> Optional[str]:
|
||||
"""Read tools.vision.provider — the UI-persisted vendor id.
|
||||
|
||||
Lets users pin a vendor for custom model names that prefix-inference
|
||||
can't recognize. Returns None when unset/blank.
|
||||
"""
|
||||
tools_conf = conf().get("tools") or conf().get("tool") or {}
|
||||
if not isinstance(tools_conf, dict):
|
||||
return None
|
||||
vision_conf = tools_conf.get("vision", {})
|
||||
if not isinstance(vision_conf, dict):
|
||||
return None
|
||||
p = vision_conf.get("provider")
|
||||
if isinstance(p, str) and p.strip():
|
||||
return p.strip()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _infer_provider_from_model(model_name: str) -> Optional[str]:
|
||||
"""
|
||||
Infer the provider display name from a model name's prefix.
|
||||
Returns None when no rule matches (or for OpenAI-family names, which
|
||||
are handled separately by the caller).
|
||||
"""
|
||||
if not model_name:
|
||||
return None
|
||||
lower = model_name.lower()
|
||||
# Sort by prefix length desc so e.g. "moonshot-" wins over hypothetical "moo-"
|
||||
for prefix, display_name in sorted(_MODEL_PREFIX_TO_PROVIDER, key=lambda x: -len(x[0])):
|
||||
if lower.startswith(prefix.lower()):
|
||||
return display_name
|
||||
return None
|
||||
|
||||
def _route_by_provider_id(self, provider_id: str, user_model: str) -> Optional[List[VisionProvider]]:
|
||||
"""Route by the UI-persisted provider id.
|
||||
|
||||
Returns:
|
||||
- [provider] : provider id is known and its key is configured.
|
||||
- None : unknown provider id, or the bot can't be created.
|
||||
Caller falls through to model-name-based routing.
|
||||
"""
|
||||
display_name = _PROVIDER_ID_TO_DISPLAY.get(provider_id)
|
||||
if not display_name:
|
||||
return None
|
||||
|
||||
# OpenAI / LinkAI use raw HTTP providers, not the discoverable bot path.
|
||||
if provider_id == "openai":
|
||||
p = self._build_openai_provider(user_model)
|
||||
return [p] if p else None
|
||||
if provider_id == "linkai":
|
||||
p = self._build_linkai_provider(user_model)
|
||||
return [p] if p else None
|
||||
|
||||
# Discoverable bot-backed providers.
|
||||
for config_key, bot_type, _default_model, name in _DISCOVERABLE_MODELS:
|
||||
if name != display_name:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
logger.warning(f"[Vision] tools.vision.provider='{provider_id}' "
|
||||
f"but '{config_key}' is not configured. Falling back.")
|
||||
return None
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
logger.warning(f"[Vision] '{display_name}' bot does not implement call_vision.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Vision] Failed to create '{display_name}' bot: {e}")
|
||||
return None
|
||||
return [VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=user_model,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)]
|
||||
return None
|
||||
|
||||
def _route_by_model_name(self, user_model: str) -> Optional[List[VisionProvider]]:
|
||||
"""
|
||||
Try to build a provider list using the user-specified model name.
|
||||
Returns:
|
||||
- [provider] : matched and the provider's key is configured
|
||||
- [] : matched but key missing → tell caller to surface this
|
||||
as a hard error rather than silently falling back
|
||||
- None : no rule matches → caller should fall through to auto
|
||||
"""
|
||||
lower = user_model.lower()
|
||||
|
||||
# OpenAI / LinkAI family
|
||||
if lower.startswith(_OPENAI_MODEL_PREFIXES):
|
||||
providers: List[VisionProvider] = []
|
||||
# Prefer LinkAI when explicitly enabled, else OpenAI first
|
||||
use_linkai = conf().get("use_linkai", False) and conf().get("linkai_api_key")
|
||||
if use_linkai:
|
||||
self._append_provider(providers, lambda: self._build_linkai_provider(user_model))
|
||||
self._append_provider(providers, lambda: self._build_openai_provider(user_model))
|
||||
else:
|
||||
self._append_provider(providers, lambda: self._build_openai_provider(user_model))
|
||||
self._append_provider(providers, lambda: self._build_linkai_provider(user_model))
|
||||
if providers:
|
||||
return providers
|
||||
logger.warning(f"[Vision] tools.vision.model='{user_model}' looks like an OpenAI "
|
||||
f"model but neither OPENAI_API_KEY nor LINKAI_API_KEY is configured.")
|
||||
return None # fall through to auto
|
||||
|
||||
# Discoverable native providers (Doubao, Moonshot, etc.)
|
||||
target_display = self._infer_provider_from_model(user_model)
|
||||
if not target_display:
|
||||
return None # unknown prefix → auto
|
||||
|
||||
for config_key, bot_type, _default_model, display_name in _DISCOVERABLE_MODELS:
|
||||
if display_name != target_display:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
logger.warning(f"[Vision] tools.vision.model='{user_model}' routes to "
|
||||
f"'{display_name}' but '{config_key}' is not configured. "
|
||||
f"Falling back to auto-discovery.")
|
||||
return None # fall through to auto
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
logger.warning(f"[Vision] '{display_name}' bot does not implement call_vision.")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Vision] Failed to create '{display_name}' bot: {e}")
|
||||
return None
|
||||
|
||||
return [VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=user_model,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)]
|
||||
|
||||
return None
|
||||
|
||||
def _append_other_model_providers(self, providers: List[VisionProvider],
|
||||
preferred_model: Optional[str] = None) -> None:
|
||||
"""
|
||||
Auto-discover other models whose API key is configured.
|
||||
Skip the main model's own bot_type (already covered by MainModel
|
||||
provider), unless the main model itself does not support vision —
|
||||
in that case we still want the vendor's dedicated vision model
|
||||
as a fallback. Also skip bot_types that already appear in the
|
||||
provider list.
|
||||
|
||||
If preferred_model matches a provider's family, use it instead
|
||||
of that provider's hard-coded default model.
|
||||
"""
|
||||
main_bot_type = None
|
||||
main_bot_supports_vision = False
|
||||
if self.model and hasattr(self.model, '_resolve_bot_type'):
|
||||
main_bot_type = self.model._resolve_bot_type(conf().get("model", ""))
|
||||
main_bot = getattr(self.model, "bot", None)
|
||||
main_bot_supports_vision = self._main_bot_supports_vision(main_bot)
|
||||
|
||||
existing_names = {p.name for p in providers}
|
||||
preferred_provider = self._infer_provider_from_model(preferred_model) if preferred_model else None
|
||||
|
||||
for config_key, bot_type, default_model, display_name in _DISCOVERABLE_MODELS:
|
||||
if display_name in existing_names:
|
||||
continue
|
||||
# Same bot_type as the main model is normally handled by the
|
||||
# MainModel provider; only skip it here if the main model
|
||||
# actually supports vision. Otherwise fall through and add
|
||||
# the vendor's dedicated vision model as a fallback.
|
||||
if bot_type == main_bot_type and main_bot_supports_vision:
|
||||
continue
|
||||
api_key = conf().get(config_key, "")
|
||||
if not api_key or not api_key.strip():
|
||||
continue
|
||||
|
||||
try:
|
||||
from models.bot_factory import create_bot
|
||||
bot = create_bot(bot_type)
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
model_for_provider = (preferred_model
|
||||
if preferred_provider == display_name and preferred_model
|
||||
else default_model)
|
||||
|
||||
provider = VisionProvider(
|
||||
name=display_name,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=model_for_provider,
|
||||
use_bot=True,
|
||||
fallback_bot=bot,
|
||||
)
|
||||
|
||||
# Same vendor as the main bot is the most natural fallback when
|
||||
# the main model itself does not support vision — promote it to
|
||||
# the front of the list instead of relying on declaration order.
|
||||
if bot_type == main_bot_type:
|
||||
providers.insert(0, provider)
|
||||
else:
|
||||
providers.append(provider)
|
||||
|
||||
def _main_bot_supports_vision(self, bot) -> bool:
|
||||
"""
|
||||
Whether the main bot is known to natively support vision.
|
||||
|
||||
Having a `call_vision` method is necessary but not sufficient —
|
||||
some bots implement the method against an endpoint that does not
|
||||
actually serve vision models, which causes silent failures when a
|
||||
vendor-foreign model name is forwarded.
|
||||
|
||||
Resolution order:
|
||||
1. If the bot explicitly declares `supports_vision`, trust it.
|
||||
This lets bots opt in or out based on their own runtime
|
||||
configuration (e.g. the currently selected model).
|
||||
2. Otherwise, fall back to a model-name prefix heuristic: trust
|
||||
call_vision when the main model looks like an OpenAI family
|
||||
model or matches a known multimodal vendor prefix.
|
||||
"""
|
||||
if bot is None:
|
||||
return False
|
||||
if hasattr(bot, "supports_vision"):
|
||||
return bool(getattr(bot, "supports_vision"))
|
||||
main_model = (conf().get("model") or "").lower()
|
||||
if not main_model:
|
||||
return False
|
||||
if main_model.startswith(_OPENAI_MODEL_PREFIXES):
|
||||
return True
|
||||
return self._infer_provider_from_model(main_model) is not None
|
||||
|
||||
def _build_main_model_provider(self) -> Optional[VisionProvider]:
|
||||
"""
|
||||
Use the vendor's own model for vision via bot.call_vision.
|
||||
Gated by _main_bot_supports_vision so non-vision bots (DeepSeek, etc.)
|
||||
do not get routed vendor-foreign model names.
|
||||
"""
|
||||
if not (self.model and hasattr(self.model, 'bot')):
|
||||
return None
|
||||
try:
|
||||
bot = self.model.bot
|
||||
except Exception:
|
||||
return None
|
||||
if not hasattr(bot, 'call_vision'):
|
||||
return None
|
||||
if not self._main_bot_supports_vision(bot):
|
||||
return None
|
||||
|
||||
# Use the configured main model name; do NOT inject tools.vision.model
|
||||
# here, because by the time we reach this branch the tools.vision.model
|
||||
# routing has already been attempted (and either matched the main bot
|
||||
# or failed to find a provider).
|
||||
main_model_name = conf().get("model") or None
|
||||
|
||||
return VisionProvider(
|
||||
name=_MAIN_MODEL_PROVIDER_NAME,
|
||||
api_key="",
|
||||
api_base="",
|
||||
model_override=main_model_name,
|
||||
use_bot=True,
|
||||
)
|
||||
|
||||
def _build_openai_provider(self, preferred_model: Optional[str] = None) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.openai.com/v1"
|
||||
# Only honor preferred_model when it looks like an OpenAI-family name;
|
||||
# otherwise the OpenAI endpoint would 400 on a vendor-specific name.
|
||||
model_override = preferred_model if (
|
||||
preferred_model and preferred_model.lower().startswith(_OPENAI_MODEL_PREFIXES)
|
||||
) else None
|
||||
return VisionProvider(
|
||||
name="OpenAI",
|
||||
api_key=api_key,
|
||||
api_base=self._ensure_v1(api_base),
|
||||
model_override=model_override,
|
||||
)
|
||||
|
||||
def _build_linkai_provider(self, preferred_model: Optional[str] = None) -> Optional[VisionProvider]:
|
||||
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.link-ai.tech"
|
||||
from common.utils import get_cloud_headers
|
||||
extra = get_cloud_headers(api_key)
|
||||
extra.pop("Authorization", None)
|
||||
extra.pop("Content-Type", None)
|
||||
# LinkAI is a multi-vendor proxy and accepts most model names, so we
|
||||
# honor any user-configured model name here.
|
||||
return VisionProvider(
|
||||
name="LinkAI",
|
||||
api_key=api_key,
|
||||
api_base=self._ensure_v1(api_base),
|
||||
extra_headers=extra,
|
||||
model_override=preferred_model,
|
||||
)
|
||||
|
||||
def _call_via_bot(self, model: str, question: str, image_content: dict,
|
||||
provider: Optional[VisionProvider] = None) -> ToolResult:
|
||||
"""
|
||||
Call a model's call_vision with vendor-native API format.
|
||||
Uses the provider's _fallback_bot if set, otherwise the main model bot.
|
||||
Raises VisionAPIError on failure so fallback can proceed.
|
||||
"""
|
||||
try:
|
||||
bot = (provider and provider.fallback_bot) or self.model.bot
|
||||
except Exception as e:
|
||||
raise VisionAPIError(f"Cannot access bot: {e}")
|
||||
|
||||
# Extract the raw image URL from the OpenAI-format image_content block
|
||||
image_url = image_content.get("image_url", {}).get("url", "")
|
||||
if not image_url:
|
||||
raise VisionAPIError("No image URL in content block")
|
||||
|
||||
try:
|
||||
response = bot.call_vision(
|
||||
image_url=image_url,
|
||||
question=question,
|
||||
model=model,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
except Exception as e:
|
||||
raise VisionAPIError(f"call_vision failed: {e}")
|
||||
|
||||
if response is NotImplemented:
|
||||
raise VisionAPIError("Bot does not support vision")
|
||||
|
||||
if isinstance(response, dict) and response.get("error"):
|
||||
raise VisionAPIError(f"API error - {response.get('message', 'Unknown')}")
|
||||
|
||||
content = response.get("content", "") if isinstance(response, dict) else ""
|
||||
if not content:
|
||||
raise VisionAPIError("Empty response from main model")
|
||||
|
||||
usage_info = response.get("usage", {}) if isinstance(response, dict) else {}
|
||||
|
||||
# Use the actual model name from the bot response if available
|
||||
actual_model = response.get("model", model) if isinstance(response, dict) else model
|
||||
provider_name = provider.name if provider else _MAIN_MODEL_PROVIDER_NAME
|
||||
return ToolResult.success({
|
||||
"model": actual_model,
|
||||
"provider": provider_name,
|
||||
"content": content,
|
||||
"usage": usage_info,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _ensure_v1(api_base: str) -> str:
|
||||
"""Append /v1 if the base URL doesn't already end with a versioned path."""
|
||||
if not api_base:
|
||||
return api_base
|
||||
# Already has /v1 or similar version suffix
|
||||
if api_base.rstrip("/").split("/")[-1].startswith("v"):
|
||||
return api_base
|
||||
return api_base.rstrip("/") + "/v1"
|
||||
|
||||
def _build_image_content(self, image: str) -> dict:
|
||||
"""
|
||||
Build the image_url content block.
|
||||
Both remote URLs and local files are converted to base64 data URLs
|
||||
so every bot backend can consume them without extra downloads.
|
||||
"""
|
||||
if image.startswith(("http://", "https://")):
|
||||
return self._download_to_data_url(image)
|
||||
|
||||
if not os.path.isfile(image):
|
||||
raise FileNotFoundError(f"Image file not found: {image}")
|
||||
|
||||
ext = image.rsplit(".", 1)[-1].lower() if "." in image else ""
|
||||
mime_type = SUPPORTED_EXTENSIONS.get(ext)
|
||||
if not mime_type:
|
||||
raise ValueError(
|
||||
f"Unsupported image format '.{ext}'. "
|
||||
f"Supported: {', '.join(SUPPORTED_EXTENSIONS.keys())}"
|
||||
)
|
||||
|
||||
file_path = self._maybe_compress(image)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
finally:
|
||||
if file_path != image and os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _download_to_data_url(url: str) -> dict:
|
||||
"""Download a remote image and return it as a base64 data URL."""
|
||||
resp = requests.get(url, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
raise VisionAPIError(f"Failed to download image: HTTP {resp.status_code}")
|
||||
content_type = resp.headers.get("Content-Type", "image/jpeg").split(";")[0].strip()
|
||||
if not content_type.startswith("image/"):
|
||||
content_type = "image/jpeg"
|
||||
b64 = base64.b64encode(resp.content).decode("ascii")
|
||||
data_url = f"data:{content_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _maybe_compress(path: str) -> str:
|
||||
"""Compress image to under COMPRESS_THRESHOLD with max long-edge 1536px."""
|
||||
file_size = os.path.getsize(path)
|
||||
if file_size <= COMPRESS_THRESHOLD:
|
||||
return path
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||||
tmp.close()
|
||||
|
||||
def _try_sips(max_dim: str, quality: str) -> bool:
|
||||
try:
|
||||
subprocess.run(
|
||||
["sips", "-Z", max_dim, "-s", "formatOptions", quality,
|
||||
path, "--out", tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
return True
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return False
|
||||
|
||||
def _try_convert(max_dim: str, quality: str) -> bool:
|
||||
try:
|
||||
subprocess.run(
|
||||
["convert", path, "-resize", f"{max_dim}x{max_dim}>",
|
||||
"-quality", quality, tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
return True
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return False
|
||||
|
||||
attempts = [
|
||||
("1536", "85"),
|
||||
("1536", "70"),
|
||||
("1536", "50"),
|
||||
]
|
||||
|
||||
for max_dim, quality in attempts:
|
||||
ok = _try_sips(max_dim, quality) or _try_convert(max_dim, quality)
|
||||
if not ok:
|
||||
continue
|
||||
new_size = os.path.getsize(tmp.name)
|
||||
logger.debug(f"[Vision] Compressed image "
|
||||
f"({file_size // 1024}KB -> {new_size // 1024}KB, "
|
||||
f"max_dim={max_dim}, q={quality})")
|
||||
if new_size <= COMPRESS_THRESHOLD:
|
||||
return tmp.name
|
||||
|
||||
if os.path.exists(tmp.name) and os.path.getsize(tmp.name) > 0:
|
||||
return tmp.name
|
||||
|
||||
os.remove(tmp.name)
|
||||
return path
|
||||
|
||||
def _call_api(self, provider: VisionProvider, model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
"""
|
||||
Call a single provider's Vision API.
|
||||
Raises VisionAPIError on recoverable failures so the caller can try
|
||||
the next provider.
|
||||
"""
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": question},
|
||||
image_content,
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**provider.extra_headers,
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{provider.api_base}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise VisionAPIError(f"HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
msg = data["error"].get("message", "Unknown API error")
|
||||
raise VisionAPIError(f"API error - {msg}")
|
||||
|
||||
content = ""
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "")
|
||||
|
||||
usage = data.get("usage", {})
|
||||
result = {
|
||||
"model": model,
|
||||
"provider": provider.name,
|
||||
"content": content,
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
}
|
||||
return ToolResult.success(result)
|
||||
0
agent/tools/web_fetch/__init__.py
Normal file
0
agent/tools/web_fetch/__init__.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
Web Fetch tool - Fetch and extract readable content from web pages and remote files.
|
||||
|
||||
Supports:
|
||||
- HTML web pages: extracts readable text content
|
||||
- Document files (PDF, Word, TXT, Markdown, etc.): downloads to workspace/tmp and parses content
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, Set
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size
|
||||
from common.log import logger
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
|
||||
"Accept": "*/*",
|
||||
}
|
||||
|
||||
# Supported document file extensions
|
||||
PDF_SUFFIXES: Set[str] = {".pdf"}
|
||||
WORD_SUFFIXES: Set[str] = {".docx"}
|
||||
TEXT_SUFFIXES: Set[str] = {".txt", ".md", ".markdown", ".rst", ".csv", ".tsv", ".log"}
|
||||
SPREADSHEET_SUFFIXES: Set[str] = {".xls", ".xlsx"}
|
||||
PPT_SUFFIXES: Set[str] = {".ppt", ".pptx"}
|
||||
|
||||
ALL_DOC_SUFFIXES = PDF_SUFFIXES | WORD_SUFFIXES | TEXT_SUFFIXES | SPREADSHEET_SUFFIXES | PPT_SUFFIXES
|
||||
|
||||
_CHARSET_RE = re.compile(r'charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_CHARSET_RE = re.compile(rb'<meta[^>]+charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_HTTP_EQUIV_RE = re.compile(
|
||||
rb'<meta[^>]+http-equiv\s*=\s*["\']?Content-Type["\']?[^>]+content\s*=\s*["\'][^"\']*charset=([\w\-]+)',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _extract_charset_from_content_type(content_type: str) -> Optional[str]:
|
||||
"""Extract charset from Content-Type header value."""
|
||||
m = _CHARSET_RE.search(content_type)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def _extract_charset_from_html_meta(raw_bytes: bytes) -> Optional[str]:
|
||||
"""Extract charset from HTML <meta> tags in the first few KB of raw bytes."""
|
||||
m = _META_CHARSET_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
m = _META_HTTP_EQUIV_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
return None
|
||||
|
||||
|
||||
def _get_url_suffix(url: str) -> str:
|
||||
"""Extract file extension from URL path, ignoring query params."""
|
||||
path = urlparse(url).path
|
||||
return os.path.splitext(path)[-1].lower()
|
||||
|
||||
|
||||
def _is_document_url(url: str) -> bool:
|
||||
"""Check if URL points to a downloadable document file."""
|
||||
suffix = _get_url_suffix(url)
|
||||
return suffix in ALL_DOC_SUFFIXES
|
||||
|
||||
|
||||
class WebFetch(BaseTool):
|
||||
"""Tool for fetching web pages and remote document files"""
|
||||
|
||||
name: str = "web_fetch"
|
||||
description: str = (
|
||||
"Fetch content from a http/https URL. For web pages, extracts readable text. "
|
||||
"For document files (PDF, Word, TXT, Markdown, Excel, PPT), downloads and parses the file content. "
|
||||
"Supported file types: .pdf, .docx, .txt, .md, .csv, .xls, .xlsx, .ppt, .pptx"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to fetch (web page or document file link)"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' parameter is required")
|
||||
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return ToolResult.fail("Error: Invalid URL (must start with http:// or https://)")
|
||||
|
||||
if _is_document_url(url):
|
||||
return self._fetch_document(url)
|
||||
|
||||
return self._fetch_webpage(url)
|
||||
|
||||
# ---- Web page fetching ----
|
||||
|
||||
def _fetch_webpage(self, url: str) -> ToolResult:
|
||||
"""Fetch and extract readable text from an HTML web page."""
|
||||
parsed = urlparse(url)
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: Failed to fetch URL: {e}")
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if self._is_binary_content_type(content_type) and not _is_document_url(url):
|
||||
return self._handle_download_by_content_type(url, response, content_type)
|
||||
|
||||
response.encoding = self._detect_encoding(response)
|
||||
html = response.text
|
||||
title = self._extract_title(html)
|
||||
text = self._extract_text(html)
|
||||
|
||||
return ToolResult.success(f"Title: {title}\n\nContent:\n{text}")
|
||||
|
||||
# ---- Document fetching ----
|
||||
|
||||
def _fetch_document(self, url: str) -> ToolResult:
|
||||
"""Download a document file and extract its text content."""
|
||||
suffix = _get_url_suffix(url)
|
||||
parsed = urlparse(url)
|
||||
filename = self._extract_filename(url)
|
||||
tmp_dir = self._ensure_tmp_dir()
|
||||
|
||||
local_path = os.path.join(tmp_dir, filename)
|
||||
logger.info(f"[WebFetch] Downloading document: {url} -> {local_path}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
content_length = int(response.headers.get("Content-Length", 0))
|
||||
if content_length > MAX_FILE_SIZE:
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large ({format_size(content_length)} > {format_size(MAX_FILE_SIZE)})"
|
||||
)
|
||||
|
||||
downloaded = 0
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
downloaded += len(chunk)
|
||||
if downloaded > MAX_FILE_SIZE:
|
||||
f.close()
|
||||
os.remove(local_path)
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large (>{format_size(MAX_FILE_SIZE)}), download aborted"
|
||||
)
|
||||
f.write(chunk)
|
||||
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Download timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to download file: {e}")
|
||||
|
||||
try:
|
||||
text = self._parse_document(local_path, suffix)
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to parse document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
file_size = os.path.getsize(local_path)
|
||||
return ToolResult.success(
|
||||
f"File downloaded to: {local_path} ({format_size(file_size)})\n"
|
||||
f"No text content could be extracted. The file may contain only images or be encrypted."
|
||||
)
|
||||
|
||||
truncation = truncate_head(text)
|
||||
result_text = truncation.content
|
||||
|
||||
file_size = os.path.getsize(local_path)
|
||||
header = f"[Document: {filename} | Size: {format_size(file_size)} | Saved to: {local_path}]\n\n"
|
||||
|
||||
if truncation.truncated:
|
||||
header += f"[Content truncated: showing {truncation.output_lines} of {truncation.total_lines} lines]\n\n"
|
||||
|
||||
return ToolResult.success(header + result_text)
|
||||
|
||||
def _parse_document(self, file_path: str, suffix: str) -> str:
|
||||
"""Parse document file and return extracted text."""
|
||||
if suffix in PDF_SUFFIXES:
|
||||
return self._parse_pdf(file_path)
|
||||
elif suffix in WORD_SUFFIXES:
|
||||
return self._parse_word(file_path)
|
||||
elif suffix in TEXT_SUFFIXES:
|
||||
return self._parse_text(file_path)
|
||||
elif suffix in SPREADSHEET_SUFFIXES:
|
||||
return self._parse_spreadsheet(file_path)
|
||||
elif suffix in PPT_SUFFIXES:
|
||||
return self._parse_ppt(file_path)
|
||||
else:
|
||||
return self._parse_text(file_path)
|
||||
|
||||
def _parse_pdf(self, file_path: str) -> str:
|
||||
"""Extract text from PDF using pypdf."""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
raise ImportError("pypdf library is required for PDF parsing. Install with: pip install pypdf")
|
||||
|
||||
reader = PdfReader(file_path)
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text and page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num}/{len(reader.pages)} ---\n{page_text}")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
def _parse_word(self, file_path: str) -> str:
|
||||
"""Extract text from Word documents (.docx)."""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-docx library is required for .docx parsing. Install with: pip install python-docx"
|
||||
)
|
||||
doc = Document(file_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return "\n\n".join(paragraphs)
|
||||
|
||||
def _parse_text(self, file_path: str) -> str:
|
||||
"""Read plain text files (txt, md, csv, etc.)."""
|
||||
encodings = ["utf-8", "utf-8-sig", "gbk", "gb2312", "latin-1"]
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(file_path, "r", encoding=enc) as f:
|
||||
return f.read()
|
||||
except (UnicodeDecodeError, UnicodeError):
|
||||
continue
|
||||
raise ValueError(f"Unable to decode file with any supported encoding: {encodings}")
|
||||
|
||||
def _parse_spreadsheet(self, file_path: str) -> str:
|
||||
"""Extract text from Excel files (.xls/.xlsx)."""
|
||||
try:
|
||||
import openpyxl
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"openpyxl library is required for .xlsx parsing. Install with: pip install openpyxl"
|
||||
)
|
||||
|
||||
wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True)
|
||||
result_parts = []
|
||||
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
rows = []
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
cells = [str(c) if c is not None else "" for c in row]
|
||||
if any(cells):
|
||||
rows.append(" | ".join(cells))
|
||||
if rows:
|
||||
result_parts.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
|
||||
|
||||
wb.close()
|
||||
return "\n\n".join(result_parts)
|
||||
|
||||
def _parse_ppt(self, file_path: str) -> str:
|
||||
"""Extract text from PowerPoint files (.ppt/.pptx)."""
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-pptx library is required for .pptx parsing. Install with: pip install python-pptx"
|
||||
)
|
||||
|
||||
prs = Presentation(file_path)
|
||||
text_parts = []
|
||||
|
||||
for slide_num, slide in enumerate(prs.slides, 1):
|
||||
slide_texts = []
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for paragraph in shape.text_frame.paragraphs:
|
||||
text = paragraph.text.strip()
|
||||
if text:
|
||||
slide_texts.append(text)
|
||||
if slide_texts:
|
||||
text_parts.append(f"--- Slide {slide_num}/{len(prs.slides)} ---\n" + "\n".join(slide_texts))
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
# ---- Encoding detection ----
|
||||
|
||||
@staticmethod
|
||||
def _detect_encoding(response: requests.Response) -> str:
|
||||
"""Detect response encoding with priority: Content-Type header > HTML meta > chardet > utf-8."""
|
||||
# 1. Check Content-Type header for explicit charset
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
charset = _extract_charset_from_content_type(content_type)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 2. Scan raw bytes for HTML meta charset declaration
|
||||
raw = response.content[:4096]
|
||||
charset = _extract_charset_from_html_meta(raw)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 3. Use apparent_encoding (chardet-based detection) if confident enough
|
||||
apparent = response.apparent_encoding
|
||||
if apparent:
|
||||
apparent_lower = apparent.lower()
|
||||
# Trust CJK / Windows encodings detected by chardet
|
||||
trusted_prefixes = ("utf", "gb", "big5", "euc", "shift_jis", "iso-2022", "windows", "ascii")
|
||||
if any(apparent_lower.startswith(p) for p in trusted_prefixes):
|
||||
return apparent
|
||||
|
||||
# 4. Fallback
|
||||
return "utf-8"
|
||||
|
||||
# ---- Helper methods ----
|
||||
|
||||
def _ensure_tmp_dir(self) -> str:
|
||||
"""Ensure workspace/tmp directory exists and return its path."""
|
||||
tmp_dir = os.path.join(self.cwd, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
def _extract_filename(self, url: str) -> str:
|
||||
"""Extract a safe filename from URL, with a short UUID prefix to avoid collisions."""
|
||||
path = urlparse(url).path
|
||||
basename = os.path.basename(unquote(path))
|
||||
if not basename or basename == "/":
|
||||
basename = "downloaded_file"
|
||||
# Sanitize: keep only safe chars
|
||||
basename = re.sub(r'[^\w.\-]', '_', basename)
|
||||
short_id = uuid.uuid4().hex[:8]
|
||||
return f"{short_id}_{basename}"
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_file(path: str):
|
||||
"""Remove a file if it exists, ignoring errors."""
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_binary_content_type(content_type: str) -> bool:
|
||||
"""Check if Content-Type indicates a binary/document response."""
|
||||
binary_types = [
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/octet-stream",
|
||||
]
|
||||
ct_lower = content_type.lower()
|
||||
return any(bt in ct_lower for bt in binary_types)
|
||||
|
||||
def _handle_download_by_content_type(self, url: str, response: requests.Response, content_type: str) -> ToolResult:
|
||||
"""Handle a URL that returned binary content instead of HTML."""
|
||||
ct_lower = content_type.lower()
|
||||
suffix_map = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml": ".docx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml": ".xlsx",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml": ".pptx",
|
||||
}
|
||||
detected_suffix = None
|
||||
for ct_prefix, ext in suffix_map.items():
|
||||
if ct_prefix in ct_lower:
|
||||
detected_suffix = ext
|
||||
break
|
||||
|
||||
if detected_suffix and detected_suffix in ALL_DOC_SUFFIXES:
|
||||
# Re-fetch as document
|
||||
return self._fetch_document(url if _get_url_suffix(url) in ALL_DOC_SUFFIXES
|
||||
else self._rewrite_url_with_suffix(url, detected_suffix))
|
||||
return ToolResult.fail(f"Error: URL returned binary content ({content_type}), not a supported document type")
|
||||
|
||||
@staticmethod
|
||||
def _rewrite_url_with_suffix(url: str, suffix: str) -> str:
|
||||
"""Append a suffix to the URL path so _get_url_suffix works correctly."""
|
||||
parsed = urlparse(url)
|
||||
new_path = parsed.path.rstrip("/") + suffix
|
||||
return parsed._replace(path=new_path).geturl()
|
||||
|
||||
# ---- HTML extraction (unchanged) ----
|
||||
|
||||
@staticmethod
|
||||
def _extract_title(html: str) -> str:
|
||||
match = re.search(r"<title[^>]*>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
|
||||
return match.group(1).strip() if match else "Untitled"
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(html: str) -> str:
|
||||
text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(""", '"').replace("'", "'").replace(" ", " ")
|
||||
text = re.sub(r"[^\S\n]+", " ", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
lines = [line.strip() for line in text.splitlines()]
|
||||
text = "\n".join(lines)
|
||||
return text.strip()
|
||||
3
agent/tools/web_search/__init__.py
Normal file
3
agent/tools/web_search/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
|
||||
__all__ = ["WebSearch"]
|
||||
487
agent/tools/web_search/web_search.py
Normal file
487
agent/tools/web_search/web_search.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""Web Search tool. Supports four backends with a unified response format:
|
||||
- bocha (https://open.bochaai.com)
|
||||
- zhipu (https://docs.bigmodel.cn/cn/guide/tools/web-search)
|
||||
- qianfan (https://cloud.baidu.com/doc/qianfan/s/2mh4su4uy)
|
||||
- linkai (https://link-ai.tech, fallback)
|
||||
|
||||
Provider selection
|
||||
- strategy 'auto' (default): pick the first configured provider in the
|
||||
canonical order [bocha, zhipu, qianfan, linkai]. When the caller passes
|
||||
an explicit `provider` it overrides the pick; an invalid/unconfigured
|
||||
one silently falls back to the auto order.
|
||||
- strategy 'fixed': use the configured provider; if its credential is
|
||||
missing at call time, silently fall back to auto order (no card hint).
|
||||
|
||||
Credentials
|
||||
- bocha : tools.web_search.bocha_api_key -> env BOCHA_API_KEY
|
||||
- zhipu : conf.zhipu_ai_api_key -> env ZHIPUAI_API_KEY
|
||||
- qianfan : conf.qianfan_api_key -> env QIANFAN_API_KEY
|
||||
- linkai : conf.linkai_api_key -> env LINKAI_API_KEY
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
|
||||
# Canonical fallback order. Empirically ordered by Chinese real-time
|
||||
# quality + relevance: bocha (best overall), qianfan (best for hot news),
|
||||
# zhipu (strong on long-form articles), linkai (cloud aggregator, last
|
||||
# resort).
|
||||
PROVIDER_ORDER = ("bocha", "qianfan", "zhipu", "linkai")
|
||||
|
||||
PROVIDER_LABELS = {
|
||||
"bocha": "Bocha",
|
||||
"zhipu": "Zhipu",
|
||||
"qianfan": "Baidu Qianfan",
|
||||
"linkai": "LinkAI",
|
||||
}
|
||||
|
||||
|
||||
def _tools_web_search_conf() -> dict:
|
||||
"""Return the tools.web_search config block (dict-like)."""
|
||||
tools_cfg = conf().get("tools") or {}
|
||||
if not isinstance(tools_cfg, dict):
|
||||
return {}
|
||||
block = tools_cfg.get("web_search") or {}
|
||||
return block if isinstance(block, dict) else {}
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> str:
|
||||
"""Resolve API key for a provider, with conf -> env fallback."""
|
||||
if provider == "bocha":
|
||||
key = (_tools_web_search_conf().get("bocha_api_key") or "").strip()
|
||||
return key or os.environ.get("BOCHA_API_KEY", "").strip()
|
||||
if provider == "zhipu":
|
||||
key = (conf().get("zhipu_ai_api_key") or "").strip()
|
||||
return key or os.environ.get("ZHIPUAI_API_KEY", "").strip()
|
||||
if provider == "qianfan":
|
||||
key = (conf().get("qianfan_api_key") or "").strip()
|
||||
return key or os.environ.get("QIANFAN_API_KEY", "").strip()
|
||||
if provider == "linkai":
|
||||
key = (conf().get("linkai_api_key") or "").strip()
|
||||
return key or os.environ.get("LINKAI_API_KEY", "").strip()
|
||||
return ""
|
||||
|
||||
|
||||
def configured_providers() -> List[str]:
|
||||
"""Return configured providers in canonical order."""
|
||||
return [p for p in PROVIDER_ORDER if _get_api_key(p)]
|
||||
|
||||
|
||||
def _configured_strategy() -> str:
|
||||
return (_tools_web_search_conf().get("strategy") or "auto").strip().lower()
|
||||
|
||||
|
||||
def _configured_provider() -> str:
|
||||
return (_tools_web_search_conf().get("provider") or "").strip().lower()
|
||||
|
||||
|
||||
class WebSearch(BaseTool):
|
||||
"""Tool for searching the web across multiple providers."""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = "Search the web for real-time information. Returns titles, URLs, and snippets."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query string"
|
||||
},
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-50, default: 10)"
|
||||
},
|
||||
"freshness": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time range filter. Options: "
|
||||
"'noLimit' (default), 'oneDay', 'oneWeek', 'oneMonth', 'oneYear', "
|
||||
"or date range like '2025-01-01..2025-02-01'"
|
||||
)
|
||||
},
|
||||
"summary": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to include text summary for each result (default: false)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
"""Tool is offered to the agent when at least one provider has a key."""
|
||||
return bool(configured_providers())
|
||||
|
||||
@classmethod
|
||||
def get_json_schema(cls) -> dict:
|
||||
"""Augment the static schema with a `provider` field — only when the
|
||||
user has ≥2 providers configured AND strategy is 'auto'. Otherwise
|
||||
the backend picks silently and exposing the field would only waste
|
||||
the agent's tokens."""
|
||||
schema = {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": json.loads(json.dumps(cls.params)), # deep copy
|
||||
}
|
||||
if _configured_strategy() != "auto":
|
||||
return schema
|
||||
available = configured_providers()
|
||||
if len(available) < 2:
|
||||
return schema
|
||||
|
||||
schema["parameters"]["properties"]["provider"] = {
|
||||
"type": "string",
|
||||
"enum": available,
|
||||
"description": "Optional. Specifies the search backend. You may switch between providers when the user wants results from a particular source or from multiple sources.",
|
||||
}
|
||||
return schema
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider resolution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_provider(self, requested: Optional[str]) -> Optional[str]:
|
||||
"""Pick a provider for this call.
|
||||
|
||||
Priority: caller-supplied (if configured) > fixed strategy (if
|
||||
configured) > first configured in PROVIDER_ORDER. Silent fallback
|
||||
when the desired one has no key.
|
||||
"""
|
||||
available = configured_providers()
|
||||
if not available:
|
||||
return None
|
||||
|
||||
if requested:
|
||||
req = requested.strip().lower()
|
||||
if req in available:
|
||||
return req
|
||||
logger.warning(f"[WebSearch] requested provider '{requested}' unavailable, falling back")
|
||||
|
||||
if _configured_strategy() == "fixed":
|
||||
pinned = _configured_provider()
|
||||
if pinned in available:
|
||||
return pinned
|
||||
if pinned:
|
||||
logger.warning(f"[WebSearch] pinned provider '{pinned}' unavailable, falling back to auto")
|
||||
|
||||
return available[0]
|
||||
|
||||
@staticmethod
|
||||
def _resolution_reason(requested: Optional[str], chosen: str) -> str:
|
||||
"""Human-readable explanation for why `chosen` won the resolver."""
|
||||
if requested and requested.strip().lower() == chosen:
|
||||
return "caller-requested"
|
||||
strategy = _configured_strategy()
|
||||
if strategy == "fixed" and _configured_provider() == chosen:
|
||||
return "fixed-strategy"
|
||||
return "auto-fallback"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
query = (args.get("query") or "").strip()
|
||||
if not query:
|
||||
return ToolResult.fail("Error: 'query' parameter is required")
|
||||
|
||||
count = args.get("count", 10)
|
||||
freshness = args.get("freshness", "noLimit")
|
||||
summary = args.get("summary", False)
|
||||
if not isinstance(count, int) or count < 1 or count > 50:
|
||||
count = 10
|
||||
|
||||
requested = args.get("provider")
|
||||
provider = self._resolve_provider(requested)
|
||||
if not provider:
|
||||
return ToolResult.fail(
|
||||
"Error: No search provider configured. "
|
||||
"Configure one of BOCHA_API_KEY / zhipu_ai_api_key / qianfan_api_key / linkai_api_key."
|
||||
)
|
||||
|
||||
# Always log the routing decision so multi-provider deployments can
|
||||
# tell at a glance which backend served any given query.
|
||||
available = configured_providers()
|
||||
reason = self._resolution_reason(requested, provider)
|
||||
q_preview = query if len(query) <= 60 else (query[:57] + "...")
|
||||
logger.info(
|
||||
f"[WebSearch] provider={provider} reason={reason} "
|
||||
f"available={list(available)} query={q_preview!r} count={count} freshness={freshness}"
|
||||
)
|
||||
|
||||
try:
|
||||
if provider == "bocha":
|
||||
return self._search_bocha(query, count, freshness, summary)
|
||||
if provider == "zhipu":
|
||||
return self._search_zhipu(query, count, freshness)
|
||||
if provider == "qianfan":
|
||||
return self._search_qianfan(query, count, freshness)
|
||||
if provider == "linkai":
|
||||
return self._search_linkai(query, count, freshness)
|
||||
return ToolResult.fail(f"Error: Unknown provider '{provider}'")
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Search request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail("Error: Failed to connect to search API")
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSearch] Unexpected error ({provider}): {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: Search failed - {str(e)}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Bocha
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_bocha(self, query: str, count: int, freshness: str, summary: bool) -> ToolResult:
|
||||
api_key = _get_api_key("bocha")
|
||||
url = "https://api.bochaai.com/v1/web-search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
payload = {"query": query, "count": count, "freshness": freshness, "summary": summary}
|
||||
|
||||
logger.debug(f"[WebSearch] bocha: query='{query}', count={count}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid bocha API key.")
|
||||
if resp.status_code == 403:
|
||||
return ToolResult.fail("Error: bocha API — insufficient balance. Top up at https://open.bochaai.com")
|
||||
if resp.status_code == 429:
|
||||
return ToolResult.fail("Error: bocha API rate limit reached.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: bocha API returned HTTP {resp.status_code}")
|
||||
|
||||
data = resp.json()
|
||||
api_code = data.get("code")
|
||||
if api_code is not None and api_code != 200:
|
||||
msg = data.get("msg") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: bocha API error (code={api_code}): {msg}")
|
||||
|
||||
pages = (data.get("data") or {}).get("webPages", {}).get("value", []) or []
|
||||
results = []
|
||||
for p in pages:
|
||||
item = {
|
||||
"title": p.get("name", ""),
|
||||
"url": p.get("url", ""),
|
||||
"snippet": p.get("snippet", ""),
|
||||
"siteName": p.get("siteName", ""),
|
||||
"datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""),
|
||||
}
|
||||
if p.get("summary"):
|
||||
item["summary"] = p["summary"]
|
||||
results.append(item)
|
||||
total = (data.get("data") or {}).get("webPages", {}).get("totalEstimatedMatches", len(results))
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "bocha",
|
||||
"total": total, "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Zhipu
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_zhipu(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("zhipu")
|
||||
api_base = (conf().get("zhipu_ai_api_base") or "https://open.bigmodel.cn/api/paas/v4").rstrip("/")
|
||||
url = f"{api_base}/web_search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Zhipu Web Search expects `search_query` <= 70 chars; truncate
|
||||
# gracefully so a long agent-supplied query doesn't get rejected.
|
||||
trimmed_query = (query or "")[:70]
|
||||
engine = (_tools_web_search_conf().get("zhipu_search_engine") or "search_pro").strip().lower()
|
||||
if engine not in ("search_std", "search_pro", "search_pro_sogou", "search_pro_quark"):
|
||||
engine = "search_pro"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"search_engine": engine,
|
||||
"search_query": trimmed_query,
|
||||
"search_intent": False,
|
||||
"count": max(1, min(int(count or 10), 50)),
|
||||
"search_recency_filter": freshness if freshness in (
|
||||
"oneDay", "oneWeek", "oneMonth", "oneYear", "noLimit"
|
||||
) else "noLimit",
|
||||
}
|
||||
content_size = (_tools_web_search_conf().get("zhipu_content_size") or "").strip().lower()
|
||||
if content_size in ("medium", "high"):
|
||||
payload["content_size"] = content_size
|
||||
|
||||
logger.debug(f"[WebSearch] zhipu: query='{trimmed_query}', count={payload['count']}, engine={engine}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid Zhipu API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Zhipu API returned HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
# Business-level errors (1701/1702/1703 etc.) come back as
|
||||
# {"error": {"code","message"}} even on HTTP 200.
|
||||
if isinstance(data, dict) and data.get("error"):
|
||||
err = data["error"] or {}
|
||||
return ToolResult.fail(f"Error: Zhipu returned {err.get('code')}: {err.get('message','')}")
|
||||
|
||||
items = data.get("search_result") or (data.get("data") or {}).get("search_result") or []
|
||||
results = []
|
||||
for it in items:
|
||||
results.append({
|
||||
"title": it.get("title", ""),
|
||||
"url": it.get("link") or it.get("url", ""),
|
||||
"snippet": it.get("content") or it.get("snippet", ""),
|
||||
"siteName": it.get("media") or it.get("siteName", ""),
|
||||
"datePublished": it.get("publish_date") or it.get("datePublished", ""),
|
||||
})
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "zhipu",
|
||||
"total": len(results), "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Qianfan (Baidu)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_qianfan(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("qianfan")
|
||||
api_base = (conf().get("qianfan_api_base") or "https://qianfan.baidubce.com/v2").rstrip("/")
|
||||
url = f"{api_base}/ai_search/web_search"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-Appbuilder-From": "cow",
|
||||
}
|
||||
|
||||
count = max(1, min(int(count or 10), 50))
|
||||
payload: Dict[str, Any] = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "web", "top_k": count}],
|
||||
}
|
||||
|
||||
# Baidu AI Search expects freshness as a date-range filter, not a
|
||||
# named recency token. Translate our shared vocabulary into the
|
||||
# underlying page_time range expected by the API.
|
||||
search_filter = self._qianfan_build_freshness_filter(freshness)
|
||||
if search_filter:
|
||||
payload["search_filter"] = search_filter
|
||||
|
||||
logger.debug(f"[WebSearch] qianfan: query='{query}', count={count}, freshness={freshness!r}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid Qianfan API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Qianfan API returned HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
# Even on HTTP 200 Baidu surfaces business errors as {"code","message"}.
|
||||
if isinstance(data, dict) and data.get("code"):
|
||||
return ToolResult.fail(f"Error: Qianfan returned {data.get('code')}: {data.get('message','')}")
|
||||
|
||||
refs = data.get("references") or []
|
||||
results = []
|
||||
for d in refs:
|
||||
results.append({
|
||||
"title": d.get("title", ""),
|
||||
"url": d.get("url", ""),
|
||||
"snippet": (d.get("content") or "")[:200],
|
||||
"siteName": d.get("web_anchor") or d.get("website") or "",
|
||||
"datePublished": d.get("date", ""),
|
||||
})
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "qianfan",
|
||||
"total": len(results), "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _qianfan_build_freshness_filter(freshness: str) -> Optional[Dict[str, Any]]:
|
||||
if not freshness or freshness == "noLimit":
|
||||
return None
|
||||
delta_days = {"oneDay": 1, "oneWeek": 7, "oneMonth": 30, "oneYear": 365}.get(freshness)
|
||||
if not delta_days:
|
||||
return None
|
||||
from datetime import datetime, timedelta
|
||||
now = datetime.now()
|
||||
end_date = (now + timedelta(days=1)).strftime("%Y-%m-%d")
|
||||
start_date = (now - timedelta(days=delta_days)).strftime("%Y-%m-%d")
|
||||
return {"range": {"page_time": {"gte": start_date, "lt": end_date}}}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LinkAI (plugin)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
api_key = _get_api_key("linkai")
|
||||
api_base = (conf().get("linkai_api_base") or "https://api.link-ai.tech").rstrip("/")
|
||||
url = f"{api_base}/v1/plugin/execute"
|
||||
|
||||
from common.utils import get_cloud_headers
|
||||
headers = get_cloud_headers(api_key)
|
||||
|
||||
payload = {"code": "web-search", "args": {"query": query, "count": count, "freshness": freshness}}
|
||||
logger.debug(f"[WebSearch] linkai: query='{query}', count={count}")
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid LinkAI API key.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: LinkAI API returned HTTP {resp.status_code}")
|
||||
|
||||
data = resp.json()
|
||||
if not data.get("success"):
|
||||
msg = data.get("message") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: LinkAI search failed: {msg}")
|
||||
|
||||
raw = data.get("data", "")
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
raw = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "linkai",
|
||||
"total": 1, "count": 1, "results": [{"content": raw}],
|
||||
})
|
||||
|
||||
if isinstance(raw, dict):
|
||||
pages = (raw.get("webPages") or {}).get("value", []) or []
|
||||
if pages:
|
||||
results = []
|
||||
for p in pages:
|
||||
item = {
|
||||
"title": p.get("name", ""),
|
||||
"url": p.get("url", ""),
|
||||
"snippet": p.get("snippet", ""),
|
||||
"siteName": p.get("siteName", ""),
|
||||
"datePublished": p.get("datePublished") or p.get("dateLastCrawled", ""),
|
||||
}
|
||||
if p.get("summary"):
|
||||
item["summary"] = p["summary"]
|
||||
results.append(item)
|
||||
total = (raw.get("webPages") or {}).get("totalEstimatedMatches", len(results))
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "linkai",
|
||||
"total": total, "count": len(results), "results": results,
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"query": query, "backend": "linkai",
|
||||
"total": 1, "count": 1, "results": [{"content": str(raw)}],
|
||||
})
|
||||
3
agent/tools/write/__init__.py
Normal file
3
agent/tools/write/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .write import Write
|
||||
|
||||
__all__ = ['Write']
|
||||
97
agent/tools/write/write.py
Normal file
97
agent/tools/write/write.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Write tool - Write file content
|
||||
Creates or overwrites files, automatically creates parent directories
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Write(BaseTool):
|
||||
"""Tool for writing file content"""
|
||||
|
||||
name: str = "write"
|
||||
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories. IMPORTANT: Single write should not exceed 10KB. For large files, create a skeleton first, then use edit to add content in chunks."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to write (relative or absolute)"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.memory_manager = self.config.get("memory_manager", None)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file write operation
|
||||
|
||||
:param args: Contains file path and content
|
||||
:return: Operation result
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
content = args.get("content", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
try:
|
||||
# Create parent directory (if needed)
|
||||
parent_dir = os.path.dirname(absolute_path)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
# Write file
|
||||
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
# Get bytes written
|
||||
bytes_written = len(content.encode('utf-8'))
|
||||
|
||||
# Auto-sync to memory database if this is a memory file
|
||||
if self.memory_manager and 'memory/' in path:
|
||||
self.memory_manager.mark_dirty()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully wrote {bytes_written} bytes to {path}",
|
||||
"path": path,
|
||||
"bytes_written": bytes_written
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied writing to {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error writing file: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
351
app.py
351
app.py
@@ -3,11 +3,263 @@
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
from channel import channel_factory
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from config import load_config, conf
|
||||
from plugins import *
|
||||
import threading
|
||||
|
||||
|
||||
_channel_mgr = None
|
||||
|
||||
|
||||
def get_channel_manager():
|
||||
return _channel_mgr
|
||||
|
||||
|
||||
def _parse_channel_type(raw) -> list:
|
||||
"""
|
||||
Parse channel_type config value into a list of channel names.
|
||||
Supports:
|
||||
- single string: "feishu"
|
||||
- comma-separated string: "feishu, dingtalk"
|
||||
- list: ["feishu", "dingtalk"]
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return [ch.strip() for ch in raw if ch.strip()]
|
||||
if isinstance(raw, str):
|
||||
return [ch.strip() for ch in raw.split(",") if ch.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manage the lifecycle of multiple channels running concurrently.
|
||||
Each channel.startup() runs in its own daemon thread.
|
||||
The web channel is started as default console unless explicitly disabled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._channels = {} # channel_name -> channel instance
|
||||
self._threads = {} # channel_name -> thread
|
||||
self._primary_channel = None
|
||||
self._lock = threading.Lock()
|
||||
self.cloud_mode = False # set to True when cloud client is active
|
||||
|
||||
@property
|
||||
def channel(self):
|
||||
"""Return the primary (first non-web) channel for backward compatibility."""
|
||||
return self._primary_channel
|
||||
|
||||
def get_channel(self, channel_name: str):
|
||||
return self._channels.get(channel_name)
|
||||
|
||||
def start(self, channel_names: list, first_start: bool = False):
|
||||
"""
|
||||
Create and start one or more channels in sub-threads.
|
||||
If first_start is True, plugins and linkai client will also be initialized.
|
||||
"""
|
||||
with self._lock:
|
||||
channels = []
|
||||
for name in channel_names:
|
||||
ch = channel_factory.create_channel(name)
|
||||
ch.cloud_mode = self.cloud_mode
|
||||
self._channels[name] = ch
|
||||
channels.append((name, ch))
|
||||
if self._primary_channel is None and name != "web":
|
||||
self._primary_channel = ch
|
||||
|
||||
if self._primary_channel is None and channels:
|
||||
self._primary_channel = channels[0][1]
|
||||
|
||||
if first_start:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
# Cloud client is optional. It is only started when
|
||||
# use_linkai=True AND cloud_deployment_id is set.
|
||||
# By default neither is configured, so the app runs
|
||||
# entirely locally without any remote connection.
|
||||
if conf().get("use_linkai") and (
|
||||
os.environ.get("CLOUD_DEPLOYMENT_ID") or conf().get("cloud_deployment_id")
|
||||
):
|
||||
try:
|
||||
from common import cloud_client
|
||||
threading.Thread(
|
||||
target=cloud_client.start,
|
||||
args=(self._primary_channel, self),
|
||||
daemon=True,
|
||||
).start()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Start web console first so its logs print cleanly,
|
||||
# then start remaining channels after a brief pause.
|
||||
web_entry = None
|
||||
other_entries = []
|
||||
for entry in channels:
|
||||
if entry[0] == "web":
|
||||
web_entry = entry
|
||||
else:
|
||||
other_entries.append(entry)
|
||||
|
||||
ordered = ([web_entry] if web_entry else []) + other_entries
|
||||
for i, (name, ch) in enumerate(ordered):
|
||||
if i > 0 and name != "web":
|
||||
time.sleep(0.1)
|
||||
t = threading.Thread(target=self._run_channel, args=(name, ch), daemon=True)
|
||||
self._threads[name] = t
|
||||
t.start()
|
||||
logger.debug(f"[ChannelManager] Channel '{name}' started in sub-thread")
|
||||
|
||||
def _run_channel(self, name: str, channel):
|
||||
try:
|
||||
channel.startup()
|
||||
except Exception as e:
|
||||
logger.error(f"[ChannelManager] Channel '{name}' startup error: {e}")
|
||||
logger.exception(e)
|
||||
|
||||
def stop(self, channel_name: str = None):
|
||||
"""
|
||||
Stop channel(s). If channel_name is given, stop only that channel;
|
||||
otherwise stop all channels.
|
||||
"""
|
||||
# Pop under lock, then stop outside lock to avoid deadlock
|
||||
with self._lock:
|
||||
names = [channel_name] if channel_name else list(self._channels.keys())
|
||||
to_stop = []
|
||||
for name in names:
|
||||
ch = self._channels.pop(name, None)
|
||||
th = self._threads.pop(name, None)
|
||||
to_stop.append((name, ch, th))
|
||||
if channel_name and self._primary_channel is self._channels.get(channel_name):
|
||||
self._primary_channel = None
|
||||
|
||||
for name, ch, th in to_stop:
|
||||
if ch is None:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' not found in managed channels")
|
||||
if th and th.is_alive():
|
||||
self._interrupt_thread(th, name)
|
||||
continue
|
||||
logger.info(f"[ChannelManager] Stopping channel '{name}'...")
|
||||
graceful = False
|
||||
if hasattr(ch, 'stop'):
|
||||
try:
|
||||
ch.stop()
|
||||
graceful = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Error during channel '{name}' stop: {e}")
|
||||
if th and th.is_alive():
|
||||
th.join(timeout=5)
|
||||
if th.is_alive():
|
||||
if graceful:
|
||||
logger.info(f"[ChannelManager] Channel '{name}' thread still alive after stop(), "
|
||||
"leaving daemon thread to finish on its own")
|
||||
else:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' thread did not exit in 5s, forcing interrupt")
|
||||
self._interrupt_thread(th, name)
|
||||
|
||||
@staticmethod
|
||||
def _interrupt_thread(th: threading.Thread, name: str):
|
||||
"""Raise SystemExit in target thread to break blocking loops like start_forever."""
|
||||
import ctypes
|
||||
try:
|
||||
tid = th.ident
|
||||
if tid is None:
|
||||
return
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
|
||||
)
|
||||
if res == 1:
|
||||
logger.info(f"[ChannelManager] Interrupted thread for channel '{name}'")
|
||||
elif res > 1:
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
|
||||
logger.warning(f"[ChannelManager] Failed to interrupt thread for channel '{name}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Thread interrupt error for '{name}': {e}")
|
||||
|
||||
def restart(self, new_channel_name: str):
|
||||
"""
|
||||
Restart a single channel with a new channel type.
|
||||
Can be called from any thread (e.g. linkai config callback).
|
||||
"""
|
||||
logger.info(f"[ChannelManager] Restarting channel to '{new_channel_name}'...")
|
||||
self.stop(new_channel_name)
|
||||
_clear_singleton_cache(new_channel_name)
|
||||
time.sleep(1)
|
||||
self.start([new_channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel restarted to '{new_channel_name}' successfully")
|
||||
|
||||
def add_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically add and start a new channel.
|
||||
If the channel is already running, restart it instead.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name in self._channels:
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' already exists, restarting")
|
||||
if self._channels.get(channel_name):
|
||||
self.restart(channel_name)
|
||||
return
|
||||
logger.info(f"[ChannelManager] Adding channel '{channel_name}'...")
|
||||
_clear_singleton_cache(channel_name)
|
||||
self.start([channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' added successfully")
|
||||
|
||||
def remove_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically stop and remove a running channel.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name not in self._channels:
|
||||
logger.warning(f"[ChannelManager] Channel '{channel_name}' not found, nothing to remove")
|
||||
return
|
||||
logger.info(f"[ChannelManager] Removing channel '{channel_name}'...")
|
||||
self.stop(channel_name)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' removed successfully")
|
||||
|
||||
|
||||
def _clear_singleton_cache(channel_name: str):
|
||||
"""
|
||||
Clear the singleton cache for the channel class so that
|
||||
a new instance can be created with updated config.
|
||||
"""
|
||||
cls_map = {
|
||||
"web": "channel.web.web_channel.WebChannel",
|
||||
"wechatmp": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatmp_service": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatcom_app": "channel.wechatcom.wechatcomapp_channel.WechatComAppChannel",
|
||||
const.WECHAT_KF: "channel.wechat_kf.wechat_kf_channel.WechatKfChannel",
|
||||
const.FEISHU: "channel.feishu.feishu_channel.FeiShuChanel",
|
||||
const.DINGTALK: "channel.dingtalk.dingtalk_channel.DingTalkChanel",
|
||||
const.WECOM_BOT: "channel.wecom_bot.wecom_bot_channel.WecomBotChannel",
|
||||
const.QQ: "channel.qq.qq_channel.QQChannel",
|
||||
const.WEIXIN: "channel.weixin.weixin_channel.WeixinChannel",
|
||||
"wx": "channel.weixin.weixin_channel.WeixinChannel",
|
||||
}
|
||||
module_path = cls_map.get(channel_name)
|
||||
if not module_path:
|
||||
return
|
||||
try:
|
||||
parts = module_path.rsplit(".", 1)
|
||||
module_name, class_name = parts[0], parts[1]
|
||||
import importlib
|
||||
module = importlib.import_module(module_name)
|
||||
wrapper = getattr(module, class_name, None)
|
||||
if wrapper and hasattr(wrapper, '__closure__') and wrapper.__closure__:
|
||||
for cell in wrapper.__closure__:
|
||||
try:
|
||||
cell_contents = cell.cell_contents
|
||||
if isinstance(cell_contents, dict):
|
||||
cell_contents.clear()
|
||||
logger.debug(f"[ChannelManager] Cleared singleton cache for {class_name}")
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Failed to clear singleton cache: {e}")
|
||||
|
||||
|
||||
def sigterm_handler_wrap(_signo):
|
||||
@@ -23,7 +275,65 @@ def sigterm_handler_wrap(_signo):
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def _warmup_mcp_tools():
|
||||
"""
|
||||
Kick off MCP server loading at process startup so subprocesses
|
||||
(npx / uvx etc.) finish initializing before the first user message
|
||||
arrives. Returns immediately — the actual work happens on a daemon
|
||||
thread inside ToolManager. Safe to call when MCP is not configured.
|
||||
"""
|
||||
try:
|
||||
from agent.tools import ToolManager
|
||||
ToolManager()._load_mcp_tools()
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] MCP warmup failed (non-fatal): {e}")
|
||||
|
||||
|
||||
def _warmup_scheduler():
|
||||
"""Eager-init AgentBridge so the scheduler thread starts at process
|
||||
boot rather than waiting for the first user message."""
|
||||
try:
|
||||
from bridge.bridge import Bridge
|
||||
Bridge().get_agent_bridge()
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Scheduler warmup failed: {e}")
|
||||
|
||||
|
||||
def _sync_builtin_skills():
|
||||
"""Sync builtin skills from project skills/ to workspace skills/ on startup."""
|
||||
import shutil
|
||||
try:
|
||||
workspace = conf().get("agent_workspace", "~/cow")
|
||||
workspace = os.path.expanduser(workspace)
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
builtin_dir = os.path.join(project_root, "skills")
|
||||
custom_dir = os.path.join(workspace, "skills")
|
||||
|
||||
if not os.path.isdir(builtin_dir):
|
||||
return
|
||||
|
||||
os.makedirs(custom_dir, exist_ok=True)
|
||||
synced = 0
|
||||
for name in os.listdir(builtin_dir):
|
||||
src = os.path.join(builtin_dir, name)
|
||||
if not os.path.isdir(src) or not os.path.isfile(os.path.join(src, "SKILL.md")):
|
||||
continue
|
||||
dst = os.path.join(custom_dir, name)
|
||||
try:
|
||||
if os.path.isdir(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst)
|
||||
synced += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Failed to sync builtin skill '{name}': {e}")
|
||||
if synced:
|
||||
logger.info(f"[App] Synced {synced} builtin skill(s) to workspace")
|
||||
except Exception as e:
|
||||
logger.warning(f"[App] Builtin skills sync failed: {e}")
|
||||
|
||||
|
||||
def run():
|
||||
global _channel_mgr
|
||||
try:
|
||||
# load config
|
||||
load_config()
|
||||
@@ -32,22 +342,39 @@ def run():
|
||||
# kill signal
|
||||
sigterm_handler_wrap(signal.SIGTERM)
|
||||
|
||||
# create channel
|
||||
channel_name = conf().get("channel_type", "wx")
|
||||
# Parse channel_type into a list
|
||||
raw_channel = conf().get("channel_type", "web")
|
||||
|
||||
if "--cmd" in sys.argv:
|
||||
channel_name = "terminal"
|
||||
channel_names = ["terminal"]
|
||||
else:
|
||||
channel_names = _parse_channel_type(raw_channel)
|
||||
if not channel_names:
|
||||
channel_names = ["web"]
|
||||
|
||||
if channel_name == "wxy":
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||
# Auto-start web console unless explicitly disabled
|
||||
web_console_enabled = conf().get("web_console", True)
|
||||
if web_console_enabled and "web" not in channel_names:
|
||||
channel_names.append("web")
|
||||
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework"]:
|
||||
PluginManager().load_plugins()
|
||||
# Sync builtin skills to workspace before channels start
|
||||
_sync_builtin_skills()
|
||||
|
||||
# startup channel
|
||||
channel.startup()
|
||||
# Kick off MCP server loading in the background so first-message
|
||||
# latency isn't dominated by npx package downloads.
|
||||
_warmup_mcp_tools()
|
||||
|
||||
_warmup_scheduler()
|
||||
|
||||
logger.info(f"[App] Starting channels: {channel_names}")
|
||||
|
||||
_channel_mgr = ChannelManager()
|
||||
_channel_mgr.start(channel_names, first_start=True)
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
|
||||
17
bot/bot.py
17
bot/bot.py
@@ -1,17 +0,0 @@
|
||||
"""
|
||||
Auto-replay chat robot abstract class
|
||||
"""
|
||||
|
||||
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
|
||||
|
||||
class Bot(object):
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
"""
|
||||
bot auto-reply content
|
||||
:param req: received message
|
||||
:return: reply content
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,46 +0,0 @@
|
||||
"""
|
||||
channel factory
|
||||
"""
|
||||
from common import const
|
||||
|
||||
|
||||
def create_bot(bot_type):
|
||||
"""
|
||||
create a bot_type instance
|
||||
:param bot_type: bot type code
|
||||
:return: bot instance
|
||||
"""
|
||||
if bot_type == const.BAIDU:
|
||||
# 替换Baidu Unit为Baidu文心千帆对话接口
|
||||
# from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
# return BaiduUnitBot()
|
||||
from bot.baidu.baidu_wenxin import BaiduWenxinBot
|
||||
return BaiduWenxinBot()
|
||||
|
||||
elif bot_type == const.CHATGPT:
|
||||
# ChatGPT 网页端web接口
|
||||
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
||||
return ChatGPTBot()
|
||||
|
||||
elif bot_type == const.OPEN_AI:
|
||||
# OpenAI 官方对话模型API
|
||||
from bot.openai.open_ai_bot import OpenAIBot
|
||||
return OpenAIBot()
|
||||
|
||||
elif bot_type == const.CHATGPTONAZURE:
|
||||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
||||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
||||
return AzureChatGPTBot()
|
||||
|
||||
elif bot_type == const.XUNFEI:
|
||||
from bot.xunfei.xunfei_spark_bot import XunFeiBot
|
||||
return XunFeiBot()
|
||||
|
||||
elif bot_type == const.LINKAI:
|
||||
from bot.linkai.link_ai_bot import LinkAIBot
|
||||
return LinkAIBot()
|
||||
|
||||
elif bot_type == const.CLAUDEAI:
|
||||
from bot.claude.claude_ai_bot import ClaudeAIBot
|
||||
return ClaudeAIBot()
|
||||
raise RuntimeError
|
||||
@@ -1,193 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import requests
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from config import conf, load_config
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# set the default api_key
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("open_ai_api_base"):
|
||||
openai.api_base = conf().get("open_ai_api_base")
|
||||
proxy = conf().get("proxy")
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
if conf().get("rate_limit_chatgpt"):
|
||||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
||||
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[CHATGPT] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
||||
|
||||
api_key = context.get("openai_api_key")
|
||||
model = context.get("gpt_model")
|
||||
new_args = None
|
||||
if model:
|
||||
new_args = self.args.copy()
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, api_key, args=new_args)
|
||||
logger.debug(
|
||||
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
if args is None:
|
||||
args = self.args
|
||||
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
||||
# logger.debug("[CHATGPT] response={}".format(response))
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]["message"]["content"],
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIError):
|
||||
logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.exception("[CHATGPT] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, api_key, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
class AzureChatGPTBot(ChatGPTBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_type = "azure"
|
||||
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
|
||||
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
api_version = "2022-08-03-preview"
|
||||
url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version)
|
||||
api_key = api_key or openai.api_key
|
||||
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
||||
try:
|
||||
body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")}
|
||||
submission = requests.post(url, headers=headers, json=body)
|
||||
operation_location = submission.headers["Operation-Location"]
|
||||
retry_after = submission.headers["Retry-after"]
|
||||
status = ""
|
||||
image_url = ""
|
||||
while status != "Succeeded":
|
||||
logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
|
||||
time.sleep(int(retry_after))
|
||||
response = requests.get(operation_location, headers=headers)
|
||||
status = response.json()["status"]
|
||||
image_url = response.json()["result"]["contentUrl"]
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.error("create image error: {}".format(e))
|
||||
return False, "图片生成失败"
|
||||
@@ -1,222 +0,0 @@
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from curl_cffi import requests
|
||||
from bot.bot import Bot
|
||||
from bot.claude.claude_ai_session import ClaudeAiSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class ClaudeAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.claude_api_cookie = conf().get("claude_api_cookie")
|
||||
self.proxy = conf().get("proxy")
|
||||
self.con_uuid_dic = {}
|
||||
if self.proxy:
|
||||
self.proxies = {
|
||||
"http": self.proxy,
|
||||
"https": self.proxy
|
||||
}
|
||||
else:
|
||||
self.proxies = None
|
||||
self.error = ""
|
||||
self.org_uuid = self.get_organization_id()
|
||||
|
||||
def generate_uuid(self):
|
||||
random_uuid = uuid.uuid4()
|
||||
random_uuid_str = str(random_uuid)
|
||||
formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
|
||||
return formatted_uuid
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def get_organization_id(self):
|
||||
url = "https://claude.ai/api/organizations"
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}'
|
||||
}
|
||||
try:
|
||||
response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
|
||||
res = json.loads(response.text)
|
||||
uuid = res[0]['uuid']
|
||||
except:
|
||||
if "App unavailable" in response.text:
|
||||
logger.error("IP error: The IP is not allowed to be used on Claude")
|
||||
self.error = "ip所在地区不被claude支持"
|
||||
elif "Invalid authorization" in response.text:
|
||||
logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
|
||||
self.error = "无法通过claude身份验证,请检查cookie"
|
||||
return None
|
||||
return uuid
|
||||
|
||||
def conversation_share_check(self,session_id):
|
||||
if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
|
||||
con_uuid = conf().get("claude_uuid")
|
||||
return con_uuid
|
||||
if session_id not in self.con_uuid_dic:
|
||||
self.con_uuid_dic[session_id] = self.generate_uuid()
|
||||
self.create_new_chat(self.con_uuid_dic[session_id])
|
||||
return self.con_uuid_dic[session_id]
|
||||
|
||||
def check_cookie(self):
|
||||
flag = self.get_organization_id()
|
||||
return flag
|
||||
|
||||
def create_new_chat(self, con_uuid):
|
||||
"""
|
||||
新建claude对话实体
|
||||
:param con_uuid: 对话id
|
||||
:return:
|
||||
"""
|
||||
url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
|
||||
payload = json.dumps({"uuid": con_uuid, "name": ""})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': self.claude_api_cookie,
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
|
||||
# Returns JSON of the newly created conversation information
|
||||
return response.json()
|
||||
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[CLAUDEAI] failed after maximum number of retry times")
|
||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
||||
|
||||
try:
|
||||
session_id = context["session_id"]
|
||||
if self.org_uuid is None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
con_uuid = self.conversation_share_check(session_id)
|
||||
|
||||
model = conf().get("model") or "gpt-3.5-turbo"
|
||||
# remove system message
|
||||
if session.messages[0].get("role") == "system":
|
||||
if model == "wenxin" or model == "claude":
|
||||
session.messages.pop(0)
|
||||
logger.info(f"[CLAUDEAI] query={query}")
|
||||
|
||||
# do http request
|
||||
base_url = "https://claude.ai"
|
||||
payload = json.dumps({
|
||||
"completion": {
|
||||
"prompt": f"{query}",
|
||||
"timezone": "Asia/Kolkata",
|
||||
"model": "claude-2"
|
||||
},
|
||||
"organization_uuid": f"{self.org_uuid}",
|
||||
"conversation_uuid": f"{con_uuid}",
|
||||
"text": f"{query}",
|
||||
"attachments": []
|
||||
})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept': 'text/event-stream, text/event-stream',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
|
||||
res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
|
||||
if res.status_code == 200 or "pemission" in res.text:
|
||||
# execute success
|
||||
decoded_data = res.content.decode("utf-8")
|
||||
decoded_data = re.sub('\n+', '\n', decoded_data).strip()
|
||||
data_strings = decoded_data.split('\n')
|
||||
completions = []
|
||||
for data_string in data_strings:
|
||||
json_str = data_string[6:].strip()
|
||||
data = json.loads(json_str)
|
||||
if 'completion' in data:
|
||||
completions.append(data['completion'])
|
||||
|
||||
reply_content = ''.join(completions)
|
||||
|
||||
if "rate limi" in reply_content:
|
||||
logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
|
||||
return Reply(ReplyType.ERROR, "对话达到系统速率限制,与cladue同步,请进入官网查看解除限制时间")
|
||||
logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
|
||||
self.sessions.session_reply(reply_content, session_id, 100)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
else:
|
||||
flag = self.check_cookie()
|
||||
if flag == None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
@@ -1,9 +0,0 @@
|
||||
from bot.session_manager import Session
|
||||
|
||||
|
||||
class ClaudeAiSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="claude"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# claude逆向不支持role prompt
|
||||
# self.reset()
|
||||
@@ -1,264 +0,0 @@
|
||||
# access LinkAI knowledge base platform
|
||||
# docs: https://link-ai.tech/platform/link-app/wechat
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, pconf
|
||||
|
||||
|
||||
class LinkAIBot(Bot):
|
||||
# authentication failed
|
||||
AUTH_FAILED_CODE = 401
|
||||
NO_QUOTA_CODE = 406
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {}
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[LINKAI] failed after maximum number of retry times")
|
||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
||||
|
||||
try:
|
||||
# load config
|
||||
if context.get("generate_breaked_by"):
|
||||
logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
|
||||
app_code = None
|
||||
else:
|
||||
app_code = context.kwargs.get("app_code") or conf().get("linkai_app_code")
|
||||
linkai_api_key = conf().get("linkai_api_key")
|
||||
|
||||
session_id = context["session_id"]
|
||||
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
model = conf().get("model") or "gpt-3.5-turbo"
|
||||
# remove system message
|
||||
if session.messages[0].get("role") == "system":
|
||||
if app_code or model == "wenxin":
|
||||
session.messages.pop(0)
|
||||
|
||||
body = {
|
||||
"app_code": app_code,
|
||||
"messages": session.messages,
|
||||
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"temperature": conf().get("temperature"),
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
}
|
||||
file_id = context.kwargs.get("file_id")
|
||||
if file_id:
|
||||
body["file_id"] = file_id
|
||||
logger.info(f"[LINKAI] query={query}, app_code={app_code}, mode={body.get('model')}, file_id={file_id}")
|
||||
headers = {"Authorization": "Bearer " + linkai_api_key}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
# execute success
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
|
||||
agent_suffix = self._fetch_agent_suffix(response)
|
||||
if agent_suffix:
|
||||
reply_content += agent_suffix
|
||||
if not agent_suffix:
|
||||
knowledge_suffix = self._fetch_knowledge_search_suffix(response)
|
||||
if knowledge_suffix:
|
||||
reply_content += knowledge_suffix
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[LINKAI] failed after maximum number of retry times")
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": "请再问我一次吧"
|
||||
}
|
||||
|
||||
try:
|
||||
body = {
|
||||
"app_code": app_code,
|
||||
"messages": session.messages,
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"temperature": conf().get("temperature"),
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
}
|
||||
if self.args.get("max_tokens"):
|
||||
body["max_tokens"] = self.args.get("max_tokens")
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
# execute success
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": reply_content,
|
||||
}
|
||||
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self.reply_text(session, app_code, retry_count + 1)
|
||||
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": "提问太快啦,请休息一下再问我吧"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self.reply_text(session, app_code, retry_count + 1)
|
||||
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
try:
|
||||
logger.info("[LinkImage] image_query={}".format(query))
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {conf().get('linkai_api_key')}"
|
||||
}
|
||||
data = {
|
||||
"prompt": query,
|
||||
"n": 1,
|
||||
"model": conf().get("text_to_image") or "dall-e-2",
|
||||
"response_format": "url",
|
||||
"img_proxy": conf().get("image_proxy")
|
||||
}
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/images/generations"
|
||||
res = requests.post(url, headers=headers, json=data, timeout=(5, 90))
|
||||
t2 = time.time()
|
||||
image_url = res.json()["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(format(e))
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
|
||||
|
||||
def _fetch_knowledge_search_suffix(self, response) -> str:
|
||||
try:
|
||||
if response.get("knowledge_base"):
|
||||
search_hit = response.get("knowledge_base").get("search_hit")
|
||||
first_similarity = response.get("knowledge_base").get("first_similarity")
|
||||
logger.info(f"[LINKAI] knowledge base, search_hit={search_hit}, first_similarity={first_similarity}")
|
||||
plugin_config = pconf("linkai")
|
||||
if plugin_config and plugin_config.get("knowledge_base") and plugin_config.get("knowledge_base").get("search_miss_text_enabled"):
|
||||
search_miss_similarity = plugin_config.get("knowledge_base").get("search_miss_similarity")
|
||||
search_miss_text = plugin_config.get("knowledge_base").get("search_miss_suffix")
|
||||
if not search_hit:
|
||||
return search_miss_text
|
||||
if search_miss_similarity and float(search_miss_similarity) > first_similarity:
|
||||
return search_miss_text
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _fetch_agent_suffix(self, response):
|
||||
try:
|
||||
plugin_list = []
|
||||
logger.debug(f"[LinkAgent] res={response}")
|
||||
if response.get("agent") and response.get("agent").get("chain") and response.get("agent").get("need_show_plugin"):
|
||||
chain = response.get("agent").get("chain")
|
||||
suffix = "\n\n- - - - - - - - - - - -"
|
||||
i = 0
|
||||
for turn in chain:
|
||||
plugin_name = turn.get('plugin_name')
|
||||
suffix += "\n"
|
||||
need_show_thought = response.get("agent").get("need_show_thought")
|
||||
if turn.get("thought") and plugin_name and need_show_thought:
|
||||
suffix += f"{turn.get('thought')}\n"
|
||||
if plugin_name:
|
||||
plugin_list.append(turn.get('plugin_name'))
|
||||
suffix += f"{turn.get('plugin_icon')} {turn.get('plugin_name')}"
|
||||
if turn.get('plugin_input'):
|
||||
suffix += f":{turn.get('plugin_input')}"
|
||||
if i < len(chain) - 1:
|
||||
suffix += "\n"
|
||||
i += 1
|
||||
logger.info(f"[LinkAgent] use plugins: {plugin_list}")
|
||||
return suffix
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
@@ -1,122 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.openai.open_ai_session import OpenAISession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
user_session = dict()
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class OpenAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("open_ai_api_base"):
|
||||
openai.api_base = conf().get("open_ai_api_base")
|
||||
proxy = conf().get("proxy")
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
"max_tokens": 1200, # 回复最大的字符数
|
||||
"top_p": 1,
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
"stop": ["\n\n\n"],
|
||||
}
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
if query == "#清除记忆":
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
result = self.reply_text(session)
|
||||
total_tokens, completion_tokens, reply_content = (
|
||||
result["total_tokens"],
|
||||
result["completion_tokens"],
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: OpenAISession, retry_count=0):
|
||||
try:
|
||||
response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
completion_tokens = response["usage"]["completion_tokens"]
|
||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.warn("[OPEN_AI] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
@@ -1,43 +0,0 @@
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from config import conf
|
||||
|
||||
|
||||
# OPENAI提供的画图接口
|
||||
class OpenAIImage(object):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("rate_limit_dalle"):
|
||||
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
try:
|
||||
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
api_key=api_key,
|
||||
prompt=query, # 图片描述
|
||||
n=1, # 每次生成图片的数量
|
||||
model=conf().get("text_to_image") or "dall-e-2",
|
||||
# size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
|
||||
return self.create_img(query, retry_count + 1)
|
||||
else:
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
1030
bridge/agent_bridge.py
Normal file
1030
bridge/agent_bridge.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user