mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 18:17:11 +08:00
Compare commits
1123 Commits
1.2.2.2
...
feat-qq-ch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4d54f58c8 | ||
|
|
005a0e1bad | ||
|
|
46d97fd57d | ||
|
|
72a26b6353 | ||
|
|
89a4033fbf | ||
|
|
39a5dc64bd | ||
|
|
d4bdd9b1b7 | ||
|
|
2f5ba87280 | ||
|
|
8b45d6c750 | ||
|
|
4ecd4df2d4 | ||
|
|
a42f31fe52 | ||
|
|
d4480b695e | ||
|
|
c4b5f7fbae | ||
|
|
ba915f2cc0 | ||
|
|
4b91140f31 | ||
|
|
9879878dd0 | ||
|
|
d78105d57c | ||
|
|
153c9e3565 | ||
|
|
c11623596d | ||
|
|
e791a77f77 | ||
|
|
b641bffb2c | ||
|
|
ee0c47ac1e | ||
|
|
eba90e9343 | ||
|
|
d8374d0fa5 | ||
|
|
fa61744c6d | ||
|
|
4fec55cc01 | ||
|
|
1767413712 | ||
|
|
734c8fa84f | ||
|
|
9a8d422554 | ||
|
|
b21e945c76 | ||
|
|
a02bf1ea09 | ||
|
|
eda82bac92 | ||
|
|
e8d4f7dc4f | ||
|
|
c4a93b7789 | ||
|
|
c3f9925097 | ||
|
|
2a0cf7511a | ||
|
|
d0a70d3339 | ||
|
|
f37e4675dd | ||
|
|
4e32f67eeb | ||
|
|
36d54cab52 | ||
|
|
9d8df10dcf | ||
|
|
45ea88e070 | ||
|
|
d5d0b947f5 | ||
|
|
f775f1f11e | ||
|
|
f1e888f3de | ||
|
|
71c8436e90 | ||
|
|
08c69f5e9b | ||
|
|
a50fafaca2 | ||
|
|
3c6781d240 | ||
|
|
3b8b5625f8 | ||
|
|
6be2034110 | ||
|
|
924dc79f00 | ||
|
|
ccb9030d3c | ||
|
|
8623287ac1 | ||
|
|
022c13f3a4 | ||
|
|
0687916e7f | ||
|
|
bb868b83ba | ||
|
|
24298130b9 | ||
|
|
6e5ee92ebd | ||
|
|
5b91fe04aa | ||
|
|
1623deb3ee | ||
|
|
4a16e05b7a | ||
|
|
f1c04bc60d | ||
|
|
84c6f31c76 | ||
|
|
9d528190bf | ||
|
|
0f23b209ad | ||
|
|
63d9325900 | ||
|
|
f342097f81 | ||
|
|
b4806c4366 | ||
|
|
ff37d8a577 | ||
|
|
a773eb7893 | ||
|
|
7c67513d24 | ||
|
|
6ed85029c5 | ||
|
|
e9c57ddf4d | ||
|
|
a33ce97ed9 | ||
|
|
b788a3dd4e | ||
|
|
fccfa92d7e | ||
|
|
8705bf0a70 | ||
|
|
9318138af7 | ||
|
|
269fa7d2d5 | ||
|
|
e99837a8b9 | ||
|
|
553861a2c4 | ||
|
|
628a85d1be | ||
|
|
2cb54514a4 | ||
|
|
6db22827f2 | ||
|
|
4cc6d5426b | ||
|
|
7d258b5202 | ||
|
|
c8d19ee0bc | ||
|
|
d891312032 | ||
|
|
5edbf4ce32 | ||
|
|
3ddbdd713d | ||
|
|
9ba107b511 | ||
|
|
c9adddb76a | ||
|
|
f0a12d5ff5 | ||
|
|
7cce224499 | ||
|
|
97397ca585 | ||
|
|
f2fbc602a8 | ||
|
|
925d728a86 | ||
|
|
f5f229871b | ||
|
|
9917552b4b | ||
|
|
adca89b973 | ||
|
|
29bfbecdc9 | ||
|
|
1a7a8c98d9 | ||
|
|
cddb38ac3d | ||
|
|
394853c0fb | ||
|
|
c0702c8b36 | ||
|
|
d610608391 | ||
|
|
9082eec91d | ||
|
|
f1a1413b5f | ||
|
|
c1e7f9af9b | ||
|
|
1c71c4e38b | ||
|
|
5e3eccb3f6 | ||
|
|
e1dc037eb9 | ||
|
|
97e9b4c801 | ||
|
|
52d7cad735 | ||
|
|
c0b1d270ba | ||
|
|
e59a2892e4 | ||
|
|
5fa0376a49 | ||
|
|
05a33042c8 | ||
|
|
ce58f23cbc | ||
|
|
b6fc9fa370 | ||
|
|
00ae38faae | ||
|
|
ab28ee58ab | ||
|
|
48db538a2e | ||
|
|
46945942e1 | ||
|
|
a24b26a1ef | ||
|
|
6f8421cdd5 | ||
|
|
284cd9bca9 | ||
|
|
23fd6b8d2b | ||
|
|
4f0ea5d756 | ||
|
|
6c218331b1 | ||
|
|
cea7fb7490 | ||
|
|
8acf2dbdfe | ||
|
|
0542700f90 | ||
|
|
5264f7ce18 | ||
|
|
051ffd78a3 | ||
|
|
bea95d4fae | ||
|
|
fdf7bc312f | ||
|
|
5b094e1097 | ||
|
|
9ad3968084 | ||
|
|
3958b6aae1 | ||
|
|
eaa413caf0 | ||
|
|
9095225b5b | ||
|
|
c529f86dbc | ||
|
|
e4fcfa356a | ||
|
|
8218cff7c1 | ||
|
|
6949bbcf39 | ||
|
|
480c60c0a7 | ||
|
|
eec10cb5db | ||
|
|
02c83d8689 | ||
|
|
72b1cacea1 | ||
|
|
c72cda3386 | ||
|
|
867442155e | ||
|
|
229b14b6fc | ||
|
|
158c87ab8b | ||
|
|
cb303e6109 | ||
|
|
a77a8741b5 | ||
|
|
3d63459c25 | ||
|
|
ce63de3c58 | ||
|
|
4b3b1219b5 | ||
|
|
73b069a76c | ||
|
|
101cf8d108 | ||
|
|
2e926dfb6e | ||
|
|
501866d12a | ||
|
|
39bcb0869f | ||
|
|
a7b99cde4e | ||
|
|
60abcd92a3 | ||
|
|
cdd36e7052 | ||
|
|
c6ac175ce4 | ||
|
|
46bcd87c23 | ||
|
|
ab74be8e33 | ||
|
|
d8298b3eab | ||
|
|
50e60e6d05 | ||
|
|
5d02acbf37 | ||
|
|
8901d91f96 | ||
|
|
b55021bb3d | ||
|
|
0ef51b85e6 | ||
|
|
c77566cc02 | ||
|
|
c1bcedfb51 | ||
|
|
d085a3c7d7 | ||
|
|
46fa07e4a9 | ||
|
|
a8d5309c90 | ||
|
|
77c2bfcc1e | ||
|
|
4c8712d683 | ||
|
|
d337140577 | ||
|
|
99c273a293 | ||
|
|
85578a06b7 | ||
|
|
6f70a8efda | ||
|
|
c693e39196 | ||
|
|
4a1fae3cb4 | ||
|
|
08b592816b | ||
|
|
0e85fcfe51 | ||
|
|
8ef788e799 | ||
|
|
645c8899b1 | ||
|
|
9bf5b0fc48 | ||
|
|
07959a3bff | ||
|
|
86a6182e41 | ||
|
|
89e229ab75 | ||
|
|
624917fac4 | ||
|
|
489894c61d | ||
|
|
ac87979cb7 | ||
|
|
5fd3e85a83 | ||
|
|
0e53ba4311 | ||
|
|
3ce57ef851 | ||
|
|
481570d059 | ||
|
|
04442b7ddb | ||
|
|
e1a71723bc | ||
|
|
f044fb8b47 | ||
|
|
e3350d5bec | ||
|
|
8a69d4354e | ||
|
|
dd6a9c26bd | ||
|
|
49fb4034c6 | ||
|
|
5a466d0ff6 | ||
|
|
bb850bb6c5 | ||
|
|
25cf6823d0 | ||
|
|
7e12744b8b | ||
|
|
8f2432e0f8 | ||
|
|
94451db638 | ||
|
|
f8b8eeec3a | ||
|
|
a4260cc5de | ||
|
|
8c1622798b | ||
|
|
e75bed1be5 | ||
|
|
8c0517de0f | ||
|
|
94e78365a5 | ||
|
|
29c056ca65 | ||
|
|
d8c57f27db | ||
|
|
3cac2bad55 | ||
|
|
e7905fdf49 | ||
|
|
a492bc2242 | ||
|
|
e663364f64 | ||
|
|
ef6466e26f | ||
|
|
7fcbbf1cdc | ||
|
|
ec6ad51ff7 | ||
|
|
1e80c59448 | ||
|
|
e48cb4fd5d | ||
|
|
7c9fbd2625 | ||
|
|
0f504415fb | ||
|
|
4998c324d1 | ||
|
|
fb5fbe76e8 | ||
|
|
223b0bfc88 | ||
|
|
51094a68c8 | ||
|
|
83cb1ec911 | ||
|
|
a77e4bfb7a | ||
|
|
654c177333 | ||
|
|
b92669ba33 | ||
|
|
f2e4f6607d | ||
|
|
5ec909c565 | ||
|
|
a84f31d54a | ||
|
|
e0dd21406d | ||
|
|
72f5f7a0b8 | ||
|
|
e3d20085c5 | ||
|
|
8bf1aef801 | ||
|
|
5f7ade20dc | ||
|
|
70d7e52df0 | ||
|
|
8e6afa5614 | ||
|
|
a1ae3804e3 | ||
|
|
814ce7a43b | ||
|
|
628f75009e | ||
|
|
03fc8c1202 | ||
|
|
8c8e996c87 | ||
|
|
933bb0b1fb | ||
|
|
931fbc3eb5 | ||
|
|
3db5e70a3d | ||
|
|
7b19b70d90 | ||
|
|
99b8103d70 | ||
|
|
7167310ccd | ||
|
|
263667a2d4 | ||
|
|
d5cef291f6 | ||
|
|
c8d166e833 | ||
|
|
6e25782d8b | ||
|
|
c3127f7e84 | ||
|
|
7b90fb018b | ||
|
|
e8bc173cd7 | ||
|
|
4d1cdf5207 | ||
|
|
57a473364e | ||
|
|
40b62e9d38 | ||
|
|
ead5f9926b | ||
|
|
814b6753c2 | ||
|
|
ce505251f8 | ||
|
|
5d2a987aaa | ||
|
|
4d67e08723 | ||
|
|
2e71dd5fe2 | ||
|
|
c3b9643227 | ||
|
|
0aad5dc2b7 | ||
|
|
cec900168f | ||
|
|
f9b1c403d5 | ||
|
|
9024b602f5 | ||
|
|
c139fd9a57 | ||
|
|
e299b68163 | ||
|
|
7777a53a82 | ||
|
|
3e185dbbfe | ||
|
|
e8a32af369 | ||
|
|
7b0ec6687e | ||
|
|
ec1c6c7b92 | ||
|
|
8dfaa86760 | ||
|
|
323aebd1be | ||
|
|
436c038a2f | ||
|
|
ccd50ec6c0 | ||
|
|
a7541c2c0f | ||
|
|
c3a57d756c | ||
|
|
aa300a4c98 | ||
|
|
83ea7352b9 | ||
|
|
9050712cd8 | ||
|
|
8d92fdbb6e | ||
|
|
a2442ec1b9 | ||
|
|
71662c9cd9 | ||
|
|
54ff5dbcc2 | ||
|
|
4ab7bd3b51 | ||
|
|
ef3c61a297 | ||
|
|
abf79bf60c | ||
|
|
5d3cecd926 | ||
|
|
16324e7283 | ||
|
|
9f7e2e1572 | ||
|
|
857ce1d530 | ||
|
|
be0d72775d | ||
|
|
7832a2495b | ||
|
|
0506b7f735 | ||
|
|
4c0b7942f0 | ||
|
|
651c840c4a | ||
|
|
2a351ca415 | ||
|
|
49b7106d71 | ||
|
|
8bf633f539 | ||
|
|
0f8efcb4b0 | ||
|
|
c567641c5c | ||
|
|
bdc3820382 | ||
|
|
33a69a7907 | ||
|
|
a4d0e9bbc3 | ||
|
|
afc753e1d2 | ||
|
|
e641a41224 | ||
|
|
79305c0632 | ||
|
|
ef2ce3f09d | ||
|
|
71c18c04fc | ||
|
|
cf84e57f81 | ||
|
|
9421d44579 | ||
|
|
5cd2ae8cc8 | ||
|
|
22d67b3a59 | ||
|
|
e102cbb8c4 | ||
|
|
d90eeb7ee4 | ||
|
|
1989d53031 | ||
|
|
04ef0907b4 | ||
|
|
517b43561c | ||
|
|
ccb8c7227f | ||
|
|
9fbfeeb04f | ||
|
|
8b753a5a1f | ||
|
|
d25cab0627 | ||
|
|
84da0a8a35 | ||
|
|
6f665cffba | ||
|
|
aea8ac2e97 | ||
|
|
8418fa7b45 | ||
|
|
9cc4d0ee07 | ||
|
|
da60831c44 | ||
|
|
0773174a20 | ||
|
|
70e007d8ca | ||
|
|
fcc4d02c2f | ||
|
|
f4a5f00593 | ||
|
|
1170ed6566 | ||
|
|
883f0d449b | ||
|
|
f4c62e7844 | ||
|
|
f0d212a9d2 | ||
|
|
76a8974034 | ||
|
|
0614e822f4 | ||
|
|
6f682c9a2e | ||
|
|
a9fdbc31c5 | ||
|
|
086fdb5856 | ||
|
|
63c8ef4f17 | ||
|
|
736f6523c7 | ||
|
|
8b0b360d25 | ||
|
|
80b84e2ee6 | ||
|
|
b5b7d86f7b | ||
|
|
f20d704390 | ||
|
|
e4e1e2e944 | ||
|
|
6bc7eeb4cc | ||
|
|
656ed5de7b | ||
|
|
a11d695c78 | ||
|
|
c4f9acd5c5 | ||
|
|
5ef929dc42 | ||
|
|
c8cf27b544 | ||
|
|
bb5ecfc398 | ||
|
|
c91e7c35bb | ||
|
|
532d56df2d | ||
|
|
111ad44029 | ||
|
|
6b02bae957 | ||
|
|
6831743416 | ||
|
|
63e2f42636 | ||
|
|
f6e6805453 | ||
|
|
ad77ad8f2b | ||
|
|
469524e8ae | ||
|
|
f4f55d5dfd | ||
|
|
c248d0f3f4 | ||
|
|
648a04b513 | ||
|
|
bdc86c16ec | ||
|
|
21efd17c17 | ||
|
|
aaa75e7b62 | ||
|
|
6d0cef3152 | ||
|
|
c18472289f | ||
|
|
02b7c70a81 | ||
|
|
4eaa2b93c6 | ||
|
|
d347905373 | ||
|
|
f495213b2c | ||
|
|
9b125913ae | ||
|
|
da81f05804 | ||
|
|
9a371a4d4d | ||
|
|
1e92828f1a | ||
|
|
7e724b3fa3 | ||
|
|
3f5b976a87 | ||
|
|
49f2339cc2 | ||
|
|
29f1699de8 | ||
|
|
c415485801 | ||
|
|
6937673472 | ||
|
|
c4f10fe876 | ||
|
|
55ca652ad8 | ||
|
|
3effd5afd1 | ||
|
|
000c2029de | ||
|
|
ab88e3af06 | ||
|
|
b544a4c954 | ||
|
|
baff5fafec | ||
|
|
1673de73ba | ||
|
|
e68936e36e | ||
|
|
7dbd195e45 | ||
|
|
3dc22f98bf | ||
|
|
805e870c18 | ||
|
|
de2c031797 | ||
|
|
3aa571aa1b | ||
|
|
3e4969efe6 | ||
|
|
446e94df76 | ||
|
|
5b26066a4c | ||
|
|
8a80de5c3f | ||
|
|
52a490c87e | ||
|
|
29490741fd | ||
|
|
f0e416455f | ||
|
|
f7a2c97943 | ||
|
|
993853757b | ||
|
|
a3abfb987d | ||
|
|
2711fa1b1b | ||
|
|
1f7afaba07 | ||
|
|
e02c8bff81 | ||
|
|
22391ba1a5 | ||
|
|
a05781ec19 | ||
|
|
f898ed6a2a | ||
|
|
e6d0a15b54 | ||
|
|
49cff026e2 | ||
|
|
08f0023cfd | ||
|
|
e311466ee6 | ||
|
|
56789e68d7 | ||
|
|
87525bb383 | ||
|
|
bb2880191a | ||
|
|
4f1acf26d6 | ||
|
|
fc2d6b21ac | ||
|
|
b9e84fefbd | ||
|
|
91f5ffb2d9 | ||
|
|
70ff2341cb | ||
|
|
74eed93497 | ||
|
|
d02e26c014 | ||
|
|
523cade7c3 | ||
|
|
e22c183ca9 | ||
|
|
3afd99da30 | ||
|
|
f44979f983 | ||
|
|
095f9cc108 | ||
|
|
1089076fce | ||
|
|
cad3b691a9 | ||
|
|
bac21426d3 | ||
|
|
c4a35314cd | ||
|
|
7090722565 | ||
|
|
6d972c7c18 | ||
|
|
6961a88feb | ||
|
|
c41ec13984 | ||
|
|
ca8e06e562 | ||
|
|
200cd33a8e | ||
|
|
1da7991c65 | ||
|
|
fdfb7e369a | ||
|
|
c2b01cc957 | ||
|
|
5de8e94bb4 | ||
|
|
7a2c15d912 | ||
|
|
70344dd214 | ||
|
|
405372d1a7 | ||
|
|
b8c5174da5 | ||
|
|
1f6f9103d9 | ||
|
|
6431487c7a | ||
|
|
8b2d1189db | ||
|
|
b777f27cb7 | ||
|
|
b31c3b124a | ||
|
|
fa1e965fba | ||
|
|
91dc8b4d58 | ||
|
|
6d16ea8830 | ||
|
|
7db4253264 | ||
|
|
4d2b7d9bf9 | ||
|
|
8f6f4acb88 | ||
|
|
f20d84cb37 | ||
|
|
afbdf1d5d5 | ||
|
|
bc8364d594 | ||
|
|
c8d388f70f | ||
|
|
be13cc3194 | ||
|
|
a46320e744 | ||
|
|
071709d263 | ||
|
|
93a32ae5ff | ||
|
|
eee96f226f | ||
|
|
e19a8b479c | ||
|
|
9ef459112e | ||
|
|
e96474bd5c | ||
|
|
6fed719e09 | ||
|
|
99aac76618 | ||
|
|
599f458201 | ||
|
|
2f8099059c | ||
|
|
e24f177832 | ||
|
|
48cc143e88 | ||
|
|
b09b46c045 | ||
|
|
2c6583cc9c | ||
|
|
e381d1bfb8 | ||
|
|
eac619d54f | ||
|
|
a6ef3bc0ce | ||
|
|
118122c541 | ||
|
|
bfdf33ac09 | ||
|
|
fa3370df5b | ||
|
|
f1e51672c5 | ||
|
|
91f97b2728 | ||
|
|
2c542e03fe | ||
|
|
71a11b4267 | ||
|
|
ea642757db | ||
|
|
fb72b601aa | ||
|
|
27e507e744 | ||
|
|
4db19f816f | ||
|
|
096d5776d1 | ||
|
|
3d799eb4d9 | ||
|
|
e4ac3afa4d | ||
|
|
d38e4eed5b | ||
|
|
97787fac91 | ||
|
|
b494ee2f1c | ||
|
|
31ac80a074 | ||
|
|
c8896450f6 | ||
|
|
c662fa4c63 | ||
|
|
db2ee802ca | ||
|
|
d40e915e2b | ||
|
|
c0616e7efa | ||
|
|
01660597e3 | ||
|
|
c5b549f450 | ||
|
|
802d8457bb | ||
|
|
c3a3df67b0 | ||
|
|
5798aeb3cd | ||
|
|
cc81dd9172 | ||
|
|
44fdadda08 | ||
|
|
66a014150b | ||
|
|
1da596639f | ||
|
|
76614ae9e5 | ||
|
|
6ddddffc0f | ||
|
|
dd95f849d4 | ||
|
|
22c7f8fe9e | ||
|
|
3d47be1f49 | ||
|
|
5e399c46b1 | ||
|
|
38e1db7a37 | ||
|
|
8309f7cdbe | ||
|
|
b8cc62ae95 | ||
|
|
c0eb433fa2 | ||
|
|
7f857d66f6 | ||
|
|
93b14d38f4 | ||
|
|
21825faab0 | ||
|
|
1fafd39298 | ||
|
|
23b750fc4f | ||
|
|
90581c840d | ||
|
|
cac7a6228a | ||
|
|
674fbc3f69 | ||
|
|
9577bf1cc7 | ||
|
|
654ebe93e7 | ||
|
|
ecb1b3c491 | ||
|
|
c3d1711edc | ||
|
|
c12c7f10f0 | ||
|
|
f71820bf4e | ||
|
|
748c53c774 | ||
|
|
b290a71bfb | ||
|
|
3204c51eca | ||
|
|
2c4b8a44dc | ||
|
|
943aa05eaa | ||
|
|
d0fd36e7e1 | ||
|
|
f45ff5fd0a | ||
|
|
c22c7102d5 | ||
|
|
11ecfd1b41 | ||
|
|
798e30e5ac | ||
|
|
15e0702329 | ||
|
|
a2bc22c37d | ||
|
|
8093fcc64c | ||
|
|
800419e7cc | ||
|
|
a241dc6785 | ||
|
|
805bea0d5f | ||
|
|
9d394adf24 | ||
|
|
2074f27aff | ||
|
|
283ad48b86 | ||
|
|
07e10a7943 | ||
|
|
2812a5026c | ||
|
|
3a20461abf | ||
|
|
64ae3d1e21 | ||
|
|
a25d7ea65b | ||
|
|
74ebbdd761 | ||
|
|
a0427b569e | ||
|
|
5346dfdd8b | ||
|
|
3ee4147285 | ||
|
|
c41e486bfc | ||
|
|
eda3ba92fd | ||
|
|
40255290b0 | ||
|
|
af5bc73dc0 | ||
|
|
0247cd4c45 | ||
|
|
916762cc8c | ||
|
|
d6fdf8ca2a | ||
|
|
95708489c9 | ||
|
|
ced0fa4608 | ||
|
|
7e0fbd600f | ||
|
|
f33e4e0323 | ||
|
|
d0fd78497d | ||
|
|
8045019603 | ||
|
|
7d92b9435e | ||
|
|
1e0822703a | ||
|
|
0403ff88ef | ||
|
|
78376d591b | ||
|
|
8e23d0df20 | ||
|
|
9e281d20ab | ||
|
|
644bd4a106 | ||
|
|
7729e66a96 | ||
|
|
d67d6b7948 | ||
|
|
4c4a46bfbe | ||
|
|
4536f9c177 | ||
|
|
977d3bc02e | ||
|
|
eae95dfef5 | ||
|
|
b67d4460ca | ||
|
|
3dea8311b1 | ||
|
|
11f6e98874 | ||
|
|
2609e595f4 | ||
|
|
ac6e41abc8 | ||
|
|
9c17e16d0a | ||
|
|
55e9064307 | ||
|
|
91cabd7d49 | ||
|
|
7456950530 | ||
|
|
8fcdda625d | ||
|
|
40a10ee926 | ||
|
|
c3f7e2645c | ||
|
|
b264af1892 | ||
|
|
43e93e8e22 | ||
|
|
d6c4789688 | ||
|
|
cb31ee6f01 | ||
|
|
f7b694ac56 | ||
|
|
eb809055d4 | ||
|
|
78d9be82b2 | ||
|
|
76a95c0226 | ||
|
|
d3ab8fb04a | ||
|
|
f7a0b63a00 | ||
|
|
a21dd97786 | ||
|
|
04943c0bfa | ||
|
|
203d4d8bfb | ||
|
|
c049a619dc | ||
|
|
cc1b14b607 | ||
|
|
e04a12a8f4 | ||
|
|
a2c82bc583 | ||
|
|
b4dc382f7c | ||
|
|
eca1892e2a | ||
|
|
23a237074e | ||
|
|
219e9eca4f | ||
|
|
413e09fb9e | ||
|
|
3514c37e4c | ||
|
|
95260e303c | ||
|
|
0cef34bdfa | ||
|
|
9838979bbd | ||
|
|
c8910b8e14 | ||
|
|
207fa1d019 | ||
|
|
be0bb591e7 | ||
|
|
bfacdb9c3b | ||
|
|
ae4077ed6c | ||
|
|
6eb3c90e18 | ||
|
|
8c2a53a504 | ||
|
|
74db1e0308 | ||
|
|
b9dfdcef3d | ||
|
|
9d4afeac31 | ||
|
|
14ae2f169a | ||
|
|
55df19142f | ||
|
|
40fd545b2c | ||
|
|
95fb07343e | ||
|
|
4d87906559 | ||
|
|
6b30dced43 | ||
|
|
293a03b7c8 | ||
|
|
c010549f17 | ||
|
|
cc0be22026 | ||
|
|
e5ba26febe | ||
|
|
36f9680eec | ||
|
|
f4f5be5b08 | ||
|
|
d89b056886 | ||
|
|
65424c7db9 | ||
|
|
32a8a847fc | ||
|
|
88fb3dbf60 | ||
|
|
f6bee3aa58 | ||
|
|
5f19f37dcb | ||
|
|
dd36d8ce9e | ||
|
|
865e4b5349 | ||
|
|
e70564752b | ||
|
|
6e0d2f9437 | ||
|
|
291f936097 | ||
|
|
0b2ce48586 | ||
|
|
da87fd9e20 | ||
|
|
d4da4d2575 | ||
|
|
bad20ff483 | ||
|
|
21ad51ffbf | ||
|
|
697c6d5fbe | ||
|
|
293c659053 | ||
|
|
a12507abbd | ||
|
|
4e675b84fb | ||
|
|
c1022feab8 | ||
|
|
ddcfcf21fe | ||
|
|
86a58c3d80 | ||
|
|
abf9a9048d | ||
|
|
b1030a527a | ||
|
|
8d07ba6332 | ||
|
|
4ce37f84e4 | ||
|
|
061d8a3a5f | ||
|
|
374cd5dbb8 | ||
|
|
5ad53c2b9c | ||
|
|
a2ec1a063d | ||
|
|
e431dbe2df | ||
|
|
7218463f9e | ||
|
|
aeb09a95b0 | ||
|
|
0c8f292e12 | ||
|
|
f001ac6903 | ||
|
|
db8e506de0 | ||
|
|
099f859dd4 | ||
|
|
b7684c1c2b | ||
|
|
058c167f79 | ||
|
|
49446d4872 | ||
|
|
ced560e1e1 | ||
|
|
339102c3cd | ||
|
|
6331350239 | ||
|
|
34e06fcbf8 | ||
|
|
70aac312ff | ||
|
|
5e00704152 | ||
|
|
1a9edb6907 | ||
|
|
0c18c3a6dd | ||
|
|
847bb51ce4 | ||
|
|
fa60a5dc63 | ||
|
|
aaed3f9839 | ||
|
|
21b956b983 | ||
|
|
792e940279 | ||
|
|
c2477b26c0 | ||
|
|
4b27de809b | ||
|
|
572932d8e8 | ||
|
|
270dd778d9 | ||
|
|
dd04287b0a | ||
|
|
36ac6d005a | ||
|
|
701daedf49 | ||
|
|
238f05f453 | ||
|
|
dd082bd212 | ||
|
|
cfd2f27b0b | ||
|
|
a2160d135e | ||
|
|
16d7836369 | ||
|
|
f3de4dcc5f | ||
|
|
e34523028f | ||
|
|
efe2fbacd6 | ||
|
|
2fa1df29be | ||
|
|
f72cd13fba | ||
|
|
5b552dffbf | ||
|
|
a0ae2d13dc | ||
|
|
f7262a0a3a | ||
|
|
9736f121eb | ||
|
|
7c8fb7eacc | ||
|
|
b45eea5908 | ||
|
|
6babf4ee6c | ||
|
|
576526d4ee | ||
|
|
c03e31b7be | ||
|
|
a1aa925019 | ||
|
|
a5a234ed97 | ||
|
|
5b5dbcd78b | ||
|
|
bd1c6361d3 | ||
|
|
1fc1febf03 | ||
|
|
55cc35efa9 | ||
|
|
5ba8fdc5e7 | ||
|
|
6ea295e227 | ||
|
|
5010c76ef7 | ||
|
|
79c7f0c29f | ||
|
|
2b3e643786 | ||
|
|
90cdff327c | ||
|
|
55c116e727 | ||
|
|
3dd83aa6b7 | ||
|
|
a74aa12641 | ||
|
|
151e8c69f9 | ||
|
|
d8bfa77705 | ||
|
|
6bd286e8d5 | ||
|
|
905532b681 | ||
|
|
04d5c1ab01 | ||
|
|
28be141dc7 | ||
|
|
652b786baf | ||
|
|
ba6c671051 | ||
|
|
ca25d0433f | ||
|
|
5338106dfa | ||
|
|
854d613a81 | ||
|
|
b6b76be4f6 | ||
|
|
03d94fcfa0 | ||
|
|
b2c5f0d455 | ||
|
|
54f60dd38c | ||
|
|
42f181aca2 | ||
|
|
9c3a27894f | ||
|
|
f7cd348912 | ||
|
|
aeaeb75d3b | ||
|
|
96542b532e | ||
|
|
139295fe0d | ||
|
|
13217b2ce2 | ||
|
|
5cc8b56a7c | ||
|
|
e23e01c95e | ||
|
|
bca8ba12c7 | ||
|
|
3c44bdbe1c | ||
|
|
db93ed025b | ||
|
|
4209e108d0 | ||
|
|
14cbf011af | ||
|
|
03a41ec199 | ||
|
|
125fe2a026 | ||
|
|
ac4adac29e | ||
|
|
ac449d078e | ||
|
|
79be4530d4 | ||
|
|
85ce52d70c | ||
|
|
7ab56b9076 | ||
|
|
dedf976375 | ||
|
|
89f438208a | ||
|
|
ffbc5080ae | ||
|
|
4167f13bac | ||
|
|
6ba0baabb0 | ||
|
|
081003df47 | ||
|
|
559194ffb2 | ||
|
|
97a26d4a46 | ||
|
|
503c6c9b7e | ||
|
|
9a1e10deff | ||
|
|
054f927c05 | ||
|
|
22210747d0 | ||
|
|
53b2deb72c | ||
|
|
6fc158e7d6 | ||
|
|
a23a65c731 | ||
|
|
7dc7105ee2 | ||
|
|
bac70108b2 | ||
|
|
297404b21e | ||
|
|
33a7f8b558 | ||
|
|
4a670b7df7 | ||
|
|
79e4af315e | ||
|
|
c6e31b2fdc | ||
|
|
91dc44df53 | ||
|
|
7e57f8f157 | ||
|
|
15f6b7c6d3 | ||
|
|
b213ba541d | ||
|
|
7c6ed9944e | ||
|
|
a5a825e439 | ||
|
|
a4ab547f77 | ||
|
|
76ed763abe | ||
|
|
b9e3125610 | ||
|
|
8d9d5b7b6f | ||
|
|
187601da1e | ||
|
|
cc3a0fc367 | ||
|
|
44cc4165d1 | ||
|
|
f98b43514e | ||
|
|
3c9b1a14e9 | ||
|
|
827e8eddf8 | ||
|
|
7bc27d6167 | ||
|
|
ba06edd63a | ||
|
|
cacf553a5b | ||
|
|
d89091a8ea | ||
|
|
01a56e1155 | ||
|
|
a64d7c42b1 | ||
|
|
36b6cc58bf | ||
|
|
5ac8a257e7 | ||
|
|
74119d0372 | ||
|
|
4e162c73e5 | ||
|
|
5ff753a492 | ||
|
|
89400630c0 | ||
|
|
3899c0cfe3 | ||
|
|
a086f1989f | ||
|
|
1171b04e93 | ||
|
|
c55d81825a | ||
|
|
2dcd026e9f | ||
|
|
cdf8609d24 | ||
|
|
36580c5f7f | ||
|
|
1cff2521f4 | ||
|
|
db4998a56b | ||
|
|
acbd506568 | ||
|
|
0cf8e3be73 | ||
|
|
2473334dfc | ||
|
|
1ff72d1d37 | ||
|
|
241fad5524 | ||
|
|
1b48cea50a | ||
|
|
88bf345b91 | ||
|
|
ab4ff3d1a3 | ||
|
|
3502e0d643 | ||
|
|
995894d3aa | ||
|
|
4da8714124 | ||
|
|
6b247ae880 | ||
|
|
176941ea3b | ||
|
|
5176b56d3b | ||
|
|
8abf18ab25 | ||
|
|
395edbd9f4 | ||
|
|
2386eb8fc2 | ||
|
|
68208f82a0 | ||
|
|
ca916b7ce5 | ||
|
|
01e02934da | ||
|
|
c81a79f7b9 | ||
|
|
1133648bf6 | ||
|
|
e05bc541d7 | ||
|
|
d689d20482 | ||
|
|
39dd99b272 | ||
|
|
cda21acb43 | ||
|
|
9bd7d09f20 | ||
|
|
b22994c2d2 | ||
|
|
e027286b6d | ||
|
|
d6e16995e0 | ||
|
|
782bff3a51 | ||
|
|
de26dc0597 | ||
|
|
233b24ab0f | ||
|
|
2f9e5b1219 | ||
|
|
dd36b8b150 | ||
|
|
f81ac31fe1 | ||
|
|
24b63bc5bd | ||
|
|
1817a972c6 | ||
|
|
74a253f521 | ||
|
|
41762a1c57 | ||
|
|
a786fa4b75 | ||
|
|
e4c7602c0c | ||
|
|
e0d2e34980 | ||
|
|
9ef8e1be3f | ||
|
|
aae9b64833 | ||
|
|
4bab4299f2 | ||
|
|
954e55f4b4 | ||
|
|
2361e3c28c | ||
|
|
8224c2fc16 | ||
|
|
8aac86f0a9 | ||
|
|
6384e9310b | ||
|
|
7a9205dfba | ||
|
|
94b47a56f4 | ||
|
|
709b5be634 | ||
|
|
f970b2c168 | ||
|
|
973acb37ed | ||
|
|
1c9020a565 | ||
|
|
c5f1d0042c | ||
|
|
fa706e8b1d | ||
|
|
12c170f227 | ||
|
|
db27dfe227 | ||
|
|
2db4673392 | ||
|
|
38619db629 | ||
|
|
930fd436ea | ||
|
|
98b8ff2fc8 | ||
|
|
d0662683f9 | ||
|
|
957f2574a9 | ||
|
|
109b362ebd | ||
|
|
ff3fdfa738 | ||
|
|
e2636ed54a | ||
|
|
dbe2f17e1a | ||
|
|
4dc535673f | ||
|
|
f414b6408e | ||
|
|
3aa2e6a04d | ||
|
|
1963ff273f | ||
|
|
bb737a71d5 | ||
|
|
a582a46ce9 | ||
|
|
abf80a3266 | ||
|
|
d768f5c66d | ||
|
|
b25e843351 | ||
|
|
419a3e518e | ||
|
|
d1b867a7c0 | ||
|
|
c34d70b3cb | ||
|
|
a33df9312f | ||
|
|
ebf8db0b37 | ||
|
|
e539ae3b69 | ||
|
|
4c5e8850aa | ||
|
|
94c0af3037 | ||
|
|
165182c68f | ||
|
|
65b9542599 | ||
|
|
d01d1f8830 | ||
|
|
ad3e9f3d42 | ||
|
|
4589974095 | ||
|
|
ed4553ddf8 | ||
|
|
ff97ae73f1 | ||
|
|
f96b4d2781 | ||
|
|
ce32cfffdb | ||
|
|
f66df8531e | ||
|
|
dfe1c23e76 | ||
|
|
07fd81919f | ||
|
|
210042bb81 | ||
|
|
12dc7427e9 | ||
|
|
b476085110 | ||
|
|
776cdaf63c | ||
|
|
69b6855745 | ||
|
|
3590babd8b | ||
|
|
c29d391c1d | ||
|
|
50e44dbb2a | ||
|
|
34277a3940 | ||
|
|
f1a00d58ca | ||
|
|
d1a5f17ae8 | ||
|
|
4dbc54fa15 | ||
|
|
1d4ff796d7 | ||
|
|
44cb54a9ea | ||
|
|
6409f49609 | ||
|
|
9ee0ea88b5 | ||
|
|
a3819d8673 | ||
|
|
2d7dd71a3d | ||
|
|
0e8195ae61 | ||
|
|
3e92d07618 | ||
|
|
e59597280d | ||
|
|
f2e3d69d8a | ||
|
|
9d2cb75c84 | ||
|
|
f971505c4a | ||
|
|
2133c1d6af | ||
|
|
0bf06ddfd3 | ||
|
|
024a50d642 | ||
|
|
e4eebd64d1 | ||
|
|
c9055989e9 | ||
|
|
4f1ed197ce | ||
|
|
3e710aa2a1 | ||
|
|
b6226a45bb | ||
|
|
3001ba9266 | ||
|
|
b0a401a1ed | ||
|
|
6b4dc37428 | ||
|
|
8528c9b262 | ||
|
|
7222a5c2f4 | ||
|
|
59050001ef | ||
|
|
2ba8f18724 | ||
|
|
fb22e01b89 | ||
|
|
76a81d5360 | ||
|
|
3314b05648 | ||
|
|
45b89218de | ||
|
|
beb7bda243 | ||
|
|
bef2896f50 | ||
|
|
9fea949b25 | ||
|
|
be258e5b05 | ||
|
|
008178d737 | ||
|
|
527d5e1dbc | ||
|
|
9b47e2d6f9 | ||
|
|
8781b1e976 | ||
|
|
38c653d8d8 | ||
|
|
74e48bb137 | ||
|
|
c3aaa1f735 | ||
|
|
bead2aa228 | ||
|
|
dc52ab8aa9 | ||
|
|
20b71f206b | ||
|
|
73c87d5959 | ||
|
|
c6601aaeed | ||
|
|
6e14fce1fe | ||
|
|
be5a62f1b8 | ||
|
|
1fa8cefaea | ||
|
|
d7c251ac83 | ||
|
|
d03229a183 | ||
|
|
243482e829 | ||
|
|
79d10be8a0 | ||
|
|
dca5c058e0 | ||
|
|
9163ce71fd | ||
|
|
2ec5374765 | ||
|
|
d6a4b35cd3 | ||
|
|
8205d2552c | ||
|
|
9a99caeb9d | ||
|
|
1e09bd0e76 | ||
|
|
cae12eb187 | ||
|
|
8bb36e0eb6 | ||
|
|
d183204caa | ||
|
|
4a22ae6b61 | ||
|
|
a52f54d988 | ||
|
|
618c94edb8 | ||
|
|
eaf4e9174f | ||
|
|
4af2c7f3d7 | ||
|
|
361f599df0 | ||
|
|
ffe4ea5e4c | ||
|
|
9461e3e01a | ||
|
|
7c85c6f742 | ||
|
|
b5df6faadf | ||
|
|
7cefe2d825 | ||
|
|
350633b69b | ||
|
|
1cd6a71ce0 | ||
|
|
3a08b002a0 | ||
|
|
665001732b | ||
|
|
cca49da730 | ||
|
|
f6d370ad29 | ||
|
|
c9131b333b | ||
|
|
e44161bf42 | ||
|
|
a26189fb25 | ||
|
|
89dd8a1db6 | ||
|
|
650e0b4ad4 | ||
|
|
c60f0517fb | ||
|
|
0f8dc91a8b | ||
|
|
b58feb5d8e | ||
|
|
71c8043699 | ||
|
|
40264bc9cb | ||
|
|
a7772316f9 | ||
|
|
34209021c8 | ||
|
|
3e9e8d442a | ||
|
|
d2bf90c6c7 | ||
|
|
1e58c1ad2b | ||
|
|
8cea022ec5 | ||
|
|
f32f8aa08e | ||
|
|
3ea8781381 | ||
|
|
ab83dacb76 | ||
|
|
4cbf46fd4d | ||
|
|
0a7d6e4577 | ||
|
|
df4c1f0401 | ||
|
|
9a86a67984 | ||
|
|
a0cbe9c3e2 | ||
|
|
a83e5a9b65 | ||
|
|
de33911460 | ||
|
|
0be56e5b25 | ||
|
|
abcbb34b1c | ||
|
|
6a13dd04a3 | ||
|
|
f2e29f3f2e | ||
|
|
68361cddd2 | ||
|
|
6404332adc | ||
|
|
e060b6fea2 | ||
|
|
e8aae27ee9 | ||
|
|
2f732e5493 | ||
|
|
65f20ff2c1 | ||
|
|
8f72e8c3e6 | ||
|
|
3b8972ce1f | ||
|
|
fc5d3e4e9c | ||
|
|
29fbf69945 | ||
|
|
583440b82b | ||
|
|
720de9d73f | ||
|
|
78332d882b | ||
|
|
2dfbc840b3 | ||
|
|
0b4bf15163 | ||
|
|
2989249e4b | ||
|
|
9cef559a05 | ||
|
|
47fe16c92a | ||
|
|
36b5c821ff | ||
|
|
82ec440b45 | ||
|
|
88f4a45cae | ||
|
|
7fb4f72b84 | ||
|
|
d4fc322101 | ||
|
|
8fa3da9ca5 | ||
|
|
68ef5aa3ae | ||
|
|
28bd917c9f | ||
|
|
0eb1b94300 | ||
|
|
15e6cf850b | ||
|
|
f687b2b6f4 | ||
|
|
8ee7a48151 |
31
.github/ISSUE_TEMPLATE.md
vendored
31
.github/ISSUE_TEMPLATE.md
vendored
@@ -1,31 +0,0 @@
|
||||
### 前置确认
|
||||
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
||||
3. `git pull` 拉取最新代码
|
||||
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
||||
5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
|
||||
6. 在已有 issue 中未搜索到类似问题
|
||||
7. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
|
||||
|
||||
### 问题描述
|
||||
|
||||
> 简要说明、截图、复现步骤等,也可以是需求或想法
|
||||
|
||||
|
||||
|
||||
|
||||
### 终端日志 (如有报错)
|
||||
|
||||
```
|
||||
[在此处粘贴终端日志, 可在主目录下`run.log`文件中找到]
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 环境
|
||||
|
||||
- 操作系统类型 (Mac/Windows/Linux):
|
||||
- Python版本 ( 执行 `python3 -V` ):
|
||||
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`):
|
||||
131
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
Normal file
131
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
Normal file
@@ -0,0 +1,131 @@
|
||||
name: Bug report 🐛
|
||||
description: 项目运行中遇到的Bug或问题。
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
### ⚠️ 前置确认
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
||||
3. `git pull` 拉取最新代码
|
||||
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
||||
5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
|
||||
6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 前置确认
|
||||
options:
|
||||
- label: 我确认我运行的是最新版本的代码,并且安装了所需的依赖,在[FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs)中也未找到类似问题。
|
||||
required: true
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: ⚠️ 搜索issues中是否已存在类似问题
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索你的问题
|
||||
或相关日志的关键词来查找是否存在类似问题。
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有跟我遇到的问题相关的issue
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写你对你所遇到问题的简略总结,这将帮助其他人更好的找到相似问题,谢谢❤️。
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 操作系统类型?
|
||||
description: >
|
||||
请选择你运行程序的操作系统类型。
|
||||
options:
|
||||
- Windows
|
||||
- Linux
|
||||
- MacOS
|
||||
- Docker
|
||||
- Railway
|
||||
- Windows Subsystem for Linux (WSL)
|
||||
- Other (请在问题中说明)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的python版本是?
|
||||
description: |
|
||||
请选择你运行程序的`python`版本。
|
||||
注意:在`python 3.7`中,有部分可选依赖无法安装。
|
||||
经过长时间的观察,我们认为`python 3.8`是兼容性最好的版本。
|
||||
`python 3.7`~`python 3.10`以外版本的issue,将视情况直接关闭。
|
||||
options:
|
||||
- python 3.7
|
||||
- python 3.8
|
||||
- python 3.9
|
||||
- python 3.10
|
||||
- other
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 使用的chatgpt-on-wechat版本是?
|
||||
description: |
|
||||
请确保你使用的是 [releases](https://github.com/zhayujie/chatgpt-on-wechat/releases) 中的最新版本。
|
||||
如果你使用git, 请使用`git branch`命令来查看分支。
|
||||
options:
|
||||
- Latest Release
|
||||
- Master (branch)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的`channel`类型是?
|
||||
description: |
|
||||
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
|
||||
options:
|
||||
- wechatmp(公众号, 订阅号)
|
||||
- wechatmp_service(公众号, 服务号)
|
||||
- terminal
|
||||
- other
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 复现步骤 🕹
|
||||
description: |
|
||||
**⚠️ 不能复现将会关闭issue.**
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 问题描述 😯
|
||||
description: 详细描述出现的问题,或提供有关截图。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 终端日志 📒
|
||||
description: |
|
||||
在此处粘贴终端日志,可在主目录下`run.log`文件中找到,这会帮助我们更好的分析问题,注意隐去你的API key。
|
||||
如果在配置文件中加入`"debug": true`,打印出的日志会更有帮助。
|
||||
|
||||
<details>
|
||||
<summary><i>示例</i></summary>
|
||||
```log
|
||||
[DEBUG][2023-04-16 00:23:22][plugin_manager.py:157] - Plugin SUMMARY triggered by event Event.ON_HANDLE_CONTEXT
|
||||
[DEBUG][2023-04-16 00:23:22][main.py:221] - [Summary] on_handle_context. content: $总结前100条消息
|
||||
[DEBUG][2023-04-16 00:23:24][main.py:240] - [Summary] limit: 100, duration: -1 seconds
|
||||
[ERROR][2023-04-16 00:23:24][chat_channel.py:244] - Worker return exception: name 'start_date' is not defined
|
||||
Traceback (most recent call last):
|
||||
File "C:\ProgramData\Anaconda3\lib\concurrent\futures\thread.py", line 57, in run
|
||||
result = self.fn(*self.args, **self.kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 132, in _handle
|
||||
reply = self._generate_reply(context)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 142, in _generate_reply
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
||||
File "D:\project\chatgpt-on-wechat\plugins\plugin_manager.py", line 159, in emit_event
|
||||
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 255, in on_handle_context
|
||||
records = self._get_records(session_id, start_time, limit)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 96, in _get_records
|
||||
c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", (session_id, start_date, limit))
|
||||
NameError: name 'start_date' is not defined
|
||||
[INFO][2023-04-16 00:23:36][app.py:14] - signal 2 received, exiting...
|
||||
```
|
||||
</details>
|
||||
value: |
|
||||
```log
|
||||
<此处粘贴终端日志>
|
||||
```
|
||||
28
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
Normal file
28
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Feature request 🚀
|
||||
description: 提出你对项目的新想法或建议。
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写简略总结,谢谢❤️。
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: ⚠️ 搜索是否存在类似issue
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索关键词查找是否存在相似issue。
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有发现相似issue
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 总结
|
||||
description: 描述feature的功能。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 举例
|
||||
description: 提供聊天示例,草图或相关网址。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 动机
|
||||
description: 描述你提出该feature的动机,比如没有这项feature对你的使用造成了怎样的影响。 请提供更详细的场景描述,这可能会帮助我们发现并提出更好的解决方案。
|
||||
72
.github/workflows/deploy-image-arm.yml
vendored
Normal file
72
.github/workflows/deploy-image-arm.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
# GitHub recommends pinning actions to a commit SHA.
|
||||
# To get a newer version, you will need to update the SHA.
|
||||
# You can also reference a tag or branch, but the action may change without warning.
|
||||
|
||||
name: Create and publish a Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ['master']
|
||||
create:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Available platforms
|
||||
run: echo ${{ steps.buildx.outputs.platforms }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
platforms: linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}-arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
with:
|
||||
package-name: 'chatgpt-on-wechat'
|
||||
package-type: 'container'
|
||||
min-versions-to-keep: 10
|
||||
delete-only-untagged-versions: 'true'
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
15
.github/workflows/deploy-image.yml
vendored
15
.github/workflows/deploy-image.yml
vendored
@@ -19,6 +19,7 @@ env:
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -28,6 +29,12 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
@@ -39,7 +46,9 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
images: |
|
||||
${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
@@ -49,9 +58,9 @@ jobs:
|
||||
file: ./docker/Dockerfile.latest
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
with:
|
||||
with:
|
||||
package-name: 'chatgpt-on-wechat'
|
||||
package-type: 'container'
|
||||
min-versions-to-keep: 10
|
||||
|
||||
21
.gitignore
vendored
21
.gitignore
vendored
@@ -1,18 +1,23 @@
|
||||
.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
.wechaty/
|
||||
.venv
|
||||
.vs
|
||||
__pycache__/
|
||||
venv*
|
||||
*.pyc
|
||||
python
|
||||
config.json
|
||||
QR.png
|
||||
nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
logs/
|
||||
workspace
|
||||
config.yaml
|
||||
user_datas.pkl
|
||||
chatgpt_tool_hub/
|
||||
plugins/**/
|
||||
!plugins/bdunit
|
||||
!plugins/dungeon
|
||||
@@ -20,5 +25,15 @@ plugins/**/
|
||||
!plugins/godcmd
|
||||
!plugins/tool
|
||||
!plugins/banwords
|
||||
!plugins/banwords/**/
|
||||
plugins/banwords/__pycache__
|
||||
plugins/banwords/lib/__pycache__
|
||||
!plugins/hello
|
||||
!plugins/role
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
!plugins/agent
|
||||
client_config.json
|
||||
ref/
|
||||
.cursor/
|
||||
local/
|
||||
|
||||
865
README.md
865
README.md
@@ -1,75 +1,112 @@
|
||||
<p align="center"><img src= "https://github.com/user-attachments/assets/eca9a9ec-8534-4615-9e0f-96c5ac1d10a3" alt="Chatgpt-on-Wechat" width="550" /></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/chatgpt-on-wechat" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/chatgpt-on-wechat" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat"><img src="https://img.shields.io/github/stars/zhayujie/chatgpt-on-wechat?style=flat-square" alt="Stars"></a> <br/>
|
||||
[中文] | [<a href="docs/en/README.md">English</a>]
|
||||
</p>
|
||||
|
||||
**CowAgent** 是基于大模型的超级AI助理,能够主动思考和任务规划、操作计算机和外部资源、创造和执行Skills、拥有长期记忆并不断成长。CowAgent 支持灵活切换多种模型,能处理文本、语音、图片、文件等多模态消息,可接入网页、飞书、钉钉、企微智能机器人、QQ、企微自建应用、微信公众号中使用,7*24小时运行于你的个人电脑或服务器中。
|
||||
|
||||
<p align="center">
|
||||
<a href="https://cowagent.ai/">🌐 官网</a> ·
|
||||
<a href="https://docs.cowagent.ai/">📖 文档中心</a> ·
|
||||
<a href="https://docs.cowagent.ai/guide/quick-start">🚀 快速开始</a>
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
# 简介
|
||||
|
||||
> ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。
|
||||
> 该项目既是一个可以开箱即用的超级AI助理,也是一个支持高扩展的Agent框架,可以通过为项目扩展大模型接口、接入渠道、内置工具、Skills系统来灵活实现各种定制需求。核心能力如下:
|
||||
|
||||
- ✅ **复杂任务规划**:能够理解复杂任务并自主规划执行,持续思考和调用工具直到完成目标,支持通过工具操作访问文件、终端、浏览器、定时任务等系统资源
|
||||
- ✅ **长期记忆:** 自动将对话记忆持久化至本地文件和数据库中,包括全局记忆和天级记忆,支持关键词及向量检索
|
||||
- ✅ **技能系统:** 实现了Skills创建和运行的引擎,内置多种技能,并支持通过自然语言对话完成自定义Skills开发
|
||||
- ✅ **多模态消息:** 支持对文本、图片、语音、文件等多类型消息进行解析、处理、生成、发送等操作
|
||||
- ✅ **多模型接入:** 支持OpenAI, Claude, Gemini, DeepSeek, MiniMax、GLM、Qwen、Kimi、Doubao等国内外主流模型厂商
|
||||
- ✅ **多端部署:** 支持运行在本地计算机或服务器,可集成到网页、飞书、钉钉、微信公众号、企业微信应用中使用
|
||||
- ✅ **知识库:** 集成企业知识库能力,让Agent成为专属数字员工,基于[LinkAI](https://link-ai.tech)平台实现
|
||||
|
||||
基于ChatGPT的微信聊天机器人,通过 [ChatGPT](https://github.com/openai/openai-python) 接口生成对话内容,使用 [itchat](https://github.com/littlecodersh/ItChat) 实现微信消息的接收和自动回复。已实现的特性如下:
|
||||
## 声明
|
||||
|
||||
- [x] **文本对话:** 接收私聊及群组中的微信消息,使用ChatGPT生成回复内容,完成自动回复
|
||||
- [x] **规则定制化:** 支持私聊中按指定规则触发自动回复,支持对群组设置自动回复白名单
|
||||
- [x] **多账号:** 支持多微信账号同时运行
|
||||
- [x] **图片生成:** 支持根据描述生成图片,并自动发送至个人聊天或群聊
|
||||
- [x] **上下文记忆**:支持多轮对话记忆,且为每个好友维护独立的上下会话
|
||||
- [x] **语音识别:** 支持接收和处理语音消息,通过文字或语音回复
|
||||
- [x] **插件化:** 支持个性化功能插件,提供角色扮演、文字冒险游戏等预设插件
|
||||
1. 本项目遵循 [MIT开源协议](/LICENSE),主要用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为。任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任。
|
||||
2. 成本与安全:Agent模式下Token使用量高于普通对话模式,请根据效果及成本综合选择模型。Agent具有访问所在操作系统的能力,请谨慎选择项目部署环境。同时项目也会持续升级安全机制、并降低模型消耗成本。
|
||||
3. CowAgent项目专注于开源技术开发,不会参与、授权或发行任何加密货币。
|
||||
|
||||
> 目前支持微信和微信个人号部署,欢迎接入更多应用,参考[`Terminal`代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。
|
||||
## 演示
|
||||
|
||||
使用说明(Agent模式):[CowAgent介绍](https://docs.cowagent.ai/intro/features)
|
||||
|
||||
快速部署:
|
||||
DEMO视频(对话模式):https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
## 社区
|
||||
|
||||
添加小助手微信加入开源项目交流群:
|
||||
|
||||
<img width="140" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
|
||||
<br/>
|
||||
|
||||
# 企业服务
|
||||
|
||||
<a href="https://link-ai.tech" target="_blank"><img width="720" src="https://cdn.link-ai.tech/image/link-ai-intro.jpg"></a>
|
||||
|
||||
> [LinkAI](https://link-ai.tech/) 是面向企业和开发者的一站式AI智能体平台,聚合多模态大模型、知识库、Agent 插件、工作流等能力,支持一键接入主流平台并进行管理,支持SaaS、私有化部署等多种模式。
|
||||
>
|
||||
>[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
> LinkAI 目前已在智能客服、私域运营、企业效率助手等场景积累了丰富的AI解决方案,在消费、健康、文教、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。
|
||||
|
||||
**产品咨询和企业服务** 可联系产品客服:
|
||||
|
||||
<img width="150" src="https://cdn.link-ai.tech/portal/linkai-customer-service.png">
|
||||
|
||||
<br/>
|
||||
|
||||
# 🏷 更新日志
|
||||
|
||||
>**2026.02.27:** [2.0.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/2.0.2),Web 控制台全面升级(流式对话、模型/技能/记忆/通道/定时任务/日志管理)、支持多通道同时运行、会话持久化存储、新增多个模型。
|
||||
|
||||
>**2026.02.13:** [2.0.1版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/2.0.1),内置 Web Search 工具、智能上下文裁剪策略、运行时信息动态更新、Windows 兼容性适配,修复定时任务记忆丢失、飞书连接等多项问题。
|
||||
|
||||
>**2026.02.03:** [2.0.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/2.0.0),正式升级为超级Agent助理,支持多轮任务决策、具备长期记忆、实现多种系统工具、支持Skills框架,新增多种模型并优化了接入渠道。
|
||||
|
||||
>**2025.05.23:** [1.7.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.6) 优化web网页channel、新增 [AgentMesh](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md)多智能体插件、百度语音合成优化、企微应用`access_token`获取优化、支持`claude-4-sonnet`和`claude-4-opus`模型
|
||||
|
||||
>**2025.04.11:** [1.7.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.5) 新增支持 [wechatferry](https://github.com/zhayujie/chatgpt-on-wechat/pull/2562) 协议、新增 deepseek 模型、新增支持腾讯云语音能力、新增支持 ModelScope 和 Gitee-AI API接口
|
||||
|
||||
更多更新历史请查看: [更新日志](https://docs.cowagent.ai/releases)
|
||||
|
||||
<br/>
|
||||
|
||||
# 🚀 快速开始
|
||||
|
||||
项目提供了一键安装、配置、启动、管理程序的脚本,推荐使用脚本快速运行,也可以根据下文中的详细指引一步步安装运行。
|
||||
|
||||
在终端执行以下命令:
|
||||
|
||||
```bash
|
||||
bash <(curl -sS https://cdn.link-ai.tech/code/cow/run.sh)
|
||||
```
|
||||
|
||||
脚本使用说明:[一键运行脚本](https://docs.cowagent.ai/guide/quick-start)
|
||||
|
||||
|
||||
# 更新日志
|
||||
## 一、准备
|
||||
|
||||
>**2023.04.05:** 支持微信个人号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
|
||||
### 1. 模型API
|
||||
|
||||
>**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
|
||||
项目支持国内外主流厂商的模型接口,可选模型及配置说明参考:[模型说明](#模型说明)。
|
||||
|
||||
>**2023.03.25:** 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 [#578](https://github.com/zhayujie/chatgpt-on-wechat/issues/578)。(contributed by [@lanvent](https://github.com/lanvent) in [#565](https://github.com/zhayujie/chatgpt-on-wechat/pull/565))
|
||||
> 注:Agent模式下推荐使用以下模型,可根据效果及成本综合选择:MiniMax-M2.5、glm-5、kimi-k2.5、qwen3.5-plus、claude-sonnet-4-6、gemini-3.1-pro-preview、gpt-5.4
|
||||
|
||||
>**2023.03.09:** 基于 `whisper API`(后续已接入更多的语音`API`服务) 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
|
||||
同时支持使用 **LinkAI平台** 接口,可灵活切换 OpenAI、Claude、Gemini、DeepSeek、Qwen、Kimi 等多种常用模型,并支持知识库、工作流、插件等Agent能力,参考 [接口文档](https://docs.link-ai.tech/platform/api)。
|
||||
|
||||
>**2023.03.02:** 接入[ChatGPT API](https://platform.openai.com/docs/guides/chat) (gpt-3.5-turbo),默认使用该模型进行对话,需升级openai依赖 (`pip3 install --upgrade openai`)。网络问题参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
### 2.环境安装
|
||||
|
||||
>**2023.02.09:** 扫码登录存在封号风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
|
||||
支持 Linux、MacOS、Windows 操作系统,可在个人计算机及服务器上运行,需安装 `Python`,Python版本需在3.7 ~ 3.12 之间,推荐使用3.9版本。
|
||||
|
||||
>**2023.02.05:** 在openai官方接口方案中 (GPT-3模型) 实现上下文对话
|
||||
|
||||
>**2022.12.18:** 支持根据描述生成图片并发送,openai版本需大于0.25.0
|
||||
|
||||
>**2022.12.17:** 原来的方案是从 [ChatGPT页面](https://chat.openai.com/chat) 获取session_token,使用 [revChatGPT](https://github.com/acheong08/ChatGPT) 直接访问web接口,但随着ChatGPT接入Cloudflare人机验证,这一方案难以在服务器顺利运行。 所以目前使用的方案是调用 OpenAI 官方提供的 [API](https://beta.openai.com/docs/api-reference/introduction),回复质量上基本接近于ChatGPT的内容,劣势是暂不支持有上下文记忆的对话,优势是稳定性和响应速度较好。
|
||||
|
||||
# 使用效果
|
||||
|
||||
### 个人聊天
|
||||
|
||||

|
||||
|
||||
### 群组聊天
|
||||
|
||||

|
||||
|
||||
### 图片生成
|
||||
|
||||

|
||||
|
||||
|
||||
# 快速开始
|
||||
|
||||
## 准备
|
||||
|
||||
### 1. OpenAI账号注册
|
||||
|
||||
前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,参考这篇 [教程](https://www.pythonthree.com/register-openai-chatgpt/) 可以通过虚拟手机号来接收验证码。创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。
|
||||
|
||||
> 项目中使用的对话模型是 davinci,计费方式是约每 750 字 (包含请求和回复) 消耗 $0.02,图片生成是每张消耗 $0.016,账号创建有免费的 $18 额度 (更新3.25: 最新注册的已经无免费额度了),使用完可以更换邮箱重新注册。
|
||||
|
||||
### 2.运行环境
|
||||
|
||||
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
|
||||
> 建议Python版本在 3.7.1~3.9.X 之间,推荐3.8版本,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
|
||||
> 注意:Agent模式推荐使用源码运行,若选择Docker部署则无需安装python环境和下载源码,可直接快进到下一节。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
|
||||
@@ -78,8 +115,10 @@ git clone https://github.com/zhayujie/chatgpt-on-wechat
|
||||
cd chatgpt-on-wechat/
|
||||
```
|
||||
|
||||
若遇到网络问题可使用国内仓库地址:https://gitee.com/zhayujie/chatgpt-on-wechat
|
||||
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
> 能够使用`itchat`创建机器人,并具有文字交流功能所需的最小依赖集合。
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
@@ -89,27 +128,9 @@ pip3 install -r requirements.txt
|
||||
```bash
|
||||
pip3 install -r requirements-optional.txt
|
||||
```
|
||||
> 如果某项依赖安装失败请注释掉对应的行再继续。
|
||||
如果某项依赖安装失败可注释掉对应的行后重试。
|
||||
|
||||
其中`tiktoken`要求`python`版本在3.8以上,它用于精确计算会话使用的tokens数量,强烈建议安装。
|
||||
|
||||
|
||||
使用`google`或`baidu`语音识别需安装`ffmpeg`,
|
||||
|
||||
默认的`openai`语音识别不需要安装`ffmpeg`。
|
||||
|
||||
参考[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)
|
||||
|
||||
使用`azure`语音功能需安装依赖(列在`requirements-optional.txt`内,但为便于`railway`部署已注释):
|
||||
|
||||
```bash
|
||||
pip3 install azure-cognitiveservices-speech
|
||||
```
|
||||
|
||||
> 目前默认发布的镜像和`railway`部署,都基于`apline`,无法安装`azure`的依赖。若有需求请自行基于[`debian`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/docker/Dockerfile.debian.latest)打包。
|
||||
参考[文档](https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/quickstarts/setup-platform?pivots=programming-language-python&tabs=linux%2Cubuntu%2Cdotnet%2Cjre%2Cmaven%2Cnodejs%2Cmac%2Cpypi)
|
||||
|
||||
## 配置
|
||||
## 二、配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
|
||||
@@ -117,106 +138,678 @@ pip3 install azure-cognitiveservices-speech
|
||||
cp config-template.json config.json
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改:
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证JSON格式的规范):
|
||||
|
||||
```bash
|
||||
# config.json文件内容示例
|
||||
{
|
||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
||||
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
# config.json 文件内容示例
|
||||
{
|
||||
"channel_type": "web", # 接入渠道类型,默认为web,支持修改为:feishu,dingtalk,wecom_bot,qq,wechatcom_app,wechatmp_service,wechatmp,terminal
|
||||
"model": "MiniMax-M2.5", # 模型名称
|
||||
"minimax_api_key": "", # MiniMax API Key
|
||||
"zhipu_ai_api_key": "", # 智谱GLM API Key
|
||||
"moonshot_api_key": "", # Kimi/Moonshot API Key
|
||||
"ark_api_key": "", # 豆包(火山方舟) API Key
|
||||
"dashscope_api_key": "", # 百炼(通义千问)API Key
|
||||
"claude_api_key": "", # Claude API Key
|
||||
"claude_api_base": "https://api.anthropic.com/v1", # Claude API 地址,修改可接入三方代理平台
|
||||
"gemini_api_key": "", # Gemini API Key
|
||||
"gemini_api_base": "https://generativelanguage.googleapis.com", # Gemini API地址
|
||||
"open_ai_api_key": "", # OpenAI API Key
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # OpenAI API 地址
|
||||
"linkai_api_key": "", # LinkAI API Key
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境需要开启代理的可填写该项,如 "127.0.0.1:7890"
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述,
|
||||
"voice_reply_voice": false, # 是否使用语音回复语音
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台接口
|
||||
"agent": true, # 是否启用Agent模式,启用后拥有多轮工具决策、长期记忆、Skills能力等
|
||||
"agent_workspace": "~/cow", # Agent的工作空间路径,用于存储memory、skills、系统设定等
|
||||
"agent_max_context_tokens": 40000, # Agent模式下最大上下文tokens,超出将自动丢弃最早的上下文
|
||||
"agent_max_context_turns": 30, # Agent模式下最大上下文记忆轮次,每轮包括一次用户提问和AI回复
|
||||
"agent_max_steps": 15 # Agent模式下单次任务的最大决策步数,超出后将停止继续调用工具
|
||||
}
|
||||
```
|
||||
**配置说明:**
|
||||
|
||||
**1.个人聊天**
|
||||
**配置补充说明:**
|
||||
|
||||
+ 个人聊天中,需要以 "bot"或"@bot" 为开头的内容触发机器人,对应配置项 `single_chat_prefix` (如果不需要以前缀触发可以填写 `"single_chat_prefix": [""]`)
|
||||
+ 机器人回复的内容会以 "[bot] " 作为前缀, 以区分真人,对应的配置项为 `single_chat_reply_prefix` (如果不需要前缀可以填写 `"single_chat_reply_prefix": ""`)
|
||||
|
||||
**2.群组聊天**
|
||||
|
||||
+ 群组聊天中,群名称需配置在 `group_name_white_list ` 中才能开启群聊自动回复。如果想对所有群聊生效,可以直接填写 `"group_name_white_list": ["ALL_GROUP"]`
|
||||
+ 默认只要被人 @ 就会触发机器人自动回复;另外群聊天中只要检测到以 "@bot" 开头的内容,同样会自动回复(方便自己触发),这对应配置项 `group_chat_prefix`
|
||||
+ 可选配置: `group_name_keyword_white_list`配置项支持模糊匹配群名称,`group_chat_keyword`配置项则支持模糊匹配群消息内容,用法与上述两个配置项相同。(Contributed by [evolay](https://github.com/evolay))
|
||||
+ `group_chat_in_one_session`:使群聊共享一个会话上下文,配置 `["ALL_GROUP"]` 则作用于所有群聊
|
||||
|
||||
**3.语音识别**
|
||||
<details>
|
||||
<summary>1. 语音配置</summary>
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊),但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊)
|
||||
</details>
|
||||
|
||||
**4.其他配置**
|
||||
<details>
|
||||
<summary>2. 其他配置</summary>
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放)
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。
|
||||
+ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
|
||||
+ `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
|
||||
+ `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
|
||||
+ `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `model`: 模型名称,Agent模式下推荐使用 `MiniMax-M2.5`、`glm-5`、`kimi-k2.5`、`qwen3.5-plus`、`claude-sonnet-4-6`、`gemini-3.1-pro-preview`,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `character_desc`:普通对话模式下的机器人系统提示词。在Agent模式下该配置不生效,由工作空间中的文件内容构成。
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
</details>
|
||||
|
||||
**所有可选的配置项均在该[文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
<details>
|
||||
<summary>3. LinkAI配置</summary>
|
||||
|
||||
## 运行
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台,使用知识库、工作流、插件等能力, 参考[接口文档](https://docs.link-ai.tech/platform/api/chat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用或工作流的code,选填,普通对话模式中使用。
|
||||
</details>
|
||||
|
||||
注:全部配置项说明可在 [`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py) 文件中查看。
|
||||
|
||||
## 三、运行
|
||||
|
||||
### 1.本地运行
|
||||
|
||||
如果是开发机 **本地运行**,直接在项目根目录下执行:
|
||||
如果是个人计算机 **本地运行**,直接在项目根目录下执行:
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
|
||||
运行后默认会启动web服务,可通过访问 `http://localhost:9899/chat` 在网页端对话。
|
||||
|
||||
如果需要接入其他应用通道只需修改 `config.json` 配置文件中的 `channel_type` 参数,详情参考:[通道说明](#通道说明)。
|
||||
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
使用nohup命令在后台运行程序:
|
||||
在服务器中可使用 `nohup` 命令在后台运行程序:
|
||||
|
||||
```bash
|
||||
touch nohup.out # 首次运行需要新建日志文件
|
||||
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
|
||||
nohup python3 app.py & tail -f nohup.out
|
||||
```
|
||||
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。
|
||||
|
||||
> **多账号支持:** 将项目复制多份,分别启动程序,用不同账号扫码登录即可实现同时运行。
|
||||
执行后程序运行于服务器后台,可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。 日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。
|
||||
|
||||
> **特殊指令:** 用户向机器人发送 **#reset** 即可清空该用户的上下文记忆。
|
||||
此外,项目的 `scripts` 目录下有一键运行、关闭程序的脚本供使用。 运行后默认channel为web,通过可以通过修改配置文件进行切换。
|
||||
|
||||
|
||||
### 3.Docker部署
|
||||
|
||||
参考文档 [Docker部署](https://github.com/limccn/chatgpt-on-wechat/wiki/Docker%E9%83%A8%E7%BD%B2) (Contributed by [limccn](https://github.com/limccn))。
|
||||
使用docker部署无需下载源码和安装依赖,只需要获取 `docker-compose.yml` 配置文件并启动容器即可。Agent模式下更推荐使用源码进行部署,以获得更多系统访问能力。
|
||||
|
||||
### 4. Railway部署(✅推荐)
|
||||
> Railway每月提供5刀和最多500小时的免费额度。
|
||||
1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)。
|
||||
2. 点击 `Deploy Now` 按钮。
|
||||
3. 设置环境变量来重载程序运行的参数,例如`open_ai_api_key`, `character_desc`。
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功后执行 `docker -v` 和 `docker-compose version` (或 `docker compose version`) 可查看到版本号。安装地址为 [docker官网](https://docs.docker.com/engine/install/) 。
|
||||
|
||||
## 常见问题
|
||||
**(1) 下载 docker-compose.yml 文件**
|
||||
|
||||
```bash
|
||||
wget https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
```
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 填写所需配置,例如 `CHANNEL_TYPE`、`OPEN_AI_API_KEY` 和等配置。
|
||||
|
||||
**(2) 启动容器**
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
|
||||
```bash
|
||||
sudo docker compose up -d # 若docker-compose为 1.X 版本,则执行 `sudo docker-compose up -d`
|
||||
```
|
||||
|
||||
运行命令后,会自动取 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取最新release版本的镜像。当执行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。最后执行以下命令可查看容器的运行日志:
|
||||
|
||||
```bash
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
**(3) 插件使用**
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
|
||||
```
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
**注**:使用docker方式部署的详细教程可以参考:[docker部署CoW项目](https://www.wangpc.cc/ai/docker-deploy-cow/)
|
||||
|
||||
|
||||
## 模型说明
|
||||
|
||||
以下对所有可支持的模型的配置和使用方法进行说明,模型接口实现在项目的 `models/` 目录下。
|
||||
|
||||
<details>
|
||||
<summary>OpenAI</summary>
|
||||
|
||||
1. API Key创建:在 [OpenAI平台](https://platform.openai.com/api-keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-5.4",
|
||||
"open_ai_api_key": "YOUR_API_KEY",
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"bot_type": "chatGPT"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 与OpenAI接口的 [model参数](https://platform.openai.com/docs/models) 一致,支持包括 gpt-5.4、o系列、gpt-4.1等模型,Agent模式推荐使用 `gpt-5.4`
|
||||
- `open_ai_api_base`: 如果需要接入第三方代理接口,可通过修改该参数进行接入
|
||||
- `bot_type`: 使用OpenAI相关模型时无需填写。当使用第三方代理接口接入Claude等非OpenAI官方模型时,该参数设为 `chatGPT`
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>LinkAI</summary>
|
||||
|
||||
1. API Key创建:在 [LinkAI平台](https://link-ai.tech/console/interface) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"use_linkai": true,
|
||||
"linkai_api_key": "YOUR API KEY",
|
||||
"linkai_app_code": "YOUR APP CODE"
|
||||
}
|
||||
```
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台的智能体,使用知识库、工作流、数据库、MCP插件等丰富的Agent能力
|
||||
+ `linkai_api_key`: LinkAI平台的API Key,可在 [控制台](https://link-ai.tech/console/interface) 中创建
|
||||
+ `linkai_app_code`: LinkAI智能体 (应用或工作流) 的code,选填,普通对话模式可用。智能体创建可参考 [说明文档](https://docs.link-ai.tech/platform/quick-start)
|
||||
+ `model`: model字段填写空则直接使用智能体的模型,可在平台中灵活切换,[模型列表](https://link-ai.tech/console/models)中的全部模型均可使用
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>MiniMax</summary>
|
||||
|
||||
方式一:官方接入,配置如下(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "MiniMax-M2.5",
|
||||
"minimax_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `MiniMax-M2.5、MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2、abab6.5-chat` 等
|
||||
- `minimax_api_key`:MiniMax平台的API-KEY,在 [控制台](https://platform.minimaxi.com/user-center/basic-information/interface-key) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "MiniMax-M2.5",
|
||||
"open_ai_api_base": "https://api.minimaxi.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `MiniMax-M2.5、MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2`,参考[API文档](https://platform.minimaxi.com/document/%E5%AF%B9%E8%AF%9D?key=66701d281d57f38758d581d0#QklxsNSbaf6kM4j6wjO5eEek)
|
||||
- `open_ai_api_base`: MiniMax平台API的 BASE URL
|
||||
- `open_ai_api_key`: MiniMax平台的API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>智谱AI (GLM)</summary>
|
||||
|
||||
方式一:官方接入,配置如下(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "glm-5",
|
||||
"zhipu_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `glm-5、glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等, 参考 [glm系列模型编码](https://bigmodel.cn/dev/api/normal-model/glm-4)
|
||||
- `zhipu_ai_api_key`: 智谱AI平台的 API KEY,在 [控制台](https://www.bigmodel.cn/usercenter/proj-mgmt/apikeys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "glm-5",
|
||||
"open_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `glm-5、glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等
|
||||
- `open_ai_api_base`: 智谱AI平台的 BASE URL
|
||||
- `open_ai_api_key`: 智谱AI平台的 API KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>通义千问 (Qwen)</summary>
|
||||
|
||||
方式一:官方SDK接入,配置如下(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen3.5-plus",
|
||||
"dashscope_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `qwen3.5-plus、qwen3-max、qwen-max、qwen-plus、qwen-turbo、qwen-long、qwq-plus` 等
|
||||
- `dashscope_api_key`: 通义千问的 API-KEY,参考 [官方文档](https://bailian.console.aliyun.com/?tab=api#/api) ,在 [控制台](https://bailian.console.aliyun.com/?tab=model#/api-key) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "qwen3.5-plus",
|
||||
"open_ai_api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"open_ai_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://help.aliyun.com/zh/model-studio/models?spm=a2c4g.11186623.0.0.78d84823Kth5on#9f8890ce29g5u)
|
||||
- `open_ai_api_base`: 通义千问API的 BASE URL
|
||||
- `open_ai_api_key`: 通义千问的 API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Kimi (Moonshot)</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "kimi-k2.5",
|
||||
"moonshot_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `kimi-k2.5、kimi-k2、moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k`
|
||||
- `moonshot_api_key`: Moonshot的API-KEY,在 [控制台](https://platform.moonshot.cn/console/api-keys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "kimi-k2.5",
|
||||
"open_ai_api_base": "https://api.moonshot.cn/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写 `kimi-k2.5、kimi-k2、moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k`
|
||||
- `open_ai_api_base`: Moonshot的 BASE URL
|
||||
- `open_ai_api_key`: Moonshot的 API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>豆包 (Doubao)</summary>
|
||||
|
||||
1. API Key创建:在 [火山方舟控制台](https://console.volcengine.com/ark/region:ark+cn-beijing/apikey) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "doubao-seed-2-0-code-preview-260215",
|
||||
"ark_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `doubao-seed-2-0-code-preview-260215、doubao-seed-2-0-pro-260215、doubao-seed-2-0-lite-260215、doubao-seed-2-0-mini-260215` 等
|
||||
- `ark_api_key`: 火山方舟平台的 API Key,在 [控制台](https://console.volcengine.com/ark/region:ark+cn-beijing/apikey) 创建
|
||||
- `ark_base_url`: 可选,默认为 `https://ark.cn-beijing.volces.com/api/v3`
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Claude</summary>
|
||||
|
||||
1. API Key创建:在 [Claude控制台](https://console.anthropic.com/settings/keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"claude_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 参考 [官方模型ID](https://docs.anthropic.com/en/docs/about-claude/models/overview#model-aliases) ,支持 `claude-sonnet-4-6、claude-opus-4-6、claude-sonnet-4-5、claude-sonnet-4-0、claude-opus-4-0、claude-3-5-sonnet-latest` 等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Gemini</summary>
|
||||
|
||||
API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn) 创建API Key ,配置如下
|
||||
```json
|
||||
{
|
||||
"model": "gemini-3.1-flash-lite-preview",
|
||||
"gemini_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 参考[官方文档-模型列表](https://ai.google.dev/gemini-api/docs/models?hl=zh-cn),支持 `gemini-3.1-flash-lite-preview、gemini-3.1-pro-preview、gemini-3-flash-preview、gemini-3-pro-preview` 等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>DeepSeek</summary>
|
||||
|
||||
1. API Key创建:在 [DeepSeek平台](https://platform.deepseek.com/api_keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "deepseek-chat",
|
||||
"open_ai_api_key": "sk-xxxxxxxxxxx",
|
||||
"open_ai_api_base": "https://api.deepseek.com/v1",
|
||||
"bot_type": "chatGPT"
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `deepseek-chat、deepseek-reasoner`,分别对应的是 DeepSeek-V3 和 DeepSeek-R1 模型
|
||||
- `open_ai_api_key`: DeepSeek平台的 API Key
|
||||
- `open_ai_api_base`: DeepSeek平台 BASE URL
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Azure</summary>
|
||||
|
||||
1. API Key创建:在 [Azure平台](https://oai.azure.com/) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "",
|
||||
"use_azure_chatgpt": true,
|
||||
"open_ai_api_key": "",
|
||||
"open_ai_api_base": "",
|
||||
"azure_deployment_id": "",
|
||||
"azure_api_version": "2025-01-01-preview"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 留空即可
|
||||
- `use_azure_chatgpt`: 设为 true
|
||||
- `open_ai_api_key`: Azure平台的密钥
|
||||
- `open_ai_api_base`: Azure平台的 BASE URL
|
||||
- `azure_deployment_id`: Azure平台部署的模型名称
|
||||
- `azure_api_version`: api版本以及以上参数可以在部署的 [模型配置](https://oai.azure.com/resource/deployments) 界面查看
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>百度文心</summary>
|
||||
方式一:官方SDK接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "wenxin-4",
|
||||
"baidu_wenxin_api_key": "IajztZ0bDxgnP9bEykU7lBer",
|
||||
"baidu_wenxin_secret_key": "EDPZn6L24uAS9d8RWFfotK47dPvkjD6G"
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `wenxin`和`wenxin-4`,对应模型为 文心-3.5 和 文心-4.0
|
||||
- `baidu_wenxin_api_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 API Key
|
||||
- `baidu_wenxin_secret_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 Secret Key
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "ERNIE-4.0-Turbo-8K",
|
||||
"open_ai_api_base": "https://qianfan.baidubce.com/v2",
|
||||
"open_ai_api_key": "bce-v3/ALTxxxxxxd2b"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Wm9cvy6rl)
|
||||
- `open_ai_api_base`: 百度文心API的 BASE URL
|
||||
- `open_ai_api_key`: 百度文心的 API-KEY,参考 [官方文档](https://cloud.baidu.com/doc/qianfan-api/s/ym9chdsy5) ,在 [控制台](https://console.bce.baidu.com/iam/#/iam/apikey/list) 创建API Key
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>讯飞星火</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
参考 [官方文档-快速指引](https://www.xfyun.cn/doc/platform/quickguide.html#%E7%AC%AC%E4%BA%8C%E6%AD%A5-%E5%88%9B%E5%BB%BA%E6%82%A8%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8-%E5%BC%80%E5%A7%8B%E4%BD%BF%E7%94%A8%E6%9C%8D%E5%8A%A1) 获取 `APPID、 APISecret、 APIKey` 三个参数
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "xunfei",
|
||||
"xunfei_app_id": "",
|
||||
"xunfei_api_key": "",
|
||||
"xunfei_api_secret": "",
|
||||
"xunfei_domain": "4.0Ultra",
|
||||
"xunfei_spark_url": "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
}
|
||||
```
|
||||
- `model`: 填 `xunfei`
|
||||
- `xunfei_domain`: 可填写 `4.0Ultra、generalv3.5、max-32k、generalv3、pro-128k、lite`
|
||||
- `xunfei_spark_url`: 填写参考 [官方文档-请求地址](https://www.xfyun.cn/doc/spark/Web.html#_1-1-%E8%AF%B7%E6%B1%82%E5%9C%B0%E5%9D%80) 的说明
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "4.0Ultra",
|
||||
"open_ai_api_base": "https://spark-api-open.xf-yun.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写 `4.0Ultra、generalv3.5、max-32k、generalv3、pro-128k、lite`
|
||||
- `open_ai_api_base`: 讯飞星火平台的 BASE URL
|
||||
- `open_ai_api_key`: 讯飞星火平台的[APIPassword](https://console.xfyun.cn/services/bm3) ,因模型而已
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ModelScope</summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"bot_type": "modelscope",
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"modelscope_api_key": "your_api_key",
|
||||
"modelscope_base_url": "https://api-inference.modelscope.cn/v1/chat/completions",
|
||||
"text_to_image": "MusePublic/489_ckpt_FLUX_1"
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: modelscope接口格式
|
||||
- `model`: 参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
- `modelscope_api_key`: 参考 [官方文档-访问令牌](https://modelscope.cn/docs/accounts/token) ,在 [控制台](https://modelscope.cn/my/myaccesstoken)
|
||||
- `modelscope_base_url`: modelscope平台的 BASE URL
|
||||
- `text_to_image`: 图像生成模型,参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
</details>
|
||||
|
||||
|
||||
## 通道说明
|
||||
|
||||
以下对可接入通道的配置方式进行说明,应用通道代码在项目的 `channel/` 目录下。
|
||||
|
||||
支持同时可接入多个通道,配置时可通过逗号进行分割,例如 `"channel_type": "feishu,dingtalk"`。
|
||||
|
||||
<details>
|
||||
<summary>1. Web</summary>
|
||||
|
||||
项目启动后会默认运行Web控制台,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "web",
|
||||
"web_port": 9899
|
||||
}
|
||||
```
|
||||
|
||||
- `web_port`: 默认为 9899,可按需更改,需要服务器防火墙和安全组放行该端口
|
||||
- 如本地运行,启动后请访问 `http://localhost:9899/chat` ;如服务器运行,请访问 `http://ip:9899/chat`
|
||||
> 注:请将上述 url 中的 ip 或者 port 替换为实际的值
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>2. Feishu - 飞书</summary>
|
||||
|
||||
飞书支持两种事件接收模式:WebSocket 长连接(推荐)和 Webhook。
|
||||
|
||||
**方式一:WebSocket 模式(推荐,无需公网 IP)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "APP_ID",
|
||||
"feishu_app_secret": "APP_SECRET",
|
||||
"feishu_event_mode": "websocket"
|
||||
}
|
||||
```
|
||||
|
||||
**方式二:Webhook 模式(需要公网 IP)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "APP_ID",
|
||||
"feishu_app_secret": "APP_SECRET",
|
||||
"feishu_token": "VERIFICATION_TOKEN",
|
||||
"feishu_event_mode": "webhook",
|
||||
"feishu_port": 9891
|
||||
}
|
||||
```
|
||||
|
||||
- `feishu_event_mode`: 事件接收模式,`websocket`(推荐)或 `webhook`
|
||||
- WebSocket 模式需安装依赖:`pip3 install lark-oapi`
|
||||
|
||||
详细步骤和参数说明参考 [飞书接入](https://docs.cowagent.ai/channels/feishu)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>3. DingTalk - 钉钉</summary>
|
||||
|
||||
钉钉需要在开放平台创建智能机器人应用,将以下配置填入 `config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "dingtalk",
|
||||
"dingtalk_client_id": "CLIENT_ID",
|
||||
"dingtalk_client_secret": "CLIENT_SECRET"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [钉钉接入](https://docs.cowagent.ai/channels/dingtalk)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>4. WeCom Bot - 企微智能机器人</summary>
|
||||
|
||||
企微智能机器人使用 WebSocket 长连接模式,无需公网 IP 和域名,配置简单:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wecom_bot",
|
||||
"wecom_bot_id": "YOUR_BOT_ID",
|
||||
"wecom_bot_secret": "YOUR_SECRET"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [企微智能机器人接入](https://docs.cowagent.ai/channels/wecom-bot)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>5. QQ - QQ 机器人</summary>
|
||||
|
||||
QQ 机器人使用 WebSocket 长连接模式,无需公网 IP 和域名,支持 QQ 单聊、群聊和频道消息:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "qq",
|
||||
"qq_app_id": "YOUR_APP_ID",
|
||||
"qq_app_secret": "YOUR_APP_SECRET"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [QQ 机器人接入](https://docs.cowagent.ai/channels/qq)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>6. WeCom App - 企业微信应用</summary>
|
||||
|
||||
企业微信自建应用接入需在后台创建应用并启用消息回调,配置示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatcom_app",
|
||||
"wechatcom_corp_id": "CORPID",
|
||||
"wechatcomapp_token": "TOKEN",
|
||||
"wechatcomapp_port": 9898,
|
||||
"wechatcomapp_secret": "SECRET",
|
||||
"wechatcomapp_agent_id": "AGENTID",
|
||||
"wechatcomapp_aes_key": "AESKEY"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [企微自建应用接入](https://docs.cowagent.ai/channels/wecom)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>7. WeChat MP - 微信公众号</summary>
|
||||
|
||||
本项目支持订阅号和服务号两种公众号,通过服务号(`wechatmp_service`)体验更佳。
|
||||
|
||||
**个人订阅号(wechatmp)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatmp",
|
||||
"wechatmp_token": "TOKEN",
|
||||
"wechatmp_port": 80,
|
||||
"wechatmp_app_id": "APPID",
|
||||
"wechatmp_app_secret": "APPSECRET",
|
||||
"wechatmp_aes_key": ""
|
||||
}
|
||||
```
|
||||
|
||||
**企业服务号(wechatmp_service)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatmp_service",
|
||||
"wechatmp_token": "TOKEN",
|
||||
"wechatmp_port": 80,
|
||||
"wechatmp_app_id": "APPID",
|
||||
"wechatmp_app_secret": "APPSECRET",
|
||||
"wechatmp_aes_key": ""
|
||||
}
|
||||
```
|
||||
|
||||
详细步骤和参数说明参考 [微信公众号接入](https://docs.cowagent.ai/channels/wechatmp)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>8. Terminal - 终端</summary>
|
||||
|
||||
修改 `config.json` 中的 `channel_type` 字段:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "terminal"
|
||||
}
|
||||
```
|
||||
|
||||
运行后可在终端与机器人进行对话。
|
||||
|
||||
</details>
|
||||
|
||||
<br/>
|
||||
|
||||
# 🔗 相关项目
|
||||
|
||||
- [bot-on-anything](https://github.com/zhayujie/bot-on-anything):轻量和高可扩展的大模型应用框架,支持接入Slack, Telegram, Discord, Gmail等海外平台,可作为本项目的补充使用。
|
||||
- [AgentMesh](https://github.com/MinimalFuture/AgentMesh):开源的多智能体(Multi-Agent)框架,可以通过多智能体团队的协同来解决复杂问题。本项目基于该框架实现了[Agent插件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md),可访问终端、浏览器、文件系统、搜索引擎 等各类工具,并实现了多智能体协同。
|
||||
|
||||
|
||||
|
||||
# 🔎 常见问题
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (知识库持续完善中,回复供参考)
|
||||
|
||||
## 联系
|
||||
# 🛠️ 开发
|
||||
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题优先查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索,若无相似问题可创建Issue,或加微信 eijuyahz 交流。
|
||||
欢迎接入更多应用通道,参考 [飞书通道](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/feishu/feishu_channel.py) 新增自定义通道,实现接收和发送消息逻辑即可完成接入。 同时欢迎贡献新的Skills,参考 [Skill创造器说明](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/skills/skill-creator/SKILL.md)。
|
||||
|
||||
|
||||
# ✉ 联系
|
||||
|
||||
欢迎提交PR、Issues进行反馈,以及通过 🌟Star 支持并关注项目更新。项目运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,以及前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品客服](https://cdn.link-ai.tech/portal/linkai-customer-service.png)咨询。
|
||||
|
||||
# 🌟 贡献者
|
||||
|
||||

|
||||
|
||||
3
agent/chat/__init__.py
Normal file
3
agent/chat/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.chat.service import ChatService
|
||||
|
||||
__all__ = ["ChatService"]
|
||||
213
agent/chat/service.py
Normal file
213
agent/chat/service.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
ChatService - Wraps the Agent stream execution to produce CHAT protocol chunks.
|
||||
|
||||
Translates agent events (message_update, message_end, tool_execution_end, etc.)
|
||||
into the CHAT socket protocol format (content chunks with segment_id, tool_calls chunks).
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""
|
||||
High-level service that runs an Agent for a given query and streams
|
||||
the results as CHAT protocol chunks via a callback.
|
||||
|
||||
Usage:
|
||||
svc = ChatService(agent_bridge)
|
||||
svc.run(query, session_id, send_chunk_fn)
|
||||
"""
|
||||
|
||||
def __init__(self, agent_bridge):
|
||||
"""
|
||||
:param agent_bridge: AgentBridge instance (manages agent lifecycle)
|
||||
"""
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def run(self, query: str, session_id: str, send_chunk_fn: Callable[[dict], None],
|
||||
channel_type: str = ""):
|
||||
"""
|
||||
Run the agent for *query* and stream results back via *send_chunk_fn*.
|
||||
|
||||
The method blocks until the agent finishes. After it returns the SDK
|
||||
will automatically send the final (streaming=false) message.
|
||||
|
||||
:param query: user query text
|
||||
:param session_id: session identifier for agent isolation
|
||||
:param send_chunk_fn: callable(chunk_data: dict) to send a streaming chunk
|
||||
:param channel_type: source channel (e.g. "web", "feishu") for persistence
|
||||
"""
|
||||
agent = self.agent_bridge.get_agent(session_id=session_id)
|
||||
if agent is None:
|
||||
raise RuntimeError("Failed to initialise agent for the session")
|
||||
|
||||
# State shared between the event callback and this method
|
||||
state = _StreamState()
|
||||
|
||||
def on_event(event: dict):
|
||||
"""Translate agent events into CHAT protocol chunks."""
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "message_update":
|
||||
# Incremental text delta
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_end":
|
||||
# A content segment finished.
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
# After tool_calls are executed the next content will be
|
||||
# a new segment; collect tool results until turn_end.
|
||||
state.pending_tool_results = []
|
||||
|
||||
elif event_type == "tool_execution_start":
|
||||
# Notify the client that a tool is about to run (with its input args)
|
||||
tool_name = data.get("tool_name", "")
|
||||
arguments = data.get("arguments", {})
|
||||
# Cache arguments keyed by tool_call_id so tool_execution_end can include them
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
state.pending_tool_arguments[tool_call_id] = arguments
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_start",
|
||||
"tool": tool_name,
|
||||
"arguments": arguments,
|
||||
})
|
||||
|
||||
elif event_type == "tool_execution_end":
|
||||
tool_name = data.get("tool_name", "")
|
||||
tool_call_id = data.get("tool_call_id", tool_name)
|
||||
# Retrieve cached arguments from the matching tool_execution_start event
|
||||
arguments = state.pending_tool_arguments.pop(tool_call_id, data.get("arguments", {}))
|
||||
result = data.get("result", "")
|
||||
status = data.get("status", "unknown")
|
||||
execution_time = data.get("execution_time", 0)
|
||||
elapsed_str = f"{execution_time:.2f}s"
|
||||
|
||||
# Serialise result to string if needed
|
||||
if not isinstance(result, str):
|
||||
import json
|
||||
try:
|
||||
result = json.dumps(result, ensure_ascii=False)
|
||||
except Exception:
|
||||
result = str(result)
|
||||
|
||||
tool_info = {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
"result": result,
|
||||
"status": status,
|
||||
"elapsed": elapsed_str,
|
||||
}
|
||||
|
||||
if state.pending_tool_results is not None:
|
||||
state.pending_tool_results.append(tool_info)
|
||||
|
||||
elif event_type == "turn_end":
|
||||
has_tool_calls = data.get("has_tool_calls", False)
|
||||
if has_tool_calls and state.pending_tool_results:
|
||||
# Flush collected tool results as a single tool_calls chunk
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_calls",
|
||||
"tool_calls": state.pending_tool_results,
|
||||
})
|
||||
state.pending_tool_results = None
|
||||
# Next content belongs to a new segment
|
||||
state.segment_id += 1
|
||||
|
||||
# Run the agent with our event callback ---------------------------
|
||||
logger.info(f"[ChatService] Starting agent run: session={session_id}, query={query[:80]}")
|
||||
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = agent.get_full_system_prompt()
|
||||
|
||||
# Create a copy of messages for this execution
|
||||
with agent.messages_lock:
|
||||
messages_copy = agent.messages.copy()
|
||||
original_length = len(agent.messages)
|
||||
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
|
||||
executor = AgentStreamExecutor(
|
||||
agent=agent,
|
||||
model=agent.model,
|
||||
system_prompt=full_system_prompt,
|
||||
tools=agent.tools,
|
||||
max_turns=agent.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy,
|
||||
max_context_turns=max_context_turns,
|
||||
)
|
||||
|
||||
try:
|
||||
response = executor.run_stream(query)
|
||||
except Exception:
|
||||
# If executor cleared messages (context overflow), sync back
|
||||
if len(executor.messages) == 0:
|
||||
with agent.messages_lock:
|
||||
agent.messages.clear()
|
||||
logger.info("[ChatService] Cleared agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Append only the NEW messages from this execution (thread-safe)
|
||||
with agent.messages_lock:
|
||||
new_messages = executor.messages[original_length:]
|
||||
agent.messages.extend(new_messages)
|
||||
|
||||
# Persist new messages to SQLite so they survive restarts and
|
||||
# can be queried via the HISTORY interface.
|
||||
if new_messages:
|
||||
self._persist_messages(session_id, list(new_messages), channel_type)
|
||||
|
||||
# Store executor reference for files_to_send access
|
||||
agent.stream_executor = executor
|
||||
|
||||
# Execute post-process tools
|
||||
agent._execute_post_process_tools()
|
||||
|
||||
logger.info(f"[ChatService] Agent run completed: session={session_id}")
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _persist_messages(session_id: str, new_messages: list, channel_type: str = ""):
|
||||
try:
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().append_messages(
|
||||
session_id, new_messages, channel_type=channel_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[ChatService] Failed to persist messages for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
class _StreamState:
|
||||
"""Mutable state shared between the event callback and the run method."""
|
||||
|
||||
def __init__(self):
|
||||
self.segment_id: int = 0
|
||||
# None means we are not accumulating tool results right now.
|
||||
# A list means we are in the middle of a tool-execution phase.
|
||||
self.pending_tool_results: Optional[list] = None
|
||||
# Maps tool_call_id -> arguments captured from tool_execution_start,
|
||||
# so that tool_execution_end can attach the correct input args.
|
||||
self.pending_tool_arguments: dict = {}
|
||||
23
agent/memory/__init__.py
Normal file
23
agent/memory/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Memory module for AgentMesh
|
||||
|
||||
Provides both long-term memory (vector/keyword search) and short-term
|
||||
conversation history persistence (SQLite).
|
||||
"""
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
from agent.memory.embedding import create_embedding_provider
|
||||
from agent.memory.conversation_store import ConversationStore, get_conversation_store
|
||||
from agent.memory.summarizer import ensure_daily_memory_file
|
||||
|
||||
__all__ = [
|
||||
'MemoryManager',
|
||||
'MemoryConfig',
|
||||
'get_default_memory_config',
|
||||
'set_global_memory_config',
|
||||
'create_embedding_provider',
|
||||
'ConversationStore',
|
||||
'get_conversation_store',
|
||||
'ensure_daily_memory_file',
|
||||
]
|
||||
140
agent/memory/chunker.py
Normal file
140
agent/memory/chunker.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Text chunking utilities for memory
|
||||
|
||||
Splits text into chunks with token limits and overlap
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""Represents a text chunk with line numbers"""
|
||||
text: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
|
||||
|
||||
class TextChunker:
|
||||
"""Chunks text by line count with token estimation"""
|
||||
|
||||
def __init__(self, max_tokens: int = 500, overlap_tokens: int = 50):
|
||||
"""
|
||||
Initialize chunker
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum tokens per chunk
|
||||
overlap_tokens: Overlap tokens between chunks
|
||||
"""
|
||||
self.max_tokens = max_tokens
|
||||
self.overlap_tokens = overlap_tokens
|
||||
# Rough estimation: ~4 chars per token for English/Chinese mixed
|
||||
self.chars_per_token = 4
|
||||
|
||||
def chunk_text(self, text: str) -> List[TextChunk]:
|
||||
"""
|
||||
Chunk text into overlapping segments
|
||||
|
||||
Args:
|
||||
text: Input text to chunk
|
||||
|
||||
Returns:
|
||||
List of TextChunk objects
|
||||
"""
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
lines = text.split('\n')
|
||||
chunks = []
|
||||
|
||||
max_chars = self.max_tokens * self.chars_per_token
|
||||
overlap_chars = self.overlap_tokens * self.chars_per_token
|
||||
|
||||
current_chunk = []
|
||||
current_chars = 0
|
||||
start_line = 1
|
||||
|
||||
for i, line in enumerate(lines, start=1):
|
||||
line_chars = len(line)
|
||||
|
||||
# If single line exceeds max, split it
|
||||
if line_chars > max_chars:
|
||||
# Save current chunk if exists
|
||||
if current_chunk:
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=i - 1
|
||||
))
|
||||
current_chunk = []
|
||||
current_chars = 0
|
||||
|
||||
# Split long line into multiple chunks
|
||||
for sub_chunk in self._split_long_line(line, max_chars):
|
||||
chunks.append(TextChunk(
|
||||
text=sub_chunk,
|
||||
start_line=i,
|
||||
end_line=i
|
||||
))
|
||||
|
||||
start_line = i + 1
|
||||
continue
|
||||
|
||||
# Check if adding this line would exceed limit
|
||||
if current_chars + line_chars > max_chars and current_chunk:
|
||||
# Save current chunk
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=i - 1
|
||||
))
|
||||
|
||||
# Start new chunk with overlap
|
||||
overlap_lines = self._get_overlap_lines(current_chunk, overlap_chars)
|
||||
current_chunk = overlap_lines + [line]
|
||||
current_chars = sum(len(l) for l in current_chunk)
|
||||
start_line = i - len(overlap_lines)
|
||||
else:
|
||||
# Add line to current chunk
|
||||
current_chunk.append(line)
|
||||
current_chars += line_chars
|
||||
|
||||
# Save last chunk
|
||||
if current_chunk:
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=len(lines)
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_long_line(self, line: str, max_chars: int) -> List[str]:
|
||||
"""Split a single long line into multiple chunks"""
|
||||
chunks = []
|
||||
for i in range(0, len(line), max_chars):
|
||||
chunks.append(line[i:i + max_chars])
|
||||
return chunks
|
||||
|
||||
def _get_overlap_lines(self, lines: List[str], target_chars: int) -> List[str]:
|
||||
"""Get last few lines that fit within target_chars for overlap"""
|
||||
overlap = []
|
||||
chars = 0
|
||||
|
||||
for line in reversed(lines):
|
||||
line_chars = len(line)
|
||||
if chars + line_chars > target_chars:
|
||||
break
|
||||
overlap.insert(0, line)
|
||||
chars += line_chars
|
||||
|
||||
return overlap
|
||||
|
||||
def chunk_markdown(self, text: str) -> List[TextChunk]:
|
||||
"""
|
||||
Chunk markdown text while respecting structure
|
||||
(For future enhancement: respect markdown sections)
|
||||
"""
|
||||
return self.chunk_text(text)
|
||||
122
agent/memory/config.py
Normal file
122
agent/memory/config.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Memory configuration module
|
||||
|
||||
Provides global memory configuration with simplified workspace structure
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _default_workspace():
|
||||
"""Get default workspace path with proper Windows support"""
|
||||
from common.utils import expand_path
|
||||
return expand_path("~/cow")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig:
|
||||
"""Configuration for memory storage and search"""
|
||||
|
||||
# Storage paths (default: ~/cow)
|
||||
workspace_root: str = field(default_factory=_default_workspace)
|
||||
|
||||
# Embedding config
|
||||
embedding_provider: str = "openai" # "openai" | "local"
|
||||
embedding_model: str = "text-embedding-3-small"
|
||||
embedding_dim: int = 1536
|
||||
|
||||
# Chunking config
|
||||
chunk_max_tokens: int = 500
|
||||
chunk_overlap_tokens: int = 50
|
||||
|
||||
# Search config
|
||||
max_results: int = 10
|
||||
min_score: float = 0.1
|
||||
|
||||
# Hybrid search weights
|
||||
vector_weight: float = 0.7
|
||||
keyword_weight: float = 0.3
|
||||
|
||||
# Memory sources
|
||||
sources: List[str] = field(default_factory=lambda: ["memory", "session"])
|
||||
|
||||
# Sync config
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
return Path(self.workspace_root)
|
||||
|
||||
def get_memory_dir(self) -> Path:
|
||||
"""Get memory files directory"""
|
||||
return self.get_workspace() / "memory"
|
||||
|
||||
def get_db_path(self) -> Path:
|
||||
"""Get SQLite database path for long-term memory index"""
|
||||
index_dir = self.get_memory_dir() / "long-term"
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
return index_dir / "index.db"
|
||||
|
||||
def get_skills_dir(self) -> Path:
|
||||
"""Get skills directory"""
|
||||
return self.get_workspace() / "skills"
|
||||
|
||||
def get_agent_workspace(self, agent_name: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get workspace directory for an agent
|
||||
|
||||
Args:
|
||||
agent_name: Optional agent name (not used in current implementation)
|
||||
|
||||
Returns:
|
||||
Path to workspace directory
|
||||
"""
|
||||
workspace = self.get_workspace()
|
||||
# Ensure workspace directory exists
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
return workspace
|
||||
|
||||
|
||||
# Global memory configuration
|
||||
_global_memory_config: Optional[MemoryConfig] = None
|
||||
|
||||
|
||||
def get_default_memory_config() -> MemoryConfig:
|
||||
"""
|
||||
Get the global memory configuration.
|
||||
If not set, returns a default configuration.
|
||||
|
||||
Returns:
|
||||
MemoryConfig instance
|
||||
"""
|
||||
global _global_memory_config
|
||||
if _global_memory_config is None:
|
||||
_global_memory_config = MemoryConfig()
|
||||
return _global_memory_config
|
||||
|
||||
|
||||
def set_global_memory_config(config: MemoryConfig):
|
||||
"""
|
||||
Set the global memory configuration.
|
||||
This should be called before creating any MemoryManager instances.
|
||||
|
||||
Args:
|
||||
config: MemoryConfig instance to use globally
|
||||
|
||||
Example:
|
||||
>>> from agent.memory import MemoryConfig, set_global_memory_config
|
||||
>>> config = MemoryConfig(
|
||||
... workspace_root="~/my_agents",
|
||||
... embedding_provider="openai",
|
||||
... vector_weight=0.8
|
||||
... )
|
||||
>>> set_global_memory_config(config)
|
||||
"""
|
||||
global _global_memory_config
|
||||
_global_memory_config = config
|
||||
618
agent/memory/conversation_store.py
Normal file
618
agent/memory/conversation_store.py
Normal file
@@ -0,0 +1,618 @@
|
||||
"""
|
||||
Conversation history persistence using SQLite.
|
||||
|
||||
Design:
|
||||
- sessions table: per-session metadata (channel_type, last_active, msg_count)
|
||||
- messages table: individual messages stored as JSON, append-only
|
||||
- Pruning: age-based only (sessions not updated within N days are deleted)
|
||||
- Thread-safe via a single in-process lock
|
||||
|
||||
Storage path: ~/cow/sessions/conversations.db
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
channel_type TEXT NOT NULL DEFAULT '',
|
||||
created_at INTEGER NOT NULL,
|
||||
last_active INTEGER NOT NULL,
|
||||
msg_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
seq INTEGER NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
UNIQUE (session_id, seq)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_session
|
||||
ON messages (session_id, seq);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_last_active
|
||||
ON sessions (last_active);
|
||||
"""
|
||||
|
||||
# Migration: add channel_type column to existing databases that predate it.
|
||||
_MIGRATION_ADD_CHANNEL_TYPE = """
|
||||
ALTER TABLE sessions ADD COLUMN channel_type TEXT NOT NULL DEFAULT '';
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_AGE_DAYS: int = 30
|
||||
|
||||
|
||||
def _is_visible_user_message(content: Any) -> bool:
|
||||
"""
|
||||
Return True when a user-role message represents actual user input
|
||||
(not an internal tool_result injected by the agent loop).
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return bool(content.strip())
|
||||
if isinstance(content, list):
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == "text"
|
||||
for b in content
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _extract_display_text(content: Any) -> str:
|
||||
"""
|
||||
Extract the human-readable text portion from a message content value.
|
||||
Returns an empty string for tool_use / tool_result blocks.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_tool_calls(content: Any) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract tool_use blocks from an assistant message content.
|
||||
Returns a list of {name, arguments} dicts (result filled in later).
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return []
|
||||
return [
|
||||
{"id": b.get("id", ""), "name": b.get("name", ""), "arguments": b.get("input", {})}
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
]
|
||||
|
||||
|
||||
def _extract_tool_results(content: Any) -> Dict[str, str]:
|
||||
"""
|
||||
Extract tool_result blocks from a user message, keyed by tool_use_id.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return {}
|
||||
results = {}
|
||||
for b in content:
|
||||
if not isinstance(b, dict) or b.get("type") != "tool_result":
|
||||
continue
|
||||
tool_id = b.get("tool_use_id", "")
|
||||
result_content = b.get("content", "")
|
||||
if isinstance(result_content, list):
|
||||
result_content = "\n".join(
|
||||
rb.get("text", "") for rb in result_content
|
||||
if isinstance(rb, dict) and rb.get("type") == "text"
|
||||
)
|
||||
results[tool_id] = str(result_content)
|
||||
return results
|
||||
|
||||
|
||||
def _group_into_display_turns(
|
||||
rows: List[tuple],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert raw (role, content_json, created_at) DB rows into display turns.
|
||||
|
||||
One display turn = one visible user message + one merged assistant reply.
|
||||
All intermediate assistant messages (those carrying tool_use) and the final
|
||||
assistant text reply produced for the same user query are collapsed into a
|
||||
single assistant turn, exactly matching the live SSE rendering where tools
|
||||
and the final answer appear inside the same bubble.
|
||||
|
||||
Grouping rules:
|
||||
- A visible user message starts a new group.
|
||||
- tool_result user messages are internal; their content is attached to the
|
||||
matching tool_use entry via tool_use_id and they never become own turns.
|
||||
- All assistant messages within a group are merged:
|
||||
* tool_use blocks → tool_calls list (result filled from tool_results)
|
||||
* text blocks → last non-empty text becomes the display content
|
||||
"""
|
||||
# ------------------------------------------------------------------ #
|
||||
# Pass 1: split rows into groups, each starting with a visible user msg
|
||||
# ------------------------------------------------------------------ #
|
||||
# group = (user_row | None, [subsequent_rows])
|
||||
# user_row: (content, created_at)
|
||||
groups: List[tuple] = []
|
||||
cur_user: Optional[tuple] = None
|
||||
cur_rest: List[tuple] = []
|
||||
started = False
|
||||
|
||||
for role, raw_content, created_at in rows:
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
|
||||
if role == "user" and _is_visible_user_message(content):
|
||||
if started:
|
||||
groups.append((cur_user, cur_rest))
|
||||
cur_user = (content, created_at)
|
||||
cur_rest = []
|
||||
started = True
|
||||
else:
|
||||
cur_rest.append((role, content, created_at))
|
||||
|
||||
if started:
|
||||
groups.append((cur_user, cur_rest))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Pass 2: build display turns from each group
|
||||
# ------------------------------------------------------------------ #
|
||||
turns: List[Dict[str, Any]] = []
|
||||
|
||||
for user_row, rest in groups:
|
||||
# User turn
|
||||
if user_row:
|
||||
content, created_at = user_row
|
||||
text = _extract_display_text(content)
|
||||
if text:
|
||||
turns.append({"role": "user", "content": text, "created_at": created_at})
|
||||
|
||||
# Collect all tool_calls and tool_results from the rest of the group
|
||||
all_tool_calls: List[Dict[str, Any]] = []
|
||||
tool_results: Dict[str, str] = {}
|
||||
final_text = ""
|
||||
final_ts: Optional[int] = None
|
||||
|
||||
for role, content, created_at in rest:
|
||||
if role == "user":
|
||||
tool_results.update(_extract_tool_results(content))
|
||||
elif role == "assistant":
|
||||
tcs = _extract_tool_calls(content)
|
||||
all_tool_calls.extend(tcs)
|
||||
t = _extract_display_text(content)
|
||||
if t:
|
||||
final_text = t
|
||||
final_ts = created_at
|
||||
|
||||
# Attach tool results to their matching tool_call entries
|
||||
for tc in all_tool_calls:
|
||||
tc["result"] = tool_results.get(tc.get("id", ""), "")
|
||||
|
||||
if final_text or all_tool_calls:
|
||||
turns.append({
|
||||
"role": "assistant",
|
||||
"content": final_text,
|
||||
"tool_calls": all_tool_calls,
|
||||
"created_at": final_ts or (user_row[1] if user_row else 0),
|
||||
})
|
||||
|
||||
return turns
|
||||
|
||||
|
||||
class ConversationStore:
|
||||
"""
|
||||
SQLite-backed store for per-session conversation history.
|
||||
|
||||
Usage:
|
||||
store = ConversationStore(db_path)
|
||||
store.append_messages("user_123", new_messages, channel_type="feishu")
|
||||
msgs = store.load_messages("user_123", max_turns=30)
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self._db_path = db_path
|
||||
self._lock = threading.Lock()
|
||||
self._init_db()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
max_turns: int = 30,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load the most recent messages for a session, for injection into the LLM.
|
||||
|
||||
ALL message types (user text, assistant tool_use, tool_result) are returned
|
||||
in their original JSON form so the LLM can reconstruct the full context.
|
||||
|
||||
max_turns is a *visible-turn* count: we count only user messages whose
|
||||
content is actual user text (not tool_result blocks). This prevents
|
||||
tool-heavy sessions from exhausting the turn budget prematurely.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier.
|
||||
max_turns: Maximum number of visible user-assistant turns to keep.
|
||||
|
||||
Returns:
|
||||
Chronologically ordered list of message dicts (role, content).
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT seq, role, content
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY seq DESC
|
||||
""",
|
||||
(session_id,),
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Walk newest-to-oldest counting *visible* user turns (actual user text,
|
||||
# not tool_result injections). Record the seq of every visible user
|
||||
# message so we can find a clean cut point later.
|
||||
visible_turn_seqs: List[int] = [] # newest first
|
||||
for seq, role, raw_content in rows:
|
||||
if role != "user":
|
||||
continue
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
if _is_visible_user_message(content):
|
||||
visible_turn_seqs.append(seq)
|
||||
|
||||
# Determine the seq of the oldest visible user message we want to keep.
|
||||
# If the total turns fit within max_turns, keep everything.
|
||||
if len(visible_turn_seqs) <= max_turns:
|
||||
cutoff_seq = None # keep all
|
||||
else:
|
||||
# The Nth visible user message (0-indexed) is the oldest we keep.
|
||||
cutoff_seq = visible_turn_seqs[max_turns - 1]
|
||||
|
||||
# Build result in chronological order, starting from cutoff.
|
||||
# IMPORTANT: we start exactly at cutoff_seq (the visible user message),
|
||||
# never mid-group, so tool_use / tool_result pairs are always complete.
|
||||
result = []
|
||||
for seq, role, raw_content in reversed(rows):
|
||||
if cutoff_seq is not None and seq < cutoff_seq:
|
||||
continue
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
result.append({"role": role, "content": content})
|
||||
return result
|
||||
|
||||
def append_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
channel_type: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Append new messages to a session's history.
|
||||
|
||||
Seq numbers continue from the session's current maximum, so
|
||||
concurrent callers on distinct sessions never collide.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier.
|
||||
messages: List of message dicts to append.
|
||||
channel_type: Source channel (e.g. "feishu", "web", "wechat").
|
||||
Only written on session creation; ignored on update.
|
||||
"""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
now = int(time.time())
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
# INSERT OR IGNORE creates the row on first visit;
|
||||
# the UPDATE always refreshes last_active.
|
||||
# Avoids ON CONFLICT...DO UPDATE (requires SQLite >= 3.24).
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO sessions
|
||||
(session_id, channel_type, created_at, last_active, msg_count)
|
||||
VALUES (?, ?, ?, ?, 0)
|
||||
""",
|
||||
(session_id, channel_type, now, now),
|
||||
)
|
||||
conn.execute(
|
||||
"UPDATE sessions SET last_active = ? WHERE session_id = ?",
|
||||
(now, session_id),
|
||||
)
|
||||
|
||||
# Determine starting seq for the new batch.
|
||||
row = conn.execute(
|
||||
"SELECT COALESCE(MAX(seq), -1) FROM messages WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
next_seq = row[0] + 1
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = json.dumps(
|
||||
msg.get("content", ""), ensure_ascii=False
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO messages
|
||||
(session_id, seq, role, content, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(session_id, next_seq, role, content, now),
|
||||
)
|
||||
next_seq += 1
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET msg_count = (
|
||||
SELECT COUNT(*) FROM messages WHERE session_id = ?
|
||||
)
|
||||
WHERE session_id = ?
|
||||
""",
|
||||
(session_id, session_id),
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""Delete all messages and the session record for a given session_id."""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM sessions WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def cleanup_old_sessions(self, max_age_days: Optional[int] = None) -> int:
|
||||
"""
|
||||
Delete sessions that have not been active within max_age_days.
|
||||
|
||||
Args:
|
||||
max_age_days: Override the default retention period.
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted.
|
||||
"""
|
||||
try:
|
||||
from config import conf
|
||||
max_age = max_age_days or conf().get(
|
||||
"conversation_max_age_days", DEFAULT_MAX_AGE_DAYS
|
||||
)
|
||||
except Exception:
|
||||
max_age = max_age_days or DEFAULT_MAX_AGE_DAYS
|
||||
|
||||
cutoff = int(time.time()) - max_age * 86400
|
||||
deleted = 0
|
||||
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
stale = conn.execute(
|
||||
"SELECT session_id FROM sessions WHERE last_active < ?",
|
||||
(cutoff,),
|
||||
).fetchall()
|
||||
for (sid,) in stale:
|
||||
conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (sid,)
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM sessions WHERE session_id = ?", (sid,)
|
||||
)
|
||||
deleted += 1
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if deleted:
|
||||
logger.info(f"[ConversationStore] Pruned {deleted} expired sessions")
|
||||
return deleted
|
||||
|
||||
def load_history_page(
|
||||
self,
|
||||
session_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load a page of conversation history for UI display, grouped into turns.
|
||||
|
||||
Each "turn" maps to one of:
|
||||
- A user message (role="user", content=str)
|
||||
- An assistant message (role="assistant", content=str,
|
||||
tool_calls=[{name, arguments, result}] when tools were used)
|
||||
|
||||
Internal tool_result user messages are merged into the preceding
|
||||
assistant entry's tool_calls list and never appear as standalone items.
|
||||
|
||||
Pages are numbered from 1 (most recent). Messages within a page are
|
||||
returned in chronological order.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user" | "assistant",
|
||||
"content": str,
|
||||
"tool_calls": [...], # assistant only, may be []
|
||||
"created_at": int,
|
||||
},
|
||||
...
|
||||
],
|
||||
"total": <visible turn count>,
|
||||
"page": <current page>,
|
||||
"page_size": <page_size>,
|
||||
"has_more": bool,
|
||||
}
|
||||
"""
|
||||
page = max(1, page)
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT role, content, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY seq ASC
|
||||
""",
|
||||
(session_id,),
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
visible = _group_into_display_turns(rows)
|
||||
|
||||
total = len(visible)
|
||||
offset = (page - 1) * page_size
|
||||
page_items = list(reversed(visible))[offset: offset + page_size]
|
||||
page_items = list(reversed(page_items))
|
||||
|
||||
return {
|
||||
"messages": page_items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": offset + page_size < total,
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Return basic stats keyed by channel_type, for monitoring."""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
total_sessions = conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions"
|
||||
).fetchone()[0]
|
||||
total_messages = conn.execute(
|
||||
"SELECT COUNT(*) FROM messages"
|
||||
).fetchone()[0]
|
||||
by_channel = conn.execute(
|
||||
"""
|
||||
SELECT channel_type, COUNT(*) as cnt
|
||||
FROM sessions
|
||||
GROUP BY channel_type
|
||||
ORDER BY cnt DESC
|
||||
"""
|
||||
).fetchall()
|
||||
return {
|
||||
"total_sessions": total_sessions,
|
||||
"total_messages": total_messages,
|
||||
"by_channel": {row[0] or "unknown": row[1] for row in by_channel},
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_db(self) -> None:
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = self._connect()
|
||||
try:
|
||||
conn.executescript(_DDL)
|
||||
conn.commit()
|
||||
self._migrate(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _migrate(self, conn: sqlite3.Connection) -> None:
|
||||
"""Apply incremental schema migrations on existing databases."""
|
||||
cols = {
|
||||
row[1]
|
||||
for row in conn.execute("PRAGMA table_info(sessions)").fetchall()
|
||||
}
|
||||
if "channel_type" not in cols:
|
||||
try:
|
||||
conn.execute(_MIGRATION_ADD_CHANNEL_TYPE)
|
||||
conn.commit()
|
||||
logger.info("[ConversationStore] Migrated: added channel_type column")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConversationStore] Migration failed: {e}")
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(str(self._db_path), timeout=10)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_store_instance: Optional[ConversationStore] = None
|
||||
_store_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_conversation_store() -> ConversationStore:
|
||||
"""
|
||||
Return the process-wide ConversationStore singleton.
|
||||
|
||||
Reuses the long-term memory database so the project stays with a single
|
||||
SQLite file: ~/cow/memory/long-term/index.db
|
||||
The conversation tables (sessions / messages) are separate from the
|
||||
memory tables (memory_chunks / file_metadata) — no conflicts.
|
||||
"""
|
||||
global _store_instance
|
||||
if _store_instance is not None:
|
||||
return _store_instance
|
||||
|
||||
with _store_lock:
|
||||
if _store_instance is not None:
|
||||
return _store_instance
|
||||
|
||||
try:
|
||||
from agent.memory.config import get_default_memory_config
|
||||
db_path = get_default_memory_config().get_db_path()
|
||||
except Exception:
|
||||
from common.utils import expand_path
|
||||
db_path = Path(expand_path("~/cow")) / "memory" / "long-term" / "index.db"
|
||||
|
||||
_store_instance = ConversationStore(db_path)
|
||||
logger.debug(f"[ConversationStore] Using shared DB at: {db_path}")
|
||||
return _store_instance
|
||||
161
agent/memory/embedding.py
Normal file
161
agent/memory/embedding.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports OpenAI and local embedding models
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Get embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using REST API"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
Initialize OpenAI embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
||||
api_key: OpenAI API key
|
||||
api_base: Optional API base URL
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
|
||||
# Validate API key
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("OpenAI API key is not configured. Please set 'open_ai_api_key' in config.json")
|
||||
|
||||
# Set dimensions based on model
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI embedding API using requests"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=5)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to OpenAI API at {url}. Please check your network connection and api_base configuration. Error: {str(e)}")
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"OpenAI API request timed out after 10s. Please check your network connection. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError(f"Invalid OpenAI API key. Please check your 'open_ai_api_key' in config.json")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError(f"OpenAI API rate limit exceeded. Please try again later.")
|
||||
else:
|
||||
raise ValueError(f"OpenAI API request failed: {e.response.status_code} - {e.response.text}")
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
result = self._call_api(text)
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
result = self._call_api(texts)
|
||||
return [item["embedding"] for item in result["data"]]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
"""Get cached embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
"""Cache embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
"""Compute cache key"""
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
"""Clear cache"""
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Supports "openai" and "linkai" providers (both use OpenAI-compatible REST API).
|
||||
If initialization fails, caller should fall back to keyword-only search.
|
||||
|
||||
Args:
|
||||
provider: Provider name ("openai" or "linkai")
|
||||
model: Model name (default: text-embedding-3-small)
|
||||
api_key: API key (required)
|
||||
api_base: API base URL
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is unsupported or api_key is missing
|
||||
"""
|
||||
if provider not in ("openai", "linkai"):
|
||||
raise ValueError(f"Unsupported embedding provider: {provider}. Use 'openai' or 'linkai'.")
|
||||
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
523
agent/memory/manager.py
Normal file
523
agent/memory/manager.py
Normal file
@@ -0,0 +1,523 @@
|
||||
"""
|
||||
Memory manager for AgentMesh
|
||||
|
||||
Provides high-level interface for memory operations
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||
from agent.memory.chunker import TextChunker
|
||||
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
|
||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Memory manager with hybrid search capabilities
|
||||
|
||||
Provides long-term memory for agents with vector and keyword search
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[MemoryConfig] = None,
|
||||
embedding_provider: Optional[EmbeddingProvider] = None,
|
||||
llm_model: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Initialize memory manager
|
||||
|
||||
Args:
|
||||
config: Memory configuration (uses global config if not provided)
|
||||
embedding_provider: Custom embedding provider (optional)
|
||||
llm_model: LLM model for summarization (optional)
|
||||
"""
|
||||
self.config = config or get_default_memory_config()
|
||||
|
||||
# Initialize storage
|
||||
db_path = self.config.get_db_path()
|
||||
self.storage = MemoryStorage(db_path)
|
||||
|
||||
# Initialize chunker
|
||||
self.chunker = TextChunker(
|
||||
max_tokens=self.config.chunk_max_tokens,
|
||||
overlap_tokens=self.config.chunk_overlap_tokens
|
||||
)
|
||||
|
||||
# Initialize embedding provider (optional, prefer OpenAI, fallback to LinkAI)
|
||||
self.embedding_provider = None
|
||||
if embedding_provider:
|
||||
self.embedding_provider = embedding_provider
|
||||
else:
|
||||
# Try OpenAI first
|
||||
try:
|
||||
api_key = os.environ.get('OPENAI_API_KEY')
|
||||
api_base = os.environ.get('OPENAI_API_BASE')
|
||||
if api_key:
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model=self.config.embedding_model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] OpenAI embedding failed: {e}")
|
||||
|
||||
# Fallback to LinkAI
|
||||
if self.embedding_provider is None:
|
||||
try:
|
||||
linkai_key = os.environ.get('LINKAI_API_KEY')
|
||||
linkai_base = os.environ.get('LINKAI_API_BASE', 'https://api.link-ai.tech')
|
||||
if linkai_key:
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider="linkai",
|
||||
model=self.config.embedding_model,
|
||||
api_key=linkai_key,
|
||||
api_base=f"{linkai_base}/v1"
|
||||
)
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] LinkAI embedding failed: {e}")
|
||||
|
||||
if self.embedding_provider is None:
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
self.flush_manager = MemoryFlushManager(
|
||||
workspace_dir=workspace_dir,
|
||||
llm_model=llm_model
|
||||
)
|
||||
|
||||
# Ensure workspace directories exist
|
||||
self._init_workspace()
|
||||
|
||||
self._dirty = False
|
||||
|
||||
def _init_workspace(self):
|
||||
"""Initialize workspace directories"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create default memory files
|
||||
workspace_dir = self.config.get_workspace()
|
||||
create_memory_files_if_needed(workspace_dir)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
max_results: Optional[int] = None,
|
||||
min_score: Optional[float] = None,
|
||||
include_shared: bool = True
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search memory with hybrid search (vector + keyword)
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
user_id: User ID for scoped search
|
||||
max_results: Maximum results to return
|
||||
min_score: Minimum score threshold
|
||||
include_shared: Include shared memories
|
||||
|
||||
Returns:
|
||||
List of search results sorted by relevance
|
||||
"""
|
||||
max_results = max_results or self.config.max_results
|
||||
min_score = min_score or self.config.min_score
|
||||
|
||||
# Determine scopes
|
||||
scopes = []
|
||||
if include_shared:
|
||||
scopes.append("shared")
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
if not scopes:
|
||||
return []
|
||||
|
||||
# Sync if needed
|
||||
if self.config.sync_on_search and self._dirty:
|
||||
await self.sync()
|
||||
|
||||
# Perform vector search (if embedding provider available)
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
try:
|
||||
from common.log import logger
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Vector search failed: {e}")
|
||||
|
||||
# Perform keyword search
|
||||
keyword_results = self.storage.search_keyword(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
vector_results,
|
||||
keyword_results,
|
||||
self.config.vector_weight,
|
||||
self.config.keyword_weight
|
||||
)
|
||||
|
||||
# Filter by min score and limit
|
||||
filtered = [r for r in merged if r.score >= min_score]
|
||||
return filtered[:max_results]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
user_id: Optional[str] = None,
|
||||
scope: str = "shared",
|
||||
source: str = "memory",
|
||||
path: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Add new memory content
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
user_id: User ID for user-scoped memory
|
||||
scope: Memory scope ("shared", "user", "session")
|
||||
source: Memory source ("memory" or "session")
|
||||
path: File path (auto-generated if not provided)
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
if not content.strip():
|
||||
return
|
||||
|
||||
# Generate path if not provided
|
||||
if not path:
|
||||
content_hash = hashlib.md5(content.encode('utf-8')).hexdigest()[:8]
|
||||
if user_id and scope == "user":
|
||||
path = f"memory/users/{user_id}/memory_{content_hash}.md"
|
||||
else:
|
||||
path = f"memory/shared/memory_{content_hash}.md"
|
||||
|
||||
# Chunk content
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
|
||||
# Generate embeddings (if provider available)
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
# No embeddings, just use None
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
# Save to storage
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
self.storage.update_file_metadata(
|
||||
path=path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(os.path.getmtime(__file__)), # Use current time
|
||||
size=len(content)
|
||||
)
|
||||
|
||||
async def sync(self, force: bool = False):
|
||||
"""
|
||||
Synchronize memory from files
|
||||
|
||||
Args:
|
||||
force: Force full reindex
|
||||
"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
workspace_dir = self.config.get_workspace()
|
||||
|
||||
# Scan MEMORY.md (workspace root)
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
await self._sync_file(memory_file, "memory", "shared", None)
|
||||
|
||||
# Scan memory directory (including daily summaries)
|
||||
if memory_dir.exists():
|
||||
for file_path in memory_dir.rglob("*.md"):
|
||||
# Determine scope and user_id from path
|
||||
rel_path = file_path.relative_to(workspace_dir)
|
||||
parts = rel_path.parts
|
||||
|
||||
# Check if it's in daily summary directory
|
||||
if "daily" in parts:
|
||||
# Daily summary files
|
||||
if "users" in parts or len(parts) > 3:
|
||||
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
|
||||
user_idx = parts.index("daily") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared daily summary: memory/daily/2024-01-29.md
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
elif "users" in parts:
|
||||
# User-scoped memory
|
||||
user_idx = parts.index("users") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared memory
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
|
||||
await self._sync_file(file_path, "memory", scope, user_id)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
async def _sync_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
source: str,
|
||||
scope: str,
|
||||
user_id: Optional[str]
|
||||
):
|
||||
"""Sync a single file"""
|
||||
# Compute file hash
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
|
||||
# Get relative path
|
||||
workspace_dir = self.config.get_workspace()
|
||||
rel_path = str(file_path.relative_to(workspace_dir))
|
||||
|
||||
# Check if file changed
|
||||
stored_hash = self.storage.get_file_hash(rel_path)
|
||||
if stored_hash == file_hash:
|
||||
return # No changes
|
||||
|
||||
# Delete old chunks
|
||||
self.storage.delete_by_path(rel_path)
|
||||
|
||||
# Chunk and embed
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None
|
||||
))
|
||||
|
||||
# Save
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
stat = file_path.stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size
|
||||
)
|
||||
|
||||
def flush_memory(
|
||||
self,
|
||||
messages: list,
|
||||
user_id: Optional[str] = None,
|
||||
reason: str = "threshold",
|
||||
max_messages: int = 10,
|
||||
) -> bool:
|
||||
"""
|
||||
Flush conversation summary to daily memory file.
|
||||
|
||||
Args:
|
||||
messages: Conversation message list
|
||||
user_id: Optional user ID
|
||||
reason: "threshold" | "overflow" | "daily_summary"
|
||||
max_messages: Max recent messages to include (0 = all)
|
||||
|
||||
Returns:
|
||||
True if content was written
|
||||
"""
|
||||
success = self.flush_manager.flush_from_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
reason=reason,
|
||||
max_messages=max_messages,
|
||||
)
|
||||
if success:
|
||||
self._dirty = True
|
||||
return success
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get memory status"""
|
||||
stats = self.storage.get_stats()
|
||||
return {
|
||||
'chunks': stats['chunks'],
|
||||
'files': stats['files'],
|
||||
'workspace': str(self.config.get_workspace()),
|
||||
'dirty': self._dirty,
|
||||
'embedding_enabled': self.embedding_provider is not None,
|
||||
'embedding_provider': self.config.embedding_provider if self.embedding_provider else 'disabled',
|
||||
'embedding_model': self.config.embedding_model if self.embedding_provider else 'N/A',
|
||||
'search_mode': 'hybrid (vector + keyword)' if self.embedding_provider else 'keyword only (FTS5)'
|
||||
}
|
||||
|
||||
def mark_dirty(self):
|
||||
"""Mark memory as dirty (needs sync)"""
|
||||
self._dirty = True
|
||||
|
||||
def close(self):
|
||||
"""Close memory manager and release resources"""
|
||||
self.storage.close()
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _generate_chunk_id(self, path: str, start_line: int, end_line: int) -> str:
|
||||
"""Generate unique chunk ID"""
|
||||
content = f"{path}:{start_line}:{end_line}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def _compute_temporal_decay(path: str, half_life_days: float = 30.0) -> float:
|
||||
"""
|
||||
Compute temporal decay multiplier for dated memory files.
|
||||
|
||||
Inspired by OpenClaw's temporal-decay: exponential decay based on file date.
|
||||
MEMORY.md and non-dated files are "evergreen" (no decay, multiplier=1.0).
|
||||
Daily files like memory/2025-03-01.md decay based on age.
|
||||
|
||||
Formula: multiplier = exp(-ln2/half_life * age_in_days)
|
||||
"""
|
||||
import re
|
||||
import math
|
||||
|
||||
match = re.search(r'(\d{4})-(\d{2})-(\d{2})\.md$', path)
|
||||
if not match:
|
||||
return 1.0 # evergreen: MEMORY.md, non-dated files
|
||||
|
||||
try:
|
||||
file_date = datetime(
|
||||
int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
)
|
||||
age_days = (datetime.now() - file_date).days
|
||||
if age_days <= 0:
|
||||
return 1.0
|
||||
|
||||
decay_lambda = math.log(2) / half_life_days
|
||||
return math.exp(-decay_lambda * age_days)
|
||||
except (ValueError, OverflowError):
|
||||
return 1.0
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
vector_results: List[SearchResult],
|
||||
keyword_results: List[SearchResult],
|
||||
vector_weight: float,
|
||||
keyword_weight: float
|
||||
) -> List[SearchResult]:
|
||||
"""Merge vector and keyword search results with temporal decay for dated files"""
|
||||
merged_map = {}
|
||||
|
||||
for result in vector_results:
|
||||
key = (result.path, result.start_line, result.end_line)
|
||||
merged_map[key] = {
|
||||
'result': result,
|
||||
'vector_score': result.score,
|
||||
'keyword_score': 0.0
|
||||
}
|
||||
|
||||
for result in keyword_results:
|
||||
key = (result.path, result.start_line, result.end_line)
|
||||
if key in merged_map:
|
||||
merged_map[key]['keyword_score'] = result.score
|
||||
else:
|
||||
merged_map[key] = {
|
||||
'result': result,
|
||||
'vector_score': 0.0,
|
||||
'keyword_score': result.score
|
||||
}
|
||||
|
||||
merged_results = []
|
||||
for entry in merged_map.values():
|
||||
combined_score = (
|
||||
vector_weight * entry['vector_score'] +
|
||||
keyword_weight * entry['keyword_score']
|
||||
)
|
||||
|
||||
# Apply temporal decay for dated memory files
|
||||
result = entry['result']
|
||||
decay = self._compute_temporal_decay(result.path)
|
||||
combined_score *= decay
|
||||
|
||||
merged_results.append(SearchResult(
|
||||
path=result.path,
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
score=combined_score,
|
||||
snippet=result.snippet,
|
||||
source=result.source,
|
||||
user_id=result.user_id
|
||||
))
|
||||
|
||||
merged_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return merged_results
|
||||
167
agent/memory/service.py
Normal file
167
agent/memory/service.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Memory service for handling memory query operations via cloud protocol.
|
||||
|
||||
Provides a unified interface for listing and reading memory files,
|
||||
callable from the cloud client (LinkAI) or a future web console.
|
||||
|
||||
Memory file layout (under workspace_root):
|
||||
MEMORY.md -> type: global
|
||||
memory/2026-02-20.md -> type: daily
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
High-level service for memory file queries.
|
||||
Operates directly on the filesystem — no MemoryManager dependency.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
"""
|
||||
:param workspace_root: Workspace root directory (e.g. ~/cow)
|
||||
"""
|
||||
self.workspace_root = workspace_root
|
||||
self.memory_dir = os.path.join(workspace_root, "memory")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — paginated file metadata
|
||||
# ------------------------------------------------------------------
|
||||
def list_files(self, page: int = 1, page_size: int = 20) -> dict:
|
||||
"""
|
||||
List all memory files with metadata (without content).
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total": 15,
|
||||
"list": [
|
||||
{"filename": "MEMORY.md", "type": "global", "size": 2048, "updated_at": "2026-02-20 10:00:00"},
|
||||
{"filename": "2026-02-20.md", "type": "daily", "size": 512, "updated_at": "2026-02-20 09:30:00"},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
files: List[dict] = []
|
||||
|
||||
# 1. Global memory — MEMORY.md in workspace root
|
||||
global_path = os.path.join(self.workspace_root, "MEMORY.md")
|
||||
if os.path.isfile(global_path):
|
||||
files.append(self._file_info(global_path, "MEMORY.md", "global"))
|
||||
|
||||
# 2. Daily memory files — memory/*.md (sorted newest first)
|
||||
if os.path.isdir(self.memory_dir):
|
||||
daily_files = []
|
||||
for name in os.listdir(self.memory_dir):
|
||||
full = os.path.join(self.memory_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
daily_files.append((name, full))
|
||||
# Sort by filename descending (newest date first)
|
||||
daily_files.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in daily_files:
|
||||
files.append(self._file_info(full, name, "daily"))
|
||||
|
||||
total = len(files)
|
||||
|
||||
# Paginate
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
page_items = files[start:end]
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"list": page_items,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# content — read a single file
|
||||
# ------------------------------------------------------------------
|
||||
def get_content(self, filename: str) -> dict:
|
||||
"""
|
||||
Read the full content of a memory file.
|
||||
|
||||
:param filename: File name, e.g. ``MEMORY.md`` or ``2026-02-20.md``
|
||||
:return: dict with ``filename`` and ``content``
|
||||
:raises FileNotFoundError: if the file does not exist
|
||||
"""
|
||||
path = self._resolve_path(filename)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Memory file not found: {filename}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a memory management action.
|
||||
|
||||
:param action: ``list`` or ``content``
|
||||
:param payload: action-specific payload
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
page = payload.get("page", 1)
|
||||
page_size = payload.get("page_size", 20)
|
||||
result_payload = self.list_files(page=page, page_size=page_size)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
elif action == "content":
|
||||
filename = payload.get("filename")
|
||||
if not filename:
|
||||
return {"action": action, "code": 400, "message": "filename is required", "payload": None}
|
||||
result_payload = self.get_content(filename)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[MemoryService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _resolve_path(self, filename: str) -> str:
|
||||
"""
|
||||
Resolve a filename to its absolute path.
|
||||
|
||||
- ``MEMORY.md`` → ``{workspace_root}/MEMORY.md``
|
||||
- ``2026-02-20.md`` → ``{workspace_root}/memory/2026-02-20.md``
|
||||
"""
|
||||
if filename == "MEMORY.md":
|
||||
return os.path.join(self.workspace_root, filename)
|
||||
return os.path.join(self.memory_dir, filename)
|
||||
|
||||
@staticmethod
|
||||
def _file_info(path: str, filename: str, file_type: str) -> dict:
|
||||
"""Build a file metadata dict."""
|
||||
stat = os.stat(path)
|
||||
updated_at = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
||||
return {
|
||||
"filename": filename,
|
||||
"type": file_type,
|
||||
"size": stat.st_size,
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
589
agent/memory/storage.py
Normal file
589
agent/memory/storage.py
Normal file
@@ -0,0 +1,589 @@
|
||||
"""
|
||||
Storage layer for memory using SQLite + FTS5
|
||||
|
||||
Provides vector and keyword search capabilities
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import sqlite3
|
||||
import json
|
||||
import hashlib
|
||||
from typing import List, Dict, Optional, Any
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryChunk:
|
||||
"""Represents a memory chunk with text and embedding"""
|
||||
id: str
|
||||
user_id: Optional[str]
|
||||
scope: str # "shared" | "user" | "session"
|
||||
source: str # "memory" | "session"
|
||||
path: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
text: str
|
||||
embedding: Optional[List[float]]
|
||||
hash: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Search result with score and snippet"""
|
||||
path: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
score: float
|
||||
snippet: str
|
||||
source: str
|
||||
user_id: Optional[str] = None
|
||||
|
||||
|
||||
class MemoryStorage:
|
||||
"""SQLite-based storage with FTS5 for keyword search"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self.conn: Optional[sqlite3.Connection] = None
|
||||
self.fts5_available = False # Track FTS5 availability
|
||||
self._init_db()
|
||||
|
||||
def _check_fts5_support(self) -> bool:
|
||||
"""Check if SQLite has FTS5 support"""
|
||||
try:
|
||||
self.conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(test)")
|
||||
self.conn.execute("DROP TABLE IF EXISTS fts5_test")
|
||||
return True
|
||||
except sqlite3.OperationalError as e:
|
||||
if "no such module: fts5" in str(e):
|
||||
return False
|
||||
raise
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database with schema"""
|
||||
try:
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
# Check FTS5 support
|
||||
self.fts5_available = self._check_fts5_support()
|
||||
if not self.fts5_available:
|
||||
from common.log import logger
|
||||
logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
|
||||
|
||||
# Check database integrity
|
||||
try:
|
||||
result = self.conn.execute("PRAGMA integrity_check").fetchone()
|
||||
if result[0] != 'ok':
|
||||
print(f"⚠️ Database integrity check failed: {result[0]}")
|
||||
print(f" Recreating database...")
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
# Remove corrupted database
|
||||
self.db_path.unlink(missing_ok=True)
|
||||
# Remove WAL files
|
||||
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
|
||||
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
|
||||
# Reconnect to create new database
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
except sqlite3.DatabaseError:
|
||||
# Database is corrupted, recreate it
|
||||
print(f"⚠️ Database is corrupted, recreating...")
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
self.db_path.unlink(missing_ok=True)
|
||||
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
|
||||
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
# Enable WAL mode for better concurrency
|
||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Set busy timeout to avoid "database is locked" errors
|
||||
self.conn.execute("PRAGMA busy_timeout=5000")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Unexpected error during database initialization: {e}")
|
||||
raise
|
||||
|
||||
# Create chunks table with embeddings
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
scope TEXT NOT NULL DEFAULT 'shared',
|
||||
source TEXT NOT NULL DEFAULT 'memory',
|
||||
path TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
embedding TEXT,
|
||||
hash TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at INTEGER DEFAULT (strftime('%s', 'now')),
|
||||
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_user
|
||||
ON chunks(user_id)
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_scope
|
||||
ON chunks(scope)
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_hash
|
||||
ON chunks(path, hash)
|
||||
""")
|
||||
|
||||
# Create FTS5 virtual table for keyword search (only if supported)
|
||||
if self.fts5_available:
|
||||
# Use default unicode61 tokenizer (stable and compatible)
|
||||
# For CJK support, we'll use LIKE queries as fallback
|
||||
self.conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
||||
text,
|
||||
id UNINDEXED,
|
||||
user_id UNINDEXED,
|
||||
path UNINDEXED,
|
||||
source UNINDEXED,
|
||||
scope UNINDEXED,
|
||||
content='chunks',
|
||||
content_rowid='rowid'
|
||||
)
|
||||
""")
|
||||
|
||||
# Create triggers to keep FTS in sync
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
||||
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
||||
DELETE FROM chunks_fts WHERE rowid = old.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
||||
UPDATE chunks_fts SET text = new.text, id = new.id,
|
||||
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
|
||||
WHERE rowid = new.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
# Create files metadata table
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
path TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL DEFAULT 'memory',
|
||||
hash TEXT NOT NULL,
|
||||
mtime INTEGER NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
)
|
||||
""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def save_chunk(self, chunk: MemoryChunk):
|
||||
"""Save a memory chunk"""
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (
|
||||
chunk.id,
|
||||
chunk.user_id,
|
||||
chunk.scope,
|
||||
chunk.source,
|
||||
chunk.path,
|
||||
chunk.start_line,
|
||||
chunk.end_line,
|
||||
chunk.text,
|
||||
json.dumps(chunk.embedding) if chunk.embedding else None,
|
||||
chunk.hash,
|
||||
json.dumps(chunk.metadata) if chunk.metadata else None
|
||||
))
|
||||
self.conn.commit()
|
||||
|
||||
def save_chunks_batch(self, chunks: List[MemoryChunk]):
|
||||
"""Save multiple chunks in a batch"""
|
||||
self.conn.executemany("""
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", [
|
||||
(
|
||||
c.id, c.user_id, c.scope, c.source, c.path,
|
||||
c.start_line, c.end_line, c.text,
|
||||
json.dumps(c.embedding) if c.embedding else None,
|
||||
c.hash,
|
||||
json.dumps(c.metadata) if c.metadata else None
|
||||
)
|
||||
for c in chunks
|
||||
])
|
||||
self.conn.commit()
|
||||
|
||||
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
|
||||
"""Get a chunk by ID"""
|
||||
row = self.conn.execute("""
|
||||
SELECT * FROM chunks WHERE id = ?
|
||||
""", (chunk_id,)).fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return self._row_to_chunk(row)
|
||||
|
||||
def search_vector(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
user_id: Optional[str] = None,
|
||||
scopes: List[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Vector similarity search using in-memory cosine similarity
|
||||
(sqlite-vec can be added later for better performance)
|
||||
"""
|
||||
if scopes is None:
|
||||
scopes = ["shared"]
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
# Build query
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
params = scopes
|
||||
|
||||
if user_id:
|
||||
query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE scope IN ({scope_placeholders})
|
||||
AND (scope = 'shared' OR user_id = ?)
|
||||
AND embedding IS NOT NULL
|
||||
"""
|
||||
params.append(user_id)
|
||||
else:
|
||||
query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE scope IN ({scope_placeholders})
|
||||
AND embedding IS NOT NULL
|
||||
"""
|
||||
|
||||
rows = self.conn.execute(query, params).fetchall()
|
||||
|
||||
# Calculate cosine similarity
|
||||
results = []
|
||||
for row in rows:
|
||||
embedding = json.loads(row['embedding'])
|
||||
similarity = self._cosine_similarity(query_embedding, embedding)
|
||||
|
||||
if similarity > 0:
|
||||
results.append((similarity, row))
|
||||
|
||||
# Sort by similarity and limit
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=score,
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for score, row in results
|
||||
]
|
||||
|
||||
def search_keyword(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
scopes: List[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Keyword search using FTS5 + LIKE fallback
|
||||
|
||||
Strategy:
|
||||
1. If FTS5 available: Try FTS5 search first (good for English and word-based languages)
|
||||
2. If no FTS5 or no results and query contains CJK: Use LIKE search
|
||||
"""
|
||||
if scopes is None:
|
||||
scopes = ["shared"]
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
# Try FTS5 search first (if available)
|
||||
if self.fts5_available:
|
||||
fts_results = self._search_fts5(query, user_id, scopes, limit)
|
||||
if fts_results:
|
||||
return fts_results
|
||||
|
||||
# Fallback to LIKE search (always for CJK, or if FTS5 not available)
|
||||
if not self.fts5_available or MemoryStorage._contains_cjk(query):
|
||||
return self._search_like(query, user_id, scopes, limit)
|
||||
|
||||
return []
|
||||
|
||||
def _search_fts5(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str],
|
||||
scopes: List[str],
|
||||
limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""FTS5 full-text search"""
|
||||
fts_query = self._build_fts_query(query)
|
||||
if not fts_query:
|
||||
return []
|
||||
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
params = [fts_query] + scopes
|
||||
|
||||
if user_id:
|
||||
sql_query = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||
FROM chunks_fts
|
||||
JOIN chunks ON chunks.id = chunks_fts.id
|
||||
WHERE chunks_fts MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
|
||||
ORDER BY rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.extend([user_id, limit])
|
||||
else:
|
||||
sql_query = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||
FROM chunks_fts
|
||||
JOIN chunks ON chunks.id = chunks_fts.id
|
||||
WHERE chunks_fts MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
ORDER BY rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
rows = self.conn.execute(sql_query, params).fetchall()
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=self._bm25_rank_to_score(row['rank']),
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _search_like(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str],
|
||||
scopes: List[str],
|
||||
limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""LIKE-based search for CJK characters"""
|
||||
import re
|
||||
# Extract CJK words (2+ characters)
|
||||
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
|
||||
if not cjk_words:
|
||||
return []
|
||||
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
|
||||
# Build LIKE conditions for each word
|
||||
like_conditions = []
|
||||
params = []
|
||||
for word in cjk_words:
|
||||
like_conditions.append("text LIKE ?")
|
||||
params.append(f'%{word}%')
|
||||
|
||||
where_clause = ' OR '.join(like_conditions)
|
||||
params.extend(scopes)
|
||||
|
||||
if user_id:
|
||||
sql_query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE ({where_clause})
|
||||
AND scope IN ({scope_placeholders})
|
||||
AND (scope = 'shared' OR user_id = ?)
|
||||
LIMIT ?
|
||||
"""
|
||||
params.extend([user_id, limit])
|
||||
else:
|
||||
sql_query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE ({where_clause})
|
||||
AND scope IN ({scope_placeholders})
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
rows = self.conn.execute(sql_query, params).fetchall()
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=0.5, # Fixed score for LIKE search
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def delete_by_path(self, path: str):
|
||||
"""Delete all chunks from a file"""
|
||||
self.conn.execute("""
|
||||
DELETE FROM chunks WHERE path = ?
|
||||
""", (path,))
|
||||
self.conn.commit()
|
||||
|
||||
def get_file_hash(self, path: str) -> Optional[str]:
|
||||
"""Get stored file hash"""
|
||||
row = self.conn.execute("""
|
||||
SELECT hash FROM files WHERE path = ?
|
||||
""", (path,)).fetchone()
|
||||
return row['hash'] if row else None
|
||||
|
||||
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
|
||||
"""Update file metadata"""
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (path, source, file_hash, mtime, size))
|
||||
self.conn.commit()
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""Get storage statistics"""
|
||||
chunks_count = self.conn.execute("""
|
||||
SELECT COUNT(*) as cnt FROM chunks
|
||||
""").fetchone()['cnt']
|
||||
|
||||
files_count = self.conn.execute("""
|
||||
SELECT COUNT(*) as cnt FROM files
|
||||
""").fetchone()['cnt']
|
||||
|
||||
return {
|
||||
'chunks': chunks_count,
|
||||
'files': files_count
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""Close database connection"""
|
||||
if self.conn:
|
||||
try:
|
||||
self.conn.commit() # Ensure all changes are committed
|
||||
self.conn.close()
|
||||
self.conn = None # Mark as closed
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error closing database connection: {e}")
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure connection is closed"""
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _row_to_chunk(self, row) -> MemoryChunk:
|
||||
"""Convert database row to MemoryChunk"""
|
||||
return MemoryChunk(
|
||||
id=row['id'],
|
||||
user_id=row['user_id'],
|
||||
scope=row['scope'],
|
||||
source=row['source'],
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
text=row['text'],
|
||||
embedding=json.loads(row['embedding']) if row['embedding'] else None,
|
||||
hash=row['hash'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors"""
|
||||
if len(vec1) != len(vec2):
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
||||
norm1 = sum(a * a for a in vec1) ** 0.5
|
||||
norm2 = sum(b * b for b in vec2) ** 0.5
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
@staticmethod
|
||||
def _contains_cjk(text: str) -> bool:
|
||||
"""Check if text contains CJK (Chinese/Japanese/Korean) characters"""
|
||||
import re
|
||||
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
||||
|
||||
@staticmethod
|
||||
def _build_fts_query(raw_query: str) -> Optional[str]:
|
||||
"""
|
||||
Build FTS5 query from raw text
|
||||
|
||||
Works best for English and word-based languages.
|
||||
For CJK characters, LIKE search will be used as fallback.
|
||||
"""
|
||||
import re
|
||||
# Extract words (primarily English words and numbers)
|
||||
tokens = re.findall(r'[A-Za-z0-9_]+', raw_query)
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
# Quote tokens for exact matching
|
||||
quoted = [f'"{t}"' for t in tokens]
|
||||
# Use OR for more flexible matching
|
||||
return ' OR '.join(quoted)
|
||||
|
||||
@staticmethod
|
||||
def _bm25_rank_to_score(rank: float) -> float:
|
||||
"""Convert BM25 rank to 0-1 score"""
|
||||
normalized = max(0, rank) if rank is not None else 999
|
||||
return 1 / (1 + normalized)
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text(text: str, max_chars: int) -> str:
|
||||
"""Truncate text to max characters"""
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[:max_chars] + "..."
|
||||
|
||||
@staticmethod
|
||||
def compute_hash(content: str) -> str:
|
||||
"""Compute SHA256 hash of content"""
|
||||
return hashlib.sha256(content.encode('utf-8')).hexdigest()
|
||||
370
agent/memory/summarizer.py
Normal file
370
agent/memory/summarizer.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
Memory flush manager
|
||||
|
||||
Handles memory persistence when conversation context is trimmed or overflows:
|
||||
- Uses LLM to summarize discarded messages into concise key-information entries
|
||||
- Writes to daily memory files (lazy creation)
|
||||
- Deduplicates trim flushes to avoid repeated writes
|
||||
- Runs summarization asynchronously to avoid blocking normal replies
|
||||
- Provides daily summary interface for scheduler
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import Optional, Callable, Any, List, Dict
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from common.log import logger
|
||||
|
||||
|
||||
SUMMARIZE_SYSTEM_PROMPT = """你是一个记忆提取助手。你的任务是从对话记录中提取值得记住的信息,生成简洁的记忆摘要。
|
||||
|
||||
输出要求:
|
||||
1. 以事件/关键信息为维度记录,每条一行,用 "- " 开头
|
||||
2. 记录有价值的关键信息,例如用户提出的要求及助手的解决方案,对话中涉及的事实信息,用户的偏好、决策或重要结论
|
||||
3. 每条摘要需要简明扼要,只保留关键信息
|
||||
4. 直接输出摘要内容,不要加任何前缀说明
|
||||
5. 当对话没有任何记录价值例如只是简单问候,可回复"无\""""
|
||||
|
||||
SUMMARIZE_USER_PROMPT = """请从以下对话记录中提取关键信息,生成记忆摘要:
|
||||
|
||||
{conversation}"""
|
||||
|
||||
|
||||
class MemoryFlushManager:
|
||||
"""
|
||||
Manages memory flush operations.
|
||||
|
||||
Flush is triggered by agent_stream in two scenarios:
|
||||
1. Context trim: _trim_messages discards old turns → flush discarded content
|
||||
2. Context overflow: API rejects request → emergency flush before clearing
|
||||
|
||||
Additionally, create_daily_summary() can be called by scheduler for end-of-day summaries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
llm_model: Optional[Any] = None,
|
||||
):
|
||||
self.workspace_dir = workspace_dir
|
||||
self.llm_model = llm_model
|
||||
|
||||
self.memory_dir = workspace_dir / "memory"
|
||||
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.last_flush_timestamp: Optional[datetime] = None
|
||||
self._trim_flushed_hashes: set = set() # Content hashes of already-flushed messages
|
||||
self._last_flushed_content_hash: str = "" # Content hash at last flush, for daily dedup
|
||||
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None, ensure_exists: bool = False) -> Path:
|
||||
"""Get today's memory file path: memory/YYYY-MM-DD.md"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
if ensure_exists:
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
today_file = user_dir / f"{today}.md"
|
||||
else:
|
||||
today_file = self.memory_dir / f"{today}.md"
|
||||
|
||||
if ensure_exists and not today_file.exists():
|
||||
today_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
today_file.write_text(f"# Daily Memory: {today}\n\n")
|
||||
|
||||
return today_file
|
||||
|
||||
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""Get main memory file path: MEMORY.md (workspace root)"""
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / "MEMORY.md"
|
||||
else:
|
||||
return Path(self.workspace_dir) / "MEMORY.md"
|
||||
|
||||
def get_status(self) -> dict:
|
||||
return {
|
||||
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
|
||||
'today_file': str(self.get_today_memory_file()),
|
||||
'main_file': str(self.get_main_memory_file())
|
||||
}
|
||||
|
||||
# ---- Flush execution (called by agent_stream or scheduler) ----
|
||||
|
||||
def flush_from_messages(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str] = None,
|
||||
reason: str = "trim",
|
||||
max_messages: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously summarize and flush messages to daily memory.
|
||||
|
||||
Deduplication runs synchronously, then LLM summarization + file write
|
||||
run in a background thread so the main reply flow is never blocked.
|
||||
|
||||
Args:
|
||||
messages: Conversation message list (OpenAI/Claude format)
|
||||
user_id: Optional user ID for user-scoped memory
|
||||
reason: Why flush was triggered ("trim" | "overflow" | "daily_summary")
|
||||
max_messages: Max recent messages to summarize (0 = all)
|
||||
|
||||
Returns:
|
||||
True if flush was dispatched
|
||||
"""
|
||||
try:
|
||||
import hashlib
|
||||
deduped = []
|
||||
for m in messages:
|
||||
text = self._extract_text_from_content(m.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
h = hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
if h not in self._trim_flushed_hashes:
|
||||
self._trim_flushed_hashes.add(h)
|
||||
deduped.append(m)
|
||||
if not deduped:
|
||||
return False
|
||||
|
||||
import copy
|
||||
snapshot = copy.deepcopy(deduped)
|
||||
thread = threading.Thread(
|
||||
target=self._flush_worker,
|
||||
args=(snapshot, user_id, reason, max_messages),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
logger.info(f"[MemoryFlush] Async flush dispatched (reason={reason}, msgs={len(snapshot)})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Failed to dispatch flush (reason={reason}): {e}")
|
||||
return False
|
||||
|
||||
def _flush_worker(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str],
|
||||
reason: str,
|
||||
max_messages: int,
|
||||
):
|
||||
"""Background worker: summarize with LLM and write to daily file."""
|
||||
try:
|
||||
summary = self._summarize_messages(messages, max_messages)
|
||||
if not summary or not summary.strip() or summary.strip() == "无":
|
||||
logger.info(f"[MemoryFlush] No valuable content to flush (reason={reason})")
|
||||
return
|
||||
|
||||
daily_file = ensure_daily_memory_file(self.workspace_dir, user_id)
|
||||
|
||||
if reason == "overflow":
|
||||
header = f"## Context Overflow Recovery ({datetime.now().strftime('%H:%M')})"
|
||||
note = "The following conversation was trimmed due to context overflow:\n"
|
||||
elif reason == "trim":
|
||||
header = f"## Trimmed Context ({datetime.now().strftime('%H:%M')})"
|
||||
note = ""
|
||||
elif reason == "daily_summary":
|
||||
header = f"## Daily Summary ({datetime.now().strftime('%H:%M')})"
|
||||
note = ""
|
||||
else:
|
||||
header = f"## Session Notes ({datetime.now().strftime('%H:%M')})"
|
||||
note = ""
|
||||
|
||||
flush_entry = f"\n{header}\n\n{note}{summary}\n"
|
||||
|
||||
with open(daily_file, "a", encoding="utf-8") as f:
|
||||
f.write(flush_entry)
|
||||
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
|
||||
logger.info(f"[MemoryFlush] Wrote to {daily_file.name} (reason={reason}, chars={len(summary)})")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] Async flush failed (reason={reason}): {e}")
|
||||
|
||||
def create_daily_summary(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Generate end-of-day summary. Called by daily timer.
|
||||
Skips if messages haven't changed since last flush.
|
||||
"""
|
||||
import hashlib
|
||||
content = "".join(
|
||||
self._extract_text_from_content(m.get("content", ""))
|
||||
for m in messages
|
||||
)
|
||||
content_hash = hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||
if content_hash == self._last_flushed_content_hash:
|
||||
logger.debug("[MemoryFlush] Daily summary skipped: no new content since last flush")
|
||||
return False
|
||||
self._last_flushed_content_hash = content_hash
|
||||
return self.flush_from_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
reason="daily_summary",
|
||||
max_messages=0,
|
||||
)
|
||||
|
||||
# ---- Internal helpers ----
|
||||
|
||||
def _summarize_messages(self, messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""
|
||||
Summarize conversation messages using LLM, with rule-based fallback.
|
||||
"""
|
||||
conversation_text = self._format_conversation_for_summary(messages, max_messages)
|
||||
if not conversation_text.strip():
|
||||
return ""
|
||||
|
||||
# Try LLM summarization first
|
||||
if self.llm_model:
|
||||
try:
|
||||
summary = self._call_llm_for_summary(conversation_text)
|
||||
if summary and summary.strip() and summary.strip() != "无":
|
||||
return summary.strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[MemoryFlush] LLM summarization failed, using fallback: {e}")
|
||||
|
||||
return self._extract_summary_fallback(messages, max_messages)
|
||||
|
||||
def _format_conversation_for_summary(self, messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""Format messages into readable conversation text for LLM summarization."""
|
||||
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
|
||||
lines = []
|
||||
for msg in msgs:
|
||||
role = msg.get("role", "")
|
||||
text = self._extract_text_from_content(msg.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
text = text.strip()
|
||||
if role == "user":
|
||||
lines.append(f"用户: {text[:500]}")
|
||||
elif role == "assistant":
|
||||
lines.append(f"助手: {text[:500]}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _call_llm_for_summary(self, conversation_text: str) -> str:
|
||||
"""Call LLM to generate a concise summary of the conversation."""
|
||||
from agent.protocol.models import LLMRequest
|
||||
|
||||
request = LLMRequest(
|
||||
messages=[{"role": "user", "content": SUMMARIZE_USER_PROMPT.format(conversation=conversation_text)}],
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
stream=False,
|
||||
system=SUMMARIZE_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
response = self.llm_model.call(request)
|
||||
|
||||
if isinstance(response, dict):
|
||||
if response.get("error"):
|
||||
raise RuntimeError(response.get("message", "LLM call failed"))
|
||||
# OpenAI format
|
||||
choices = response.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("message", {}).get("content", "")
|
||||
|
||||
# Handle response object with attribute access (e.g. OpenAI SDK response)
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
return response.choices[0].message.content or ""
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_summary_fallback(messages: List[Dict], max_messages: int = 0) -> str:
|
||||
"""Rule-based fallback when LLM is unavailable."""
|
||||
msgs = messages if max_messages == 0 else messages[-max_messages * 2:]
|
||||
|
||||
items = []
|
||||
for msg in msgs:
|
||||
role = msg.get("role", "")
|
||||
text = MemoryFlushManager._extract_text_from_content(msg.get("content", ""))
|
||||
if not text or not text.strip():
|
||||
continue
|
||||
text = text.strip()
|
||||
|
||||
if role == "user":
|
||||
if len(text) <= 5:
|
||||
continue
|
||||
items.append(f"- 用户请求: {text[:200]}")
|
||||
elif role == "assistant":
|
||||
first_line = text.split("\n")[0].strip()
|
||||
if len(first_line) > 10:
|
||||
items.append(f"- 处理结果: {first_line[:200]}")
|
||||
|
||||
return "\n".join(items[:15])
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from message content (string or content blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
parts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
parts.append(block)
|
||||
return "\n".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
|
||||
"""
|
||||
Create essential memory files if they don't exist.
|
||||
Only creates MEMORY.md; daily files are created lazily on first write.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
user_id: Optional user ID for user-specific files
|
||||
"""
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create main MEMORY.md in workspace root (always needed for bootstrap)
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
main_memory = user_dir / "MEMORY.md"
|
||||
else:
|
||||
main_memory = Path(workspace_dir) / "MEMORY.md"
|
||||
|
||||
if not main_memory.exists():
|
||||
main_memory.write_text("")
|
||||
|
||||
|
||||
def ensure_daily_memory_file(workspace_dir: Path, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Ensure today's daily memory file exists, creating it only when actually needed.
|
||||
Called lazily before first write to daily memory.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
user_id: Optional user ID for user-specific files
|
||||
|
||||
Returns:
|
||||
Path to today's memory file
|
||||
"""
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
today_memory = user_dir / f"{today}.md"
|
||||
else:
|
||||
today_memory = memory_dir / f"{today}.md"
|
||||
|
||||
if not today_memory.exists():
|
||||
today_memory.write_text(
|
||||
f"# Daily Memory: {today}\n\n"
|
||||
)
|
||||
|
||||
return today_memory
|
||||
13
agent/prompt/__init__.py
Normal file
13
agent/prompt/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Agent Prompt Module - 系统提示词构建模块
|
||||
"""
|
||||
|
||||
from .builder import PromptBuilder, build_agent_system_prompt
|
||||
from .workspace import ensure_workspace, load_context_files
|
||||
|
||||
__all__ = [
|
||||
'PromptBuilder',
|
||||
'build_agent_system_prompt',
|
||||
'ensure_workspace',
|
||||
'load_context_files',
|
||||
]
|
||||
488
agent/prompt/builder.py
Normal file
488
agent/prompt/builder.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
System Prompt Builder - 系统提示词构建器
|
||||
|
||||
实现模块化的系统提示词构建,支持工具、技能、记忆等多个子系统
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextFile:
|
||||
"""上下文文件"""
|
||||
path: str
|
||||
content: str
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""提示词构建器"""
|
||||
|
||||
def __init__(self, workspace_dir: str, language: str = "zh"):
|
||||
"""
|
||||
初始化提示词构建器
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
language: 语言 ("zh" 或 "en")
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.language = language
|
||||
|
||||
def build(
|
||||
self,
|
||||
base_persona: Optional[str] = None,
|
||||
user_identity: Optional[Dict[str, str]] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
context_files: Optional[List[ContextFile]] = None,
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建完整的系统提示词
|
||||
|
||||
Args:
|
||||
base_persona: 基础人格描述(会被context_files中的AGENT.md覆盖)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表(AGENT.md, USER.md, RULE.md, BOOTSTRAP.md等)
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的系统提示词
|
||||
"""
|
||||
return build_agent_system_prompt(
|
||||
workspace_dir=self.workspace_dir,
|
||||
language=self.language,
|
||||
base_persona=base_persona,
|
||||
user_identity=user_identity,
|
||||
tools=tools,
|
||||
context_files=context_files,
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def build_agent_system_prompt(
|
||||
workspace_dir: str,
|
||||
language: str = "zh",
|
||||
base_persona: Optional[str] = None,
|
||||
user_identity: Optional[Dict[str, str]] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
context_files: Optional[List[ContextFile]] = None,
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建Agent系统提示词
|
||||
|
||||
顺序说明(按重要性和逻辑关系排列):
|
||||
1. 工具系统 - 核心能力,最先介绍
|
||||
2. 技能系统 - 紧跟工具,因为技能需要用 read 工具读取
|
||||
3. 记忆系统 - 独立的记忆能力
|
||||
4. 工作空间 - 工作环境说明
|
||||
5. 用户身份 - 用户信息(可选)
|
||||
6. 项目上下文 - AGENT.md, USER.md, RULE.md, BOOTSTRAP.md(定义人格、身份、规则、初始化引导)
|
||||
7. 运行时信息 - 元信息(时间、模型等)
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
language: 语言 ("zh" 或 "en")
|
||||
base_persona: 基础人格描述(已废弃,由AGENT.md定义)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的系统提示词
|
||||
"""
|
||||
sections = []
|
||||
|
||||
# 1. 工具系统(最重要,放在最前面)
|
||||
if tools:
|
||||
sections.extend(_build_tooling_section(tools, language))
|
||||
|
||||
# 2. 技能系统(紧跟工具,因为需要用 read 工具)
|
||||
if skill_manager:
|
||||
sections.extend(_build_skills_section(skill_manager, tools, language))
|
||||
|
||||
# 3. 记忆系统(独立的记忆能力)
|
||||
if memory_manager:
|
||||
sections.extend(_build_memory_section(memory_manager, tools, language))
|
||||
|
||||
# 4. 工作空间(工作环境说明)
|
||||
sections.extend(_build_workspace_section(workspace_dir, language))
|
||||
|
||||
# 5. 用户身份(如果有)
|
||||
if user_identity:
|
||||
sections.extend(_build_user_identity_section(user_identity, language))
|
||||
|
||||
# 6. 项目上下文文件(AGENT.md, USER.md, RULE.md - 定义人格)
|
||||
if context_files:
|
||||
sections.extend(_build_context_files_section(context_files, language))
|
||||
|
||||
# 7. 运行时信息(元信息,放在最后)
|
||||
if runtime_info:
|
||||
sections.extend(_build_runtime_section(runtime_info, language))
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def _build_identity_section(base_persona: Optional[str], language: str) -> List[str]:
|
||||
"""构建基础身份section - 不再需要,身份由AGENT.md定义"""
|
||||
# 不再生成基础身份section,完全由AGENT.md定义
|
||||
return []
|
||||
|
||||
|
||||
def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"""Build tooling section with concise tool list and call style guide."""
|
||||
# One-line summaries for known tools (details are in the tool schema)
|
||||
core_summaries = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建或覆盖文件",
|
||||
"edit": "精确编辑文件",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "搜索文件内容",
|
||||
"find": "按模式查找文件",
|
||||
"bash": "执行shell命令",
|
||||
"terminal": "管理后台进程",
|
||||
"web_search": "网络搜索",
|
||||
"web_fetch": "获取URL内容",
|
||||
"browser": "控制浏览器",
|
||||
"memory_search": "搜索记忆",
|
||||
"memory_get": "读取记忆内容",
|
||||
"env_config": "管理API密钥和技能配置",
|
||||
"scheduler": "管理定时任务和提醒",
|
||||
"send": "发送本地文件给用户(仅限本地文件,URL直接放在回复文本中)",
|
||||
}
|
||||
|
||||
# Preferred display order
|
||||
tool_order = [
|
||||
"read", "write", "edit", "ls", "grep", "find",
|
||||
"bash", "terminal",
|
||||
"web_search", "web_fetch", "browser",
|
||||
"memory_search", "memory_get",
|
||||
"env_config", "scheduler", "send",
|
||||
]
|
||||
|
||||
# Build name -> summary mapping for available tools
|
||||
available = {}
|
||||
for tool in tools:
|
||||
name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
available[name] = core_summaries.get(name, "")
|
||||
|
||||
# Generate tool lines: ordered tools first, then extras
|
||||
tool_lines = []
|
||||
for name in tool_order:
|
||||
if name in available:
|
||||
summary = available.pop(name)
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
for name in sorted(available):
|
||||
summary = available[name]
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
|
||||
lines = [
|
||||
"## 工具系统",
|
||||
"",
|
||||
"可用工具(名称大小写敏感,严格按列表调用):",
|
||||
"\n".join(tool_lines),
|
||||
"",
|
||||
"工具调用风格:",
|
||||
"",
|
||||
"- 在多步骤任务、敏感操作或用户要求时简要解释决策过程",
|
||||
"- 持续推进直到任务完成,完成后向用户报告结果。",
|
||||
"- 回复中涉及密钥、令牌等敏感信息必须脱敏。",
|
||||
"- URL链接直接放在回复文本中即可,系统会自动处理和渲染。无需下载后使用send工具发送",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""构建技能系统section"""
|
||||
if not skill_manager:
|
||||
return []
|
||||
|
||||
# 获取read工具名称
|
||||
read_tool_name = "read"
|
||||
if tools:
|
||||
for tool in tools:
|
||||
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
if tool_name.lower() == "read":
|
||||
read_tool_name = tool_name
|
||||
break
|
||||
|
||||
lines = [
|
||||
"## 技能系统(mandatory)",
|
||||
"",
|
||||
"在回复之前:扫描下方 <available_skills> 中每个技能的 <description>。",
|
||||
"",
|
||||
f"- 如果有技能的描述与用户需求匹配:使用 `{read_tool_name}` 工具读取其 <location> 路径的 SKILL.md 文件,然后严格遵循文件中的指令。"
|
||||
"当有匹配的技能时,应优先使用技能",
|
||||
"- 如果多个技能都适用则选择最匹配的一个,然后读取并遵循。",
|
||||
"- 如果没有技能明确适用:不要读取任何 SKILL.md,直接使用通用工具。",
|
||||
"",
|
||||
f"**重要**: 技能不是工具,不能直接调用。使用技能的唯一方式是用 `{read_tool_name}` 读取 SKILL.md 文件,然后按文件内容操作。"
|
||||
"永远不要一次性读取多个技能,只在选择后再读取。",
|
||||
"",
|
||||
"以下是可用技能:"
|
||||
]
|
||||
|
||||
# 添加技能列表(通过skill_manager获取)
|
||||
try:
|
||||
skills_prompt = skill_manager.build_skills_prompt()
|
||||
logger.debug(f"[PromptBuilder] Skills prompt length: {len(skills_prompt) if skills_prompt else 0}")
|
||||
if skills_prompt:
|
||||
lines.append(skills_prompt.strip())
|
||||
lines.append("")
|
||||
else:
|
||||
logger.warning("[PromptBuilder] No skills prompt generated - skills_prompt is empty")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build skills prompt: {e}")
|
||||
import traceback
|
||||
logger.debug(f"Skills prompt error traceback: {traceback.format_exc()}")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""构建记忆系统section"""
|
||||
if not memory_manager:
|
||||
return []
|
||||
|
||||
# 检查是否有memory工具
|
||||
has_memory_tools = False
|
||||
if tools:
|
||||
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
|
||||
has_memory_tools = any(name in ['memory_search', 'memory_get'] for name in tool_names)
|
||||
|
||||
if not has_memory_tools:
|
||||
return []
|
||||
|
||||
from datetime import datetime
|
||||
today_file = datetime.now().strftime("%Y-%m-%d") + ".md"
|
||||
|
||||
lines = [
|
||||
"## 记忆系统",
|
||||
"",
|
||||
"### 检索记忆",
|
||||
"",
|
||||
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
|
||||
"",
|
||||
"1. 不确定记忆文件位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
|
||||
"2. 已知文件位置 → 直接用 `memory_get` 读取相应的行 (例如:MEMORY.md, memory/YYYY-MM-DD.md)",
|
||||
"3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
f"- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
|
||||
f"- `memory/YYYY-MM-DD.md`: 每日记忆,今天是 `memory/{today_file}`",
|
||||
"",
|
||||
"### 写入记忆",
|
||||
"",
|
||||
"**主动存储**:遇到以下情况时,应主动将信息写入记忆文件(无需告知用户):",
|
||||
"",
|
||||
"- 用户明确要求你记住某些信息",
|
||||
"- 用户分享了重要的个人偏好、习惯、决策",
|
||||
"- 对话中产生了重要的结论、方案、约定",
|
||||
"- 完成了复杂任务,值得记录关键步骤和结果",
|
||||
"- 发现了用户经常遇到的问题或解决方案",
|
||||
"",
|
||||
"**存储规则**:",
|
||||
f"- 长期有效的核心信息 → `MEMORY.md`(文件保持精简,< 2000 tokens)",
|
||||
f"- 当天的事件、进展、笔记 → `memory/{today_file}`",
|
||||
"- 追加内容 → `edit` 工具,oldText 留空",
|
||||
"- 修改内容 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- **禁止写入敏感信息**:API密钥、令牌等敏感信息严禁写入记忆文件",
|
||||
"",
|
||||
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_user_identity_section(user_identity: Dict[str, str], language: str) -> List[str]:
|
||||
"""构建用户身份section"""
|
||||
if not user_identity:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 用户身份",
|
||||
"",
|
||||
]
|
||||
|
||||
if user_identity.get("name"):
|
||||
lines.append(f"**用户姓名**: {user_identity['name']}")
|
||||
if user_identity.get("nickname"):
|
||||
lines.append(f"**称呼**: {user_identity['nickname']}")
|
||||
if user_identity.get("timezone"):
|
||||
lines.append(f"**时区**: {user_identity['timezone']}")
|
||||
if user_identity.get("notes"):
|
||||
lines.append(f"**备注**: {user_identity['notes']}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""构建文档路径section - 已移除,不再需要"""
|
||||
# 不再生成文档section
|
||||
return []
|
||||
|
||||
|
||||
def _build_workspace_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""构建工作空间section"""
|
||||
lines = [
|
||||
"## 工作空间",
|
||||
"",
|
||||
f"你的工作目录是: `{workspace_dir}`",
|
||||
"",
|
||||
"**路径使用规则** (非常重要):",
|
||||
"",
|
||||
f"1. **相对路径的基准目录**: 所有相对路径都是相对于 `{workspace_dir}` 而言的",
|
||||
f" - ✅ 正确: 访问工作空间内的文件用相对路径,如 `AGENT.md`",
|
||||
f" - ❌ 错误: 用相对路径访问其他目录的文件 (如果它不在 `{workspace_dir}` 内)",
|
||||
"",
|
||||
"2. **访问其他目录**: 如果要访问工作空间之外的目录(如项目代码、系统文件),**必须使用绝对路径**",
|
||||
f" - ✅ 正确: 例如 `~/chatgpt-on-wechat`、`/usr/local/`",
|
||||
f" - ❌ 错误: 假设相对路径会指向其他目录",
|
||||
"",
|
||||
"3. **路径解析示例**:",
|
||||
f" - 相对路径 `memory/` → 实际路径 `{workspace_dir}/memory/`",
|
||||
f" - 绝对路径 `~/chatgpt-on-wechat/docs/` → 实际路径 `~/chatgpt-on-wechat/docs/`",
|
||||
"",
|
||||
"4. **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
|
||||
"",
|
||||
"**重要说明 - 文件已自动加载**:",
|
||||
"",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词的「项目上下文」section 中,你**无需再用 read 工具读取它们**:",
|
||||
"",
|
||||
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定。当用户修改你的名字、性格或交流风格时,用 `edit` 更新此文件",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息。当用户修改称呼、姓名等身份信息时,用 `edit` 更新此文件",
|
||||
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则",
|
||||
"",
|
||||
"**交流规范**:",
|
||||
"",
|
||||
"- 在对话中,无需直接输出工作空间中的技术细节,例如 AGENT.md、USER.md、MEMORY.md 等文件名称",
|
||||
"- 例如用自然表达例如「我已记住」而不是「已更新 MEMORY.md」",
|
||||
"",
|
||||
]
|
||||
|
||||
# Cloud deployment: inject websites directory info and access URL
|
||||
cloud_website_lines = _build_cloud_website_section(workspace_dir)
|
||||
if cloud_website_lines:
|
||||
lines.extend(cloud_website_lines)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_cloud_website_section(workspace_dir: str) -> List[str]:
|
||||
"""Build cloud website access prompt when cloud deployment is configured."""
|
||||
try:
|
||||
from common.cloud_client import build_website_prompt
|
||||
return build_website_prompt(workspace_dir)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
|
||||
"""构建项目上下文文件section"""
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
# 检查是否有AGENT.md
|
||||
has_agent = any(
|
||||
f.path.lower().endswith('agent.md') or 'agent.md' in f.path.lower()
|
||||
for f in context_files
|
||||
)
|
||||
|
||||
lines = [
|
||||
"# 项目上下文",
|
||||
"",
|
||||
"以下项目上下文文件已被加载:",
|
||||
"",
|
||||
]
|
||||
|
||||
if has_agent:
|
||||
lines.append("如果存在 `AGENT.md`,请体现其中定义的人格和语气。避免僵硬、模板化的回复;遵循其指导,除非有更高优先级的指令覆盖它。")
|
||||
lines.append("")
|
||||
|
||||
# 添加每个文件的内容
|
||||
for file in context_files:
|
||||
lines.append(f"## {file.path}")
|
||||
lines.append("")
|
||||
lines.append(file.content)
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[str]:
|
||||
"""构建运行时信息section - 支持动态时间"""
|
||||
if not runtime_info:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 运行时信息",
|
||||
"",
|
||||
]
|
||||
|
||||
# Add current time if available
|
||||
# Support dynamic time via callable function
|
||||
if callable(runtime_info.get("_get_current_time")):
|
||||
try:
|
||||
time_info = runtime_info["_get_current_time"]()
|
||||
time_line = f"当前时间: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})"
|
||||
lines.append(time_line)
|
||||
lines.append("")
|
||||
except Exception as e:
|
||||
logger.warning(f"[PromptBuilder] Failed to get dynamic time: {e}")
|
||||
elif runtime_info.get("current_time"):
|
||||
# Fallback to static time for backward compatibility
|
||||
time_str = runtime_info["current_time"]
|
||||
weekday = runtime_info.get("weekday", "")
|
||||
timezone = runtime_info.get("timezone", "")
|
||||
|
||||
time_line = f"当前时间: {time_str}"
|
||||
if weekday:
|
||||
time_line += f" {weekday}"
|
||||
if timezone:
|
||||
time_line += f" ({timezone})"
|
||||
|
||||
lines.append(time_line)
|
||||
lines.append("")
|
||||
|
||||
# Add other runtime info
|
||||
runtime_parts = []
|
||||
if runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={runtime_info['model']}")
|
||||
if runtime_info.get("workspace"):
|
||||
runtime_parts.append(f"工作空间={runtime_info['workspace']}")
|
||||
# Only add channel if it's not the default "web"
|
||||
if runtime_info.get("channel") and runtime_info.get("channel") != "web":
|
||||
runtime_parts.append(f"渠道={runtime_info['channel']}")
|
||||
|
||||
if runtime_parts:
|
||||
lines.append("运行时: " + " | ".join(runtime_parts))
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
382
agent/prompt/workspace.py
Normal file
382
agent/prompt/workspace.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""
|
||||
Workspace Management - 工作空间管理模块
|
||||
|
||||
负责初始化工作空间、创建模板文件、加载上下文文件
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import List, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
from .builder import ContextFile
|
||||
|
||||
|
||||
# 默认文件名常量
|
||||
DEFAULT_AGENT_FILENAME = "AGENT.md"
|
||||
DEFAULT_USER_FILENAME = "USER.md"
|
||||
DEFAULT_RULE_FILENAME = "RULE.md"
|
||||
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
|
||||
DEFAULT_BOOTSTRAP_FILENAME = "BOOTSTRAP.md"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceFiles:
|
||||
"""工作空间文件路径"""
|
||||
agent_path: str
|
||||
user_path: str
|
||||
rule_path: str
|
||||
memory_path: str
|
||||
memory_dir: str
|
||||
|
||||
|
||||
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
|
||||
"""
|
||||
确保工作空间存在,并创建必要的模板文件
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录路径
|
||||
create_templates: 是否创建模板文件(首次运行时)
|
||||
|
||||
Returns:
|
||||
WorkspaceFiles对象,包含所有文件路径
|
||||
"""
|
||||
# Check if this is a brand new workspace (AGENT.md not yet created).
|
||||
# Cannot rely on directory existence because other modules (e.g. ConversationStore)
|
||||
# may create the workspace directory before ensure_workspace is called.
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
is_new_workspace = not os.path.exists(agent_path)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
|
||||
# 定义文件路径
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
rule_path = os.path.join(workspace_dir, DEFAULT_RULE_FILENAME)
|
||||
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md 在根目录
|
||||
memory_dir = os.path.join(workspace_dir, "memory") # 每日记忆子目录
|
||||
|
||||
# 创建memory子目录
|
||||
os.makedirs(memory_dir, exist_ok=True)
|
||||
|
||||
# 创建skills子目录 (for workspace-level skills installed by agent)
|
||||
skills_dir = os.path.join(workspace_dir, "skills")
|
||||
os.makedirs(skills_dir, exist_ok=True)
|
||||
|
||||
# 创建websites子目录 (for web pages / sites generated by agent)
|
||||
websites_dir = os.path.join(workspace_dir, "websites")
|
||||
os.makedirs(websites_dir, exist_ok=True)
|
||||
|
||||
# 如果需要,创建模板文件
|
||||
if create_templates:
|
||||
_create_template_if_missing(agent_path, _get_agent_template())
|
||||
_create_template_if_missing(user_path, _get_user_template())
|
||||
_create_template_if_missing(rule_path, _get_rule_template())
|
||||
_create_template_if_missing(memory_path, _get_memory_template())
|
||||
|
||||
# Only create BOOTSTRAP.md for brand new workspaces;
|
||||
# agent deletes it after completing onboarding
|
||||
if is_new_workspace:
|
||||
bootstrap_path = os.path.join(workspace_dir, DEFAULT_BOOTSTRAP_FILENAME)
|
||||
_create_template_if_missing(bootstrap_path, _get_bootstrap_template())
|
||||
|
||||
logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
|
||||
|
||||
return WorkspaceFiles(
|
||||
agent_path=agent_path,
|
||||
user_path=user_path,
|
||||
rule_path=rule_path,
|
||||
memory_path=memory_path,
|
||||
memory_dir=memory_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] = None) -> List[ContextFile]:
|
||||
"""
|
||||
加载工作空间的上下文文件
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
files_to_load: 要加载的文件列表(相对路径),如果为None则加载所有标准文件
|
||||
|
||||
Returns:
|
||||
ContextFile对象列表
|
||||
"""
|
||||
if files_to_load is None:
|
||||
# 默认加载的文件(按优先级排序)
|
||||
files_to_load = [
|
||||
DEFAULT_AGENT_FILENAME,
|
||||
DEFAULT_USER_FILENAME,
|
||||
DEFAULT_RULE_FILENAME,
|
||||
DEFAULT_BOOTSTRAP_FILENAME, # Only exists when onboarding is incomplete
|
||||
]
|
||||
|
||||
context_files = []
|
||||
|
||||
for filename in files_to_load:
|
||||
filepath = os.path.join(workspace_dir, filename)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
# Auto-cleanup: if BOOTSTRAP.md still exists but AGENT.md is already
|
||||
# filled in, the agent forgot to delete it — clean up and skip loading
|
||||
if filename == DEFAULT_BOOTSTRAP_FILENAME:
|
||||
if _is_onboarding_done(workspace_dir):
|
||||
try:
|
||||
os.remove(filepath)
|
||||
logger.info("[Workspace] Auto-removed BOOTSTRAP.md (onboarding already complete)")
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 跳过空文件或只包含模板占位符的文件
|
||||
if not content or _is_template_placeholder(content):
|
||||
continue
|
||||
|
||||
context_files.append(ContextFile(
|
||||
path=filename,
|
||||
content=content
|
||||
))
|
||||
|
||||
logger.debug(f"[Workspace] Loaded context file: {filename}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to load {filename}: {e}")
|
||||
|
||||
return context_files
|
||||
|
||||
|
||||
def _create_template_if_missing(filepath: str, template_content: str):
|
||||
"""如果文件不存在,创建模板文件"""
|
||||
if not os.path.exists(filepath):
|
||||
try:
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(template_content)
|
||||
logger.debug(f"[Workspace] Created template: {os.path.basename(filepath)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Workspace] Failed to create template {filepath}: {e}")
|
||||
|
||||
|
||||
def _is_template_placeholder(content: str) -> bool:
|
||||
"""检查内容是否为模板占位符"""
|
||||
# 常见的占位符模式
|
||||
placeholders = [
|
||||
"*(填写",
|
||||
"*(在首次对话时填写",
|
||||
"*(可选)",
|
||||
"*(根据需要添加",
|
||||
]
|
||||
|
||||
lines = content.split('\n')
|
||||
non_empty_lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
|
||||
|
||||
# 如果没有实际内容(只有标题和占位符)
|
||||
if len(non_empty_lines) <= 3:
|
||||
for placeholder in placeholders:
|
||||
if any(placeholder in line for line in non_empty_lines):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_onboarding_done(workspace_dir: str) -> bool:
|
||||
"""Check if AGENT.md or USER.md has been modified from the original template"""
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
|
||||
agent_template = _get_agent_template().strip()
|
||||
user_template = _get_user_template().strip()
|
||||
|
||||
for path, template in [(agent_path, agent_template), (user_path, user_template)]:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
if content != template:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
# ============= 模板内容 =============
|
||||
|
||||
def _get_agent_template() -> str:
|
||||
"""Agent人格设定模板"""
|
||||
return """# AGENT.md - 我是谁?
|
||||
|
||||
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
|
||||
|
||||
## 基本信息
|
||||
|
||||
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
|
||||
- **角色**: *(AI助理、智能管家、技术顾问等)*
|
||||
- **性格**: *(友好、专业、幽默、严谨等)*
|
||||
|
||||
## 交流风格
|
||||
|
||||
*(描述你如何与用户交流:)*
|
||||
- 使用什么样的语言风格?(正式/轻松/幽默)
|
||||
- 回复长度偏好?(简洁/详细)
|
||||
- 是否使用表情符号?
|
||||
|
||||
## 核心能力
|
||||
|
||||
*(你擅长什么?)*
|
||||
- 文件管理和代码编辑
|
||||
- 网络搜索和信息查询
|
||||
- 记忆管理和上下文理解
|
||||
- 任务规划和执行
|
||||
|
||||
## 行为准则
|
||||
|
||||
*(你遵循的基本原则:)*
|
||||
1. 始终在执行破坏性操作前确认
|
||||
2. 优先使用工具而不是猜测
|
||||
3. 主动记录重要信息到记忆文件
|
||||
4. 定期整理和总结对话内容
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
"""
|
||||
|
||||
|
||||
def _get_user_template() -> str:
|
||||
"""用户身份信息模板"""
|
||||
return """# USER.md - 用户基本信息
|
||||
|
||||
*这个文件只存放不会变的基本身份信息。爱好、偏好、计划等动态信息请写入 MEMORY.md。*
|
||||
|
||||
## 基本信息
|
||||
|
||||
- **姓名**: *(在首次对话时询问)*
|
||||
- **称呼**: *(用户希望被如何称呼)*
|
||||
- **职业**: *(可选)*
|
||||
- **时区**: *(例如: Asia/Shanghai)*
|
||||
|
||||
## 联系方式
|
||||
|
||||
- **微信**:
|
||||
- **邮箱**:
|
||||
- **其他**:
|
||||
|
||||
## 重要日期
|
||||
|
||||
- **生日**:
|
||||
- **纪念日**:
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这个文件存放静态的身份信息
|
||||
"""
|
||||
|
||||
|
||||
def _get_rule_template() -> str:
|
||||
"""工作空间规则模板"""
|
||||
return """# RULE.md - 工作空间规则
|
||||
|
||||
这个文件夹是你的家。好好对待它。
|
||||
|
||||
## 记忆系统
|
||||
|
||||
你每次会话都是全新的,记忆文件让你保持连续性:
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 原始的对话日志
|
||||
- 记录当天发生的事情
|
||||
- 如果 `memory/` 目录不存在,创建它
|
||||
|
||||
### 🧠 长期记忆:`MEMORY.md`
|
||||
- 你精选的记忆,就像人类的长期记忆
|
||||
- **仅在主会话中加载**(与用户的直接聊天)
|
||||
- **不要在共享上下文中加载**(群聊、与其他人的会话)
|
||||
- 这是为了**安全** - 包含不应泄露给陌生人的个人上下文
|
||||
- 记录重要事件、想法、决定、观点、经验教训
|
||||
- 这是你精选的记忆 - 精华,而不是原始日志
|
||||
- 用 `edit` 工具追加新的记忆内容
|
||||
|
||||
### 📝 写下来 - 不要"记在心里"!
|
||||
- **记忆是有限的** - 如果你想记住某事,写入文件
|
||||
- "记在心里"不会在会话重启后保留,文件才会
|
||||
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
|
||||
- 当你学到教训 → 更新 RULE.md 或相关技能
|
||||
- 当你犯错 → 记录下来,这样未来的你不会重复,**文字 > 大脑** 📝
|
||||
|
||||
### 存储规则
|
||||
|
||||
当用户分享信息时,根据类型选择存储位置:
|
||||
|
||||
1. **你的身份设定 → AGENT.md**(你的名字、角色、性格、交流风格——用户修改时必须用 `edit` 更新)
|
||||
2. **用户静态身份 → USER.md**(姓名、称呼、职业、时区、联系方式、生日——用户修改时必须用 `edit` 更新)
|
||||
3. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
|
||||
4. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
|
||||
## 安全
|
||||
|
||||
- 永远不要泄露秘钥等私人数据
|
||||
- 不要在未经询问的情况下运行破坏性命令
|
||||
- 当有疑问时,先问
|
||||
|
||||
## 工作空间演化
|
||||
|
||||
这个工作空间会随着你的使用而不断成长。当你学到新东西、发现更好的方式,或者犯错后改正时,记录下来。你可以随时更新这个规则文件。
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_template() -> str:
|
||||
"""长期记忆模板 - 创建一个空文件,由 Agent 自己填充"""
|
||||
return """# MEMORY.md - 长期记忆
|
||||
|
||||
*这是你的长期记忆文件。记录重要的事件、决策、偏好、学到的教训。*
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _get_bootstrap_template() -> str:
|
||||
"""First-run onboarding guide, deleted by agent after completion"""
|
||||
return """# BOOTSTRAP.md - 首次初始化引导
|
||||
|
||||
_你刚刚启动,这是你的第一次对话。_
|
||||
|
||||
## 对话流程
|
||||
|
||||
不要审问式地提问,自然地交流:
|
||||
|
||||
1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待
|
||||
2. **简短介绍能力**:一行说明你能帮助解决各种问题、管理计算机、使用各种技能等等,且拥有长期记忆能不断成长
|
||||
3. **询问核心问题**:
|
||||
- 你希望给我起个什么名字?
|
||||
- 我该怎么称呼你?
|
||||
- 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)
|
||||
4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内
|
||||
5. 能力介绍和交流风格选项都只要一行,保持精简
|
||||
6. 不要问太多其他信息(职业、时区等可以后续自然了解)
|
||||
|
||||
**重要**: 如果用户第一句话是具体的任务或提问,先回答他们的问题,然后在回复末尾自然地引导初始化(如:"顺便问一下,你想怎么称呼我?我该怎么叫你?")。
|
||||
|
||||
## 信息写入(必须严格执行)
|
||||
|
||||
每当用户提供了名字、称呼、风格等任何初始化信息时,**必须在当轮回复中立即调用 `edit` 工具写入文件**,不能只口头确认。
|
||||
|
||||
- `AGENT.md` — 你的名字、角色、性格、交流风格(每收到一条相关信息就立即更新对应字段)
|
||||
- `USER.md` — 用户的姓名、称呼、基本信息等
|
||||
|
||||
⚠️ 只说"记住了"而不调用 edit 写入 = 没有完成。信息只有写入文件才会被持久保存。
|
||||
|
||||
## 全部完成后
|
||||
|
||||
当 AGENT.md 和 USER.md 的核心字段都已填写后,用 bash 执行 `rm BOOTSTRAP.md` 删除此文件。你不再需要引导脚本了——你已经是你了。
|
||||
"""
|
||||
|
||||
|
||||
|
||||
20
agent/protocol/__init__.py
Normal file
20
agent/protocol/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .agent import Agent
|
||||
from .agent_stream import AgentStreamExecutor
|
||||
from .task import Task, TaskType, TaskStatus
|
||||
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
|
||||
from .models import LLMModel, LLMRequest, ModelFactory
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'AgentStreamExecutor',
|
||||
'Task',
|
||||
'TaskType',
|
||||
'TaskStatus',
|
||||
'AgentResult',
|
||||
'AgentAction',
|
||||
'AgentActionType',
|
||||
'ToolResult',
|
||||
'LLMModel',
|
||||
'LLMRequest',
|
||||
'ModelFactory'
|
||||
]
|
||||
571
agent/protocol/agent.py
Normal file
571
agent/protocol/agent.py
Normal file
@@ -0,0 +1,571 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult
|
||||
from agent.tools.base_tool import BaseTool, ToolStage
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
||||
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
||||
context_reserve_tokens=None, memory_manager=None, name: str = None,
|
||||
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True,
|
||||
runtime_info: dict = None):
|
||||
"""
|
||||
Initialize the Agent with system prompt, model, description.
|
||||
|
||||
:param system_prompt: The system prompt for the agent.
|
||||
:param description: A description of the agent.
|
||||
:param model: An instance of LLMModel to be used by the agent.
|
||||
:param tools: Optional list of tools for the agent to use.
|
||||
:param output_mode: Control how execution progress is displayed:
|
||||
"print" for console output or "logger" for using logger
|
||||
:param max_steps: Maximum number of steps the agent can take (default: 100)
|
||||
:param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model)
|
||||
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
|
||||
:param memory_manager: Optional MemoryManager instance for memory operations
|
||||
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
|
||||
:param workspace_dir: Optional workspace directory for workspace-specific skills
|
||||
:param skill_manager: Optional SkillManager instance (will be created if None and enable_skills=True)
|
||||
:param enable_skills: Whether to enable skills support (default: True)
|
||||
:param runtime_info: Optional runtime info dict (with _get_current_time callable for dynamic time)
|
||||
"""
|
||||
self.name = name or "Agent"
|
||||
self.system_prompt = system_prompt
|
||||
self.model: LLMModel = model # Instance of LLMModel
|
||||
self.description = description
|
||||
self.tools: list = []
|
||||
self.max_steps = max_steps # max tool-call steps, default 100
|
||||
self.max_context_tokens = max_context_tokens # max tokens in context
|
||||
self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests
|
||||
self.captured_actions = [] # Initialize captured actions list
|
||||
self.output_mode = output_mode
|
||||
self.last_usage = None # Store last API response usage info
|
||||
self.messages = [] # Unified message history for stream mode
|
||||
self.messages_lock = threading.Lock() # Lock for thread-safe message operations
|
||||
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||
self.workspace_dir = workspace_dir # Workspace directory
|
||||
self.enable_skills = enable_skills # Skills enabled flag
|
||||
self.runtime_info = runtime_info # Runtime info for dynamic time update
|
||||
|
||||
# Initialize skill manager
|
||||
self.skill_manager = None
|
||||
if enable_skills:
|
||||
if skill_manager:
|
||||
self.skill_manager = skill_manager
|
||||
else:
|
||||
# Auto-create skill manager
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
custom_dir = os.path.join(workspace_dir, "skills") if workspace_dir else None
|
||||
self.skill_manager = SkillManager(custom_dir=custom_dir)
|
||||
logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SkillManager: {e}")
|
||||
|
||||
if tools:
|
||||
for tool in tools:
|
||||
self.add_tool(tool)
|
||||
|
||||
def add_tool(self, tool: BaseTool):
|
||||
"""
|
||||
Add a tool to the agent.
|
||||
|
||||
:param tool: The tool to add (either a tool instance or a tool name)
|
||||
"""
|
||||
# If tool is already an instance, use it directly
|
||||
tool.model = self.model
|
||||
self.tools.append(tool)
|
||||
|
||||
def get_skills_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the skills prompt to append to system prompt.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt or empty string
|
||||
"""
|
||||
if not self.skill_manager:
|
||||
return ""
|
||||
|
||||
try:
|
||||
return self.skill_manager.build_skills_prompt(skill_filter=skill_filter)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build skills prompt: {e}")
|
||||
return ""
|
||||
|
||||
def get_full_system_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the full system prompt including skills.
|
||||
|
||||
Note: Skills are now built into the system prompt by PromptBuilder,
|
||||
so we just return the base prompt directly. This method is kept for
|
||||
backward compatibility.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include (deprecated)
|
||||
:return: Complete system prompt
|
||||
"""
|
||||
prompt = self.system_prompt
|
||||
|
||||
# Rebuild tool list section to reflect current self.tools
|
||||
prompt = self._rebuild_tool_list_section(prompt)
|
||||
|
||||
# If runtime_info contains dynamic time function, rebuild runtime section
|
||||
if self.runtime_info and callable(self.runtime_info.get('_get_current_time')):
|
||||
prompt = self._rebuild_runtime_section(prompt)
|
||||
|
||||
# Rebuild skills section to pick up newly installed/removed skills
|
||||
if self.skill_manager:
|
||||
prompt = self._rebuild_skills_section(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
def _rebuild_runtime_section(self, prompt: str) -> str:
|
||||
"""
|
||||
Rebuild runtime info section with current time.
|
||||
|
||||
This method dynamically updates the runtime info section by calling
|
||||
the _get_current_time function from runtime_info.
|
||||
|
||||
:param prompt: Original system prompt
|
||||
:return: Updated system prompt with current runtime info
|
||||
"""
|
||||
try:
|
||||
# Get current time dynamically
|
||||
time_info = self.runtime_info['_get_current_time']()
|
||||
|
||||
# Build new runtime section
|
||||
runtime_lines = [
|
||||
"\n## 运行时信息\n",
|
||||
"\n",
|
||||
f"当前时间: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})\n",
|
||||
"\n"
|
||||
]
|
||||
|
||||
# Add other runtime info
|
||||
runtime_parts = []
|
||||
if self.runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={self.runtime_info['model']}")
|
||||
if self.runtime_info.get("workspace"):
|
||||
# Replace backslashes with forward slashes for Windows paths
|
||||
workspace_path = str(self.runtime_info['workspace']).replace('\\', '/')
|
||||
runtime_parts.append(f"工作空间={workspace_path}")
|
||||
if self.runtime_info.get("channel") and self.runtime_info.get("channel") != "web":
|
||||
runtime_parts.append(f"渠道={self.runtime_info['channel']}")
|
||||
|
||||
if runtime_parts:
|
||||
runtime_lines.append("运行时: " + " | ".join(runtime_parts) + "\n")
|
||||
runtime_lines.append("\n")
|
||||
|
||||
new_runtime_section = "".join(runtime_lines)
|
||||
|
||||
# Find and replace the runtime section
|
||||
import re
|
||||
pattern = r'\n## 运行时信息\s*\n.*?(?=\n##|\Z)'
|
||||
_repl = new_runtime_section.rstrip('\n')
|
||||
updated_prompt = re.sub(pattern, lambda m: _repl, prompt, flags=re.DOTALL)
|
||||
|
||||
return updated_prompt
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild runtime section: {e}")
|
||||
return prompt
|
||||
|
||||
def _rebuild_skills_section(self, prompt: str) -> str:
|
||||
"""
|
||||
Rebuild the <available_skills> block so that newly installed or
|
||||
removed skills are reflected without re-creating the agent.
|
||||
"""
|
||||
try:
|
||||
import re
|
||||
self.skill_manager.refresh_skills()
|
||||
new_skills_xml = self.skill_manager.build_skills_prompt()
|
||||
|
||||
old_block_pattern = r'<available_skills>.*?</available_skills>'
|
||||
has_old_block = re.search(old_block_pattern, prompt, flags=re.DOTALL)
|
||||
|
||||
# Extract the new <available_skills>...</available_skills> tag from the prompt
|
||||
new_block = ""
|
||||
if new_skills_xml and new_skills_xml.strip():
|
||||
m = re.search(old_block_pattern, new_skills_xml, flags=re.DOTALL)
|
||||
if m:
|
||||
new_block = m.group(0)
|
||||
|
||||
if has_old_block:
|
||||
replacement = new_block or "<available_skills>\n</available_skills>"
|
||||
# Use lambda to prevent re.sub from interpreting backslashes in replacement
|
||||
# (e.g. Windows paths like \LinkAI would be treated as bad escape sequences)
|
||||
prompt = re.sub(old_block_pattern, lambda m: replacement, prompt, flags=re.DOTALL)
|
||||
elif new_block:
|
||||
skills_header = "以下是可用技能:"
|
||||
idx = prompt.find(skills_header)
|
||||
if idx != -1:
|
||||
insert_pos = idx + len(skills_header)
|
||||
prompt = prompt[:insert_pos] + "\n" + new_block + prompt[insert_pos:]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild skills section: {e}")
|
||||
return prompt
|
||||
|
||||
def _rebuild_tool_list_section(self, prompt: str) -> str:
|
||||
"""
|
||||
Rebuild the tool list inside the '## 工具系统' section so that it
|
||||
always reflects the current ``self.tools`` (handles dynamic add/remove
|
||||
of conditional tools like web_search).
|
||||
"""
|
||||
import re
|
||||
from agent.prompt.builder import _build_tooling_section
|
||||
|
||||
try:
|
||||
if not self.tools:
|
||||
return prompt
|
||||
|
||||
new_lines = _build_tooling_section(self.tools, "zh")
|
||||
new_section = "\n".join(new_lines).rstrip("\n")
|
||||
|
||||
# Replace existing tooling section
|
||||
pattern = r'## 工具系统\s*\n.*?(?=\n## |\Z)'
|
||||
updated = re.sub(pattern, lambda m: new_section, prompt, count=1, flags=re.DOTALL)
|
||||
return updated
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild tool list section: {e}")
|
||||
return prompt
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Refresh the loaded skills."""
|
||||
if self.skill_manager:
|
||||
self.skill_manager.refresh_skills()
|
||||
logger.info(f"Refreshed skills: {len(self.skill_manager.skills)} skills loaded")
|
||||
|
||||
def list_skills(self):
|
||||
"""
|
||||
List all loaded skills.
|
||||
|
||||
:return: List of skill entries or empty list
|
||||
"""
|
||||
if not self.skill_manager:
|
||||
return []
|
||||
return self.skill_manager.list_skills()
|
||||
|
||||
def _get_model_context_window(self) -> int:
|
||||
"""
|
||||
Get the model's context window size in tokens.
|
||||
Auto-detect based on model name.
|
||||
|
||||
Model context windows:
|
||||
- Claude 3.5/3.7 Sonnet: 200K tokens
|
||||
- Claude 3 Opus: 200K tokens
|
||||
- GPT-4 Turbo/128K: 128K tokens
|
||||
- GPT-4: 8K-32K tokens
|
||||
- GPT-3.5: 16K tokens
|
||||
- DeepSeek: 64K tokens
|
||||
|
||||
:return: Context window size in tokens
|
||||
"""
|
||||
if self.model and hasattr(self.model, 'model'):
|
||||
model_name = self.model.model.lower()
|
||||
|
||||
# Claude models - 200K context
|
||||
if 'claude-3' in model_name or 'claude-sonnet' in model_name:
|
||||
return 200000
|
||||
|
||||
# GPT-4 models
|
||||
elif 'gpt-4' in model_name:
|
||||
if 'turbo' in model_name or '128k' in model_name:
|
||||
return 128000
|
||||
elif '32k' in model_name:
|
||||
return 32000
|
||||
else:
|
||||
return 8000
|
||||
|
||||
# GPT-3.5
|
||||
elif 'gpt-3.5' in model_name:
|
||||
if '16k' in model_name:
|
||||
return 16000
|
||||
else:
|
||||
return 4000
|
||||
|
||||
# DeepSeek
|
||||
elif 'deepseek' in model_name:
|
||||
return 64000
|
||||
|
||||
# Gemini models
|
||||
elif 'gemini' in model_name:
|
||||
if '2.0' in model_name or 'exp' in model_name:
|
||||
return 2000000 # Gemini 2.0: 2M tokens
|
||||
else:
|
||||
return 1000000 # Gemini 1.5: 1M tokens
|
||||
|
||||
# Default conservative value
|
||||
return 128000
|
||||
|
||||
def _get_context_reserve_tokens(self) -> int:
|
||||
"""
|
||||
Get the number of tokens to reserve for new requests.
|
||||
This prevents context overflow by keeping a buffer.
|
||||
|
||||
:return: Number of tokens to reserve
|
||||
"""
|
||||
if self.context_reserve_tokens is not None:
|
||||
return self.context_reserve_tokens
|
||||
|
||||
# Reserve ~10% of context window, with min 10K and max 200K
|
||||
context_window = self._get_model_context_window()
|
||||
reserve = int(context_window * 0.1)
|
||||
return max(10000, min(200000, reserve))
|
||||
|
||||
def _estimate_message_tokens(self, message: dict) -> int:
|
||||
"""
|
||||
Estimate token count for a message.
|
||||
|
||||
Uses chars/3 for Chinese-heavy content and chars/4 for ASCII-heavy content,
|
||||
plus per-block overhead for tool_use / tool_result structures.
|
||||
|
||||
:param message: Message dict with 'role' and 'content'
|
||||
:return: Estimated token count
|
||||
"""
|
||||
content = message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
return max(1, self._estimate_text_tokens(content))
|
||||
elif isinstance(content, list):
|
||||
total_tokens = 0
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
block_type = part.get('type', '')
|
||||
if block_type == 'text':
|
||||
total_tokens += self._estimate_text_tokens(part.get('text', ''))
|
||||
elif block_type == 'image':
|
||||
total_tokens += 1200
|
||||
elif block_type == 'tool_use':
|
||||
# tool_use has id + name + input (JSON-encoded)
|
||||
total_tokens += 50 # overhead for structure
|
||||
input_data = part.get('input', {})
|
||||
if isinstance(input_data, dict):
|
||||
import json
|
||||
input_str = json.dumps(input_data, ensure_ascii=False)
|
||||
total_tokens += self._estimate_text_tokens(input_str)
|
||||
elif block_type == 'tool_result':
|
||||
# tool_result has tool_use_id + content
|
||||
total_tokens += 30 # overhead for structure
|
||||
result_content = part.get('content', '')
|
||||
if isinstance(result_content, str):
|
||||
total_tokens += self._estimate_text_tokens(result_content)
|
||||
else:
|
||||
# Unknown block type, estimate conservatively
|
||||
total_tokens += 10
|
||||
return max(1, total_tokens)
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def _estimate_text_tokens(text: str) -> int:
|
||||
"""
|
||||
Estimate token count for a text string.
|
||||
|
||||
Chinese / CJK characters typically use ~1.5 tokens each,
|
||||
while ASCII uses ~0.25 tokens per char (4 chars/token).
|
||||
We use a weighted average based on the character mix.
|
||||
|
||||
:param text: Input text
|
||||
:return: Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Count non-ASCII characters (CJK, emoji, etc.)
|
||||
non_ascii = sum(1 for c in text if ord(c) > 127)
|
||||
ascii_count = len(text) - non_ascii
|
||||
# CJK chars: ~1.5 tokens each; ASCII: ~0.25 tokens per char
|
||||
return int(non_ascii * 1.5 + ascii_count * 0.25) + 1
|
||||
|
||||
def _find_tool(self, tool_name: str):
|
||||
"""Find and return a tool with the specified name"""
|
||||
for tool in self.tools:
|
||||
if tool.name == tool_name:
|
||||
# Only pre-process stage tools can be actively called
|
||||
if tool.stage == ToolStage.PRE_PROCESS:
|
||||
tool.model = self.model
|
||||
tool.context = self # Set tool context
|
||||
return tool
|
||||
else:
|
||||
# If it's a post-process tool, return None to prevent direct calling
|
||||
logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.")
|
||||
return None
|
||||
return None
|
||||
|
||||
# output function based on mode
|
||||
def output(self, message="", end="\n"):
|
||||
if self.output_mode == "print":
|
||||
print(message, end=end)
|
||||
elif message:
|
||||
logger.info(message)
|
||||
|
||||
def _execute_post_process_tools(self):
|
||||
"""Execute all post-process stage tools"""
|
||||
# Get all post-process stage tools
|
||||
post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS]
|
||||
|
||||
# Execute each tool
|
||||
for tool in post_process_tools:
|
||||
# Set tool context
|
||||
tool.context = self
|
||||
|
||||
# Record start time for execution timing
|
||||
start_time = time.time()
|
||||
|
||||
# Execute tool (with empty parameters, tool will extract needed info from context)
|
||||
result = tool.execute({})
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Capture tool use for tracking
|
||||
self.capture_tool_use(
|
||||
tool_name=tool.name,
|
||||
input_params={}, # Post-process tools typically don't take parameters
|
||||
output=result.result,
|
||||
status=result.status,
|
||||
error_message=str(result.result) if result.status == "error" else None,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# Log result
|
||||
if result.status == "success":
|
||||
# Print tool execution result in the desired format
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}")
|
||||
else:
|
||||
# Print failure in print mode
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}")
|
||||
|
||||
def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None,
|
||||
execution_time=0.0):
|
||||
"""
|
||||
Capture a tool use action.
|
||||
|
||||
:param thought: thought content
|
||||
:param tool_name: Name of the tool used
|
||||
:param input_params: Parameters passed to the tool
|
||||
:param output: Output from the tool
|
||||
:param status: Status of the tool execution
|
||||
:param error_message: Error message if the tool execution failed
|
||||
:param execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_result = ToolResult(
|
||||
tool_name=tool_name,
|
||||
input_params=input_params,
|
||||
output=output,
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
action = AgentAction(
|
||||
agent_id=self.id if hasattr(self, 'id') else str(id(self)),
|
||||
agent_name=self.name,
|
||||
action_type=AgentActionType.TOOL_USE,
|
||||
tool_result=tool_result,
|
||||
thought=thought
|
||||
)
|
||||
|
||||
self.captured_actions.append(action)
|
||||
|
||||
return action
|
||||
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, skill_filter=None) -> str:
|
||||
"""
|
||||
Execute single agent task with streaming (based on tool-call)
|
||||
|
||||
This method supports:
|
||||
- Streaming output
|
||||
- Multi-turn reasoning based on tool-call
|
||||
- Event callbacks
|
||||
- Persistent conversation history across calls
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
on_event: Event callback function callback(event: dict)
|
||||
event = {"type": str, "timestamp": float, "data": dict}
|
||||
clear_history: If True, clear conversation history before this call (default: False)
|
||||
skill_filter: Optional list of skill names to include in this run
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
|
||||
Example:
|
||||
# Multi-turn conversation with memory
|
||||
response1 = agent.run_stream("My name is Alice")
|
||||
response2 = agent.run_stream("What's my name?") # Will remember Alice
|
||||
|
||||
# Single-turn without memory
|
||||
response = agent.run_stream("Hello", clear_history=True)
|
||||
"""
|
||||
# Clear history if requested
|
||||
if clear_history:
|
||||
with self.messages_lock:
|
||||
self.messages = []
|
||||
|
||||
# Get model to use
|
||||
if not self.model:
|
||||
raise ValueError("No model available for agent")
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = self.get_full_system_prompt(skill_filter=skill_filter)
|
||||
|
||||
# Create a copy of messages for this execution to avoid concurrent modification
|
||||
# Record the original length to track which messages are new
|
||||
with self.messages_lock:
|
||||
messages_copy = self.messages.copy()
|
||||
original_length = len(self.messages)
|
||||
|
||||
# Get max_context_turns from config
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 20)
|
||||
|
||||
# Create stream executor with copied message history
|
||||
executor = AgentStreamExecutor(
|
||||
agent=self,
|
||||
model=self.model,
|
||||
system_prompt=full_system_prompt,
|
||||
tools=self.tools,
|
||||
max_turns=self.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy, # Pass copied message history
|
||||
max_context_turns=max_context_turns
|
||||
)
|
||||
|
||||
# Execute
|
||||
try:
|
||||
response = executor.run_stream(user_message)
|
||||
except Exception:
|
||||
# If executor cleared its messages (context overflow / message format error),
|
||||
# sync that back to the Agent's own message list so the next request
|
||||
# starts fresh instead of hitting the same overflow forever.
|
||||
if len(executor.messages) == 0:
|
||||
with self.messages_lock:
|
||||
self.messages.clear()
|
||||
logger.info("[Agent] Cleared Agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Sync executor's messages back to agent (thread-safe).
|
||||
# If the executor trimmed context, its message list is shorter than
|
||||
# original_length, so we must replace rather than append.
|
||||
with self.messages_lock:
|
||||
self.messages = list(executor.messages)
|
||||
# Track messages added in this run (user query + all assistant/tool messages)
|
||||
# original_length may exceed executor.messages length after trimming
|
||||
trim_adjusted_start = min(original_length, len(executor.messages))
|
||||
self._last_run_new_messages = list(executor.messages[trim_adjusted_start:])
|
||||
|
||||
# Store executor reference for agent_bridge to access files_to_send
|
||||
self.stream_executor = executor
|
||||
|
||||
# Execute all post-process tools
|
||||
self._execute_post_process_tools()
|
||||
|
||||
return response
|
||||
|
||||
def clear_history(self):
|
||||
"""Clear conversation history and captured actions"""
|
||||
self.messages = []
|
||||
self.captured_actions = []
|
||||
1331
agent/protocol/agent_stream.py
Normal file
1331
agent/protocol/agent_stream.py
Normal file
File diff suppressed because it is too large
Load Diff
27
agent/protocol/context.py
Normal file
27
agent/protocol/context.py
Normal file
@@ -0,0 +1,27 @@
|
||||
class TeamContext:
|
||||
def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100):
|
||||
"""
|
||||
Initialize the TeamContext with a name, description, rules, a list of agents, and a user question.
|
||||
:param name: The name of the group context.
|
||||
:param description: A description of the group context.
|
||||
:param rule: The rules governing the group context.
|
||||
:param agents: A list of agents in the context.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.rule = rule
|
||||
self.agents = agents
|
||||
self.user_task = "" # For backward compatibility
|
||||
self.task = None # Will be a Task instance
|
||||
self.model = None # Will be an instance of LLMModel
|
||||
self.task_short_name = None # Store the task directory name
|
||||
# List of agents that have been executed
|
||||
self.agent_outputs: list = []
|
||||
self.current_steps = 0
|
||||
self.max_steps = max_steps
|
||||
|
||||
|
||||
class AgentOutput:
|
||||
def __init__(self, agent_name: str, output: str):
|
||||
self.agent_name = agent_name
|
||||
self.output = output
|
||||
240
agent/protocol/message_utils.py
Normal file
240
agent/protocol/message_utils.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Message sanitizer — fix broken tool_use / tool_result pairs.
|
||||
|
||||
Provides two public helpers that can be reused across agent_stream.py
|
||||
and any bot that converts messages to OpenAI format:
|
||||
|
||||
1. sanitize_claude_messages(messages)
|
||||
Operates on the internal Claude-format message list (in-place).
|
||||
|
||||
2. drop_orphaned_tool_results_openai(messages)
|
||||
Operates on an already-converted OpenAI-format message list,
|
||||
returning a cleaned copy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Claude-format sanitizer (used by agent_stream)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def sanitize_claude_messages(messages: List[Dict]) -> int:
|
||||
"""
|
||||
Validate and fix a Claude-format message list **in-place**.
|
||||
|
||||
Fixes handled:
|
||||
- Trailing assistant message with tool_use but no following tool_result
|
||||
- Leading orphaned tool_result user messages
|
||||
- Mid-list tool_result blocks whose tool_use_id has no matching
|
||||
tool_use in any preceding assistant message
|
||||
|
||||
Returns the number of messages / blocks removed.
|
||||
"""
|
||||
if not messages:
|
||||
return 0
|
||||
|
||||
removed = 0
|
||||
|
||||
# 1. Remove trailing incomplete tool_use assistant messages
|
||||
while messages:
|
||||
last = messages[-1]
|
||||
if last.get("role") != "assistant":
|
||||
break
|
||||
content = last.get("content", [])
|
||||
if isinstance(content, list) and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
for b in content
|
||||
):
|
||||
logger.warning("⚠️ Removing trailing incomplete tool_use assistant message")
|
||||
messages.pop()
|
||||
removed += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 2. Remove leading orphaned tool_result user messages
|
||||
while messages:
|
||||
first = messages[0]
|
||||
if first.get("role") != "user":
|
||||
break
|
||||
content = first.get("content", [])
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result") \
|
||||
and not _has_block_type(content, "text"):
|
||||
logger.warning("⚠️ Removing leading orphaned tool_result user message")
|
||||
messages.pop(0)
|
||||
removed += 1
|
||||
else:
|
||||
break
|
||||
|
||||
# 3. Iteratively remove unmatched tool_use / tool_result until stable.
|
||||
# Removing one broken message can orphan others (e.g. an assistant msg
|
||||
# with both matched and unmatched tool_use — deleting it orphans the
|
||||
# previously-matched tool_result). Loop until clean.
|
||||
for _ in range(5):
|
||||
use_ids: Set[str] = set()
|
||||
result_ids: Set[str] = set()
|
||||
for msg in messages:
|
||||
for block in (msg.get("content") or []):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
if block.get("type") == "tool_use" and block.get("id"):
|
||||
use_ids.add(block["id"])
|
||||
elif block.get("type") == "tool_result" and block.get("tool_use_id"):
|
||||
result_ids.add(block["tool_use_id"])
|
||||
|
||||
bad_use = use_ids - result_ids
|
||||
bad_result = result_ids - use_ids
|
||||
if not bad_use and not bad_result:
|
||||
break
|
||||
|
||||
pass_removed = 0
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if role == "assistant" and bad_use and any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
and b.get("id") in bad_use for b in content
|
||||
):
|
||||
logger.warning(f"⚠️ Removing assistant msg with unmatched tool_use")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
|
||||
if role == "user" and bad_result and _has_block_type(content, "tool_result"):
|
||||
has_bad = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result for b in content
|
||||
)
|
||||
if has_bad:
|
||||
if not _has_block_type(content, "text"):
|
||||
logger.warning(f"⚠️ Removing user msg with unmatched tool_result")
|
||||
messages.pop(i)
|
||||
pass_removed += 1
|
||||
continue
|
||||
else:
|
||||
before = len(content)
|
||||
msg["content"] = [
|
||||
b for b in content
|
||||
if not (isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
and b.get("tool_use_id") in bad_result)
|
||||
]
|
||||
pass_removed += before - len(msg["content"])
|
||||
|
||||
i += 1
|
||||
|
||||
removed += pass_removed
|
||||
if pass_removed == 0:
|
||||
break
|
||||
|
||||
if removed:
|
||||
logger.info(f"🔧 Message validation: removed {removed} broken message(s)")
|
||||
return removed
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# OpenAI-format sanitizer (used by minimax_bot, openai_compatible_bot)
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def drop_orphaned_tool_results_openai(messages: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
Return a copy of *messages* (OpenAI format) with any ``role=tool``
|
||||
messages removed if their ``tool_call_id`` does not match a
|
||||
``tool_calls[].id`` in a preceding assistant message.
|
||||
"""
|
||||
known_ids: Set[str] = set()
|
||||
cleaned: List[Dict] = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id", "")
|
||||
if tc_id:
|
||||
known_ids.add(tc_id)
|
||||
|
||||
if msg.get("role") == "tool":
|
||||
ref_id = msg.get("tool_call_id", "")
|
||||
if ref_id and ref_id not in known_ids:
|
||||
logger.warning(
|
||||
f"[MessageSanitizer] Dropping orphaned tool result "
|
||||
f"(tool_call_id={ref_id} not in known ids)"
|
||||
)
|
||||
continue
|
||||
cleaned.append(msg)
|
||||
return cleaned
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _has_block_type(content: list, block_type: str) -> bool:
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == block_type
|
||||
for b in content
|
||||
)
|
||||
|
||||
|
||||
def _extract_text_from_content(content) -> str:
|
||||
"""Extract plain text from a message content field (str or list of blocks)."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def compress_turn_to_text_only(turn: Dict) -> Dict:
|
||||
"""
|
||||
Compress a full turn (with tool_use/tool_result chains) into a lightweight
|
||||
text-only turn that keeps only the first user text and the last assistant text.
|
||||
|
||||
This preserves the conversational context (what the user asked and what the
|
||||
agent concluded) while stripping out the bulky intermediate tool interactions.
|
||||
|
||||
Returns a new turn dict with a ``messages`` list; the original is not mutated.
|
||||
"""
|
||||
user_text = ""
|
||||
last_assistant_text = ""
|
||||
|
||||
for msg in turn["messages"]:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, list) and _has_block_type(content, "tool_result"):
|
||||
continue
|
||||
if not user_text:
|
||||
user_text = _extract_text_from_content(content)
|
||||
|
||||
elif role == "assistant":
|
||||
text = _extract_text_from_content(content)
|
||||
if text:
|
||||
last_assistant_text = text
|
||||
|
||||
compressed_messages = []
|
||||
if user_text:
|
||||
compressed_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if last_assistant_text:
|
||||
compressed_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": last_assistant_text}]
|
||||
})
|
||||
|
||||
return {"messages": compressed_messages}
|
||||
57
agent/protocol/models.py
Normal file
57
agent/protocol/models.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Models module for agent system.
|
||||
Provides basic model classes needed by tools and bridge integration.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
"""Request model for LLM operations"""
|
||||
|
||||
def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None,
|
||||
temperature: float = 0.7, max_tokens: Optional[int] = None,
|
||||
stream: bool = False, tools: Optional[List] = None, **kwargs):
|
||||
self.messages = messages or []
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.tools = tools
|
||||
# Allow extra attributes
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LLMModel:
|
||||
"""Base class for LLM models"""
|
||||
|
||||
def __init__(self, model: str = None, **kwargs):
|
||||
self.model = model
|
||||
self.config = kwargs
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with a request.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call not implemented in this context")
|
||||
|
||||
def call_stream(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with streaming.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call_stream not implemented in this context")
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
"""Factory for creating model instances"""
|
||||
|
||||
@staticmethod
|
||||
def create_model(model_type: str, **kwargs):
|
||||
"""
|
||||
Create a model instance based on type.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("ModelFactory.create_model not implemented in this context")
|
||||
97
agent/protocol/result.py
Normal file
97
agent/protocol/result.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from agent.protocol.task import Task, TaskStatus
|
||||
|
||||
|
||||
class AgentActionType(Enum):
|
||||
"""Enum representing different types of agent actions."""
|
||||
TOOL_USE = "tool_use"
|
||||
THINKING = "thinking"
|
||||
FINAL_ANSWER = "final_answer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""
|
||||
Represents the result of a tool use.
|
||||
|
||||
Attributes:
|
||||
tool_name: Name of the tool used
|
||||
input_params: Parameters passed to the tool
|
||||
output: Output from the tool
|
||||
status: Status of the tool execution (success/error)
|
||||
error_message: Error message if the tool execution failed
|
||||
execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_name: str
|
||||
input_params: Dict[str, Any]
|
||||
output: Any
|
||||
status: str
|
||||
error_message: Optional[str] = None
|
||||
execution_time: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""
|
||||
Represents an action taken by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the action
|
||||
agent_id: ID of the agent that performed the action
|
||||
agent_name: Name of the agent that performed the action
|
||||
action_type: Type of action (tool use, thinking, final answer)
|
||||
content: Content of the action (thought content, final answer content)
|
||||
tool_result: Tool use details if action_type is TOOL_USE
|
||||
timestamp: When the action was performed
|
||||
"""
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
action_type: AgentActionType
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
tool_result: Optional[ToolResult] = None
|
||||
thought: Optional[str] = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""
|
||||
Represents the result of an agent's execution.
|
||||
|
||||
Attributes:
|
||||
final_answer: The final answer provided by the agent
|
||||
step_count: Number of steps taken by the agent
|
||||
status: Status of the execution (success/error)
|
||||
error_message: Error message if execution failed
|
||||
"""
|
||||
final_answer: str
|
||||
step_count: int
|
||||
status: str = "success"
|
||||
error_message: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def success(cls, final_answer: str, step_count: int) -> "AgentResult":
|
||||
"""Create a successful result"""
|
||||
return cls(final_answer=final_answer, step_count=step_count)
|
||||
|
||||
@classmethod
|
||||
def error(cls, error_message: str, step_count: int = 0) -> "AgentResult":
|
||||
"""Create an error result"""
|
||||
return cls(
|
||||
final_answer=f"Error: {error_message}",
|
||||
step_count=step_count,
|
||||
status="error",
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
@property
|
||||
def is_error(self) -> bool:
|
||||
"""Check if the result represents an error"""
|
||||
return self.status == "error"
|
||||
96
agent/protocol/task.py
Normal file
96
agent/protocol/task.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""Enum representing different types of tasks."""
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
FILE = "file"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Enum representing the status of a task."""
|
||||
INIT = "init" # Initial state
|
||||
PROCESSING = "processing" # In progress
|
||||
COMPLETED = "completed" # Completed
|
||||
FAILED = "failed" # Failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""
|
||||
Represents a task to be processed by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the task
|
||||
content: The primary text content of the task
|
||||
type: Type of the task
|
||||
status: Current status of the task
|
||||
created_at: Timestamp when the task was created
|
||||
updated_at: Timestamp when the task was last updated
|
||||
metadata: Additional metadata for the task
|
||||
images: List of image URLs or base64 encoded images
|
||||
videos: List of video URLs
|
||||
audios: List of audio URLs or base64 encoded audios
|
||||
files: List of file URLs or paths
|
||||
"""
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
type: TaskType = TaskType.TEXT
|
||||
status: TaskStatus = TaskStatus.INIT
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Media content
|
||||
images: List[str] = field(default_factory=list)
|
||||
videos: List[str] = field(default_factory=list)
|
||||
audios: List[str] = field(default_factory=list)
|
||||
files: List[str] = field(default_factory=list)
|
||||
|
||||
def __init__(self, content: str = "", **kwargs):
|
||||
"""
|
||||
Initialize a Task with content and optional keyword arguments.
|
||||
|
||||
Args:
|
||||
content: The text content of the task
|
||||
**kwargs: Additional attributes to set
|
||||
"""
|
||||
self.id = kwargs.get('id', str(uuid.uuid4()))
|
||||
self.content = content
|
||||
self.type = kwargs.get('type', TaskType.TEXT)
|
||||
self.status = kwargs.get('status', TaskStatus.INIT)
|
||||
self.created_at = kwargs.get('created_at', time.time())
|
||||
self.updated_at = kwargs.get('updated_at', time.time())
|
||||
self.metadata = kwargs.get('metadata', {})
|
||||
self.images = kwargs.get('images', [])
|
||||
self.videos = kwargs.get('videos', [])
|
||||
self.audios = kwargs.get('audios', [])
|
||||
self.files = kwargs.get('files', [])
|
||||
|
||||
def get_text(self) -> str:
|
||||
"""
|
||||
Get the text content of the task.
|
||||
|
||||
Returns:
|
||||
The text content
|
||||
"""
|
||||
return self.content
|
||||
|
||||
def update_status(self, status: TaskStatus) -> None:
|
||||
"""
|
||||
Update the status of the task.
|
||||
|
||||
Args:
|
||||
status: The new status
|
||||
"""
|
||||
self.status = status
|
||||
self.updated_at = time.time()
|
||||
31
agent/skills/__init__.py
Normal file
31
agent/skills/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Skills module for agent system.
|
||||
|
||||
This module provides the framework for loading, managing, and executing skills.
|
||||
Skills are markdown files with frontmatter that provide specialized instructions
|
||||
for specific tasks.
|
||||
"""
|
||||
|
||||
from agent.skills.types import (
|
||||
Skill,
|
||||
SkillEntry,
|
||||
SkillMetadata,
|
||||
SkillInstallSpec,
|
||||
LoadSkillsResult,
|
||||
)
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.manager import SkillManager
|
||||
from agent.skills.service import SkillService
|
||||
from agent.skills.formatter import format_skills_for_prompt
|
||||
|
||||
__all__ = [
|
||||
"Skill",
|
||||
"SkillEntry",
|
||||
"SkillMetadata",
|
||||
"SkillInstallSpec",
|
||||
"LoadSkillsResult",
|
||||
"SkillLoader",
|
||||
"SkillManager",
|
||||
"SkillService",
|
||||
"format_skills_for_prompt",
|
||||
]
|
||||
189
agent/skills/config.py
Normal file
189
agent/skills/config.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Configuration support for skills.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from typing import Dict, Optional, List
|
||||
from agent.skills.types import SkillEntry
|
||||
|
||||
|
||||
def resolve_runtime_platform() -> str:
|
||||
"""Get the current runtime platform."""
|
||||
return platform.system().lower()
|
||||
|
||||
|
||||
def has_binary(bin_name: str) -> bool:
|
||||
"""
|
||||
Check if a binary is available in PATH.
|
||||
|
||||
:param bin_name: Binary name to check
|
||||
:return: True if binary is available
|
||||
"""
|
||||
import shutil
|
||||
return shutil.which(bin_name) is not None
|
||||
|
||||
|
||||
def has_any_binary(bin_names: List[str]) -> bool:
|
||||
"""
|
||||
Check if any of the given binaries is available.
|
||||
|
||||
:param bin_names: List of binary names to check
|
||||
:return: True if at least one binary is available
|
||||
"""
|
||||
return any(has_binary(bin_name) for bin_name in bin_names)
|
||||
|
||||
|
||||
def has_env_var(env_name: str) -> bool:
|
||||
"""
|
||||
Check if an environment variable is set.
|
||||
|
||||
:param env_name: Environment variable name
|
||||
:return: True if environment variable is set
|
||||
"""
|
||||
return env_name in os.environ and bool(os.environ[env_name].strip())
|
||||
|
||||
|
||||
def get_skill_config(config: Optional[Dict], skill_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get skill-specific configuration.
|
||||
|
||||
:param config: Global configuration dictionary
|
||||
:param skill_name: Name of the skill
|
||||
:return: Skill configuration or None
|
||||
"""
|
||||
if not config:
|
||||
return None
|
||||
|
||||
skills_config = config.get('skills', {})
|
||||
if not isinstance(skills_config, dict):
|
||||
return None
|
||||
|
||||
entries = skills_config.get('entries', {})
|
||||
if not isinstance(entries, dict):
|
||||
return None
|
||||
|
||||
return entries.get(skill_name)
|
||||
|
||||
|
||||
def should_include_skill(
|
||||
entry: SkillEntry,
|
||||
config: Optional[Dict] = None,
|
||||
current_platform: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a skill should be included based on requirements.
|
||||
|
||||
Simple rule: Skills are auto-enabled if their requirements are met.
|
||||
- Has required API keys → enabled
|
||||
- Missing API keys → disabled
|
||||
- Wrong keys → enabled but will fail at runtime (LLM will handle error)
|
||||
|
||||
:param entry: SkillEntry to check
|
||||
:param config: Configuration dictionary (currently unused, reserved for future)
|
||||
:param current_platform: Current platform (default: auto-detect)
|
||||
:return: True if skill should be included
|
||||
"""
|
||||
metadata = entry.metadata
|
||||
|
||||
# No metadata = always include (no requirements)
|
||||
if not metadata:
|
||||
return True
|
||||
|
||||
# Check platform requirements (can't work on wrong platform)
|
||||
if metadata.os:
|
||||
platform_name = current_platform or resolve_runtime_platform()
|
||||
# Map common platform names
|
||||
platform_map = {
|
||||
'darwin': 'darwin',
|
||||
'linux': 'linux',
|
||||
'windows': 'win32',
|
||||
}
|
||||
normalized_platform = platform_map.get(platform_name, platform_name)
|
||||
|
||||
if normalized_platform not in metadata.os:
|
||||
return False
|
||||
|
||||
# If skill has 'always: true', include it regardless of other requirements
|
||||
if metadata.always:
|
||||
return True
|
||||
|
||||
# Check requirements
|
||||
if metadata.requires:
|
||||
# Check required binaries (all must be present)
|
||||
required_bins = metadata.requires.get('bins', [])
|
||||
if required_bins:
|
||||
if not all(has_binary(bin_name) for bin_name in required_bins):
|
||||
return False
|
||||
|
||||
# Check anyBins (at least one must be present)
|
||||
any_bins = metadata.requires.get('anyBins', [])
|
||||
if any_bins:
|
||||
if not has_any_binary(any_bins):
|
||||
return False
|
||||
|
||||
# Check environment variables (API keys)
|
||||
# All required env vars must be set
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
for env_name in required_env:
|
||||
if not has_env_var(env_name):
|
||||
return False
|
||||
|
||||
# Check anyEnv (at least one must be present)
|
||||
any_env = metadata.requires.get('anyEnv', [])
|
||||
if any_env:
|
||||
if not any(has_env_var(e) for e in any_env):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_config_path_truthy(config: Dict, path: str) -> bool:
|
||||
"""
|
||||
Check if a config path resolves to a truthy value.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:param path: Dot-separated path (e.g., 'skills.enabled')
|
||||
:return: True if path resolves to truthy value
|
||||
"""
|
||||
parts = path.split('.')
|
||||
current = config
|
||||
|
||||
for part in parts:
|
||||
if not isinstance(current, dict):
|
||||
return False
|
||||
current = current.get(part)
|
||||
if current is None:
|
||||
return False
|
||||
|
||||
# Check if value is truthy
|
||||
if isinstance(current, bool):
|
||||
return current
|
||||
if isinstance(current, (int, float)):
|
||||
return current != 0
|
||||
if isinstance(current, str):
|
||||
return bool(current.strip())
|
||||
|
||||
return bool(current)
|
||||
|
||||
|
||||
def resolve_config_path(config: Dict, path: str):
|
||||
"""
|
||||
Resolve a dot-separated config path to its value.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:param path: Dot-separated path
|
||||
:return: Value at path or None
|
||||
"""
|
||||
parts = path.split('.')
|
||||
current = config
|
||||
|
||||
for part in parts:
|
||||
if not isinstance(current, dict):
|
||||
return None
|
||||
current = current.get(part)
|
||||
if current is None:
|
||||
return None
|
||||
|
||||
return current
|
||||
61
agent/skills/formatter.py
Normal file
61
agent/skills/formatter.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Skill formatter for generating prompts from skills.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
|
||||
|
||||
def format_skills_for_prompt(skills: List[Skill]) -> str:
|
||||
"""
|
||||
Format skills for inclusion in a system prompt.
|
||||
|
||||
Uses XML format per Agent Skills standard.
|
||||
Skills with disable_model_invocation=True are excluded.
|
||||
|
||||
:param skills: List of skills to format
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
# Filter out skills that should not be invoked by the model
|
||||
visible_skills = [s for s in skills if not s.disable_model_invocation]
|
||||
|
||||
if not visible_skills:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"",
|
||||
"<available_skills>",
|
||||
]
|
||||
|
||||
for skill in visible_skills:
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <location>{_escape_xml(skill.file_path)}</location>")
|
||||
lines.append(f" <base_dir>{_escape_xml(skill.base_dir)}</base_dir>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</available_skills>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
|
||||
"""
|
||||
Format skill entries for inclusion in a system prompt.
|
||||
|
||||
:param entries: List of skill entries to format
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
skills = [entry.skill for entry in entries]
|
||||
return format_skills_for_prompt(skills)
|
||||
|
||||
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (text
|
||||
.replace('&', '&')
|
||||
.replace('<', '<')
|
||||
.replace('>', '>')
|
||||
.replace('"', '"')
|
||||
.replace("'", '''))
|
||||
172
agent/skills/frontmatter.py
Normal file
172
agent/skills/frontmatter.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Frontmatter parsing for skills.
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from agent.skills.types import SkillMetadata, SkillInstallSpec
|
||||
|
||||
|
||||
def parse_frontmatter(content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse YAML-style frontmatter from markdown content.
|
||||
|
||||
Returns a dictionary of frontmatter fields.
|
||||
"""
|
||||
frontmatter = {}
|
||||
|
||||
# Match frontmatter block between --- markers
|
||||
match = re.match(r'^---\s*\n(.*?)\n---\s*\n', content, re.DOTALL)
|
||||
if not match:
|
||||
return frontmatter
|
||||
|
||||
frontmatter_text = match.group(1)
|
||||
|
||||
# Try to use PyYAML for proper YAML parsing
|
||||
try:
|
||||
import yaml
|
||||
frontmatter = yaml.safe_load(frontmatter_text)
|
||||
if not isinstance(frontmatter, dict):
|
||||
frontmatter = {}
|
||||
return frontmatter
|
||||
except ImportError:
|
||||
# Fallback to simple parsing if PyYAML not available
|
||||
pass
|
||||
except Exception:
|
||||
# If YAML parsing fails, fall back to simple parsing
|
||||
pass
|
||||
|
||||
# Simple YAML-like parsing (supports key: value format only)
|
||||
# This is a fallback for when PyYAML is not available
|
||||
for line in frontmatter_text.split('\n'):
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
if ':' in line:
|
||||
key, value = line.split(':', 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Try to parse as JSON if it looks like JSON
|
||||
if value.startswith('{') or value.startswith('['):
|
||||
try:
|
||||
value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
# Parse boolean values
|
||||
elif value.lower() in ('true', 'false'):
|
||||
value = value.lower() == 'true'
|
||||
# Parse numbers
|
||||
elif value.isdigit():
|
||||
value = int(value)
|
||||
|
||||
frontmatter[key] = value
|
||||
|
||||
return frontmatter
|
||||
|
||||
|
||||
def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
"""
|
||||
Parse skill metadata from frontmatter.
|
||||
|
||||
Looks for 'metadata' field containing JSON with skill configuration.
|
||||
"""
|
||||
metadata_raw = frontmatter.get('metadata')
|
||||
if not metadata_raw:
|
||||
return None
|
||||
|
||||
# If it's a string, try to parse as JSON
|
||||
if isinstance(metadata_raw, str):
|
||||
try:
|
||||
metadata_raw = json.loads(metadata_raw)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
if not isinstance(metadata_raw, dict):
|
||||
return None
|
||||
|
||||
# Use metadata_raw directly (COW format)
|
||||
meta_obj = metadata_raw
|
||||
|
||||
# Parse install specs
|
||||
install_specs = []
|
||||
install_raw = meta_obj.get('install', [])
|
||||
if isinstance(install_raw, list):
|
||||
for spec_raw in install_raw:
|
||||
if not isinstance(spec_raw, dict):
|
||||
continue
|
||||
|
||||
kind = spec_raw.get('kind', spec_raw.get('type', '')).lower()
|
||||
if not kind:
|
||||
continue
|
||||
|
||||
spec = SkillInstallSpec(
|
||||
kind=kind,
|
||||
id=spec_raw.get('id'),
|
||||
label=spec_raw.get('label'),
|
||||
bins=_normalize_string_list(spec_raw.get('bins')),
|
||||
os=_normalize_string_list(spec_raw.get('os')),
|
||||
formula=spec_raw.get('formula'),
|
||||
package=spec_raw.get('package'),
|
||||
module=spec_raw.get('module'),
|
||||
url=spec_raw.get('url'),
|
||||
archive=spec_raw.get('archive'),
|
||||
extract=spec_raw.get('extract', False),
|
||||
strip_components=spec_raw.get('stripComponents'),
|
||||
target_dir=spec_raw.get('targetDir'),
|
||||
)
|
||||
install_specs.append(spec)
|
||||
|
||||
# Parse requires
|
||||
requires = {}
|
||||
requires_raw = meta_obj.get('requires', {})
|
||||
if isinstance(requires_raw, dict):
|
||||
for key, value in requires_raw.items():
|
||||
requires[key] = _normalize_string_list(value)
|
||||
|
||||
return SkillMetadata(
|
||||
always=meta_obj.get('always', False),
|
||||
skill_key=meta_obj.get('skillKey'),
|
||||
primary_env=meta_obj.get('primaryEnv'),
|
||||
emoji=meta_obj.get('emoji'),
|
||||
homepage=meta_obj.get('homepage'),
|
||||
os=_normalize_string_list(meta_obj.get('os')),
|
||||
requires=requires,
|
||||
install=install_specs,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_string_list(value: Any) -> List[str]:
|
||||
"""Normalize a value to a list of strings."""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
if isinstance(value, list):
|
||||
return [str(v).strip() for v in value if v]
|
||||
|
||||
if isinstance(value, str):
|
||||
return [v.strip() for v in value.split(',') if v.strip()]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def parse_boolean_value(value: Optional[str], default: bool = False) -> bool:
|
||||
"""Parse a boolean value from frontmatter."""
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ('true', '1', 'yes', 'on')
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def get_frontmatter_value(frontmatter: Dict[str, Any], key: str) -> Optional[str]:
|
||||
"""Get a frontmatter value as a string."""
|
||||
value = frontmatter.get(key)
|
||||
return str(value) if value is not None else None
|
||||
278
agent/skills/loader.py
Normal file
278
agent/skills/loader.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Skill loader for discovering and loading skills from directories.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry, LoadSkillsResult, SkillMetadata
|
||||
from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_boolean_value, get_frontmatter_value
|
||||
|
||||
|
||||
class SkillLoader:
|
||||
"""Loads skills from various directories."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load skills from a directory.
|
||||
|
||||
Discovery rules:
|
||||
- Direct .md files in the root directory
|
||||
- Recursive SKILL.md files under subdirectories
|
||||
|
||||
:param dir_path: Directory path to scan
|
||||
:param source: Source identifier ('builtin' or 'custom')
|
||||
:return: LoadSkillsResult with skills and diagnostics
|
||||
"""
|
||||
skills = []
|
||||
diagnostics = []
|
||||
|
||||
if not os.path.exists(dir_path):
|
||||
diagnostics.append(f"Directory does not exist: {dir_path}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
if not os.path.isdir(dir_path):
|
||||
diagnostics.append(f"Path is not a directory: {dir_path}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
# Load skills from root-level .md files and subdirectories
|
||||
result = self._load_skills_recursive(dir_path, source, include_root_files=True)
|
||||
|
||||
return result
|
||||
|
||||
def _load_skills_recursive(
|
||||
self,
|
||||
dir_path: str,
|
||||
source: str,
|
||||
include_root_files: bool = False
|
||||
) -> LoadSkillsResult:
|
||||
"""
|
||||
Recursively load skills from a directory.
|
||||
|
||||
:param dir_path: Directory to scan
|
||||
:param source: Source identifier
|
||||
:param include_root_files: Whether to include root-level .md files
|
||||
:return: LoadSkillsResult
|
||||
"""
|
||||
skills = []
|
||||
diagnostics = []
|
||||
|
||||
try:
|
||||
entries = os.listdir(dir_path)
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
for entry in entries:
|
||||
# Skip hidden files and directories
|
||||
if entry.startswith('.'):
|
||||
continue
|
||||
|
||||
# Skip common non-skill directories
|
||||
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
|
||||
continue
|
||||
|
||||
full_path = os.path.join(dir_path, entry)
|
||||
|
||||
# Handle directories
|
||||
if os.path.isdir(full_path):
|
||||
# Recursively scan subdirectories
|
||||
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
|
||||
skills.extend(sub_result.skills)
|
||||
diagnostics.extend(sub_result.diagnostics)
|
||||
continue
|
||||
|
||||
# Handle files
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
|
||||
# Check if this is a skill file
|
||||
is_root_md = include_root_files and entry.endswith('.md')
|
||||
is_skill_md = not include_root_files and entry == 'SKILL.md'
|
||||
|
||||
if not (is_root_md or is_skill_md):
|
||||
continue
|
||||
|
||||
# Load the skill
|
||||
skill_result = self._load_skill_from_file(full_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
diagnostics.extend(skill_result.diagnostics)
|
||||
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
def _load_skill_from_file(self, file_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load a single skill from a markdown file.
|
||||
|
||||
:param file_path: Path to the skill markdown file
|
||||
:param source: Source identifier
|
||||
:return: LoadSkillsResult
|
||||
"""
|
||||
diagnostics = []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to read skill file {file_path}: {e}")
|
||||
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||
|
||||
# Parse frontmatter
|
||||
frontmatter = parse_frontmatter(content)
|
||||
|
||||
# Get skill name and description
|
||||
skill_dir = os.path.dirname(file_path)
|
||||
parent_dir_name = os.path.basename(skill_dir)
|
||||
|
||||
name = frontmatter.get('name', parent_dir_name)
|
||||
description = frontmatter.get('description', '')
|
||||
|
||||
# Normalize name (handle both string and list)
|
||||
if isinstance(name, list):
|
||||
name = name[0] if name else parent_dir_name
|
||||
elif not isinstance(name, str):
|
||||
name = str(name) if name else parent_dir_name
|
||||
|
||||
# Normalize description (handle both string and list)
|
||||
if isinstance(description, list):
|
||||
description = ' '.join(str(d) for d in description if d)
|
||||
elif not isinstance(description, str):
|
||||
description = str(description) if description else ''
|
||||
|
||||
# Special handling for linkai-agent: dynamically load apps from config.json
|
||||
if name == 'linkai-agent':
|
||||
description = self._load_linkai_agent_description(skill_dir, description)
|
||||
|
||||
if not description or not description.strip():
|
||||
diagnostics.append(f"Skill {name} has no description: {file_path}")
|
||||
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||
|
||||
# Parse disable-model-invocation flag
|
||||
disable_model_invocation = parse_boolean_value(
|
||||
get_frontmatter_value(frontmatter, 'disable-model-invocation'),
|
||||
default=False
|
||||
)
|
||||
|
||||
# Create skill object
|
||||
skill = Skill(
|
||||
name=name,
|
||||
description=description,
|
||||
file_path=file_path,
|
||||
base_dir=skill_dir,
|
||||
source=source,
|
||||
content=content,
|
||||
disable_model_invocation=disable_model_invocation,
|
||||
frontmatter=frontmatter,
|
||||
)
|
||||
|
||||
return LoadSkillsResult(skills=[skill], diagnostics=diagnostics)
|
||||
|
||||
def _load_linkai_agent_description(self, skill_dir: str, default_description: str) -> str:
|
||||
"""
|
||||
Dynamically load LinkAI agent description from config.json
|
||||
|
||||
:param skill_dir: Skill directory
|
||||
:param default_description: Default description from SKILL.md
|
||||
:return: Dynamic description with app list
|
||||
"""
|
||||
import json
|
||||
|
||||
config_path = os.path.join(skill_dir, "config.json")
|
||||
|
||||
# Without config.json, skip this skill entirely (return empty to trigger exclusion)
|
||||
if not os.path.exists(config_path):
|
||||
logger.debug(f"[SkillLoader] linkai-agent skipped: no config.json found")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
apps = config.get("apps", [])
|
||||
if not apps:
|
||||
return default_description
|
||||
|
||||
# Build dynamic description with app details
|
||||
app_descriptions = "; ".join([
|
||||
f"{app['app_name']}({app['app_code']}: {app['app_description']})"
|
||||
for app in apps
|
||||
])
|
||||
|
||||
return f"Call LinkAI apps/workflows. {app_descriptions}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[SkillLoader] Failed to load linkai-agent config: {e}")
|
||||
return default_description
|
||||
|
||||
def load_all_skills(
|
||||
self,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
) -> Dict[str, SkillEntry]:
|
||||
"""
|
||||
Load skills from builtin and custom directories.
|
||||
|
||||
Precedence (lowest to highest):
|
||||
1. builtin — project root ``skills/``, shipped with the codebase
|
||||
2. custom — workspace ``skills/``, installed via cloud console or skill creator
|
||||
|
||||
Same-name custom skills override builtin ones.
|
||||
|
||||
:param builtin_dir: Built-in skills directory
|
||||
:param custom_dir: Custom skills directory
|
||||
:return: Dictionary mapping skill name to SkillEntry
|
||||
"""
|
||||
skill_map: Dict[str, SkillEntry] = {}
|
||||
all_diagnostics = []
|
||||
|
||||
# Load builtin skills (lower precedence)
|
||||
if builtin_dir and os.path.exists(builtin_dir):
|
||||
result = self.load_skills_from_dir(builtin_dir, source='builtin')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load custom skills (higher precedence, overrides builtin)
|
||||
if custom_dir and os.path.exists(custom_dir):
|
||||
result = self.load_skills_from_dir(custom_dir, source='custom')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Log diagnostics
|
||||
if all_diagnostics:
|
||||
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
|
||||
for diag in all_diagnostics[:5]:
|
||||
logger.debug(f" - {diag}")
|
||||
|
||||
logger.debug(f"Loaded {len(skill_map)} skills total")
|
||||
|
||||
return skill_map
|
||||
|
||||
def _create_skill_entry(self, skill: Skill) -> SkillEntry:
|
||||
"""
|
||||
Create a SkillEntry from a Skill with parsed metadata.
|
||||
|
||||
:param skill: The skill to create an entry for
|
||||
:return: SkillEntry with metadata
|
||||
"""
|
||||
metadata = parse_metadata(skill.frontmatter)
|
||||
|
||||
# Parse user-invocable flag
|
||||
user_invocable = parse_boolean_value(
|
||||
get_frontmatter_value(skill.frontmatter, 'user-invocable'),
|
||||
default=True
|
||||
)
|
||||
|
||||
return SkillEntry(
|
||||
skill=skill,
|
||||
metadata=metadata,
|
||||
user_invocable=user_invocable,
|
||||
)
|
||||
303
agent/skills/manager.py
Normal file
303
agent/skills/manager.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Skill manager for managing skill lifecycle and operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry, SkillSnapshot
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.formatter import format_skill_entries_for_prompt
|
||||
|
||||
SKILLS_CONFIG_FILE = "skills_config.json"
|
||||
|
||||
|
||||
class SkillManager:
|
||||
"""Manages skills for an agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
config: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the skill manager.
|
||||
|
||||
:param builtin_dir: Built-in skills directory (project root ``skills/``)
|
||||
:param custom_dir: Custom skills directory (workspace ``skills/``)
|
||||
:param config: Configuration dictionary
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
self.builtin_dir = builtin_dir or os.path.join(project_root, 'skills')
|
||||
self.custom_dir = custom_dir or os.path.join(project_root, 'workspace', 'skills')
|
||||
self.config = config or {}
|
||||
self._skills_config_path = os.path.join(self.custom_dir, SKILLS_CONFIG_FILE)
|
||||
|
||||
# skills_config: full skill metadata keyed by name
|
||||
# { "web-fetch": {"name": ..., "description": ..., "source": ..., "enabled": true}, ... }
|
||||
self.skills_config: Dict[str, dict] = {}
|
||||
|
||||
self.loader = SkillLoader()
|
||||
self.skills: Dict[str, SkillEntry] = {}
|
||||
|
||||
# Load skills on initialization
|
||||
self.refresh_skills()
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Reload all skills from builtin and custom directories, then sync config."""
|
||||
self.skills = self.loader.load_all_skills(
|
||||
builtin_dir=self.builtin_dir,
|
||||
custom_dir=self.custom_dir,
|
||||
)
|
||||
self._sync_skills_config()
|
||||
logger.debug(f"SkillManager: Loaded {len(self.skills)} skills")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# skills_config.json management
|
||||
# ------------------------------------------------------------------
|
||||
def _load_skills_config(self) -> Dict[str, dict]:
|
||||
"""Load skills_config.json from custom_dir. Returns empty dict if not found."""
|
||||
if not os.path.exists(self._skills_config_path):
|
||||
return {}
|
||||
try:
|
||||
with open(self._skills_config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning(f"[SkillManager] Failed to load {SKILLS_CONFIG_FILE}: {e}")
|
||||
return {}
|
||||
|
||||
def _save_skills_config(self):
|
||||
"""Persist skills_config to custom_dir/skills_config.json."""
|
||||
os.makedirs(self.custom_dir, exist_ok=True)
|
||||
try:
|
||||
with open(self._skills_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.skills_config, f, indent=4, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillManager] Failed to save {SKILLS_CONFIG_FILE}: {e}")
|
||||
|
||||
def _sync_skills_config(self):
|
||||
"""
|
||||
Merge directory-scanned skills with the persisted config file.
|
||||
|
||||
- New skills discovered on disk are added with enabled=True.
|
||||
- Skills that no longer exist on disk are removed.
|
||||
- Existing entries preserve their enabled state; name/description/source
|
||||
are refreshed from the latest scan.
|
||||
"""
|
||||
saved = self._load_skills_config()
|
||||
merged: Dict[str, dict] = {}
|
||||
|
||||
for name, entry in self.skills.items():
|
||||
skill = entry.skill
|
||||
prev = saved.get(name, {})
|
||||
# category priority: persisted config (set by cloud) > default "skill"
|
||||
category = prev.get("category", "skill")
|
||||
merged[name] = {
|
||||
"name": name,
|
||||
"description": skill.description,
|
||||
"source": skill.source,
|
||||
"enabled": prev.get("enabled", True),
|
||||
"category": category,
|
||||
}
|
||||
|
||||
self.skills_config = merged
|
||||
self._save_skills_config()
|
||||
|
||||
def is_skill_enabled(self, name: str) -> bool:
|
||||
"""
|
||||
Check if a skill is enabled according to skills_config.
|
||||
|
||||
:param name: skill name
|
||||
:return: True if enabled (default True if not in config)
|
||||
"""
|
||||
entry = self.skills_config.get(name)
|
||||
if entry is None:
|
||||
return True
|
||||
return entry.get("enabled", True)
|
||||
|
||||
def set_skill_enabled(self, name: str, enabled: bool):
|
||||
"""
|
||||
Set a skill's enabled state and persist.
|
||||
|
||||
:param name: skill name
|
||||
:param enabled: True to enable, False to disable
|
||||
"""
|
||||
if name not in self.skills_config:
|
||||
raise ValueError(f"skill '{name}' not found in config")
|
||||
self.skills_config[name]["enabled"] = enabled
|
||||
self._save_skills_config()
|
||||
|
||||
def get_skills_config(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Return the full skills_config dict (for query API).
|
||||
|
||||
:return: copy of skills_config
|
||||
"""
|
||||
return dict(self.skills_config)
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
Get a skill by name.
|
||||
|
||||
:param name: Skill name
|
||||
:return: SkillEntry or None if not found
|
||||
"""
|
||||
return self.skills.get(name)
|
||||
|
||||
def list_skills(self) -> List[SkillEntry]:
|
||||
"""
|
||||
Get all loaded skills.
|
||||
|
||||
:return: List of all skill entries
|
||||
"""
|
||||
return list(self.skills.values())
|
||||
|
||||
def filter_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
include_disabled: bool = False,
|
||||
) -> List[SkillEntry]:
|
||||
"""
|
||||
Filter skills based on criteria.
|
||||
|
||||
Simple rule: Skills are auto-enabled if requirements are met.
|
||||
- Has required API keys -> included
|
||||
- Missing API keys -> excluded
|
||||
|
||||
:param skill_filter: List of skill names to include (None = all)
|
||||
:param include_disabled: Whether to include disabled skills
|
||||
:return: Filtered list of skill entries
|
||||
"""
|
||||
from agent.skills.config import should_include_skill
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Check requirements (platform, binaries, env vars)
|
||||
entries = [e for e in entries if should_include_skill(e, self.config)]
|
||||
|
||||
# Apply skill filter
|
||||
if skill_filter is not None:
|
||||
normalized = []
|
||||
for item in skill_filter:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
elif isinstance(item, list):
|
||||
for subitem in item:
|
||||
if isinstance(subitem, str):
|
||||
name = subitem.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
if normalized:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Filter out disabled skills based on skills_config.json
|
||||
if not include_disabled:
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
return entries
|
||||
|
||||
def build_skills_prompt(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a formatted prompt containing available skills.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt
|
||||
"""
|
||||
from common.log import logger
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
logger.debug(f"[SkillManager] Filtered {len(entries)} skills for prompt (total: {len(self.skills)})")
|
||||
if entries:
|
||||
skill_names = [e.skill.name for e in entries]
|
||||
logger.debug(f"[SkillManager] Skills to include: {skill_names}")
|
||||
result = format_skill_entries_for_prompt(entries)
|
||||
logger.debug(f"[SkillManager] Generated prompt length: {len(result)}")
|
||||
return result
|
||||
|
||||
def build_skill_snapshot(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
version: Optional[int] = None,
|
||||
) -> SkillSnapshot:
|
||||
"""
|
||||
Build a snapshot of skills for a specific run.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:param version: Optional version number for the snapshot
|
||||
:return: SkillSnapshot
|
||||
"""
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
prompt = format_skill_entries_for_prompt(entries)
|
||||
|
||||
skills_info = []
|
||||
resolved_skills = []
|
||||
|
||||
for entry in entries:
|
||||
skills_info.append({
|
||||
'name': entry.skill.name,
|
||||
'primary_env': entry.metadata.primary_env if entry.metadata else None,
|
||||
})
|
||||
resolved_skills.append(entry.skill)
|
||||
|
||||
return SkillSnapshot(
|
||||
prompt=prompt,
|
||||
skills=skills_info,
|
||||
resolved_skills=resolved_skills,
|
||||
version=version,
|
||||
)
|
||||
|
||||
def sync_skills_to_workspace(self, target_workspace_dir: str):
|
||||
"""
|
||||
Sync all loaded skills to a target workspace directory.
|
||||
|
||||
This is useful for sandbox environments where skills need to be copied.
|
||||
|
||||
:param target_workspace_dir: Target workspace directory
|
||||
"""
|
||||
import shutil
|
||||
|
||||
target_skills_dir = os.path.join(target_workspace_dir, 'skills')
|
||||
|
||||
# Remove existing skills directory
|
||||
if os.path.exists(target_skills_dir):
|
||||
shutil.rmtree(target_skills_dir)
|
||||
|
||||
# Create new skills directory
|
||||
os.makedirs(target_skills_dir, exist_ok=True)
|
||||
|
||||
# Copy each skill
|
||||
for entry in self.skills.values():
|
||||
skill_name = entry.skill.name
|
||||
source_dir = entry.skill.base_dir
|
||||
target_dir = os.path.join(target_skills_dir, skill_name)
|
||||
|
||||
try:
|
||||
shutil.copytree(source_dir, target_dir)
|
||||
logger.debug(f"Synced skill '{skill_name}' to {target_dir}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync skill '{skill_name}': {e}")
|
||||
|
||||
logger.info(f"Synced {len(self.skills)} skills to {target_skills_dir}")
|
||||
|
||||
def get_skill_by_key(self, skill_key: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
Get a skill by its skill key (which may differ from name).
|
||||
|
||||
:param skill_key: Skill key to look up
|
||||
:return: SkillEntry or None
|
||||
"""
|
||||
for entry in self.skills.values():
|
||||
if entry.metadata and entry.metadata.skill_key == skill_key:
|
||||
return entry
|
||||
if entry.skill.name == skill_key:
|
||||
return entry
|
||||
return None
|
||||
285
agent/skills/service.py
Normal file
285
agent/skills/service.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Skill service for handling skill CRUD operations.
|
||||
|
||||
This service provides a unified interface for managing skills, which can be
|
||||
called from the cloud control client (LinkAI), the local web console, or any
|
||||
other management entry point.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import tempfile
|
||||
from typing import Dict, List, Optional
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
from agent.skills.manager import SkillManager
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""
|
||||
High-level service for skill lifecycle management.
|
||||
Wraps SkillManager and provides network-aware operations such as
|
||||
downloading skill files from remote URLs.
|
||||
"""
|
||||
|
||||
def __init__(self, skill_manager: SkillManager):
|
||||
"""
|
||||
:param skill_manager: The SkillManager instance to operate on
|
||||
"""
|
||||
self.manager = skill_manager
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# query
|
||||
# ------------------------------------------------------------------
|
||||
def query(self) -> List[dict]:
|
||||
"""
|
||||
Query all skills and return a serialisable list.
|
||||
Reads from skills_config.json (refreshes from disk if needed).
|
||||
|
||||
:return: list of skill info dicts
|
||||
"""
|
||||
self.manager.refresh_skills()
|
||||
config = self.manager.get_skills_config()
|
||||
result = list(config.values())
|
||||
logger.info(f"[SkillService] query: {len(result)} skills found")
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# add / install
|
||||
# ------------------------------------------------------------------
|
||||
def add(self, payload: dict) -> None:
|
||||
"""
|
||||
Add (install) a skill from a remote payload.
|
||||
|
||||
Supported payload types:
|
||||
|
||||
1. ``type: "url"`` – download individual files::
|
||||
|
||||
{
|
||||
"name": "web_search",
|
||||
"type": "url",
|
||||
"enabled": true,
|
||||
"files": [
|
||||
{"url": "https://...", "path": "README.md"},
|
||||
{"url": "https://...", "path": "scripts/main.py"}
|
||||
]
|
||||
}
|
||||
|
||||
2. ``type: "package"`` – download a zip archive and extract::
|
||||
|
||||
{
|
||||
"name": "plugin-custom-tool",
|
||||
"type": "package",
|
||||
"category": "skills",
|
||||
"enabled": true,
|
||||
"files": [{"url": "https://cdn.example.com/skills/custom-tool.zip"}]
|
||||
}
|
||||
|
||||
:param payload: skill add payload from server
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
payload_type = payload.get("type", "url")
|
||||
|
||||
if payload_type == "package":
|
||||
self._add_package(name, payload)
|
||||
else:
|
||||
self._add_url(name, payload)
|
||||
|
||||
self.manager.refresh_skills()
|
||||
|
||||
category = payload.get("category")
|
||||
if category and name in self.manager.skills_config:
|
||||
self.manager.skills_config[name]["category"] = category
|
||||
self.manager._save_skills_config()
|
||||
|
||||
def _add_url(self, name: str, payload: dict) -> None:
|
||||
"""Install a skill by downloading individual files."""
|
||||
files = payload.get("files", [])
|
||||
if not files:
|
||||
raise ValueError("skill files list is empty")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
tmp_dir = skill_dir + ".tmp"
|
||||
if os.path.exists(tmp_dir):
|
||||
shutil.rmtree(tmp_dir)
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
for file_info in files:
|
||||
url = file_info.get("url")
|
||||
rel_path = file_info.get("path")
|
||||
if not url or not rel_path:
|
||||
logger.warning(f"[SkillService] add: skip invalid file entry {file_info}")
|
||||
continue
|
||||
dest = os.path.join(tmp_dir, rel_path)
|
||||
self._download_file(url, dest)
|
||||
except Exception:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
os.rename(tmp_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via url ({len(files)} files)")
|
||||
|
||||
def _add_package(self, name: str, payload: dict) -> None:
|
||||
"""
|
||||
Install a skill by downloading a zip archive and extracting it.
|
||||
|
||||
If the archive contains a single top-level directory, that directory
|
||||
is used as the skill folder directly; otherwise a new directory named
|
||||
after the skill is created to hold the extracted contents.
|
||||
"""
|
||||
files = payload.get("files", [])
|
||||
if not files or not files[0].get("url"):
|
||||
raise ValueError("package url is required")
|
||||
|
||||
url = files[0]["url"]
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
zip_path = os.path.join(tmp_dir, "package.zip")
|
||||
self._download_file(url, zip_path)
|
||||
|
||||
if not zipfile.is_zipfile(zip_path):
|
||||
raise ValueError(f"downloaded file is not a valid zip archive: {url}")
|
||||
|
||||
extract_dir = os.path.join(tmp_dir, "extracted")
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
zf.extractall(extract_dir)
|
||||
|
||||
# Determine the actual content root.
|
||||
# If the zip has a single top-level directory, use its contents
|
||||
# so the skill folder is clean (no extra nesting).
|
||||
top_items = [
|
||||
item for item in os.listdir(extract_dir)
|
||||
if not item.startswith(".")
|
||||
]
|
||||
if len(top_items) == 1:
|
||||
single = os.path.join(extract_dir, top_items[0])
|
||||
if os.path.isdir(single):
|
||||
extract_dir = single
|
||||
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
shutil.copytree(extract_dir, skill_dir)
|
||||
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed via package ({url})")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# open / close (enable / disable)
|
||||
# ------------------------------------------------------------------
|
||||
def open(self, payload: dict) -> None:
|
||||
"""
|
||||
Enable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=True)
|
||||
logger.info(f"[SkillService] open: skill '{name}' enabled")
|
||||
|
||||
def close(self, payload: dict) -> None:
|
||||
"""
|
||||
Disable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=False)
|
||||
logger.info(f"[SkillService] close: skill '{name}' disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete
|
||||
# ------------------------------------------------------------------
|
||||
def delete(self, payload: dict) -> None:
|
||||
"""
|
||||
Delete a skill by removing its directory entirely.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
logger.info(f"[SkillService] delete: removed directory {skill_dir}")
|
||||
else:
|
||||
logger.warning(f"[SkillService] delete: skill directory not found: {skill_dir}")
|
||||
|
||||
# Refresh will remove the deleted skill from config automatically
|
||||
self.manager.refresh_skills()
|
||||
logger.info(f"[SkillService] delete: skill '{name}' deleted")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch - single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a skill management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
:param action: one of query / add / open / close / delete
|
||||
:param payload: action-specific payload (may be None for query)
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "query":
|
||||
result_payload = self.query()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
elif action == "add":
|
||||
self.add(payload)
|
||||
elif action == "open":
|
||||
self.open(payload)
|
||||
elif action == "close":
|
||||
self.close(payload)
|
||||
elif action == "delete":
|
||||
self.delete(payload)
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _download_file(url: str, dest: str):
|
||||
"""
|
||||
Download a file from *url* and save to *dest*.
|
||||
|
||||
:param url: remote file URL
|
||||
:param dest: local destination path
|
||||
"""
|
||||
if requests is None:
|
||||
raise RuntimeError("requests library is required for downloading skill files")
|
||||
|
||||
dest_dir = os.path.dirname(dest)
|
||||
if dest_dir:
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
resp = requests.get(url, timeout=60)
|
||||
resp.raise_for_status()
|
||||
with open(dest, "wb") as f:
|
||||
f.write(resp.content)
|
||||
logger.debug(f"[SkillService] downloaded {url} -> {dest}")
|
||||
75
agent/skills/types.py
Normal file
75
agent/skills/types.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Type definitions for skills system.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillInstallSpec:
|
||||
"""Specification for installing skill dependencies."""
|
||||
kind: str # brew, pip, npm, download, etc.
|
||||
id: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
bins: List[str] = field(default_factory=list)
|
||||
os: List[str] = field(default_factory=list)
|
||||
formula: Optional[str] = None # for brew
|
||||
package: Optional[str] = None # for pip/npm
|
||||
module: Optional[str] = None
|
||||
url: Optional[str] = None # for download
|
||||
archive: Optional[str] = None
|
||||
extract: bool = False
|
||||
strip_components: Optional[int] = None
|
||||
target_dir: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillMetadata:
|
||||
"""Metadata for a skill from frontmatter."""
|
||||
always: bool = False # Always include this skill
|
||||
skill_key: Optional[str] = None # Override skill key
|
||||
primary_env: Optional[str] = None # Primary environment variable
|
||||
emoji: Optional[str] = None
|
||||
homepage: Optional[str] = None
|
||||
os: List[str] = field(default_factory=list) # Supported OS platforms
|
||||
requires: Dict[str, List[str]] = field(default_factory=dict) # Requirements
|
||||
install: List[SkillInstallSpec] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
"""Represents a skill loaded from a markdown file."""
|
||||
name: str
|
||||
description: str
|
||||
file_path: str
|
||||
base_dir: str
|
||||
source: str # builtin or custom
|
||||
content: str # Full markdown content
|
||||
disable_model_invocation: bool = False
|
||||
frontmatter: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillEntry:
|
||||
"""A skill with parsed metadata."""
|
||||
skill: Skill
|
||||
metadata: Optional[SkillMetadata] = None
|
||||
user_invocable: bool = True # Can users invoke this skill directly
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadSkillsResult:
|
||||
"""Result of loading skills from a directory."""
|
||||
skills: List[Skill]
|
||||
diagnostics: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillSnapshot:
|
||||
"""Snapshot of skills for a specific run."""
|
||||
prompt: str # Formatted prompt text
|
||||
skills: List[Dict[str, str]] # List of skill info (name, primary_env)
|
||||
resolved_skills: List[Skill] = field(default_factory=list)
|
||||
version: Optional[int] = None
|
||||
133
agent/tools/__init__.py
Normal file
133
agent/tools/__init__.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Import base tool
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
|
||||
# Import file operation tools
|
||||
from agent.tools.read.read import Read
|
||||
from agent.tools.write.write import Write
|
||||
from agent.tools.edit.edit import Edit
|
||||
from agent.tools.bash.bash import Bash
|
||||
from agent.tools.ls.ls import Ls
|
||||
from agent.tools.send.send import Send
|
||||
|
||||
# Import memory tools
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
# Import tools with optional dependencies
|
||||
def _import_optional_tools():
|
||||
"""Import tools that have optional dependencies"""
|
||||
from common.log import logger
|
||||
tools = {}
|
||||
|
||||
# EnvConfig Tool (requires python-dotenv)
|
||||
try:
|
||||
from agent.tools.env_config.env_config import EnvConfig
|
||||
tools['EnvConfig'] = EnvConfig
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"[Tools] EnvConfig tool not loaded - missing dependency: {e}\n"
|
||||
f" To enable environment variable management, run:\n"
|
||||
f" pip install python-dotenv>=1.0.0"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] EnvConfig tool failed to load: {e}")
|
||||
|
||||
# Scheduler Tool (requires croniter)
|
||||
try:
|
||||
from agent.tools.scheduler.scheduler_tool import SchedulerTool
|
||||
tools['SchedulerTool'] = SchedulerTool
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"[Tools] Scheduler tool not loaded - missing dependency: {e}\n"
|
||||
f" To enable scheduled tasks, run:\n"
|
||||
f" pip install croniter>=2.0.0"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Scheduler tool failed to load: {e}")
|
||||
|
||||
# WebSearch Tool (conditionally loaded based on API key availability at init time)
|
||||
try:
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
tools['WebSearch'] = WebSearch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebSearch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebSearch failed to load: {e}")
|
||||
|
||||
# WebFetch Tool
|
||||
try:
|
||||
from agent.tools.web_fetch.web_fetch import WebFetch
|
||||
tools['WebFetch'] = WebFetch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebFetch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebFetch failed to load: {e}")
|
||||
|
||||
# Vision Tool (conditionally loaded based on API key availability)
|
||||
try:
|
||||
from agent.tools.vision.vision import Vision
|
||||
tools['Vision'] = Vision
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] Vision not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Vision failed to load: {e}")
|
||||
|
||||
return tools
|
||||
|
||||
# Load optional tools
|
||||
_optional_tools = _import_optional_tools()
|
||||
EnvConfig = _optional_tools.get('EnvConfig')
|
||||
SchedulerTool = _optional_tools.get('SchedulerTool')
|
||||
WebSearch = _optional_tools.get('WebSearch')
|
||||
WebFetch = _optional_tools.get('WebFetch')
|
||||
Vision = _optional_tools.get('Vision')
|
||||
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
Terminal = _optional_tools.get('Terminal')
|
||||
|
||||
|
||||
# Delayed import for BrowserTool
|
||||
def _import_browser_tool():
|
||||
try:
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
return BrowserTool
|
||||
except ImportError:
|
||||
# Return a placeholder class that will prompt the user to install dependencies when instantiated
|
||||
class BrowserToolPlaceholder:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"The 'browser-use' package is required to use BrowserTool. "
|
||||
"Please install it with 'pip install browser-use>=0.1.40'."
|
||||
)
|
||||
|
||||
return BrowserToolPlaceholder
|
||||
|
||||
|
||||
# Dynamically set BrowserTool
|
||||
# BrowserTool = _import_browser_tool()
|
||||
|
||||
# Export all tools (including optional ones that might be None)
|
||||
__all__ = [
|
||||
'BaseTool',
|
||||
'ToolManager',
|
||||
'Read',
|
||||
'Write',
|
||||
'Edit',
|
||||
'Bash',
|
||||
'Ls',
|
||||
'Send',
|
||||
'MemorySearchTool',
|
||||
'MemoryGetTool',
|
||||
'EnvConfig',
|
||||
'SchedulerTool',
|
||||
'WebSearch',
|
||||
'WebFetch',
|
||||
'Vision',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
# 'BrowserTool'
|
||||
]
|
||||
|
||||
"""
|
||||
Tools module for Agent.
|
||||
"""
|
||||
99
agent/tools/base_tool.py
Normal file
99
agent/tools/base_tool.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from common.log import logger
|
||||
import copy
|
||||
|
||||
|
||||
class ToolStage(Enum):
|
||||
"""Enum representing tool decision stages"""
|
||||
PRE_PROCESS = "pre_process" # Tools that need to be actively selected by the agent
|
||||
POST_PROCESS = "post_process" # Tools that automatically execute after final_answer
|
||||
|
||||
|
||||
class ToolResult:
|
||||
"""Tool execution result"""
|
||||
|
||||
def __init__(self, status: str = None, result: Any = None, ext_data: Any = None):
|
||||
self.status = status
|
||||
self.result = result
|
||||
self.ext_data = ext_data
|
||||
|
||||
@staticmethod
|
||||
def success(result, ext_data: Any = None):
|
||||
return ToolResult(status="success", result=result, ext_data=ext_data)
|
||||
|
||||
@staticmethod
|
||||
def fail(result, ext_data: Any = None):
|
||||
return ToolResult(status="error", result=result, ext_data=ext_data)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all tools."""
|
||||
|
||||
# Default decision stage is pre-process
|
||||
stage = ToolStage.PRE_PROCESS
|
||||
|
||||
# Class attributes must be inherited
|
||||
name: str = "base_tool"
|
||||
description: str = "Base tool"
|
||||
params: dict = {} # Store JSON Schema
|
||||
model: Optional[Any] = None # LLM model instance, type depends on bot implementation
|
||||
|
||||
@classmethod
|
||||
def get_json_schema(cls) -> dict:
|
||||
"""Get the standard description of the tool"""
|
||||
return {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": cls.params
|
||||
}
|
||||
|
||||
def execute_tool(self, params: dict) -> ToolResult:
|
||||
try:
|
||||
return self.execute(params)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
"""Specific logic to be implemented by subclasses"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _parse_schema(cls) -> dict:
|
||||
"""Convert JSON Schema to Pydantic fields"""
|
||||
fields = {}
|
||||
for name, prop in cls.params["properties"].items():
|
||||
# Convert JSON Schema types to Python types
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict
|
||||
}
|
||||
fields[name] = (
|
||||
type_map[prop["type"]],
|
||||
prop.get("default", ...)
|
||||
)
|
||||
return fields
|
||||
|
||||
def should_auto_execute(self, context) -> bool:
|
||||
"""
|
||||
Determine if this tool should be automatically executed based on context.
|
||||
|
||||
:param context: The agent context
|
||||
:return: True if the tool should be executed, False otherwise
|
||||
"""
|
||||
# Only tools in post-process stage will be automatically executed
|
||||
return self.stage == ToolStage.POST_PROCESS
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close any resources used by the tool.
|
||||
This method should be overridden by tools that need to clean up resources
|
||||
such as browser connections, file handles, etc.
|
||||
|
||||
By default, this method does nothing.
|
||||
"""
|
||||
pass
|
||||
3
agent/tools/bash/__init__.py
Normal file
3
agent/tools/bash/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .bash import Bash
|
||||
|
||||
__all__ = ['Bash']
|
||||
291
agent/tools/bash/bash.py
Normal file
291
agent/tools/bash/bash.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Bash tool - Execute bash commands
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Bash(BaseTool):
|
||||
"""Tool for executing bash commands"""
|
||||
|
||||
name: str = "bash"
|
||||
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
|
||||
|
||||
ENVIRONMENT: All API keys from env_config are auto-injected. Use $VAR_NAME directly.
|
||||
|
||||
SAFETY:
|
||||
- Freely create/modify/delete files within the workspace
|
||||
- For destructive and out-of-workspace commands, explain and confirm first"""
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (optional, default: 30)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
# Ensure working directory exists
|
||||
if not os.path.exists(self.cwd):
|
||||
os.makedirs(self.cwd, exist_ok=True)
|
||||
self.default_timeout = self.config.get("timeout", 30)
|
||||
# Enable safety mode by default (can be disabled in config)
|
||||
self.safety_mode = self.config.get("safety_mode", True)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute a bash command
|
||||
|
||||
:param args: Dictionary containing the command and optional timeout
|
||||
:return: Command output or error
|
||||
"""
|
||||
command = args.get("command", "").strip()
|
||||
timeout = args.get("timeout", self.default_timeout)
|
||||
|
||||
if not command:
|
||||
return ToolResult.fail("Error: command parameter is required")
|
||||
|
||||
# Security check: Prevent accessing sensitive config files
|
||||
if "~/.cow/.env" in command or "~/.cow" in command:
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
# Optional safety check - only warn about extremely dangerous commands
|
||||
if self.safety_mode:
|
||||
warning = self._get_safety_warning(command)
|
||||
if warning:
|
||||
return ToolResult.fail(
|
||||
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
|
||||
|
||||
try:
|
||||
# Prepare environment with .env file variables
|
||||
env = os.environ.copy()
|
||||
|
||||
# Load environment variables from ~/.cow/.env if it exists
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
dotenv_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
dotenv_vars = dotenv_values(env_file)
|
||||
env.update(dotenv_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(dotenv_vars)} variables from {env_file}")
|
||||
except ImportError:
|
||||
logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
|
||||
except Exception as e:
|
||||
logger.debug(f"[Bash] Failed to load .env: {e}")
|
||||
|
||||
# getuid() only exists on Unix-like systems
|
||||
if hasattr(os, 'getuid'):
|
||||
logger.debug(f"[Bash] Process UID: {os.getuid()}")
|
||||
else:
|
||||
logger.debug(f"[Bash] Process User: {os.environ.get('USERNAME', os.environ.get('USER', 'unknown'))}")
|
||||
|
||||
# On Windows, convert $VAR references to %VAR% for cmd.exe
|
||||
if sys.platform == "win32":
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
command = self._convert_env_vars_for_windows(command, dotenv_vars)
|
||||
if command and not command.strip().lower().startswith("chcp"):
|
||||
command = f"chcp 65001 >nul 2>&1 && {command}"
|
||||
|
||||
# Execute command with inherited environment variables
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
cwd=self.cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
|
||||
logger.debug(f"[Bash] Exit code: {result.returncode}")
|
||||
logger.debug(f"[Bash] Stdout length: {len(result.stdout)}")
|
||||
logger.debug(f"[Bash] Stderr length: {len(result.stderr)}")
|
||||
|
||||
# Workaround for exit code 126 with no output
|
||||
if result.returncode == 126 and not result.stdout and not result.stderr:
|
||||
logger.warning(f"[Bash] Exit 126 with no output - trying alternative execution method")
|
||||
# Try using argument list instead of shell=True
|
||||
import shlex
|
||||
try:
|
||||
parts = shlex.split(command)
|
||||
if len(parts) > 0:
|
||||
logger.info(f"[Bash] Retrying with argument list: {parts[:3]}...")
|
||||
retry_result = subprocess.run(
|
||||
parts,
|
||||
cwd=self.cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
logger.debug(f"[Bash] Retry exit code: {retry_result.returncode}, stdout: {len(retry_result.stdout)}, stderr: {len(retry_result.stderr)}")
|
||||
|
||||
# If retry succeeded, use retry result
|
||||
if retry_result.returncode == 0 or retry_result.stdout or retry_result.stderr:
|
||||
result = retry_result
|
||||
else:
|
||||
# Both attempts failed - check if this is openai-image-vision skill
|
||||
if 'openai-image-vision' in command or 'vision.sh' in command:
|
||||
# Create a mock result with helpful error message
|
||||
from types import SimpleNamespace
|
||||
result = SimpleNamespace(
|
||||
returncode=1,
|
||||
stdout='{"error": "图片无法解析", "reason": "该图片格式可能不受支持,或图片文件存在问题", "suggestion": "请尝试其他图片"}',
|
||||
stderr=''
|
||||
)
|
||||
logger.info(f"[Bash] Converted exit 126 to user-friendly image error message for vision skill")
|
||||
except Exception as retry_err:
|
||||
logger.warning(f"[Bash] Retry failed: {retry_err}")
|
||||
|
||||
# Combine stdout and stderr
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
|
||||
# Check if we need to save full output to temp file
|
||||
temp_file_path = None
|
||||
total_bytes = len(output.encode('utf-8'))
|
||||
|
||||
if total_bytes > DEFAULT_MAX_BYTES:
|
||||
# Save full output to temp file
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='bash-') as f:
|
||||
f.write(output)
|
||||
temp_file_path = f.name
|
||||
|
||||
# Apply tail truncation
|
||||
truncation = truncate_tail(output)
|
||||
output_text = truncation.content or "(no output)"
|
||||
|
||||
# Build result
|
||||
details = {}
|
||||
|
||||
if truncation.truncated:
|
||||
details["truncation"] = truncation.to_dict()
|
||||
if temp_file_path:
|
||||
details["full_output_path"] = temp_file_path
|
||||
|
||||
# Build notice
|
||||
start_line = truncation.total_lines - truncation.output_lines + 1
|
||||
end_line = truncation.total_lines
|
||||
|
||||
if truncation.last_line_partial:
|
||||
# Edge case: last line alone > 30KB
|
||||
last_line = output.split('\n')[-1] if output else ""
|
||||
last_line_size = format_size(len(last_line.encode('utf-8')))
|
||||
output_text += f"\n\n[Showing last {format_size(truncation.output_bytes)} of line {end_line} (line is {last_line_size}). Full output: {temp_file_path}]"
|
||||
elif truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines}. Full output: {temp_file_path}]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Full output: {temp_file_path}]"
|
||||
|
||||
# Check exit code
|
||||
if result.returncode != 0:
|
||||
output_text += f"\n\nCommand exited with code {result.returncode}"
|
||||
return ToolResult.fail({
|
||||
"output": output_text,
|
||||
"exit_code": result.returncode,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output_text,
|
||||
"exit_code": result.returncode,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail(f"Error: Command timed out after {timeout} seconds")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing command: {str(e)}")
|
||||
|
||||
def _get_safety_warning(self, command: str) -> str:
|
||||
"""
|
||||
Get safety warning for potentially dangerous commands
|
||||
Only warns about extremely dangerous system-level operations
|
||||
|
||||
:param command: Command to check
|
||||
:return: Warning message if dangerous, empty string if safe
|
||||
"""
|
||||
cmd_lower = command.lower().strip()
|
||||
|
||||
# Only block extremely dangerous system operations
|
||||
dangerous_patterns = [
|
||||
# System shutdown/reboot
|
||||
("shutdown", "This command will shut down the system"),
|
||||
("reboot", "This command will reboot the system"),
|
||||
("halt", "This command will halt the system"),
|
||||
("poweroff", "This command will power off the system"),
|
||||
|
||||
# Critical system modifications
|
||||
("rm -rf /", "This command will delete the entire filesystem"),
|
||||
("rm -rf /*", "This command will delete the entire filesystem"),
|
||||
("dd if=/dev/zero", "This command can destroy disk data"),
|
||||
("mkfs", "This command will format a filesystem, destroying all data"),
|
||||
("fdisk", "This command modifies disk partitions"),
|
||||
|
||||
# User/system management (only if targeting system users)
|
||||
("userdel root", "This command will delete the root user"),
|
||||
("passwd root", "This command will change the root password"),
|
||||
]
|
||||
|
||||
for pattern, warning in dangerous_patterns:
|
||||
if pattern in cmd_lower:
|
||||
return warning
|
||||
|
||||
# Check for recursive deletion outside workspace
|
||||
if "rm" in cmd_lower and "-rf" in cmd_lower:
|
||||
# Allow deletion within current workspace
|
||||
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
|
||||
# Check if targeting system directories
|
||||
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
|
||||
if any(sysdir in cmd_lower for sysdir in system_dirs):
|
||||
return "This command will recursively delete system directories"
|
||||
|
||||
return "" # No warning needed
|
||||
|
||||
@staticmethod
|
||||
def _convert_env_vars_for_windows(command: str, dotenv_vars: dict) -> str:
|
||||
"""
|
||||
Convert bash-style $VAR / ${VAR} references to cmd.exe %VAR% syntax.
|
||||
Only converts variables loaded from .env (user-configured API keys etc.)
|
||||
to avoid breaking $PATH, jq expressions, regex, etc.
|
||||
"""
|
||||
if not dotenv_vars:
|
||||
return command
|
||||
|
||||
def replace_match(m):
|
||||
var_name = m.group(1) or m.group(2)
|
||||
if var_name in dotenv_vars:
|
||||
return f"%{var_name}%"
|
||||
return m.group(0)
|
||||
|
||||
return re.sub(r'\$\{(\w+)\}|\$(\w+)', replace_match, command)
|
||||
18
agent/tools/browser_tool.py
Normal file
18
agent/tools/browser_tool.py
Normal file
@@ -0,0 +1,18 @@
|
||||
def copy(self):
|
||||
"""
|
||||
Special copy method for browser tool to avoid recreating browser instance.
|
||||
|
||||
:return: A new instance with shared browser reference but unique model
|
||||
"""
|
||||
new_tool = self.__class__()
|
||||
|
||||
# Copy essential attributes
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, 'context', None)
|
||||
new_tool.config = getattr(self, 'config', None)
|
||||
|
||||
# Share the browser instance instead of creating a new one
|
||||
if hasattr(self, 'browser'):
|
||||
new_tool.browser = self.browser
|
||||
|
||||
return new_tool
|
||||
3
agent/tools/edit/__init__.py
Normal file
3
agent/tools/edit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .edit import Edit
|
||||
|
||||
__all__ = ['Edit']
|
||||
185
agent/tools/edit/edit.py
Normal file
185
agent/tools/edit/edit.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Edit tool - Precise file editing
|
||||
Edit files through exact text replacement
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
from agent.tools.utils.diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
normalize_to_lf,
|
||||
restore_line_endings,
|
||||
normalize_for_fuzzy_match,
|
||||
fuzzy_find_text,
|
||||
generate_diff_string
|
||||
)
|
||||
|
||||
|
||||
class Edit(BaseTool):
|
||||
"""Tool for precise file editing"""
|
||||
|
||||
name: str = "edit"
|
||||
description: str = "Edit a file by replacing exact text, or append to end if oldText is empty. For append: use empty oldText. For replace: oldText must match exactly (including whitespace)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit (relative or absolute)"
|
||||
},
|
||||
"oldText": {
|
||||
"type": "string",
|
||||
"description": "Text to find and replace. Use empty string to append to end of file. For replacement: must match exactly including whitespace."
|
||||
},
|
||||
"newText": {
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with"
|
||||
}
|
||||
},
|
||||
"required": ["path", "oldText", "newText"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.memory_manager = self.config.get("memory_manager", None)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file edit operation
|
||||
|
||||
:param args: Contains file path, old text and new text
|
||||
:return: Operation result
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
old_text = args.get("oldText", "")
|
||||
new_text = args.get("newText", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable/writable
|
||||
if not os.access(absolute_path, os.R_OK | os.W_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable/writable: {path}")
|
||||
|
||||
try:
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
raw_content = f.read()
|
||||
|
||||
# Remove BOM (LLM won't include invisible BOM in oldText)
|
||||
bom, content = strip_bom(raw_content)
|
||||
|
||||
# Detect original line ending
|
||||
original_ending = detect_line_ending(content)
|
||||
|
||||
# Normalize to LF
|
||||
normalized_content = normalize_to_lf(content)
|
||||
normalized_old_text = normalize_to_lf(old_text)
|
||||
normalized_new_text = normalize_to_lf(new_text)
|
||||
|
||||
# Special case: empty oldText means append to end of file
|
||||
if not old_text or not old_text.strip():
|
||||
# Append mode: add newText to the end
|
||||
# Add newline before newText if file doesn't end with one
|
||||
if normalized_content and not normalized_content.endswith('\n'):
|
||||
new_content = normalized_content + '\n' + normalized_new_text
|
||||
else:
|
||||
new_content = normalized_content + normalized_new_text
|
||||
base_content = normalized_content # For verification
|
||||
else:
|
||||
# Normal edit mode: find and replace
|
||||
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
|
||||
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
|
||||
|
||||
if not match_result.found:
|
||||
return ToolResult.fail(
|
||||
f"Error: Could not find the exact text in {path}. "
|
||||
"The old text must match exactly including all whitespace and newlines."
|
||||
)
|
||||
|
||||
# Calculate occurrence count (use fuzzy normalized content for consistency)
|
||||
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
|
||||
occurrences = fuzzy_content.count(fuzzy_old_text)
|
||||
|
||||
if occurrences > 1:
|
||||
return ToolResult.fail(
|
||||
f"Error: Found {occurrences} occurrences of the text in {path}. "
|
||||
"The text must be unique. Please provide more context to make it unique."
|
||||
)
|
||||
|
||||
# Execute replacement (use matched text position)
|
||||
base_content = match_result.content_for_replacement
|
||||
new_content = (
|
||||
base_content[:match_result.index] +
|
||||
normalized_new_text +
|
||||
base_content[match_result.index + match_result.match_length:]
|
||||
)
|
||||
|
||||
# Verify replacement actually changed content
|
||||
if base_content == new_content:
|
||||
return ToolResult.fail(
|
||||
f"Error: No changes made to {path}. "
|
||||
"The replacement produced identical content. "
|
||||
"This might indicate an issue with special characters or the text not existing as expected."
|
||||
)
|
||||
|
||||
# Restore original line endings
|
||||
final_content = bom + restore_line_endings(new_content, original_ending)
|
||||
|
||||
# Write file
|
||||
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||
f.write(final_content)
|
||||
|
||||
# Generate diff
|
||||
diff_result = generate_diff_string(base_content, new_content)
|
||||
|
||||
result = {
|
||||
"message": f"Successfully replaced text in {path}",
|
||||
"path": path,
|
||||
"diff": diff_result['diff'],
|
||||
"first_changed_line": diff_result['first_changed_line']
|
||||
}
|
||||
|
||||
# Notify memory manager if file is in memory directory
|
||||
if self.memory_manager and "memory/" in path:
|
||||
try:
|
||||
self.memory_manager.mark_dirty()
|
||||
except Exception as e:
|
||||
# Don't fail the edit if memory notification fails
|
||||
pass
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {path}")
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied accessing {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error editing file: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
3
agent/tools/env_config/__init__.py
Normal file
3
agent/tools/env_config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.env_config.env_config import EnvConfig
|
||||
|
||||
__all__ = ['EnvConfig']
|
||||
286
agent/tools/env_config/env_config.py
Normal file
286
agent/tools/env_config/env_config.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Environment Configuration Tool - Manage API keys and environment variables
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
# API Key 知识库:常见的环境变量及其描述
|
||||
API_KEY_REGISTRY = {
|
||||
# AI 模型服务
|
||||
"OPENAI_API_KEY": "OpenAI API 密钥 (用于GPT模型、Embedding模型)",
|
||||
"GEMINI_API_KEY": "Google Gemini API 密钥",
|
||||
"CLAUDE_API_KEY": "Claude API 密钥 (用于Claude模型)",
|
||||
"LINKAI_API_KEY": "LinkAI智能体平台 API 密钥,支持多种模型切换",
|
||||
# 搜索服务
|
||||
"BOCHA_API_KEY": "博查 AI 搜索 API 密钥 ",
|
||||
}
|
||||
|
||||
class EnvConfig(BaseTool):
|
||||
"""Tool for managing environment variables (API keys, etc.)"""
|
||||
|
||||
name: str = "env_config"
|
||||
description: str = (
|
||||
"Manage API keys and skill configurations securely. "
|
||||
"Use this tool when user wants to configure API keys (like BOCHA_API_KEY, OPENAI_API_KEY), "
|
||||
"view configured keys, or manage skill settings. "
|
||||
"Actions: 'set' (add/update key), 'get' (view specific key), 'list' (show all configured keys), 'delete' (remove key). "
|
||||
"Values are automatically masked for security. Changes take effect immediately via hot reload."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform: 'set', 'get', 'list', 'delete'",
|
||||
"enum": ["set", "get", "list", "delete"]
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Environment variable key name. Common keys:\n"
|
||||
"- OPENAI_API_KEY: OpenAI API (GPT models)\n"
|
||||
"- OPENAI_API_BASE: OpenAI API base URL\n"
|
||||
"- CLAUDE_API_KEY: Anthropic Claude API\n"
|
||||
"- GEMINI_API_KEY: Google Gemini API\n"
|
||||
"- LINKAI_API_KEY: LinkAI platform\n"
|
||||
"- BOCHA_API_KEY: Bocha AI search (博查搜索)\n"
|
||||
"Use exact key names (case-sensitive, all uppercase with underscores)"
|
||||
)
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Value to set for the environment variable (for 'set' action)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
# Store env config in ~/.cow directory (outside workspace for security)
|
||||
self.env_dir = expand_path("~/.cow")
|
||||
self.env_path = os.path.join(self.env_dir, '.env')
|
||||
self.agent_bridge = self.config.get("agent_bridge") # Reference to AgentBridge for hot reload
|
||||
# Don't create .env file in __init__ to avoid issues during tool discovery
|
||||
# It will be created on first use in execute()
|
||||
|
||||
def _ensure_env_file(self):
|
||||
"""Ensure the .env file exists"""
|
||||
# Create ~/.cow directory if it doesn't exist
|
||||
os.makedirs(self.env_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(self.env_path):
|
||||
Path(self.env_path).touch()
|
||||
logger.info(f"[EnvConfig] Created .env file at {self.env_path}")
|
||||
|
||||
def _mask_value(self, value: str) -> str:
|
||||
"""Mask sensitive parts of a value for logging"""
|
||||
if not value or len(value) <= 10:
|
||||
return "***"
|
||||
return f"{value[:6]}***{value[-4:]}"
|
||||
|
||||
def _read_env_file(self) -> Dict[str, str]:
|
||||
"""Read all key-value pairs from .env file"""
|
||||
env_vars = {}
|
||||
if os.path.exists(self.env_path):
|
||||
with open(self.env_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
# Parse KEY=VALUE
|
||||
match = re.match(r'^([^=]+)=(.*)$', line)
|
||||
if match:
|
||||
key, value = match.groups()
|
||||
env_vars[key.strip()] = value.strip()
|
||||
return env_vars
|
||||
|
||||
def _write_env_file(self, env_vars: Dict[str, str]):
|
||||
"""Write all key-value pairs to .env file"""
|
||||
with open(self.env_path, 'w', encoding='utf-8') as f:
|
||||
f.write("# Environment variables for agent skills\n")
|
||||
f.write("# Auto-managed by env_config tool\n\n")
|
||||
for key, value in sorted(env_vars.items()):
|
||||
f.write(f"{key}={value}\n")
|
||||
|
||||
def _reload_env(self):
|
||||
"""Reload environment variables from .env file"""
|
||||
env_vars = self._read_env_file()
|
||||
for key, value in env_vars.items():
|
||||
os.environ[key] = value
|
||||
logger.debug(f"[EnvConfig] Reloaded {len(env_vars)} environment variables")
|
||||
|
||||
def _refresh_skills(self):
|
||||
"""Refresh skills after environment variable changes"""
|
||||
if self.agent_bridge:
|
||||
try:
|
||||
# Reload .env file
|
||||
self._reload_env()
|
||||
|
||||
# Refresh skills in all agent instances
|
||||
refreshed = self.agent_bridge.refresh_all_skills()
|
||||
logger.info(f"[EnvConfig] Refreshed skills in {refreshed} agent instance(s)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[EnvConfig] Failed to refresh skills: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute environment configuration operation
|
||||
|
||||
:param args: Contains action, key, and value parameters
|
||||
:return: Result of the operation
|
||||
"""
|
||||
# Ensure .env file exists on first use
|
||||
self._ensure_env_file()
|
||||
|
||||
action = args.get("action")
|
||||
key = args.get("key")
|
||||
value = args.get("value")
|
||||
|
||||
try:
|
||||
if action == "set":
|
||||
if not key or not value:
|
||||
return ToolResult.fail("Error: 'key' and 'value' are required for 'set' action.")
|
||||
|
||||
# Read current env vars
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
# Update the key
|
||||
env_vars[key] = value
|
||||
|
||||
# Write back to file
|
||||
self._write_env_file(env_vars)
|
||||
|
||||
# Update current process env
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[EnvConfig] Set {key}={self._mask_value(value)}")
|
||||
|
||||
# Try to refresh skills immediately
|
||||
refreshed = self._refresh_skills()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully set {key}",
|
||||
"key": key,
|
||||
"value": self._mask_value(value),
|
||||
}
|
||||
|
||||
if refreshed:
|
||||
result["note"] = "✅ Skills refreshed automatically - changes are now active"
|
||||
else:
|
||||
result["note"] = "⚠️ Skills not refreshed - restart agent to load new skills"
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
elif action == "get":
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for 'get' action.")
|
||||
|
||||
# Check in file first, then in current env
|
||||
env_vars = self._read_env_file()
|
||||
value = env_vars.get(key) or os.getenv(key)
|
||||
|
||||
# Get description from registry
|
||||
description = API_KEY_REGISTRY.get(key, "未知用途的环境变量")
|
||||
|
||||
if value is not None:
|
||||
logger.info(f"[EnvConfig] Got {key}={self._mask_value(value)}")
|
||||
return ToolResult.success({
|
||||
"key": key,
|
||||
"value": self._mask_value(value),
|
||||
"description": description,
|
||||
"exists": True,
|
||||
"note": f"Value is masked for security. In bash, use ${key} directly — it is auto-injected."
|
||||
})
|
||||
else:
|
||||
return ToolResult.success({
|
||||
"key": key,
|
||||
"description": description,
|
||||
"exists": False,
|
||||
"message": f"Environment variable '{key}' is not set"
|
||||
})
|
||||
|
||||
elif action == "list":
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
# Build detailed variable list with descriptions
|
||||
variables_with_info = {}
|
||||
for key, value in env_vars.items():
|
||||
variables_with_info[key] = {
|
||||
"value": self._mask_value(value),
|
||||
"description": API_KEY_REGISTRY.get(key, "未知用途的环境变量")
|
||||
}
|
||||
|
||||
logger.info(f"[EnvConfig] Listed {len(env_vars)} environment variables")
|
||||
|
||||
if not env_vars:
|
||||
return ToolResult.success({
|
||||
"message": "No environment variables configured",
|
||||
"variables": {},
|
||||
"note": "常用的 API 密钥可以通过 env_config(action='set', key='KEY_NAME', value='your-key') 来配置"
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"message": f"Found {len(env_vars)} environment variable(s)",
|
||||
"variables": variables_with_info
|
||||
})
|
||||
|
||||
elif action == "delete":
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for 'delete' action.")
|
||||
|
||||
# Read current env vars
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
if key not in env_vars:
|
||||
return ToolResult.success({
|
||||
"message": f"Environment variable '{key}' was not set",
|
||||
"key": key
|
||||
})
|
||||
|
||||
# Remove the key
|
||||
del env_vars[key]
|
||||
|
||||
# Write back to file
|
||||
self._write_env_file(env_vars)
|
||||
|
||||
# Remove from current process env
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
logger.info(f"[EnvConfig] Deleted {key}")
|
||||
|
||||
# Try to refresh skills immediately
|
||||
refreshed = self._refresh_skills()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully deleted {key}",
|
||||
"key": key,
|
||||
}
|
||||
|
||||
if refreshed:
|
||||
result["note"] = "✅ Skills refreshed automatically - changes are now active"
|
||||
else:
|
||||
result["note"] = "⚠️ Skills not refreshed - restart agent to apply changes"
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
else:
|
||||
return ToolResult.fail(f"Error: Unknown action '{action}'. Use 'set', 'get', 'list', or 'delete'.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EnvConfig] Error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"EnvConfig tool error: {str(e)}")
|
||||
3
agent/tools/ls/__init__.py
Normal file
3
agent/tools/ls/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ls import Ls
|
||||
|
||||
__all__ = ['Ls']
|
||||
140
agent/tools/ls/ls.py
Normal file
140
agent/tools/ls/ls.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Ls tool - List directory contents
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 500
|
||||
|
||||
|
||||
class Ls(BaseTool):
|
||||
"""Tool for listing directory contents"""
|
||||
|
||||
name: str = "ls"
|
||||
description: str = f"List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to {DEFAULT_LIMIT} entries or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to list. IMPORTANT: Relative paths are based on workspace directory. To access directories outside workspace, use absolute paths starting with ~ or /."
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of entries to return (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute directory listing
|
||||
|
||||
:param args: Listing parameters
|
||||
:return: Directory contents or error
|
||||
"""
|
||||
path = args.get("path", ".").strip()
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent accessing sensitive config directory
|
||||
env_config_dir = expand_path("~/.cow")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_dir):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
# Provide helpful hint if using relative path
|
||||
if not os.path.isabs(path) and not path.startswith('~'):
|
||||
return ToolResult.fail(
|
||||
f"Error: Path not found: {path}\n"
|
||||
f"Resolved to: {absolute_path}\n"
|
||||
f"Hint: Relative paths are based on workspace ({self.cwd}). For files outside workspace, use absolute paths."
|
||||
)
|
||||
return ToolResult.fail(f"Error: Path not found: {path}")
|
||||
|
||||
if not os.path.isdir(absolute_path):
|
||||
return ToolResult.fail(f"Error: Not a directory: {path}")
|
||||
|
||||
try:
|
||||
# Read directory entries
|
||||
entries = os.listdir(absolute_path)
|
||||
|
||||
# Sort alphabetically (case-insensitive)
|
||||
entries.sort(key=lambda x: x.lower())
|
||||
|
||||
# Format entries with directory indicators
|
||||
results = []
|
||||
entry_limit_reached = False
|
||||
|
||||
for entry in entries:
|
||||
if len(results) >= limit:
|
||||
entry_limit_reached = True
|
||||
break
|
||||
|
||||
full_path = os.path.join(absolute_path, entry)
|
||||
|
||||
try:
|
||||
if os.path.isdir(full_path):
|
||||
results.append(entry + '/')
|
||||
else:
|
||||
results.append(entry)
|
||||
except Exception:
|
||||
# Skip entries we can't stat
|
||||
continue
|
||||
|
||||
if not results:
|
||||
return ToolResult.success({"message": "(empty directory)", "entries": []})
|
||||
|
||||
# Format output
|
||||
raw_output = '\n'.join(results)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
if entry_limit_reached:
|
||||
notices.append(f"{limit} entries limit reached. Use limit={limit * 2} for more")
|
||||
details["entry_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"entry_count": len(results),
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied reading directory: {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error listing directory: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
10
agent/tools/memory/__init__.py
Normal file
10
agent/tools/memory/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Memory tools for Agent
|
||||
|
||||
Provides memory_search and memory_get tools
|
||||
"""
|
||||
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
__all__ = ['MemorySearchTool', 'MemoryGetTool']
|
||||
111
agent/tools/memory/memory_get.py
Normal file
111
agent/tools/memory/memory_get.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Memory get tool
|
||||
|
||||
Allows agents to read specific sections from memory files
|
||||
"""
|
||||
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class MemoryGetTool(BaseTool):
|
||||
"""Tool for reading memory file contents"""
|
||||
|
||||
name: str = "memory_get"
|
||||
description: str = (
|
||||
"Read specific content from memory files. "
|
||||
"Use this to get full context from a memory file or specific line range."
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory file (e.g. 'MEMORY.md', 'memory/2026-01-01.md')"
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Starting line number (optional, default: 1)",
|
||||
"default": 1
|
||||
},
|
||||
"num_lines": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read (optional, reads all if not specified)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager):
|
||||
"""
|
||||
Initialize memory get tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
Execute memory file read
|
||||
|
||||
Args:
|
||||
args: Dictionary with path, start_line, num_lines
|
||||
|
||||
Returns:
|
||||
ToolResult with file content
|
||||
"""
|
||||
from agent.tools.base_tool import ToolResult
|
||||
|
||||
path = args.get("path")
|
||||
start_line = args.get("start_line", 1)
|
||||
num_lines = args.get("num_lines")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
try:
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
# Exception: MEMORY.md is in the root directory
|
||||
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = workspace_dir / path
|
||||
|
||||
if not file_path.exists():
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
lines = content.split('\n')
|
||||
|
||||
# Handle line range
|
||||
if start_line < 1:
|
||||
start_line = 1
|
||||
|
||||
start_idx = start_line - 1
|
||||
|
||||
if num_lines:
|
||||
end_idx = start_idx + num_lines
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
else:
|
||||
selected_lines = lines[start_idx:]
|
||||
|
||||
result = '\n'.join(selected_lines)
|
||||
|
||||
# Add metadata
|
||||
total_lines = len(lines)
|
||||
shown_lines = len(selected_lines)
|
||||
|
||||
output = [
|
||||
f"File: {path}",
|
||||
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
|
||||
"",
|
||||
result
|
||||
]
|
||||
|
||||
return ToolResult.success('\n'.join(output))
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading memory file: {str(e)}")
|
||||
102
agent/tools/memory/memory_search.py
Normal file
102
agent/tools/memory/memory_search.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Memory search tool
|
||||
|
||||
Allows agents to search their memory using semantic and keyword search
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class MemorySearchTool(BaseTool):
|
||||
"""Tool for searching agent memory"""
|
||||
|
||||
name: str = "memory_search"
|
||||
description: str = (
|
||||
"Search agent's long-term memory using semantic and keyword search. "
|
||||
"Use this to recall past conversations, preferences, and knowledge."
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (can be natural language question or keywords)"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 10)",
|
||||
"default": 10
|
||||
},
|
||||
"min_score": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score (0-1, default: 0.1)",
|
||||
"default": 0.1
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager, user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize memory search tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
user_id: Optional user ID for scoped search
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self.user_id = user_id
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
Execute memory search
|
||||
|
||||
Args:
|
||||
args: Dictionary with query, max_results, min_score
|
||||
|
||||
Returns:
|
||||
ToolResult with formatted search results
|
||||
"""
|
||||
from agent.tools.base_tool import ToolResult
|
||||
import asyncio
|
||||
|
||||
query = args.get("query")
|
||||
max_results = args.get("max_results", 10)
|
||||
min_score = args.get("min_score", 0.1)
|
||||
|
||||
if not query:
|
||||
return ToolResult.fail("Error: query parameter is required")
|
||||
|
||||
try:
|
||||
# Run async search in sync context
|
||||
results = asyncio.run(self.memory_manager.search(
|
||||
query=query,
|
||||
user_id=self.user_id,
|
||||
max_results=max_results,
|
||||
min_score=min_score,
|
||||
include_shared=True
|
||||
))
|
||||
|
||||
if not results:
|
||||
# Return clear message that no memories exist yet
|
||||
# This prevents infinite retry loops
|
||||
return ToolResult.success(
|
||||
f"No memories found for '{query}'. "
|
||||
f"This is normal if no memories have been stored yet. "
|
||||
f"You can store new memories by writing to MEMORY.md or memory/YYYY-MM-DD.md files."
|
||||
)
|
||||
|
||||
# Format results
|
||||
output = [f"Found {len(results)} relevant memories:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
|
||||
output.append(f" Score: {result.score:.3f}")
|
||||
output.append(f" Snippet: {result.snippet}")
|
||||
|
||||
return ToolResult.success("\n".join(output))
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error searching memory: {str(e)}")
|
||||
3
agent/tools/read/__init__.py
Normal file
3
agent/tools/read/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .read import Read
|
||||
|
||||
__all__ = ['Read']
|
||||
443
agent/tools/read/read.py
Normal file
443
agent/tools/read/read.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Read tool - Read file contents
|
||||
Supports text files, images (jpg, png, gif, webp), and PDF files
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Read(BaseTool):
|
||||
"""Tool for reading file contents"""
|
||||
|
||||
name: str = "read"
|
||||
description: str = f"Read or inspect file contents. For text/PDF files, returns content (truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB). For images/videos/audio, returns metadata only (file info, size, type). Use offset/limit for large text files."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to read. IMPORTANT: Relative paths are based on workspace directory. To access files outside workspace, use absolute paths starting with ~ or /."
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, optional). Use negative values to read from end (e.g. -20 for last 20 lines)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (optional)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
# File type categories
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
|
||||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
|
||||
self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
|
||||
self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
|
||||
self.pdf_extensions = {'.pdf'}
|
||||
|
||||
# Readable text formats (will be read with truncation)
|
||||
self.text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
|
||||
'.py', '.js', '.ts', '.java', '.c', '.cpp', '.h', '.hpp', '.go', '.rs', '.rb', '.php',
|
||||
'.html', '.css', '.scss', '.sass', '.less', '.vue', '.jsx', '.tsx',
|
||||
'.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
|
||||
'.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
|
||||
'.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
|
||||
'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx' # Office documents
|
||||
}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file read operation
|
||||
|
||||
:param args: Contains file path and optional offset/limit parameters
|
||||
:return: File content or error message
|
||||
"""
|
||||
# Support 'location' as alias for 'path' (LLM may use it from skill listing)
|
||||
path = args.get("path", "") or args.get("location", "")
|
||||
path = path.strip() if isinstance(path, str) else ""
|
||||
offset = args.get("offset")
|
||||
limit = args.get("limit")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent reading sensitive config files
|
||||
env_config_path = expand_path("~/.cow/.env")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_path):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
# Provide helpful hint if using relative path
|
||||
if not os.path.isabs(path) and not path.startswith('~'):
|
||||
return ToolResult.fail(
|
||||
f"Error: File not found: {path}\n"
|
||||
f"Resolved to: {absolute_path}\n"
|
||||
f"Hint: Relative paths are based on workspace ({self.cwd}). For files outside workspace, use absolute paths."
|
||||
)
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable
|
||||
if not os.access(absolute_path, os.R_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable: {path}")
|
||||
|
||||
# Check file type
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
|
||||
# Check if image - return metadata for sending
|
||||
if file_ext in self.image_extensions:
|
||||
return self._read_image(absolute_path, file_ext)
|
||||
|
||||
# Check if video/audio/binary/archive - return metadata only
|
||||
if file_ext in self.video_extensions:
|
||||
return self._return_file_metadata(absolute_path, "video", file_size)
|
||||
if file_ext in self.audio_extensions:
|
||||
return self._return_file_metadata(absolute_path, "audio", file_size)
|
||||
if file_ext in self.binary_extensions or file_ext in self.archive_extensions:
|
||||
return self._return_file_metadata(absolute_path, "binary", file_size)
|
||||
|
||||
# Check if PDF
|
||||
if file_ext in self.pdf_extensions:
|
||||
return self._read_pdf(absolute_path, path, offset, limit)
|
||||
|
||||
# Read text file (with truncation for large files)
|
||||
return self._read_text(absolute_path, path, offset, limit)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _return_file_metadata(self, absolute_path: str, file_type: str, file_size: int) -> ToolResult:
|
||||
"""
|
||||
Return file metadata for non-readable files (video, audio, binary, etc.)
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param file_type: Type of file (video, audio, binary, etc.)
|
||||
:param file_size: File size in bytes
|
||||
:return: File metadata
|
||||
"""
|
||||
file_name = Path(absolute_path).name
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
|
||||
# Determine MIME type
|
||||
mime_types = {
|
||||
# Video
|
||||
'.mp4': 'video/mp4', '.avi': 'video/x-msvideo', '.mov': 'video/quicktime',
|
||||
'.mkv': 'video/x-matroska', '.webm': 'video/webm',
|
||||
# Audio
|
||||
'.mp3': 'audio/mpeg', '.wav': 'audio/wav', '.ogg': 'audio/ogg',
|
||||
'.m4a': 'audio/mp4', '.flac': 'audio/flac',
|
||||
# Binary
|
||||
'.zip': 'application/zip', '.tar': 'application/x-tar',
|
||||
'.gz': 'application/gzip', '.rar': 'application/x-rar-compressed',
|
||||
}
|
||||
mime_type = mime_types.get(file_ext, 'application/octet-stream')
|
||||
|
||||
result = {
|
||||
"type": f"{file_type}_metadata",
|
||||
"file_type": file_type,
|
||||
"path": absolute_path,
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"{file_type.capitalize()} 文件: {file_name} ({format_size(file_size)})\n提示: 如果需要发送此文件,请使用 send 工具。"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
|
||||
"""
|
||||
Read image file - always return metadata only (images should be sent, not read into context)
|
||||
|
||||
:param absolute_path: Absolute path to the image file
|
||||
:param file_ext: File extension
|
||||
:return: Result containing image metadata for sending
|
||||
"""
|
||||
try:
|
||||
# Get file size
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
|
||||
# Determine MIME type
|
||||
mime_type_map = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp'
|
||||
}
|
||||
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
|
||||
|
||||
# Return metadata for images (NOT file_to_send - use send tool to actually send)
|
||||
result = {
|
||||
"type": "image_metadata",
|
||||
"file_type": "image",
|
||||
"path": absolute_path,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"图片文件: {Path(absolute_path).name} ({format_size(file_size)})\n提示: 如果需要发送此图片,请使用 send 工具。"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading image file: {str(e)}")
|
||||
|
||||
def _read_text(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read text file
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param display_path: Path to display
|
||||
:param offset: Starting line number (1-indexed)
|
||||
:param limit: Maximum number of lines to read
|
||||
:return: File content or error message
|
||||
"""
|
||||
try:
|
||||
# Check file size first
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
# File too large, return metadata only
|
||||
return ToolResult.success({
|
||||
"type": "file_to_send",
|
||||
"file_type": "document",
|
||||
"path": absolute_path,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
|
||||
})
|
||||
|
||||
# Read file (utf-8-sig strips BOM automatically on Windows)
|
||||
with open(absolute_path, 'r', encoding='utf-8-sig') as f:
|
||||
content = f.read()
|
||||
|
||||
# Truncate content if too long (20K characters max for model context)
|
||||
MAX_CONTENT_CHARS = 20 * 1024 # 20K characters
|
||||
content_truncated = False
|
||||
if len(content) > MAX_CONTENT_CHARS:
|
||||
content = content[:MAX_CONTENT_CHARS]
|
||||
content_truncated = True
|
||||
|
||||
all_lines = content.split('\n')
|
||||
total_file_lines = len(all_lines)
|
||||
|
||||
# Apply offset (if specified)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
if offset < 0:
|
||||
# Negative offset: read from end
|
||||
# -20 means "last 20 lines" → start from (total - 20)
|
||||
start_line = max(0, total_file_lines + offset)
|
||||
else:
|
||||
# Positive offset: read from start (1-indexed)
|
||||
start_line = max(0, offset - 1) # Convert to 0-indexed
|
||||
if start_line >= total_file_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1 # For display (1-indexed)
|
||||
|
||||
# If user specified limit, use it
|
||||
selected_content = content
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_file_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
# Apply truncation (considering line count and byte limits)
|
||||
truncation = truncate_head(selected_content)
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
# Add truncation warning if content was truncated
|
||||
if content_truncated:
|
||||
output_text = f"[文件内容已截断到前 {format_size(MAX_CONTENT_CHARS)},完整文件大小: {format_size(file_size)}]\n\n"
|
||||
|
||||
if truncation.first_line_exceeds_limit:
|
||||
# First line exceeds 30KB limit
|
||||
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
|
||||
output_text = f"[Line {start_line_display} is {first_line_size}, exceeds {format_size(DEFAULT_MAX_BYTES)} limit. Use bash tool to read: head -c {DEFAULT_MAX_BYTES} {display_path} | tail -n +{start_line_display}]"
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif truncation.truncated:
|
||||
# Truncation occurred
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
|
||||
output_text = truncation.content
|
||||
|
||||
if truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines}. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_file_lines:
|
||||
# User specified limit, more content available, but no truncation
|
||||
remaining = total_file_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
# No truncation, no exceeding user limit
|
||||
output_text = truncation.content
|
||||
|
||||
result = {
|
||||
"content": output_text,
|
||||
"total_lines": total_file_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines
|
||||
}
|
||||
|
||||
if details:
|
||||
result["details"] = details
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {display_path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading file: {str(e)}")
|
||||
|
||||
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read PDF file content
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param display_path: Path to display
|
||||
:param offset: Starting line number (1-indexed)
|
||||
:param limit: Maximum number of lines to read
|
||||
:return: PDF text content or error message
|
||||
"""
|
||||
try:
|
||||
# Try to import pypdf
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
return ToolResult.fail(
|
||||
"Error: pypdf library not installed. Install with: pip install pypdf"
|
||||
)
|
||||
|
||||
# Read PDF
|
||||
reader = PdfReader(absolute_path)
|
||||
total_pages = len(reader.pages)
|
||||
|
||||
# Extract text from all pages
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num} ---\n{page_text}")
|
||||
|
||||
if not text_parts:
|
||||
return ToolResult.success({
|
||||
"content": f"[PDF file with {total_pages} pages, but no text content could be extracted]",
|
||||
"total_pages": total_pages,
|
||||
"message": "PDF may contain only images or be encrypted"
|
||||
})
|
||||
|
||||
# Merge all text
|
||||
full_content = "\n\n".join(text_parts)
|
||||
all_lines = full_content.split('\n')
|
||||
total_lines = len(all_lines)
|
||||
|
||||
# Apply offset and limit (same logic as text files)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
start_line = max(0, offset - 1)
|
||||
if start_line >= total_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1
|
||||
|
||||
selected_content = full_content
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
# Apply truncation
|
||||
truncation = truncate_head(selected_content)
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
if truncation.truncated:
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
|
||||
output_text = truncation.content
|
||||
|
||||
if truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||
remaining = total_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text = truncation.content
|
||||
|
||||
result = {
|
||||
"content": output_text,
|
||||
"total_pages": total_pages,
|
||||
"total_lines": total_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines
|
||||
}
|
||||
|
||||
if details:
|
||||
result["details"] = details
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading PDF file: {str(e)}")
|
||||
287
agent/tools/scheduler/README.md
Normal file
287
agent/tools/scheduler/README.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# 定时任务工具 (Scheduler Tool)
|
||||
|
||||
## 功能简介
|
||||
|
||||
定时任务工具允许 Agent 创建、管理和执行定时任务,支持:
|
||||
|
||||
- ⏰ **定时提醒**: 在指定时间发送消息
|
||||
- 🔄 **周期性任务**: 按固定间隔或 cron 表达式重复执行
|
||||
- 🔧 **动态工具调用**: 定时执行其他工具并发送结果(如搜索新闻、查询天气等)
|
||||
- 📋 **任务管理**: 查询、启用、禁用、删除任务
|
||||
|
||||
## 安装依赖
|
||||
|
||||
```bash
|
||||
pip install croniter>=2.0.0
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 创建定时任务
|
||||
|
||||
Agent 可以通过自然语言创建定时任务,支持两种类型:
|
||||
|
||||
#### 1.1 静态消息任务
|
||||
|
||||
发送预定义的消息:
|
||||
|
||||
**示例对话:**
|
||||
```
|
||||
用户: 每天早上9点提醒我开会
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: create
|
||||
name: 每日开会提醒
|
||||
message: 该开会了!
|
||||
schedule_type: cron
|
||||
schedule_value: 0 9 * * *
|
||||
```
|
||||
|
||||
#### 1.2 动态工具调用任务
|
||||
|
||||
定时执行工具并发送结果:
|
||||
|
||||
**示例对话:**
|
||||
```
|
||||
用户: 每天早上8点帮我读取一下今日日程
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: create
|
||||
name: 每日日程
|
||||
tool_call:
|
||||
tool_name: read
|
||||
tool_params:
|
||||
file_path: ~/cow/schedule.txt
|
||||
result_prefix: 📅 今日日程
|
||||
schedule_type: cron
|
||||
schedule_value: 0 8 * * *
|
||||
```
|
||||
|
||||
**工具调用参数说明:**
|
||||
- `tool_name`: 要调用的工具名称(如 `bash`、`read`、`write` 等内置工具)
|
||||
- `tool_params`: 工具的参数(字典格式)
|
||||
- `result_prefix`: 可选,在结果前添加的前缀文本
|
||||
|
||||
**注意:** 如果要使用 skills(如 bocha-search),需要通过 `bash` 工具调用 skill 脚本
|
||||
|
||||
### 2. 支持的调度类型
|
||||
|
||||
#### Cron 表达式 (`cron`)
|
||||
使用标准 cron 表达式:
|
||||
|
||||
```
|
||||
0 9 * * * # 每天 9:00
|
||||
0 */2 * * * # 每 2 小时
|
||||
30 8 * * 1-5 # 工作日 8:30
|
||||
0 0 1 * * # 每月 1 号
|
||||
```
|
||||
|
||||
#### 固定间隔 (`interval`)
|
||||
以秒为单位的间隔:
|
||||
|
||||
```
|
||||
3600 # 每小时
|
||||
86400 # 每天
|
||||
1800 # 每 30 分钟
|
||||
```
|
||||
|
||||
#### 一次性任务 (`once`)
|
||||
指定具体时间(ISO 格式):
|
||||
|
||||
```
|
||||
2024-12-25T09:00:00
|
||||
2024-12-31T23:59:59
|
||||
```
|
||||
|
||||
### 3. 查询任务列表
|
||||
|
||||
```
|
||||
用户: 查看我的定时任务
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: list
|
||||
```
|
||||
|
||||
### 4. 查看任务详情
|
||||
|
||||
```
|
||||
用户: 查看任务 abc123 的详情
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: get
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
### 5. 删除任务
|
||||
|
||||
```
|
||||
用户: 删除任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: delete
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
### 6. 启用/禁用任务
|
||||
|
||||
```
|
||||
用户: 暂停任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: disable
|
||||
task_id: abc123
|
||||
|
||||
用户: 恢复任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: enable
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
## 任务存储
|
||||
|
||||
任务保存在 JSON 文件中:
|
||||
```
|
||||
~/cow/scheduler/tasks.json
|
||||
```
|
||||
|
||||
任务数据结构:
|
||||
|
||||
**静态消息任务:**
|
||||
```json
|
||||
{
|
||||
"id": "abc123",
|
||||
"name": "每日提醒",
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-01T10:00:00",
|
||||
"updated_at": "2024-01-01T10:00:00",
|
||||
"schedule": {
|
||||
"type": "cron",
|
||||
"expression": "0 9 * * *"
|
||||
},
|
||||
"action": {
|
||||
"type": "send_message",
|
||||
"content": "该开会了!",
|
||||
"receiver": "wxid_xxx",
|
||||
"receiver_name": "张三",
|
||||
"is_group": false,
|
||||
"channel_type": "wechat"
|
||||
},
|
||||
"next_run_at": "2024-01-02T09:00:00",
|
||||
"last_run_at": "2024-01-01T09:00:00"
|
||||
}
|
||||
```
|
||||
|
||||
**动态工具调用任务:**
|
||||
```json
|
||||
{
|
||||
"id": "def456",
|
||||
"name": "每日日程",
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-01T10:00:00",
|
||||
"updated_at": "2024-01-01T10:00:00",
|
||||
"schedule": {
|
||||
"type": "cron",
|
||||
"expression": "0 8 * * *"
|
||||
},
|
||||
"action": {
|
||||
"type": "tool_call",
|
||||
"tool_name": "read",
|
||||
"tool_params": {
|
||||
"file_path": "~/cow/schedule.txt"
|
||||
},
|
||||
"result_prefix": "📅 今日日程",
|
||||
"receiver": "wxid_xxx",
|
||||
"receiver_name": "张三",
|
||||
"is_group": false,
|
||||
"channel_type": "wechat"
|
||||
},
|
||||
"next_run_at": "2024-01-02T08:00:00"
|
||||
}
|
||||
```
|
||||
|
||||
## 后台服务
|
||||
|
||||
定时任务由后台服务 `SchedulerService` 管理:
|
||||
|
||||
- 每 30 秒检查一次到期任务
|
||||
- 自动执行到期任务
|
||||
- 计算下次执行时间
|
||||
- 记录执行历史和错误
|
||||
|
||||
服务在 Agent 初始化时自动启动,无需手动配置。
|
||||
|
||||
## 接收者确定
|
||||
|
||||
定时任务会发送给**创建任务时的对话对象**:
|
||||
|
||||
- 如果在私聊中创建,发送给该用户
|
||||
- 如果在群聊中创建,发送到该群
|
||||
- 接收者信息在创建时自动保存
|
||||
|
||||
## 常见用例
|
||||
|
||||
### 1. 每日提醒(静态消息)
|
||||
```
|
||||
用户: 每天早上8点提醒我吃药
|
||||
Agent: ✅ 定时任务创建成功
|
||||
任务ID: a1b2c3d4
|
||||
调度: 每天 8:00
|
||||
消息: 该吃药了!
|
||||
```
|
||||
|
||||
### 2. 工作日提醒(静态消息)
|
||||
```
|
||||
用户: 工作日下午6点提醒我下班
|
||||
Agent: [创建 cron: 0 18 * * 1-5]
|
||||
消息: 该下班了!
|
||||
```
|
||||
|
||||
### 3. 倒计时提醒(静态消息)
|
||||
```
|
||||
用户: 1小时后提醒我
|
||||
Agent: [创建 interval: 3600]
|
||||
```
|
||||
|
||||
### 4. 每日日程推送(动态工具调用)
|
||||
```
|
||||
用户: 每天早上8点帮我读取今日日程
|
||||
Agent: ✅ 定时任务创建成功
|
||||
任务ID: schedule001
|
||||
调度: 每天 8:00
|
||||
工具: read(file_path='~/cow/schedule.txt')
|
||||
前缀: 📅 今日日程
|
||||
```
|
||||
|
||||
### 5. 定时文件备份(动态工具调用)
|
||||
```
|
||||
用户: 每天晚上11点备份工作文件
|
||||
Agent: [创建 cron: 0 23 * * *]
|
||||
工具: bash(command='cp ~/cow/work.txt ~/cow/backup/work_$(date +%Y%m%d).txt')
|
||||
前缀: ✅ 文件已备份
|
||||
```
|
||||
|
||||
### 6. 周报提醒(静态消息)
|
||||
```
|
||||
用户: 每周五下午5点提醒我写周报
|
||||
Agent: [创建 cron: 0 17 * * 5]
|
||||
消息: 📊 该写周报了!
|
||||
```
|
||||
|
||||
### 4. 特定日期提醒
|
||||
```
|
||||
用户: 12月25日早上9点提醒我圣诞快乐
|
||||
Agent: [创建 once: 2024-12-25T09:00:00]
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **时区**: 使用系统本地时区
|
||||
2. **精度**: 检查间隔为 30 秒,实际执行可能有 ±30 秒误差
|
||||
3. **持久化**: 任务保存在文件中,重启后自动恢复
|
||||
4. **一次性任务**: 执行后自动禁用,不会删除(可手动删除)
|
||||
5. **错误处理**: 执行失败会记录错误,不影响其他任务
|
||||
|
||||
## 技术实现
|
||||
|
||||
- **TaskStore**: 任务持久化存储
|
||||
- **SchedulerService**: 后台调度服务
|
||||
- **SchedulerTool**: Agent 工具接口
|
||||
- **Integration**: 与 AgentBridge 集成
|
||||
|
||||
## 依赖
|
||||
|
||||
- `croniter`: Cron 表达式解析(轻量级,仅 ~50KB)
|
||||
7
agent/tools/scheduler/__init__.py
Normal file
7
agent/tools/scheduler/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Scheduler tool for managing scheduled tasks
|
||||
"""
|
||||
|
||||
from .scheduler_tool import SchedulerTool
|
||||
|
||||
__all__ = ["SchedulerTool"]
|
||||
464
agent/tools/scheduler/integration.py
Normal file
464
agent/tools/scheduler/integration.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
Integration module for scheduler with AgentBridge
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
# Global scheduler service instance
|
||||
_scheduler_service = None
|
||||
_task_store = None
|
||||
|
||||
|
||||
def init_scheduler(agent_bridge) -> bool:
|
||||
"""
|
||||
Initialize scheduler service
|
||||
|
||||
Args:
|
||||
agent_bridge: AgentBridge instance
|
||||
|
||||
Returns:
|
||||
True if initialized successfully
|
||||
"""
|
||||
global _scheduler_service, _task_store
|
||||
|
||||
try:
|
||||
from agent.tools.scheduler.task_store import TaskStore
|
||||
from agent.tools.scheduler.scheduler_service import SchedulerService
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
|
||||
|
||||
# Create task store
|
||||
_task_store = TaskStore(store_path)
|
||||
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
|
||||
|
||||
# Create execute callback
|
||||
def execute_task_callback(task: dict):
|
||||
"""Callback to execute a scheduled task"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "agent_task":
|
||||
_execute_agent_task(task, agent_bridge)
|
||||
elif action_type == "send_message":
|
||||
# Legacy support for old tasks
|
||||
_execute_send_message(task, agent_bridge)
|
||||
elif action_type == "tool_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_tool_call(task, agent_bridge)
|
||||
elif action_type == "skill_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_skill_call(task, agent_bridge)
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
|
||||
|
||||
# Create scheduler service
|
||||
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
|
||||
_scheduler_service.start()
|
||||
|
||||
logger.debug("[Scheduler] Scheduler service initialized and started")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_task_store():
|
||||
"""Get the global task store instance"""
|
||||
return _task_store
|
||||
|
||||
|
||||
def get_scheduler_service():
|
||||
"""Get the global scheduler service instance"""
|
||||
return _scheduler_service
|
||||
|
||||
|
||||
def _execute_agent_task(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute an agent_task action - let Agent handle the task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
task_description = action.get("task_description")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not task_description:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No task_description specified")
|
||||
return
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Check for unsupported channels
|
||||
if channel_type == "dingtalk":
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk channel does not support scheduled messages (Stream mode limitation). Task will execute but message cannot be sent.")
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing agent task '{task_description}'")
|
||||
|
||||
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
|
||||
# Format: scheduler_<receiver>_<task_id> to ensure isolation
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
|
||||
# Create context for Agent
|
||||
context = Context(ContextType.TEXT, task_description)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
# Channel-specific setup
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk requires msg object, set to None for scheduled tasks
|
||||
context["msg"] = None
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the task
|
||||
# Mark this as a scheduled task execution to prevent recursive task creation
|
||||
context["is_scheduled_task"] = True
|
||||
|
||||
try:
|
||||
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
|
||||
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=False)
|
||||
|
||||
if reply and reply.content:
|
||||
# Send the reply via channel
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register request_id
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
request_id = context.get("request_id")
|
||||
if request_id:
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
# Send the reply
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send result: {e}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute task via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_agent_task: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def _execute_send_message(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a send_message action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
content = action.get("content", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Create context for sending message
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific context setup
|
||||
if channel_type == "web":
|
||||
# Web channel needs request_id
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
|
||||
# Use chat_id for groups, open_id for private chats
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
# Keep isgroup as is, but set msg to None (no original message to reply to)
|
||||
# Feishu channel will detect this and send as new message instead of reply
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk channel setup
|
||||
context["msg"] = None
|
||||
# 如果是单聊,需要传递 sender_staff_id
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
elif channel_type == "qq":
|
||||
context["msg"] = None
|
||||
|
||||
# Create reply
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_send_message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def _execute_tool_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a tool_call action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
tool_name = action.get("call_name") or action.get("tool_name")
|
||||
tool_params = action.get("call_params") or action.get("tool_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not tool_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified")
|
||||
return
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Get tool manager and create tool instance
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
tool_manager = ToolManager()
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
if not tool:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found")
|
||||
return
|
||||
|
||||
# Execute tool
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}")
|
||||
result = tool.execute(tool_params)
|
||||
|
||||
# Get result content
|
||||
if hasattr(result, 'result'):
|
||||
content = result.result
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
# Add prefix if specified
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
# Send result as message
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific context setup
|
||||
if channel_type == "web":
|
||||
# Web channel needs request_id
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send tool result: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
|
||||
|
||||
|
||||
def _execute_skill_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a skill_call action by asking Agent to run the skill
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
skill_name = action.get("call_name") or action.get("skill_name")
|
||||
skill_params = action.get("call_params") or action.get("skill_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("isgroup", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not skill_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified")
|
||||
return
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}")
|
||||
|
||||
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
|
||||
# Format: scheduler_<receiver>_<task_id> to ensure isolation
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
|
||||
# Build a natural language query for the Agent to execute the skill
|
||||
# Format: "Use skill-name to do something with params"
|
||||
param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()])
|
||||
query = f"Use {skill_name} skill"
|
||||
if param_str:
|
||||
query += f" with {param_str}"
|
||||
|
||||
# Create context for Agent
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
# Channel-specific setup
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
elif channel_type == "wecom_bot":
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the skill
|
||||
try:
|
||||
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
|
||||
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=False)
|
||||
|
||||
if reply and reply.content:
|
||||
content = reply.content
|
||||
|
||||
# Add prefix if specified
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_skill_call: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
"""
|
||||
Attach scheduler components to a SchedulerTool instance
|
||||
|
||||
Args:
|
||||
tool: SchedulerTool instance
|
||||
context: Current context (optional)
|
||||
"""
|
||||
if _task_store:
|
||||
tool.task_store = _task_store
|
||||
|
||||
if context:
|
||||
tool.current_context = context
|
||||
|
||||
channel_type = context.get("channel_type") or conf().get("channel_type", "unknown")
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = channel_type
|
||||
213
agent/tools/scheduler/scheduler_service.py
Normal file
213
agent/tools/scheduler/scheduler_service.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
Background scheduler service for executing scheduled tasks
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Optional
|
||||
from croniter import croniter
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""
|
||||
Background service that executes scheduled tasks
|
||||
"""
|
||||
|
||||
def __init__(self, task_store, execute_callback: Callable):
|
||||
"""
|
||||
Initialize scheduler service
|
||||
|
||||
Args:
|
||||
task_store: TaskStore instance
|
||||
execute_callback: Function to call when executing a task
|
||||
"""
|
||||
self.task_store = task_store
|
||||
self.execute_callback = execute_callback
|
||||
self.running = False
|
||||
self.thread = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def start(self):
|
||||
"""Start the scheduler service"""
|
||||
with self._lock:
|
||||
if self.running:
|
||||
logger.warning("[Scheduler] Service already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self.thread.start()
|
||||
logger.debug("[Scheduler] Service started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the scheduler service"""
|
||||
with self._lock:
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=5)
|
||||
logger.info("[Scheduler] Service stopped")
|
||||
|
||||
def _run_loop(self):
|
||||
"""Main scheduler loop"""
|
||||
logger.debug("[Scheduler] Scheduler loop started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self._check_and_execute_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in scheduler loop: {e}")
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
def _check_and_execute_tasks(self):
|
||||
"""Check for due tasks and execute them"""
|
||||
now = datetime.now()
|
||||
tasks = self.task_store.list_tasks(enabled_only=True)
|
||||
|
||||
for task in tasks:
|
||||
try:
|
||||
# Check if task is due
|
||||
if self._is_task_due(task, now):
|
||||
logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}")
|
||||
self._execute_task(task)
|
||||
|
||||
# Update next run time
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_run.isoformat(),
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
else:
|
||||
# One-time task completed, remove it
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task completed and removed: {task['id']}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
|
||||
|
||||
def _is_task_due(self, task: dict, now: datetime) -> bool:
|
||||
"""
|
||||
Check if a task is due to run
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
now: Current datetime
|
||||
|
||||
Returns:
|
||||
True if task should run now
|
||||
"""
|
||||
next_run_str = task.get("next_run_at")
|
||||
if not next_run_str:
|
||||
# Calculate initial next_run_at
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_run.isoformat()
|
||||
})
|
||||
return False
|
||||
return False
|
||||
|
||||
try:
|
||||
next_run = datetime.fromisoformat(next_run_str)
|
||||
|
||||
# Check if task is overdue (e.g., service restart)
|
||||
if next_run < now:
|
||||
time_diff = (now - next_run).total_seconds()
|
||||
|
||||
# If overdue by more than 5 minutes, skip this run and schedule next
|
||||
if time_diff > 300: # 5 minutes
|
||||
logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run")
|
||||
|
||||
# For one-time tasks, remove them directly
|
||||
schedule = task.get("schedule", {})
|
||||
if schedule.get("type") == "once":
|
||||
self.task_store.delete_task(task['id'])
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, removed")
|
||||
return False
|
||||
|
||||
# For recurring tasks, calculate next run from now
|
||||
next_next_run = self._calculate_next_run(task, now)
|
||||
if next_next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_next_run.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
|
||||
return False
|
||||
|
||||
return now >= next_run
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
|
||||
"""
|
||||
Calculate next run time for a task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
from_time: Calculate from this time
|
||||
|
||||
Returns:
|
||||
Next run datetime or None for one-time tasks
|
||||
"""
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
if schedule_type == "cron":
|
||||
# Cron expression
|
||||
expression = schedule.get("expression")
|
||||
if not expression:
|
||||
return None
|
||||
|
||||
try:
|
||||
cron = croniter(expression, from_time)
|
||||
return cron.get_next(datetime)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Invalid cron expression '{expression}': {e}")
|
||||
return None
|
||||
|
||||
elif schedule_type == "interval":
|
||||
# Interval in seconds
|
||||
seconds = schedule.get("seconds", 0)
|
||||
if seconds <= 0:
|
||||
return None
|
||||
return from_time + timedelta(seconds=seconds)
|
||||
|
||||
elif schedule_type == "once":
|
||||
# One-time task at specific time
|
||||
run_at_str = schedule.get("run_at")
|
||||
if not run_at_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
run_at = datetime.fromisoformat(run_at_str)
|
||||
# Only return if in the future
|
||||
if run_at > from_time:
|
||||
return run_at
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _execute_task(self, task: dict):
|
||||
"""
|
||||
Execute a task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
"""
|
||||
try:
|
||||
# Call the execute callback
|
||||
self.execute_callback(task)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task['id']}: {e}")
|
||||
# Update task with error
|
||||
self.task_store.update_task(task['id'], {
|
||||
"last_error": str(e),
|
||||
"last_error_at": datetime.now().isoformat()
|
||||
})
|
||||
443
agent/tools/scheduler/scheduler_tool.py
Normal file
443
agent/tools/scheduler/scheduler_tool.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Scheduler tool for creating and managing scheduled tasks
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from croniter import croniter
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class SchedulerTool(BaseTool):
|
||||
"""
|
||||
Tool for managing scheduled tasks (reminders, notifications, etc.)
|
||||
"""
|
||||
|
||||
name: str = "scheduler"
|
||||
description: str = (
|
||||
"创建、查询和管理定时任务(提醒、周期性任务等)。\n\n"
|
||||
"⚠️ 重要:仅当需要「定时/提醒/每天/每周/X分钟后/X点」等延迟或周期执行时才使用此工具。"
|
||||
"使用方法:\n"
|
||||
"- 创建:action='create', name='任务名', message/ai_task='内容', schedule_type='once/interval/cron', schedule_value='...'\n"
|
||||
"- 查询:action='list' / action='get', task_id='任务ID'\n"
|
||||
"- 管理:action='delete/enable/disable', task_id='任务ID'\n\n"
|
||||
"调度类型:\n"
|
||||
"- once: 一次性任务,支持相对时间(+5s,+10m,+1h,+1d)或ISO时间\n"
|
||||
"- interval: 固定间隔(秒),如3600表示每小时\n"
|
||||
"- cron: cron表达式,如'0 8 * * *'表示每天8点\n\n"
|
||||
"注意:'X秒后'用once+相对时间,'每X秒'用interval"
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["create", "list", "get", "delete", "enable", "disable"],
|
||||
"description": "操作类型: create(创建), list(列表), get(查询), delete(删除), enable(启用), disable(禁用)"
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "任务ID (用于 get/delete/enable/disable 操作)"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "任务名称 (用于 create 操作)"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "固定消息内容 (与ai_task二选一)"
|
||||
},
|
||||
"ai_task": {
|
||||
"type": "string",
|
||||
"description": "AI任务描述 (与message二选一),用于定时让AI执行的任务"
|
||||
},
|
||||
"schedule_type": {
|
||||
"type": "string",
|
||||
"enum": ["cron", "interval", "once"],
|
||||
"description": "调度类型 (用于 create 操作): cron(cron表达式), interval(固定间隔秒数), once(一次性)"
|
||||
},
|
||||
"schedule_value": {
|
||||
"type": "string",
|
||||
"description": "调度值: cron表达式/间隔秒数/时间(+5s,+10m,+1h或ISO格式)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
super().__init__()
|
||||
self.config = config or {}
|
||||
|
||||
# Will be set by agent bridge
|
||||
self.task_store = None
|
||||
self.current_context = None
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
"""
|
||||
Execute scheduler operations
|
||||
|
||||
Args:
|
||||
params: Dictionary containing:
|
||||
- action: Operation type (create/list/get/delete/enable/disable)
|
||||
- Other parameters depending on action
|
||||
|
||||
Returns:
|
||||
ToolResult object
|
||||
"""
|
||||
# Extract parameters
|
||||
action = params.get("action")
|
||||
kwargs = params
|
||||
|
||||
if not self.task_store:
|
||||
return ToolResult.fail("错误: 定时任务系统未初始化")
|
||||
|
||||
try:
|
||||
if action == "create":
|
||||
result = self._create_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "list":
|
||||
result = self._list_tasks(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "get":
|
||||
result = self._get_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "delete":
|
||||
result = self._delete_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "enable":
|
||||
result = self._enable_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "disable":
|
||||
result = self._disable_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
else:
|
||||
return ToolResult.fail(f"未知操作: {action}")
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Error: {e}")
|
||||
return ToolResult.fail(f"操作失败: {str(e)}")
|
||||
|
||||
def _create_task(self, **kwargs) -> str:
|
||||
"""Create a new scheduled task"""
|
||||
name = kwargs.get("name")
|
||||
message = kwargs.get("message")
|
||||
ai_task = kwargs.get("ai_task")
|
||||
schedule_type = kwargs.get("schedule_type")
|
||||
schedule_value = kwargs.get("schedule_value")
|
||||
|
||||
# Validate required fields
|
||||
if not name:
|
||||
return "错误: 缺少任务名称 (name)"
|
||||
|
||||
# Check that exactly one of message/ai_task is provided
|
||||
if not message and not ai_task:
|
||||
return "错误: 必须提供 message(固定消息)或 ai_task(AI任务)之一"
|
||||
if message and ai_task:
|
||||
return "错误: message 和 ai_task 只能提供其中一个"
|
||||
|
||||
if not schedule_type:
|
||||
return "错误: 缺少调度类型 (schedule_type)"
|
||||
if not schedule_value:
|
||||
return "错误: 缺少调度值 (schedule_value)"
|
||||
|
||||
# Validate schedule
|
||||
schedule = self._parse_schedule(schedule_type, schedule_value)
|
||||
if not schedule:
|
||||
return f"错误: 无效的调度配置 - type: {schedule_type}, value: {schedule_value}"
|
||||
|
||||
# Get context info for receiver
|
||||
if not self.current_context:
|
||||
return "错误: 无法获取当前对话上下文"
|
||||
|
||||
context = self.current_context
|
||||
|
||||
# Create task
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Build action based on message or ai_task
|
||||
if message:
|
||||
action = {
|
||||
"type": "send_message",
|
||||
"content": message,
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
}
|
||||
else: # ai_task
|
||||
action = {
|
||||
"type": "agent_task",
|
||||
"task_description": ai_task,
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
}
|
||||
|
||||
# 针对钉钉单聊,额外存储 sender_staff_id
|
||||
msg = context.kwargs.get("msg")
|
||||
if msg and hasattr(msg, 'sender_staff_id') and not context.get("isgroup", False):
|
||||
action["dingtalk_sender_staff_id"] = msg.sender_staff_id
|
||||
|
||||
task_data = {
|
||||
"id": task_id,
|
||||
"name": name,
|
||||
"enabled": True,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"schedule": schedule,
|
||||
"action": action
|
||||
}
|
||||
|
||||
# Calculate initial next_run_at
|
||||
next_run = self._calculate_next_run(task_data)
|
||||
if next_run:
|
||||
task_data["next_run_at"] = next_run.isoformat()
|
||||
|
||||
# Save task
|
||||
self.task_store.add_task(task_data)
|
||||
|
||||
# Format response
|
||||
schedule_desc = self._format_schedule_description(schedule)
|
||||
receiver_desc = task_data["action"]["receiver_name"] or task_data["action"]["receiver"]
|
||||
|
||||
if message:
|
||||
content_desc = f"💬 固定消息: {message}"
|
||||
else:
|
||||
content_desc = f"🤖 AI任务: {ai_task}"
|
||||
|
||||
return (
|
||||
f"✅ 定时任务创建成功\n\n"
|
||||
f"📋 任务ID: {task_id}\n"
|
||||
f"📝 名称: {name}\n"
|
||||
f"⏰ 调度: {schedule_desc}\n"
|
||||
f"👤 接收者: {receiver_desc}\n"
|
||||
f"{content_desc}\n"
|
||||
f"🕐 下次执行: {next_run.strftime('%Y-%m-%d %H:%M:%S') if next_run else '未知'}"
|
||||
)
|
||||
|
||||
def _list_tasks(self, **kwargs) -> str:
|
||||
"""List all tasks"""
|
||||
tasks = self.task_store.list_tasks()
|
||||
|
||||
if not tasks:
|
||||
return "📋 暂无定时任务"
|
||||
|
||||
lines = [f"📋 定时任务列表 (共 {len(tasks)} 个)\n"]
|
||||
|
||||
for task in tasks:
|
||||
status = "✅" if task.get("enabled", True) else "❌"
|
||||
schedule_desc = self._format_schedule_description(task.get("schedule", {}))
|
||||
next_run = task.get("next_run_at")
|
||||
next_run_str = datetime.fromisoformat(next_run).strftime('%m-%d %H:%M') if next_run else "未知"
|
||||
|
||||
lines.append(
|
||||
f"{status} [{task['id']}] {task['name']}\n"
|
||||
f" ⏰ {schedule_desc} | 下次: {next_run_str}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _get_task(self, **kwargs) -> str:
|
||||
"""Get task details"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
status = "启用" if task.get("enabled", True) else "禁用"
|
||||
schedule_desc = self._format_schedule_description(task.get("schedule", {}))
|
||||
action = task.get("action", {})
|
||||
next_run = task.get("next_run_at")
|
||||
next_run_str = datetime.fromisoformat(next_run).strftime('%Y-%m-%d %H:%M:%S') if next_run else "未知"
|
||||
last_run = task.get("last_run_at")
|
||||
last_run_str = datetime.fromisoformat(last_run).strftime('%Y-%m-%d %H:%M:%S') if last_run else "从未执行"
|
||||
|
||||
return (
|
||||
f"📋 任务详情\n\n"
|
||||
f"ID: {task['id']}\n"
|
||||
f"名称: {task['name']}\n"
|
||||
f"状态: {status}\n"
|
||||
f"调度: {schedule_desc}\n"
|
||||
f"接收者: {action.get('receiver_name', action.get('receiver'))}\n"
|
||||
f"消息: {action.get('content')}\n"
|
||||
f"下次执行: {next_run_str}\n"
|
||||
f"上次执行: {last_run_str}\n"
|
||||
f"创建时间: {datetime.fromisoformat(task['created_at']).strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
|
||||
def _delete_task(self, **kwargs) -> str:
|
||||
"""Delete a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.delete_task(task_id)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已删除"
|
||||
|
||||
def _enable_task(self, **kwargs) -> str:
|
||||
"""Enable a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.enable_task(task_id, True)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已启用"
|
||||
|
||||
def _disable_task(self, **kwargs) -> str:
|
||||
"""Disable a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.enable_task(task_id, False)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已禁用"
|
||||
|
||||
def _parse_schedule(self, schedule_type: str, schedule_value: str) -> Optional[dict]:
|
||||
"""Parse and validate schedule configuration"""
|
||||
try:
|
||||
if schedule_type == "cron":
|
||||
# Validate cron expression
|
||||
croniter(schedule_value)
|
||||
return {"type": "cron", "expression": schedule_value}
|
||||
|
||||
elif schedule_type == "interval":
|
||||
# Parse interval in seconds
|
||||
seconds = int(schedule_value)
|
||||
if seconds <= 0:
|
||||
return None
|
||||
return {"type": "interval", "seconds": seconds}
|
||||
|
||||
elif schedule_type == "once":
|
||||
# Parse datetime - support both relative and absolute time
|
||||
|
||||
# Check if it's relative time (e.g., "+5s", "+10m", "+1h", "+1d")
|
||||
if schedule_value.startswith("+"):
|
||||
import re
|
||||
match = re.match(r'\+(\d+)([smhd])', schedule_value)
|
||||
if match:
|
||||
amount = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
from datetime import timedelta
|
||||
now = datetime.now()
|
||||
|
||||
if unit == 's': # seconds
|
||||
target_time = now + timedelta(seconds=amount)
|
||||
elif unit == 'm': # minutes
|
||||
target_time = now + timedelta(minutes=amount)
|
||||
elif unit == 'h': # hours
|
||||
target_time = now + timedelta(hours=amount)
|
||||
elif unit == 'd': # days
|
||||
target_time = now + timedelta(days=amount)
|
||||
else:
|
||||
return None
|
||||
|
||||
return {"type": "once", "run_at": target_time.isoformat()}
|
||||
else:
|
||||
logger.error(f"[SchedulerTool] Invalid relative time format: {schedule_value}")
|
||||
return None
|
||||
else:
|
||||
# Absolute time in ISO format
|
||||
datetime.fromisoformat(schedule_value)
|
||||
return {"type": "once", "run_at": schedule_value}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Invalid schedule: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_next_run(self, task: dict) -> Optional[datetime]:
|
||||
"""Calculate next run time for a task"""
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
now = datetime.now()
|
||||
|
||||
if schedule_type == "cron":
|
||||
expression = schedule.get("expression")
|
||||
cron = croniter(expression, now)
|
||||
return cron.get_next(datetime)
|
||||
|
||||
elif schedule_type == "interval":
|
||||
seconds = schedule.get("seconds", 0)
|
||||
from datetime import timedelta
|
||||
return now + timedelta(seconds=seconds)
|
||||
|
||||
elif schedule_type == "once":
|
||||
run_at_str = schedule.get("run_at")
|
||||
return datetime.fromisoformat(run_at_str)
|
||||
|
||||
return None
|
||||
|
||||
def _format_schedule_description(self, schedule: dict) -> str:
|
||||
"""Format schedule as human-readable description"""
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
if schedule_type == "cron":
|
||||
expr = schedule.get("expression", "")
|
||||
# Try to provide friendly description
|
||||
if expr == "0 9 * * *":
|
||||
return "每天 9:00"
|
||||
elif expr == "0 */1 * * *":
|
||||
return "每小时"
|
||||
elif expr == "*/30 * * * *":
|
||||
return "每30分钟"
|
||||
else:
|
||||
return f"Cron: {expr}"
|
||||
|
||||
elif schedule_type == "interval":
|
||||
seconds = schedule.get("seconds", 0)
|
||||
if seconds >= 86400:
|
||||
days = seconds // 86400
|
||||
return f"每 {days} 天"
|
||||
elif seconds >= 3600:
|
||||
hours = seconds // 3600
|
||||
return f"每 {hours} 小时"
|
||||
elif seconds >= 60:
|
||||
minutes = seconds // 60
|
||||
return f"每 {minutes} 分钟"
|
||||
else:
|
||||
return f"每 {seconds} 秒"
|
||||
|
||||
elif schedule_type == "once":
|
||||
run_at = schedule.get("run_at", "")
|
||||
try:
|
||||
dt = datetime.fromisoformat(run_at)
|
||||
return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
|
||||
except Exception:
|
||||
return "一次性"
|
||||
|
||||
return "未知"
|
||||
|
||||
def _get_receiver_name(self, context: Context) -> str:
|
||||
"""Get receiver name from context"""
|
||||
try:
|
||||
msg = context.get("msg")
|
||||
if msg:
|
||||
if context.get("isgroup"):
|
||||
return msg.other_user_nickname or "群聊"
|
||||
else:
|
||||
return msg.from_user_nickname or "用户"
|
||||
except Exception:
|
||||
pass
|
||||
return "未知"
|
||||
201
agent/tools/scheduler/task_store.py
Normal file
201
agent/tools/scheduler/task_store.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Task storage management for scheduler
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class TaskStore:
|
||||
"""
|
||||
Manages persistent storage of scheduled tasks
|
||||
"""
|
||||
|
||||
def __init__(self, store_path: str = None):
|
||||
"""
|
||||
Initialize task store
|
||||
|
||||
Args:
|
||||
store_path: Path to tasks.json file. Defaults to ~/cow/scheduler/tasks.json
|
||||
"""
|
||||
if store_path is None:
|
||||
# Default to ~/cow/scheduler/tasks.json
|
||||
home = expand_path("~")
|
||||
store_path = os.path.join(home, "cow", "scheduler", "tasks.json")
|
||||
|
||||
self.store_path = store_path
|
||||
self.lock = threading.Lock()
|
||||
self._ensure_store_dir()
|
||||
|
||||
def _ensure_store_dir(self):
|
||||
"""Ensure the storage directory exists"""
|
||||
store_dir = os.path.dirname(self.store_path)
|
||||
os.makedirs(store_dir, exist_ok=True)
|
||||
|
||||
def load_tasks(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Load all tasks from storage
|
||||
|
||||
Returns:
|
||||
Dictionary of task_id -> task_data
|
||||
"""
|
||||
with self.lock:
|
||||
if not os.path.exists(self.store_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(self.store_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data.get("tasks", {})
|
||||
except Exception as e:
|
||||
print(f"Error loading tasks: {e}")
|
||||
return {}
|
||||
|
||||
def save_tasks(self, tasks: Dict[str, dict]):
|
||||
"""
|
||||
Save all tasks to storage
|
||||
|
||||
Args:
|
||||
tasks: Dictionary of task_id -> task_data
|
||||
"""
|
||||
with self.lock:
|
||||
try:
|
||||
# Create backup
|
||||
if os.path.exists(self.store_path):
|
||||
backup_path = f"{self.store_path}.bak"
|
||||
try:
|
||||
with open(self.store_path, 'r') as src:
|
||||
with open(backup_path, 'w') as dst:
|
||||
dst.write(src.read())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save tasks
|
||||
data = {
|
||||
"version": 1,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"tasks": tasks
|
||||
}
|
||||
|
||||
with open(self.store_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving tasks: {e}")
|
||||
raise
|
||||
|
||||
def add_task(self, task: dict) -> bool:
|
||||
"""
|
||||
Add a new task
|
||||
|
||||
Args:
|
||||
task: Task data dictionary
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
task_id = task.get("id")
|
||||
|
||||
if not task_id:
|
||||
raise ValueError("Task must have an 'id' field")
|
||||
|
||||
if task_id in tasks:
|
||||
raise ValueError(f"Task with id '{task_id}' already exists")
|
||||
|
||||
tasks[task_id] = task
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def update_task(self, task_id: str, updates: dict) -> bool:
|
||||
"""
|
||||
Update an existing task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
updates: Dictionary of fields to update
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
|
||||
if task_id not in tasks:
|
||||
raise ValueError(f"Task '{task_id}' not found")
|
||||
|
||||
# Update fields
|
||||
tasks[task_id].update(updates)
|
||||
tasks[task_id]["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def delete_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
|
||||
if task_id not in tasks:
|
||||
raise ValueError(f"Task '{task_id}' not found")
|
||||
|
||||
del tasks[task_id]
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get a specific task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Task data or None if not found
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
return tasks.get(task_id)
|
||||
|
||||
def list_tasks(self, enabled_only: bool = False) -> List[dict]:
|
||||
"""
|
||||
List all tasks
|
||||
|
||||
Args:
|
||||
enabled_only: If True, only return enabled tasks
|
||||
|
||||
Returns:
|
||||
List of task dictionaries
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
task_list = list(tasks.values())
|
||||
|
||||
if enabled_only:
|
||||
task_list = [t for t in task_list if t.get("enabled", True)]
|
||||
|
||||
# Sort by next_run_at
|
||||
task_list.sort(key=lambda t: t.get("next_run_at", float('inf')))
|
||||
|
||||
return task_list
|
||||
|
||||
def enable_task(self, task_id: str, enabled: bool = True) -> bool:
|
||||
"""
|
||||
Enable or disable a task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
enabled: True to enable, False to disable
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
return self.update_task(task_id, {"enabled": enabled})
|
||||
3
agent/tools/send/__init__.py
Normal file
3
agent/tools/send/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .send import Send
|
||||
|
||||
__all__ = ['Send']
|
||||
160
agent/tools/send/send.py
Normal file
160
agent/tools/send/send.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Send tool - Send files to the user
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Send(BaseTool):
|
||||
"""Tool for sending files to the user"""
|
||||
|
||||
name: str = "send"
|
||||
description: str = "Send a LOCAL file (image, video, audio, document) to the user. Only for local file paths. Do NOT use this for URLs — URLs should be included directly in your text reply, the system will handle them automatically."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Local file path to send. Must be an absolute path or relative to workspace. Do NOT pass URLs here."
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Optional message to accompany the file"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
# Supported file types
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
|
||||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
|
||||
self.document_extensions = {'.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.txt', '.md'}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file send operation
|
||||
|
||||
:param args: Contains file path and optional message
|
||||
:return: File metadata for channel to send
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
message = args.get("message", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable
|
||||
if not os.access(absolute_path, os.R_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable: {path}")
|
||||
|
||||
# Get file info
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
file_name = Path(absolute_path).name
|
||||
|
||||
# Determine file type
|
||||
if file_ext in self.image_extensions:
|
||||
file_type = "image"
|
||||
mime_type = self._get_image_mime_type(file_ext)
|
||||
elif file_ext in self.video_extensions:
|
||||
file_type = "video"
|
||||
mime_type = self._get_video_mime_type(file_ext)
|
||||
elif file_ext in self.audio_extensions:
|
||||
file_type = "audio"
|
||||
mime_type = self._get_audio_mime_type(file_ext)
|
||||
elif file_ext in self.document_extensions:
|
||||
file_type = "document"
|
||||
mime_type = self._get_document_mime_type(file_ext)
|
||||
else:
|
||||
file_type = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# Return file_to_send metadata
|
||||
result = {
|
||||
"type": "file_to_send",
|
||||
"file_type": file_type,
|
||||
"path": absolute_path,
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": self._format_size(file_size),
|
||||
"message": message or f"正在发送 {file_name}"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _get_image_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for image"""
|
||||
mime_map = {
|
||||
'.jpg': 'image/jpeg', '.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png', '.gif': 'image/gif',
|
||||
'.webp': 'image/webp', '.bmp': 'image/bmp',
|
||||
'.svg': 'image/svg+xml', '.ico': 'image/x-icon'
|
||||
}
|
||||
return mime_map.get(ext, 'image/jpeg')
|
||||
|
||||
def _get_video_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for video"""
|
||||
mime_map = {
|
||||
'.mp4': 'video/mp4', '.avi': 'video/x-msvideo',
|
||||
'.mov': 'video/quicktime', '.mkv': 'video/x-matroska',
|
||||
'.webm': 'video/webm', '.flv': 'video/x-flv'
|
||||
}
|
||||
return mime_map.get(ext, 'video/mp4')
|
||||
|
||||
def _get_audio_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for audio"""
|
||||
mime_map = {
|
||||
'.mp3': 'audio/mpeg', '.wav': 'audio/wav',
|
||||
'.ogg': 'audio/ogg', '.m4a': 'audio/mp4',
|
||||
'.flac': 'audio/flac', '.aac': 'audio/aac'
|
||||
}
|
||||
return mime_map.get(ext, 'audio/mpeg')
|
||||
|
||||
def _get_document_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for document"""
|
||||
mime_map = {
|
||||
'.pdf': 'application/pdf',
|
||||
'.doc': 'application/msword',
|
||||
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'.xls': 'application/vnd.ms-excel',
|
||||
'.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'.ppt': 'application/vnd.ms-powerpoint',
|
||||
'.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'.txt': 'text/plain',
|
||||
'.md': 'text/markdown'
|
||||
}
|
||||
return mime_map.get(ext, 'application/octet-stream')
|
||||
|
||||
def _format_size(self, size_bytes: int) -> str:
|
||||
"""Format file size in human-readable format"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f}{unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f}TB"
|
||||
248
agent/tools/tool_manager.py
Normal file
248
agent/tools/tool_manager.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Type
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
Tool manager for managing tools.
|
||||
"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern to ensure only one instance of ToolManager exists."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ToolManager, cls).__new__(cls)
|
||||
cls._instance.tool_classes = {} # Store tool classes instead of instances
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Initialize only once
|
||||
if not hasattr(self, 'tool_classes'):
|
||||
self.tool_classes = {} # Dictionary to store tool classes
|
||||
|
||||
def load_tools(self, tools_dir: str = "", config_dict=None):
|
||||
"""
|
||||
Load tools from both directory and configuration.
|
||||
|
||||
:param tools_dir: Directory to scan for tool modules
|
||||
"""
|
||||
if tools_dir:
|
||||
self._load_tools_from_directory(tools_dir)
|
||||
self._configure_tools_from_config()
|
||||
else:
|
||||
self._load_tools_from_init()
|
||||
self._configure_tools_from_config(config_dict)
|
||||
|
||||
def _load_tools_from_init(self) -> bool:
|
||||
"""
|
||||
Load tool classes from tools.__init__.__all__
|
||||
|
||||
:return: True if tools were loaded, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to import the tools package
|
||||
tools_package = importlib.import_module("agent.tools")
|
||||
|
||||
# Check if __all__ is defined
|
||||
if hasattr(tools_package, "__all__"):
|
||||
tool_classes = tools_package.__all__
|
||||
|
||||
# Import each tool class directly from the tools package
|
||||
for class_name in tool_classes:
|
||||
try:
|
||||
# Skip base classes
|
||||
if class_name in ["BaseTool", "ToolManager"]:
|
||||
continue
|
||||
|
||||
# Get the class directly from the tools package
|
||||
if hasattr(tools_package, class_name):
|
||||
cls = getattr(tools_package, class_name)
|
||||
|
||||
if (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, BaseTool)
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip memory tools (they need special initialization with memory_manager)
|
||||
if class_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {class_name} (requires memory_manager)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
tool_name = temp_instance.name
|
||||
# Store the class, not the instance
|
||||
self.tool_classes[tool_name] = cls
|
||||
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
|
||||
f" Install with: pip install markdownify"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing class {class_name}: {e}")
|
||||
|
||||
return len(self.tool_classes) > 0
|
||||
return False
|
||||
except ImportError:
|
||||
logger.warning("Could not import agent.tools package")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading tools from __init__.__all__: {e}")
|
||||
return False
|
||||
|
||||
def _load_tools_from_directory(self, tools_dir: str):
|
||||
"""Dynamically load tool classes from directory"""
|
||||
tools_path = Path(tools_dir)
|
||||
|
||||
# Traverse all .py files
|
||||
for py_file in tools_path.rglob("*.py"):
|
||||
# Skip initialization files and base tool files
|
||||
if py_file.name in ["__init__.py", "base_tool.py", "tool_manager.py"]:
|
||||
continue
|
||||
|
||||
# Get module name
|
||||
module_name = py_file.stem
|
||||
|
||||
try:
|
||||
# Load module directly from file
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Find tool classes in the module
|
||||
for attr_name in dir(module):
|
||||
cls = getattr(module, attr_name)
|
||||
if (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, BaseTool)
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip memory tools (they need special initialization with memory_manager)
|
||||
if attr_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {attr_name} (requires memory_manager)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
tool_name = temp_instance.name
|
||||
# Store the class, not the instance
|
||||
self.tool_classes[tool_name] = cls
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
|
||||
f" Install with: pip install markdownify"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error importing module {py_file}: {e}")
|
||||
|
||||
def _configure_tools_from_config(self, config_dict=None):
|
||||
"""Configure tool classes based on configuration file"""
|
||||
try:
|
||||
# Get tools configuration
|
||||
tools_config = config_dict or conf().get("tools", {})
|
||||
|
||||
# Record tools that are configured but not loaded
|
||||
missing_tools = []
|
||||
|
||||
# Store configurations for later use when instantiating
|
||||
self.tool_configs = tools_config
|
||||
|
||||
# Check which configured tools are missing
|
||||
for tool_name in tools_config:
|
||||
if tool_name not in self.tool_classes:
|
||||
missing_tools.append(tool_name)
|
||||
|
||||
# If there are missing tools, record warnings
|
||||
if missing_tools:
|
||||
for tool_name in missing_tools:
|
||||
if tool_name == "browser":
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool is configured but not loaded.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif tool_name == "google_search":
|
||||
logger.warning(
|
||||
f"[ToolManager] Google Search tool is configured but may need API key.\n"
|
||||
f" Get API key from: https://serper.dev\n"
|
||||
f" Configure in config.json: tools.google_search.api_key"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] Tool '{tool_name}' is configured but could not be loaded.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error configuring tools from config: {e}")
|
||||
|
||||
def create_tool(self, name: str) -> BaseTool:
|
||||
"""
|
||||
Get a new instance of a tool by name.
|
||||
|
||||
:param name: The name of the tool to get.
|
||||
:return: A new instance of the tool or None if not found.
|
||||
"""
|
||||
tool_class = self.tool_classes.get(name)
|
||||
if tool_class:
|
||||
# Create a new instance
|
||||
tool_instance = tool_class()
|
||||
|
||||
# Apply configuration if available
|
||||
if hasattr(self, 'tool_configs') and name in self.tool_configs:
|
||||
tool_instance.config = self.tool_configs[name]
|
||||
|
||||
return tool_instance
|
||||
return None
|
||||
|
||||
def list_tools(self) -> dict:
|
||||
"""
|
||||
Get information about all loaded tools.
|
||||
|
||||
:return: A dictionary with tool information.
|
||||
"""
|
||||
result = {}
|
||||
for name, tool_class in self.tool_classes.items():
|
||||
# Create a temporary instance to get schema
|
||||
temp_instance = tool_class()
|
||||
result[name] = {
|
||||
"description": temp_instance.description,
|
||||
"parameters": temp_instance.get_json_schema()
|
||||
}
|
||||
return result
|
||||
40
agent/tools/utils/__init__.py
Normal file
40
agent/tools/utils/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from .truncate import (
|
||||
truncate_head,
|
||||
truncate_tail,
|
||||
truncate_line,
|
||||
format_size,
|
||||
TruncationResult,
|
||||
DEFAULT_MAX_LINES,
|
||||
DEFAULT_MAX_BYTES,
|
||||
GREP_MAX_LINE_LENGTH
|
||||
)
|
||||
|
||||
from .diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
normalize_to_lf,
|
||||
restore_line_endings,
|
||||
normalize_for_fuzzy_match,
|
||||
fuzzy_find_text,
|
||||
generate_diff_string,
|
||||
FuzzyMatchResult
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'truncate_head',
|
||||
'truncate_tail',
|
||||
'truncate_line',
|
||||
'format_size',
|
||||
'TruncationResult',
|
||||
'DEFAULT_MAX_LINES',
|
||||
'DEFAULT_MAX_BYTES',
|
||||
'GREP_MAX_LINE_LENGTH',
|
||||
'strip_bom',
|
||||
'detect_line_ending',
|
||||
'normalize_to_lf',
|
||||
'restore_line_endings',
|
||||
'normalize_for_fuzzy_match',
|
||||
'fuzzy_find_text',
|
||||
'generate_diff_string',
|
||||
'FuzzyMatchResult'
|
||||
]
|
||||
167
agent/tools/utils/diff.py
Normal file
167
agent/tools/utils/diff.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Diff tools for file editing
|
||||
Provides fuzzy matching and diff generation functionality
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def strip_bom(text: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Remove BOM (Byte Order Mark)
|
||||
|
||||
:param text: Original text
|
||||
:return: (BOM, text after removing BOM)
|
||||
"""
|
||||
if text.startswith('\ufeff'):
|
||||
return '\ufeff', text[1:]
|
||||
return '', text
|
||||
|
||||
|
||||
def detect_line_ending(text: str) -> str:
|
||||
"""
|
||||
Detect line ending type
|
||||
|
||||
:param text: Text content
|
||||
:return: Line ending type ('\r\n' or '\n')
|
||||
"""
|
||||
if '\r\n' in text:
|
||||
return '\r\n'
|
||||
return '\n'
|
||||
|
||||
|
||||
def normalize_to_lf(text: str) -> str:
|
||||
"""
|
||||
Normalize all line endings to LF (\n)
|
||||
|
||||
:param text: Original text
|
||||
:return: Normalized text
|
||||
"""
|
||||
return text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
|
||||
def restore_line_endings(text: str, original_ending: str) -> str:
|
||||
"""
|
||||
Restore original line endings
|
||||
|
||||
:param text: LF normalized text
|
||||
:param original_ending: Original line ending
|
||||
:return: Text with restored line endings
|
||||
"""
|
||||
if original_ending == '\r\n':
|
||||
return text.replace('\n', '\r\n')
|
||||
return text
|
||||
|
||||
|
||||
def normalize_for_fuzzy_match(text: str) -> str:
|
||||
"""
|
||||
Normalize text for fuzzy matching
|
||||
Remove excess whitespace but preserve basic structure
|
||||
|
||||
:param text: Original text
|
||||
:return: Normalized text
|
||||
"""
|
||||
# Compress multiple spaces to one
|
||||
text = re.sub(r'[ \t]+', ' ', text)
|
||||
# Remove trailing spaces
|
||||
text = re.sub(r' +\n', '\n', text)
|
||||
# Remove leading spaces (but preserve indentation structure, only remove excess)
|
||||
lines = text.split('\n')
|
||||
normalized_lines = []
|
||||
for line in lines:
|
||||
# Preserve indentation but normalize to multiples of single spaces
|
||||
stripped = line.lstrip()
|
||||
if stripped:
|
||||
indent_count = len(line) - len(stripped)
|
||||
# Normalize indentation (convert tabs to spaces)
|
||||
normalized_indent = ' ' * indent_count
|
||||
normalized_lines.append(normalized_indent + stripped)
|
||||
else:
|
||||
normalized_lines.append('')
|
||||
return '\n'.join(normalized_lines)
|
||||
|
||||
|
||||
class FuzzyMatchResult:
|
||||
"""Fuzzy match result"""
|
||||
|
||||
def __init__(self, found: bool, index: int = -1, match_length: int = 0, content_for_replacement: str = ""):
|
||||
self.found = found
|
||||
self.index = index
|
||||
self.match_length = match_length
|
||||
self.content_for_replacement = content_for_replacement
|
||||
|
||||
|
||||
def fuzzy_find_text(content: str, old_text: str) -> FuzzyMatchResult:
|
||||
"""
|
||||
Find text in content, try exact match first, then fuzzy match
|
||||
|
||||
:param content: Content to search in
|
||||
:param old_text: Text to find
|
||||
:return: Match result
|
||||
"""
|
||||
# First try exact match
|
||||
index = content.find(old_text)
|
||||
if index != -1:
|
||||
return FuzzyMatchResult(
|
||||
found=True,
|
||||
index=index,
|
||||
match_length=len(old_text),
|
||||
content_for_replacement=content
|
||||
)
|
||||
|
||||
# Try fuzzy match
|
||||
fuzzy_content = normalize_for_fuzzy_match(content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(old_text)
|
||||
|
||||
index = fuzzy_content.find(fuzzy_old_text)
|
||||
if index != -1:
|
||||
# Fuzzy match successful, use normalized content for replacement
|
||||
return FuzzyMatchResult(
|
||||
found=True,
|
||||
index=index,
|
||||
match_length=len(fuzzy_old_text),
|
||||
content_for_replacement=fuzzy_content
|
||||
)
|
||||
|
||||
# Not found
|
||||
return FuzzyMatchResult(found=False)
|
||||
|
||||
|
||||
def generate_diff_string(old_content: str, new_content: str) -> dict:
|
||||
"""
|
||||
Generate unified diff string
|
||||
|
||||
:param old_content: Old content
|
||||
:param new_content: New content
|
||||
:return: Dictionary containing diff and first changed line number
|
||||
"""
|
||||
old_lines = old_content.split('\n')
|
||||
new_lines = new_content.split('\n')
|
||||
|
||||
# Generate unified diff
|
||||
diff_lines = list(difflib.unified_diff(
|
||||
old_lines,
|
||||
new_lines,
|
||||
lineterm='',
|
||||
fromfile='original',
|
||||
tofile='modified'
|
||||
))
|
||||
|
||||
# Find first changed line number
|
||||
first_changed_line = None
|
||||
for line in diff_lines:
|
||||
if line.startswith('@@'):
|
||||
# Parse @@ -1,3 +1,3 @@ format
|
||||
match = re.search(r'@@ -\d+,?\d* \+(\d+)', line)
|
||||
if match:
|
||||
first_changed_line = int(match.group(1))
|
||||
break
|
||||
|
||||
diff_string = '\n'.join(diff_lines)
|
||||
|
||||
return {
|
||||
'diff': diff_string,
|
||||
'first_changed_line': first_changed_line
|
||||
}
|
||||
292
agent/tools/utils/truncate.py
Normal file
292
agent/tools/utils/truncate.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Shared truncation utilities for tool outputs.
|
||||
|
||||
Truncation is based on two independent limits - whichever is hit first wins:
|
||||
- Line limit (default: 2000 lines)
|
||||
- Byte limit (default: 50KB)
|
||||
|
||||
Never returns partial lines (except bash tail truncation edge case).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Literal, Tuple
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 2000
|
||||
DEFAULT_MAX_BYTES = 50 * 1024 # 50KB
|
||||
GREP_MAX_LINE_LENGTH = 500 # Max chars per grep match line
|
||||
|
||||
|
||||
class TruncationResult:
|
||||
"""Truncation result"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
truncated: bool,
|
||||
truncated_by: Optional[Literal["lines", "bytes"]],
|
||||
total_lines: int,
|
||||
total_bytes: int,
|
||||
output_lines: int,
|
||||
output_bytes: int,
|
||||
last_line_partial: bool = False,
|
||||
first_line_exceeds_limit: bool = False,
|
||||
max_lines: int = DEFAULT_MAX_LINES,
|
||||
max_bytes: int = DEFAULT_MAX_BYTES
|
||||
):
|
||||
self.content = content
|
||||
self.truncated = truncated
|
||||
self.truncated_by = truncated_by
|
||||
self.total_lines = total_lines
|
||||
self.total_bytes = total_bytes
|
||||
self.output_lines = output_lines
|
||||
self.output_bytes = output_bytes
|
||||
self.last_line_partial = last_line_partial
|
||||
self.first_line_exceeds_limit = first_line_exceeds_limit
|
||||
self.max_lines = max_lines
|
||||
self.max_bytes = max_bytes
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"content": self.content,
|
||||
"truncated": self.truncated,
|
||||
"truncated_by": self.truncated_by,
|
||||
"total_lines": self.total_lines,
|
||||
"total_bytes": self.total_bytes,
|
||||
"output_lines": self.output_lines,
|
||||
"output_bytes": self.output_bytes,
|
||||
"last_line_partial": self.last_line_partial,
|
||||
"first_line_exceeds_limit": self.first_line_exceeds_limit,
|
||||
"max_lines": self.max_lines,
|
||||
"max_bytes": self.max_bytes
|
||||
}
|
||||
|
||||
|
||||
def format_size(bytes_count: int) -> str:
|
||||
"""Format bytes as human-readable size"""
|
||||
if bytes_count < 1024:
|
||||
return f"{bytes_count}B"
|
||||
elif bytes_count < 1024 * 1024:
|
||||
return f"{bytes_count / 1024:.1f}KB"
|
||||
else:
|
||||
return f"{bytes_count / (1024 * 1024):.1f}MB"
|
||||
|
||||
|
||||
def truncate_head(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||
"""
|
||||
Truncate content from the head (keep first N lines/bytes).
|
||||
Suitable for file reads where you want to see the beginning.
|
||||
|
||||
Never returns partial lines. If first line exceeds byte limit,
|
||||
returns empty content with first_line_exceeds_limit=True.
|
||||
|
||||
:param content: Content to truncate
|
||||
:param max_lines: Maximum number of lines (default: 2000)
|
||||
:param max_bytes: Maximum number of bytes (default: 50KB)
|
||||
:return: Truncation result
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_MAX_BYTES
|
||||
|
||||
total_bytes = len(content.encode('utf-8'))
|
||||
lines = content.split('\n')
|
||||
total_lines = len(lines)
|
||||
|
||||
# Check if no truncation is needed
|
||||
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||
return TruncationResult(
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncated_by=None,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=total_lines,
|
||||
output_bytes=total_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Check if first line alone exceeds byte limit
|
||||
first_line_bytes = len(lines[0].encode('utf-8'))
|
||||
if first_line_bytes > max_bytes:
|
||||
return TruncationResult(
|
||||
content="",
|
||||
truncated=True,
|
||||
truncated_by="bytes",
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=0,
|
||||
output_bytes=0,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=True,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Collect complete lines that fit
|
||||
output_lines_arr = []
|
||||
output_bytes_count = 0
|
||||
truncated_by = "lines"
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if i >= max_lines:
|
||||
break
|
||||
|
||||
# Calculate line bytes (add 1 for newline if not first line)
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if i > 0 else 0)
|
||||
|
||||
if output_bytes_count + line_bytes > max_bytes:
|
||||
truncated_by = "bytes"
|
||||
break
|
||||
|
||||
output_lines_arr.append(line)
|
||||
output_bytes_count += line_bytes
|
||||
|
||||
# If exited due to line limit
|
||||
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||
truncated_by = "lines"
|
||||
|
||||
output_content = '\n'.join(output_lines_arr)
|
||||
final_output_bytes = len(output_content.encode('utf-8'))
|
||||
|
||||
return TruncationResult(
|
||||
content=output_content,
|
||||
truncated=True,
|
||||
truncated_by=truncated_by,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=len(output_lines_arr),
|
||||
output_bytes=final_output_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
|
||||
def truncate_tail(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||
"""
|
||||
Truncate content from tail (keep last N lines/bytes).
|
||||
Suitable for bash output where you want to see the ending content (errors, final results).
|
||||
|
||||
If the last line of original content exceeds byte limit, may return partial first line.
|
||||
|
||||
:param content: Content to truncate
|
||||
:param max_lines: Maximum lines (default: 2000)
|
||||
:param max_bytes: Maximum bytes (default: 50KB)
|
||||
:return: Truncation result
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_MAX_BYTES
|
||||
|
||||
total_bytes = len(content.encode('utf-8'))
|
||||
lines = content.split('\n')
|
||||
total_lines = len(lines)
|
||||
|
||||
# Check if no truncation is needed
|
||||
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||
return TruncationResult(
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncated_by=None,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=total_lines,
|
||||
output_bytes=total_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Work backwards from the end
|
||||
output_lines_arr = []
|
||||
output_bytes_count = 0
|
||||
truncated_by = "lines"
|
||||
last_line_partial = False
|
||||
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
if len(output_lines_arr) >= max_lines:
|
||||
break
|
||||
|
||||
line = lines[i]
|
||||
# Calculate line bytes (add newline if not the first added line)
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if len(output_lines_arr) > 0 else 0)
|
||||
|
||||
if output_bytes_count + line_bytes > max_bytes:
|
||||
truncated_by = "bytes"
|
||||
# Edge case: if we haven't added any lines yet and this line exceeds maxBytes,
|
||||
# take the end portion of this line
|
||||
if len(output_lines_arr) == 0:
|
||||
truncated_line = _truncate_string_to_bytes_from_end(line, max_bytes)
|
||||
output_lines_arr.insert(0, truncated_line)
|
||||
output_bytes_count = len(truncated_line.encode('utf-8'))
|
||||
last_line_partial = True
|
||||
break
|
||||
|
||||
output_lines_arr.insert(0, line)
|
||||
output_bytes_count += line_bytes
|
||||
|
||||
# If exited due to line limit
|
||||
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||
truncated_by = "lines"
|
||||
|
||||
output_content = '\n'.join(output_lines_arr)
|
||||
final_output_bytes = len(output_content.encode('utf-8'))
|
||||
|
||||
return TruncationResult(
|
||||
content=output_content,
|
||||
truncated=True,
|
||||
truncated_by=truncated_by,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=len(output_lines_arr),
|
||||
output_bytes=final_output_bytes,
|
||||
last_line_partial=last_line_partial,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
|
||||
def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
|
||||
"""
|
||||
Truncate string to fit byte limit (from end).
|
||||
Properly handles multi-byte UTF-8 characters.
|
||||
|
||||
:param text: String to truncate
|
||||
:param max_bytes: Maximum bytes
|
||||
:return: Truncated string
|
||||
"""
|
||||
encoded = text.encode('utf-8')
|
||||
if len(encoded) <= max_bytes:
|
||||
return text
|
||||
|
||||
# Start from end, skip back maxBytes
|
||||
start = len(encoded) - max_bytes
|
||||
|
||||
# Find valid UTF-8 boundary (character start)
|
||||
while start < len(encoded) and (encoded[start] & 0xC0) == 0x80:
|
||||
start += 1
|
||||
|
||||
return encoded[start:].decode('utf-8', errors='ignore')
|
||||
|
||||
|
||||
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> Tuple[str, bool]:
|
||||
"""
|
||||
Truncate single line to max characters, add [truncated] suffix.
|
||||
Used for grep match lines.
|
||||
|
||||
:param line: Line to truncate
|
||||
:param max_chars: Maximum characters
|
||||
:return: (truncated text, whether truncated)
|
||||
"""
|
||||
if len(line) <= max_chars:
|
||||
return line, False
|
||||
return f"{line[:max_chars]}... [truncated]", True
|
||||
1
agent/tools/vision/__init__.py
Normal file
1
agent/tools/vision/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from agent.tools.vision.vision import Vision
|
||||
255
agent/tools/vision/vision.py
Normal file
255
agent/tools/vision/vision.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""
|
||||
Vision tool - Analyze images using OpenAI-compatible Vision API.
|
||||
Supports local files (auto base64-encoded) and HTTP URLs.
|
||||
Providers: OpenAI (preferred) > LinkAI (fallback).
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
DEFAULT_MODEL = "gpt-4.1-mini"
|
||||
DEFAULT_TIMEOUT = 60
|
||||
MAX_TOKENS = 1000
|
||||
COMPRESS_THRESHOLD = 1_048_576 # 1 MB
|
||||
|
||||
SUPPORTED_EXTENSIONS = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
}
|
||||
|
||||
|
||||
class Vision(BaseTool):
|
||||
"""Analyze images using OpenAI-compatible Vision API"""
|
||||
|
||||
name: str = "vision"
|
||||
description: str = (
|
||||
"Analyze an image (local file or URL) using Vision API. "
|
||||
"Can describe content, extract text, identify objects, colors, etc. "
|
||||
"Requires OPENAI_API_KEY or LINKAI_API_KEY."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {
|
||||
"type": "string",
|
||||
"description": "Local file path or HTTP(S) URL of the image to analyze",
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "Question to ask about the image",
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
f"Vision model to use (default: {DEFAULT_MODEL}). "
|
||||
"Options: gpt-4.1-mini, gpt-4.1, gpt-4o-mini, gpt-4o"
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["image", "question"],
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
return bool(
|
||||
conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
|
||||
or conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
|
||||
)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
image = args.get("image", "").strip()
|
||||
question = args.get("question", "").strip()
|
||||
model = args.get("model", DEFAULT_MODEL).strip() or DEFAULT_MODEL
|
||||
|
||||
if not image:
|
||||
return ToolResult.fail("Error: 'image' parameter is required")
|
||||
if not question:
|
||||
return ToolResult.fail("Error: 'question' parameter is required")
|
||||
|
||||
api_key, api_base = self._resolve_provider()
|
||||
if not api_key:
|
||||
return ToolResult.fail(
|
||||
"Error: No API key configured for Vision.\n"
|
||||
"Please configure one of the following using env_config tool:\n"
|
||||
" 1. OPENAI_API_KEY (preferred): env_config(action=\"set\", key=\"OPENAI_API_KEY\", value=\"your-key\")\n"
|
||||
" 2. LINKAI_API_KEY (fallback): env_config(action=\"set\", key=\"LINKAI_API_KEY\", value=\"your-key\")\n\n"
|
||||
"Get your key at: https://platform.openai.com/api-keys or https://link-ai.tech"
|
||||
)
|
||||
|
||||
try:
|
||||
image_content = self._build_image_content(image)
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: {e}")
|
||||
|
||||
try:
|
||||
return self._call_api(api_key, api_base, model, question, image_content)
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Vision API request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail("Error: Failed to connect to Vision API")
|
||||
except Exception as e:
|
||||
logger.error(f"[Vision] Unexpected error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: Vision API call failed - {e}")
|
||||
|
||||
def _resolve_provider(self) -> Tuple[Optional[str], str]:
|
||||
"""Resolve API key and base URL. Priority: conf() > env vars."""
|
||||
api_key = conf().get("open_ai_api_key") or os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
api_base = (conf().get("open_ai_api_base") or os.environ.get("OPENAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.openai.com/v1"
|
||||
return api_key, self._ensure_v1(api_base)
|
||||
|
||||
api_key = conf().get("linkai_api_key") or os.environ.get("LINKAI_API_KEY")
|
||||
if api_key:
|
||||
api_base = (conf().get("linkai_api_base") or os.environ.get("LINKAI_API_BASE", "")).rstrip("/") \
|
||||
or "https://api.link-ai.tech"
|
||||
logger.debug("[Vision] Using LinkAI API (OPENAI_API_KEY not set)")
|
||||
return api_key, self._ensure_v1(api_base)
|
||||
|
||||
return None, ""
|
||||
|
||||
@staticmethod
|
||||
def _ensure_v1(api_base: str) -> str:
|
||||
"""Append /v1 if the base URL doesn't already end with a versioned path."""
|
||||
if not api_base:
|
||||
return api_base
|
||||
# Already has /v1 or similar version suffix
|
||||
if api_base.rstrip("/").split("/")[-1].startswith("v"):
|
||||
return api_base
|
||||
return api_base.rstrip("/") + "/v1"
|
||||
|
||||
def _build_image_content(self, image: str) -> dict:
|
||||
"""Build the image_url content block for the API request."""
|
||||
if image.startswith(("http://", "https://")):
|
||||
return {"type": "image_url", "image_url": {"url": image}}
|
||||
|
||||
if not os.path.isfile(image):
|
||||
raise FileNotFoundError(f"Image file not found: {image}")
|
||||
|
||||
ext = image.rsplit(".", 1)[-1].lower() if "." in image else ""
|
||||
mime_type = SUPPORTED_EXTENSIONS.get(ext)
|
||||
if not mime_type:
|
||||
raise ValueError(
|
||||
f"Unsupported image format '.{ext}'. "
|
||||
f"Supported: {', '.join(SUPPORTED_EXTENSIONS.keys())}"
|
||||
)
|
||||
|
||||
file_path = self._maybe_compress(image)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
b64 = base64.b64encode(f.read()).decode("ascii")
|
||||
finally:
|
||||
if file_path != image and os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{b64}"
|
||||
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||
|
||||
@staticmethod
|
||||
def _maybe_compress(path: str) -> str:
|
||||
"""Compress image if larger than threshold; return path to use."""
|
||||
file_size = os.path.getsize(path)
|
||||
if file_size <= COMPRESS_THRESHOLD:
|
||||
return path
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||||
tmp.close()
|
||||
|
||||
try:
|
||||
# macOS: use sips
|
||||
subprocess.run(
|
||||
["sips", "-Z", "800", path, "--out", tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
logger.debug(f"[Vision] Compressed image ({file_size // 1024}KB -> {os.path.getsize(tmp.name) // 1024}KB)")
|
||||
return tmp.name
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
pass
|
||||
|
||||
try:
|
||||
# Linux: use ImageMagick convert
|
||||
subprocess.run(
|
||||
["convert", path, "-resize", "800x800>", tmp.name],
|
||||
capture_output=True, check=True,
|
||||
)
|
||||
logger.debug(f"[Vision] Compressed image ({file_size // 1024}KB -> {os.path.getsize(tmp.name) // 1024}KB)")
|
||||
return tmp.name
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
pass
|
||||
|
||||
os.remove(tmp.name)
|
||||
return path
|
||||
|
||||
def _call_api(self, api_key: str, api_base: str, model: str,
|
||||
question: str, image_content: dict) -> ToolResult:
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": question},
|
||||
image_content,
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": MAX_TOKENS,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{api_base}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid API key. Please check your configuration.")
|
||||
if resp.status_code == 429:
|
||||
return ToolResult.fail("Error: API rate limit reached. Please try again later.")
|
||||
if resp.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Vision API returned HTTP {resp.status_code}: {resp.text[:200]}")
|
||||
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
msg = data["error"].get("message", "Unknown API error")
|
||||
return ToolResult.fail(f"Error: Vision API error - {msg}")
|
||||
|
||||
content = ""
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
content = choices[0].get("message", {}).get("content", "")
|
||||
|
||||
usage = data.get("usage", {})
|
||||
result = {
|
||||
"model": model,
|
||||
"content": content,
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage.get("completion_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
}
|
||||
return ToolResult.success(result)
|
||||
0
agent/tools/web_fetch/__init__.py
Normal file
0
agent/tools/web_fetch/__init__.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
444
agent/tools/web_fetch/web_fetch.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
Web Fetch tool - Fetch and extract readable content from web pages and remote files.
|
||||
|
||||
Supports:
|
||||
- HTML web pages: extracts readable text content
|
||||
- Document files (PDF, Word, TXT, Markdown, etc.): downloads to workspace/tmp and parses content
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, Set
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size
|
||||
from common.log import logger
|
||||
|
||||
|
||||
DEFAULT_TIMEOUT = 30
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36",
|
||||
"Accept": "*/*",
|
||||
}
|
||||
|
||||
# Supported document file extensions
|
||||
PDF_SUFFIXES: Set[str] = {".pdf"}
|
||||
WORD_SUFFIXES: Set[str] = {".docx"}
|
||||
TEXT_SUFFIXES: Set[str] = {".txt", ".md", ".markdown", ".rst", ".csv", ".tsv", ".log"}
|
||||
SPREADSHEET_SUFFIXES: Set[str] = {".xls", ".xlsx"}
|
||||
PPT_SUFFIXES: Set[str] = {".ppt", ".pptx"}
|
||||
|
||||
ALL_DOC_SUFFIXES = PDF_SUFFIXES | WORD_SUFFIXES | TEXT_SUFFIXES | SPREADSHEET_SUFFIXES | PPT_SUFFIXES
|
||||
|
||||
_CHARSET_RE = re.compile(r'charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_CHARSET_RE = re.compile(rb'<meta[^>]+charset\s*=\s*["\']?\s*([\w\-]+)', re.IGNORECASE)
|
||||
_META_HTTP_EQUIV_RE = re.compile(
|
||||
rb'<meta[^>]+http-equiv\s*=\s*["\']?Content-Type["\']?[^>]+content\s*=\s*["\'][^"\']*charset=([\w\-]+)',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _extract_charset_from_content_type(content_type: str) -> Optional[str]:
|
||||
"""Extract charset from Content-Type header value."""
|
||||
m = _CHARSET_RE.search(content_type)
|
||||
return m.group(1) if m else None
|
||||
|
||||
|
||||
def _extract_charset_from_html_meta(raw_bytes: bytes) -> Optional[str]:
|
||||
"""Extract charset from HTML <meta> tags in the first few KB of raw bytes."""
|
||||
m = _META_CHARSET_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
m = _META_HTTP_EQUIV_RE.search(raw_bytes)
|
||||
if m:
|
||||
return m.group(1).decode("ascii", errors="ignore")
|
||||
return None
|
||||
|
||||
|
||||
def _get_url_suffix(url: str) -> str:
|
||||
"""Extract file extension from URL path, ignoring query params."""
|
||||
path = urlparse(url).path
|
||||
return os.path.splitext(path)[-1].lower()
|
||||
|
||||
|
||||
def _is_document_url(url: str) -> bool:
|
||||
"""Check if URL points to a downloadable document file."""
|
||||
suffix = _get_url_suffix(url)
|
||||
return suffix in ALL_DOC_SUFFIXES
|
||||
|
||||
|
||||
class WebFetch(BaseTool):
|
||||
"""Tool for fetching web pages and remote document files"""
|
||||
|
||||
name: str = "web_fetch"
|
||||
description: str = (
|
||||
"Fetch content from a URL. For web pages, extracts readable text. "
|
||||
"For document files (PDF, Word, TXT, Markdown, Excel, PPT), downloads and parses the file content. "
|
||||
"Supported file types: .pdf, .docx, .txt, .md, .csv, .xls, .xlsx, .ppt, .pptx"
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The HTTP/HTTPS URL to fetch (web page or document file link)"
|
||||
}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
url = args.get("url", "").strip()
|
||||
if not url:
|
||||
return ToolResult.fail("Error: 'url' parameter is required")
|
||||
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return ToolResult.fail("Error: Invalid URL (must start with http:// or https://)")
|
||||
|
||||
if _is_document_url(url):
|
||||
return self._fetch_document(url)
|
||||
|
||||
return self._fetch_webpage(url)
|
||||
|
||||
# ---- Web page fetching ----
|
||||
|
||||
def _fetch_webpage(self, url: str) -> ToolResult:
|
||||
"""Fetch and extract readable text from an HTML web page."""
|
||||
parsed = urlparse(url)
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error: Failed to fetch URL: {e}")
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if self._is_binary_content_type(content_type) and not _is_document_url(url):
|
||||
return self._handle_download_by_content_type(url, response, content_type)
|
||||
|
||||
response.encoding = self._detect_encoding(response)
|
||||
html = response.text
|
||||
title = self._extract_title(html)
|
||||
text = self._extract_text(html)
|
||||
|
||||
return ToolResult.success(f"Title: {title}\n\nContent:\n{text}")
|
||||
|
||||
# ---- Document fetching ----
|
||||
|
||||
def _fetch_document(self, url: str) -> ToolResult:
|
||||
"""Download a document file and extract its text content."""
|
||||
suffix = _get_url_suffix(url)
|
||||
parsed = urlparse(url)
|
||||
filename = self._extract_filename(url)
|
||||
tmp_dir = self._ensure_tmp_dir()
|
||||
|
||||
local_path = os.path.join(tmp_dir, filename)
|
||||
logger.info(f"[WebFetch] Downloading document: {url} -> {local_path}")
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
headers=DEFAULT_HEADERS,
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
stream=True,
|
||||
allow_redirects=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
content_length = int(response.headers.get("Content-Length", 0))
|
||||
if content_length > MAX_FILE_SIZE:
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large ({format_size(content_length)} > {format_size(MAX_FILE_SIZE)})"
|
||||
)
|
||||
|
||||
downloaded = 0
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
downloaded += len(chunk)
|
||||
if downloaded > MAX_FILE_SIZE:
|
||||
f.close()
|
||||
os.remove(local_path)
|
||||
return ToolResult.fail(
|
||||
f"Error: File too large (>{format_size(MAX_FILE_SIZE)}), download aborted"
|
||||
)
|
||||
f.write(chunk)
|
||||
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Download timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail(f"Error: Failed to connect to {parsed.netloc}")
|
||||
except requests.HTTPError as e:
|
||||
return ToolResult.fail(f"Error: HTTP {e.response.status_code} for URL: {url}")
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to download file: {e}")
|
||||
|
||||
try:
|
||||
text = self._parse_document(local_path, suffix)
|
||||
except Exception as e:
|
||||
self._cleanup_file(local_path)
|
||||
return ToolResult.fail(f"Error: Failed to parse document: {e}")
|
||||
|
||||
if not text or not text.strip():
|
||||
file_size = os.path.getsize(local_path)
|
||||
return ToolResult.success(
|
||||
f"File downloaded to: {local_path} ({format_size(file_size)})\n"
|
||||
f"No text content could be extracted. The file may contain only images or be encrypted."
|
||||
)
|
||||
|
||||
truncation = truncate_head(text)
|
||||
result_text = truncation.content
|
||||
|
||||
file_size = os.path.getsize(local_path)
|
||||
header = f"[Document: {filename} | Size: {format_size(file_size)} | Saved to: {local_path}]\n\n"
|
||||
|
||||
if truncation.truncated:
|
||||
header += f"[Content truncated: showing {truncation.output_lines} of {truncation.total_lines} lines]\n\n"
|
||||
|
||||
return ToolResult.success(header + result_text)
|
||||
|
||||
def _parse_document(self, file_path: str, suffix: str) -> str:
|
||||
"""Parse document file and return extracted text."""
|
||||
if suffix in PDF_SUFFIXES:
|
||||
return self._parse_pdf(file_path)
|
||||
elif suffix in WORD_SUFFIXES:
|
||||
return self._parse_word(file_path)
|
||||
elif suffix in TEXT_SUFFIXES:
|
||||
return self._parse_text(file_path)
|
||||
elif suffix in SPREADSHEET_SUFFIXES:
|
||||
return self._parse_spreadsheet(file_path)
|
||||
elif suffix in PPT_SUFFIXES:
|
||||
return self._parse_ppt(file_path)
|
||||
else:
|
||||
return self._parse_text(file_path)
|
||||
|
||||
def _parse_pdf(self, file_path: str) -> str:
|
||||
"""Extract text from PDF using pypdf."""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
raise ImportError("pypdf library is required for PDF parsing. Install with: pip install pypdf")
|
||||
|
||||
reader = PdfReader(file_path)
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text and page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num}/{len(reader.pages)} ---\n{page_text}")
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
def _parse_word(self, file_path: str) -> str:
|
||||
"""Extract text from Word documents (.docx)."""
|
||||
try:
|
||||
from docx import Document
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-docx library is required for .docx parsing. Install with: pip install python-docx"
|
||||
)
|
||||
doc = Document(file_path)
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return "\n\n".join(paragraphs)
|
||||
|
||||
def _parse_text(self, file_path: str) -> str:
|
||||
"""Read plain text files (txt, md, csv, etc.)."""
|
||||
encodings = ["utf-8", "utf-8-sig", "gbk", "gb2312", "latin-1"]
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(file_path, "r", encoding=enc) as f:
|
||||
return f.read()
|
||||
except (UnicodeDecodeError, UnicodeError):
|
||||
continue
|
||||
raise ValueError(f"Unable to decode file with any supported encoding: {encodings}")
|
||||
|
||||
def _parse_spreadsheet(self, file_path: str) -> str:
|
||||
"""Extract text from Excel files (.xls/.xlsx)."""
|
||||
try:
|
||||
import openpyxl
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"openpyxl library is required for .xlsx parsing. Install with: pip install openpyxl"
|
||||
)
|
||||
|
||||
wb = openpyxl.load_workbook(file_path, read_only=True, data_only=True)
|
||||
result_parts = []
|
||||
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
rows = []
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
cells = [str(c) if c is not None else "" for c in row]
|
||||
if any(cells):
|
||||
rows.append(" | ".join(cells))
|
||||
if rows:
|
||||
result_parts.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
|
||||
|
||||
wb.close()
|
||||
return "\n\n".join(result_parts)
|
||||
|
||||
def _parse_ppt(self, file_path: str) -> str:
|
||||
"""Extract text from PowerPoint files (.ppt/.pptx)."""
|
||||
try:
|
||||
from pptx import Presentation
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"python-pptx library is required for .pptx parsing. Install with: pip install python-pptx"
|
||||
)
|
||||
|
||||
prs = Presentation(file_path)
|
||||
text_parts = []
|
||||
|
||||
for slide_num, slide in enumerate(prs.slides, 1):
|
||||
slide_texts = []
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
for paragraph in shape.text_frame.paragraphs:
|
||||
text = paragraph.text.strip()
|
||||
if text:
|
||||
slide_texts.append(text)
|
||||
if slide_texts:
|
||||
text_parts.append(f"--- Slide {slide_num}/{len(prs.slides)} ---\n" + "\n".join(slide_texts))
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
# ---- Encoding detection ----
|
||||
|
||||
@staticmethod
|
||||
def _detect_encoding(response: requests.Response) -> str:
|
||||
"""Detect response encoding with priority: Content-Type header > HTML meta > chardet > utf-8."""
|
||||
# 1. Check Content-Type header for explicit charset
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
charset = _extract_charset_from_content_type(content_type)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 2. Scan raw bytes for HTML meta charset declaration
|
||||
raw = response.content[:4096]
|
||||
charset = _extract_charset_from_html_meta(raw)
|
||||
if charset:
|
||||
return charset
|
||||
|
||||
# 3. Use apparent_encoding (chardet-based detection) if confident enough
|
||||
apparent = response.apparent_encoding
|
||||
if apparent:
|
||||
apparent_lower = apparent.lower()
|
||||
# Trust CJK / Windows encodings detected by chardet
|
||||
trusted_prefixes = ("utf", "gb", "big5", "euc", "shift_jis", "iso-2022", "windows", "ascii")
|
||||
if any(apparent_lower.startswith(p) for p in trusted_prefixes):
|
||||
return apparent
|
||||
|
||||
# 4. Fallback
|
||||
return "utf-8"
|
||||
|
||||
# ---- Helper methods ----
|
||||
|
||||
def _ensure_tmp_dir(self) -> str:
|
||||
"""Ensure workspace/tmp directory exists and return its path."""
|
||||
tmp_dir = os.path.join(self.cwd, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
def _extract_filename(self, url: str) -> str:
|
||||
"""Extract a safe filename from URL, with a short UUID prefix to avoid collisions."""
|
||||
path = urlparse(url).path
|
||||
basename = os.path.basename(unquote(path))
|
||||
if not basename or basename == "/":
|
||||
basename = "downloaded_file"
|
||||
# Sanitize: keep only safe chars
|
||||
basename = re.sub(r'[^\w.\-]', '_', basename)
|
||||
short_id = uuid.uuid4().hex[:8]
|
||||
return f"{short_id}_{basename}"
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_file(path: str):
|
||||
"""Remove a file if it exists, ignoring errors."""
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_binary_content_type(content_type: str) -> bool:
|
||||
"""Check if Content-Type indicates a binary/document response."""
|
||||
binary_types = [
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/octet-stream",
|
||||
]
|
||||
ct_lower = content_type.lower()
|
||||
return any(bt in ct_lower for bt in binary_types)
|
||||
|
||||
def _handle_download_by_content_type(self, url: str, response: requests.Response, content_type: str) -> ToolResult:
|
||||
"""Handle a URL that returned binary content instead of HTML."""
|
||||
ct_lower = content_type.lower()
|
||||
suffix_map = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml": ".docx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml": ".xlsx",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml": ".pptx",
|
||||
}
|
||||
detected_suffix = None
|
||||
for ct_prefix, ext in suffix_map.items():
|
||||
if ct_prefix in ct_lower:
|
||||
detected_suffix = ext
|
||||
break
|
||||
|
||||
if detected_suffix and detected_suffix in ALL_DOC_SUFFIXES:
|
||||
# Re-fetch as document
|
||||
return self._fetch_document(url if _get_url_suffix(url) in ALL_DOC_SUFFIXES
|
||||
else self._rewrite_url_with_suffix(url, detected_suffix))
|
||||
return ToolResult.fail(f"Error: URL returned binary content ({content_type}), not a supported document type")
|
||||
|
||||
@staticmethod
|
||||
def _rewrite_url_with_suffix(url: str, suffix: str) -> str:
|
||||
"""Append a suffix to the URL path so _get_url_suffix works correctly."""
|
||||
parsed = urlparse(url)
|
||||
new_path = parsed.path.rstrip("/") + suffix
|
||||
return parsed._replace(path=new_path).geturl()
|
||||
|
||||
# ---- HTML extraction (unchanged) ----
|
||||
|
||||
@staticmethod
|
||||
def _extract_title(html: str) -> str:
|
||||
match = re.search(r"<title[^>]*>(.*?)</title>", html, re.IGNORECASE | re.DOTALL)
|
||||
return match.group(1).strip() if match else "Untitled"
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(html: str) -> str:
|
||||
text = re.sub(r"<script[^>]*>.*?</script>", "", html, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<style[^>]*>.*?</style>", "", text, flags=re.IGNORECASE | re.DOTALL)
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(""", '"').replace("'", "'").replace(" ", " ")
|
||||
text = re.sub(r"[^\S\n]+", " ", text)
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
lines = [line.strip() for line in text.splitlines()]
|
||||
text = "\n".join(lines)
|
||||
return text.strip()
|
||||
3
agent/tools/web_search/__init__.py
Normal file
3
agent/tools/web_search/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
|
||||
__all__ = ["WebSearch"]
|
||||
320
agent/tools/web_search/web_search.py
Normal file
320
agent/tools/web_search/web_search.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Web Search tool - Search the web using Bocha or LinkAI search API.
|
||||
Supports two backends with unified response format:
|
||||
1. Bocha Search (primary, requires BOCHA_API_KEY)
|
||||
2. LinkAI Search (fallback, requires LINKAI_API_KEY)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
# Default timeout for API requests (seconds)
|
||||
DEFAULT_TIMEOUT = 30
|
||||
|
||||
|
||||
class WebSearch(BaseTool):
|
||||
"""Tool for searching the web using Bocha or LinkAI search API"""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = "Search the web for real-time information. Returns titles, URLs, and snippets."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query string"
|
||||
},
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-50, default: 10)"
|
||||
},
|
||||
"freshness": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time range filter. Options: "
|
||||
"'noLimit' (default), 'oneDay', 'oneWeek', 'oneMonth', 'oneYear', "
|
||||
"or date range like '2025-01-01..2025-02-01'"
|
||||
)
|
||||
},
|
||||
"summary": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to include text summary for each result (default: false)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self._backend = None # Will be resolved on first execute
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
"""Check if web search is available (at least one API key is configured)"""
|
||||
return bool(os.environ.get("BOCHA_API_KEY") or os.environ.get("LINKAI_API_KEY"))
|
||||
|
||||
def _resolve_backend(self) -> Optional[str]:
|
||||
"""
|
||||
Determine which search backend to use.
|
||||
Priority: Bocha > LinkAI
|
||||
|
||||
:return: 'bocha', 'linkai', or None
|
||||
"""
|
||||
if os.environ.get("BOCHA_API_KEY"):
|
||||
return "bocha"
|
||||
if os.environ.get("LINKAI_API_KEY"):
|
||||
return "linkai"
|
||||
return None
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute web search
|
||||
|
||||
:param args: Search parameters (query, count, freshness, summary)
|
||||
:return: Search results
|
||||
"""
|
||||
query = args.get("query", "").strip()
|
||||
if not query:
|
||||
return ToolResult.fail("Error: 'query' parameter is required")
|
||||
|
||||
count = args.get("count", 10)
|
||||
freshness = args.get("freshness", "noLimit")
|
||||
summary = args.get("summary", False)
|
||||
|
||||
# Validate count
|
||||
if not isinstance(count, int) or count < 1 or count > 50:
|
||||
count = 10
|
||||
|
||||
# Resolve backend
|
||||
backend = self._resolve_backend()
|
||||
if not backend:
|
||||
return ToolResult.fail(
|
||||
"Error: No search API key configured. "
|
||||
"Please set BOCHA_API_KEY or LINKAI_API_KEY using env_config tool.\n"
|
||||
" - Bocha Search: https://open.bocha.cn\n"
|
||||
" - LinkAI Search: https://link-ai.tech"
|
||||
)
|
||||
|
||||
try:
|
||||
if backend == "bocha":
|
||||
return self._search_bocha(query, count, freshness, summary)
|
||||
else:
|
||||
return self._search_linkai(query, count, freshness)
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Search request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail("Error: Failed to connect to search API")
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSearch] Unexpected error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: Search failed - {str(e)}")
|
||||
|
||||
def _search_bocha(self, query: str, count: int, freshness: str, summary: bool) -> ToolResult:
|
||||
"""
|
||||
Search using Bocha API
|
||||
|
||||
:param query: Search query
|
||||
:param count: Number of results
|
||||
:param freshness: Time range filter
|
||||
:param summary: Whether to include summary
|
||||
:return: Formatted search results
|
||||
"""
|
||||
api_key = os.environ.get("BOCHA_API_KEY", "")
|
||||
url = "https://api.bocha.cn/v1/web-search"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"count": count,
|
||||
"freshness": freshness,
|
||||
"summary": summary
|
||||
}
|
||||
|
||||
logger.debug(f"[WebSearch] Bocha search: query='{query}', count={count}")
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if response.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid BOCHA_API_KEY. Please check your API key.")
|
||||
if response.status_code == 403:
|
||||
return ToolResult.fail("Error: Bocha API - insufficient balance. Please top up at https://open.bocha.cn")
|
||||
if response.status_code == 429:
|
||||
return ToolResult.fail("Error: Bocha API rate limit reached. Please try again later.")
|
||||
if response.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Bocha API returned HTTP {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check API-level error code
|
||||
api_code = data.get("code")
|
||||
if api_code is not None and api_code != 200:
|
||||
msg = data.get("msg") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: Bocha API error (code={api_code}): {msg}")
|
||||
|
||||
# Extract and format results
|
||||
return self._format_bocha_results(data, query)
|
||||
|
||||
def _format_bocha_results(self, data: dict, query: str) -> ToolResult:
|
||||
"""
|
||||
Format Bocha API response into unified result structure
|
||||
|
||||
:param data: Raw API response
|
||||
:param query: Original query
|
||||
:return: Formatted ToolResult
|
||||
"""
|
||||
search_data = data.get("data", {})
|
||||
web_pages = search_data.get("webPages", {})
|
||||
pages = web_pages.get("value", [])
|
||||
|
||||
if not pages:
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "bocha",
|
||||
"total": 0,
|
||||
"results": [],
|
||||
"message": "No results found"
|
||||
})
|
||||
|
||||
results = []
|
||||
for page in pages:
|
||||
result = {
|
||||
"title": page.get("name", ""),
|
||||
"url": page.get("url", ""),
|
||||
"snippet": page.get("snippet", ""),
|
||||
"siteName": page.get("siteName", ""),
|
||||
"datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""),
|
||||
}
|
||||
# Include summary only if present
|
||||
if page.get("summary"):
|
||||
result["summary"] = page["summary"]
|
||||
results.append(result)
|
||||
|
||||
total = web_pages.get("totalEstimatedMatches", len(results))
|
||||
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "bocha",
|
||||
"total": total,
|
||||
"count": len(results),
|
||||
"results": results
|
||||
})
|
||||
|
||||
def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
"""
|
||||
Search using LinkAI plugin API
|
||||
|
||||
:param query: Search query
|
||||
:param count: Number of results
|
||||
:param freshness: Time range filter
|
||||
:return: Formatted search results
|
||||
"""
|
||||
api_key = os.environ.get("LINKAI_API_KEY", "")
|
||||
api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
url = f"{api_base.rstrip('/')}/v1/plugin/execute"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"code": "web-search",
|
||||
"args": {
|
||||
"query": query,
|
||||
"count": count,
|
||||
"freshness": freshness
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(f"[WebSearch] LinkAI search: query='{query}', count={count}")
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if response.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid LINKAI_API_KEY. Please check your API key.")
|
||||
if response.status_code != 200:
|
||||
return ToolResult.fail(f"Error: LinkAI API returned HTTP {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
msg = data.get("message") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: LinkAI search failed: {msg}")
|
||||
|
||||
return self._format_linkai_results(data, query)
|
||||
|
||||
def _format_linkai_results(self, data: dict, query: str) -> ToolResult:
|
||||
"""
|
||||
Format LinkAI API response into unified result structure.
|
||||
LinkAI returns the search data in data.data field, which follows
|
||||
the same Bing-compatible format as Bocha.
|
||||
|
||||
:param data: Raw API response
|
||||
:param query: Original query
|
||||
:return: Formatted ToolResult
|
||||
"""
|
||||
raw_data = data.get("data", "")
|
||||
|
||||
# LinkAI may return data as a JSON string
|
||||
if isinstance(raw_data, str):
|
||||
try:
|
||||
raw_data = json.loads(raw_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If data is plain text, return it as a single result
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "linkai",
|
||||
"total": 1,
|
||||
"count": 1,
|
||||
"results": [{"content": raw_data}]
|
||||
})
|
||||
|
||||
# If the response follows Bing-compatible structure
|
||||
if isinstance(raw_data, dict):
|
||||
web_pages = raw_data.get("webPages", {})
|
||||
pages = web_pages.get("value", [])
|
||||
|
||||
if pages:
|
||||
results = []
|
||||
for page in pages:
|
||||
result = {
|
||||
"title": page.get("name", ""),
|
||||
"url": page.get("url", ""),
|
||||
"snippet": page.get("snippet", ""),
|
||||
"siteName": page.get("siteName", ""),
|
||||
"datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""),
|
||||
}
|
||||
if page.get("summary"):
|
||||
result["summary"] = page["summary"]
|
||||
results.append(result)
|
||||
|
||||
total = web_pages.get("totalEstimatedMatches", len(results))
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "linkai",
|
||||
"total": total,
|
||||
"count": len(results),
|
||||
"results": results
|
||||
})
|
||||
|
||||
# Fallback: return raw data
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "linkai",
|
||||
"total": 1,
|
||||
"count": 1,
|
||||
"results": [{"content": str(raw_data)}]
|
||||
})
|
||||
3
agent/tools/write/__init__.py
Normal file
3
agent/tools/write/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .write import Write
|
||||
|
||||
__all__ = ['Write']
|
||||
97
agent/tools/write/write.py
Normal file
97
agent/tools/write/write.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Write tool - Write file content
|
||||
Creates or overwrites files, automatically creates parent directories
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Write(BaseTool):
|
||||
"""Tool for writing file content"""
|
||||
|
||||
name: str = "write"
|
||||
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories. IMPORTANT: Single write should not exceed 10KB. For large files, create a skeleton first, then use edit to add content in chunks."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to write (relative or absolute)"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.memory_manager = self.config.get("memory_manager", None)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file write operation
|
||||
|
||||
:param args: Contains file path and content
|
||||
:return: Operation result
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
content = args.get("content", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
try:
|
||||
# Create parent directory (if needed)
|
||||
parent_dir = os.path.dirname(absolute_path)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
# Write file
|
||||
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
# Get bytes written
|
||||
bytes_written = len(content.encode('utf-8'))
|
||||
|
||||
# Auto-sync to memory database if this is a memory file
|
||||
if self.memory_manager and 'memory/' in path:
|
||||
self.memory_manager.mark_dirty()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully wrote {bytes_written} bytes to {path}",
|
||||
"path": path,
|
||||
"bytes_written": bytes_written
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied writing to {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error writing file: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
293
app.py
293
app.py
@@ -1,23 +1,273 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import os
|
||||
from config import conf, load_config
|
||||
from channel import channel_factory
|
||||
from common.log import logger
|
||||
from plugins import *
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
from channel import channel_factory
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import load_config, conf
|
||||
from plugins import *
|
||||
import threading
|
||||
|
||||
|
||||
_channel_mgr = None
|
||||
|
||||
|
||||
def get_channel_manager():
|
||||
return _channel_mgr
|
||||
|
||||
|
||||
def _parse_channel_type(raw) -> list:
|
||||
"""
|
||||
Parse channel_type config value into a list of channel names.
|
||||
Supports:
|
||||
- single string: "feishu"
|
||||
- comma-separated string: "feishu, dingtalk"
|
||||
- list: ["feishu", "dingtalk"]
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return [ch.strip() for ch in raw if ch.strip()]
|
||||
if isinstance(raw, str):
|
||||
return [ch.strip() for ch in raw.split(",") if ch.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manage the lifecycle of multiple channels running concurrently.
|
||||
Each channel.startup() runs in its own daemon thread.
|
||||
The web channel is started as default console unless explicitly disabled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._channels = {} # channel_name -> channel instance
|
||||
self._threads = {} # channel_name -> thread
|
||||
self._primary_channel = None
|
||||
self._lock = threading.Lock()
|
||||
self.cloud_mode = False # set to True when cloud client is active
|
||||
|
||||
@property
|
||||
def channel(self):
|
||||
"""Return the primary (first non-web) channel for backward compatibility."""
|
||||
return self._primary_channel
|
||||
|
||||
def get_channel(self, channel_name: str):
|
||||
return self._channels.get(channel_name)
|
||||
|
||||
def start(self, channel_names: list, first_start: bool = False):
|
||||
"""
|
||||
Create and start one or more channels in sub-threads.
|
||||
If first_start is True, plugins and linkai client will also be initialized.
|
||||
"""
|
||||
with self._lock:
|
||||
channels = []
|
||||
for name in channel_names:
|
||||
ch = channel_factory.create_channel(name)
|
||||
ch.cloud_mode = self.cloud_mode
|
||||
self._channels[name] = ch
|
||||
channels.append((name, ch))
|
||||
if self._primary_channel is None and name != "web":
|
||||
self._primary_channel = ch
|
||||
|
||||
if self._primary_channel is None and channels:
|
||||
self._primary_channel = channels[0][1]
|
||||
|
||||
if first_start:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
try:
|
||||
from common import cloud_client
|
||||
threading.Thread(
|
||||
target=cloud_client.start,
|
||||
args=(self._primary_channel, self),
|
||||
daemon=True,
|
||||
).start()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Start web console first so its logs print cleanly,
|
||||
# then start remaining channels after a brief pause.
|
||||
web_entry = None
|
||||
other_entries = []
|
||||
for entry in channels:
|
||||
if entry[0] == "web":
|
||||
web_entry = entry
|
||||
else:
|
||||
other_entries.append(entry)
|
||||
|
||||
ordered = ([web_entry] if web_entry else []) + other_entries
|
||||
for i, (name, ch) in enumerate(ordered):
|
||||
if i > 0 and name != "web":
|
||||
time.sleep(0.1)
|
||||
t = threading.Thread(target=self._run_channel, args=(name, ch), daemon=True)
|
||||
self._threads[name] = t
|
||||
t.start()
|
||||
logger.debug(f"[ChannelManager] Channel '{name}' started in sub-thread")
|
||||
|
||||
def _run_channel(self, name: str, channel):
|
||||
try:
|
||||
channel.startup()
|
||||
except Exception as e:
|
||||
logger.error(f"[ChannelManager] Channel '{name}' startup error: {e}")
|
||||
logger.exception(e)
|
||||
|
||||
def stop(self, channel_name: str = None):
|
||||
"""
|
||||
Stop channel(s). If channel_name is given, stop only that channel;
|
||||
otherwise stop all channels.
|
||||
"""
|
||||
# Pop under lock, then stop outside lock to avoid deadlock
|
||||
with self._lock:
|
||||
names = [channel_name] if channel_name else list(self._channels.keys())
|
||||
to_stop = []
|
||||
for name in names:
|
||||
ch = self._channels.pop(name, None)
|
||||
th = self._threads.pop(name, None)
|
||||
to_stop.append((name, ch, th))
|
||||
if channel_name and self._primary_channel is self._channels.get(channel_name):
|
||||
self._primary_channel = None
|
||||
|
||||
for name, ch, th in to_stop:
|
||||
if ch is None:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' not found in managed channels")
|
||||
if th and th.is_alive():
|
||||
self._interrupt_thread(th, name)
|
||||
continue
|
||||
logger.info(f"[ChannelManager] Stopping channel '{name}'...")
|
||||
graceful = False
|
||||
if hasattr(ch, 'stop'):
|
||||
try:
|
||||
ch.stop()
|
||||
graceful = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Error during channel '{name}' stop: {e}")
|
||||
if th and th.is_alive():
|
||||
th.join(timeout=5)
|
||||
if th.is_alive():
|
||||
if graceful:
|
||||
logger.info(f"[ChannelManager] Channel '{name}' thread still alive after stop(), "
|
||||
"leaving daemon thread to finish on its own")
|
||||
else:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' thread did not exit in 5s, forcing interrupt")
|
||||
self._interrupt_thread(th, name)
|
||||
|
||||
@staticmethod
|
||||
def _interrupt_thread(th: threading.Thread, name: str):
|
||||
"""Raise SystemExit in target thread to break blocking loops like start_forever."""
|
||||
import ctypes
|
||||
try:
|
||||
tid = th.ident
|
||||
if tid is None:
|
||||
return
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
|
||||
)
|
||||
if res == 1:
|
||||
logger.info(f"[ChannelManager] Interrupted thread for channel '{name}'")
|
||||
elif res > 1:
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
|
||||
logger.warning(f"[ChannelManager] Failed to interrupt thread for channel '{name}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Thread interrupt error for '{name}': {e}")
|
||||
|
||||
def restart(self, new_channel_name: str):
|
||||
"""
|
||||
Restart a single channel with a new channel type.
|
||||
Can be called from any thread (e.g. linkai config callback).
|
||||
"""
|
||||
logger.info(f"[ChannelManager] Restarting channel to '{new_channel_name}'...")
|
||||
self.stop(new_channel_name)
|
||||
_clear_singleton_cache(new_channel_name)
|
||||
time.sleep(1)
|
||||
self.start([new_channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel restarted to '{new_channel_name}' successfully")
|
||||
|
||||
def add_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically add and start a new channel.
|
||||
If the channel is already running, restart it instead.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name in self._channels:
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' already exists, restarting")
|
||||
if self._channels.get(channel_name):
|
||||
self.restart(channel_name)
|
||||
return
|
||||
logger.info(f"[ChannelManager] Adding channel '{channel_name}'...")
|
||||
_clear_singleton_cache(channel_name)
|
||||
self.start([channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' added successfully")
|
||||
|
||||
def remove_channel(self, channel_name: str):
|
||||
"""
|
||||
Dynamically stop and remove a running channel.
|
||||
"""
|
||||
with self._lock:
|
||||
if channel_name not in self._channels:
|
||||
logger.warning(f"[ChannelManager] Channel '{channel_name}' not found, nothing to remove")
|
||||
return
|
||||
logger.info(f"[ChannelManager] Removing channel '{channel_name}'...")
|
||||
self.stop(channel_name)
|
||||
logger.info(f"[ChannelManager] Channel '{channel_name}' removed successfully")
|
||||
|
||||
|
||||
def _clear_singleton_cache(channel_name: str):
|
||||
"""
|
||||
Clear the singleton cache for the channel class so that
|
||||
a new instance can be created with updated config.
|
||||
"""
|
||||
cls_map = {
|
||||
"web": "channel.web.web_channel.WebChannel",
|
||||
"wechatmp": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatmp_service": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatcom_app": "channel.wechatcom.wechatcomapp_channel.WechatComAppChannel",
|
||||
const.FEISHU: "channel.feishu.feishu_channel.FeiShuChanel",
|
||||
const.DINGTALK: "channel.dingtalk.dingtalk_channel.DingTalkChanel",
|
||||
const.WECOM_BOT: "channel.wecom_bot.wecom_bot_channel.WecomBotChannel",
|
||||
const.QQ: "channel.qq.qq_channel.QQChannel",
|
||||
}
|
||||
module_path = cls_map.get(channel_name)
|
||||
if not module_path:
|
||||
return
|
||||
try:
|
||||
parts = module_path.rsplit(".", 1)
|
||||
module_name, class_name = parts[0], parts[1]
|
||||
import importlib
|
||||
module = importlib.import_module(module_name)
|
||||
wrapper = getattr(module, class_name, None)
|
||||
if wrapper and hasattr(wrapper, '__closure__') and wrapper.__closure__:
|
||||
for cell in wrapper.__closure__:
|
||||
try:
|
||||
cell_contents = cell.cell_contents
|
||||
if isinstance(cell_contents, dict):
|
||||
cell_contents.clear()
|
||||
logger.debug(f"[ChannelManager] Cleared singleton cache for {class_name}")
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Failed to clear singleton cache: {e}")
|
||||
|
||||
|
||||
def sigterm_handler_wrap(_signo):
|
||||
old_handler = signal.getsignal(_signo)
|
||||
|
||||
def func(_signo, _stack_frame):
|
||||
logger.info("signal {} received, exiting...".format(_signo))
|
||||
conf().save_user_datas()
|
||||
if callable(old_handler): # check old_handler
|
||||
if callable(old_handler): # check old_handler
|
||||
return old_handler(_signo, _stack_frame)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def run():
|
||||
global _channel_mgr
|
||||
try:
|
||||
# load config
|
||||
load_config()
|
||||
@@ -26,25 +276,32 @@ def run():
|
||||
# kill signal
|
||||
sigterm_handler_wrap(signal.SIGTERM)
|
||||
|
||||
# create channel
|
||||
channel_name=conf().get('channel_type', 'wx')
|
||||
# Parse channel_type into a list
|
||||
raw_channel = conf().get("channel_type", "web")
|
||||
|
||||
if "--cmd" in sys.argv:
|
||||
channel_name = 'terminal'
|
||||
channel_names = ["terminal"]
|
||||
else:
|
||||
channel_names = _parse_channel_type(raw_channel)
|
||||
if not channel_names:
|
||||
channel_names = ["web"]
|
||||
|
||||
if channel_name == 'wxy':
|
||||
os.environ['WECHATY_LOG']="warn"
|
||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||
# Auto-start web console unless explicitly disabled
|
||||
web_console_enabled = conf().get("web_console", True)
|
||||
if web_console_enabled and "web" not in channel_names:
|
||||
channel_names.append("web")
|
||||
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ['wx','wxy','terminal','wechatmp','wechatmp_service']:
|
||||
PluginManager().load_plugins()
|
||||
logger.info(f"[App] Starting channels: {channel_names}")
|
||||
|
||||
# startup channel
|
||||
channel.startup()
|
||||
_channel_mgr = ChannelManager()
|
||||
_channel_mgr.start(channel_names, first_start=True)
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests
|
||||
from bot.bot import Bot
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
|
||||
# Baidu Unit对话接口 (可用, 但能力较弱)
|
||||
class BaiduUnitBot(Bot):
|
||||
def reply(self, query, context=None):
|
||||
token = self.get_token()
|
||||
url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token
|
||||
post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}"
|
||||
print(post_data)
|
||||
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
||||
response = requests.post(url, data=post_data.encode(), headers=headers)
|
||||
if response:
|
||||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
|
||||
return reply
|
||||
|
||||
def get_token(self):
|
||||
access_key = 'YOUR_ACCESS_KEY'
|
||||
secret_key = 'YOUR_SECRET_KEY'
|
||||
host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key
|
||||
response = requests.get(host)
|
||||
if response:
|
||||
print(response.json())
|
||||
return response.json()['access_token']
|
||||
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
channel factory
|
||||
"""
|
||||
from common import const
|
||||
|
||||
|
||||
def create_bot(bot_type):
|
||||
"""
|
||||
create a bot_type instance
|
||||
:param bot_type: bot type code
|
||||
:return: bot instance
|
||||
"""
|
||||
if bot_type == const.BAIDU:
|
||||
# Baidu Unit对话接口
|
||||
from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
return BaiduUnitBot()
|
||||
|
||||
elif bot_type == const.CHATGPT:
|
||||
# ChatGPT 网页端web接口
|
||||
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
||||
return ChatGPTBot()
|
||||
|
||||
elif bot_type == const.OPEN_AI:
|
||||
# OpenAI 官方对话模型API
|
||||
from bot.openai.open_ai_bot import OpenAIBot
|
||||
return OpenAIBot()
|
||||
|
||||
elif bot_type == const.CHATGPTONAZURE:
|
||||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
||||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
||||
return AzureChatGPTBot()
|
||||
raise RuntimeError
|
||||
@@ -1,156 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf, load_config
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
import openai
|
||||
import openai.error
|
||||
import time
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot,OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# set the default api_key
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
proxy = conf().get('proxy')
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
if conf().get('rate_limit_chatgpt'):
|
||||
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
|
||||
|
||||
self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[CHATGPT] query={}".format(query))
|
||||
|
||||
|
||||
session_id = context['session_id']
|
||||
reply = None
|
||||
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
elif query == '#更新配置':
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, '配置已更新')
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
||||
|
||||
api_key = context.get('openai_api_key')
|
||||
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, session_id, api_key, 0)
|
||||
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
|
||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
|
||||
return reply
|
||||
|
||||
def compose_args(self):
|
||||
return {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p":1,
|
||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get('request_timeout', None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get('request_timeout', None), #重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
|
||||
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict:
|
||||
'''
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
'''
|
||||
try:
|
||||
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
|
||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
response = openai.ChatCompletion.create(
|
||||
api_key=api_key, messages=session.messages, **self.compose_args()
|
||||
)
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]['message']['content']}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
||||
result['content'] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
||||
result['content'] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result['content'] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.warn("[CHATGPT] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(session, session_id, api_key, retry_count+1)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
class AzureChatGPTBot(ChatGPTBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_type = "azure"
|
||||
openai.api_version = "2023-03-15-preview"
|
||||
|
||||
def compose_args(self):
|
||||
args = super().compose_args()
|
||||
args["deployment_id"] = conf().get("azure_deployment_id")
|
||||
#args["engine"] = args["model"]
|
||||
#del(args["model"])
|
||||
return args
|
||||
@@ -1,109 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.openai.open_ai_session import OpenAISession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import openai
|
||||
import openai.error
|
||||
import time
|
||||
|
||||
user_session = dict()
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class OpenAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
proxy = conf().get('proxy')
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
session_id = context['session_id']
|
||||
reply = None
|
||||
if query == '#清除记忆':
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
new_query = str(session)
|
||||
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
||||
|
||||
total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens))
|
||||
|
||||
if total_tokens == 0 :
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, query, session_id, retry_count=0):
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
model= conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||
prompt=query,
|
||||
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
max_tokens=1200, # 回复最大的字符数
|
||||
top_p=1,
|
||||
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
stop=["\n\n\n"]
|
||||
)
|
||||
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
completion_tokens = response["usage"]["completion_tokens"]
|
||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||
return total_tokens, completion_tokens, res_content
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = [0,0,"我现在有点累了,等会再来吧"]
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
||||
result[2] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
||||
result[2] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result[2] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.warn("[OPEN_AI] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, session_id, retry_count+1)
|
||||
else:
|
||||
return result
|
||||
@@ -1,38 +0,0 @@
|
||||
import time
|
||||
import openai
|
||||
import openai.error
|
||||
from common.token_bucket import TokenBucket
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
# OPENAI提供的画图接口
|
||||
class OpenAIImage(object):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('rate_limit_dalle'):
|
||||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
prompt=query, #图片描述
|
||||
n=1, #每次生成图片的数量
|
||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.create_img(query, retry_count+1)
|
||||
else:
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, str(e)
|
||||
674
bridge/agent_bridge.py
Normal file
674
bridge/agent_bridge.py
Normal file
@@ -0,0 +1,674 @@
|
||||
"""
|
||||
Agent Bridge - Integrates Agent system with existing COW bridge
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List
|
||||
|
||||
from agent.protocol import Agent, LLMModel, LLMRequest
|
||||
from bridge.agent_event_handler import AgentEventHandler
|
||||
from bridge.agent_initializer import AgentInitializer
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from models.openai_compatible_bot import OpenAICompatibleBot
|
||||
|
||||
|
||||
def add_openai_compatible_support(bot_instance):
|
||||
"""
|
||||
Dynamically add OpenAI-compatible tool calling support to a bot instance.
|
||||
|
||||
This allows any bot to gain tool calling capability without modifying its code,
|
||||
as long as it uses OpenAI-compatible API format.
|
||||
|
||||
Note: Some bots like ZHIPUAIBot have native tool calling support and don't need enhancement.
|
||||
"""
|
||||
if hasattr(bot_instance, 'call_with_tools'):
|
||||
# Bot already has tool calling support (e.g., ZHIPUAIBot)
|
||||
logger.debug(f"[AgentBridge] {type(bot_instance).__name__} already has native tool calling support")
|
||||
return bot_instance
|
||||
|
||||
# Create a temporary mixin class that combines the bot with OpenAI compatibility
|
||||
class EnhancedBot(bot_instance.__class__, OpenAICompatibleBot):
|
||||
"""Dynamically enhanced bot with OpenAI-compatible tool calling"""
|
||||
|
||||
def get_api_config(self):
|
||||
"""
|
||||
Infer API config from common configuration patterns.
|
||||
Most OpenAI-compatible bots use similar configuration.
|
||||
"""
|
||||
from config import conf
|
||||
|
||||
return {
|
||||
'api_key': conf().get("open_ai_api_key"),
|
||||
'api_base': conf().get("open_ai_api_base"),
|
||||
'model': conf().get("model", "gpt-3.5-turbo"),
|
||||
'default_temperature': conf().get("temperature", 0.9),
|
||||
'default_top_p': conf().get("top_p", 1.0),
|
||||
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
|
||||
'default_presence_penalty': conf().get("presence_penalty", 0.0),
|
||||
}
|
||||
|
||||
# Change the bot's class to the enhanced version
|
||||
bot_instance.__class__ = EnhancedBot
|
||||
logger.info(
|
||||
f"[AgentBridge] Enhanced {bot_instance.__class__.__bases__[0].__name__} with OpenAI-compatible tool calling")
|
||||
|
||||
return bot_instance
|
||||
|
||||
|
||||
class AgentLLMModel(LLMModel):
|
||||
"""
|
||||
LLM Model adapter that uses COW's existing bot infrastructure
|
||||
"""
|
||||
|
||||
_MODEL_BOT_TYPE_MAP = {
|
||||
"wenxin": const.BAIDU, "wenxin-4": const.BAIDU,
|
||||
"xunfei": const.XUNFEI, const.QWEN: const.QWEN,
|
||||
const.MODELSCOPE: const.MODELSCOPE,
|
||||
}
|
||||
_MODEL_PREFIX_MAP = [
|
||||
("qwen", const.QWEN_DASHSCOPE), ("qwq", const.QWEN_DASHSCOPE), ("qvq", const.QWEN_DASHSCOPE),
|
||||
("gemini", const.GEMINI), ("glm", const.ZHIPU_AI), ("claude", const.CLAUDEAPI),
|
||||
("moonshot", const.MOONSHOT), ("kimi", const.MOONSHOT),
|
||||
("doubao", const.DOUBAO),
|
||||
]
|
||||
|
||||
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
|
||||
from config import conf
|
||||
super().__init__(model=conf().get("model", const.GPT_41))
|
||||
self.bridge = bridge
|
||||
self.bot_type = bot_type
|
||||
self._bot = None
|
||||
self._bot_model = None
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
from config import conf
|
||||
return conf().get("model", const.GPT_41)
|
||||
|
||||
@model.setter
|
||||
def model(self, value):
|
||||
pass
|
||||
|
||||
def _resolve_bot_type(self, model_name: str) -> str:
|
||||
"""Resolve bot type from model name, matching Bridge.__init__ logic."""
|
||||
from config import conf
|
||||
|
||||
if conf().get("use_linkai", False) and conf().get("linkai_api_key"):
|
||||
return const.LINKAI
|
||||
# Support custom bot type configuration
|
||||
configured_bot_type = conf().get("bot_type")
|
||||
if configured_bot_type:
|
||||
return configured_bot_type
|
||||
|
||||
if not model_name or not isinstance(model_name, str):
|
||||
return const.CHATGPT
|
||||
if model_name in self._MODEL_BOT_TYPE_MAP:
|
||||
return self._MODEL_BOT_TYPE_MAP[model_name]
|
||||
if model_name.lower().startswith("minimax") or model_name in ["abab6.5-chat"]:
|
||||
return const.MiniMax
|
||||
if model_name in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
return const.QWEN_DASHSCOPE
|
||||
if model_name in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
return const.MOONSHOT
|
||||
if model_name in [const.DEEPSEEK_CHAT, const.DEEPSEEK_REASONER]:
|
||||
return const.CHATGPT
|
||||
for prefix, btype in self._MODEL_PREFIX_MAP:
|
||||
if model_name.startswith(prefix):
|
||||
return btype
|
||||
return const.CHATGPT
|
||||
|
||||
@property
|
||||
def bot(self):
|
||||
"""Lazy load the bot, re-create when model changes"""
|
||||
from models.bot_factory import create_bot
|
||||
cur_model = self.model
|
||||
if self._bot is None or self._bot_model != cur_model:
|
||||
bot_type = self._resolve_bot_type(cur_model)
|
||||
self._bot = create_bot(bot_type)
|
||||
self._bot = add_openai_compatible_support(self._bot)
|
||||
self._bot_model = cur_model
|
||||
return self._bot
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model using COW's bot infrastructure
|
||||
"""
|
||||
try:
|
||||
# For non-streaming calls, we'll use the existing reply method
|
||||
# This is a simplified implementation
|
||||
if hasattr(self.bot, 'call_with_tools'):
|
||||
# Use tool-enabled call if available
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
'tools': getattr(request, 'tools', None),
|
||||
'stream': False,
|
||||
'model': self.model # Pass model parameter
|
||||
}
|
||||
# Only pass max_tokens if it's explicitly set
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
response = self.bot.call_with_tools(**kwargs)
|
||||
return self._format_response(response)
|
||||
else:
|
||||
# Fallback to regular call
|
||||
# This would need to be implemented based on your specific needs
|
||||
raise NotImplementedError("Regular call not implemented yet")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentLLMModel call error: {e}")
|
||||
raise
|
||||
|
||||
def call_stream(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with streaming using COW's bot infrastructure
|
||||
"""
|
||||
try:
|
||||
if hasattr(self.bot, 'call_with_tools'):
|
||||
# Use tool-enabled streaming call if available
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
|
||||
# Build kwargs for call_with_tools
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
'tools': getattr(request, 'tools', None),
|
||||
'stream': True,
|
||||
'model': self.model # Pass model parameter
|
||||
}
|
||||
|
||||
# Only pass max_tokens if explicitly set, let the bot use its default
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
|
||||
# Add system prompt if present
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
# Pass channel_type for linkai tracking
|
||||
channel_type = getattr(self, 'channel_type', None)
|
||||
if channel_type:
|
||||
kwargs['channel_type'] = channel_type
|
||||
|
||||
stream = self.bot.call_with_tools(**kwargs)
|
||||
|
||||
# Convert stream format to our expected format
|
||||
for chunk in stream:
|
||||
yield self._format_stream_chunk(chunk)
|
||||
else:
|
||||
bot_type = type(self.bot).__name__
|
||||
raise NotImplementedError(f"Bot {bot_type} does not support call_with_tools. Please add the method.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentLLMModel call_stream error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _format_response(self, response):
|
||||
"""Format Claude response to our expected format"""
|
||||
# This would need to be implemented based on Claude's response format
|
||||
return response
|
||||
|
||||
def _format_stream_chunk(self, chunk):
|
||||
"""Format Claude stream chunk to our expected format"""
|
||||
# This would need to be implemented based on Claude's stream format
|
||||
return chunk
|
||||
|
||||
|
||||
class AgentBridge:
|
||||
"""
|
||||
Bridge class that integrates super Agent with COW
|
||||
Manages multiple agent instances per session for conversation isolation
|
||||
"""
|
||||
|
||||
def __init__(self, bridge: Bridge):
|
||||
self.bridge = bridge
|
||||
self.agents = {} # session_id -> Agent instance mapping
|
||||
self.default_agent = None # For backward compatibility (no session_id)
|
||||
self.agent: Optional[Agent] = None
|
||||
self.scheduler_initialized = False
|
||||
|
||||
# Create helper instances
|
||||
self.initializer = AgentInitializer(bridge, self)
|
||||
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
|
||||
"""
|
||||
Create the super agent with COW integration
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt
|
||||
tools: List of tools (optional)
|
||||
**kwargs: Additional agent parameters
|
||||
|
||||
Returns:
|
||||
Agent instance
|
||||
"""
|
||||
# Create LLM model that uses COW's bot infrastructure
|
||||
model = AgentLLMModel(self.bridge)
|
||||
|
||||
# Default tools if none provided
|
||||
if tools is None:
|
||||
# Use ToolManager to load all available tools
|
||||
from agent.tools import ToolManager
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.load_tools()
|
||||
|
||||
tools = []
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
if tool:
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
|
||||
|
||||
# Create agent instance
|
||||
agent = Agent(
|
||||
system_prompt=system_prompt,
|
||||
description=kwargs.get("description", "AI Super Agent"),
|
||||
model=model,
|
||||
tools=tools,
|
||||
max_steps=kwargs.get("max_steps", 15),
|
||||
output_mode=kwargs.get("output_mode", "logger"),
|
||||
workspace_dir=kwargs.get("workspace_dir"), # Pass workspace for skills loading
|
||||
enable_skills=kwargs.get("enable_skills", True), # Enable skills by default
|
||||
memory_manager=kwargs.get("memory_manager"), # Pass memory manager
|
||||
max_context_tokens=kwargs.get("max_context_tokens"),
|
||||
context_reserve_tokens=kwargs.get("context_reserve_tokens"),
|
||||
runtime_info=kwargs.get("runtime_info") # Pass runtime_info for dynamic time updates
|
||||
)
|
||||
|
||||
# Log skill loading details
|
||||
if agent.skill_manager:
|
||||
logger.debug(f"[AgentBridge] SkillManager initialized with {len(agent.skill_manager.skills)} skills")
|
||||
|
||||
return agent
|
||||
|
||||
def get_agent(self, session_id: str = None) -> Optional[Agent]:
|
||||
"""
|
||||
Get agent instance for the given session
|
||||
|
||||
Args:
|
||||
session_id: Session identifier (e.g., user_id). If None, returns default agent.
|
||||
|
||||
Returns:
|
||||
Agent instance for this session
|
||||
"""
|
||||
# If no session_id, use default agent (backward compatibility)
|
||||
if session_id is None:
|
||||
if self.default_agent is None:
|
||||
self._init_default_agent()
|
||||
return self.default_agent
|
||||
|
||||
# Check if agent exists for this session
|
||||
if session_id not in self.agents:
|
||||
self._init_agent_for_session(session_id)
|
||||
|
||||
return self.agents[session_id]
|
||||
|
||||
def _init_default_agent(self):
|
||||
"""Initialize default super agent"""
|
||||
agent = self.initializer.initialize_agent(session_id=None)
|
||||
self.default_agent = agent
|
||||
|
||||
def _init_agent_for_session(self, session_id: str):
|
||||
"""Initialize agent for a specific session"""
|
||||
agent = self.initializer.initialize_agent(session_id=session_id)
|
||||
self.agents[session_id] = agent
|
||||
|
||||
def agent_reply(self, query: str, context: Context = None,
|
||||
on_event=None, clear_history: bool = False) -> Reply:
|
||||
"""
|
||||
Use super agent to reply to a query
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
context: COW context (optional, contains session_id for user isolation)
|
||||
on_event: Event callback (optional)
|
||||
clear_history: Whether to clear conversation history
|
||||
|
||||
Returns:
|
||||
Reply object
|
||||
"""
|
||||
session_id = None
|
||||
agent = None
|
||||
try:
|
||||
# Extract session_id from context for user isolation
|
||||
if context:
|
||||
session_id = context.kwargs.get("session_id") or context.get("session_id")
|
||||
|
||||
# Get agent for this session (will auto-initialize if needed)
|
||||
agent = self.get_agent(session_id=session_id)
|
||||
if not agent:
|
||||
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
|
||||
|
||||
# Create event handler for logging and channel communication
|
||||
event_handler = AgentEventHandler(context=context, original_callback=on_event)
|
||||
|
||||
# Filter tools based on context
|
||||
original_tools = agent.tools
|
||||
filtered_tools = original_tools
|
||||
|
||||
# If this is a scheduled task execution, exclude scheduler tool to prevent recursion
|
||||
if context and context.get("is_scheduled_task"):
|
||||
filtered_tools = [tool for tool in agent.tools if tool.name != "scheduler"]
|
||||
agent.tools = filtered_tools
|
||||
logger.info(f"[AgentBridge] Scheduled task execution: excluded scheduler tool ({len(filtered_tools)}/{len(original_tools)} tools)")
|
||||
else:
|
||||
# Attach context to scheduler tool if present
|
||||
if context and agent.tools:
|
||||
for tool in agent.tools:
|
||||
if tool.name == "scheduler":
|
||||
try:
|
||||
from agent.tools.scheduler.integration import attach_scheduler_to_tool
|
||||
attach_scheduler_to_tool(tool, context)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
|
||||
break
|
||||
|
||||
# Pass channel_type to model so linkai requests carry it
|
||||
if context and hasattr(agent, 'model'):
|
||||
agent.model.channel_type = context.get("channel_type", "")
|
||||
|
||||
# Store session_id on agent so executor can clear DB on fatal errors
|
||||
agent._current_session_id = session_id
|
||||
|
||||
try:
|
||||
# Use agent's run_stream method with event handler
|
||||
response = agent.run_stream(
|
||||
user_message=query,
|
||||
on_event=event_handler.handle_event,
|
||||
clear_history=clear_history
|
||||
)
|
||||
finally:
|
||||
# Restore original tools
|
||||
if context and context.get("is_scheduled_task"):
|
||||
agent.tools = original_tools
|
||||
|
||||
# Log execution summary
|
||||
event_handler.log_summary()
|
||||
|
||||
# Persist new messages generated during this run
|
||||
if session_id:
|
||||
channel_type = (context.get("channel_type") or "") if context else ""
|
||||
new_messages = getattr(agent, '_last_run_new_messages', [])
|
||||
if new_messages:
|
||||
self._persist_messages(session_id, list(new_messages), channel_type)
|
||||
else:
|
||||
with agent.messages_lock:
|
||||
msg_count = len(agent.messages)
|
||||
if msg_count == 0:
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().clear_session(session_id)
|
||||
logger.info(f"[AgentBridge] Cleared DB for recovered session: {session_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to clear DB after recovery: {e}")
|
||||
|
||||
# Check if there are files to send (from read tool)
|
||||
if hasattr(agent, 'stream_executor') and hasattr(agent.stream_executor, 'files_to_send'):
|
||||
files_to_send = agent.stream_executor.files_to_send
|
||||
if files_to_send:
|
||||
# Send the first file (for now, handle one file at a time)
|
||||
file_info = files_to_send[0]
|
||||
logger.info(f"[AgentBridge] Sending file: {file_info.get('path')}")
|
||||
|
||||
# Clear files_to_send for next request
|
||||
agent.stream_executor.files_to_send = []
|
||||
|
||||
# Return file reply based on file type
|
||||
return self._create_file_reply(file_info, response, context)
|
||||
|
||||
return Reply(ReplyType.TEXT, response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent reply error: {e}")
|
||||
# If the agent cleared its messages due to format error / overflow,
|
||||
# also purge the DB so the next request starts clean.
|
||||
if session_id and agent:
|
||||
try:
|
||||
with agent.messages_lock:
|
||||
msg_count = len(agent.messages)
|
||||
if msg_count == 0:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().clear_session(session_id)
|
||||
logger.info(f"[AgentBridge] Cleared DB for session after error: {session_id}")
|
||||
except Exception as db_err:
|
||||
logger.warning(f"[AgentBridge] Failed to clear DB after error: {db_err}")
|
||||
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
|
||||
|
||||
def _create_file_reply(self, file_info: dict, text_response: str, context: Context = None) -> Reply:
|
||||
"""
|
||||
Create a reply for sending files
|
||||
|
||||
Args:
|
||||
file_info: File metadata from read tool
|
||||
text_response: Text response from agent
|
||||
context: Context object
|
||||
|
||||
Returns:
|
||||
Reply object for file sending
|
||||
"""
|
||||
file_type = file_info.get("file_type", "file")
|
||||
file_path = file_info.get("path")
|
||||
|
||||
# For images, use IMAGE_URL type (channel will handle upload)
|
||||
if file_type == "image":
|
||||
# Convert local path to file:// URL for channel processing
|
||||
file_url = f"file://{file_path}"
|
||||
logger.info(f"[AgentBridge] Sending image: {file_url}")
|
||||
reply = Reply(ReplyType.IMAGE_URL, file_url)
|
||||
# Attach text message if present (for channels that support text+image)
|
||||
if text_response:
|
||||
reply.text_content = text_response # Store accompanying text
|
||||
return reply
|
||||
|
||||
# For all file types (document, video, audio), use FILE type
|
||||
if file_type in ["document", "video", "audio"]:
|
||||
file_url = f"file://{file_path}"
|
||||
logger.info(f"[AgentBridge] Sending {file_type}: {file_url}")
|
||||
reply = Reply(ReplyType.FILE, file_url)
|
||||
reply.file_name = file_info.get("file_name", os.path.basename(file_path))
|
||||
# Attach text message if present
|
||||
if text_response:
|
||||
reply.text_content = text_response
|
||||
return reply
|
||||
|
||||
# For other unknown file types, return text with file info
|
||||
message = text_response or file_info.get("message", "文件已准备")
|
||||
message += f"\n\n[文件: {file_info.get('file_name', file_path)}]"
|
||||
return Reply(ReplyType.TEXT, message)
|
||||
|
||||
def _migrate_config_to_env(self, workspace_root: str):
|
||||
"""
|
||||
Migrate API keys from config.json to .env file if not already set
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace directory path (not used, kept for compatibility)
|
||||
"""
|
||||
from config import conf
|
||||
import os
|
||||
|
||||
# Mapping from config.json keys to environment variable names
|
||||
key_mapping = {
|
||||
"open_ai_api_key": "OPENAI_API_KEY",
|
||||
"open_ai_api_base": "OPENAI_API_BASE",
|
||||
"gemini_api_key": "GEMINI_API_KEY",
|
||||
"claude_api_key": "CLAUDE_API_KEY",
|
||||
"linkai_api_key": "LINKAI_API_KEY",
|
||||
}
|
||||
|
||||
# Use fixed secure location for .env file
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
|
||||
# Read existing env vars from .env file
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
with open(env_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, _ = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to read .env file: {e}")
|
||||
|
||||
# Check which keys need to be migrated
|
||||
keys_to_migrate = {}
|
||||
for config_key, env_key in key_mapping.items():
|
||||
# Skip if already in .env file
|
||||
if env_key in existing_env_vars:
|
||||
continue
|
||||
|
||||
# Get value from config.json
|
||||
value = conf().get(config_key, "")
|
||||
if value and value.strip(): # Only migrate non-empty values
|
||||
keys_to_migrate[env_key] = value.strip()
|
||||
|
||||
# Log summary if there are keys to skip
|
||||
if existing_env_vars:
|
||||
logger.debug(f"[AgentBridge] {len(existing_env_vars)} env vars already in .env")
|
||||
|
||||
# Write new keys to .env file
|
||||
if keys_to_migrate:
|
||||
try:
|
||||
# Ensure ~/.cow directory and .env file exist
|
||||
env_dir = os.path.dirname(env_file)
|
||||
if not os.path.exists(env_dir):
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
if not os.path.exists(env_file):
|
||||
open(env_file, 'a').close()
|
||||
|
||||
# Append new keys
|
||||
with open(env_file, 'a', encoding='utf-8') as f:
|
||||
f.write('\n# Auto-migrated from config.json\n')
|
||||
for key, value in keys_to_migrate.items():
|
||||
f.write(f'{key}={value}\n')
|
||||
# Also set in current process
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[AgentBridge] Migrated {len(keys_to_migrate)} API keys from config.json to .env: {list(keys_to_migrate.keys())}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to migrate API keys: {e}")
|
||||
|
||||
def _persist_messages(
|
||||
self, session_id: str, new_messages: list, channel_type: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
Persist new messages to the conversation store after each agent run.
|
||||
|
||||
Failures are logged but never propagate — they must not interrupt replies.
|
||||
"""
|
||||
if not new_messages:
|
||||
return
|
||||
try:
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().append_messages(
|
||||
session_id, new_messages, channel_type=channel_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentBridge] Failed to persist messages for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""
|
||||
Clear a specific session's agent and conversation history
|
||||
|
||||
Args:
|
||||
session_id: Session identifier to clear
|
||||
"""
|
||||
if session_id in self.agents:
|
||||
logger.info(f"[AgentBridge] Clearing session: {session_id}")
|
||||
del self.agents[session_id]
|
||||
|
||||
def clear_all_sessions(self):
|
||||
"""Clear all agent sessions"""
|
||||
logger.info(f"[AgentBridge] Clearing all sessions ({len(self.agents)} total)")
|
||||
self.agents.clear()
|
||||
self.default_agent = None
|
||||
|
||||
def refresh_all_skills(self) -> int:
|
||||
"""
|
||||
Refresh skills and conditional tools in all agent instances after
|
||||
environment variable changes. This allows hot-reload without restarting.
|
||||
|
||||
Returns:
|
||||
Number of agent instances refreshed
|
||||
"""
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from config import conf
|
||||
|
||||
# Reload environment variables from .env file
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
env_file = os.path.join(workspace_root, '.env')
|
||||
|
||||
if os.path.exists(env_file):
|
||||
load_dotenv(env_file, override=True)
|
||||
logger.info(f"[AgentBridge] Reloaded environment variables from {env_file}")
|
||||
|
||||
refreshed_count = 0
|
||||
|
||||
# Collect all agent instances to refresh
|
||||
agents_to_refresh = []
|
||||
if self.default_agent:
|
||||
agents_to_refresh.append(("default", self.default_agent))
|
||||
for session_id, agent in self.agents.items():
|
||||
agents_to_refresh.append((session_id, agent))
|
||||
|
||||
for label, agent in agents_to_refresh:
|
||||
# Refresh skills
|
||||
if hasattr(agent, 'skill_manager') and agent.skill_manager:
|
||||
agent.skill_manager.refresh_skills()
|
||||
|
||||
# Refresh conditional tools (e.g. web_search depends on API keys)
|
||||
self._refresh_conditional_tools(agent)
|
||||
|
||||
refreshed_count += 1
|
||||
|
||||
if refreshed_count > 0:
|
||||
logger.info(f"[AgentBridge] Refreshed skills & tools in {refreshed_count} agent instance(s)")
|
||||
|
||||
return refreshed_count
|
||||
|
||||
@staticmethod
|
||||
def _refresh_conditional_tools(agent):
|
||||
"""
|
||||
Add or remove conditional tools based on current environment variables.
|
||||
For example, web_search should only be present when BOCHA_API_KEY or
|
||||
LINKAI_API_KEY is set.
|
||||
"""
|
||||
try:
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
|
||||
has_tool = any(t.name == "web_search" for t in agent.tools)
|
||||
available = WebSearch.is_available()
|
||||
|
||||
if available and not has_tool:
|
||||
# API key was added - inject the tool
|
||||
tool = WebSearch()
|
||||
tool.model = agent.model
|
||||
agent.tools.append(tool)
|
||||
logger.info("[AgentBridge] web_search tool added (API key now available)")
|
||||
elif not available and has_tool:
|
||||
# API key was removed - remove the tool
|
||||
agent.tools = [t for t in agent.tools if t.name != "web_search"]
|
||||
logger.info("[AgentBridge] web_search tool removed (API key no longer available)")
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentBridge] Failed to refresh conditional tools: {e}")
|
||||
115
bridge/agent_event_handler.py
Normal file
115
bridge/agent_event_handler.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Agent Event Handler - Handles agent events and thinking process output
|
||||
"""
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class AgentEventHandler:
|
||||
"""
|
||||
Handles agent events and optionally sends intermediate messages to channel
|
||||
"""
|
||||
|
||||
def __init__(self, context=None, original_callback=None):
|
||||
"""
|
||||
Initialize event handler
|
||||
|
||||
Args:
|
||||
context: COW context (for accessing channel)
|
||||
original_callback: Original event callback to chain
|
||||
"""
|
||||
self.context = context
|
||||
self.original_callback = original_callback
|
||||
|
||||
# Get channel for sending intermediate messages
|
||||
self.channel = None
|
||||
if context:
|
||||
self.channel = context.kwargs.get("channel") if hasattr(context, "kwargs") else None
|
||||
|
||||
# Track current thinking for channel output
|
||||
self.current_thinking = ""
|
||||
self.turn_number = 0
|
||||
|
||||
def handle_event(self, event):
|
||||
"""
|
||||
Main event handler
|
||||
|
||||
Args:
|
||||
event: Event dict with type and data
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
# Dispatch to specific handlers
|
||||
if event_type == "turn_start":
|
||||
self._handle_turn_start(data)
|
||||
elif event_type == "message_update":
|
||||
self._handle_message_update(data)
|
||||
elif event_type == "message_end":
|
||||
self._handle_message_end(data)
|
||||
elif event_type == "tool_execution_start":
|
||||
self._handle_tool_execution_start(data)
|
||||
elif event_type == "tool_execution_end":
|
||||
self._handle_tool_execution_end(data)
|
||||
|
||||
# Call original callback if provided
|
||||
if self.original_callback:
|
||||
self.original_callback(event)
|
||||
|
||||
def _handle_turn_start(self, data):
|
||||
"""Handle turn start event"""
|
||||
self.turn_number = data.get("turn", 0)
|
||||
self.has_tool_calls_in_turn = False
|
||||
self.current_thinking = ""
|
||||
|
||||
def _handle_message_update(self, data):
|
||||
"""Handle message update event (streaming text)"""
|
||||
delta = data.get("delta", "")
|
||||
self.current_thinking += delta
|
||||
|
||||
def _handle_message_end(self, data):
|
||||
"""Handle message end event"""
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
|
||||
# Only send thinking process if followed by tool calls
|
||||
if tool_calls:
|
||||
if self.current_thinking.strip():
|
||||
logger.info(f"💭 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
|
||||
# Send thinking process to channel
|
||||
self._send_to_channel(f"{self.current_thinking.strip()}")
|
||||
else:
|
||||
# No tool calls = final response (logged at agent_stream level)
|
||||
if self.current_thinking.strip():
|
||||
logger.debug(f"💬 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
|
||||
|
||||
self.current_thinking = ""
|
||||
|
||||
def _handle_tool_execution_start(self, data):
|
||||
"""Handle tool execution start event - logged by agent_stream.py"""
|
||||
pass
|
||||
|
||||
def _handle_tool_execution_end(self, data):
|
||||
"""Handle tool execution end event - logged by agent_stream.py"""
|
||||
pass
|
||||
|
||||
def _send_to_channel(self, message):
|
||||
"""
|
||||
Try to send intermediate message to channel.
|
||||
Skipped in SSE mode because thinking text is already streamed via on_event.
|
||||
"""
|
||||
if self.context and self.context.get("on_event"):
|
||||
return
|
||||
|
||||
if self.channel:
|
||||
try:
|
||||
from bridge.reply import Reply, ReplyType
|
||||
reply = Reply(ReplyType.TEXT, message)
|
||||
self.channel._send(reply, self.context)
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentEventHandler] Failed to send to channel: {e}")
|
||||
|
||||
def log_summary(self):
|
||||
"""Log execution summary - simplified"""
|
||||
# Summary removed as per user request
|
||||
# Real-time logging during execution is sufficient
|
||||
pass
|
||||
584
bridge/agent_initializer.py
Normal file
584
bridge/agent_initializer.py
Normal file
@@ -0,0 +1,584 @@
|
||||
"""
|
||||
Agent Initializer - Handles agent initialization logic
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import time
|
||||
from typing import Optional, List
|
||||
|
||||
from agent.protocol import Agent
|
||||
from agent.tools import ToolManager
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class AgentInitializer:
|
||||
"""
|
||||
Handles agent initialization including:
|
||||
- Workspace setup
|
||||
- Memory system initialization
|
||||
- Tool loading
|
||||
- System prompt building
|
||||
"""
|
||||
|
||||
def __init__(self, bridge, agent_bridge):
|
||||
"""
|
||||
Initialize agent initializer
|
||||
|
||||
Args:
|
||||
bridge: COW bridge instance
|
||||
agent_bridge: AgentBridge instance (for create_agent method)
|
||||
"""
|
||||
self.bridge = bridge
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def initialize_agent(self, session_id: Optional[str] = None) -> Agent:
|
||||
"""
|
||||
Initialize agent for a session
|
||||
|
||||
Args:
|
||||
session_id: Session ID (None for default agent)
|
||||
|
||||
Returns:
|
||||
Initialized agent instance
|
||||
"""
|
||||
from config import conf
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
|
||||
# Migrate API keys
|
||||
self._migrate_config_to_env(workspace_root)
|
||||
|
||||
# Load environment variables
|
||||
self._load_env_file()
|
||||
|
||||
# Initialize workspace
|
||||
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
|
||||
workspace_files = ensure_workspace(workspace_root, create_templates=True)
|
||||
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Workspace initialized at: {workspace_root}")
|
||||
|
||||
# Setup memory system
|
||||
memory_manager, memory_tools = self._setup_memory_system(workspace_root, session_id)
|
||||
|
||||
# Load tools
|
||||
tools = self._load_tools(workspace_root, memory_manager, memory_tools, session_id)
|
||||
|
||||
# Initialize scheduler if needed
|
||||
self._initialize_scheduler(tools, session_id)
|
||||
|
||||
# Load context files
|
||||
context_files = load_context_files(workspace_root)
|
||||
|
||||
# Initialize skill manager
|
||||
skill_manager = self._initialize_skill_manager(workspace_root, session_id)
|
||||
|
||||
# Build system prompt
|
||||
prompt_builder = PromptBuilder(workspace_dir=workspace_root, language="zh")
|
||||
runtime_info = self._get_runtime_info(workspace_root)
|
||||
|
||||
system_prompt = prompt_builder.build(
|
||||
tools=tools,
|
||||
context_files=context_files,
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
)
|
||||
|
||||
# Get cost control parameters
|
||||
from config import conf
|
||||
max_steps = conf().get("agent_max_steps", 20)
|
||||
max_context_tokens = conf().get("agent_max_context_tokens", 50000)
|
||||
|
||||
# Create agent
|
||||
agent = self.agent_bridge.create_agent(
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
max_steps=max_steps,
|
||||
output_mode="logger",
|
||||
workspace_dir=workspace_root,
|
||||
skill_manager=skill_manager,
|
||||
enable_skills=True,
|
||||
max_context_tokens=max_context_tokens,
|
||||
runtime_info=runtime_info # Pass runtime_info for dynamic time updates
|
||||
)
|
||||
|
||||
# Attach memory manager and share LLM model for summarization
|
||||
if memory_manager:
|
||||
agent.memory_manager = memory_manager
|
||||
if hasattr(agent, 'model') and agent.model:
|
||||
memory_manager.flush_manager.llm_model = agent.model
|
||||
|
||||
# Restore persisted conversation history for this session
|
||||
if session_id:
|
||||
self._restore_conversation_history(agent, session_id)
|
||||
|
||||
# Start daily memory flush timer (once, on first agent init regardless of session)
|
||||
self._start_daily_flush_timer()
|
||||
|
||||
return agent
|
||||
|
||||
def _restore_conversation_history(self, agent, session_id: str) -> None:
|
||||
"""
|
||||
Load persisted conversation messages from SQLite and inject them
|
||||
into the agent's in-memory message list.
|
||||
|
||||
Only user text and assistant text are restored. Tool call chains
|
||||
(tool_use / tool_result) are stripped out because:
|
||||
1. They are intermediate process, the value is already in the final
|
||||
assistant text reply.
|
||||
2. They consume massive context tokens (often 80%+ of history).
|
||||
3. Different models have incompatible tool message formats, so
|
||||
restoring tool chains across model switches causes 400 errors.
|
||||
4. Eliminates the entire class of tool_use/tool_result pairing bugs.
|
||||
"""
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
store = get_conversation_store()
|
||||
max_turns = conf().get("agent_max_context_turns", 20)
|
||||
restore_turns = max(3, max_turns // 6)
|
||||
saved = store.load_messages(session_id, max_turns=restore_turns)
|
||||
if saved:
|
||||
filtered = self._filter_text_only_messages(saved)
|
||||
if filtered:
|
||||
with agent.messages_lock:
|
||||
agent.messages = filtered
|
||||
logger.debug(
|
||||
f"[AgentInitializer] Restored {len(filtered)} text messages "
|
||||
f"(from {len(saved)} total, {restore_turns} turns cap) "
|
||||
f"for session={session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentInitializer] Failed to restore conversation history for "
|
||||
f"session={session_id}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_text_only_messages(messages: list) -> list:
|
||||
"""
|
||||
Extract clean user/assistant turn pairs from raw message history.
|
||||
|
||||
Groups messages into turns (each starting with a real user query),
|
||||
then keeps only:
|
||||
- The first user text in each turn (the actual user input)
|
||||
- The last assistant text in each turn (the final answer)
|
||||
|
||||
All tool_use, tool_result, intermediate assistant thoughts, and
|
||||
internal hint messages injected by the agent loop are discarded.
|
||||
"""
|
||||
|
||||
def _extract_text(content) -> str:
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
def _is_real_user_msg(msg: dict) -> bool:
|
||||
"""True for actual user input, False for tool_result or internal hints."""
|
||||
if msg.get("role") != "user":
|
||||
return False
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
has_tool_result = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result"
|
||||
for b in content
|
||||
)
|
||||
if has_tool_result:
|
||||
return False
|
||||
text = _extract_text(content)
|
||||
return bool(text)
|
||||
|
||||
# Group into turns: each turn starts with a real user message
|
||||
turns = []
|
||||
current_turn = None
|
||||
for msg in messages:
|
||||
if _is_real_user_msg(msg):
|
||||
if current_turn is not None:
|
||||
turns.append(current_turn)
|
||||
current_turn = {"user": msg, "assistants": []}
|
||||
elif current_turn is not None and msg.get("role") == "assistant":
|
||||
text = _extract_text(msg.get("content"))
|
||||
if text:
|
||||
current_turn["assistants"].append(text)
|
||||
if current_turn is not None:
|
||||
turns.append(current_turn)
|
||||
|
||||
# Build result: one user msg + one assistant msg per turn
|
||||
filtered = []
|
||||
for turn in turns:
|
||||
user_text = _extract_text(turn["user"].get("content"))
|
||||
if not user_text:
|
||||
continue
|
||||
filtered.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": user_text}]
|
||||
})
|
||||
if turn["assistants"]:
|
||||
final_reply = turn["assistants"][-1]
|
||||
filtered.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": final_reply}]
|
||||
})
|
||||
|
||||
return filtered
|
||||
|
||||
def _load_env_file(self):
|
||||
"""Load environment variables from .env file"""
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(env_file, override=True)
|
||||
except ImportError:
|
||||
logger.warning("[AgentInitializer] python-dotenv not installed")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to load .env file: {e}")
|
||||
|
||||
def _setup_memory_system(self, workspace_root: str, session_id: Optional[str] = None):
|
||||
"""
|
||||
Setup memory system
|
||||
|
||||
Returns:
|
||||
(memory_manager, memory_tools) tuple
|
||||
"""
|
||||
memory_manager = None
|
||||
memory_tools = []
|
||||
|
||||
try:
|
||||
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
from config import conf
|
||||
|
||||
# Initialize embedding provider (prefer OpenAI, fallback to LinkAI)
|
||||
embedding_provider = None
|
||||
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key=openai_api_key,
|
||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||
)
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] OpenAI embedding initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
|
||||
|
||||
if embedding_provider is None:
|
||||
linkai_api_key = conf().get("linkai_api_key", "") or os.environ.get("LINKAI_API_KEY", "")
|
||||
linkai_api_base = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
if linkai_api_key and linkai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="linkai",
|
||||
model="text-embedding-3-small",
|
||||
api_key=linkai_api_key,
|
||||
api_base=f"{linkai_api_base}/v1"
|
||||
)
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] LinkAI embedding initialized (fallback)")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] LinkAI embedding failed: {e}")
|
||||
|
||||
# Create memory manager
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
|
||||
# Sync memory
|
||||
self._sync_memory(memory_manager, session_id)
|
||||
|
||||
# Create memory tools
|
||||
memory_tools = [
|
||||
MemorySearchTool(memory_manager),
|
||||
MemoryGetTool(memory_manager)
|
||||
]
|
||||
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] Memory system initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Memory system not available: {e}")
|
||||
|
||||
return memory_manager, memory_tools
|
||||
|
||||
def _sync_memory(self, memory_manager, session_id: Optional[str] = None):
|
||||
"""Sync memory database"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed")
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
if loop.is_running():
|
||||
asyncio.create_task(memory_manager.sync())
|
||||
else:
|
||||
loop.run_until_complete(memory_manager.sync())
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Memory sync failed: {e}")
|
||||
|
||||
def _load_tools(self, workspace_root: str, memory_manager, memory_tools: List, session_id: Optional[str] = None):
|
||||
"""Load all tools"""
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.load_tools()
|
||||
|
||||
tools = []
|
||||
file_config = {
|
||||
"cwd": workspace_root,
|
||||
"memory_manager": memory_manager
|
||||
} if memory_manager else {"cwd": workspace_root}
|
||||
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
# Skip web_search if no API key is available
|
||||
if tool_name == "web_search":
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
if not WebSearch.is_available():
|
||||
logger.debug("[AgentInitializer] WebSearch skipped - no BOCHA_API_KEY or LINKAI_API_KEY")
|
||||
continue
|
||||
|
||||
# Special handling for EnvConfig tool
|
||||
if tool_name == "env_config":
|
||||
from agent.tools import EnvConfig
|
||||
tool = EnvConfig({"agent_bridge": self.agent_bridge})
|
||||
else:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
if tool:
|
||||
# Apply workspace config to file operation tools
|
||||
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls', 'web_fetch']:
|
||||
tool.config = file_config
|
||||
tool.cwd = file_config.get("cwd", getattr(tool, 'cwd', None))
|
||||
if 'memory_manager' in file_config:
|
||||
tool.memory_manager = file_config['memory_manager']
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to load tool {tool_name}: {e}")
|
||||
|
||||
# Add memory tools
|
||||
if memory_tools:
|
||||
tools.extend(memory_tools)
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Added {len(memory_tools)} memory tools")
|
||||
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Loaded {len(tools)} tools: {[t.name for t in tools]}")
|
||||
|
||||
return tools
|
||||
|
||||
def _initialize_scheduler(self, tools: List, session_id: Optional[str] = None):
|
||||
"""Initialize scheduler service if needed"""
|
||||
if not self.agent_bridge.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import init_scheduler
|
||||
if init_scheduler(self.agent_bridge):
|
||||
self.agent_bridge.scheduler_initialized = True
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] Scheduler service initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize scheduler: {e}")
|
||||
|
||||
# Inject scheduler dependencies
|
||||
if self.agent_bridge.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
|
||||
from agent.tools import SchedulerTool
|
||||
from config import conf
|
||||
|
||||
task_store = get_task_store()
|
||||
scheduler_service = get_scheduler_service()
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, SchedulerTool):
|
||||
tool.task_store = task_store
|
||||
tool.scheduler_service = scheduler_service
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
raw_ct = conf().get("channel_type", "unknown")
|
||||
if isinstance(raw_ct, list):
|
||||
ct = raw_ct[0] if raw_ct else "unknown"
|
||||
elif isinstance(raw_ct, str) and "," in raw_ct:
|
||||
ct = raw_ct.split(",")[0].strip()
|
||||
else:
|
||||
ct = raw_ct
|
||||
tool.config["channel_type"] = ct
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to inject scheduler dependencies: {e}")
|
||||
|
||||
def _initialize_skill_manager(self, workspace_root: str, session_id: Optional[str] = None):
|
||||
"""Initialize skill manager"""
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
skill_manager = SkillManager(custom_dir=os.path.join(workspace_root, "skills"))
|
||||
return skill_manager
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize SkillManager: {e}")
|
||||
return None
|
||||
|
||||
def _get_runtime_info(self, workspace_root: str):
|
||||
"""Get runtime information with dynamic time support"""
|
||||
from config import conf
|
||||
|
||||
def get_current_time():
|
||||
"""Get current time dynamically - called each time system prompt is accessed"""
|
||||
now = datetime.datetime.now()
|
||||
|
||||
# Get timezone info
|
||||
try:
|
||||
offset = -time.timezone if not time.daylight else -time.altzone
|
||||
hours = offset // 3600
|
||||
minutes = (offset % 3600) // 60
|
||||
timezone_name = f"UTC{hours:+03d}:{minutes:02d}" if minutes else f"UTC{hours:+03d}"
|
||||
except Exception:
|
||||
timezone_name = "UTC"
|
||||
|
||||
# Chinese weekday mapping
|
||||
weekday_map = {
|
||||
'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三',
|
||||
'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'
|
||||
}
|
||||
weekday_zh = weekday_map.get(now.strftime("%A"), now.strftime("%A"))
|
||||
|
||||
return {
|
||||
'time': now.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
'weekday': weekday_zh,
|
||||
'timezone': timezone_name
|
||||
}
|
||||
|
||||
return {
|
||||
"model": conf().get("model", "unknown"),
|
||||
"workspace": workspace_root,
|
||||
"channel": ", ".join(conf().get("channel_type")) if isinstance(conf().get("channel_type"), list) else conf().get("channel_type", "unknown"),
|
||||
"_get_current_time": get_current_time # Dynamic time function
|
||||
}
|
||||
|
||||
def _migrate_config_to_env(self, workspace_root: str):
|
||||
"""Migrate API keys from config.json to .env file"""
|
||||
from config import conf
|
||||
|
||||
key_mapping = {
|
||||
"open_ai_api_key": "OPENAI_API_KEY",
|
||||
"open_ai_api_base": "OPENAI_API_BASE",
|
||||
"gemini_api_key": "GEMINI_API_KEY",
|
||||
"claude_api_key": "CLAUDE_API_KEY",
|
||||
"linkai_api_key": "LINKAI_API_KEY",
|
||||
}
|
||||
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
|
||||
# Read existing env vars
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
with open(env_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, _ = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to read .env file: {e}")
|
||||
|
||||
# Check which keys need migration
|
||||
keys_to_migrate = {}
|
||||
for config_key, env_key in key_mapping.items():
|
||||
if env_key in existing_env_vars:
|
||||
continue
|
||||
value = conf().get(config_key, "")
|
||||
if value and value.strip():
|
||||
keys_to_migrate[env_key] = value.strip()
|
||||
|
||||
# Write new keys
|
||||
if keys_to_migrate:
|
||||
try:
|
||||
env_dir = os.path.dirname(env_file)
|
||||
if not os.path.exists(env_dir):
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
if not os.path.exists(env_file):
|
||||
open(env_file, 'a').close()
|
||||
|
||||
with open(env_file, 'a', encoding='utf-8') as f:
|
||||
f.write('\n# Auto-migrated from config.json\n')
|
||||
for key, value in keys_to_migrate.items():
|
||||
f.write(f'{key}={value}\n')
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[AgentInitializer] Migrated {len(keys_to_migrate)} API keys to .env: {list(keys_to_migrate.keys())}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to migrate API keys: {e}")
|
||||
|
||||
def _start_daily_flush_timer(self):
|
||||
"""Start a background thread that flushes all agents' memory daily at 23:55."""
|
||||
if getattr(self.agent_bridge, '_daily_flush_started', False):
|
||||
return
|
||||
self.agent_bridge._daily_flush_started = True
|
||||
|
||||
import threading
|
||||
|
||||
def _daily_flush_loop():
|
||||
while True:
|
||||
try:
|
||||
now = datetime.datetime.now()
|
||||
target = now.replace(hour=23, minute=55, second=0, microsecond=0)
|
||||
if target <= now:
|
||||
target += datetime.timedelta(days=1)
|
||||
wait_seconds = (target - now).total_seconds()
|
||||
logger.info(f"[DailyFlush] Next flush at {target.strftime('%Y-%m-%d %H:%M')} (in {wait_seconds/3600:.1f}h)")
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
self._flush_all_agents()
|
||||
except Exception as e:
|
||||
logger.warning(f"[DailyFlush] Error in daily flush loop: {e}")
|
||||
time.sleep(3600)
|
||||
|
||||
t = threading.Thread(target=_daily_flush_loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
def _flush_all_agents(self):
|
||||
"""Flush memory for all active agent sessions."""
|
||||
agents = []
|
||||
if self.agent_bridge.default_agent:
|
||||
agents.append(("default", self.agent_bridge.default_agent))
|
||||
for sid, agent in self.agent_bridge.agents.items():
|
||||
agents.append((sid, agent))
|
||||
|
||||
if not agents:
|
||||
return
|
||||
|
||||
flushed = 0
|
||||
for label, agent in agents:
|
||||
try:
|
||||
if not agent.memory_manager:
|
||||
continue
|
||||
with agent.messages_lock:
|
||||
messages = list(agent.messages)
|
||||
if not messages:
|
||||
continue
|
||||
result = agent.memory_manager.flush_manager.create_daily_summary(messages)
|
||||
if result:
|
||||
flushed += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"[DailyFlush] Failed for session {label}: {e}")
|
||||
|
||||
if flushed:
|
||||
logger.info(f"[DailyFlush] Flushed {flushed}/{len(agents)} agent session(s)")
|
||||
142
bridge/bridge.py
142
bridge/bridge.py
@@ -1,50 +1,146 @@
|
||||
from models.bot_factory import create_bot
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from common.log import logger
|
||||
from bot import bot_factory
|
||||
from common.singleton import singleton
|
||||
from voice import voice_factory
|
||||
from config import conf
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from translate.factory import create_translator
|
||||
from voice.factory import create_voice
|
||||
|
||||
|
||||
@singleton
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
self.btype={
|
||||
self.btype = {
|
||||
"chat": const.CHATGPT,
|
||||
"voice_to_text": conf().get("voice_to_text", "openai"),
|
||||
"text_to_voice": conf().get("text_to_voice", "google")
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
}
|
||||
model_type = conf().get("model")
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype['chat'] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype['chat'] = const.CHATGPTONAZURE
|
||||
self.bots={}
|
||||
# 这边取配置的模型
|
||||
bot_type = conf().get("bot_type")
|
||||
if bot_type:
|
||||
self.btype["chat"] = bot_type
|
||||
else:
|
||||
model_type = conf().get("model") or const.GPT_41_MINI
|
||||
|
||||
# Ensure model_type is string to prevent AttributeError when using startswith()
|
||||
# This handles cases where numeric model names (e.g., "1") are parsed as integers from YAML
|
||||
if not isinstance(model_type, str):
|
||||
logger.warning(f"[Bridge] model_type is not a string: {model_type} (type: {type(model_type).__name__}), converting to string")
|
||||
model_type = str(model_type)
|
||||
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype["chat"] = const.CHATGPTONAZURE
|
||||
if model_type in ["wenxin", "wenxin-4"]:
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
# Support Qwen3 and other DashScope models
|
||||
if model_type and (model_type.startswith("qwen") or model_type.startswith("qwq") or model_type.startswith("qvq")):
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and model_type.startswith("gemini"):
|
||||
self.btype["chat"] = const.GEMINI
|
||||
if model_type and model_type.startswith("glm"):
|
||||
self.btype["chat"] = const.ZHIPU_AI
|
||||
if model_type and model_type.startswith("claude"):
|
||||
self.btype["chat"] = const.CLAUDEAPI
|
||||
|
||||
def get_bot(self,typename):
|
||||
if model_type in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
if model_type and model_type.startswith("kimi"):
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
|
||||
if model_type and model_type.startswith("doubao"):
|
||||
self.btype["chat"] = const.DOUBAO
|
||||
|
||||
if model_type in [const.MODELSCOPE]:
|
||||
self.btype["chat"] = const.MODELSCOPE
|
||||
|
||||
# MiniMax models
|
||||
if model_type and (model_type in ["abab6.5-chat", "abab6.5"] or model_type.lower().startswith("minimax")):
|
||||
self.btype["chat"] = const.MiniMax
|
||||
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
self.btype["chat"] = const.LINKAI
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
self.btype["voice_to_text"] = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
self.btype["text_to_voice"] = const.LINKAI
|
||||
|
||||
self.bots = {}
|
||||
self.chat_bots = {}
|
||||
self._agent_bridge = None
|
||||
|
||||
# 模型对应的接口
|
||||
def get_bot(self, typename):
|
||||
if self.bots.get(typename) is None:
|
||||
logger.info("create bot {} for {}".format(self.btype[typename],typename))
|
||||
logger.info("create bot {} for {}".format(self.btype[typename], typename))
|
||||
if typename == "text_to_voice":
|
||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||
self.bots[typename] = create_voice(self.btype[typename])
|
||||
elif typename == "voice_to_text":
|
||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||
self.bots[typename] = create_voice(self.btype[typename])
|
||||
elif typename == "chat":
|
||||
self.bots[typename] = bot_factory.create_bot(self.btype[typename])
|
||||
self.bots[typename] = create_bot(self.btype[typename])
|
||||
elif typename == "translate":
|
||||
self.bots[typename] = create_translator(self.btype[typename])
|
||||
return self.bots[typename]
|
||||
|
||||
def get_bot_type(self,typename):
|
||||
|
||||
def get_bot_type(self, typename):
|
||||
return self.btype[typename]
|
||||
|
||||
|
||||
def fetch_reply_content(self, query, context : Context) -> Reply:
|
||||
def fetch_reply_content(self, query, context: Context) -> Reply:
|
||||
return self.get_bot("chat").reply(query, context)
|
||||
|
||||
|
||||
def fetch_voice_to_text(self, voiceFile) -> Reply:
|
||||
return self.get_bot("voice_to_text").voiceToText(voiceFile)
|
||||
|
||||
def fetch_text_to_voice(self, text) -> Reply:
|
||||
return self.get_bot("text_to_voice").textToVoice(text)
|
||||
|
||||
def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
|
||||
return self.get_bot("translate").translate(text, from_lang, to_lang)
|
||||
|
||||
def find_chat_bot(self, bot_type: str):
|
||||
if self.chat_bots.get(bot_type) is None:
|
||||
self.chat_bots[bot_type] = create_bot(bot_type)
|
||||
return self.chat_bots.get(bot_type)
|
||||
|
||||
def reset_bot(self):
|
||||
"""
|
||||
重置bot路由
|
||||
"""
|
||||
self.__init__()
|
||||
|
||||
def get_agent_bridge(self):
|
||||
"""
|
||||
Get agent bridge for agent-based conversations
|
||||
"""
|
||||
if self._agent_bridge is None:
|
||||
from bridge.agent_bridge import AgentBridge
|
||||
self._agent_bridge = AgentBridge(self)
|
||||
return self._agent_bridge
|
||||
|
||||
def fetch_agent_reply(self, query: str, context: Context = None,
|
||||
on_event=None, clear_history: bool = False) -> Reply:
|
||||
"""
|
||||
Use super agent to handle the query
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
context: Context object
|
||||
on_event: Event callback for streaming
|
||||
clear_history: Whether to clear conversation history
|
||||
|
||||
Returns:
|
||||
Reply object
|
||||
"""
|
||||
agent_bridge = self.get_agent_bridge()
|
||||
return agent_bridge.agent_reply(query, context, on_event, clear_history)
|
||||
|
||||
@@ -2,36 +2,49 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class ContextType (Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE = 3 # 图片消息
|
||||
IMAGE_CREATE = 10 # 创建图片命令
|
||||
|
||||
|
||||
class ContextType(Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE = 3 # 图片消息
|
||||
FILE = 4 # 文件信息
|
||||
VIDEO = 5 # 视频信息
|
||||
SHARING = 6 # 分享信息
|
||||
|
||||
IMAGE_CREATE = 10 # 创建图片命令
|
||||
ACCEPT_FRIEND = 19 # 同意好友请求
|
||||
JOIN_GROUP = 20 # 加入群聊
|
||||
PATPAT = 21 # 拍了拍
|
||||
FUNCTION = 22 # 函数调用
|
||||
EXIT_GROUP = 23 #退出
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()):
|
||||
def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
|
||||
self.type = type
|
||||
self.content = content
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __contains__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
return self.type is not None
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
return self.content is not None
|
||||
else:
|
||||
return key in self.kwargs
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
return self.type
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
return self.content
|
||||
else:
|
||||
return self.kwargs[key]
|
||||
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
@@ -39,20 +52,20 @@ class Context:
|
||||
return default
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
self.type = value
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
self.content = value
|
||||
else:
|
||||
self.kwargs[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
self.type = None
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
self.content = None
|
||||
else:
|
||||
del self.kwargs[key]
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
||||
|
||||
@@ -1,22 +1,31 @@
|
||||
|
||||
# encoding:utf-8
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ReplyType(Enum):
|
||||
TEXT = 1 # 文本
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
|
||||
TEXT = 1 # 文本
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
VIDEO_URL = 5 # 视频URL
|
||||
FILE = 6 # 文件
|
||||
CARD = 7 # 微信名片,仅支持ntchat
|
||||
INVITE_ROOM = 8 # 邀请好友进群
|
||||
INFO = 9
|
||||
ERROR = 10
|
||||
TEXT_ = 11 # 强制文本
|
||||
VIDEO = 12
|
||||
MINIAPP = 13 # 小程序
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Reply:
|
||||
def __init__(self, type : ReplyType = None , content = None):
|
||||
def __init__(self, type: ReplyType = None, content=None):
|
||||
self.type = type
|
||||
self.content = content
|
||||
|
||||
def __str__(self):
|
||||
return "Reply(type={}, content={})".format(self.type, self.content)
|
||||
return "Reply(type={}, content={})".format(self.type, self.content)
|
||||
|
||||
@@ -5,15 +5,52 @@ Message sending channel abstract class
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import *
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class Channel(object):
|
||||
channel_type = ""
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
||||
|
||||
def __init__(self):
|
||||
import threading
|
||||
self._startup_event = threading.Event()
|
||||
self._startup_error = None
|
||||
self.cloud_mode = False # set to True by ChannelManager when running with cloud client
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
init channel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def report_startup_success(self):
|
||||
self._startup_error = None
|
||||
self._startup_event.set()
|
||||
|
||||
def report_startup_error(self, error: str):
|
||||
self._startup_error = error
|
||||
self._startup_event.set()
|
||||
|
||||
def wait_startup(self, timeout: float = 3) -> (bool, str):
|
||||
"""
|
||||
Wait for channel startup result.
|
||||
Returns (success: bool, error_msg: str).
|
||||
"""
|
||||
ready = self._startup_event.wait(timeout=timeout)
|
||||
if not ready:
|
||||
return True, ""
|
||||
if self._startup_error:
|
||||
return False, self._startup_error
|
||||
return True, ""
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
stop channel gracefully, called before restart
|
||||
"""
|
||||
pass
|
||||
|
||||
def handle_text(self, msg):
|
||||
"""
|
||||
process received msg
|
||||
@@ -27,15 +64,45 @@ class Channel(object):
|
||||
send message to user
|
||||
:param msg: message content
|
||||
:param receiver: receiver channel account
|
||||
:return:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_reply_content(self, query, context : Context=None) -> Reply:
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
def build_reply_content(self, query, context: Context = None) -> Reply:
|
||||
"""
|
||||
Build reply content, using agent if enabled in config
|
||||
"""
|
||||
# Check if agent mode is enabled
|
||||
use_agent = conf().get("agent", False)
|
||||
|
||||
if use_agent:
|
||||
try:
|
||||
logger.info("[Channel] Using agent mode")
|
||||
|
||||
# Add channel_type to context if not present
|
||||
if context and "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
|
||||
# Read on_event callback injected by the channel (e.g. web SSE)
|
||||
on_event = context.get("on_event") if context else None
|
||||
|
||||
# Use agent bridge to handle the query
|
||||
return Bridge().fetch_agent_reply(
|
||||
query=query,
|
||||
context=context,
|
||||
on_event=on_event,
|
||||
clear_history=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Channel] Agent mode failed, fallback to normal mode: {e}")
|
||||
# Fallback to normal mode if agent fails
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
else:
|
||||
# Normal mode
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
|
||||
def build_voice_to_text(self, voice_file) -> Reply:
|
||||
return Bridge().fetch_voice_to_text(voice_file)
|
||||
|
||||
|
||||
def build_text_to_voice(self, text) -> Reply:
|
||||
return Bridge().fetch_text_to_voice(text)
|
||||
|
||||
@@ -1,26 +1,45 @@
|
||||
"""
|
||||
channel factory
|
||||
"""
|
||||
from common import const
|
||||
from .channel import Channel
|
||||
|
||||
def create_channel(channel_type):
|
||||
|
||||
def create_channel(channel_type) -> Channel:
|
||||
"""
|
||||
create a channel instance
|
||||
:param channel_type: channel type code
|
||||
:return: channel instance
|
||||
"""
|
||||
if channel_type == 'wx':
|
||||
from channel.wechat.wechat_channel import WechatChannel
|
||||
return WechatChannel()
|
||||
elif channel_type == 'wxy':
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
return WechatyChannel()
|
||||
elif channel_type == 'terminal':
|
||||
ch = Channel()
|
||||
if channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
return TerminalChannel()
|
||||
elif channel_type == 'wechatmp':
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == 'web':
|
||||
from channel.web.web_channel import WebChannel
|
||||
ch = WebChannel()
|
||||
elif channel_type == "wechatmp":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
return WechatMPChannel(passive_reply = True)
|
||||
elif channel_type == 'wechatmp_service':
|
||||
ch = WechatMPChannel(passive_reply=True)
|
||||
elif channel_type == "wechatmp_service":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
return WechatMPChannel(passive_reply = False)
|
||||
raise RuntimeError
|
||||
ch = WechatMPChannel(passive_reply=False)
|
||||
elif channel_type == "wechatcom_app":
|
||||
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
|
||||
ch = WechatComAppChannel()
|
||||
elif channel_type == const.FEISHU:
|
||||
from channel.feishu.feishu_channel import FeiShuChanel
|
||||
ch = FeiShuChanel()
|
||||
elif channel_type == const.DINGTALK:
|
||||
from channel.dingtalk.dingtalk_channel import DingTalkChanel
|
||||
ch = DingTalkChanel()
|
||||
elif channel_type == const.WECOM_BOT:
|
||||
from channel.wecom_bot.wecom_bot_channel import WecomBotChannel
|
||||
ch = WecomBotChannel()
|
||||
elif channel_type == const.QQ:
|
||||
from channel.qq.qq_channel import QQChannel
|
||||
ch = QQChannel()
|
||||
else:
|
||||
raise RuntimeError
|
||||
ch.channel_type = channel_type
|
||||
return ch
|
||||
|
||||
@@ -1,156 +1,218 @@
|
||||
|
||||
|
||||
from asyncio import CancelledError
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from common.dequeue import Dequeue
|
||||
from channel.channel import Channel
|
||||
from bridge.reply import *
|
||||
from asyncio import CancelledError
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from bridge.context import *
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from bridge.reply import *
|
||||
from channel.channel import Channel
|
||||
from common.dequeue import Dequeue
|
||||
from common import memory
|
||||
from plugins import *
|
||||
|
||||
try:
|
||||
from voice.audio_convert import any_to_wav
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
|
||||
|
||||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Instance-level attributes so each channel subclass has its own
|
||||
# independent session queue and lock. Previously these were class-level,
|
||||
# which caused contexts from one channel (e.g. Feishu) to be consumed
|
||||
# by another channel's consume() thread (e.g. Web), leading to errors
|
||||
# like "No request_id found in context".
|
||||
self.futures = {}
|
||||
self.sessions = {}
|
||||
self.lock = threading.Lock()
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
|
||||
|
||||
# 根据消息构造context,消息内容相关的触发项写在这里
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
# context首次传入时,origin_ctype是None,
|
||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||
if 'origin_ctype' not in context:
|
||||
context['origin_ctype'] = ctype
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
# context首次传入时,receiver是None,根据类型设置receiver
|
||||
first_in = 'receiver' not in context
|
||||
first_in = "receiver" not in context
|
||||
# 群名匹配过程,设置session_id和receiver
|
||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
||||
config = conf()
|
||||
cmsg = context['msg']
|
||||
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True):
|
||||
logger.debug("[WX]self message skipped")
|
||||
return None
|
||||
cmsg = context["msg"]
|
||||
user_data = conf().get_user_data(cmsg.from_user_id)
|
||||
context["openai_api_key"] = user_data.get("openai_api_key")
|
||||
context["gpt_model"] = user_data.get("gpt_model")
|
||||
if context.get("isgroup", False):
|
||||
group_name = cmsg.other_user_nickname
|
||||
group_id = cmsg.other_user_id
|
||||
|
||||
group_name_white_list = config.get('group_name_white_list', [])
|
||||
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
|
||||
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
session_id = cmsg.actual_user_id
|
||||
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
|
||||
group_name_white_list = config.get("group_name_white_list", [])
|
||||
group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
|
||||
if any(
|
||||
[
|
||||
group_name in group_name_white_list,
|
||||
"ALL_GROUP" in group_name_white_list,
|
||||
check_contain(group_name, group_name_keyword_white_list),
|
||||
]
|
||||
):
|
||||
# Check global group_shared_session config first
|
||||
group_shared_session = conf().get("group_shared_session", True)
|
||||
if group_shared_session:
|
||||
# All users in the group share the same session
|
||||
session_id = group_id
|
||||
else:
|
||||
# Check group-specific whitelist (legacy behavior)
|
||||
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
|
||||
session_id = cmsg.actual_user_id
|
||||
if any(
|
||||
[
|
||||
group_name in group_chat_in_one_session,
|
||||
"ALL_GROUP" in group_chat_in_one_session,
|
||||
]
|
||||
):
|
||||
session_id = group_id
|
||||
else:
|
||||
logger.debug(f"No need reply, groupName not in whitelist, group_name={group_name}")
|
||||
return None
|
||||
context['session_id'] = session_id
|
||||
context['receiver'] = group_id
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = group_id
|
||||
else:
|
||||
context['session_id'] = cmsg.other_user_id
|
||||
context['receiver'] = cmsg.other_user_id
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
|
||||
context = e_context["context"]
|
||||
if e_context.is_pass() or context is None:
|
||||
return context
|
||||
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
|
||||
logger.debug("[chat_channel]self message skipped")
|
||||
return None
|
||||
|
||||
# 消息内容匹配过程,并处理content
|
||||
if ctype == ContextType.TEXT:
|
||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||
logger.debug("[WX]reference query skipped")
|
||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||
logger.debug(content)
|
||||
logger.debug("[chat_channel]reference query skipped")
|
||||
return None
|
||||
|
||||
if context.get("isgroup", False): # 群聊
|
||||
|
||||
nick_name_black_list = conf().get("nick_name_black_list", [])
|
||||
if context.get("isgroup", False): # 群聊
|
||||
# 校验关键字
|
||||
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
|
||||
match_contain = check_contain(content, conf().get('group_chat_keyword'))
|
||||
match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
|
||||
match_contain = check_contain(content, conf().get("group_chat_keyword"))
|
||||
flag = False
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
flag = True
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
if context['msg'].is_at:
|
||||
logger.info("[WX]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
if context["msg"].to_user_id != context["msg"].actual_user_id:
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
flag = True
|
||||
pattern = f'@{self.name}(\u2005|\u0020)'
|
||||
content = re.sub(pattern, r'', content)
|
||||
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
if context["msg"].is_at:
|
||||
nick_name = context["msg"].actual_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[chat_channel] Nickname {nick_name} in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
logger.info("[chat_channel]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
flag = True
|
||||
self.name = self.name if self.name is not None else "" # 部分渠道self.name可能没有赋值
|
||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
if isinstance(context["msg"].at_list, list):
|
||||
for at in context["msg"].at_list:
|
||||
pattern = f"@{re.escape(at)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", subtract_res)
|
||||
if subtract_res == content and context["msg"].self_display_name:
|
||||
# 前缀移除后没有变化,使用群昵称再次移除
|
||||
pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
content = subtract_res
|
||||
if not flag:
|
||||
if context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
||||
logger.info("[chat_channel]receive group voice, but checkprefix didn't match")
|
||||
return None
|
||||
else: # 单聊
|
||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix',['']))
|
||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
else: # 单聊
|
||||
nick_name = context["msg"].from_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[chat_channel] Nickname '{nick_name}' in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
|
||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
|
||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||
logger.info("[chat_channel]receive single chat msg, but checkprefix didn't match")
|
||||
return None
|
||||
content = content.strip()
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix",[""]))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, '', 1).strip()
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content
|
||||
if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context['desire_rtype'] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context['desire_rtype'] = ReplyType.VOICE
|
||||
|
||||
context.content = content.strip()
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
return context
|
||||
|
||||
def _handle(self, context: Context):
|
||||
if context is None or not context.content:
|
||||
return
|
||||
logger.debug('[WX] ready to handle context: {}'.format(context))
|
||||
logger.debug("[chat_channel] handling context: {}".format(context))
|
||||
# reply的构建步骤
|
||||
reply = self._generate_reply(context)
|
||||
|
||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
||||
# reply的包装步骤
|
||||
reply = self._decorate_reply(context, reply)
|
||||
logger.debug("[chat_channel] decorating reply: {}".format(reply))
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
# reply的包装步骤
|
||||
if reply and reply.content:
|
||||
reply = self._decorate_reply(context, reply)
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
|
||||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_HANDLE_CONTEXT,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass():
|
||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
|
||||
logger.debug("[chat_channel] type={}, content={}".format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
context["channel"] = e_context["channel"]
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE: # 语音消息
|
||||
cmsg = context['msg']
|
||||
cmsg = context["msg"]
|
||||
cmsg.prepare()
|
||||
file_path = context.content
|
||||
wav_path = os.path.splitext(file_path)[0] + '.wav'
|
||||
wav_path = os.path.splitext(file_path)[0] + ".wav"
|
||||
try:
|
||||
any_to_wav(file_path, wav_path)
|
||||
any_to_wav(file_path, wav_path)
|
||||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
||||
logger.warning("[WX]any to wav error, use raw path. " + str(e))
|
||||
logger.warning("[chat_channel]any to wav error, use raw path. " + str(e))
|
||||
wav_path = file_path
|
||||
# 语音识别
|
||||
reply = super().build_voice_to_text(wav_path)
|
||||
@@ -161,32 +223,41 @@ class ChatChannel(Channel):
|
||||
os.remove(wav_path)
|
||||
except Exception as e:
|
||||
pass
|
||||
# logger.warning("[WX]delete temp file error: " + str(e))
|
||||
# logger.warning("[chat_channel]delete temp file error: " + str(e))
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
new_context = self._compose_context(
|
||||
ContextType.TEXT, reply.content, **context.kwargs)
|
||||
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
|
||||
if new_context:
|
||||
reply = self._generate_reply(new_context)
|
||||
else:
|
||||
return
|
||||
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑
|
||||
elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
|
||||
memory.USER_IMAGE_CACHE[context["session_id"]] = {
|
||||
"path": context.content,
|
||||
"msg": context.get("msg")
|
||||
}
|
||||
elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
|
||||
pass
|
||||
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
|
||||
pass
|
||||
else:
|
||||
logger.error('[WX] unknown context type: {}'.format(context.type))
|
||||
logger.warning("[chat_channel] unknown context type: {}".format(context.type))
|
||||
return
|
||||
return reply
|
||||
|
||||
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
desire_rtype = context.get('desire_rtype')
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_DECORATE_REPLY,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
desire_rtype = context.get("desire_rtype")
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
|
||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||
logger.error("[WX]reply type not support: " + str(reply.type))
|
||||
logger.error("[chat_channel]reply type not support: " + str(reply.type))
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "不支持发送的消息类型: " + str(reply.type)
|
||||
|
||||
@@ -196,55 +267,153 @@ class ChatChannel(Channel):
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context.get("isgroup", False):
|
||||
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
|
||||
if not context.get("no_need_at", False):
|
||||
reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "")
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
|
||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "")
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = "["+str(reply.type)+"]\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
||||
reply.content = "[" + str(reply.type) + "]\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
|
||||
pass
|
||||
else:
|
||||
logger.error('[WX] unknown reply type: {}'.format(reply.type))
|
||||
logger.error("[chat_channel] unknown reply type: {}".format(reply.type))
|
||||
return
|
||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
||||
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
|
||||
logger.warning("[chat_channel] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
|
||||
return reply
|
||||
|
||||
def _send_reply(self, context: Context, reply: Reply):
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_SEND_REPLY,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context))
|
||||
logger.debug("[chat_channel] sending reply: {}, context: {}".format(reply, context))
|
||||
|
||||
# 如果是文本回复,尝试提取并发送图片
|
||||
if reply.type == ReplyType.TEXT:
|
||||
self._extract_and_send_images(reply, context)
|
||||
# 如果是图片回复但带有文本内容,先发文本再发图片
|
||||
elif reply.type == ReplyType.IMAGE_URL and hasattr(reply, 'text_content') and reply.text_content:
|
||||
# 先发送文本
|
||||
text_reply = Reply(ReplyType.TEXT, reply.text_content)
|
||||
self._send(text_reply, context)
|
||||
# 短暂延迟后发送图片
|
||||
time.sleep(0.3)
|
||||
self._send(reply, context)
|
||||
else:
|
||||
self._send(reply, context)
|
||||
|
||||
def _extract_and_send_images(self, reply: Reply, context: Context):
|
||||
"""
|
||||
从文本回复中提取图片/视频URL并单独发送
|
||||
支持格式:[图片: /path/to/image.png], [视频: /path/to/video.mp4], , <img src="url">
|
||||
最多发送5个媒体文件
|
||||
"""
|
||||
content = reply.content
|
||||
media_items = [] # [(url, type), ...]
|
||||
|
||||
# 正则提取各种格式的媒体URL
|
||||
patterns = [
|
||||
(r'\[图片:\s*([^\]]+)\]', 'image'), # [图片: /path/to/image.png]
|
||||
(r'\[视频:\s*([^\]]+)\]', 'video'), # [视频: /path/to/video.mp4]
|
||||
(r'!\[.*?\]\(([^\)]+)\)', 'image'), #  - 默认图片
|
||||
(r'<img[^>]+src=["\']([^"\']+)["\']', 'image'), # <img src="url">
|
||||
(r'<video[^>]+src=["\']([^"\']+)["\']', 'video'), # <video src="url">
|
||||
(r'https?://[^\s]+\.(?:jpg|jpeg|png|gif|webp)', 'image'), # 直接的图片URL
|
||||
(r'https?://[^\s]+\.(?:mp4|avi|mov|wmv|flv)', 'video'), # 直接的视频URL
|
||||
]
|
||||
|
||||
for pattern, media_type in patterns:
|
||||
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
for match in matches:
|
||||
media_items.append((match, media_type))
|
||||
|
||||
# 去重(保持顺序)并限制最多5个
|
||||
seen = set()
|
||||
unique_items = []
|
||||
for url, mtype in media_items:
|
||||
if url not in seen:
|
||||
seen.add(url)
|
||||
unique_items.append((url, mtype))
|
||||
media_items = unique_items[:5]
|
||||
|
||||
if media_items:
|
||||
logger.info(f"[chat_channel] Extracted {len(media_items)} media item(s) from reply")
|
||||
|
||||
# 先发送文本(保持原文本不变)
|
||||
logger.info(f"[chat_channel] Sending text content before media: {reply.content[:100]}...")
|
||||
self._send(reply, context)
|
||||
logger.info(f"[chat_channel] Text sent, now sending {len(media_items)} media item(s)")
|
||||
|
||||
# 然后逐个发送媒体文件
|
||||
for i, (url, media_type) in enumerate(media_items):
|
||||
try:
|
||||
# 判断是本地文件还是URL
|
||||
if url.startswith(('http://', 'https://')):
|
||||
# 网络资源
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型发送
|
||||
media_reply = Reply(ReplyType.FILE, url)
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, url)
|
||||
elif os.path.exists(url):
|
||||
# 本地文件
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.FILE, f"file://{url}")
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, f"file://{url}")
|
||||
else:
|
||||
logger.warning(f"[chat_channel] Media file not found or invalid URL: {url}")
|
||||
continue
|
||||
|
||||
# 发送媒体文件(添加小延迟避免频率限制)
|
||||
if i > 0:
|
||||
time.sleep(0.5)
|
||||
self._send(media_reply, context)
|
||||
logger.info(f"[chat_channel] Sent {media_type} {i+1}/{len(media_items)}: {url[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[chat_channel] Failed to send {media_type} {url}: {e}")
|
||||
else:
|
||||
# 没有媒体文件,正常发送文本
|
||||
self._send(reply, context)
|
||||
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt = 0):
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt=0):
|
||||
try:
|
||||
self.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error('[WX] sendMsg error: {}'.format(str(e)))
|
||||
logger.error("[chat_channel] sendMsg error: {}".format(str(e)))
|
||||
if isinstance(e, NotImplementedError):
|
||||
return
|
||||
logger.exception(e)
|
||||
if retry_cnt < 2:
|
||||
time.sleep(3+3*retry_cnt)
|
||||
self._send(reply, context, retry_cnt+1)
|
||||
time.sleep(3 + 3 * retry_cnt)
|
||||
self._send(reply, context, retry_cnt + 1)
|
||||
|
||||
def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数
|
||||
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
|
||||
logger.debug("Worker return success, session_id = {}".format(session_id))
|
||||
|
||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
|
||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.exception("Worker return exception: {}".format(exception))
|
||||
|
||||
def _thread_pool_callback(self, session_id, **kwargs):
|
||||
def func(worker:Future):
|
||||
def func(worker: Future):
|
||||
try:
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
self._fail_callback(session_id, exception = worker_exception, **kwargs)
|
||||
self._fail_callback(session_id, exception=worker_exception, **kwargs)
|
||||
else:
|
||||
self._success_callback(session_id, **kwargs)
|
||||
except CancelledError as e:
|
||||
@@ -253,15 +422,19 @@ class ChatChannel(Channel):
|
||||
logger.exception("Worker raise exception: {}".format(e))
|
||||
with self.lock:
|
||||
self.sessions[session_id][1].release()
|
||||
|
||||
return func
|
||||
|
||||
def produce(self, context: Context):
|
||||
session_id = context['session_id']
|
||||
session_id = context["session_id"]
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 4))]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
self.sessions[session_id] = [
|
||||
Dequeue(),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 1)),
|
||||
]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
else:
|
||||
self.sessions[session_id][0].put(context)
|
||||
|
||||
@@ -270,46 +443,49 @@ class ChatChannel(Channel):
|
||||
while True:
|
||||
with self.lock:
|
||||
session_ids = list(self.sessions.keys())
|
||||
for session_id in session_ids:
|
||||
for session_id in session_ids:
|
||||
with self.lock:
|
||||
context_queue, semaphore = self.sessions[session_id]
|
||||
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[WX] consume context: {}".format(context))
|
||||
future:Future = self.handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context = context))
|
||||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[chat_channel] consume context: {}".format(context))
|
||||
future: Future = handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||
with self.lock:
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
self.futures[session_id].append(future)
|
||||
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
with self.lock:
|
||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
||||
del self.sessions[session_id]
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.2)
|
||||
|
||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
||||
def cancel_session(self, session_id):
|
||||
def cancel_session(self, session_id):
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt>0:
|
||||
if cnt > 0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
|
||||
def cancel_all_session(self):
|
||||
with self.lock:
|
||||
for session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt>0:
|
||||
if cnt > 0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
if not prefix_list:
|
||||
@@ -319,6 +495,7 @@ def check_prefix(content, prefix_list):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
|
||||
def check_contain(content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
|
||||
"""
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
||||
"""
|
||||
Unified chat message class for different channel implementations.
|
||||
|
||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
||||
|
||||
@@ -20,45 +19,47 @@ other_user_id: 对方的id,如果你是发送者,那这个就是接收者id
|
||||
other_user_nickname: 同上
|
||||
|
||||
is_group: 是否是群消息 (群聊必填)
|
||||
is_at: 是否被at
|
||||
is_at: 是否被at
|
||||
|
||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
||||
actual_user_id: 实际发送者id (群聊必填)
|
||||
actual_user_nickname:实际发送者昵称
|
||||
|
||||
|
||||
|
||||
self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
|
||||
|
||||
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
|
||||
_prepared: 是否已经调用过准备函数
|
||||
_rawmsg: 原始消息对象
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ChatMessage(object):
|
||||
msg_id = None
|
||||
create_time = None
|
||||
|
||||
|
||||
ctype = None
|
||||
content = None
|
||||
|
||||
|
||||
from_user_id = None
|
||||
from_user_nickname = None
|
||||
to_user_id = None
|
||||
to_user_nickname = None
|
||||
other_user_id = None
|
||||
other_user_nickname = None
|
||||
|
||||
my_msg = False
|
||||
self_display_name = None
|
||||
|
||||
is_group = False
|
||||
is_at = False
|
||||
actual_user_id = None
|
||||
actual_user_nickname = None
|
||||
at_list = None
|
||||
|
||||
_prepare_fn = None
|
||||
_prepared = False
|
||||
_rawmsg = None
|
||||
|
||||
|
||||
def __init__(self,_rawmsg):
|
||||
def __init__(self, _rawmsg):
|
||||
self._rawmsg = _rawmsg
|
||||
|
||||
def prepare(self):
|
||||
@@ -67,7 +68,7 @@ class ChatMessage(object):
|
||||
self._prepare_fn()
|
||||
|
||||
def __str__(self):
|
||||
return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format(
|
||||
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}, at_list={}".format(
|
||||
self.msg_id,
|
||||
self.create_time,
|
||||
self.ctype,
|
||||
@@ -82,4 +83,5 @@ class ChatMessage(object):
|
||||
self.is_at,
|
||||
self.actual_user_id,
|
||||
self.actual_user_nickname,
|
||||
)
|
||||
self.at_list
|
||||
)
|
||||
|
||||
970
channel/dingtalk/dingtalk_channel.py
Normal file
970
channel/dingtalk/dingtalk_channel.py
Normal file
@@ -0,0 +1,970 @@
|
||||
"""
|
||||
钉钉通道接入
|
||||
|
||||
@author huiwen
|
||||
@Date 2023/11/28
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
# -*- coding=utf-8 -*-
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
|
||||
import dingtalk_stream
|
||||
from dingtalk_stream import AckMessage
|
||||
from dingtalk_stream.card_replier import AICardReplier
|
||||
from dingtalk_stream.card_replier import AICardStatus
|
||||
from dingtalk_stream.card_replier import CardReplier
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel
|
||||
from common.utils import expand_path
|
||||
from channel.dingtalk.dingtalk_message import DingTalkMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.time_check import time_checker
|
||||
from config import conf
|
||||
|
||||
|
||||
class CustomAICardReplier(CardReplier):
|
||||
def __init__(self, dingtalk_client, incoming_message):
|
||||
super(AICardReplier, self).__init__(dingtalk_client, incoming_message)
|
||||
|
||||
def start(
|
||||
self,
|
||||
card_template_id: str,
|
||||
card_data: dict,
|
||||
recipients: list = None,
|
||||
support_forward: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
AI卡片的创建接口
|
||||
:param support_forward:
|
||||
:param recipients:
|
||||
:param card_template_id:
|
||||
:param card_data:
|
||||
:return:
|
||||
"""
|
||||
card_data_with_status = copy.deepcopy(card_data)
|
||||
card_data_with_status["flowStatus"] = AICardStatus.PROCESSING
|
||||
return self.create_and_send_card(
|
||||
card_template_id,
|
||||
card_data_with_status,
|
||||
at_sender=True,
|
||||
at_all=False,
|
||||
recipients=recipients,
|
||||
support_forward=support_forward,
|
||||
)
|
||||
|
||||
|
||||
# 对 AICardReplier 进行猴子补丁
|
||||
AICardReplier.start = CustomAICardReplier.start
|
||||
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: DingTalkMessage):
|
||||
msgId = cmsg.msg_id
|
||||
if msgId in self.receivedMsgs:
|
||||
logger.info("DingTalk message {} already received, ignore".format(msgId))
|
||||
return
|
||||
self.receivedMsgs[msgId] = True
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[DingTalk] History message {} skipped".format(msgId))
|
||||
return
|
||||
if cmsg.my_msg and not cmsg.is_group:
|
||||
logger.debug("[DingTalk] My message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@singleton
|
||||
class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
|
||||
def setup_logger(self):
|
||||
# Suppress verbose logs from dingtalk_stream SDK
|
||||
logging.getLogger("dingtalk_stream").setLevel(logging.WARNING)
|
||||
return logging.getLogger("DingTalk")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
super(dingtalk_stream.ChatbotHandler, self).__init__()
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
self._stream_client = None
|
||||
self._running = False
|
||||
self._event_loop = None
|
||||
logger.debug("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
# 单聊无需前缀
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
# Access token cache
|
||||
self._access_token = None
|
||||
self._access_token_expires_at = 0
|
||||
# Robot code cache (extracted from incoming messages)
|
||||
self._robot_code = None
|
||||
|
||||
def _open_connection(self, client):
|
||||
"""
|
||||
Open a DingTalk stream connection directly, bypassing SDK's internal error-swallowing.
|
||||
Returns (connection_dict, error_str). On success error_str is empty; on failure
|
||||
connection_dict is None and error_str contains a human-readable message.
|
||||
"""
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://api.dingtalk.com/v1.0/gateway/connections/open",
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
json={
|
||||
"clientId": client.credential.client_id,
|
||||
"clientSecret": client.credential.client_secret,
|
||||
"subscriptions": [{"type": "CALLBACK",
|
||||
"topic": dingtalk_stream.chatbot.ChatbotMessage.TOPIC}],
|
||||
"ua": "dingtalk-sdk-python/cow",
|
||||
"localIp": "",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
body = resp.json()
|
||||
if not resp.ok:
|
||||
code = body.get("code", resp.status_code)
|
||||
message = body.get("message", resp.reason)
|
||||
return None, f"open connection failed: [{code}] {message}"
|
||||
return body, ""
|
||||
except Exception as e:
|
||||
return None, f"open connection failed: {e}"
|
||||
|
||||
def startup(self):
|
||||
import asyncio
|
||||
self.dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
self.dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
self._running = True
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
self._stream_client = client
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
logger.info("[DingTalk] ✅ Stream client initialized, ready to receive messages")
|
||||
|
||||
# Run the connection loop ourselves instead of delegating to client.start(),
|
||||
# so we can get detailed error messages and respond to stop() quickly.
|
||||
import urllib.parse as _urlparse
|
||||
import websockets as _ws
|
||||
import json as _json
|
||||
client.pre_start()
|
||||
_first_connect = True
|
||||
while self._running:
|
||||
# Open connection using our own request so we get detailed error info.
|
||||
connection, err_msg = self._open_connection(client)
|
||||
|
||||
if connection is None:
|
||||
if _first_connect:
|
||||
logger.warning(f"[DingTalk] {err_msg}")
|
||||
self.report_startup_error(err_msg)
|
||||
_first_connect = False
|
||||
else:
|
||||
logger.warning(f"[DingTalk] {err_msg}, retrying in 10s...")
|
||||
|
||||
# Interruptible sleep: checks _running every 100ms.
|
||||
for _ in range(100):
|
||||
if not self._running:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
if _first_connect:
|
||||
logger.info("[DingTalk] ✅ Connected to DingTalk stream")
|
||||
self.report_startup_success()
|
||||
_first_connect = False
|
||||
else:
|
||||
logger.info("[DingTalk] Reconnected to DingTalk stream")
|
||||
|
||||
# Run the WebSocket session in an asyncio loop.
|
||||
uri = '%s?ticket=%s' % (
|
||||
connection['endpoint'],
|
||||
_urlparse.quote_plus(connection['ticket'])
|
||||
)
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
self._event_loop = loop
|
||||
try:
|
||||
async def _session():
|
||||
async with _ws.connect(uri) as websocket:
|
||||
client.websocket = websocket
|
||||
async for raw_message in websocket:
|
||||
json_message = _json.loads(raw_message)
|
||||
result = await client.route_message(json_message)
|
||||
if result == dingtalk_stream.DingTalkStreamClient.TAG_DISCONNECT:
|
||||
break
|
||||
|
||||
loop.run_until_complete(_session())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("[DingTalk] Session loop received stop signal, exiting")
|
||||
break
|
||||
except Exception as e:
|
||||
if not self._running:
|
||||
break
|
||||
logger.warning(f"[DingTalk] Stream session error: {e}, reconnecting in 3s...")
|
||||
for _ in range(30):
|
||||
if not self._running:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
finally:
|
||||
self._event_loop = None
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("[DingTalk] Startup loop exited")
|
||||
|
||||
def stop(self):
|
||||
logger.info("[DingTalk] stop() called, setting _running=False")
|
||||
self._running = False
|
||||
loop = self._event_loop
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
logger.info("[DingTalk] Sent stop signal to event loop")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DingTalk] Error stopping event loop: {e}")
|
||||
self._stream_client = None
|
||||
logger.info("[DingTalk] stop() completed")
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
获取企业内部应用的 access_token
|
||||
文档: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 如果 token 还没过期,直接返回缓存的 token
|
||||
if self._access_token and current_time < self._access_token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
# 获取新的 access_token
|
||||
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"appKey": self.dingtalk_client_id,
|
||||
"appSecret": self.dingtalk_client_secret
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200 and "accessToken" in result:
|
||||
self._access_token = result["accessToken"]
|
||||
# Token 有效期为 2 小时,提前 5 分钟刷新
|
||||
self._access_token_expires_at = current_time + result.get("expireIn", 7200) - 300
|
||||
logger.info("[DingTalk] Access token refreshed successfully")
|
||||
return self._access_token
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to get access token: {result}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error getting access token: {e}")
|
||||
return None
|
||||
|
||||
def send_single_message(self, user_id: str, content: str, robot_code: str) -> bool:
|
||||
"""
|
||||
Send message to single user (private chat)
|
||||
API: https://open.dingtalk.com/document/orgapp/chatbots-send-one-on-one-chat-messages-in-batches
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Failed to send single message: Access token not available.")
|
||||
return False
|
||||
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] Cannot send single message: robot_code is required")
|
||||
return False
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"content": content}),
|
||||
"msgKey": "sampleText",
|
||||
"userIds": [user_id],
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
logger.info(f"[DingTalk] Sending single message to user {user_id} with robot_code {robot_code}")
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200 and result.get("processQueryKey"):
|
||||
logger.info(f"[DingTalk] Single message sent successfully to {user_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to send single message: {result}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error sending single message: {e}")
|
||||
return False
|
||||
|
||||
def send_group_message(self, conversation_id: str, content: str, robot_code: str = None):
|
||||
"""
|
||||
主动发送群消息
|
||||
文档: https://open.dingtalk.com/document/orgapp/the-robot-sends-a-group-message
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID (openConversationId)
|
||||
content: 消息内容
|
||||
robot_code: 机器人编码,默认使用 dingtalk_client_id
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot send group message: no access token")
|
||||
return False
|
||||
|
||||
# Validate robot_code
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] Cannot send group message: robot_code is required")
|
||||
return False
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"content": content}),
|
||||
"msgKey": "sampleText",
|
||||
"openConversationId": conversation_id,
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(f"[DingTalk] Group message sent successfully to {conversation_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to send group message: {result}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error sending group message: {e}")
|
||||
return False
|
||||
|
||||
def upload_media(self, file_path: str, media_type: str = "image") -> str:
|
||||
"""
|
||||
上传媒体文件到钉钉
|
||||
|
||||
Args:
|
||||
file_path: 本地文件路径或URL
|
||||
media_type: 媒体类型 (image, video, voice, file)
|
||||
|
||||
Returns:
|
||||
media_id,如果上传失败返回 None
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot upload media: no access token")
|
||||
return None
|
||||
|
||||
# 处理 file:// URL
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
|
||||
# 如果是 HTTP URL,先下载
|
||||
if file_path.startswith("http://") or file_path.startswith("https://"):
|
||||
try:
|
||||
import uuid
|
||||
response = requests.get(file_path, timeout=(5, 60))
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[DingTalk] Failed to download file from URL: {file_path}")
|
||||
return None
|
||||
|
||||
# 保存到临时文件
|
||||
file_name = os.path.basename(file_path) or f"media_{uuid.uuid4()}"
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
temp_file = os.path.join(tmp_dir, file_name)
|
||||
|
||||
with open(temp_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
file_path = temp_file
|
||||
logger.info(f"[DingTalk] Downloaded file to {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error downloading file: {e}")
|
||||
return None
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"[DingTalk] File not found: {file_path}")
|
||||
return None
|
||||
|
||||
# 上传到钉钉
|
||||
# 钉钉上传媒体文件 API: https://open.dingtalk.com/document/orgapp/upload-media-files
|
||||
url = "https://oapi.dingtalk.com/media/upload"
|
||||
params = {
|
||||
"access_token": access_token,
|
||||
"type": media_type
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"media": (os.path.basename(file_path), f)}
|
||||
response = requests.post(url, params=params, files=files, timeout=(5, 60))
|
||||
result = response.json()
|
||||
|
||||
if result.get("errcode") == 0:
|
||||
media_id = result.get("media_id")
|
||||
logger.info(f"[DingTalk] Media uploaded successfully, media_id={media_id}")
|
||||
return media_id
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to upload media: {result}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error uploading media: {e}")
|
||||
return None
|
||||
|
||||
def send_image_with_media_id(self, access_token: str, media_id: str, incoming_message, is_group: bool) -> bool:
|
||||
"""
|
||||
发送图片消息(使用 media_id)
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
media_id: 媒体ID
|
||||
incoming_message: 钉钉消息对象
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
msg_param = {
|
||||
"photoURL": media_id # 钉钉图片消息使用 photoURL 字段
|
||||
}
|
||||
|
||||
body = {
|
||||
"robotCode": incoming_message.robot_code,
|
||||
"msgKey": "sampleImageMsg",
|
||||
"msgParam": json.dumps(msg_param),
|
||||
}
|
||||
|
||||
if is_group:
|
||||
# 群聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
body["openConversationId"] = incoming_message.conversation_id
|
||||
else:
|
||||
# 单聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
body["userIds"] = [incoming_message.sender_staff_id]
|
||||
|
||||
try:
|
||||
response = requests.post(url=url, headers=headers, json=body, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
logger.info(f"[DingTalk] Image send result: {response.text}")
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Send image error: {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Send image exception: {e}")
|
||||
return False
|
||||
|
||||
def send_image_message(self, receiver: str, media_id: str, is_group: bool, robot_code: str) -> bool:
|
||||
"""
|
||||
发送图片消息
|
||||
|
||||
Args:
|
||||
receiver: 接收者ID (user_id 或 conversation_id)
|
||||
media_id: 媒体ID
|
||||
is_group: 是否为群聊
|
||||
robot_code: 机器人编码
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot send image: no access token")
|
||||
return False
|
||||
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] Cannot send image: robot_code is required")
|
||||
return False
|
||||
|
||||
if is_group:
|
||||
# 发送群聊图片
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"mediaId": media_id}),
|
||||
"msgKey": "sampleImageMsg",
|
||||
"openConversationId": receiver,
|
||||
"robotCode": robot_code
|
||||
}
|
||||
else:
|
||||
# 发送单聊图片
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"mediaId": media_id}),
|
||||
"msgKey": "sampleImageMsg",
|
||||
"userIds": [receiver],
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(f"[DingTalk] Image message sent successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to send image message: {result}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error sending image message: {e}")
|
||||
return False
|
||||
|
||||
def get_image_download_url(self, download_code: str) -> str:
|
||||
"""
|
||||
获取图片下载地址
|
||||
返回一个特殊的 URL 格式:dingtalk://download/{robot_code}:{download_code}
|
||||
后续会在 download_image_file 中使用新版 API 下载
|
||||
"""
|
||||
# 获取 robot_code
|
||||
if not hasattr(self, '_robot_code_cache'):
|
||||
self._robot_code_cache = None
|
||||
|
||||
robot_code = self._robot_code_cache
|
||||
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] robot_code not available for image download")
|
||||
return None
|
||||
|
||||
# 返回一个特殊的 URL,包含 robot_code 和 download_code
|
||||
logger.info(f"[DingTalk] Successfully got image download URL for code: {download_code}")
|
||||
return f"dingtalk://download/{robot_code}:{download_code}"
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
|
||||
# 缓存 robot_code,用于后续图片下载
|
||||
if hasattr(incoming_message, 'robot_code'):
|
||||
self._robot_code_cache = incoming_message.robot_code
|
||||
|
||||
# Filter out stale messages from before channel startup (offline backlog)
|
||||
create_at = getattr(incoming_message, 'create_at', None)
|
||||
if create_at:
|
||||
msg_age_s = time.time() - int(create_at) / 1000
|
||||
if msg_age_s > 60:
|
||||
logger.warning(f"[DingTalk] stale msg filtered (age={msg_age_s:.0f}s), "
|
||||
f"msg_id={getattr(incoming_message, 'message_id', 'N/A')}")
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
|
||||
image_download_handler = self
|
||||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
|
||||
|
||||
if dingtalk_msg.is_group:
|
||||
self.handle_group(dingtalk_msg)
|
||||
else:
|
||||
self.handle_single(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] process error: {e}", exc_info=True)
|
||||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: DingTalkMessage):
|
||||
# 处理单聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE_CREATE:
|
||||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
|
||||
# 处理文件缓存逻辑
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
# 单聊的 session_id 就是 sender_id
|
||||
session_id = cmsg.from_user_id
|
||||
|
||||
# 如果是单张图片消息,缓存起来
|
||||
if cmsg.ctype == ContextType.IMAGE:
|
||||
if hasattr(cmsg, 'image_path') and cmsg.image_path:
|
||||
file_cache.add(session_id, cmsg.image_path, file_type='image')
|
||||
logger.info(f"[DingTalk] Image cached for session {session_id}, waiting for user query...")
|
||||
# 单张图片不直接处理,等待用户提问
|
||||
return
|
||||
|
||||
# 如果是文本消息,检查是否有缓存的文件
|
||||
if cmsg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
# 将缓存的文件附加到文本消息中
|
||||
file_refs = []
|
||||
for file_info in cached_files:
|
||||
file_path = file_info['path']
|
||||
file_type = file_info['type']
|
||||
if file_type == 'image':
|
||||
file_refs.append(f"[图片: {file_path}]")
|
||||
elif file_type == 'video':
|
||||
file_refs.append(f"[视频: {file_path}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {file_path}]")
|
||||
|
||||
cmsg.content = cmsg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[DingTalk] Attached {len(cached_files)} cached file(s) to user query")
|
||||
# 清除缓存
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: DingTalkMessage):
|
||||
# 处理群聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE_CREATE:
|
||||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
|
||||
# 处理文件缓存逻辑
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
# 群聊的 session_id
|
||||
if conf().get("group_shared_session", True):
|
||||
session_id = cmsg.other_user_id # conversation_id
|
||||
else:
|
||||
session_id = cmsg.from_user_id + "_" + cmsg.other_user_id
|
||||
|
||||
# 如果是单张图片消息,缓存起来
|
||||
if cmsg.ctype == ContextType.IMAGE:
|
||||
if hasattr(cmsg, 'image_path') and cmsg.image_path:
|
||||
file_cache.add(session_id, cmsg.image_path, file_type='image')
|
||||
logger.info(f"[DingTalk] Image cached for session {session_id}, waiting for user query...")
|
||||
# 单张图片不直接处理,等待用户提问
|
||||
return
|
||||
|
||||
# 如果是文本消息,检查是否有缓存的文件
|
||||
if cmsg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
# 将缓存的文件附加到文本消息中
|
||||
file_refs = []
|
||||
for file_info in cached_files:
|
||||
file_path = file_info['path']
|
||||
file_type = file_info['type']
|
||||
if file_type == 'image':
|
||||
file_refs.append(f"[图片: {file_path}]")
|
||||
elif file_type == 'video':
|
||||
file_refs.append(f"[视频: {file_path}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {file_path}]")
|
||||
|
||||
cmsg.content = cmsg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[DingTalk] Attached {len(cached_files)} cached file(s) to user query")
|
||||
# 清除缓存
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
context['no_need_at'] = True
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
logger.debug(f"[DingTalk] send() called with reply.type={reply.type}, content_length={len(str(reply.content))}")
|
||||
receiver = context["receiver"]
|
||||
|
||||
# Check if msg exists (for scheduled tasks, msg might be None)
|
||||
msg = context.kwargs.get('msg')
|
||||
if msg is None:
|
||||
# 定时任务场景:使用主动发送 API
|
||||
is_group = context.get("isgroup", False)
|
||||
logger.info(f"[DingTalk] Sending scheduled task message to {receiver} (is_group={is_group})")
|
||||
|
||||
# 使用缓存的 robot_code 或配置的值
|
||||
robot_code = self._robot_code or conf().get("dingtalk_robot_code")
|
||||
logger.info(f"[DingTalk] Using robot_code: {robot_code}, cached: {self._robot_code}, config: {conf().get('dingtalk_robot_code')}")
|
||||
|
||||
if not robot_code:
|
||||
logger.error(f"[DingTalk] Cannot send scheduled task: robot_code not available. Please send at least one message to the bot first, or configure dingtalk_robot_code in config.json")
|
||||
return
|
||||
|
||||
# 根据是否群聊选择不同的 API
|
||||
if is_group:
|
||||
success = self.send_group_message(receiver, reply.content, robot_code)
|
||||
else:
|
||||
# 单聊场景:尝试从 context 中获取 dingtalk_sender_staff_id
|
||||
sender_staff_id = context.get("dingtalk_sender_staff_id")
|
||||
if not sender_staff_id:
|
||||
logger.error(f"[DingTalk] Cannot send single chat scheduled message: sender_staff_id not available in context")
|
||||
return
|
||||
|
||||
logger.info(f"[DingTalk] Sending single message to staff_id: {sender_staff_id}")
|
||||
success = self.send_single_message(sender_staff_id, reply.content, robot_code)
|
||||
|
||||
if not success:
|
||||
logger.error(f"[DingTalk] Failed to send scheduled task message")
|
||||
return
|
||||
|
||||
# 从正常消息中提取并缓存 robot_code
|
||||
if hasattr(msg, 'robot_code'):
|
||||
robot_code = msg.robot_code
|
||||
if robot_code and robot_code != self._robot_code:
|
||||
self._robot_code = robot_code
|
||||
logger.debug(f"[DingTalk] Cached robot_code: {robot_code}")
|
||||
|
||||
isgroup = msg.is_group
|
||||
incoming_message = msg.incoming_message
|
||||
robot_code = self._robot_code or conf().get("dingtalk_robot_code")
|
||||
|
||||
# 处理图片和视频发送
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
logger.info(f"[DingTalk] Sending image: {reply.content}")
|
||||
|
||||
# 如果有附加的文本内容,先发送文本
|
||||
if hasattr(reply, 'text_content') and reply.text_content:
|
||||
self.reply_text(reply.text_content, incoming_message)
|
||||
import time
|
||||
time.sleep(0.3) # 短暂延迟,确保文本先到达
|
||||
|
||||
media_id = self.upload_media(reply.content, media_type="image")
|
||||
if media_id:
|
||||
# 使用主动发送 API 发送图片
|
||||
access_token = self.get_access_token()
|
||||
if access_token:
|
||||
success = self.send_image_with_media_id(
|
||||
access_token,
|
||||
media_id,
|
||||
incoming_message,
|
||||
isgroup
|
||||
)
|
||||
if not success:
|
||||
logger.error("[DingTalk] Failed to send image message")
|
||||
self.reply_text("抱歉,图片发送失败", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Cannot get access token")
|
||||
self.reply_text("抱歉,图片发送失败(无法获取token)", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Failed to upload image")
|
||||
self.reply_text("抱歉,图片上传失败", incoming_message)
|
||||
return
|
||||
|
||||
elif reply.type == ReplyType.FILE:
|
||||
# 如果有附加的文本内容,先发送文本
|
||||
if hasattr(reply, 'text_content') and reply.text_content:
|
||||
self.reply_text(reply.text_content, incoming_message)
|
||||
import time
|
||||
time.sleep(0.3) # 短暂延迟,确保文本先到达
|
||||
|
||||
# 判断是否为视频文件
|
||||
file_path = reply.content
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
|
||||
is_video = file_path.lower().endswith(('.mp4', '.avi', '.mov', '.wmv', '.flv'))
|
||||
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot get access token")
|
||||
self.reply_text("抱歉,文件发送失败(无法获取token)", incoming_message)
|
||||
return
|
||||
|
||||
if is_video:
|
||||
logger.info(f"[DingTalk] Sending video: {reply.content}")
|
||||
media_id = self.upload_media(reply.content, media_type="video")
|
||||
if media_id:
|
||||
# 发送视频消息
|
||||
msg_param = {
|
||||
"duration": "30", # TODO: 获取实际视频时长
|
||||
"videoMediaId": media_id,
|
||||
"videoType": "mp4",
|
||||
"height": "400",
|
||||
"width": "600",
|
||||
}
|
||||
success = self._send_file_message(
|
||||
access_token,
|
||||
incoming_message,
|
||||
"sampleVideo",
|
||||
msg_param,
|
||||
isgroup
|
||||
)
|
||||
if not success:
|
||||
self.reply_text("抱歉,视频发送失败", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Failed to upload video")
|
||||
self.reply_text("抱歉,视频上传失败", incoming_message)
|
||||
else:
|
||||
# 其他文件类型
|
||||
logger.info(f"[DingTalk] Sending file: {reply.content}")
|
||||
media_id = self.upload_media(reply.content, media_type="file")
|
||||
if media_id:
|
||||
file_name = os.path.basename(file_path)
|
||||
file_base, file_extension = os.path.splitext(file_name)
|
||||
msg_param = {
|
||||
"mediaId": media_id,
|
||||
"fileName": file_name,
|
||||
"fileType": file_extension[1:] if file_extension else "file"
|
||||
}
|
||||
success = self._send_file_message(
|
||||
access_token,
|
||||
incoming_message,
|
||||
"sampleFile",
|
||||
msg_param,
|
||||
isgroup
|
||||
)
|
||||
if not success:
|
||||
self.reply_text("抱歉,文件发送失败", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Failed to upload file")
|
||||
self.reply_text("抱歉,文件上传失败", incoming_message)
|
||||
return
|
||||
|
||||
# 处理文本消息
|
||||
elif reply.type == ReplyType.TEXT:
|
||||
logger.info(f"[DingTalk] Sending text message, length={len(reply.content)}")
|
||||
if conf().get("dingtalk_card_enabled"):
|
||||
logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
def reply_with_text():
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
def reply_with_at_text():
|
||||
self.reply_text("📢 您有一条新的消息,请查看。", incoming_message)
|
||||
def reply_with_ai_markdown():
|
||||
button_list, markdown_content = self.generate_button_markdown_content(context, reply)
|
||||
self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI生成", "",[incoming_message.sender_staff_id])
|
||||
|
||||
if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]:
|
||||
if isgroup:
|
||||
reply_with_ai_markdown()
|
||||
reply_with_at_text()
|
||||
else:
|
||||
reply_with_ai_markdown()
|
||||
else:
|
||||
# 暂不支持其它类型消息回复
|
||||
reply_with_text()
|
||||
else:
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
return
|
||||
|
||||
def _send_file_message(self, access_token: str, incoming_message, msg_key: str, msg_param: dict, is_group: bool) -> bool:
|
||||
"""
|
||||
发送文件/视频消息的通用方法
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
incoming_message: 钉钉消息对象
|
||||
msg_key: 消息类型 (sampleFile, sampleVideo, sampleAudio)
|
||||
msg_param: 消息参数
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
body = {
|
||||
"robotCode": incoming_message.robot_code,
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param),
|
||||
}
|
||||
|
||||
if is_group:
|
||||
# 群聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
body["openConversationId"] = incoming_message.conversation_id
|
||||
else:
|
||||
# 单聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
body["userIds"] = [incoming_message.sender_staff_id]
|
||||
|
||||
try:
|
||||
response = requests.post(url=url, headers=headers, json=body, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
logger.info(f"[DingTalk] File send result: {response.text}")
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Send file error: {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Send file exception: {e}")
|
||||
return False
|
||||
|
||||
def generate_button_markdown_content(self, context, reply):
|
||||
image_url = context.kwargs.get("image_url")
|
||||
promptEn = context.kwargs.get("promptEn")
|
||||
reply_text = reply.content
|
||||
button_list = []
|
||||
markdown_content = f"""
|
||||
{reply.content}
|
||||
"""
|
||||
if image_url is not None and promptEn is not None:
|
||||
button_list = [
|
||||
{"text": "查看原图", "url": image_url, "iosUrl": image_url, "color": "blue"}
|
||||
]
|
||||
markdown_content = f"""
|
||||
{promptEn}
|
||||
|
||||

|
||||
|
||||
{reply_text}
|
||||
|
||||
"""
|
||||
logger.debug(f"[Dingtalk] generate_button_markdown_content, button_list={button_list} , markdown_content={markdown_content}")
|
||||
|
||||
return button_list, markdown_content
|
||||
244
channel/dingtalk/dingtalk_message.py
Normal file
244
channel/dingtalk/dingtalk_message.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import requests
|
||||
from dingtalk_stream import ChatbotMessage
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
# -*- coding=utf-8 -*-
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
class DingTalkMessage(ChatMessage):
|
||||
def __init__(self, event: ChatbotMessage, image_download_handler):
|
||||
super().__init__(event)
|
||||
self.image_download_handler = image_download_handler
|
||||
self.msg_id = event.message_id
|
||||
self.message_type = event.message_type
|
||||
self.incoming_message = event
|
||||
self.sender_staff_id = event.sender_staff_id
|
||||
self.other_user_id = event.conversation_id
|
||||
self.create_time = event.create_at
|
||||
self.image_content = event.image_content
|
||||
self.rich_text_content = event.rich_text_content
|
||||
self.robot_code = event.robot_code # 机器人编码
|
||||
if event.conversation_type == "1":
|
||||
self.is_group = False
|
||||
else:
|
||||
self.is_group = True
|
||||
|
||||
if self.message_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
self.content = event.text.content.strip()
|
||||
elif self.message_type == "audio":
|
||||
# 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理
|
||||
self.content = event.extensions['content']['recognition'].strip()
|
||||
self.ctype = ContextType.TEXT
|
||||
elif (self.message_type == 'picture') or (self.message_type == 'richText'):
|
||||
# 钉钉图片类型或富文本类型消息处理
|
||||
image_list = event.get_image_list()
|
||||
|
||||
if self.message_type == 'picture' and len(image_list) > 0:
|
||||
# 单张图片消息:下载到工作空间,用于文件缓存
|
||||
self.ctype = ContextType.IMAGE
|
||||
download_code = image_list[0]
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
|
||||
# 下载到工作空间 tmp 目录
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
image_path = download_image_file(download_url, tmp_dir)
|
||||
if image_path:
|
||||
self.content = image_path
|
||||
self.image_path = image_path # 保存图片路径用于缓存
|
||||
logger.info(f"[DingTalk] Downloaded single image to {image_path}")
|
||||
else:
|
||||
self.content = "[图片下载失败]"
|
||||
self.image_path = None
|
||||
|
||||
elif self.message_type == 'richText' and len(image_list) > 0:
|
||||
# 富文本消息:下载所有图片并附加到文本中
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
# 下载到工作空间 tmp 目录
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
# 提取富文本中的文本内容
|
||||
text_content = ""
|
||||
if self.rich_text_content:
|
||||
# rich_text_content 是一个 RichTextContent 对象,需要从中提取文本
|
||||
text_list = event.get_text_list()
|
||||
if text_list:
|
||||
text_content = "".join(text_list).strip()
|
||||
|
||||
# 下载所有图片
|
||||
image_paths = []
|
||||
for download_code in image_list:
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
image_path = download_image_file(download_url, tmp_dir)
|
||||
if image_path:
|
||||
image_paths.append(image_path)
|
||||
|
||||
# 构建消息内容:文本 + 图片路径
|
||||
content_parts = []
|
||||
if text_content:
|
||||
content_parts.append(text_content)
|
||||
for img_path in image_paths:
|
||||
content_parts.append(f"[图片: {img_path}]")
|
||||
|
||||
self.content = "\n".join(content_parts) if content_parts else "[富文本消息]"
|
||||
logger.info(f"[DingTalk] Received richText with {len(image_paths)} image(s): {self.content}")
|
||||
else:
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = "[未找到图片]"
|
||||
logger.debug(f"[DingTalk] messageType: {self.message_type}, imageList isEmpty")
|
||||
|
||||
if self.is_group:
|
||||
self.from_user_id = event.conversation_id
|
||||
self.actual_user_id = event.sender_id
|
||||
self.is_at = True
|
||||
else:
|
||||
self.from_user_id = event.sender_id
|
||||
self.actual_user_id = event.sender_id
|
||||
self.to_user_id = event.chatbot_user_id
|
||||
self.other_user_nickname = event.conversation_title
|
||||
|
||||
|
||||
def download_image_file(image_url, temp_dir):
|
||||
"""
|
||||
下载图片文件
|
||||
支持两种方式:
|
||||
1. 普通 HTTP(S) URL
|
||||
2. 钉钉 downloadCode: dingtalk://download/{download_code}
|
||||
"""
|
||||
# 检查临时目录是否存在,如果不存在则创建
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
|
||||
# 处理钉钉 downloadCode
|
||||
if image_url.startswith("dingtalk://download/"):
|
||||
download_code = image_url.replace("dingtalk://download/", "")
|
||||
logger.info(f"[DingTalk] Downloading image with downloadCode: {download_code[:20]}...")
|
||||
|
||||
# 需要从外部传入 access_token,这里先用一个临时方案
|
||||
# 从 config 获取 dingtalk_client_id 和 dingtalk_client_secret
|
||||
from config import conf
|
||||
client_id = conf().get("dingtalk_client_id")
|
||||
client_secret = conf().get("dingtalk_client_secret")
|
||||
|
||||
if not client_id or not client_secret:
|
||||
logger.error("[DingTalk] Missing dingtalk_client_id or dingtalk_client_secret")
|
||||
return None
|
||||
|
||||
# 解析 robot_code 和 download_code
|
||||
parts = download_code.split(":", 1)
|
||||
if len(parts) != 2:
|
||||
logger.error(f"[DingTalk] Invalid download_code format (expected robot_code:download_code): {download_code[:50]}")
|
||||
return None
|
||||
|
||||
robot_code, actual_download_code = parts
|
||||
|
||||
# 获取 access_token(使用新版 API)
|
||||
token_url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
token_headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
token_body = {
|
||||
"appKey": client_id,
|
||||
"appSecret": client_secret
|
||||
}
|
||||
|
||||
try:
|
||||
token_response = requests.post(token_url, json=token_body, headers=token_headers, timeout=10)
|
||||
|
||||
if token_response.status_code == 200:
|
||||
token_data = token_response.json()
|
||||
access_token = token_data.get("accessToken")
|
||||
|
||||
if not access_token:
|
||||
logger.error(f"[DingTalk] Failed to get access token: {token_data}")
|
||||
return None
|
||||
|
||||
# 获取下载 URL(使用新版 API)
|
||||
download_api_url = "https://api.dingtalk.com/v1.0/robot/messageFiles/download"
|
||||
download_headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
download_body = {
|
||||
"downloadCode": actual_download_code,
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
download_response = requests.post(download_api_url, json=download_body, headers=download_headers, timeout=10)
|
||||
|
||||
if download_response.status_code == 200:
|
||||
download_data = download_response.json()
|
||||
download_url = download_data.get("downloadUrl")
|
||||
|
||||
if not download_url:
|
||||
logger.error(f"[DingTalk] No downloadUrl in response: {download_data}")
|
||||
return None
|
||||
|
||||
# 从 downloadUrl 下载实际图片
|
||||
image_response = requests.get(download_url, stream=True, timeout=60)
|
||||
|
||||
if image_response.status_code == 200:
|
||||
# 生成文件名(使用 download_code 的 hash,避免特殊字符)
|
||||
import hashlib
|
||||
file_hash = hashlib.md5(actual_download_code.encode()).hexdigest()[:16]
|
||||
file_name = f"{file_hash}.png"
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(image_response.content)
|
||||
|
||||
logger.info(f"[DingTalk] Image downloaded successfully: {file_path}")
|
||||
return file_path
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to download image from URL: {image_response.status_code}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to get download URL: {download_response.status_code}, {download_response.text}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to get access token: {token_response.status_code}, {token_response.text}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Exception downloading image: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
# 普通 HTTP(S) URL
|
||||
else:
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36'
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(image_url, headers=headers, stream=True, timeout=60 * 5)
|
||||
if response.status_code == 200:
|
||||
# 生成文件名
|
||||
file_name = image_url.split("/")[-1].split("?")[0]
|
||||
|
||||
# 将文件保存到临时目录
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
return file_path
|
||||
else:
|
||||
logger.info(f"[Dingtalk] Failed to download image file, {response.content}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Dingtalk] Exception downloading image: {e}")
|
||||
return None
|
||||
184
channel/feishu/README.md
Normal file
184
channel/feishu/README.md
Normal file
@@ -0,0 +1,184 @@
|
||||
# 飞书Channel使用说明
|
||||
|
||||
飞书Channel支持两种事件接收模式,可以根据部署环境灵活选择。
|
||||
|
||||
## 模式对比
|
||||
|
||||
| 模式 | 适用场景 | 优点 | 缺点 |
|
||||
|------|---------|------|------|
|
||||
| **webhook** | 生产环境 | 稳定可靠,官方推荐 | 需要公网IP或域名 |
|
||||
| **websocket** | 本地开发 | 无需公网IP,开发便捷 | 需要额外依赖 |
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 基础配置
|
||||
|
||||
在 `config.json` 中添加以下配置:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "cli_xxxxx",
|
||||
"feishu_app_secret": "your_app_secret",
|
||||
"feishu_token": "your_verification_token",
|
||||
"feishu_bot_name": "你的机器人名称",
|
||||
"feishu_event_mode": "webhook",
|
||||
"feishu_port": 9891
|
||||
}
|
||||
```
|
||||
|
||||
### 配置项说明
|
||||
|
||||
- `feishu_app_id`: 飞书应用的App ID
|
||||
- `feishu_app_secret`: 飞书应用的App Secret
|
||||
- `feishu_token`: 事件订阅的Verification Token
|
||||
- `feishu_bot_name`: 机器人名称(用于群聊@判断)
|
||||
- `feishu_event_mode`: 事件接收模式,可选值:
|
||||
- `"websocket"`: 长连接模式(默认)
|
||||
- `"webhook"`: HTTP服务器模式
|
||||
- `feishu_port`: webhook模式下的HTTP服务端口(默认9891)
|
||||
|
||||
## 模式一: Webhook模式(推荐生产环境)
|
||||
|
||||
### 1. 配置
|
||||
|
||||
```json
|
||||
{
|
||||
"feishu_event_mode": "webhook",
|
||||
"feishu_port": 9891
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 启动服务
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
```
|
||||
|
||||
服务将在 `http://0.0.0.0:9891` 启动。
|
||||
|
||||
### 3. 配置飞书应用
|
||||
|
||||
1. 登录[飞书开放平台](https://open.feishu.cn/)
|
||||
2. 进入应用详情 -> 事件订阅
|
||||
3. 选择 **将事件发送至开发者服务器**
|
||||
4. 填写请求地址: `http://your-domain:9891/`
|
||||
5. 添加事件: `im.message.receive_v1` (接收消息v2.0)
|
||||
6. 保存配置
|
||||
|
||||
### 4. 注意事项
|
||||
|
||||
- 需要有公网IP或域名
|
||||
- 确保防火墙开放对应端口
|
||||
- 建议使用HTTPS(需要配置反向代理)
|
||||
|
||||
## 模式二: WebSocket模式(推荐本地开发)
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install lark-oapi
|
||||
```
|
||||
|
||||
### 2. 配置
|
||||
|
||||
```json
|
||||
{
|
||||
"feishu_event_mode": "websocket"
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 启动服务
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
```
|
||||
|
||||
程序将自动建立与飞书开放平台的长连接。
|
||||
|
||||
### 4. 配置飞书应用
|
||||
|
||||
1. 登录[飞书开放平台](https://open.feishu.cn/)
|
||||
2. 进入应用详情 -> 事件订阅
|
||||
3. 选择 **使用长连接接收事件**
|
||||
4. 添加事件: `im.message.receive_v1` (接收消息v2.0)
|
||||
5. 保存配置
|
||||
|
||||
### 5. 注意事项
|
||||
|
||||
- 无需公网IP
|
||||
- 需要能访问公网(建立WebSocket连接)
|
||||
- 每个应用最多50个连接
|
||||
- 集群模式下消息随机分发到一个客户端
|
||||
|
||||
## 平滑迁移
|
||||
|
||||
从webhook模式切换到websocket模式(或反向切换):
|
||||
|
||||
1. 修改 `config.json` 中的 `feishu_event_mode`
|
||||
2. 如果切换到websocket模式,安装 `lark-oapi` 依赖
|
||||
3. 重启服务
|
||||
4. 在飞书开放平台修改事件订阅方式
|
||||
|
||||
**重要**: 同一时间只能使用一种模式,否则会导致消息重复接收。
|
||||
|
||||
## 消息去重机制
|
||||
|
||||
两种模式都使用相同的消息去重机制:
|
||||
|
||||
- 使用 `ExpiredDict` 存储已处理的消息ID
|
||||
- 过期时间: 7.1小时
|
||||
- 确保消息不会重复处理
|
||||
|
||||
## 故障排查
|
||||
|
||||
### WebSocket模式连接失败
|
||||
|
||||
```
|
||||
[FeiShu] lark_oapi not installed
|
||||
```
|
||||
|
||||
**解决**: 安装依赖 `pip install lark-oapi`
|
||||
|
||||
### SSL证书验证失败
|
||||
|
||||
```
|
||||
[Lark][ERROR] connect failed, err:[SSL:CERTIFICATE_VERIFY_FAILED] certificate verify failed: self signed certificate in certificate chain
|
||||
```
|
||||
|
||||
**原因**: 网络环境中存在自签名证书或SSL中间人代理(如企业代理、VPN等)
|
||||
|
||||
**解决**: 程序会自动检测SSL证书验证失败,并自动重试禁用证书验证的连接。无需手动配置。
|
||||
|
||||
当遇到证书错误时,日志会显示:
|
||||
```
|
||||
[FeiShu] SSL certificate verification disabled due to certificate error. This may happen when using corporate proxy or self-signed certificates.
|
||||
```
|
||||
|
||||
这是正常现象,程序会自动处理并继续运行。
|
||||
|
||||
### Webhook模式端口被占用
|
||||
|
||||
```
|
||||
Address already in use
|
||||
```
|
||||
|
||||
**解决**: 修改 `feishu_port` 配置或关闭占用端口的进程
|
||||
|
||||
### 收不到消息
|
||||
|
||||
1. 检查飞书应用的事件订阅配置
|
||||
2. 确认已添加 `im.message.receive_v1` 事件
|
||||
3. 检查应用权限: 需要 `im:message` 权限
|
||||
4. 查看日志中的错误信息
|
||||
|
||||
## 开发建议
|
||||
|
||||
- **本地开发**: 使用websocket模式,快速迭代
|
||||
- **测试环境**: 可以使用webhook模式 + 内网穿透工具(如ngrok)
|
||||
- **生产环境**: 使用webhook模式,配置正式域名和HTTPS
|
||||
|
||||
## 参考文档
|
||||
|
||||
- [飞书开放平台 - 事件订阅](https://open.feishu.cn/document/ukTMukTMukTM/uUTNz4SN1MjL1UzM)
|
||||
- [飞书SDK - Python](https://github.com/larksuite/oapi-sdk-python)
|
||||
880
channel/feishu/feishu_channel.py
Normal file
880
channel/feishu/feishu_channel.py
Normal file
@@ -0,0 +1,880 @@
|
||||
"""
|
||||
飞书通道接入
|
||||
|
||||
支持两种事件接收模式:
|
||||
1. webhook模式: 通过HTTP服务器接收事件(需要公网IP)
|
||||
2. websocket模式: 通过长连接接收事件(本地开发友好)
|
||||
|
||||
通过配置项 feishu_event_mode 选择模式: "webhook" 或 "websocket"
|
||||
|
||||
@author Saboteur7
|
||||
@Date 2023/11/19
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import threading
|
||||
# -*- coding=utf-8 -*-
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
import web
|
||||
|
||||
from bridge.context import Context
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.feishu.feishu_message import FeishuMessage
|
||||
from common import utils
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
# Suppress verbose logs from Lark SDK
|
||||
logging.getLogger("Lark").setLevel(logging.WARNING)
|
||||
|
||||
URL_VERIFICATION = "url_verification"
|
||||
|
||||
# Lazy-check for lark_oapi SDK availability without importing it at module level.
|
||||
# The full `import lark_oapi` pulls in 10k+ files and takes 4-10s, so we defer
|
||||
# the actual import to _startup_websocket() where it is needed.
|
||||
LARK_SDK_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||
lark = None # will be populated on first use via _ensure_lark_imported()
|
||||
|
||||
|
||||
def _ensure_lark_imported():
|
||||
"""Import lark_oapi on first use (takes 4-10s due to 10k+ source files)."""
|
||||
global lark
|
||||
if lark is None:
|
||||
import lark_oapi as _lark
|
||||
lark = _lark
|
||||
return lark
|
||||
|
||||
|
||||
@singleton
|
||||
class FeiShuChanel(ChatChannel):
|
||||
feishu_app_id = conf().get('feishu_app_id')
|
||||
feishu_app_secret = conf().get('feishu_app_secret')
|
||||
feishu_token = conf().get('feishu_token')
|
||||
feishu_event_mode = conf().get('feishu_event_mode', 'websocket') # webhook 或 websocket
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
self._http_server = None
|
||||
self._ws_client = None
|
||||
self._ws_thread = None
|
||||
self._bot_open_id = None # cached bot open_id for @-mention matching
|
||||
logger.debug("[FeiShu] app_id={}, app_secret={}, verification_token={}, event_mode={}".format(
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token, self.feishu_event_mode))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# 验证配置
|
||||
if self.feishu_event_mode == 'websocket' and not LARK_SDK_AVAILABLE:
|
||||
logger.error("[FeiShu] websocket mode requires lark_oapi. Please install: pip install lark-oapi")
|
||||
raise Exception("lark_oapi not installed")
|
||||
|
||||
def startup(self):
|
||||
self.feishu_app_id = conf().get('feishu_app_id')
|
||||
self.feishu_app_secret = conf().get('feishu_app_secret')
|
||||
self.feishu_token = conf().get('feishu_token')
|
||||
self.feishu_event_mode = conf().get('feishu_event_mode', 'websocket')
|
||||
self._fetch_bot_open_id()
|
||||
if self.feishu_event_mode == 'websocket':
|
||||
self._startup_websocket()
|
||||
else:
|
||||
self._startup_webhook()
|
||||
|
||||
def _fetch_bot_open_id(self):
|
||||
"""Fetch the bot's own open_id via API so we can match @-mentions without feishu_bot_name."""
|
||||
try:
|
||||
access_token = self.fetch_access_token()
|
||||
if not access_token:
|
||||
logger.warning("[FeiShu] Cannot fetch bot info: no access_token")
|
||||
return
|
||||
headers = {"Authorization": "Bearer " + access_token}
|
||||
resp = requests.get("https://open.feishu.cn/open-apis/bot/v3/info/", headers=headers, timeout=5)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
if data.get("code") == 0:
|
||||
self._bot_open_id = data.get("bot", {}).get("open_id")
|
||||
logger.info(f"[FeiShu] Bot open_id fetched: {self._bot_open_id}")
|
||||
else:
|
||||
logger.warning(f"[FeiShu] Fetch bot info failed: code={data.get('code')}, msg={data.get('msg')}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Fetch bot open_id error: {e}")
|
||||
|
||||
def stop(self):
|
||||
import ctypes
|
||||
logger.info("[FeiShu] stop() called")
|
||||
ws_client = self._ws_client
|
||||
self._ws_client = None
|
||||
ws_thread = self._ws_thread
|
||||
self._ws_thread = None
|
||||
# Interrupt the ws thread first so its blocking start() unblocks
|
||||
if ws_thread and ws_thread.is_alive():
|
||||
try:
|
||||
tid = ws_thread.ident
|
||||
if tid:
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
|
||||
)
|
||||
if res == 1:
|
||||
logger.info("[FeiShu] Interrupted ws thread via ctypes")
|
||||
elif res > 1:
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Error interrupting ws thread: {e}")
|
||||
# lark.ws.Client has no stop() method; thread interruption above is sufficient
|
||||
if self._http_server:
|
||||
try:
|
||||
self._http_server.stop()
|
||||
logger.info("[FeiShu] HTTP server stopped")
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Error stopping HTTP server: {e}")
|
||||
self._http_server = None
|
||||
logger.info("[FeiShu] stop() completed")
|
||||
|
||||
def _startup_webhook(self):
|
||||
"""启动HTTP服务器接收事件(webhook模式)"""
|
||||
logger.debug("[FeiShu] Starting in webhook mode...")
|
||||
urls = (
|
||||
'/', 'channel.feishu.feishu_channel.FeishuController'
|
||||
)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("feishu_port", 9891)
|
||||
func = web.httpserver.StaticMiddleware(app.wsgifunc())
|
||||
func = web.httpserver.LogMiddleware(func)
|
||||
server = web.httpserver.WSGIServer(("0.0.0.0", port), func)
|
||||
self._http_server = server
|
||||
try:
|
||||
server.start()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
server.stop()
|
||||
|
||||
def _startup_websocket(self):
|
||||
"""启动长连接接收事件(websocket模式)"""
|
||||
_ensure_lark_imported()
|
||||
logger.debug("[FeiShu] Starting in websocket mode...")
|
||||
|
||||
# 创建事件处理器
|
||||
def handle_message_event(data: lark.im.v1.P2ImMessageReceiveV1) -> None:
|
||||
"""处理接收消息事件 v2.0"""
|
||||
try:
|
||||
event_dict = json.loads(lark.JSON.marshal(data))
|
||||
event = event_dict.get("event", {})
|
||||
msg = event.get("message", {})
|
||||
|
||||
# Skip group messages that don't @-mention the bot (reduce log noise)
|
||||
if msg.get("chat_type") == "group" and not msg.get("mentions") and msg.get("message_type") == "text":
|
||||
return
|
||||
|
||||
logger.debug(f"[FeiShu] websocket receive event: {lark.JSON.marshal(data, indent=2)}")
|
||||
|
||||
# 处理消息
|
||||
self._handle_message_event(event)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] websocket handle message error: {e}", exc_info=True)
|
||||
|
||||
# 构建事件分发器
|
||||
event_handler = lark.EventDispatcherHandler.builder("", "") \
|
||||
.register_p2_im_message_receive_v1(handle_message_event) \
|
||||
.build()
|
||||
|
||||
def start_client_with_retry():
|
||||
"""Run ws client in this thread with its own event loop to avoid conflicts."""
|
||||
import asyncio
|
||||
import ssl as ssl_module
|
||||
original_create_default_context = ssl_module.create_default_context
|
||||
|
||||
def create_unverified_context(*args, **kwargs):
|
||||
context = original_create_default_context(*args, **kwargs)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
return context
|
||||
|
||||
# lark_oapi.ws.client captures the event loop at module-import time as a module-
|
||||
# level global variable. When a previous ws thread is force-killed via ctypes its
|
||||
# loop may still be marked as "running", which causes the next ws_client.start()
|
||||
# call (in this new thread) to raise "This event loop is already running".
|
||||
# Fix: replace the module-level loop with a brand-new, idle loop before starting.
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
import lark_oapi.ws.client as _lark_ws_client_mod
|
||||
_lark_ws_client_mod.loop = loop
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
startup_error = None
|
||||
for attempt in range(2):
|
||||
try:
|
||||
if attempt == 1:
|
||||
logger.warning("[FeiShu] Retrying with SSL verification disabled...")
|
||||
ssl_module.create_default_context = create_unverified_context
|
||||
ssl_module._create_unverified_context = create_unverified_context
|
||||
|
||||
ws_client = lark.ws.Client(
|
||||
self.feishu_app_id,
|
||||
self.feishu_app_secret,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.WARNING
|
||||
)
|
||||
self._ws_client = ws_client
|
||||
logger.debug("[FeiShu] Websocket client starting...")
|
||||
ws_client.start()
|
||||
break
|
||||
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
logger.info("[FeiShu] Websocket thread received stop signal")
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
is_ssl_error = ("CERTIFICATE_VERIFY_FAILED" in error_msg
|
||||
or "certificate verify failed" in error_msg.lower())
|
||||
if is_ssl_error and attempt == 0:
|
||||
logger.warning(f"[FeiShu] SSL error: {error_msg}, retrying...")
|
||||
continue
|
||||
logger.error(f"[FeiShu] Websocket client error: {e}", exc_info=True)
|
||||
startup_error = error_msg
|
||||
ssl_module.create_default_context = original_create_default_context
|
||||
break
|
||||
if startup_error:
|
||||
self.report_startup_error(startup_error)
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[FeiShu] Websocket thread exited")
|
||||
|
||||
ws_thread = threading.Thread(target=start_client_with_retry, daemon=True)
|
||||
self._ws_thread = ws_thread
|
||||
ws_thread.start()
|
||||
logger.info("[FeiShu] ✅ Websocket thread started, ready to receive messages")
|
||||
ws_thread.join()
|
||||
|
||||
def _is_mention_bot(self, mentions: list) -> bool:
|
||||
"""Check whether any mention in the list refers to this bot.
|
||||
|
||||
Priority:
|
||||
1. Match by open_id (obtained from /bot/v3/info at startup, no config needed)
|
||||
2. Fallback to feishu_bot_name config for backward compatibility
|
||||
3. If neither is available, assume the first mention is the bot (Feishu only
|
||||
delivers group messages that @-mention the bot, so this is usually correct)
|
||||
"""
|
||||
if self._bot_open_id:
|
||||
return any(
|
||||
m.get("id", {}).get("open_id") == self._bot_open_id
|
||||
for m in mentions
|
||||
)
|
||||
bot_name = conf().get("feishu_bot_name")
|
||||
if bot_name:
|
||||
return any(m.get("name") == bot_name for m in mentions)
|
||||
# Feishu event subscription only delivers messages that @-mention the bot,
|
||||
# so reaching here means the bot was indeed mentioned.
|
||||
return True
|
||||
|
||||
def _handle_message_event(self, event: dict):
|
||||
"""
|
||||
处理消息事件的核心逻辑
|
||||
webhook和websocket模式共用此方法
|
||||
"""
|
||||
if not event.get("message") or not event.get("sender"):
|
||||
logger.warning(f"[FeiShu] invalid message, event={event}")
|
||||
return
|
||||
|
||||
msg = event.get("message")
|
||||
|
||||
# 幂等判断
|
||||
msg_id = msg.get("message_id")
|
||||
if self.receivedMsgs.get(msg_id):
|
||||
logger.warning(f"[FeiShu] repeat msg filtered, msg_id={msg_id}")
|
||||
return
|
||||
self.receivedMsgs[msg_id] = True
|
||||
|
||||
# Filter out stale messages from before channel startup (offline backlog)
|
||||
import time as _time
|
||||
create_time_ms = msg.get("create_time")
|
||||
if create_time_ms:
|
||||
msg_age_s = _time.time() - int(create_time_ms) / 1000
|
||||
if msg_age_s > 60:
|
||||
logger.warning(f"[FeiShu] stale msg filtered (age={msg_age_s:.0f}s), msg_id={msg_id}")
|
||||
return
|
||||
|
||||
is_group = False
|
||||
chat_type = msg.get("chat_type")
|
||||
|
||||
if chat_type == "group":
|
||||
if not msg.get("mentions") and msg.get("message_type") == "text":
|
||||
# 群聊中未@不响应
|
||||
return
|
||||
if msg.get("mentions") and msg.get("message_type") == "text":
|
||||
if not self._is_mention_bot(msg.get("mentions")):
|
||||
return
|
||||
# 群聊
|
||||
is_group = True
|
||||
receive_id_type = "chat_id"
|
||||
elif chat_type == "p2p":
|
||||
receive_id_type = "open_id"
|
||||
else:
|
||||
logger.warning("[FeiShu] message ignore")
|
||||
return
|
||||
|
||||
# 构造飞书消息对象
|
||||
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=self.fetch_access_token())
|
||||
if not feishu_msg:
|
||||
return
|
||||
|
||||
# 处理文件缓存逻辑
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
# 获取 session_id(用于缓存关联)
|
||||
if is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
session_id = msg.get("chat_id") # 群共享会话
|
||||
else:
|
||||
session_id = feishu_msg.from_user_id + "_" + msg.get("chat_id")
|
||||
else:
|
||||
session_id = feishu_msg.from_user_id
|
||||
|
||||
# 如果是单张图片消息,缓存起来
|
||||
if feishu_msg.ctype == ContextType.IMAGE:
|
||||
if hasattr(feishu_msg, 'image_path') and feishu_msg.image_path:
|
||||
file_cache.add(session_id, feishu_msg.image_path, file_type='image')
|
||||
logger.info(f"[FeiShu] Image cached for session {session_id}, waiting for user query...")
|
||||
# 单张图片不直接处理,等待用户提问
|
||||
return
|
||||
|
||||
# 如果是文本消息,检查是否有缓存的文件
|
||||
if feishu_msg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
# 将缓存的文件附加到文本消息中
|
||||
file_refs = []
|
||||
for file_info in cached_files:
|
||||
file_path = file_info['path']
|
||||
file_type = file_info['type']
|
||||
if file_type == 'image':
|
||||
file_refs.append(f"[图片: {file_path}]")
|
||||
elif file_type == 'video':
|
||||
file_refs.append(f"[视频: {file_path}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {file_path}]")
|
||||
|
||||
feishu_msg.content = feishu_msg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[FeiShu] Attached {len(cached_files)} cached file(s) to user query")
|
||||
# 清除缓存
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(
|
||||
feishu_msg.ctype,
|
||||
feishu_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=feishu_msg,
|
||||
receive_id_type=receive_id_type,
|
||||
no_need_at=True
|
||||
)
|
||||
if context:
|
||||
self.produce(context)
|
||||
logger.debug(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context["isgroup"]
|
||||
if msg:
|
||||
access_token = msg.access_token
|
||||
else:
|
||||
access_token = self.fetch_access_token()
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
msg_type = "text"
|
||||
logger.debug(f"[FeiShu] sending reply, type={context.type}, content={reply.content[:100]}...")
|
||||
reply_content = reply.content
|
||||
content_key = "text"
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
# 图片上传
|
||||
reply_content = self._upload_image_url(reply.content, access_token)
|
||||
if not reply_content:
|
||||
logger.warning("[FeiShu] upload image failed")
|
||||
return
|
||||
msg_type = "image"
|
||||
content_key = "image_key"
|
||||
elif reply.type == ReplyType.FILE:
|
||||
# 如果有附加的文本内容,先发送文本
|
||||
if hasattr(reply, 'text_content') and reply.text_content:
|
||||
logger.info(f"[FeiShu] Sending text before file: {reply.text_content[:50]}...")
|
||||
text_reply = Reply(ReplyType.TEXT, reply.text_content)
|
||||
self._send(text_reply, context)
|
||||
import time
|
||||
time.sleep(0.3) # 短暂延迟,确保文本先到达
|
||||
|
||||
# 判断是否为视频文件
|
||||
file_path = reply.content
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
|
||||
is_video = file_path.lower().endswith(('.mp4', '.avi', '.mov', '.wmv', '.flv'))
|
||||
|
||||
if is_video:
|
||||
# 视频上传(包含duration信息)
|
||||
upload_data = self._upload_video_url(reply.content, access_token)
|
||||
if not upload_data or not upload_data.get('file_key'):
|
||||
logger.warning("[FeiShu] upload video failed")
|
||||
return
|
||||
|
||||
# 视频使用 media 类型(根据官方文档)
|
||||
# 错误码 230055 说明:上传 mp4 时必须使用 msg_type="media"
|
||||
msg_type = "media"
|
||||
reply_content = upload_data # 完整的上传响应数据(包含file_key和duration)
|
||||
logger.info(
|
||||
f"[FeiShu] Sending video: file_key={upload_data.get('file_key')}, duration={upload_data.get('duration')}ms")
|
||||
content_key = None # 直接序列化整个对象
|
||||
else:
|
||||
# 其他文件使用 file 类型
|
||||
file_key = self._upload_file_url(reply.content, access_token)
|
||||
if not file_key:
|
||||
logger.warning("[FeiShu] upload file failed")
|
||||
return
|
||||
reply_content = file_key
|
||||
msg_type = "file"
|
||||
content_key = "file_key"
|
||||
|
||||
# Check if we can reply to an existing message (need msg_id)
|
||||
can_reply = is_group and msg and hasattr(msg, 'msg_id') and msg.msg_id
|
||||
|
||||
# Build content JSON
|
||||
content_json = json.dumps(reply_content) if content_key is None else json.dumps({content_key: reply_content})
|
||||
logger.debug(f"[FeiShu] Sending message: msg_type={msg_type}, content={content_json[:200]}")
|
||||
|
||||
if can_reply:
|
||||
# 群聊中回复已有消息
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply"
|
||||
data = {
|
||||
"msg_type": msg_type,
|
||||
"content": content_json
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
|
||||
else:
|
||||
# 发送新消息(私聊或群聊中无msg_id的情况,如定时任务)
|
||||
url = "https://open.feishu.cn/open-apis/im/v1/messages"
|
||||
params = {"receive_id_type": context.get("receive_id_type") or "open_id"}
|
||||
data = {
|
||||
"receive_id": context.get("receiver"),
|
||||
"msg_type": msg_type,
|
||||
"content": content_json
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10))
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
logger.info(f"[FeiShu] send message success")
|
||||
else:
|
||||
logger.error(f"[FeiShu] send message failed, code={res.get('code')}, msg={res.get('msg')}")
|
||||
|
||||
def fetch_access_token(self) -> str:
|
||||
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/"
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
req_body = {
|
||||
"app_id": self.feishu_app_id,
|
||||
"app_secret": self.feishu_app_secret
|
||||
}
|
||||
data = bytes(json.dumps(req_body), encoding='utf8')
|
||||
response = requests.post(url=url, data=data, headers=headers)
|
||||
if response.status_code == 200:
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
logger.error(f"[FeiShu] get tenant_access_token error, code={res.get('code')}, msg={res.get('msg')}")
|
||||
return ""
|
||||
else:
|
||||
return res.get("tenant_access_token")
|
||||
else:
|
||||
logger.error(f"[FeiShu] fetch token error, res={response}")
|
||||
|
||||
def _upload_image_url(self, img_url, access_token):
|
||||
logger.debug(f"[FeiShu] start process image, img_url={img_url}")
|
||||
|
||||
# Check if it's a local file path (file:// protocol)
|
||||
if img_url.startswith("file://"):
|
||||
local_path = img_url[7:] # Remove "file://" prefix
|
||||
logger.info(f"[FeiShu] uploading local file: {local_path}")
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[FeiShu] local file not found: {local_path}")
|
||||
return None
|
||||
|
||||
# Upload directly from local file
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
|
||||
data = {'image_type': 'message'}
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
with open(local_path, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
if response_data.get("code") == 0:
|
||||
return response_data.get("data").get("image_key")
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload failed: {response_data}")
|
||||
return None
|
||||
|
||||
# Original logic for HTTP URLs
|
||||
response = requests.get(img_url)
|
||||
suffix = utils.get_path_suffix(img_url)
|
||||
temp_name = str(uuid.uuid4()) + "." + suffix
|
||||
if response.status_code == 200:
|
||||
# 将图片内容保存为临时文件
|
||||
with open(temp_name, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
# upload
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
|
||||
data = {
|
||||
'image_type': 'message'
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
}
|
||||
with open(temp_name, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
os.remove(temp_name)
|
||||
return upload_response.json().get("data").get("image_key")
|
||||
|
||||
def _get_video_duration(self, file_path: str) -> int:
|
||||
"""
|
||||
获取视频时长(毫秒)
|
||||
|
||||
Args:
|
||||
file_path: 视频文件路径
|
||||
|
||||
Returns:
|
||||
视频时长(毫秒),如果获取失败返回0
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
# 使用 ffprobe 获取视频时长
|
||||
cmd = [
|
||||
'ffprobe',
|
||||
'-v', 'error',
|
||||
'-show_entries', 'format=duration',
|
||||
'-of', 'default=noprint_wrappers=1:nokey=1',
|
||||
file_path
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
if result.returncode == 0:
|
||||
duration_seconds = float(result.stdout.strip())
|
||||
duration_ms = int(duration_seconds * 1000)
|
||||
logger.info(f"[FeiShu] Video duration: {duration_seconds:.2f}s ({duration_ms}ms)")
|
||||
return duration_ms
|
||||
else:
|
||||
logger.warning(f"[FeiShu] Failed to get video duration via ffprobe: {result.stderr}")
|
||||
return 0
|
||||
except FileNotFoundError:
|
||||
logger.warning("[FeiShu] ffprobe not found, video duration will be 0. Install ffmpeg to fix this.")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Failed to get video duration: {e}")
|
||||
return 0
|
||||
|
||||
def _upload_video_url(self, video_url, access_token):
|
||||
"""
|
||||
Upload video to Feishu and return video info (file_key and duration)
|
||||
Supports:
|
||||
- file:// URLs for local files
|
||||
- http(s):// URLs (download then upload)
|
||||
|
||||
Returns:
|
||||
dict with 'file_key' and 'duration' (milliseconds), or None if failed
|
||||
"""
|
||||
local_path = None
|
||||
temp_file = None
|
||||
|
||||
try:
|
||||
# For file:// URLs (local files), upload directly
|
||||
if video_url.startswith("file://"):
|
||||
local_path = video_url[7:] # Remove file:// prefix
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[FeiShu] local video file not found: {local_path}")
|
||||
return None
|
||||
else:
|
||||
# For HTTP URLs, download first
|
||||
logger.info(f"[FeiShu] Downloading video from URL: {video_url}")
|
||||
response = requests.get(video_url, timeout=(5, 60))
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[FeiShu] download video failed, status={response.status_code}")
|
||||
return None
|
||||
|
||||
# Save to temp file
|
||||
import uuid
|
||||
file_name = os.path.basename(video_url) or "video.mp4"
|
||||
temp_file = str(uuid.uuid4()) + "_" + file_name
|
||||
|
||||
with open(temp_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
logger.info(f"[FeiShu] Video downloaded, size={len(response.content)} bytes")
|
||||
local_path = temp_file
|
||||
|
||||
# Get video duration
|
||||
duration = self._get_video_duration(local_path)
|
||||
|
||||
# Upload to Feishu
|
||||
file_name = os.path.basename(local_path)
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
file_type_map = {'.mp4': 'mp4'}
|
||||
file_type = file_type_map.get(file_ext, 'mp4')
|
||||
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
|
||||
data = {
|
||||
'file_type': file_type,
|
||||
'file_name': file_name
|
||||
}
|
||||
# Add duration only if available (required for video/audio)
|
||||
if duration:
|
||||
data['duration'] = duration # Must be int, not string
|
||||
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
logger.info(f"[FeiShu] Uploading video: file_name={file_name}, duration={duration}ms")
|
||||
|
||||
with open(local_path, "rb") as file:
|
||||
upload_response = requests.post(
|
||||
upload_url,
|
||||
files={"file": file},
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=(5, 60)
|
||||
)
|
||||
logger.info(
|
||||
f"[FeiShu] upload video response, status={upload_response.status_code}, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
if response_data.get("code") == 0:
|
||||
# Add duration to the response data (API doesn't return it)
|
||||
upload_data = response_data.get("data")
|
||||
upload_data['duration'] = duration # Add our calculated duration
|
||||
logger.info(
|
||||
f"[FeiShu] Upload complete: file_key={upload_data.get('file_key')}, duration={duration}ms")
|
||||
return upload_data
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload video failed: {response_data}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] upload video exception: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if temp_file and os.path.exists(temp_file):
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Failed to remove temp file {temp_file}: {e}")
|
||||
|
||||
def _upload_file_url(self, file_url, access_token):
|
||||
"""
|
||||
Upload file to Feishu
|
||||
Supports both local files (file://) and HTTP URLs
|
||||
"""
|
||||
logger.debug(f"[FeiShu] start process file, file_url={file_url}")
|
||||
|
||||
# Check if it's a local file path (file:// protocol)
|
||||
if file_url.startswith("file://"):
|
||||
local_path = file_url[7:] # Remove "file://" prefix
|
||||
logger.info(f"[FeiShu] uploading local file: {local_path}")
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[FeiShu] local file not found: {local_path}")
|
||||
return None
|
||||
|
||||
# Get file info
|
||||
file_name = os.path.basename(local_path)
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
|
||||
# Determine file type for Feishu API
|
||||
# Feishu supports: opus, mp4, pdf, doc, xls, ppt, stream (other types)
|
||||
file_type_map = {
|
||||
'.opus': 'opus',
|
||||
'.mp4': 'mp4',
|
||||
'.pdf': 'pdf',
|
||||
'.doc': 'doc', '.docx': 'doc',
|
||||
'.xls': 'xls', '.xlsx': 'xls',
|
||||
'.ppt': 'ppt', '.pptx': 'ppt',
|
||||
}
|
||||
file_type = file_type_map.get(file_ext, 'stream') # Default to stream for other types
|
||||
|
||||
# Upload file to Feishu
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
|
||||
data = {'file_type': file_type, 'file_name': file_name}
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
try:
|
||||
with open(local_path, "rb") as file:
|
||||
upload_response = requests.post(
|
||||
upload_url,
|
||||
files={"file": file},
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=(5, 30) # 5s connect, 30s read timeout
|
||||
)
|
||||
logger.info(
|
||||
f"[FeiShu] upload file response, status={upload_response.status_code}, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
if response_data.get("code") == 0:
|
||||
return response_data.get("data").get("file_key")
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload file failed: {response_data}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] upload file exception: {e}")
|
||||
return None
|
||||
|
||||
# For HTTP URLs, download first then upload
|
||||
try:
|
||||
response = requests.get(file_url, timeout=(5, 30))
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[FeiShu] download file failed, status={response.status_code}")
|
||||
return None
|
||||
|
||||
# Save to temp file
|
||||
import uuid
|
||||
file_name = os.path.basename(file_url)
|
||||
temp_name = str(uuid.uuid4()) + "_" + file_name
|
||||
|
||||
with open(temp_name, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
# Upload
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
file_type_map = {
|
||||
'.opus': 'opus', '.mp4': 'mp4', '.pdf': 'pdf',
|
||||
'.doc': 'doc', '.docx': 'doc',
|
||||
'.xls': 'xls', '.xlsx': 'xls',
|
||||
'.ppt': 'ppt', '.pptx': 'ppt',
|
||||
}
|
||||
file_type = file_type_map.get(file_ext, 'stream')
|
||||
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
|
||||
data = {'file_type': file_type, 'file_name': file_name}
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
with open(temp_name, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"file": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
os.remove(temp_name) # Clean up temp file
|
||||
|
||||
if response_data.get("code") == 0:
|
||||
return response_data.get("data").get("file_key")
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload file failed: {response_data}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] upload file from URL exception: {e}")
|
||||
return None
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
|
||||
# Set session_id based on chat type
|
||||
if cmsg.is_group:
|
||||
# Group chat: check if group_shared_session is enabled
|
||||
if conf().get("group_shared_session", True):
|
||||
# All users in the group share the same session context
|
||||
context["session_id"] = cmsg.other_user_id # group_id
|
||||
else:
|
||||
# Each user has their own session within the group
|
||||
# This ensures:
|
||||
# - Same user in different groups have separate conversation histories
|
||||
# - Same user in private chat and group chat have separate histories
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
else:
|
||||
# Private chat: use user_id only
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# 1.文本请求
|
||||
# 图片生成处理
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
elif context.type == ContextType.VOICE:
|
||||
# 2.语音请求
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class FeishuController:
|
||||
"""
|
||||
HTTP服务器控制器,用于webhook模式
|
||||
"""
|
||||
# 类常量
|
||||
FAILED_MSG = '{"success": false}'
|
||||
SUCCESS_MSG = '{"success": true}'
|
||||
MESSAGE_RECEIVE_TYPE = "im.message.receive_v1"
|
||||
|
||||
def GET(self):
|
||||
return "Feishu service start success!"
|
||||
|
||||
def POST(self):
|
||||
try:
|
||||
channel = FeiShuChanel()
|
||||
|
||||
request = json.loads(web.data().decode("utf-8"))
|
||||
logger.debug(f"[FeiShu] receive request: {request}")
|
||||
|
||||
# 1.事件订阅回调验证
|
||||
if request.get("type") == URL_VERIFICATION:
|
||||
varify_res = {"challenge": request.get("challenge")}
|
||||
return json.dumps(varify_res)
|
||||
|
||||
# 2.消息接收处理
|
||||
# token 校验
|
||||
header = request.get("header")
|
||||
if not header or header.get("token") != channel.feishu_token:
|
||||
return self.FAILED_MSG
|
||||
|
||||
# 处理消息事件
|
||||
event = request.get("event")
|
||||
if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event:
|
||||
channel._handle_message_event(event)
|
||||
|
||||
return self.SUCCESS_MSG
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.FAILED_MSG
|
||||
179
channel/feishu/feishu_message.py
Normal file
179
channel/feishu/feishu_message.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
class FeishuMessage(ChatMessage):
|
||||
def __init__(self, event: dict, is_group=False, access_token=None):
|
||||
super().__init__(event)
|
||||
msg = event.get("message")
|
||||
sender = event.get("sender")
|
||||
self.access_token = access_token
|
||||
self.msg_id = msg.get("message_id")
|
||||
self.create_time = msg.get("create_time")
|
||||
self.is_group = is_group
|
||||
msg_type = msg.get("message_type")
|
||||
|
||||
if msg_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
content = json.loads(msg.get('content'))
|
||||
self.content = content.get("text").strip()
|
||||
elif msg_type == "image":
|
||||
# 单张图片消息:下载并缓存,等待用户提问时一起发送
|
||||
self.ctype = ContextType.IMAGE
|
||||
content = json.loads(msg.get("content"))
|
||||
image_key = content.get("image_key")
|
||||
|
||||
# 下载图片到工作空间临时目录
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
image_path = os.path.join(tmp_dir, f"{image_key}.png")
|
||||
|
||||
# 下载图片
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.get('message_id')}/resources/{image_key}"
|
||||
headers = {"Authorization": "Bearer " + access_token}
|
||||
params = {"type": "image"}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"[FeiShu] Downloaded single image, key={image_key}, path={image_path}")
|
||||
self.content = image_path
|
||||
self.image_path = image_path # 保存图片路径
|
||||
else:
|
||||
logger.error(f"[FeiShu] Failed to download single image, key={image_key}, status={response.status_code}")
|
||||
self.content = f"[图片下载失败: {image_key}]"
|
||||
self.image_path = None
|
||||
elif msg_type == "post":
|
||||
# 富文本消息,可能包含图片、文本等多种元素
|
||||
content = json.loads(msg.get("content"))
|
||||
|
||||
# 飞书富文本消息结构:content 直接包含 title 和 content 数组
|
||||
# 不是嵌套在 post 字段下
|
||||
title = content.get("title", "")
|
||||
content_list = content.get("content", [])
|
||||
|
||||
logger.info(f"[FeiShu] Post message - title: '{title}', content_list length: {len(content_list)}")
|
||||
|
||||
# 收集所有图片和文本
|
||||
image_keys = []
|
||||
text_parts = []
|
||||
|
||||
if title:
|
||||
text_parts.append(title)
|
||||
|
||||
for block in content_list:
|
||||
logger.debug(f"[FeiShu] Processing block: {block}")
|
||||
# block 本身就是元素列表
|
||||
if not isinstance(block, list):
|
||||
continue
|
||||
|
||||
for element in block:
|
||||
element_tag = element.get("tag")
|
||||
logger.debug(f"[FeiShu] Element tag: {element_tag}, element: {element}")
|
||||
if element_tag == "img":
|
||||
# 找到图片元素
|
||||
image_key = element.get("image_key")
|
||||
if image_key:
|
||||
image_keys.append(image_key)
|
||||
elif element_tag == "text":
|
||||
# 文本元素
|
||||
text_content = element.get("text", "")
|
||||
if text_content:
|
||||
text_parts.append(text_content)
|
||||
|
||||
logger.info(f"[FeiShu] Parsed - images: {len(image_keys)}, text_parts: {text_parts}")
|
||||
|
||||
# 富文本消息统一作为文本消息处理
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
if image_keys:
|
||||
# 如果包含图片,下载并在文本中引用本地路径
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
# 保存图片路径映射
|
||||
self.image_paths = {}
|
||||
for image_key in image_keys:
|
||||
image_path = os.path.join(tmp_dir, f"{image_key}.png")
|
||||
self.image_paths[image_key] = image_path
|
||||
|
||||
def _download_images():
|
||||
for image_key, image_path in self.image_paths.items():
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{image_key}"
|
||||
headers = {"Authorization": "Bearer " + access_token}
|
||||
params = {"type": "image"}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
if response.status_code == 200:
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"[FeiShu] Image downloaded from post message, key={image_key}, path={image_path}")
|
||||
else:
|
||||
logger.error(f"[FeiShu] Failed to download image from post, key={image_key}, status={response.status_code}")
|
||||
|
||||
# 立即下载图片,不使用延迟下载
|
||||
# 因为 TEXT 类型消息不会调用 prepare()
|
||||
_download_images()
|
||||
|
||||
# 构建消息内容:文本 + 图片路径
|
||||
content_parts = []
|
||||
if text_parts:
|
||||
content_parts.append("\n".join(text_parts).strip())
|
||||
for image_key, image_path in self.image_paths.items():
|
||||
content_parts.append(f"[图片: {image_path}]")
|
||||
|
||||
self.content = "\n".join(content_parts)
|
||||
logger.info(f"[FeiShu] Received post message with {len(image_keys)} image(s) and text: {self.content}")
|
||||
else:
|
||||
# 纯文本富文本消息
|
||||
self.content = "\n".join(text_parts).strip() if text_parts else "[富文本消息]"
|
||||
logger.info(f"[FeiShu] Received post message (text only): {self.content}")
|
||||
elif msg_type == "file":
|
||||
self.ctype = ContextType.FILE
|
||||
content = json.loads(msg.get("content"))
|
||||
file_key = content.get("file_key")
|
||||
file_name = content.get("file_name")
|
||||
|
||||
self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name)
|
||||
|
||||
def _download_file():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}"
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
}
|
||||
params = {
|
||||
"type": "file"
|
||||
}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
|
||||
self._prepare_fn = _download_file
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))
|
||||
|
||||
self.from_user_id = sender.get("sender_id").get("open_id")
|
||||
self.to_user_id = event.get("app_id")
|
||||
if is_group:
|
||||
# 群聊
|
||||
self.other_user_id = msg.get("chat_id")
|
||||
self.actual_user_id = self.from_user_id
|
||||
self.content = self.content.replace("@_user_1", "").strip()
|
||||
self.actual_user_nickname = ""
|
||||
else:
|
||||
# 私聊
|
||||
self.other_user_id = self.from_user_id
|
||||
self.actual_user_id = self.from_user_id
|
||||
100
channel/file_cache.py
Normal file
100
channel/file_cache.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
文件缓存管理器
|
||||
用于缓存单独发送的文件消息(图片、视频、文档等),在用户提问时自动附加
|
||||
"""
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileCache:
|
||||
"""文件缓存管理器,按 session_id 缓存文件,TTL=2分钟"""
|
||||
|
||||
def __init__(self, ttl=120):
|
||||
"""
|
||||
Args:
|
||||
ttl: 缓存过期时间(秒),默认2分钟
|
||||
"""
|
||||
self.cache = {}
|
||||
self.ttl = ttl
|
||||
|
||||
def add(self, session_id: str, file_path: str, file_type: str = "image"):
|
||||
"""
|
||||
添加文件到缓存
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
file_path: 文件本地路径
|
||||
file_type: 文件类型(image, video, file 等)
|
||||
"""
|
||||
if session_id not in self.cache:
|
||||
self.cache[session_id] = {
|
||||
'files': [],
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
# 添加文件(去重)
|
||||
file_info = {'path': file_path, 'type': file_type}
|
||||
if file_info not in self.cache[session_id]['files']:
|
||||
self.cache[session_id]['files'].append(file_info)
|
||||
logger.info(f"[FileCache] Added {file_type} to cache for session {session_id}: {file_path}")
|
||||
|
||||
def get(self, session_id: str) -> list:
|
||||
"""
|
||||
获取缓存的文件列表
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
文件信息列表 [{'path': '...', 'type': 'image'}, ...],如果没有或已过期返回空列表
|
||||
"""
|
||||
if session_id not in self.cache:
|
||||
return []
|
||||
|
||||
item = self.cache[session_id]
|
||||
|
||||
# 检查是否过期
|
||||
if time.time() - item['timestamp'] > self.ttl:
|
||||
logger.info(f"[FileCache] Cache expired for session {session_id}, clearing...")
|
||||
del self.cache[session_id]
|
||||
return []
|
||||
|
||||
return item['files']
|
||||
|
||||
def clear(self, session_id: str):
|
||||
"""
|
||||
清除指定会话的缓存
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
if session_id in self.cache:
|
||||
logger.info(f"[FileCache] Cleared cache for session {session_id}")
|
||||
del self.cache[session_id]
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""清理所有过期的缓存"""
|
||||
current_time = time.time()
|
||||
expired_sessions = []
|
||||
|
||||
for session_id, item in self.cache.items():
|
||||
if current_time - item['timestamp'] > self.ttl:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del self.cache[session_id]
|
||||
logger.debug(f"[FileCache] Cleaned up expired cache for session {session_id}")
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"[FileCache] Cleaned up {len(expired_sessions)} expired cache(s)")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_file_cache = FileCache()
|
||||
|
||||
|
||||
def get_file_cache() -> FileCache:
|
||||
"""获取全局文件缓存实例"""
|
||||
return _file_cache
|
||||
0
channel/qq/__init__.py
Normal file
0
channel/qq/__init__.py
Normal file
735
channel/qq/qq_channel.py
Normal file
735
channel/qq/qq_channel.py
Normal file
@@ -0,0 +1,735 @@
|
||||
"""
|
||||
QQ Bot channel via WebSocket long connection.
|
||||
|
||||
Supports:
|
||||
- Group chat (@bot), single chat (C2C), guild channel, guild DM
|
||||
- Text / image / file message send & receive
|
||||
- Heartbeat keep-alive and auto-reconnect with session resume
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
import websocket
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.qq.qq_message import QQMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
# Rich media file_type constants
|
||||
QQ_FILE_TYPE_IMAGE = 1
|
||||
QQ_FILE_TYPE_VIDEO = 2
|
||||
QQ_FILE_TYPE_VOICE = 3
|
||||
QQ_FILE_TYPE_FILE = 4
|
||||
|
||||
QQ_API_BASE = "https://api.sgroup.qq.com"
|
||||
|
||||
# Intents: GROUP_AND_C2C_EVENT(1<<25) | PUBLIC_GUILD_MESSAGES(1<<30)
|
||||
DEFAULT_INTENTS = (1 << 25) | (1 << 30)
|
||||
|
||||
# OpCode constants
|
||||
OP_DISPATCH = 0
|
||||
OP_HEARTBEAT = 1
|
||||
OP_IDENTIFY = 2
|
||||
OP_RESUME = 6
|
||||
OP_RECONNECT = 7
|
||||
OP_INVALID_SESSION = 9
|
||||
OP_HELLO = 10
|
||||
OP_HEARTBEAT_ACK = 11
|
||||
|
||||
# Resumable error codes
|
||||
RESUMABLE_CLOSE_CODES = {4008, 4009}
|
||||
|
||||
|
||||
@singleton
|
||||
class QQChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app_id = ""
|
||||
self.app_secret = ""
|
||||
|
||||
self._access_token = ""
|
||||
self._token_expires_at = 0
|
||||
|
||||
self._ws = None
|
||||
self._ws_thread = None
|
||||
self._heartbeat_thread = None
|
||||
self._connected = False
|
||||
self._stop_event = threading.Event()
|
||||
self._token_lock = threading.Lock()
|
||||
|
||||
self._session_id = None
|
||||
self._last_seq = None
|
||||
self._heartbeat_interval = 45000
|
||||
self._can_resume = False
|
||||
|
||||
self.received_msgs = ExpiredDict(60 * 60 * 7.1)
|
||||
self._msg_seq_counter = {}
|
||||
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def startup(self):
|
||||
self.app_id = conf().get("qq_app_id", "")
|
||||
self.app_secret = conf().get("qq_app_secret", "")
|
||||
|
||||
if not self.app_id or not self.app_secret:
|
||||
err = "[QQ] qq_app_id and qq_app_secret are required"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._refresh_access_token()
|
||||
if not self._access_token:
|
||||
err = "[QQ] Failed to get initial access_token"
|
||||
logger.error(err)
|
||||
self.report_startup_error(err)
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_ws()
|
||||
|
||||
def stop(self):
|
||||
logger.info("[QQ] stop() called")
|
||||
self._stop_event.set()
|
||||
if self._ws:
|
||||
try:
|
||||
self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Access Token
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _refresh_access_token(self):
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://bots.qq.com/app/getAppAccessToken",
|
||||
json={"appId": self.app_id, "clientSecret": self.app_secret},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._access_token = data.get("access_token", "")
|
||||
expires_in = int(data.get("expires_in", 7200))
|
||||
self._token_expires_at = time.time() + expires_in - 60
|
||||
logger.debug(f"[QQ] Access token refreshed, expires_in={expires_in}s")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to refresh access_token: {e}")
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
with self._token_lock:
|
||||
if time.time() >= self._token_expires_at:
|
||||
self._refresh_access_token()
|
||||
return self._access_token
|
||||
|
||||
def _get_auth_headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": f"QQBot {self._get_access_token()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket connection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_ws_url(self) -> str:
|
||||
try:
|
||||
resp = requests.get(
|
||||
f"{QQ_API_BASE}/gateway",
|
||||
headers=self._get_auth_headers(),
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
url = resp.json().get("url", "")
|
||||
logger.debug(f"[QQ] Gateway URL: {url}")
|
||||
return url
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to get gateway URL: {e}")
|
||||
return ""
|
||||
|
||||
def _start_ws(self):
|
||||
ws_url = self._get_ws_url()
|
||||
if not ws_url:
|
||||
logger.error("[QQ] Cannot start WebSocket without gateway URL")
|
||||
self.report_startup_error("Failed to get gateway URL")
|
||||
return
|
||||
|
||||
def _on_open(ws):
|
||||
logger.debug("[QQ] WebSocket connected, waiting for Hello...")
|
||||
|
||||
def _on_message(ws, raw):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
self._handle_ws_message(data)
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to handle ws message: {e}", exc_info=True)
|
||||
|
||||
def _on_error(ws, error):
|
||||
logger.error(f"[QQ] WebSocket error: {error}")
|
||||
|
||||
def _on_close(ws, close_status_code, close_msg):
|
||||
logger.warning(f"[QQ] WebSocket closed: status={close_status_code}, msg={close_msg}")
|
||||
self._connected = False
|
||||
if not self._stop_event.is_set():
|
||||
if close_status_code in RESUMABLE_CLOSE_CODES and self._session_id:
|
||||
self._can_resume = True
|
||||
logger.info("[QQ] Will attempt resume in 3s...")
|
||||
time.sleep(3)
|
||||
else:
|
||||
self._can_resume = False
|
||||
logger.info("[QQ] Will reconnect in 5s...")
|
||||
time.sleep(5)
|
||||
if not self._stop_event.is_set():
|
||||
self._start_ws()
|
||||
|
||||
self._ws = websocket.WebSocketApp(
|
||||
ws_url,
|
||||
on_open=_on_open,
|
||||
on_message=_on_message,
|
||||
on_error=_on_error,
|
||||
on_close=_on_close,
|
||||
)
|
||||
|
||||
def run_forever():
|
||||
try:
|
||||
self._ws.run_forever(ping_interval=0, reconnect=0)
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
logger.info("[QQ] WebSocket thread interrupted")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] WebSocket run_forever error: {e}")
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_forever, daemon=True)
|
||||
self._ws_thread.start()
|
||||
self._ws_thread.join()
|
||||
|
||||
def _ws_send(self, data: dict):
|
||||
if self._ws:
|
||||
self._ws.send(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Identify & Resume & Heartbeat
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_identify(self):
|
||||
self._ws_send({
|
||||
"op": OP_IDENTIFY,
|
||||
"d": {
|
||||
"token": f"QQBot {self._get_access_token()}",
|
||||
"intents": DEFAULT_INTENTS,
|
||||
"shard": [0, 1],
|
||||
"properties": {
|
||||
"$os": "linux",
|
||||
"$browser": "chatgpt-on-wechat",
|
||||
"$device": "chatgpt-on-wechat",
|
||||
},
|
||||
},
|
||||
})
|
||||
logger.debug(f"[QQ] Identify sent with intents={DEFAULT_INTENTS}")
|
||||
|
||||
def _send_resume(self):
|
||||
self._ws_send({
|
||||
"op": OP_RESUME,
|
||||
"d": {
|
||||
"token": f"QQBot {self._get_access_token()}",
|
||||
"session_id": self._session_id,
|
||||
"seq": self._last_seq,
|
||||
},
|
||||
})
|
||||
logger.debug(f"[QQ] Resume sent: session_id={self._session_id}, seq={self._last_seq}")
|
||||
|
||||
def _start_heartbeat(self, interval_ms: int):
|
||||
if self._heartbeat_thread and self._heartbeat_thread.is_alive():
|
||||
return
|
||||
self._heartbeat_interval = interval_ms
|
||||
interval_sec = interval_ms / 1000.0
|
||||
|
||||
def heartbeat_loop():
|
||||
while not self._stop_event.is_set() and self._connected:
|
||||
try:
|
||||
self._ws_send({
|
||||
"op": OP_HEARTBEAT,
|
||||
"d": self._last_seq,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"[QQ] Heartbeat send failed: {e}")
|
||||
break
|
||||
self._stop_event.wait(interval_sec)
|
||||
|
||||
self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
|
||||
self._heartbeat_thread.start()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Incoming message dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_ws_message(self, data: dict):
|
||||
op = data.get("op")
|
||||
d = data.get("d")
|
||||
t = data.get("t")
|
||||
s = data.get("s")
|
||||
|
||||
if s is not None:
|
||||
self._last_seq = s
|
||||
|
||||
if op == OP_HELLO:
|
||||
heartbeat_interval = d.get("heartbeat_interval", 45000) if d else 45000
|
||||
logger.debug(f"[QQ] Received Hello, heartbeat_interval={heartbeat_interval}ms")
|
||||
self._heartbeat_interval = heartbeat_interval
|
||||
if self._can_resume and self._session_id:
|
||||
self._send_resume()
|
||||
else:
|
||||
self._send_identify()
|
||||
|
||||
elif op == OP_HEARTBEAT_ACK:
|
||||
logger.debug("[QQ] Heartbeat ACK received")
|
||||
|
||||
elif op == OP_HEARTBEAT:
|
||||
self._ws_send({"op": OP_HEARTBEAT, "d": self._last_seq})
|
||||
|
||||
elif op == OP_RECONNECT:
|
||||
logger.warning("[QQ] Server requested reconnect")
|
||||
self._can_resume = True
|
||||
if self._ws:
|
||||
self._ws.close()
|
||||
|
||||
elif op == OP_INVALID_SESSION:
|
||||
logger.warning("[QQ] Invalid session, re-identifying...")
|
||||
self._session_id = None
|
||||
self._can_resume = False
|
||||
time.sleep(2)
|
||||
self._send_identify()
|
||||
|
||||
elif op == OP_DISPATCH:
|
||||
if t == "READY":
|
||||
self._session_id = d.get("session_id", "")
|
||||
user = d.get("user", {})
|
||||
bot_name = user.get('username', '')
|
||||
logger.info(f"[QQ] ✅ Connected successfully (bot={bot_name})")
|
||||
self._connected = True
|
||||
self._can_resume = False
|
||||
self._start_heartbeat(self._heartbeat_interval)
|
||||
self.report_startup_success()
|
||||
|
||||
elif t == "RESUMED":
|
||||
logger.info("[QQ] Session resumed successfully")
|
||||
self._connected = True
|
||||
self._can_resume = False
|
||||
self._start_heartbeat(self._heartbeat_interval)
|
||||
|
||||
elif t in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE",
|
||||
"AT_MESSAGE_CREATE", "DIRECT_MESSAGE_CREATE"):
|
||||
self._handle_msg_event(d, t)
|
||||
|
||||
elif t in ("GROUP_ADD_ROBOT", "FRIEND_ADD"):
|
||||
logger.info(f"[QQ] Event: {t}")
|
||||
|
||||
else:
|
||||
logger.debug(f"[QQ] Dispatch event: {t}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message event handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _handle_msg_event(self, event_data: dict, event_type: str):
|
||||
msg_id = event_data.get("id", "")
|
||||
if self.received_msgs.get(msg_id):
|
||||
logger.debug(f"[QQ] Duplicate msg filtered: {msg_id}")
|
||||
return
|
||||
self.received_msgs[msg_id] = True
|
||||
|
||||
try:
|
||||
qq_msg = QQMessage(event_data, event_type)
|
||||
except NotImplementedError as e:
|
||||
logger.warning(f"[QQ] {e}")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to parse message: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
is_group = qq_msg.is_group
|
||||
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
if is_group:
|
||||
session_id = qq_msg.other_user_id
|
||||
else:
|
||||
session_id = qq_msg.from_user_id
|
||||
|
||||
if qq_msg.ctype == ContextType.IMAGE:
|
||||
if hasattr(qq_msg, "image_path") and qq_msg.image_path:
|
||||
file_cache.add(session_id, qq_msg.image_path, file_type="image")
|
||||
logger.info(f"[QQ] Image cached for session {session_id}")
|
||||
return
|
||||
|
||||
if qq_msg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
file_refs = []
|
||||
for fi in cached_files:
|
||||
ftype = fi["type"]
|
||||
fpath = fi["path"]
|
||||
if ftype == "image":
|
||||
file_refs.append(f"[图片: {fpath}]")
|
||||
elif ftype == "video":
|
||||
file_refs.append(f"[视频: {fpath}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {fpath}]")
|
||||
qq_msg.content = qq_msg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[QQ] Attached {len(cached_files)} cached file(s)")
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(
|
||||
qq_msg.ctype,
|
||||
qq_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=qq_msg,
|
||||
no_need_at=True,
|
||||
)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _compose_context
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
|
||||
if cmsg.is_group:
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
else:
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
return context
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send reply
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context.get("isgroup", False)
|
||||
receiver = context.get("receiver", "")
|
||||
|
||||
if not msg:
|
||||
# Active send (e.g. scheduled tasks), no original message to reply to
|
||||
self._active_send_text(reply.content if reply.type == ReplyType.TEXT else str(reply.content),
|
||||
receiver, is_group)
|
||||
return
|
||||
|
||||
event_type = getattr(msg, "event_type", "")
|
||||
msg_id = getattr(msg, "msg_id", "")
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
self._send_text(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type in (ReplyType.IMAGE_URL, ReplyType.IMAGE):
|
||||
self._send_image(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type == ReplyType.FILE:
|
||||
if hasattr(reply, "text_content") and reply.text_content:
|
||||
self._send_text(reply.text_content, msg, event_type, msg_id)
|
||||
time.sleep(0.3)
|
||||
self._send_file(reply.content, msg, event_type, msg_id)
|
||||
elif reply.type in (ReplyType.VIDEO, ReplyType.VIDEO_URL):
|
||||
self._send_media(reply.content, msg, event_type, msg_id, QQ_FILE_TYPE_VIDEO)
|
||||
else:
|
||||
logger.warning(f"[QQ] Unsupported reply type: {reply.type}, falling back to text")
|
||||
self._send_text(str(reply.content), msg, event_type, msg_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_next_msg_seq(self, msg_id: str) -> int:
|
||||
seq = self._msg_seq_counter.get(msg_id, 1)
|
||||
self._msg_seq_counter[msg_id] = seq + 1
|
||||
return seq
|
||||
|
||||
def _build_msg_url_and_base_body(self, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Build the API URL and base body dict for sending a message."""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
url = f"{QQ_API_BASE}/v2/groups/{group_openid}/messages"
|
||||
body = {
|
||||
"msg_id": msg_id,
|
||||
"msg_seq": self._get_next_msg_seq(msg_id),
|
||||
}
|
||||
return url, body, "group", group_openid
|
||||
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = msg._rawmsg.get("author", {}).get("user_openid", "") or msg.from_user_id
|
||||
url = f"{QQ_API_BASE}/v2/users/{user_openid}/messages"
|
||||
body = {
|
||||
"msg_id": msg_id,
|
||||
"msg_seq": self._get_next_msg_seq(msg_id),
|
||||
}
|
||||
return url, body, "c2c", user_openid
|
||||
|
||||
elif event_type == "AT_MESSAGE_CREATE":
|
||||
channel_id = msg._rawmsg.get("channel_id", "")
|
||||
url = f"{QQ_API_BASE}/channels/{channel_id}/messages"
|
||||
body = {"msg_id": msg_id}
|
||||
return url, body, "channel", channel_id
|
||||
|
||||
elif event_type == "DIRECT_MESSAGE_CREATE":
|
||||
guild_id = msg._rawmsg.get("guild_id", "")
|
||||
url = f"{QQ_API_BASE}/dms/{guild_id}/messages"
|
||||
body = {"msg_id": msg_id}
|
||||
return url, body, "dm", guild_id
|
||||
|
||||
return None, None, None, None
|
||||
|
||||
def _post_message(self, url: str, body: dict, event_type: str):
|
||||
try:
|
||||
resp = requests.post(url, json=body, headers=self._get_auth_headers(), timeout=10)
|
||||
if resp.status_code in (200, 201, 202, 204):
|
||||
logger.info(f"[QQ] Message sent successfully: event_type={event_type}")
|
||||
else:
|
||||
logger.error(f"[QQ] Failed to send message: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Send message error: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Active send (no original message, e.g. scheduled tasks)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _active_send_text(self, content: str, receiver: str, is_group: bool):
|
||||
"""Send text without an original message (active push). QQ limits active messages to 4/month per user."""
|
||||
if not receiver:
|
||||
logger.warning("[QQ] No receiver for active send")
|
||||
return
|
||||
if is_group:
|
||||
url = f"{QQ_API_BASE}/v2/groups/{receiver}/messages"
|
||||
else:
|
||||
url = f"{QQ_API_BASE}/v2/users/{receiver}/messages"
|
||||
body = {
|
||||
"content": content,
|
||||
"msg_type": 0,
|
||||
}
|
||||
event_label = "GROUP_ACTIVE" if is_group else "C2C_ACTIVE"
|
||||
self._post_message(url, body, event_label)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Send text
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _send_text(self, content: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
|
||||
if not url:
|
||||
logger.warning(f"[QQ] Cannot send reply for event_type: {event_type}")
|
||||
return
|
||||
body["content"] = content
|
||||
body["msg_type"] = 0
|
||||
self._post_message(url, body, event_type)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rich media upload & send (image / video / file)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _upload_rich_media(self, file_url: str, file_type: int, msg: QQMessage,
|
||||
event_type: str) -> str:
|
||||
"""
|
||||
Upload media via QQ rich media API and return file_info.
|
||||
For group: POST /v2/groups/{group_openid}/files
|
||||
For c2c: POST /v2/users/{openid}/files
|
||||
"""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
|
||||
or msg.from_user_id)
|
||||
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
|
||||
else:
|
||||
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
|
||||
return ""
|
||||
|
||||
upload_body = {
|
||||
"file_type": file_type,
|
||||
"url": file_url,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
upload_url, json=upload_body,
|
||||
headers=self._get_auth_headers(), timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
data = resp.json()
|
||||
file_info = data.get("file_info", "")
|
||||
logger.info(f"[QQ] Rich media uploaded: file_type={file_type}, "
|
||||
f"file_uuid={data.get('file_uuid', '')}")
|
||||
return file_info
|
||||
else:
|
||||
logger.error(f"[QQ] Rich media upload failed: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Rich media upload error: {e}")
|
||||
return ""
|
||||
|
||||
def _upload_rich_media_base64(self, file_path: str, file_type: int, msg: QQMessage,
|
||||
event_type: str) -> str:
|
||||
"""Upload local file via base64 file_data field."""
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
group_openid = msg._rawmsg.get("group_openid", "")
|
||||
upload_url = f"{QQ_API_BASE}/v2/groups/{group_openid}/files"
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = (msg._rawmsg.get("author", {}).get("user_openid", "")
|
||||
or msg.from_user_id)
|
||||
upload_url = f"{QQ_API_BASE}/v2/users/{user_openid}/files"
|
||||
else:
|
||||
logger.warning(f"[QQ] Rich media upload not supported for event_type: {event_type}")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
file_data = base64.b64encode(f.read()).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to read file for upload: {e}")
|
||||
return ""
|
||||
|
||||
upload_body = {
|
||||
"file_type": file_type,
|
||||
"file_data": file_data,
|
||||
"srv_send_msg": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
upload_url, json=upload_body,
|
||||
headers=self._get_auth_headers(), timeout=30,
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
data = resp.json()
|
||||
file_info = data.get("file_info", "")
|
||||
logger.info(f"[QQ] Rich media uploaded (base64): file_type={file_type}, "
|
||||
f"file_uuid={data.get('file_uuid', '')}")
|
||||
return file_info
|
||||
else:
|
||||
logger.error(f"[QQ] Rich media upload (base64) failed: status={resp.status_code}, "
|
||||
f"body={resp.text}")
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Rich media upload (base64) error: {e}")
|
||||
return ""
|
||||
|
||||
def _send_media_msg(self, file_info: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send a message with msg_type=7 (rich media) using file_info."""
|
||||
url, body, _, _ = self._build_msg_url_and_base_body(msg, event_type, msg_id)
|
||||
if not url:
|
||||
return
|
||||
body["msg_type"] = 7
|
||||
body["media"] = {"file_info": file_info}
|
||||
self._post_message(url, body, event_type)
|
||||
|
||||
def _send_image(self, img_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send image reply. Supports URL and local file path."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(img_path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if img_path_or_url.startswith("file://"):
|
||||
img_path_or_url = img_path_or_url[7:]
|
||||
|
||||
if img_path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(
|
||||
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
|
||||
elif os.path.exists(img_path_or_url):
|
||||
file_info = self._upload_rich_media_base64(
|
||||
img_path_or_url, QQ_FILE_TYPE_IMAGE, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] Image not found: {img_path_or_url}")
|
||||
self._send_text("[Image send failed]", msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
self._send_text("[Image upload failed]", msg, event_type, msg_id)
|
||||
|
||||
def _send_file(self, file_path_or_url: str, msg: QQMessage, event_type: str, msg_id: str):
|
||||
"""Send file reply."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(file_path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_path_or_url.startswith("file://"):
|
||||
file_path_or_url = file_path_or_url[7:]
|
||||
|
||||
if file_path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(
|
||||
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
|
||||
elif os.path.exists(file_path_or_url):
|
||||
file_info = self._upload_rich_media_base64(
|
||||
file_path_or_url, QQ_FILE_TYPE_FILE, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] File not found: {file_path_or_url}")
|
||||
self._send_text("[File send failed]", msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
self._send_text("[File upload failed]", msg, event_type, msg_id)
|
||||
|
||||
def _send_media(self, path_or_url: str, msg: QQMessage, event_type: str,
|
||||
msg_id: str, file_type: int):
|
||||
"""Generic media send for video/voice etc."""
|
||||
if event_type not in ("GROUP_AT_MESSAGE_CREATE", "C2C_MESSAGE_CREATE"):
|
||||
self._send_text(str(path_or_url), msg, event_type, msg_id)
|
||||
return
|
||||
|
||||
if path_or_url.startswith("file://"):
|
||||
path_or_url = path_or_url[7:]
|
||||
|
||||
if path_or_url.startswith(("http://", "https://")):
|
||||
file_info = self._upload_rich_media(path_or_url, file_type, msg, event_type)
|
||||
elif os.path.exists(path_or_url):
|
||||
file_info = self._upload_rich_media_base64(path_or_url, file_type, msg, event_type)
|
||||
else:
|
||||
logger.error(f"[QQ] Media not found: {path_or_url}")
|
||||
return
|
||||
|
||||
if file_info:
|
||||
self._send_media_msg(file_info, msg, event_type, msg_id)
|
||||
else:
|
||||
logger.error(f"[QQ] Media upload failed: {path_or_url}")
|
||||
123
channel/qq/qq_message.py
Normal file
123
channel/qq/qq_message.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
import requests
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
def _get_tmp_dir() -> str:
|
||||
"""Return the workspace tmp directory (absolute path), creating it if needed."""
|
||||
ws_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(ws_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
||||
class QQMessage(ChatMessage):
|
||||
"""Message wrapper for QQ Bot (websocket long-connection mode)."""
|
||||
|
||||
def __init__(self, event_data: dict, event_type: str):
|
||||
super().__init__(event_data)
|
||||
self.msg_id = event_data.get("id", "")
|
||||
self.create_time = event_data.get("timestamp", "")
|
||||
self.is_group = event_type in ("GROUP_AT_MESSAGE_CREATE",)
|
||||
self.event_type = event_type
|
||||
|
||||
author = event_data.get("author", {})
|
||||
from_user_id = author.get("member_openid", "") or author.get("id", "")
|
||||
group_openid = event_data.get("group_openid", "")
|
||||
|
||||
content = event_data.get("content", "").strip()
|
||||
|
||||
attachments = event_data.get("attachments", [])
|
||||
has_image = any(
|
||||
a.get("content_type", "").startswith("image/") for a in attachments
|
||||
) if attachments else False
|
||||
|
||||
if has_image and not content:
|
||||
self.ctype = ContextType.IMAGE
|
||||
img_attachment = next(
|
||||
a for a in attachments if a.get("content_type", "").startswith("image/")
|
||||
)
|
||||
img_url = img_attachment.get("url", "")
|
||||
if img_url and not img_url.startswith("http"):
|
||||
img_url = "https://" + img_url
|
||||
tmp_dir = _get_tmp_dir()
|
||||
image_path = os.path.join(tmp_dir, f"qq_{self.msg_id}.png")
|
||||
try:
|
||||
resp = requests.get(img_url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
self.content = image_path
|
||||
self.image_path = image_path
|
||||
logger.info(f"[QQ] Image downloaded: {image_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to download image: {e}")
|
||||
self.content = "[Image download failed]"
|
||||
self.image_path = None
|
||||
elif has_image and content:
|
||||
self.ctype = ContextType.TEXT
|
||||
image_paths = []
|
||||
tmp_dir = _get_tmp_dir()
|
||||
for idx, att in enumerate(attachments):
|
||||
if not att.get("content_type", "").startswith("image/"):
|
||||
continue
|
||||
img_url = att.get("url", "")
|
||||
if img_url and not img_url.startswith("http"):
|
||||
img_url = "https://" + img_url
|
||||
img_path = os.path.join(tmp_dir, f"qq_{self.msg_id}_{idx}.png")
|
||||
try:
|
||||
resp = requests.get(img_url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
with open(img_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
image_paths.append(img_path)
|
||||
except Exception as e:
|
||||
logger.error(f"[QQ] Failed to download mixed image: {e}")
|
||||
content_parts = [content]
|
||||
for p in image_paths:
|
||||
content_parts.append(f"[图片: {p}]")
|
||||
self.content = "\n".join(content_parts)
|
||||
else:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = content
|
||||
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
self.other_user_id = group_openid
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = from_user_id
|
||||
|
||||
elif event_type == "C2C_MESSAGE_CREATE":
|
||||
user_openid = author.get("user_openid", "") or from_user_id
|
||||
self.from_user_id = user_openid
|
||||
self.to_user_id = ""
|
||||
self.other_user_id = user_openid
|
||||
self.actual_user_id = user_openid
|
||||
|
||||
elif event_type == "AT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
channel_id = event_data.get("channel_id", "")
|
||||
self.other_user_id = channel_id
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = author.get("username", from_user_id)
|
||||
|
||||
elif event_type == "DIRECT_MESSAGE_CREATE":
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = ""
|
||||
guild_id = event_data.get("guild_id", "")
|
||||
self.other_user_id = f"dm_{guild_id}_{from_user_id}"
|
||||
self.actual_user_id = from_user_id
|
||||
self.actual_user_nickname = author.get("username", from_user_id)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported QQ event type: {event_type}")
|
||||
|
||||
logger.debug(f"[QQ] Message parsed: type={event_type}, ctype={self.ctype}, "
|
||||
f"from={self.from_user_id}, content_len={len(self.content)}")
|
||||
@@ -1,14 +1,23 @@
|
||||
import sys
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.chat_message import ChatMessage
|
||||
import sys
|
||||
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class TerminalMessage(ChatMessage):
|
||||
def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"):
|
||||
def __init__(
|
||||
self,
|
||||
msg_id,
|
||||
content,
|
||||
ctype=ContextType.TEXT,
|
||||
from_user_id="User",
|
||||
to_user_id="Chatgpt",
|
||||
other_user_id="Chatgpt",
|
||||
):
|
||||
self.msg_id = msg_id
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
@@ -16,6 +25,7 @@ class TerminalMessage(ChatMessage):
|
||||
self.to_user_id = to_user_id
|
||||
self.other_user_id = other_user_id
|
||||
|
||||
|
||||
class TerminalChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
||||
|
||||
@@ -23,14 +33,18 @@ class TerminalChannel(ChatChannel):
|
||||
print("\nBot:")
|
||||
if reply.type == ReplyType.IMAGE:
|
||||
from PIL import Image
|
||||
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print("<IMAGE>")
|
||||
img.show()
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
import requests,io
|
||||
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
@@ -59,11 +73,12 @@ class TerminalChannel(ChatChannel):
|
||||
print("\nExiting...")
|
||||
sys.exit()
|
||||
msg_id += 1
|
||||
trigger_prefixs = conf().get("single_chat_prefix",[""])
|
||||
trigger_prefixs = conf().get("single_chat_prefix", [""])
|
||||
if check_prefix(prompt, trigger_prefixs) is None:
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt))
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
|
||||
context["isgroup"] = False
|
||||
if context:
|
||||
self.produce(context)
|
||||
else:
|
||||
|
||||
10
channel/web/README.md
Normal file
10
channel/web/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# Web Channel
|
||||
|
||||
提供了一个默认的AI对话页面,可展示文本、图片等消息交互,支持markdown语法渲染,兼容插件执行。
|
||||
|
||||
# 使用说明
|
||||
|
||||
- 在 `config.json` 配置文件中的 `channel_type` 字段填入 `web`
|
||||
- 程序运行后将监听9899端口,浏览器访问 http://localhost:9899/chat 即可使用
|
||||
- 监听端口可以在配置文件 `web_port` 中自定义
|
||||
- 对于Docker运行方式,如果需要外部访问,需要在 `docker-compose.yml` 中通过 ports配置将端口监听映射到宿主机
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user