mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 18:17:11 +08:00
Compare commits
1041 Commits
1.2.2.1
...
feat-web-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7d258b5202 | ||
|
|
5edbf4ce32 | ||
|
|
3ddbdd713d | ||
|
|
9ba107b511 | ||
|
|
c9adddb76a | ||
|
|
f0a12d5ff5 | ||
|
|
7cce224499 | ||
|
|
97397ca585 | ||
|
|
f2fbc602a8 | ||
|
|
925d728a86 | ||
|
|
f5f229871b | ||
|
|
9917552b4b | ||
|
|
adca89b973 | ||
|
|
29bfbecdc9 | ||
|
|
1a7a8c98d9 | ||
|
|
cddb38ac3d | ||
|
|
394853c0fb | ||
|
|
c0702c8b36 | ||
|
|
d610608391 | ||
|
|
9082eec91d | ||
|
|
f1a1413b5f | ||
|
|
c1e7f9af9b | ||
|
|
1c71c4e38b | ||
|
|
5e3eccb3f6 | ||
|
|
e1dc037eb9 | ||
|
|
97e9b4c801 | ||
|
|
52d7cad735 | ||
|
|
c0b1d270ba | ||
|
|
e59a2892e4 | ||
|
|
5fa0376a49 | ||
|
|
05a33042c8 | ||
|
|
ce58f23cbc | ||
|
|
b6fc9fa370 | ||
|
|
00ae38faae | ||
|
|
ab28ee58ab | ||
|
|
48db538a2e | ||
|
|
46945942e1 | ||
|
|
a24b26a1ef | ||
|
|
6f8421cdd5 | ||
|
|
284cd9bca9 | ||
|
|
23fd6b8d2b | ||
|
|
4f0ea5d756 | ||
|
|
6c218331b1 | ||
|
|
cea7fb7490 | ||
|
|
8acf2dbdfe | ||
|
|
0542700f90 | ||
|
|
5264f7ce18 | ||
|
|
051ffd78a3 | ||
|
|
bea95d4fae | ||
|
|
fdf7bc312f | ||
|
|
5b094e1097 | ||
|
|
9ad3968084 | ||
|
|
3958b6aae1 | ||
|
|
eaa413caf0 | ||
|
|
9095225b5b | ||
|
|
c529f86dbc | ||
|
|
e4fcfa356a | ||
|
|
8218cff7c1 | ||
|
|
6949bbcf39 | ||
|
|
480c60c0a7 | ||
|
|
eec10cb5db | ||
|
|
02c83d8689 | ||
|
|
72b1cacea1 | ||
|
|
c72cda3386 | ||
|
|
867442155e | ||
|
|
229b14b6fc | ||
|
|
158c87ab8b | ||
|
|
cb303e6109 | ||
|
|
a77a8741b5 | ||
|
|
3d63459c25 | ||
|
|
ce63de3c58 | ||
|
|
4b3b1219b5 | ||
|
|
73b069a76c | ||
|
|
101cf8d108 | ||
|
|
2e926dfb6e | ||
|
|
501866d12a | ||
|
|
39bcb0869f | ||
|
|
a7b99cde4e | ||
|
|
60abcd92a3 | ||
|
|
cdd36e7052 | ||
|
|
c6ac175ce4 | ||
|
|
46bcd87c23 | ||
|
|
ab74be8e33 | ||
|
|
d8298b3eab | ||
|
|
50e60e6d05 | ||
|
|
5d02acbf37 | ||
|
|
8901d91f96 | ||
|
|
b55021bb3d | ||
|
|
0ef51b85e6 | ||
|
|
c77566cc02 | ||
|
|
c1bcedfb51 | ||
|
|
d085a3c7d7 | ||
|
|
46fa07e4a9 | ||
|
|
a8d5309c90 | ||
|
|
77c2bfcc1e | ||
|
|
4c8712d683 | ||
|
|
d337140577 | ||
|
|
99c273a293 | ||
|
|
85578a06b7 | ||
|
|
6f70a8efda | ||
|
|
c693e39196 | ||
|
|
4a1fae3cb4 | ||
|
|
08b592816b | ||
|
|
0e85fcfe51 | ||
|
|
8ef788e799 | ||
|
|
645c8899b1 | ||
|
|
9bf5b0fc48 | ||
|
|
07959a3bff | ||
|
|
86a6182e41 | ||
|
|
89e229ab75 | ||
|
|
624917fac4 | ||
|
|
489894c61d | ||
|
|
ac87979cb7 | ||
|
|
5fd3e85a83 | ||
|
|
0e53ba4311 | ||
|
|
3ce57ef851 | ||
|
|
481570d059 | ||
|
|
04442b7ddb | ||
|
|
e1a71723bc | ||
|
|
f044fb8b47 | ||
|
|
e3350d5bec | ||
|
|
8a69d4354e | ||
|
|
dd6a9c26bd | ||
|
|
49fb4034c6 | ||
|
|
5a466d0ff6 | ||
|
|
bb850bb6c5 | ||
|
|
25cf6823d0 | ||
|
|
7e12744b8b | ||
|
|
8f2432e0f8 | ||
|
|
94451db638 | ||
|
|
f8b8eeec3a | ||
|
|
a4260cc5de | ||
|
|
8c1622798b | ||
|
|
e75bed1be5 | ||
|
|
8c0517de0f | ||
|
|
94e78365a5 | ||
|
|
29c056ca65 | ||
|
|
d8c57f27db | ||
|
|
3cac2bad55 | ||
|
|
e7905fdf49 | ||
|
|
a492bc2242 | ||
|
|
e663364f64 | ||
|
|
ef6466e26f | ||
|
|
7fcbbf1cdc | ||
|
|
ec6ad51ff7 | ||
|
|
1e80c59448 | ||
|
|
e48cb4fd5d | ||
|
|
7c9fbd2625 | ||
|
|
0f504415fb | ||
|
|
4998c324d1 | ||
|
|
fb5fbe76e8 | ||
|
|
223b0bfc88 | ||
|
|
51094a68c8 | ||
|
|
83cb1ec911 | ||
|
|
a77e4bfb7a | ||
|
|
654c177333 | ||
|
|
b92669ba33 | ||
|
|
f2e4f6607d | ||
|
|
5ec909c565 | ||
|
|
a84f31d54a | ||
|
|
e0dd21406d | ||
|
|
72f5f7a0b8 | ||
|
|
e3d20085c5 | ||
|
|
8bf1aef801 | ||
|
|
5f7ade20dc | ||
|
|
70d7e52df0 | ||
|
|
8e6afa5614 | ||
|
|
a1ae3804e3 | ||
|
|
814ce7a43b | ||
|
|
628f75009e | ||
|
|
03fc8c1202 | ||
|
|
8c8e996c87 | ||
|
|
933bb0b1fb | ||
|
|
931fbc3eb5 | ||
|
|
3db5e70a3d | ||
|
|
7b19b70d90 | ||
|
|
99b8103d70 | ||
|
|
7167310ccd | ||
|
|
263667a2d4 | ||
|
|
d5cef291f6 | ||
|
|
c8d166e833 | ||
|
|
6e25782d8b | ||
|
|
c3127f7e84 | ||
|
|
7b90fb018b | ||
|
|
e8bc173cd7 | ||
|
|
4d1cdf5207 | ||
|
|
57a473364e | ||
|
|
40b62e9d38 | ||
|
|
ead5f9926b | ||
|
|
814b6753c2 | ||
|
|
ce505251f8 | ||
|
|
5d2a987aaa | ||
|
|
4d67e08723 | ||
|
|
2e71dd5fe2 | ||
|
|
c3b9643227 | ||
|
|
0aad5dc2b7 | ||
|
|
cec900168f | ||
|
|
f9b1c403d5 | ||
|
|
9024b602f5 | ||
|
|
c139fd9a57 | ||
|
|
e299b68163 | ||
|
|
7777a53a82 | ||
|
|
3e185dbbfe | ||
|
|
e8a32af369 | ||
|
|
7b0ec6687e | ||
|
|
ec1c6c7b92 | ||
|
|
8dfaa86760 | ||
|
|
323aebd1be | ||
|
|
436c038a2f | ||
|
|
ccd50ec6c0 | ||
|
|
a7541c2c0f | ||
|
|
c3a57d756c | ||
|
|
aa300a4c98 | ||
|
|
83ea7352b9 | ||
|
|
9050712cd8 | ||
|
|
8d92fdbb6e | ||
|
|
a2442ec1b9 | ||
|
|
71662c9cd9 | ||
|
|
54ff5dbcc2 | ||
|
|
4ab7bd3b51 | ||
|
|
ef3c61a297 | ||
|
|
abf79bf60c | ||
|
|
5d3cecd926 | ||
|
|
16324e7283 | ||
|
|
9f7e2e1572 | ||
|
|
857ce1d530 | ||
|
|
be0d72775d | ||
|
|
7832a2495b | ||
|
|
0506b7f735 | ||
|
|
4c0b7942f0 | ||
|
|
651c840c4a | ||
|
|
2a351ca415 | ||
|
|
49b7106d71 | ||
|
|
8bf633f539 | ||
|
|
0f8efcb4b0 | ||
|
|
c567641c5c | ||
|
|
bdc3820382 | ||
|
|
33a69a7907 | ||
|
|
a4d0e9bbc3 | ||
|
|
afc753e1d2 | ||
|
|
e641a41224 | ||
|
|
79305c0632 | ||
|
|
ef2ce3f09d | ||
|
|
71c18c04fc | ||
|
|
cf84e57f81 | ||
|
|
9421d44579 | ||
|
|
5cd2ae8cc8 | ||
|
|
22d67b3a59 | ||
|
|
e102cbb8c4 | ||
|
|
d90eeb7ee4 | ||
|
|
1989d53031 | ||
|
|
04ef0907b4 | ||
|
|
517b43561c | ||
|
|
ccb8c7227f | ||
|
|
9fbfeeb04f | ||
|
|
8b753a5a1f | ||
|
|
d25cab0627 | ||
|
|
84da0a8a35 | ||
|
|
6f665cffba | ||
|
|
aea8ac2e97 | ||
|
|
8418fa7b45 | ||
|
|
9cc4d0ee07 | ||
|
|
da60831c44 | ||
|
|
0773174a20 | ||
|
|
70e007d8ca | ||
|
|
fcc4d02c2f | ||
|
|
f4a5f00593 | ||
|
|
1170ed6566 | ||
|
|
883f0d449b | ||
|
|
f4c62e7844 | ||
|
|
f0d212a9d2 | ||
|
|
76a8974034 | ||
|
|
0614e822f4 | ||
|
|
6f682c9a2e | ||
|
|
a9fdbc31c5 | ||
|
|
086fdb5856 | ||
|
|
63c8ef4f17 | ||
|
|
736f6523c7 | ||
|
|
8b0b360d25 | ||
|
|
80b84e2ee6 | ||
|
|
b5b7d86f7b | ||
|
|
f20d704390 | ||
|
|
e4e1e2e944 | ||
|
|
6bc7eeb4cc | ||
|
|
656ed5de7b | ||
|
|
a11d695c78 | ||
|
|
c4f9acd5c5 | ||
|
|
5ef929dc42 | ||
|
|
c8cf27b544 | ||
|
|
bb5ecfc398 | ||
|
|
c91e7c35bb | ||
|
|
532d56df2d | ||
|
|
111ad44029 | ||
|
|
6b02bae957 | ||
|
|
6831743416 | ||
|
|
63e2f42636 | ||
|
|
f6e6805453 | ||
|
|
ad77ad8f2b | ||
|
|
469524e8ae | ||
|
|
f4f55d5dfd | ||
|
|
c248d0f3f4 | ||
|
|
648a04b513 | ||
|
|
bdc86c16ec | ||
|
|
21efd17c17 | ||
|
|
aaa75e7b62 | ||
|
|
6d0cef3152 | ||
|
|
c18472289f | ||
|
|
02b7c70a81 | ||
|
|
4eaa2b93c6 | ||
|
|
d347905373 | ||
|
|
f495213b2c | ||
|
|
9b125913ae | ||
|
|
da81f05804 | ||
|
|
9a371a4d4d | ||
|
|
1e92828f1a | ||
|
|
7e724b3fa3 | ||
|
|
3f5b976a87 | ||
|
|
49f2339cc2 | ||
|
|
29f1699de8 | ||
|
|
c415485801 | ||
|
|
6937673472 | ||
|
|
c4f10fe876 | ||
|
|
55ca652ad8 | ||
|
|
3effd5afd1 | ||
|
|
000c2029de | ||
|
|
ab88e3af06 | ||
|
|
b544a4c954 | ||
|
|
baff5fafec | ||
|
|
1673de73ba | ||
|
|
e68936e36e | ||
|
|
7dbd195e45 | ||
|
|
3dc22f98bf | ||
|
|
805e870c18 | ||
|
|
de2c031797 | ||
|
|
3aa571aa1b | ||
|
|
3e4969efe6 | ||
|
|
446e94df76 | ||
|
|
5b26066a4c | ||
|
|
8a80de5c3f | ||
|
|
52a490c87e | ||
|
|
29490741fd | ||
|
|
f0e416455f | ||
|
|
f7a2c97943 | ||
|
|
993853757b | ||
|
|
a3abfb987d | ||
|
|
2711fa1b1b | ||
|
|
1f7afaba07 | ||
|
|
e02c8bff81 | ||
|
|
22391ba1a5 | ||
|
|
a05781ec19 | ||
|
|
f898ed6a2a | ||
|
|
e6d0a15b54 | ||
|
|
49cff026e2 | ||
|
|
08f0023cfd | ||
|
|
e311466ee6 | ||
|
|
56789e68d7 | ||
|
|
87525bb383 | ||
|
|
bb2880191a | ||
|
|
4f1acf26d6 | ||
|
|
fc2d6b21ac | ||
|
|
b9e84fefbd | ||
|
|
91f5ffb2d9 | ||
|
|
70ff2341cb | ||
|
|
74eed93497 | ||
|
|
d02e26c014 | ||
|
|
523cade7c3 | ||
|
|
e22c183ca9 | ||
|
|
3afd99da30 | ||
|
|
f44979f983 | ||
|
|
095f9cc108 | ||
|
|
1089076fce | ||
|
|
cad3b691a9 | ||
|
|
bac21426d3 | ||
|
|
c4a35314cd | ||
|
|
7090722565 | ||
|
|
6d972c7c18 | ||
|
|
6961a88feb | ||
|
|
c41ec13984 | ||
|
|
ca8e06e562 | ||
|
|
200cd33a8e | ||
|
|
1da7991c65 | ||
|
|
fdfb7e369a | ||
|
|
c2b01cc957 | ||
|
|
5de8e94bb4 | ||
|
|
7a2c15d912 | ||
|
|
70344dd214 | ||
|
|
405372d1a7 | ||
|
|
b8c5174da5 | ||
|
|
1f6f9103d9 | ||
|
|
6431487c7a | ||
|
|
8b2d1189db | ||
|
|
b777f27cb7 | ||
|
|
b31c3b124a | ||
|
|
fa1e965fba | ||
|
|
91dc8b4d58 | ||
|
|
6d16ea8830 | ||
|
|
7db4253264 | ||
|
|
4d2b7d9bf9 | ||
|
|
8f6f4acb88 | ||
|
|
f20d84cb37 | ||
|
|
afbdf1d5d5 | ||
|
|
bc8364d594 | ||
|
|
c8d388f70f | ||
|
|
be13cc3194 | ||
|
|
a46320e744 | ||
|
|
071709d263 | ||
|
|
93a32ae5ff | ||
|
|
eee96f226f | ||
|
|
e19a8b479c | ||
|
|
9ef459112e | ||
|
|
e96474bd5c | ||
|
|
6fed719e09 | ||
|
|
99aac76618 | ||
|
|
599f458201 | ||
|
|
2f8099059c | ||
|
|
e24f177832 | ||
|
|
48cc143e88 | ||
|
|
b09b46c045 | ||
|
|
2c6583cc9c | ||
|
|
e381d1bfb8 | ||
|
|
eac619d54f | ||
|
|
a6ef3bc0ce | ||
|
|
118122c541 | ||
|
|
bfdf33ac09 | ||
|
|
fa3370df5b | ||
|
|
f1e51672c5 | ||
|
|
91f97b2728 | ||
|
|
2c542e03fe | ||
|
|
71a11b4267 | ||
|
|
ea642757db | ||
|
|
fb72b601aa | ||
|
|
27e507e744 | ||
|
|
4db19f816f | ||
|
|
096d5776d1 | ||
|
|
3d799eb4d9 | ||
|
|
e4ac3afa4d | ||
|
|
d38e4eed5b | ||
|
|
97787fac91 | ||
|
|
b494ee2f1c | ||
|
|
31ac80a074 | ||
|
|
c8896450f6 | ||
|
|
c662fa4c63 | ||
|
|
db2ee802ca | ||
|
|
d40e915e2b | ||
|
|
c0616e7efa | ||
|
|
01660597e3 | ||
|
|
c5b549f450 | ||
|
|
802d8457bb | ||
|
|
c3a3df67b0 | ||
|
|
5798aeb3cd | ||
|
|
cc81dd9172 | ||
|
|
44fdadda08 | ||
|
|
66a014150b | ||
|
|
1da596639f | ||
|
|
76614ae9e5 | ||
|
|
6ddddffc0f | ||
|
|
dd95f849d4 | ||
|
|
22c7f8fe9e | ||
|
|
3d47be1f49 | ||
|
|
5e399c46b1 | ||
|
|
38e1db7a37 | ||
|
|
8309f7cdbe | ||
|
|
b8cc62ae95 | ||
|
|
c0eb433fa2 | ||
|
|
7f857d66f6 | ||
|
|
93b14d38f4 | ||
|
|
21825faab0 | ||
|
|
1fafd39298 | ||
|
|
23b750fc4f | ||
|
|
90581c840d | ||
|
|
cac7a6228a | ||
|
|
674fbc3f69 | ||
|
|
9577bf1cc7 | ||
|
|
654ebe93e7 | ||
|
|
ecb1b3c491 | ||
|
|
c3d1711edc | ||
|
|
c12c7f10f0 | ||
|
|
f71820bf4e | ||
|
|
748c53c774 | ||
|
|
b290a71bfb | ||
|
|
3204c51eca | ||
|
|
2c4b8a44dc | ||
|
|
943aa05eaa | ||
|
|
d0fd36e7e1 | ||
|
|
f45ff5fd0a | ||
|
|
c22c7102d5 | ||
|
|
11ecfd1b41 | ||
|
|
798e30e5ac | ||
|
|
15e0702329 | ||
|
|
a2bc22c37d | ||
|
|
8093fcc64c | ||
|
|
800419e7cc | ||
|
|
a241dc6785 | ||
|
|
805bea0d5f | ||
|
|
9d394adf24 | ||
|
|
2074f27aff | ||
|
|
283ad48b86 | ||
|
|
07e10a7943 | ||
|
|
2812a5026c | ||
|
|
3a20461abf | ||
|
|
64ae3d1e21 | ||
|
|
a25d7ea65b | ||
|
|
74ebbdd761 | ||
|
|
a0427b569e | ||
|
|
5346dfdd8b | ||
|
|
3ee4147285 | ||
|
|
c41e486bfc | ||
|
|
eda3ba92fd | ||
|
|
40255290b0 | ||
|
|
af5bc73dc0 | ||
|
|
0247cd4c45 | ||
|
|
916762cc8c | ||
|
|
d6fdf8ca2a | ||
|
|
95708489c9 | ||
|
|
ced0fa4608 | ||
|
|
7e0fbd600f | ||
|
|
f33e4e0323 | ||
|
|
d0fd78497d | ||
|
|
8045019603 | ||
|
|
7d92b9435e | ||
|
|
1e0822703a | ||
|
|
0403ff88ef | ||
|
|
78376d591b | ||
|
|
8e23d0df20 | ||
|
|
9e281d20ab | ||
|
|
644bd4a106 | ||
|
|
7729e66a96 | ||
|
|
d67d6b7948 | ||
|
|
4c4a46bfbe | ||
|
|
4536f9c177 | ||
|
|
977d3bc02e | ||
|
|
eae95dfef5 | ||
|
|
b67d4460ca | ||
|
|
3dea8311b1 | ||
|
|
11f6e98874 | ||
|
|
2609e595f4 | ||
|
|
ac6e41abc8 | ||
|
|
9c17e16d0a | ||
|
|
55e9064307 | ||
|
|
91cabd7d49 | ||
|
|
7456950530 | ||
|
|
8fcdda625d | ||
|
|
40a10ee926 | ||
|
|
c3f7e2645c | ||
|
|
b264af1892 | ||
|
|
43e93e8e22 | ||
|
|
d6c4789688 | ||
|
|
cb31ee6f01 | ||
|
|
f7b694ac56 | ||
|
|
eb809055d4 | ||
|
|
78d9be82b2 | ||
|
|
76a95c0226 | ||
|
|
d3ab8fb04a | ||
|
|
f7a0b63a00 | ||
|
|
a21dd97786 | ||
|
|
04943c0bfa | ||
|
|
203d4d8bfb | ||
|
|
c049a619dc | ||
|
|
cc1b14b607 | ||
|
|
e04a12a8f4 | ||
|
|
a2c82bc583 | ||
|
|
b4dc382f7c | ||
|
|
eca1892e2a | ||
|
|
23a237074e | ||
|
|
219e9eca4f | ||
|
|
413e09fb9e | ||
|
|
3514c37e4c | ||
|
|
95260e303c | ||
|
|
0cef34bdfa | ||
|
|
9838979bbd | ||
|
|
c8910b8e14 | ||
|
|
207fa1d019 | ||
|
|
be0bb591e7 | ||
|
|
bfacdb9c3b | ||
|
|
ae4077ed6c | ||
|
|
6eb3c90e18 | ||
|
|
8c2a53a504 | ||
|
|
74db1e0308 | ||
|
|
b9dfdcef3d | ||
|
|
9d4afeac31 | ||
|
|
14ae2f169a | ||
|
|
55df19142f | ||
|
|
40fd545b2c | ||
|
|
95fb07343e | ||
|
|
4d87906559 | ||
|
|
6b30dced43 | ||
|
|
293a03b7c8 | ||
|
|
c010549f17 | ||
|
|
cc0be22026 | ||
|
|
e5ba26febe | ||
|
|
36f9680eec | ||
|
|
f4f5be5b08 | ||
|
|
d89b056886 | ||
|
|
65424c7db9 | ||
|
|
32a8a847fc | ||
|
|
88fb3dbf60 | ||
|
|
f6bee3aa58 | ||
|
|
5f19f37dcb | ||
|
|
dd36d8ce9e | ||
|
|
865e4b5349 | ||
|
|
e70564752b | ||
|
|
6e0d2f9437 | ||
|
|
291f936097 | ||
|
|
0b2ce48586 | ||
|
|
da87fd9e20 | ||
|
|
d4da4d2575 | ||
|
|
bad20ff483 | ||
|
|
21ad51ffbf | ||
|
|
697c6d5fbe | ||
|
|
293c659053 | ||
|
|
a12507abbd | ||
|
|
4e675b84fb | ||
|
|
c1022feab8 | ||
|
|
ddcfcf21fe | ||
|
|
86a58c3d80 | ||
|
|
abf9a9048d | ||
|
|
b1030a527a | ||
|
|
8d07ba6332 | ||
|
|
4ce37f84e4 | ||
|
|
061d8a3a5f | ||
|
|
374cd5dbb8 | ||
|
|
5ad53c2b9c | ||
|
|
a2ec1a063d | ||
|
|
e431dbe2df | ||
|
|
7218463f9e | ||
|
|
aeb09a95b0 | ||
|
|
0c8f292e12 | ||
|
|
f001ac6903 | ||
|
|
db8e506de0 | ||
|
|
099f859dd4 | ||
|
|
b7684c1c2b | ||
|
|
058c167f79 | ||
|
|
49446d4872 | ||
|
|
ced560e1e1 | ||
|
|
339102c3cd | ||
|
|
6331350239 | ||
|
|
34e06fcbf8 | ||
|
|
70aac312ff | ||
|
|
5e00704152 | ||
|
|
1a9edb6907 | ||
|
|
0c18c3a6dd | ||
|
|
847bb51ce4 | ||
|
|
fa60a5dc63 | ||
|
|
aaed3f9839 | ||
|
|
21b956b983 | ||
|
|
792e940279 | ||
|
|
c2477b26c0 | ||
|
|
4b27de809b | ||
|
|
572932d8e8 | ||
|
|
270dd778d9 | ||
|
|
dd04287b0a | ||
|
|
36ac6d005a | ||
|
|
701daedf49 | ||
|
|
238f05f453 | ||
|
|
dd082bd212 | ||
|
|
cfd2f27b0b | ||
|
|
a2160d135e | ||
|
|
16d7836369 | ||
|
|
f3de4dcc5f | ||
|
|
e34523028f | ||
|
|
efe2fbacd6 | ||
|
|
2fa1df29be | ||
|
|
f72cd13fba | ||
|
|
5b552dffbf | ||
|
|
a0ae2d13dc | ||
|
|
f7262a0a3a | ||
|
|
9736f121eb | ||
|
|
7c8fb7eacc | ||
|
|
b45eea5908 | ||
|
|
6babf4ee6c | ||
|
|
576526d4ee | ||
|
|
c03e31b7be | ||
|
|
a1aa925019 | ||
|
|
a5a234ed97 | ||
|
|
5b5dbcd78b | ||
|
|
bd1c6361d3 | ||
|
|
1fc1febf03 | ||
|
|
55cc35efa9 | ||
|
|
5ba8fdc5e7 | ||
|
|
6ea295e227 | ||
|
|
5010c76ef7 | ||
|
|
79c7f0c29f | ||
|
|
2b3e643786 | ||
|
|
90cdff327c | ||
|
|
55c116e727 | ||
|
|
3dd83aa6b7 | ||
|
|
a74aa12641 | ||
|
|
151e8c69f9 | ||
|
|
d8bfa77705 | ||
|
|
6bd286e8d5 | ||
|
|
905532b681 | ||
|
|
04d5c1ab01 | ||
|
|
28be141dc7 | ||
|
|
652b786baf | ||
|
|
ba6c671051 | ||
|
|
ca25d0433f | ||
|
|
5338106dfa | ||
|
|
854d613a81 | ||
|
|
b6b76be4f6 | ||
|
|
03d94fcfa0 | ||
|
|
b2c5f0d455 | ||
|
|
54f60dd38c | ||
|
|
42f181aca2 | ||
|
|
9c3a27894f | ||
|
|
f7cd348912 | ||
|
|
aeaeb75d3b | ||
|
|
96542b532e | ||
|
|
139295fe0d | ||
|
|
13217b2ce2 | ||
|
|
5cc8b56a7c | ||
|
|
e23e01c95e | ||
|
|
bca8ba12c7 | ||
|
|
3c44bdbe1c | ||
|
|
db93ed025b | ||
|
|
4209e108d0 | ||
|
|
14cbf011af | ||
|
|
03a41ec199 | ||
|
|
125fe2a026 | ||
|
|
ac4adac29e | ||
|
|
ac449d078e | ||
|
|
79be4530d4 | ||
|
|
85ce52d70c | ||
|
|
7ab56b9076 | ||
|
|
dedf976375 | ||
|
|
89f438208a | ||
|
|
ffbc5080ae | ||
|
|
4167f13bac | ||
|
|
6ba0baabb0 | ||
|
|
081003df47 | ||
|
|
559194ffb2 | ||
|
|
97a26d4a46 | ||
|
|
503c6c9b7e | ||
|
|
9a1e10deff | ||
|
|
054f927c05 | ||
|
|
22210747d0 | ||
|
|
53b2deb72c | ||
|
|
6fc158e7d6 | ||
|
|
a23a65c731 | ||
|
|
7dc7105ee2 | ||
|
|
bac70108b2 | ||
|
|
297404b21e | ||
|
|
33a7f8b558 | ||
|
|
4a670b7df7 | ||
|
|
79e4af315e | ||
|
|
c6e31b2fdc | ||
|
|
91dc44df53 | ||
|
|
7e57f8f157 | ||
|
|
15f6b7c6d3 | ||
|
|
b213ba541d | ||
|
|
7c6ed9944e | ||
|
|
a5a825e439 | ||
|
|
a4ab547f77 | ||
|
|
76ed763abe | ||
|
|
b9e3125610 | ||
|
|
8d9d5b7b6f | ||
|
|
187601da1e | ||
|
|
cc3a0fc367 | ||
|
|
44cc4165d1 | ||
|
|
f98b43514e | ||
|
|
3c9b1a14e9 | ||
|
|
827e8eddf8 | ||
|
|
7bc27d6167 | ||
|
|
ba06edd63a | ||
|
|
cacf553a5b | ||
|
|
d89091a8ea | ||
|
|
01a56e1155 | ||
|
|
a64d7c42b1 | ||
|
|
36b6cc58bf | ||
|
|
5ac8a257e7 | ||
|
|
74119d0372 | ||
|
|
4e162c73e5 | ||
|
|
5ff753a492 | ||
|
|
89400630c0 | ||
|
|
3899c0cfe3 | ||
|
|
a086f1989f | ||
|
|
1171b04e93 | ||
|
|
c55d81825a | ||
|
|
2dcd026e9f | ||
|
|
cdf8609d24 | ||
|
|
36580c5f7f | ||
|
|
1cff2521f4 | ||
|
|
db4998a56b | ||
|
|
acbd506568 | ||
|
|
0cf8e3be73 | ||
|
|
2473334dfc | ||
|
|
1ff72d1d37 | ||
|
|
241fad5524 | ||
|
|
1b48cea50a | ||
|
|
88bf345b91 | ||
|
|
ab4ff3d1a3 | ||
|
|
3502e0d643 | ||
|
|
995894d3aa | ||
|
|
4da8714124 | ||
|
|
6b247ae880 | ||
|
|
176941ea3b | ||
|
|
5176b56d3b | ||
|
|
8abf18ab25 | ||
|
|
395edbd9f4 | ||
|
|
2386eb8fc2 | ||
|
|
68208f82a0 | ||
|
|
ca916b7ce5 | ||
|
|
01e02934da | ||
|
|
c81a79f7b9 | ||
|
|
1133648bf6 | ||
|
|
e05bc541d7 | ||
|
|
d689d20482 | ||
|
|
39dd99b272 | ||
|
|
cda21acb43 | ||
|
|
9bd7d09f20 | ||
|
|
b22994c2d2 | ||
|
|
e027286b6d | ||
|
|
d6e16995e0 | ||
|
|
782bff3a51 | ||
|
|
de26dc0597 | ||
|
|
233b24ab0f | ||
|
|
2f9e5b1219 | ||
|
|
dd36b8b150 | ||
|
|
f81ac31fe1 | ||
|
|
24b63bc5bd | ||
|
|
1817a972c6 | ||
|
|
74a253f521 | ||
|
|
41762a1c57 | ||
|
|
a786fa4b75 | ||
|
|
e4c7602c0c | ||
|
|
e0d2e34980 | ||
|
|
9ef8e1be3f | ||
|
|
aae9b64833 | ||
|
|
4bab4299f2 | ||
|
|
954e55f4b4 | ||
|
|
2361e3c28c | ||
|
|
8224c2fc16 | ||
|
|
8aac86f0a9 | ||
|
|
6384e9310b | ||
|
|
7a9205dfba | ||
|
|
94b47a56f4 | ||
|
|
709b5be634 | ||
|
|
f970b2c168 | ||
|
|
973acb37ed | ||
|
|
1c9020a565 | ||
|
|
c5f1d0042c | ||
|
|
fa706e8b1d | ||
|
|
12c170f227 | ||
|
|
db27dfe227 | ||
|
|
2db4673392 | ||
|
|
38619db629 | ||
|
|
930fd436ea | ||
|
|
98b8ff2fc8 | ||
|
|
d0662683f9 | ||
|
|
957f2574a9 | ||
|
|
109b362ebd | ||
|
|
ff3fdfa738 | ||
|
|
e2636ed54a | ||
|
|
dbe2f17e1a | ||
|
|
4dc535673f | ||
|
|
f414b6408e | ||
|
|
3aa2e6a04d | ||
|
|
1963ff273f | ||
|
|
bb737a71d5 | ||
|
|
a582a46ce9 | ||
|
|
abf80a3266 | ||
|
|
d768f5c66d | ||
|
|
b25e843351 | ||
|
|
419a3e518e | ||
|
|
d1b867a7c0 | ||
|
|
c34d70b3cb | ||
|
|
a33df9312f | ||
|
|
ebf8db0b37 | ||
|
|
e539ae3b69 | ||
|
|
4c5e8850aa | ||
|
|
94c0af3037 | ||
|
|
165182c68f | ||
|
|
65b9542599 | ||
|
|
d01d1f8830 | ||
|
|
ad3e9f3d42 | ||
|
|
4589974095 | ||
|
|
ed4553ddf8 | ||
|
|
ff97ae73f1 | ||
|
|
f96b4d2781 | ||
|
|
ce32cfffdb | ||
|
|
f66df8531e | ||
|
|
dfe1c23e76 | ||
|
|
07fd81919f | ||
|
|
210042bb81 | ||
|
|
12dc7427e9 | ||
|
|
b476085110 | ||
|
|
776cdaf63c | ||
|
|
69b6855745 | ||
|
|
3590babd8b | ||
|
|
c29d391c1d | ||
|
|
50e44dbb2a | ||
|
|
34277a3940 | ||
|
|
f1a00d58ca | ||
|
|
d1a5f17ae8 | ||
|
|
4dbc54fa15 | ||
|
|
1d4ff796d7 | ||
|
|
44cb54a9ea | ||
|
|
6409f49609 | ||
|
|
9ee0ea88b5 | ||
|
|
a3819d8673 | ||
|
|
2d7dd71a3d | ||
|
|
0e8195ae61 | ||
|
|
3e92d07618 | ||
|
|
e59597280d | ||
|
|
f2e3d69d8a | ||
|
|
9d2cb75c84 | ||
|
|
f971505c4a | ||
|
|
2133c1d6af | ||
|
|
0bf06ddfd3 | ||
|
|
024a50d642 | ||
|
|
e4eebd64d1 | ||
|
|
c9055989e9 | ||
|
|
4f1ed197ce | ||
|
|
3e710aa2a1 | ||
|
|
b6226a45bb | ||
|
|
3001ba9266 | ||
|
|
b0a401a1ed | ||
|
|
6b4dc37428 | ||
|
|
8528c9b262 | ||
|
|
7222a5c2f4 | ||
|
|
59050001ef | ||
|
|
2ba8f18724 | ||
|
|
fb22e01b89 | ||
|
|
76a81d5360 | ||
|
|
3314b05648 | ||
|
|
45b89218de | ||
|
|
beb7bda243 | ||
|
|
bef2896f50 | ||
|
|
9fea949b25 | ||
|
|
be258e5b05 | ||
|
|
008178d737 | ||
|
|
527d5e1dbc | ||
|
|
9b47e2d6f9 | ||
|
|
8781b1e976 | ||
|
|
38c653d8d8 | ||
|
|
74e48bb137 | ||
|
|
c3aaa1f735 | ||
|
|
bead2aa228 | ||
|
|
dc52ab8aa9 | ||
|
|
20b71f206b | ||
|
|
73c87d5959 | ||
|
|
c6601aaeed | ||
|
|
6e14fce1fe | ||
|
|
be5a62f1b8 | ||
|
|
1fa8cefaea | ||
|
|
d7c251ac83 | ||
|
|
d03229a183 | ||
|
|
243482e829 | ||
|
|
79d10be8a0 | ||
|
|
dca5c058e0 | ||
|
|
9163ce71fd | ||
|
|
2ec5374765 | ||
|
|
d6a4b35cd3 | ||
|
|
8205d2552c | ||
|
|
9a99caeb9d | ||
|
|
1e09bd0e76 | ||
|
|
cae12eb187 | ||
|
|
8bb36e0eb6 | ||
|
|
d183204caa | ||
|
|
4a22ae6b61 | ||
|
|
a52f54d988 | ||
|
|
618c94edb8 | ||
|
|
eaf4e9174f | ||
|
|
4af2c7f3d7 | ||
|
|
361f599df0 | ||
|
|
ffe4ea5e4c | ||
|
|
9461e3e01a | ||
|
|
7c85c6f742 | ||
|
|
b5df6faadf | ||
|
|
7cefe2d825 | ||
|
|
350633b69b | ||
|
|
1cd6a71ce0 | ||
|
|
3a08b002a0 | ||
|
|
665001732b | ||
|
|
cca49da730 | ||
|
|
f6d370ad29 | ||
|
|
c9131b333b | ||
|
|
e44161bf42 | ||
|
|
a26189fb25 | ||
|
|
89dd8a1db6 | ||
|
|
650e0b4ad4 | ||
|
|
c60f0517fb | ||
|
|
0f8dc91a8b | ||
|
|
b58feb5d8e | ||
|
|
71c8043699 | ||
|
|
40264bc9cb | ||
|
|
a7772316f9 | ||
|
|
34209021c8 | ||
|
|
3e9e8d442a | ||
|
|
d2bf90c6c7 | ||
|
|
1e58c1ad2b | ||
|
|
8cea022ec5 | ||
|
|
f32f8aa08e | ||
|
|
3ea8781381 | ||
|
|
ab83dacb76 | ||
|
|
4cbf46fd4d | ||
|
|
0a7d6e4577 | ||
|
|
df4c1f0401 | ||
|
|
9a86a67984 | ||
|
|
a0cbe9c3e2 | ||
|
|
a83e5a9b65 | ||
|
|
de33911460 | ||
|
|
0be56e5b25 | ||
|
|
abcbb34b1c | ||
|
|
6a13dd04a3 | ||
|
|
f2e29f3f2e | ||
|
|
68361cddd2 | ||
|
|
6404332adc | ||
|
|
e060b6fea2 | ||
|
|
e8aae27ee9 | ||
|
|
2f732e5493 | ||
|
|
65f20ff2c1 | ||
|
|
8f72e8c3e6 | ||
|
|
3b8972ce1f | ||
|
|
fc5d3e4e9c | ||
|
|
29fbf69945 | ||
|
|
583440b82b | ||
|
|
720de9d73f | ||
|
|
78332d882b | ||
|
|
2dfbc840b3 | ||
|
|
0b4bf15163 | ||
|
|
2989249e4b | ||
|
|
9cef559a05 | ||
|
|
47fe16c92a | ||
|
|
36b5c821ff | ||
|
|
82ec440b45 | ||
|
|
88f4a45cae | ||
|
|
7fb4f72b84 | ||
|
|
d4fc322101 | ||
|
|
8fa3da9ca5 | ||
|
|
68ef5aa3ae | ||
|
|
28bd917c9f | ||
|
|
0eb1b94300 | ||
|
|
15e6cf850b | ||
|
|
ee91c86a29 | ||
|
|
48c08f4aad | ||
|
|
fceabb8e67 | ||
|
|
fcfafb05f1 | ||
|
|
f1e8344beb | ||
|
|
f687b2b6f4 | ||
|
|
8ee7a48151 | ||
|
|
89e8f385b4 |
31
.github/ISSUE_TEMPLATE.md
vendored
31
.github/ISSUE_TEMPLATE.md
vendored
@@ -1,31 +0,0 @@
|
||||
### 前置确认
|
||||
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
||||
3. `git pull` 拉取最新代码
|
||||
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
||||
5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
|
||||
6. 在已有 issue 中未搜索到类似问题
|
||||
7. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
|
||||
|
||||
### 问题描述
|
||||
|
||||
> 简要说明、截图、复现步骤等,也可以是需求或想法
|
||||
|
||||
|
||||
|
||||
|
||||
### 终端日志 (如有报错)
|
||||
|
||||
```
|
||||
[在此处粘贴终端日志, 可在主目录下`run.log`文件中找到]
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 环境
|
||||
|
||||
- 操作系统类型 (Mac/Windows/Linux):
|
||||
- Python版本 ( 执行 `python3 -V` ):
|
||||
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`):
|
||||
133
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
Normal file
133
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
name: Bug report 🐛
|
||||
description: 项目运行中遇到的Bug或问题。
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
### ⚠️ 前置确认
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
||||
3. `git pull` 拉取最新代码
|
||||
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
||||
5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
|
||||
6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 前置确认
|
||||
options:
|
||||
- label: 我确认我运行的是最新版本的代码,并且安装了所需的依赖,在[FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs)中也未找到类似问题。
|
||||
required: true
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: ⚠️ 搜索issues中是否已存在类似问题
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索你的问题
|
||||
或相关日志的关键词来查找是否存在类似问题。
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有跟我遇到的问题相关的issue
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写你对你所遇到问题的简略总结,这将帮助其他人更好的找到相似问题,谢谢❤️。
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 操作系统类型?
|
||||
description: >
|
||||
请选择你运行程序的操作系统类型。
|
||||
options:
|
||||
- Windows
|
||||
- Linux
|
||||
- MacOS
|
||||
- Docker
|
||||
- Railway
|
||||
- Windows Subsystem for Linux (WSL)
|
||||
- Other (请在问题中说明)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的python版本是?
|
||||
description: |
|
||||
请选择你运行程序的`python`版本。
|
||||
注意:在`python 3.7`中,有部分可选依赖无法安装。
|
||||
经过长时间的观察,我们认为`python 3.8`是兼容性最好的版本。
|
||||
`python 3.7`~`python 3.10`以外版本的issue,将视情况直接关闭。
|
||||
options:
|
||||
- python 3.7
|
||||
- python 3.8
|
||||
- python 3.9
|
||||
- python 3.10
|
||||
- other
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 使用的chatgpt-on-wechat版本是?
|
||||
description: |
|
||||
请确保你使用的是 [releases](https://github.com/zhayujie/chatgpt-on-wechat/releases) 中的最新版本。
|
||||
如果你使用git, 请使用`git branch`命令来查看分支。
|
||||
options:
|
||||
- Latest Release
|
||||
- Master (branch)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的`channel`类型是?
|
||||
description: |
|
||||
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
|
||||
options:
|
||||
- wx(个人微信, itchat)
|
||||
- wxy(个人微信, wechaty)
|
||||
- wechatmp(公众号, 订阅号)
|
||||
- wechatmp_service(公众号, 服务号)
|
||||
- terminal
|
||||
- other
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 复现步骤 🕹
|
||||
description: |
|
||||
**⚠️ 不能复现将会关闭issue.**
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 问题描述 😯
|
||||
description: 详细描述出现的问题,或提供有关截图。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 终端日志 📒
|
||||
description: |
|
||||
在此处粘贴终端日志,可在主目录下`run.log`文件中找到,这会帮助我们更好的分析问题,注意隐去你的API key。
|
||||
如果在配置文件中加入`"debug": true`,打印出的日志会更有帮助。
|
||||
|
||||
<details>
|
||||
<summary><i>示例</i></summary>
|
||||
```log
|
||||
[DEBUG][2023-04-16 00:23:22][plugin_manager.py:157] - Plugin SUMMARY triggered by event Event.ON_HANDLE_CONTEXT
|
||||
[DEBUG][2023-04-16 00:23:22][main.py:221] - [Summary] on_handle_context. content: $总结前100条消息
|
||||
[DEBUG][2023-04-16 00:23:24][main.py:240] - [Summary] limit: 100, duration: -1 seconds
|
||||
[ERROR][2023-04-16 00:23:24][chat_channel.py:244] - Worker return exception: name 'start_date' is not defined
|
||||
Traceback (most recent call last):
|
||||
File "C:\ProgramData\Anaconda3\lib\concurrent\futures\thread.py", line 57, in run
|
||||
result = self.fn(*self.args, **self.kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 132, in _handle
|
||||
reply = self._generate_reply(context)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 142, in _generate_reply
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
||||
File "D:\project\chatgpt-on-wechat\plugins\plugin_manager.py", line 159, in emit_event
|
||||
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 255, in on_handle_context
|
||||
records = self._get_records(session_id, start_time, limit)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 96, in _get_records
|
||||
c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", (session_id, start_date, limit))
|
||||
NameError: name 'start_date' is not defined
|
||||
[INFO][2023-04-16 00:23:36][app.py:14] - signal 2 received, exiting...
|
||||
```
|
||||
</details>
|
||||
value: |
|
||||
```log
|
||||
<此处粘贴终端日志>
|
||||
```
|
||||
28
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
Normal file
28
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Feature request 🚀
|
||||
description: 提出你对项目的新想法或建议。
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写简略总结,谢谢❤️。
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: ⚠️ 搜索是否存在类似issue
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索关键词查找是否存在相似issue。
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有发现相似issue
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 总结
|
||||
description: 描述feature的功能。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 举例
|
||||
description: 提供聊天示例,草图或相关网址。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 动机
|
||||
description: 描述你提出该feature的动机,比如没有这项feature对你的使用造成了怎样的影响。 请提供更详细的场景描述,这可能会帮助我们发现并提出更好的解决方案。
|
||||
72
.github/workflows/deploy-image-arm.yml
vendored
Normal file
72
.github/workflows/deploy-image-arm.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
# GitHub recommends pinning actions to a commit SHA.
|
||||
# To get a newer version, you will need to update the SHA.
|
||||
# You can also reference a tag or branch, but the action may change without warning.
|
||||
|
||||
name: Create and publish a Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ['master']
|
||||
create:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Available platforms
|
||||
run: echo ${{ steps.buildx.outputs.platforms }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
platforms: linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}-arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
with:
|
||||
package-name: 'chatgpt-on-wechat'
|
||||
package-type: 'container'
|
||||
min-versions-to-keep: 10
|
||||
delete-only-untagged-versions: 'true'
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
15
.github/workflows/deploy-image.yml
vendored
15
.github/workflows/deploy-image.yml
vendored
@@ -19,6 +19,7 @@ env:
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -28,6 +29,12 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
@@ -39,7 +46,9 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
images: |
|
||||
${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
@@ -49,9 +58,9 @@ jobs:
|
||||
file: ./docker/Dockerfile.latest
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
with:
|
||||
with:
|
||||
package-name: 'chatgpt-on-wechat'
|
||||
package-type: 'container'
|
||||
min-versions-to-keep: 10
|
||||
|
||||
18
.gitignore
vendored
18
.gitignore
vendored
@@ -1,6 +1,8 @@
|
||||
.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
.venv
|
||||
.vs
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
@@ -12,7 +14,11 @@ tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
logs/
|
||||
workspace
|
||||
config.yaml
|
||||
user_datas.pkl
|
||||
chatgpt_tool_hub/
|
||||
plugins/**/
|
||||
!plugins/bdunit
|
||||
!plugins/dungeon
|
||||
@@ -20,5 +26,15 @@ plugins/**/
|
||||
!plugins/godcmd
|
||||
!plugins/tool
|
||||
!plugins/banwords
|
||||
!plugins/banwords/**/
|
||||
plugins/banwords/__pycache__
|
||||
plugins/banwords/lib/__pycache__
|
||||
!plugins/hello
|
||||
!plugins/role
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
!plugins/agent
|
||||
client_config.json
|
||||
ref/
|
||||
.cursor/
|
||||
local/
|
||||
|
||||
826
README.md
826
README.md
@@ -1,75 +1,105 @@
|
||||
<p align="center"><img src= "https://github.com/user-attachments/assets/eca9a9ec-8534-4615-9e0f-96c5ac1d10a3" alt="Chatgpt-on-Wechat" width="550" /></p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/chatgpt-on-wechat" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/chatgpt-on-wechat" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat"><img src="https://img.shields.io/github/stars/zhayujie/chatgpt-on-wechat?style=flat-square" alt="Stars"></a> <br/>
|
||||
</p>
|
||||
|
||||
**CowAgent** 是基于大模型的超级AI助理,能够主动思考和任务规划、操作计算机和外部资源、创造和执行Skills、拥有长期记忆并不断成长。CowAgent 支持灵活切换多种模型,能处理文本、语音、图片、文件等多模态消息,可接入网页、飞书、钉钉、企业微信应用、微信公众号中使用,7*24小时运行于你的个人电脑或服务器中。
|
||||
|
||||
📖能力介绍:[CowAgent 2.0](/docs/agent.md)
|
||||
|
||||
# 简介
|
||||
|
||||
> ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。
|
||||
> 该项目既是一个可以开箱即用的超级AI助理,也是一个支持高扩展的Agent框架,可以通过为项目扩展大模型接口、接入渠道、内置工具、Skills系统来灵活实现各种定制需求。核心能力如下:
|
||||
|
||||
- ✅ **复杂任务规划**:能够理解复杂任务并自主规划执行,持续思考和调用工具直到完成目标,支持通过工具操作访问文件、终端、浏览器、定时任务等系统资源
|
||||
- ✅ **长期记忆:** 自动将对话记忆持久化至本地文件和数据库中,包括全局记忆和天级记忆,支持关键词及向量检索
|
||||
- ✅ **技能系统:** 实现了Skills创建和运行的引擎,内置多种技能,并支持通过自然语言对话完成自定义Skills开发
|
||||
- ✅ **多模态消息:** 支持对文本、图片、语音、文件等多类型消息进行解析、处理、生成、发送等操作
|
||||
- ✅ **多模型接入:** 支持OpenAI, Claude, Gemini, DeepSeek, MiniMax、GLM、Qwen、Kimi、Doubao等国内外主流模型厂商
|
||||
- ✅ **多端部署:** 支持运行在本地计算机或服务器,可集成到网页、飞书、钉钉、微信公众号、企业微信应用中使用
|
||||
- ✅ **知识库:** 集成企业知识库能力,让Agent成为专属数字员工,基于[LinkAI](https://link-ai.tech)平台实现
|
||||
|
||||
基于ChatGPT的微信聊天机器人,通过 [ChatGPT](https://github.com/openai/openai-python) 接口生成对话内容,使用 [itchat](https://github.com/littlecodersh/ItChat) 实现微信消息的接收和自动回复。已实现的特性如下:
|
||||
## 声明
|
||||
|
||||
- [x] **文本对话:** 接收私聊及群组中的微信消息,使用ChatGPT生成回复内容,完成自动回复
|
||||
- [x] **规则定制化:** 支持私聊中按指定规则触发自动回复,支持对群组设置自动回复白名单
|
||||
- [x] **多账号:** 支持多微信账号同时运行
|
||||
- [x] **图片生成:** 支持根据描述生成图片,并自动发送至个人聊天或群聊
|
||||
- [x] **上下文记忆**:支持多轮对话记忆,且为每个好友维护独立的上下会话
|
||||
- [x] **语音识别:** 支持接收和处理语音消息,通过文字或语音回复
|
||||
- [x] **插件化:** 支持个性化功能插件,提供角色扮演、文字冒险游戏等预设插件
|
||||
1. 本项目遵循 [MIT开源协议](/LICENSE),主要用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为。任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任。
|
||||
2. 成本与安全:Agent模式下Token使用量高于普通对话模式,请根据效果及成本综合选择模型。Agent具有访问所在操作系统的能力,请谨慎选择项目部署环境。同时项目也会持续升级安全机制、并降低模型消耗成本。
|
||||
3. CowAgent项目专注于开源技术开发,不会参与、授权或发行任何加密货币。
|
||||
|
||||
> 目前支持微信和微信个人号部署,欢迎接入更多应用,参考[`Terminal`代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。
|
||||
## 演示
|
||||
|
||||
使用说明(Agent模式):[CowAgent介绍](/docs/agent.md)
|
||||
|
||||
快速部署:
|
||||
DEMO视频(对话模式):https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
## 社区
|
||||
|
||||
添加小助手微信加入开源项目交流群:
|
||||
|
||||
<img width="140" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
|
||||
<br/>
|
||||
|
||||
# 企业服务
|
||||
|
||||
<a href="https://link-ai.tech" target="_blank"><img width="720" src="https://cdn.link-ai.tech/image/link-ai-intro.jpg"></a>
|
||||
|
||||
> [LinkAI](https://link-ai.tech/) 是面向企业和开发者的一站式AI智能体平台,聚合多模态大模型、知识库、Agent 插件、工作流等能力,支持一键接入主流平台并进行管理,支持SaaS、私有化部署等多种模式。
|
||||
>
|
||||
>[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
> LinkAI 目前已在智能客服、私域运营、企业效率助手等场景积累了丰富的AI解决方案,在消费、健康、文教、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。
|
||||
|
||||
**产品咨询和企业服务** 可联系产品客服:
|
||||
|
||||
<img width="150" src="https://cdn.link-ai.tech/portal/linkai-customer-service.png">
|
||||
|
||||
<br/>
|
||||
|
||||
# 🏷 更新日志
|
||||
|
||||
>**2026.02.03:** [2.0.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/2.0.0),正式升级为超级Agent助理,支持多轮任务决策、具备长期记忆、实现多种系统工具、支持Skills框架,新增多种模型并优化了接入渠道。
|
||||
|
||||
>**2025.05.23:** [1.7.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.6) 优化web网页channel、新增 [AgentMesh](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md)多智能体插件、百度语音合成优化、企微应用`access_token`获取优化、支持`claude-4-sonnet`和`claude-4-opus`模型
|
||||
|
||||
>**2025.04.11:** [1.7.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.5) 新增支持 [wechatferry](https://github.com/zhayujie/chatgpt-on-wechat/pull/2562) 协议、新增 deepseek 模型、新增支持腾讯云语音能力、新增支持 ModelScope 和 Gitee-AI API接口
|
||||
|
||||
>**2024.12.13:** [1.7.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.4) 新增 Gemini 2.0 模型、新增web channel、解决内存泄漏问题、解决 `#reloadp` 命令重载不生效问题
|
||||
|
||||
>**2024.10.31:** [1.7.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.3) 程序稳定性提升、数据库功能、Claude模型优化、linkai插件优化、离线通知
|
||||
|
||||
更多更新历史请查看: [更新日志](/docs/release/history.md)
|
||||
|
||||
<br/>
|
||||
|
||||
# 🚀 快速开始
|
||||
|
||||
项目提供了一键安装、配置、启动、管理程序的脚本,推荐使用脚本快速运行,也可以根据下文中的详细指引一步步安装运行。
|
||||
|
||||
在终端执行以下命令:
|
||||
|
||||
```bash
|
||||
bash <(curl -sS https://cdn.link-ai.tech/code/cow/run.sh)
|
||||
```
|
||||
|
||||
脚本使用说明:[一键运行脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/CowAgentQuickStart)
|
||||
|
||||
|
||||
# 更新日志
|
||||
## 一、准备
|
||||
|
||||
>**2023.04.05:** 支持微信个人号部署,兼容角色扮演等预设插件,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
|
||||
### 1. 模型API
|
||||
|
||||
>**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
|
||||
项目支持国内外主流厂商的模型接口,可选模型及配置说明参考:[模型说明](#模型说明)。
|
||||
|
||||
>**2023.03.25:** 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 [#578](https://github.com/zhayujie/chatgpt-on-wechat/issues/578)。(contributed by [@lanvent](https://github.com/lanvent) in [#565](https://github.com/zhayujie/chatgpt-on-wechat/pull/565))
|
||||
> 注:Agent模式下推荐使用以下模型,可根据效果及成本综合选择:MiniMax-M2.5、glm-5、kimi-k2.5、qwen3.5-plus、claude-sonnet-4-6、gemini-3.1-pro-preview
|
||||
|
||||
>**2023.03.09:** 基于 `whisper API`(后续已接入更多的语音`API`服务) 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
|
||||
同时支持使用 **LinkAI平台** 接口,可灵活切换 OpenAI、Claude、Gemini、DeepSeek、Qwen、Kimi 等多种常用模型,并支持知识库、工作流、插件等Agent能力,参考 [接口文档](https://docs.link-ai.tech/platform/api)。
|
||||
|
||||
>**2023.03.02:** 接入[ChatGPT API](https://platform.openai.com/docs/guides/chat) (gpt-3.5-turbo),默认使用该模型进行对话,需升级openai依赖 (`pip3 install --upgrade openai`)。网络问题参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
### 2.环境安装
|
||||
|
||||
>**2023.02.09:** 扫码登录存在封号风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
|
||||
支持 Linux、MacOS、Windows 操作系统,可在个人计算机及服务器上运行,需安装 `Python`,Python版本需在3.7 ~ 3.12 之间,推荐使用3.9版本。
|
||||
|
||||
>**2023.02.05:** 在openai官方接口方案中 (GPT-3模型) 实现上下文对话
|
||||
|
||||
>**2022.12.18:** 支持根据描述生成图片并发送,openai版本需大于0.25.0
|
||||
|
||||
>**2022.12.17:** 原来的方案是从 [ChatGPT页面](https://chat.openai.com/chat) 获取session_token,使用 [revChatGPT](https://github.com/acheong08/ChatGPT) 直接访问web接口,但随着ChatGPT接入Cloudflare人机验证,这一方案难以在服务器顺利运行。 所以目前使用的方案是调用 OpenAI 官方提供的 [API](https://beta.openai.com/docs/api-reference/introduction),回复质量上基本接近于ChatGPT的内容,劣势是暂不支持有上下文记忆的对话,优势是稳定性和响应速度较好。
|
||||
|
||||
# 使用效果
|
||||
|
||||
### 个人聊天
|
||||
|
||||

|
||||
|
||||
### 群组聊天
|
||||
|
||||

|
||||
|
||||
### 图片生成
|
||||
|
||||

|
||||
|
||||
|
||||
# 快速开始
|
||||
|
||||
## 准备
|
||||
|
||||
### 1. OpenAI账号注册
|
||||
|
||||
前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,参考这篇 [教程](https://www.pythonthree.com/register-openai-chatgpt/) 可以通过虚拟手机号来接收验证码。创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。
|
||||
|
||||
> 项目中使用的对话模型是 davinci,计费方式是约每 750 字 (包含请求和回复) 消耗 $0.02,图片生成是每张消耗 $0.016,账号创建有免费的 $18 额度 (更新3.25: 最新注册的已经无免费额度了),使用完可以更换邮箱重新注册。
|
||||
|
||||
### 2.运行环境
|
||||
|
||||
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
|
||||
> 建议Python版本在 3.7.1~3.9.X 之间,推荐3.8版本,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
|
||||
> 注意:Agent模式推荐使用源码运行,若选择Docker部署则无需安装python环境和下载源码,可直接快进到下一节。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
|
||||
@@ -78,8 +108,10 @@ git clone https://github.com/zhayujie/chatgpt-on-wechat
|
||||
cd chatgpt-on-wechat/
|
||||
```
|
||||
|
||||
若遇到网络问题可使用国内仓库地址:https://gitee.com/zhayujie/chatgpt-on-wechat
|
||||
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
> 能够使用`itchat`创建机器人,并具有文字交流功能所需的最小依赖集合。
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
@@ -89,27 +121,9 @@ pip3 install -r requirements.txt
|
||||
```bash
|
||||
pip3 install -r requirements-optional.txt
|
||||
```
|
||||
> 如果某项依赖安装失败请注释掉对应的行再继续。
|
||||
如果某项依赖安装失败可注释掉对应的行后重试。
|
||||
|
||||
其中`tiktoken`要求`python`版本在3.8以上,它用于精确计算会话使用的tokens数量,强烈建议安装。
|
||||
|
||||
|
||||
使用`google`或`baidu`语音识别需安装`ffmpeg`,
|
||||
|
||||
默认的`openai`语音识别不需要安装`ffmpeg`。
|
||||
|
||||
参考[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)
|
||||
|
||||
使用`azure`语音功能需安装依赖(列在`requirements-optional.txt`内,但为便于`railway`部署已注释):
|
||||
|
||||
```bash
|
||||
pip3 install azure-cognitiveservices-speech
|
||||
```
|
||||
|
||||
> 目前默认发布的镜像和`railway`部署,都基于`apline`,无法安装`azure`的依赖。若有需求请自行基于[`debian`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/docker/Dockerfile.debian.latest)打包。
|
||||
参考[文档](https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/quickstarts/setup-platform?pivots=programming-language-python&tabs=linux%2Cubuntu%2Cdotnet%2Cjre%2Cmaven%2Cnodejs%2Cmac%2Cpypi)
|
||||
|
||||
## 配置
|
||||
## 二、配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
|
||||
@@ -117,106 +131,646 @@ pip3 install azure-cognitiveservices-speech
|
||||
cp config-template.json config.json
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改:
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证JSON格式的规范):
|
||||
|
||||
```bash
|
||||
# config.json文件内容示例
|
||||
{
|
||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
||||
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
# config.json 文件内容示例
|
||||
{
|
||||
"channel_type": "web", # 接入渠道类型,默认为web,支持修改为:feishu,dingtalk,wechatcom_app,terminal,wechatmp,wechatmp_service
|
||||
"model": "MiniMax-M2.5", # 模型名称
|
||||
"minimax_api_key": "", # MiniMax API Key
|
||||
"zhipu_ai_api_key": "", # 智谱GLM API Key
|
||||
"moonshot_api_key": "", # Kimi/Moonshot API Key
|
||||
"ark_api_key": "", # 豆包(火山方舟) API Key
|
||||
"dashscope_api_key": "", # 百炼(通义千问)API Key
|
||||
"claude_api_key": "", # Claude API Key
|
||||
"claude_api_base": "https://api.anthropic.com/v1", # Claude API 地址,修改可接入三方代理平台
|
||||
"gemini_api_key": "", # Gemini API Key
|
||||
"gemini_api_base": "https://generativelanguage.googleapis.com", # Gemini API地址
|
||||
"open_ai_api_key": "", # OpenAI API Key
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # OpenAI API 地址
|
||||
"linkai_api_key": "", # LinkAI API Key
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境需要开启代理的可填写该项,如 "127.0.0.1:7890"
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述,
|
||||
"voice_reply_voice": false, # 是否使用语音回复语音
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台接口
|
||||
"agent": true, # 是否启用Agent模式,启用后拥有多轮工具决策、长期记忆、Skills能力等
|
||||
"agent_workspace": "~/cow", # Agent的工作空间路径,用于存储memory、skills、系统设定等
|
||||
"agent_max_context_tokens": 40000, # Agent模式下最大上下文tokens,超出将自动丢弃最早的上下文
|
||||
"agent_max_context_turns": 30, # Agent模式下最大上下文记忆轮次,每轮包括一次用户提问和AI回复
|
||||
"agent_max_steps": 15 # Agent模式下单次任务的最大决策步数,超出后将停止继续调用工具
|
||||
}
|
||||
```
|
||||
**配置说明:**
|
||||
|
||||
**1.个人聊天**
|
||||
**配置补充说明:**
|
||||
|
||||
+ 个人聊天中,需要以 "bot"或"@bot" 为开头的内容触发机器人,对应配置项 `single_chat_prefix` (如果不需要以前缀触发可以填写 `"single_chat_prefix": [""]`)
|
||||
+ 机器人回复的内容会以 "[bot] " 作为前缀, 以区分真人,对应的配置项为 `single_chat_reply_prefix` (如果不需要前缀可以填写 `"single_chat_reply_prefix": ""`)
|
||||
|
||||
**2.群组聊天**
|
||||
|
||||
+ 群组聊天中,群名称需配置在 `group_name_white_list ` 中才能开启群聊自动回复。如果想对所有群聊生效,可以直接填写 `"group_name_white_list": ["ALL_GROUP"]`
|
||||
+ 默认只要被人 @ 就会触发机器人自动回复;另外群聊天中只要检测到以 "@bot" 开头的内容,同样会自动回复(方便自己触发),这对应配置项 `group_chat_prefix`
|
||||
+ 可选配置: `group_name_keyword_white_list`配置项支持模糊匹配群名称,`group_chat_keyword`配置项则支持模糊匹配群消息内容,用法与上述两个配置项相同。(Contributed by [evolay](https://github.com/evolay))
|
||||
+ `group_chat_in_one_session`:使群聊共享一个会话上下文,配置 `["ALL_GROUP"]` 则作用于所有群聊
|
||||
|
||||
**3.语音识别**
|
||||
<details>
|
||||
<summary>1. 语音配置</summary>
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊),但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊)
|
||||
</details>
|
||||
|
||||
**4.其他配置**
|
||||
<details>
|
||||
<summary>2. 其他配置</summary>
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放)
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。
|
||||
+ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
|
||||
+ `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
|
||||
+ `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
|
||||
+ `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `model`: 模型名称,Agent模式下推荐使用 `MiniMax-M2.5`、`glm-5`、`kimi-k2.5`、`qwen3.5-plus`、`claude-sonnet-4-6`、`gemini-3.1-pro-preview`,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `character_desc`:普通对话模式下的机器人系统提示词。在Agent模式下该配置不生效,由工作空间中的文件内容构成。
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
</details>
|
||||
|
||||
**所有可选的配置项均在该[文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
<details>
|
||||
<summary>3. LinkAI配置</summary>
|
||||
|
||||
## 运行
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台,使用知识库、工作流、插件等能力, 参考[接口文档](https://docs.link-ai.tech/platform/api/chat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用或工作流的code,选填,普通对话模式中使用。
|
||||
</details>
|
||||
|
||||
注:全部配置项说明可在 [`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py) 文件中查看。
|
||||
|
||||
## 三、运行
|
||||
|
||||
### 1.本地运行
|
||||
|
||||
如果是开发机 **本地运行**,直接在项目根目录下执行:
|
||||
如果是个人计算机 **本地运行**,直接在项目根目录下执行:
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
|
||||
运行后默认会启动web服务,可通过访问 `http://localhost:9899/chat` 在网页端对话。
|
||||
|
||||
如果需要接入其他应用通道只需修改 `config.json` 配置文件中的 `channel_type` 参数,详情参考:[通道说明](#通道说明)。
|
||||
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
使用nohup命令在后台运行程序:
|
||||
在服务器中可使用 `nohup` 命令在后台运行程序:
|
||||
|
||||
```bash
|
||||
touch nohup.out # 首次运行需要新建日志文件
|
||||
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
|
||||
nohup python3 app.py & tail -f nohup.out
|
||||
```
|
||||
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。
|
||||
|
||||
> **多账号支持:** 将项目复制多份,分别启动程序,用不同账号扫码登录即可实现同时运行。
|
||||
执行后程序运行于服务器后台,可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。 日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。
|
||||
|
||||
> **特殊指令:** 用户向机器人发送 **#reset** 即可清空该用户的上下文记忆。
|
||||
此外,项目的 `scripts` 目录下有一键运行、关闭程序的脚本供使用。 运行后默认channel为web,通过可以通过修改配置文件进行切换。
|
||||
|
||||
|
||||
### 3.Docker部署
|
||||
|
||||
参考文档 [Docker部署](https://github.com/limccn/chatgpt-on-wechat/wiki/Docker%E9%83%A8%E7%BD%B2) (Contributed by [limccn](https://github.com/limccn))。
|
||||
使用docker部署无需下载源码和安装依赖,只需要获取 `docker-compose.yml` 配置文件并启动容器即可。Agent模式下更推荐使用源码进行部署,以获得更多系统访问能力。
|
||||
|
||||
### 4. Railway部署(✅推荐)
|
||||
> Railway每月提供5刀和最多500小时的免费额度。
|
||||
1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)。
|
||||
2. 点击 `Deploy Now` 按钮。
|
||||
3. 设置环境变量来重载程序运行的参数,例如`open_ai_api_key`, `character_desc`。
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功后执行 `docker -v` 和 `docker-compose version` (或 `docker compose version`) 可查看到版本号。安装地址为 [docker官网](https://docs.docker.com/engine/install/) 。
|
||||
|
||||
## 常见问题
|
||||
**(1) 下载 docker-compose.yml 文件**
|
||||
|
||||
```bash
|
||||
wget https://cdn.link-ai.tech/code/cow/docker-compose.yml
|
||||
```
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 填写所需配置,例如 `CHANNEL_TYPE`、`OPEN_AI_API_KEY` 和等配置。
|
||||
|
||||
**(2) 启动容器**
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
|
||||
```bash
|
||||
sudo docker compose up -d # 若docker-compose为 1.X 版本,则执行 `sudo docker-compose up -d`
|
||||
```
|
||||
|
||||
运行命令后,会自动取 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取最新release版本的镜像。当执行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。最后执行以下命令可查看容器的运行日志:
|
||||
|
||||
```bash
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
**(3) 插件使用**
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
|
||||
```
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
**注**:使用docker方式部署的详细教程可以参考:[docker部署CoW项目](https://www.wangpc.cc/ai/docker-deploy-cow/)
|
||||
|
||||
|
||||
## 模型说明
|
||||
|
||||
以下对所有可支持的模型的配置和使用方法进行说明,模型接口实现在项目的 `models/` 目录下。
|
||||
|
||||
<details>
|
||||
<summary>OpenAI</summary>
|
||||
|
||||
1. API Key创建:在 [OpenAI平台](https://platform.openai.com/api-keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4.1-mini",
|
||||
"open_ai_api_key": "YOUR_API_KEY",
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"bot_type": "chatGPT"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 与OpenAI接口的 [model参数](https://platform.openai.com/docs/models) 一致,支持包括 o系列、gpt-5.2、gpt-5.1、gpt-4.1等系列模型
|
||||
- `open_ai_api_base`: 如果需要接入第三方代理接口,可通过修改该参数进行接入
|
||||
- `bot_type`: 使用OpenAI相关模型时无需填写。当使用第三方代理接口接入Claude等非OpenAI官方模型时,该参数设为 `chatGPT`
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>LinkAI</summary>
|
||||
|
||||
1. API Key创建:在 [LinkAI平台](https://link-ai.tech/console/interface) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"use_linkai": true,
|
||||
"linkai_api_key": "YOUR API KEY",
|
||||
"linkai_app_code": "YOUR APP CODE"
|
||||
}
|
||||
```
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,默认关闭,设置为true后可对接LinkAI平台的智能体,使用知识库、工作流、数据库、MCP插件等丰富的Agent能力
|
||||
+ `linkai_api_key`: LinkAI平台的API Key,可在 [控制台](https://link-ai.tech/console/interface) 中创建
|
||||
+ `linkai_app_code`: LinkAI智能体 (应用或工作流) 的code,选填,普通对话模式可用。智能体创建可参考 [说明文档](https://docs.link-ai.tech/platform/quick-start)
|
||||
+ `model`: model字段填写空则直接使用智能体的模型,可在平台中灵活切换,[模型列表](https://link-ai.tech/console/models)中的全部模型均可使用
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>MiniMax</summary>
|
||||
|
||||
方式一:官方接入,配置如下(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "MiniMax-M2.5",
|
||||
"minimax_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `MiniMax-M2.5、MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2、abab6.5-chat` 等
|
||||
- `minimax_api_key`:MiniMax平台的API-KEY,在 [控制台](https://platform.minimaxi.com/user-center/basic-information/interface-key) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "MiniMax-M2.5",
|
||||
"open_ai_api_base": "https://api.minimaxi.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `MiniMax-M2.5、MiniMax-M2.1、MiniMax-M2.1-lightning、MiniMax-M2`,参考[API文档](https://platform.minimaxi.com/document/%E5%AF%B9%E8%AF%9D?key=66701d281d57f38758d581d0#QklxsNSbaf6kM4j6wjO5eEek)
|
||||
- `open_ai_api_base`: MiniMax平台API的 BASE URL
|
||||
- `open_ai_api_key`: MiniMax平台的API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>智谱AI (GLM)</summary>
|
||||
|
||||
方式一:官方接入,配置如下(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "glm-5",
|
||||
"zhipu_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `glm-5、glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等, 参考 [glm系列模型编码](https://bigmodel.cn/dev/api/normal-model/glm-4)
|
||||
- `zhipu_ai_api_key`: 智谱AI平台的 API KEY,在 [控制台](https://www.bigmodel.cn/usercenter/proj-mgmt/apikeys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "glm-5",
|
||||
"open_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `glm-5、glm-4.7、glm-4-plus、glm-4-flash、glm-4-air、glm-4-airx、glm-4-long` 等
|
||||
- `open_ai_api_base`: 智谱AI平台的 BASE URL
|
||||
- `open_ai_api_key`: 智谱AI平台的 API KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>通义千问 (Qwen)</summary>
|
||||
|
||||
方式一:官方SDK接入,配置如下(推荐):
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen3.5-plus",
|
||||
"dashscope_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `qwen3.5-plus、qwen3-max、qwen-max、qwen-plus、qwen-turbo、qwen-long、qwq-plus` 等
|
||||
- `dashscope_api_key`: 通义千问的 API-KEY,参考 [官方文档](https://bailian.console.aliyun.com/?tab=api#/api) ,在 [控制台](https://bailian.console.aliyun.com/?tab=model#/api-key) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "qwen3.5-plus",
|
||||
"open_ai_api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"open_ai_api_key": "sk-qVxxxxG"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://help.aliyun.com/zh/model-studio/models?spm=a2c4g.11186623.0.0.78d84823Kth5on#9f8890ce29g5u)
|
||||
- `open_ai_api_base`: 通义千问API的 BASE URL
|
||||
- `open_ai_api_key`: 通义千问的 API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Kimi (Moonshot)</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "kimi-k2.5",
|
||||
"moonshot_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `kimi-k2.5、kimi-k2、moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k`
|
||||
- `moonshot_api_key`: Moonshot的API-KEY,在 [控制台](https://platform.moonshot.cn/console/api-keys) 创建
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "kimi-k2.5",
|
||||
"open_ai_api_base": "https://api.moonshot.cn/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写 `kimi-k2.5、kimi-k2、moonshot-v1-8k、moonshot-v1-32k、moonshot-v1-128k`
|
||||
- `open_ai_api_base`: Moonshot的 BASE URL
|
||||
- `open_ai_api_key`: Moonshot的 API-KEY
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>豆包 (Doubao)</summary>
|
||||
|
||||
1. API Key创建:在 [火山方舟控制台](https://console.volcengine.com/ark/region:ark+cn-beijing/apikey) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "doubao-seed-2-0-code-preview-260215",
|
||||
"ark_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 可填写 `doubao-seed-2-0-code-preview-260215、doubao-seed-2-0-pro-260215、doubao-seed-2-0-lite-260215、doubao-seed-2-0-mini-260215` 等
|
||||
- `ark_api_key`: 火山方舟平台的 API Key,在 [控制台](https://console.volcengine.com/ark/region:ark+cn-beijing/apikey) 创建
|
||||
- `ark_base_url`: 可选,默认为 `https://ark.cn-beijing.volces.com/api/v3`
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Claude</summary>
|
||||
|
||||
1. API Key创建:在 [Claude控制台](https://console.anthropic.com/settings/keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"claude_api_key": "YOUR_API_KEY"
|
||||
}
|
||||
```
|
||||
- `model`: 参考 [官方模型ID](https://docs.anthropic.com/en/docs/about-claude/models/overview#model-aliases) ,支持 `claude-sonnet-4-6、claude-opus-4-6、claude-sonnet-4-5、claude-sonnet-4-0、claude-opus-4-0、claude-3-5-sonnet-latest` 等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Gemini</summary>
|
||||
|
||||
API Key创建:在 [控制台](https://aistudio.google.com/app/apikey?hl=zh-cn) 创建API Key ,配置如下
|
||||
```json
|
||||
{
|
||||
"model": "gemini-3.1-pro-preview",
|
||||
"gemini_api_key": ""
|
||||
}
|
||||
```
|
||||
- `model`: 参考[官方文档-模型列表](https://ai.google.dev/gemini-api/docs/models?hl=zh-cn),支持 `gemini-3.1-pro-preview、gemini-3-flash-preview、gemini-3-pro-preview、gemini-2.5-pro、gemini-2.0-flash` 等
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>DeepSeek</summary>
|
||||
|
||||
1. API Key创建:在 [DeepSeek平台](https://platform.deepseek.com/api_keys) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "deepseek-chat",
|
||||
"open_ai_api_key": "sk-xxxxxxxxxxx",
|
||||
"open_ai_api_base": "https://api.deepseek.com/v1",
|
||||
"bot_type": "chatGPT"
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填 `deepseek-chat、deepseek-reasoner`,分别对应的是 DeepSeek-V3 和 DeepSeek-R1 模型
|
||||
- `open_ai_api_key`: DeepSeek平台的 API Key
|
||||
- `open_ai_api_base`: DeepSeek平台 BASE URL
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Azure</summary>
|
||||
|
||||
1. API Key创建:在 [Azure平台](https://oai.azure.com/) 创建API Key
|
||||
|
||||
2. 填写配置
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "",
|
||||
"use_azure_chatgpt": true,
|
||||
"open_ai_api_key": "",
|
||||
"open_ai_api_base": "",
|
||||
"azure_deployment_id": "",
|
||||
"azure_api_version": "2025-01-01-preview"
|
||||
}
|
||||
```
|
||||
|
||||
- `model`: 留空即可
|
||||
- `use_azure_chatgpt`: 设为 true
|
||||
- `open_ai_api_key`: Azure平台的密钥
|
||||
- `open_ai_api_base`: Azure平台的 BASE URL
|
||||
- `azure_deployment_id`: Azure平台部署的模型名称
|
||||
- `azure_api_version`: api版本以及以上参数可以在部署的 [模型配置](https://oai.azure.com/resource/deployments) 界面查看
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>百度文心</summary>
|
||||
方式一:官方SDK接入,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "wenxin-4",
|
||||
"baidu_wenxin_api_key": "IajztZ0bDxgnP9bEykU7lBer",
|
||||
"baidu_wenxin_secret_key": "EDPZn6L24uAS9d8RWFfotK47dPvkjD6G"
|
||||
}
|
||||
```
|
||||
- `model`: 可填 `wenxin`和`wenxin-4`,对应模型为 文心-3.5 和 文心-4.0
|
||||
- `baidu_wenxin_api_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 API Key
|
||||
- `baidu_wenxin_secret_key`:参考 [千帆平台-access_token鉴权](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/dlv4pct3s) 文档获取 Secret Key
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "ERNIE-4.0-Turbo-8K",
|
||||
"open_ai_api_base": "https://qianfan.baidubce.com/v2",
|
||||
"open_ai_api_key": "bce-v3/ALTxxxxxxd2b"
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 支持官方所有模型,参考[模型列表](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Wm9cvy6rl)
|
||||
- `open_ai_api_base`: 百度文心API的 BASE URL
|
||||
- `open_ai_api_key`: 百度文心的 API-KEY,参考 [官方文档](https://cloud.baidu.com/doc/qianfan-api/s/ym9chdsy5) ,在 [控制台](https://console.bce.baidu.com/iam/#/iam/apikey/list) 创建API Key
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>讯飞星火</summary>
|
||||
|
||||
方式一:官方接入,配置如下:
|
||||
参考 [官方文档-快速指引](https://www.xfyun.cn/doc/platform/quickguide.html#%E7%AC%AC%E4%BA%8C%E6%AD%A5-%E5%88%9B%E5%BB%BA%E6%82%A8%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8-%E5%BC%80%E5%A7%8B%E4%BD%BF%E7%94%A8%E6%9C%8D%E5%8A%A1) 获取 `APPID、 APISecret、 APIKey` 三个参数
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "xunfei",
|
||||
"xunfei_app_id": "",
|
||||
"xunfei_api_key": "",
|
||||
"xunfei_api_secret": "",
|
||||
"xunfei_domain": "4.0Ultra",
|
||||
"xunfei_spark_url": "wss://spark-api.xf-yun.com/v4.0/chat"
|
||||
}
|
||||
```
|
||||
- `model`: 填 `xunfei`
|
||||
- `xunfei_domain`: 可填写 `4.0Ultra、generalv3.5、max-32k、generalv3、pro-128k、lite`
|
||||
- `xunfei_spark_url`: 填写参考 [官方文档-请求地址](https://www.xfyun.cn/doc/spark/Web.html#_1-1-%E8%AF%B7%E6%B1%82%E5%9C%B0%E5%9D%80) 的说明
|
||||
|
||||
方式二:OpenAI兼容方式接入,配置如下:
|
||||
```json
|
||||
{
|
||||
"bot_type": "chatGPT",
|
||||
"model": "4.0Ultra",
|
||||
"open_ai_api_base": "https://spark-api-open.xf-yun.com/v1",
|
||||
"open_ai_api_key": ""
|
||||
}
|
||||
```
|
||||
- `bot_type`: OpenAI兼容方式
|
||||
- `model`: 可填写 `4.0Ultra、generalv3.5、max-32k、generalv3、pro-128k、lite`
|
||||
- `open_ai_api_base`: 讯飞星火平台的 BASE URL
|
||||
- `open_ai_api_key`: 讯飞星火平台的[APIPassword](https://console.xfyun.cn/services/bm3) ,因模型而已
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>ModelScope</summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"bot_type": "modelscope",
|
||||
"model": "Qwen/QwQ-32B",
|
||||
"modelscope_api_key": "your_api_key",
|
||||
"modelscope_base_url": "https://api-inference.modelscope.cn/v1/chat/completions",
|
||||
"text_to_image": "MusePublic/489_ckpt_FLUX_1"
|
||||
}
|
||||
```
|
||||
|
||||
- `bot_type`: modelscope接口格式
|
||||
- `model`: 参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
- `modelscope_api_key`: 参考 [官方文档-访问令牌](https://modelscope.cn/docs/accounts/token) ,在 [控制台](https://modelscope.cn/my/myaccesstoken)
|
||||
- `modelscope_base_url`: modelscope平台的 BASE URL
|
||||
- `text_to_image`: 图像生成模型,参考[模型列表](https://www.modelscope.cn/models?filter=inference_type&page=1)
|
||||
</details>
|
||||
|
||||
|
||||
## 通道说明
|
||||
|
||||
以下对可接入通道的配置方式进行说明,应用通道代码在项目的 `channel/` 目录下。
|
||||
|
||||
支持同时可接入多个通道,配置时可通过逗号进行分割,例如 `"channel_type": "feishu,dingtalk"`。
|
||||
|
||||
<details>
|
||||
<summary>1. Web</summary>
|
||||
|
||||
项目启动后会默认运行Web控制台,配置如下:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "web",
|
||||
"web_port": 9899
|
||||
}
|
||||
```
|
||||
|
||||
- `web_port`: 默认为 9899,可按需更改,需要服务器防火墙和安全组放行该端口
|
||||
- 如本地运行,启动后请访问 `http://localhost:9899/chat` ;如服务器运行,请访问 `http://ip:9899/chat`
|
||||
> 注:请将上述 url 中的 ip 或者 port 替换为实际的值
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>2. Feishu - 飞书</summary>
|
||||
|
||||
飞书支持两种事件接收模式:WebSocket 长连接(推荐)和 Webhook。
|
||||
|
||||
**方式一:WebSocket 模式(推荐,无需公网 IP)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "APP_ID",
|
||||
"feishu_app_secret": "APP_SECRET",
|
||||
"feishu_event_mode": "websocket"
|
||||
}
|
||||
```
|
||||
|
||||
**方式二:Webhook 模式(需要公网 IP)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "APP_ID",
|
||||
"feishu_app_secret": "APP_SECRET",
|
||||
"feishu_token": "VERIFICATION_TOKEN",
|
||||
"feishu_event_mode": "webhook",
|
||||
"feishu_port": 9891
|
||||
}
|
||||
```
|
||||
|
||||
- `feishu_event_mode`: 事件接收模式,`websocket`(推荐)或 `webhook`
|
||||
- WebSocket 模式需安装依赖:`pip3 install lark-oapi`
|
||||
|
||||
详细步骤和参数说明参考 [飞书接入](https://docs.link-ai.tech/cow/multi-platform/feishu)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>3. DingTalk - 钉钉</summary>
|
||||
|
||||
钉钉需要在开放平台创建智能机器人应用,将以下配置填入 `config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "dingtalk",
|
||||
"dingtalk_client_id": "CLIENT_ID",
|
||||
"dingtalk_client_secret": "CLIENT_SECRET"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [钉钉接入](https://docs.link-ai.tech/cow/multi-platform/dingtalk)
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>4. WeCom App - 企业微信应用</summary>
|
||||
|
||||
企业微信自建应用接入需在后台创建应用并启用消息回调,配置示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatcom_app",
|
||||
"wechatcom_corp_id": "CORPID",
|
||||
"wechatcomapp_token": "TOKEN",
|
||||
"wechatcomapp_port": 9898,
|
||||
"wechatcomapp_secret": "SECRET",
|
||||
"wechatcomapp_agent_id": "AGENTID",
|
||||
"wechatcomapp_aes_key": "AESKEY"
|
||||
}
|
||||
```
|
||||
详细步骤和参数说明参考 [企微自建应用接入](https://docs.link-ai.tech/cow/multi-platform/wechat-com)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>5. WeChat MP - 微信公众号</summary>
|
||||
|
||||
本项目支持订阅号和服务号两种公众号,通过服务号(`wechatmp_service`)体验更佳。
|
||||
|
||||
**个人订阅号(wechatmp)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatmp",
|
||||
"wechatmp_token": "TOKEN",
|
||||
"wechatmp_port": 80,
|
||||
"wechatmp_app_id": "APPID",
|
||||
"wechatmp_app_secret": "APPSECRET",
|
||||
"wechatmp_aes_key": ""
|
||||
}
|
||||
```
|
||||
|
||||
**企业服务号(wechatmp_service)**
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "wechatmp_service",
|
||||
"wechatmp_token": "TOKEN",
|
||||
"wechatmp_port": 80,
|
||||
"wechatmp_app_id": "APPID",
|
||||
"wechatmp_app_secret": "APPSECRET",
|
||||
"wechatmp_aes_key": ""
|
||||
}
|
||||
```
|
||||
|
||||
详细步骤和参数说明参考 [微信公众号接入](https://docs.link-ai.tech/cow/multi-platform/wechat-mp)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>6. Terminal - 终端</summary>
|
||||
|
||||
修改 `config.json` 中的 `channel_type` 字段:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "terminal"
|
||||
}
|
||||
```
|
||||
|
||||
运行后可在终端与机器人进行对话。
|
||||
|
||||
</details>
|
||||
|
||||
<br/>
|
||||
|
||||
# 🔗 相关项目
|
||||
|
||||
- [bot-on-anything](https://github.com/zhayujie/bot-on-anything):轻量和高可扩展的大模型应用框架,支持接入Slack, Telegram, Discord, Gmail等海外平台,可作为本项目的补充使用。
|
||||
- [AgentMesh](https://github.com/MinimalFuture/AgentMesh):开源的多智能体(Multi-Agent)框架,可以通过多智能体团队的协同来解决复杂问题。本项目基于该框架实现了[Agent插件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md),可访问终端、浏览器、文件系统、搜索引擎 等各类工具,并实现了多智能体协同。
|
||||
|
||||
|
||||
|
||||
# 🔎 常见问题
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (知识库持续完善中,回复供参考)
|
||||
|
||||
## 联系
|
||||
# 🛠️ 开发
|
||||
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题优先查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索,若无相似问题可创建Issue,或加微信 eijuyahz 交流。
|
||||
欢迎接入更多应用通道,参考 [飞书通道](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/feishu/feishu_channel.py) 新增自定义通道,实现接收和发送消息逻辑即可完成接入。 同时欢迎贡献新的Skills,参考 [Skill创造器说明](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/skills/skill-creator/SKILL.md)。
|
||||
|
||||
|
||||
# ✉ 联系
|
||||
|
||||
欢迎提交PR、Issues进行反馈,以及通过 🌟Star 支持并关注项目更新。项目运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,以及前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品客服](https://cdn.link-ai.tech/portal/linkai-customer-service.png)咨询。
|
||||
|
||||
# 🌟 贡献者
|
||||
|
||||

|
||||
|
||||
3
agent/chat/__init__.py
Normal file
3
agent/chat/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.chat.service import ChatService
|
||||
|
||||
__all__ = ["ChatService"]
|
||||
169
agent/chat/service.py
Normal file
169
agent/chat/service.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
ChatService - Wraps the Agent stream execution to produce CHAT protocol chunks.
|
||||
|
||||
Translates agent events (message_update, message_end, tool_execution_end, etc.)
|
||||
into the CHAT socket protocol format (content chunks with segment_id, tool_calls chunks).
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""
|
||||
High-level service that runs an Agent for a given query and streams
|
||||
the results as CHAT protocol chunks via a callback.
|
||||
|
||||
Usage:
|
||||
svc = ChatService(agent_bridge)
|
||||
svc.run(query, session_id, send_chunk_fn)
|
||||
"""
|
||||
|
||||
def __init__(self, agent_bridge):
|
||||
"""
|
||||
:param agent_bridge: AgentBridge instance (manages agent lifecycle)
|
||||
"""
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def run(self, query: str, session_id: str, send_chunk_fn: Callable[[dict], None]):
|
||||
"""
|
||||
Run the agent for *query* and stream results back via *send_chunk_fn*.
|
||||
|
||||
The method blocks until the agent finishes. After it returns the SDK
|
||||
will automatically send the final (streaming=false) message.
|
||||
|
||||
:param query: user query text
|
||||
:param session_id: session identifier for agent isolation
|
||||
:param send_chunk_fn: callable(chunk_data: dict) to send a streaming chunk
|
||||
"""
|
||||
agent = self.agent_bridge.get_agent(session_id=session_id)
|
||||
if agent is None:
|
||||
raise RuntimeError("Failed to initialise agent for the session")
|
||||
|
||||
# State shared between the event callback and this method
|
||||
state = _StreamState()
|
||||
|
||||
def on_event(event: dict):
|
||||
"""Translate agent events into CHAT protocol chunks."""
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
if event_type == "message_update":
|
||||
# Incremental text delta
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
send_chunk_fn({
|
||||
"chunk_type": "content",
|
||||
"delta": delta,
|
||||
"segment_id": state.segment_id,
|
||||
})
|
||||
|
||||
elif event_type == "message_end":
|
||||
# A content segment finished.
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
# After tool_calls are executed the next content will be
|
||||
# a new segment; collect tool results until turn_end.
|
||||
state.pending_tool_results = []
|
||||
|
||||
elif event_type == "tool_execution_end":
|
||||
tool_name = data.get("tool_name", "")
|
||||
arguments = data.get("arguments", {})
|
||||
result = data.get("result", "")
|
||||
status = data.get("status", "unknown")
|
||||
execution_time = data.get("execution_time", 0)
|
||||
elapsed_str = f"{execution_time:.2f}s"
|
||||
|
||||
# Serialise result to string if needed
|
||||
if not isinstance(result, str):
|
||||
import json
|
||||
try:
|
||||
result = json.dumps(result, ensure_ascii=False)
|
||||
except Exception:
|
||||
result = str(result)
|
||||
|
||||
tool_info = {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
"result": result,
|
||||
"status": status,
|
||||
"elapsed": elapsed_str,
|
||||
}
|
||||
|
||||
if state.pending_tool_results is not None:
|
||||
state.pending_tool_results.append(tool_info)
|
||||
|
||||
elif event_type == "turn_end":
|
||||
has_tool_calls = data.get("has_tool_calls", False)
|
||||
if has_tool_calls and state.pending_tool_results:
|
||||
# Flush collected tool results as a single tool_calls chunk
|
||||
send_chunk_fn({
|
||||
"chunk_type": "tool_calls",
|
||||
"tool_calls": state.pending_tool_results,
|
||||
})
|
||||
state.pending_tool_results = None
|
||||
# Next content belongs to a new segment
|
||||
state.segment_id += 1
|
||||
|
||||
# Run the agent with our event callback ---------------------------
|
||||
logger.info(f"[ChatService] Starting agent run: session={session_id}, query={query[:80]}")
|
||||
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 30)
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = agent.get_full_system_prompt()
|
||||
|
||||
# Create a copy of messages for this execution
|
||||
with agent.messages_lock:
|
||||
messages_copy = agent.messages.copy()
|
||||
original_length = len(agent.messages)
|
||||
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
|
||||
executor = AgentStreamExecutor(
|
||||
agent=agent,
|
||||
model=agent.model,
|
||||
system_prompt=full_system_prompt,
|
||||
tools=agent.tools,
|
||||
max_turns=agent.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy,
|
||||
max_context_turns=max_context_turns,
|
||||
)
|
||||
|
||||
try:
|
||||
response = executor.run_stream(query)
|
||||
except Exception:
|
||||
# If executor cleared messages (context overflow), sync back
|
||||
if len(executor.messages) == 0:
|
||||
with agent.messages_lock:
|
||||
agent.messages.clear()
|
||||
logger.info("[ChatService] Cleared agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Append only the NEW messages from this execution (thread-safe)
|
||||
with agent.messages_lock:
|
||||
new_messages = executor.messages[original_length:]
|
||||
agent.messages.extend(new_messages)
|
||||
|
||||
# Store executor reference for files_to_send access
|
||||
agent.stream_executor = executor
|
||||
|
||||
# Execute post-process tools
|
||||
agent._execute_post_process_tools()
|
||||
|
||||
logger.info(f"[ChatService] Agent run completed: session={session_id}")
|
||||
|
||||
|
||||
|
||||
class _StreamState:
|
||||
"""Mutable state shared between the event callback and the run method."""
|
||||
|
||||
def __init__(self):
|
||||
self.segment_id: int = 0
|
||||
# None means we are not accumulating tool results right now.
|
||||
# A list means we are in the middle of a tool-execution phase.
|
||||
self.pending_tool_results: Optional[list] = None
|
||||
21
agent/memory/__init__.py
Normal file
21
agent/memory/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Memory module for AgentMesh
|
||||
|
||||
Provides both long-term memory (vector/keyword search) and short-term
|
||||
conversation history persistence (SQLite).
|
||||
"""
|
||||
|
||||
from agent.memory.manager import MemoryManager
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config, set_global_memory_config
|
||||
from agent.memory.embedding import create_embedding_provider
|
||||
from agent.memory.conversation_store import ConversationStore, get_conversation_store
|
||||
|
||||
__all__ = [
|
||||
'MemoryManager',
|
||||
'MemoryConfig',
|
||||
'get_default_memory_config',
|
||||
'set_global_memory_config',
|
||||
'create_embedding_provider',
|
||||
'ConversationStore',
|
||||
'get_conversation_store',
|
||||
]
|
||||
140
agent/memory/chunker.py
Normal file
140
agent/memory/chunker.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Text chunking utilities for memory
|
||||
|
||||
Splits text into chunks with token limits and overlap
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""Represents a text chunk with line numbers"""
|
||||
text: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
|
||||
|
||||
class TextChunker:
|
||||
"""Chunks text by line count with token estimation"""
|
||||
|
||||
def __init__(self, max_tokens: int = 500, overlap_tokens: int = 50):
|
||||
"""
|
||||
Initialize chunker
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum tokens per chunk
|
||||
overlap_tokens: Overlap tokens between chunks
|
||||
"""
|
||||
self.max_tokens = max_tokens
|
||||
self.overlap_tokens = overlap_tokens
|
||||
# Rough estimation: ~4 chars per token for English/Chinese mixed
|
||||
self.chars_per_token = 4
|
||||
|
||||
def chunk_text(self, text: str) -> List[TextChunk]:
|
||||
"""
|
||||
Chunk text into overlapping segments
|
||||
|
||||
Args:
|
||||
text: Input text to chunk
|
||||
|
||||
Returns:
|
||||
List of TextChunk objects
|
||||
"""
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
lines = text.split('\n')
|
||||
chunks = []
|
||||
|
||||
max_chars = self.max_tokens * self.chars_per_token
|
||||
overlap_chars = self.overlap_tokens * self.chars_per_token
|
||||
|
||||
current_chunk = []
|
||||
current_chars = 0
|
||||
start_line = 1
|
||||
|
||||
for i, line in enumerate(lines, start=1):
|
||||
line_chars = len(line)
|
||||
|
||||
# If single line exceeds max, split it
|
||||
if line_chars > max_chars:
|
||||
# Save current chunk if exists
|
||||
if current_chunk:
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=i - 1
|
||||
))
|
||||
current_chunk = []
|
||||
current_chars = 0
|
||||
|
||||
# Split long line into multiple chunks
|
||||
for sub_chunk in self._split_long_line(line, max_chars):
|
||||
chunks.append(TextChunk(
|
||||
text=sub_chunk,
|
||||
start_line=i,
|
||||
end_line=i
|
||||
))
|
||||
|
||||
start_line = i + 1
|
||||
continue
|
||||
|
||||
# Check if adding this line would exceed limit
|
||||
if current_chars + line_chars > max_chars and current_chunk:
|
||||
# Save current chunk
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=i - 1
|
||||
))
|
||||
|
||||
# Start new chunk with overlap
|
||||
overlap_lines = self._get_overlap_lines(current_chunk, overlap_chars)
|
||||
current_chunk = overlap_lines + [line]
|
||||
current_chars = sum(len(l) for l in current_chunk)
|
||||
start_line = i - len(overlap_lines)
|
||||
else:
|
||||
# Add line to current chunk
|
||||
current_chunk.append(line)
|
||||
current_chars += line_chars
|
||||
|
||||
# Save last chunk
|
||||
if current_chunk:
|
||||
chunks.append(TextChunk(
|
||||
text='\n'.join(current_chunk),
|
||||
start_line=start_line,
|
||||
end_line=len(lines)
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def _split_long_line(self, line: str, max_chars: int) -> List[str]:
|
||||
"""Split a single long line into multiple chunks"""
|
||||
chunks = []
|
||||
for i in range(0, len(line), max_chars):
|
||||
chunks.append(line[i:i + max_chars])
|
||||
return chunks
|
||||
|
||||
def _get_overlap_lines(self, lines: List[str], target_chars: int) -> List[str]:
|
||||
"""Get last few lines that fit within target_chars for overlap"""
|
||||
overlap = []
|
||||
chars = 0
|
||||
|
||||
for line in reversed(lines):
|
||||
line_chars = len(line)
|
||||
if chars + line_chars > target_chars:
|
||||
break
|
||||
overlap.insert(0, line)
|
||||
chars += line_chars
|
||||
|
||||
return overlap
|
||||
|
||||
def chunk_markdown(self, text: str) -> List[TextChunk]:
|
||||
"""
|
||||
Chunk markdown text while respecting structure
|
||||
(For future enhancement: respect markdown sections)
|
||||
"""
|
||||
return self.chunk_text(text)
|
||||
125
agent/memory/config.py
Normal file
125
agent/memory/config.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Memory configuration module
|
||||
|
||||
Provides global memory configuration with simplified workspace structure
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _default_workspace():
|
||||
"""Get default workspace path with proper Windows support"""
|
||||
from common.utils import expand_path
|
||||
return expand_path("~/cow")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig:
|
||||
"""Configuration for memory storage and search"""
|
||||
|
||||
# Storage paths (default: ~/cow)
|
||||
workspace_root: str = field(default_factory=_default_workspace)
|
||||
|
||||
# Embedding config
|
||||
embedding_provider: str = "openai" # "openai" | "local"
|
||||
embedding_model: str = "text-embedding-3-small"
|
||||
embedding_dim: int = 1536
|
||||
|
||||
# Chunking config
|
||||
chunk_max_tokens: int = 500
|
||||
chunk_overlap_tokens: int = 50
|
||||
|
||||
# Search config
|
||||
max_results: int = 10
|
||||
min_score: float = 0.1
|
||||
|
||||
# Hybrid search weights
|
||||
vector_weight: float = 0.7
|
||||
keyword_weight: float = 0.3
|
||||
|
||||
# Memory sources
|
||||
sources: List[str] = field(default_factory=lambda: ["memory", "session"])
|
||||
|
||||
# Sync config
|
||||
enable_auto_sync: bool = True
|
||||
sync_on_search: bool = True
|
||||
|
||||
# Memory flush config (独立于模型 context window)
|
||||
flush_token_threshold: int = 50000 # 50K tokens 触发 flush
|
||||
flush_turn_threshold: int = 20 # 20 轮对话触发 flush (用户+AI各一条为一轮)
|
||||
|
||||
def get_workspace(self) -> Path:
|
||||
"""Get workspace root directory"""
|
||||
return Path(self.workspace_root)
|
||||
|
||||
def get_memory_dir(self) -> Path:
|
||||
"""Get memory files directory"""
|
||||
return self.get_workspace() / "memory"
|
||||
|
||||
def get_db_path(self) -> Path:
|
||||
"""Get SQLite database path for long-term memory index"""
|
||||
index_dir = self.get_memory_dir() / "long-term"
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
return index_dir / "index.db"
|
||||
|
||||
def get_skills_dir(self) -> Path:
|
||||
"""Get skills directory"""
|
||||
return self.get_workspace() / "skills"
|
||||
|
||||
def get_agent_workspace(self, agent_name: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get workspace directory for an agent
|
||||
|
||||
Args:
|
||||
agent_name: Optional agent name (not used in current implementation)
|
||||
|
||||
Returns:
|
||||
Path to workspace directory
|
||||
"""
|
||||
workspace = self.get_workspace()
|
||||
# Ensure workspace directory exists
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
return workspace
|
||||
|
||||
|
||||
# Global memory configuration
|
||||
_global_memory_config: Optional[MemoryConfig] = None
|
||||
|
||||
|
||||
def get_default_memory_config() -> MemoryConfig:
|
||||
"""
|
||||
Get the global memory configuration.
|
||||
If not set, returns a default configuration.
|
||||
|
||||
Returns:
|
||||
MemoryConfig instance
|
||||
"""
|
||||
global _global_memory_config
|
||||
if _global_memory_config is None:
|
||||
_global_memory_config = MemoryConfig()
|
||||
return _global_memory_config
|
||||
|
||||
|
||||
def set_global_memory_config(config: MemoryConfig):
|
||||
"""
|
||||
Set the global memory configuration.
|
||||
This should be called before creating any MemoryManager instances.
|
||||
|
||||
Args:
|
||||
config: MemoryConfig instance to use globally
|
||||
|
||||
Example:
|
||||
>>> from agent.memory import MemoryConfig, set_global_memory_config
|
||||
>>> config = MemoryConfig(
|
||||
... workspace_root="~/my_agents",
|
||||
... embedding_provider="openai",
|
||||
... vector_weight=0.8
|
||||
... )
|
||||
>>> set_global_memory_config(config)
|
||||
"""
|
||||
global _global_memory_config
|
||||
_global_memory_config = config
|
||||
618
agent/memory/conversation_store.py
Normal file
618
agent/memory/conversation_store.py
Normal file
@@ -0,0 +1,618 @@
|
||||
"""
|
||||
Conversation history persistence using SQLite.
|
||||
|
||||
Design:
|
||||
- sessions table: per-session metadata (channel_type, last_active, msg_count)
|
||||
- messages table: individual messages stored as JSON, append-only
|
||||
- Pruning: age-based only (sessions not updated within N days are deleted)
|
||||
- Thread-safe via a single in-process lock
|
||||
|
||||
Storage path: ~/cow/sessions/conversations.db
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DDL = """
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
channel_type TEXT NOT NULL DEFAULT '',
|
||||
created_at INTEGER NOT NULL,
|
||||
last_active INTEGER NOT NULL,
|
||||
msg_count INTEGER NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
seq INTEGER NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
UNIQUE (session_id, seq)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_session
|
||||
ON messages (session_id, seq);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_last_active
|
||||
ON sessions (last_active);
|
||||
"""
|
||||
|
||||
# Migration: add channel_type column to existing databases that predate it.
|
||||
_MIGRATION_ADD_CHANNEL_TYPE = """
|
||||
ALTER TABLE sessions ADD COLUMN channel_type TEXT NOT NULL DEFAULT '';
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_AGE_DAYS: int = 30
|
||||
|
||||
|
||||
def _is_visible_user_message(content: Any) -> bool:
|
||||
"""
|
||||
Return True when a user-role message represents actual user input
|
||||
(not an internal tool_result injected by the agent loop).
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return bool(content.strip())
|
||||
if isinstance(content, list):
|
||||
return any(
|
||||
isinstance(b, dict) and b.get("type") == "text"
|
||||
for b in content
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _extract_display_text(content: Any) -> str:
|
||||
"""
|
||||
Extract the human-readable text portion from a message content value.
|
||||
Returns an empty string for tool_use / tool_result blocks.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
b.get("text", "")
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
return "\n".join(p for p in parts if p).strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_tool_calls(content: Any) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Extract tool_use blocks from an assistant message content.
|
||||
Returns a list of {name, arguments} dicts (result filled in later).
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return []
|
||||
return [
|
||||
{"id": b.get("id", ""), "name": b.get("name", ""), "arguments": b.get("input", {})}
|
||||
for b in content
|
||||
if isinstance(b, dict) and b.get("type") == "tool_use"
|
||||
]
|
||||
|
||||
|
||||
def _extract_tool_results(content: Any) -> Dict[str, str]:
|
||||
"""
|
||||
Extract tool_result blocks from a user message, keyed by tool_use_id.
|
||||
"""
|
||||
if not isinstance(content, list):
|
||||
return {}
|
||||
results = {}
|
||||
for b in content:
|
||||
if not isinstance(b, dict) or b.get("type") != "tool_result":
|
||||
continue
|
||||
tool_id = b.get("tool_use_id", "")
|
||||
result_content = b.get("content", "")
|
||||
if isinstance(result_content, list):
|
||||
result_content = "\n".join(
|
||||
rb.get("text", "") for rb in result_content
|
||||
if isinstance(rb, dict) and rb.get("type") == "text"
|
||||
)
|
||||
results[tool_id] = str(result_content)
|
||||
return results
|
||||
|
||||
|
||||
def _group_into_display_turns(
|
||||
rows: List[tuple],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert raw (role, content_json, created_at) DB rows into display turns.
|
||||
|
||||
One display turn = one visible user message + one merged assistant reply.
|
||||
All intermediate assistant messages (those carrying tool_use) and the final
|
||||
assistant text reply produced for the same user query are collapsed into a
|
||||
single assistant turn, exactly matching the live SSE rendering where tools
|
||||
and the final answer appear inside the same bubble.
|
||||
|
||||
Grouping rules:
|
||||
- A visible user message starts a new group.
|
||||
- tool_result user messages are internal; their content is attached to the
|
||||
matching tool_use entry via tool_use_id and they never become own turns.
|
||||
- All assistant messages within a group are merged:
|
||||
* tool_use blocks → tool_calls list (result filled from tool_results)
|
||||
* text blocks → last non-empty text becomes the display content
|
||||
"""
|
||||
# ------------------------------------------------------------------ #
|
||||
# Pass 1: split rows into groups, each starting with a visible user msg
|
||||
# ------------------------------------------------------------------ #
|
||||
# group = (user_row | None, [subsequent_rows])
|
||||
# user_row: (content, created_at)
|
||||
groups: List[tuple] = []
|
||||
cur_user: Optional[tuple] = None
|
||||
cur_rest: List[tuple] = []
|
||||
started = False
|
||||
|
||||
for role, raw_content, created_at in rows:
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
|
||||
if role == "user" and _is_visible_user_message(content):
|
||||
if started:
|
||||
groups.append((cur_user, cur_rest))
|
||||
cur_user = (content, created_at)
|
||||
cur_rest = []
|
||||
started = True
|
||||
else:
|
||||
cur_rest.append((role, content, created_at))
|
||||
|
||||
if started:
|
||||
groups.append((cur_user, cur_rest))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Pass 2: build display turns from each group
|
||||
# ------------------------------------------------------------------ #
|
||||
turns: List[Dict[str, Any]] = []
|
||||
|
||||
for user_row, rest in groups:
|
||||
# User turn
|
||||
if user_row:
|
||||
content, created_at = user_row
|
||||
text = _extract_display_text(content)
|
||||
if text:
|
||||
turns.append({"role": "user", "content": text, "created_at": created_at})
|
||||
|
||||
# Collect all tool_calls and tool_results from the rest of the group
|
||||
all_tool_calls: List[Dict[str, Any]] = []
|
||||
tool_results: Dict[str, str] = {}
|
||||
final_text = ""
|
||||
final_ts: Optional[int] = None
|
||||
|
||||
for role, content, created_at in rest:
|
||||
if role == "user":
|
||||
tool_results.update(_extract_tool_results(content))
|
||||
elif role == "assistant":
|
||||
tcs = _extract_tool_calls(content)
|
||||
all_tool_calls.extend(tcs)
|
||||
t = _extract_display_text(content)
|
||||
if t:
|
||||
final_text = t
|
||||
final_ts = created_at
|
||||
|
||||
# Attach tool results to their matching tool_call entries
|
||||
for tc in all_tool_calls:
|
||||
tc["result"] = tool_results.get(tc.get("id", ""), "")
|
||||
|
||||
if final_text or all_tool_calls:
|
||||
turns.append({
|
||||
"role": "assistant",
|
||||
"content": final_text,
|
||||
"tool_calls": all_tool_calls,
|
||||
"created_at": final_ts or (user_row[1] if user_row else 0),
|
||||
})
|
||||
|
||||
return turns
|
||||
|
||||
|
||||
class ConversationStore:
|
||||
"""
|
||||
SQLite-backed store for per-session conversation history.
|
||||
|
||||
Usage:
|
||||
store = ConversationStore(db_path)
|
||||
store.append_messages("user_123", new_messages, channel_type="feishu")
|
||||
msgs = store.load_messages("user_123", max_turns=30)
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self._db_path = db_path
|
||||
self._lock = threading.Lock()
|
||||
self._init_db()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def load_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
max_turns: int = 30,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load the most recent messages for a session, for injection into the LLM.
|
||||
|
||||
ALL message types (user text, assistant tool_use, tool_result) are returned
|
||||
in their original JSON form so the LLM can reconstruct the full context.
|
||||
|
||||
max_turns is a *visible-turn* count: we count only user messages whose
|
||||
content is actual user text (not tool_result blocks). This prevents
|
||||
tool-heavy sessions from exhausting the turn budget prematurely.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier.
|
||||
max_turns: Maximum number of visible user-assistant turns to keep.
|
||||
|
||||
Returns:
|
||||
Chronologically ordered list of message dicts (role, content).
|
||||
"""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT seq, role, content
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY seq DESC
|
||||
""",
|
||||
(session_id,),
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Walk newest-to-oldest counting *visible* user turns (actual user text,
|
||||
# not tool_result injections). Record the seq of every visible user
|
||||
# message so we can find a clean cut point later.
|
||||
visible_turn_seqs: List[int] = [] # newest first
|
||||
for seq, role, raw_content in rows:
|
||||
if role != "user":
|
||||
continue
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
if _is_visible_user_message(content):
|
||||
visible_turn_seqs.append(seq)
|
||||
|
||||
# Determine the seq of the oldest visible user message we want to keep.
|
||||
# If the total turns fit within max_turns, keep everything.
|
||||
if len(visible_turn_seqs) <= max_turns:
|
||||
cutoff_seq = None # keep all
|
||||
else:
|
||||
# The Nth visible user message (0-indexed) is the oldest we keep.
|
||||
cutoff_seq = visible_turn_seqs[max_turns - 1]
|
||||
|
||||
# Build result in chronological order, starting from cutoff.
|
||||
# IMPORTANT: we start exactly at cutoff_seq (the visible user message),
|
||||
# never mid-group, so tool_use / tool_result pairs are always complete.
|
||||
result = []
|
||||
for seq, role, raw_content in reversed(rows):
|
||||
if cutoff_seq is not None and seq < cutoff_seq:
|
||||
continue
|
||||
try:
|
||||
content = json.loads(raw_content)
|
||||
except Exception:
|
||||
content = raw_content
|
||||
result.append({"role": role, "content": content})
|
||||
return result
|
||||
|
||||
def append_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
channel_type: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Append new messages to a session's history.
|
||||
|
||||
Seq numbers continue from the session's current maximum, so
|
||||
concurrent callers on distinct sessions never collide.
|
||||
|
||||
Args:
|
||||
session_id: Unique session identifier.
|
||||
messages: List of message dicts to append.
|
||||
channel_type: Source channel (e.g. "feishu", "web", "wechat").
|
||||
Only written on session creation; ignored on update.
|
||||
"""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
now = int(time.time())
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
# INSERT OR IGNORE creates the row on first visit;
|
||||
# the UPDATE always refreshes last_active.
|
||||
# Avoids ON CONFLICT...DO UPDATE (requires SQLite >= 3.24).
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO sessions
|
||||
(session_id, channel_type, created_at, last_active, msg_count)
|
||||
VALUES (?, ?, ?, ?, 0)
|
||||
""",
|
||||
(session_id, channel_type, now, now),
|
||||
)
|
||||
conn.execute(
|
||||
"UPDATE sessions SET last_active = ? WHERE session_id = ?",
|
||||
(now, session_id),
|
||||
)
|
||||
|
||||
# Determine starting seq for the new batch.
|
||||
row = conn.execute(
|
||||
"SELECT COALESCE(MAX(seq), -1) FROM messages WHERE session_id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
next_seq = row[0] + 1
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = json.dumps(
|
||||
msg.get("content", ""), ensure_ascii=False
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO messages
|
||||
(session_id, seq, role, content, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(session_id, next_seq, role, content, now),
|
||||
)
|
||||
next_seq += 1
|
||||
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET msg_count = (
|
||||
SELECT COUNT(*) FROM messages WHERE session_id = ?
|
||||
)
|
||||
WHERE session_id = ?
|
||||
""",
|
||||
(session_id, session_id),
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def clear_session(self, session_id: str) -> None:
|
||||
"""Delete all messages and the session record for a given session_id."""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM sessions WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def cleanup_old_sessions(self, max_age_days: Optional[int] = None) -> int:
|
||||
"""
|
||||
Delete sessions that have not been active within max_age_days.
|
||||
|
||||
Args:
|
||||
max_age_days: Override the default retention period.
|
||||
|
||||
Returns:
|
||||
Number of sessions deleted.
|
||||
"""
|
||||
try:
|
||||
from config import conf
|
||||
max_age = max_age_days or conf().get(
|
||||
"conversation_max_age_days", DEFAULT_MAX_AGE_DAYS
|
||||
)
|
||||
except Exception:
|
||||
max_age = max_age_days or DEFAULT_MAX_AGE_DAYS
|
||||
|
||||
cutoff = int(time.time()) - max_age * 86400
|
||||
deleted = 0
|
||||
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
with conn:
|
||||
stale = conn.execute(
|
||||
"SELECT session_id FROM sessions WHERE last_active < ?",
|
||||
(cutoff,),
|
||||
).fetchall()
|
||||
for (sid,) in stale:
|
||||
conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (sid,)
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM sessions WHERE session_id = ?", (sid,)
|
||||
)
|
||||
deleted += 1
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
if deleted:
|
||||
logger.info(f"[ConversationStore] Pruned {deleted} expired sessions")
|
||||
return deleted
|
||||
|
||||
def load_history_page(
|
||||
self,
|
||||
session_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load a page of conversation history for UI display, grouped into turns.
|
||||
|
||||
Each "turn" maps to one of:
|
||||
- A user message (role="user", content=str)
|
||||
- An assistant message (role="assistant", content=str,
|
||||
tool_calls=[{name, arguments, result}] when tools were used)
|
||||
|
||||
Internal tool_result user messages are merged into the preceding
|
||||
assistant entry's tool_calls list and never appear as standalone items.
|
||||
|
||||
Pages are numbered from 1 (most recent). Messages within a page are
|
||||
returned in chronological order.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user" | "assistant",
|
||||
"content": str,
|
||||
"tool_calls": [...], # assistant only, may be []
|
||||
"created_at": int,
|
||||
},
|
||||
...
|
||||
],
|
||||
"total": <visible turn count>,
|
||||
"page": <current page>,
|
||||
"page_size": <page_size>,
|
||||
"has_more": bool,
|
||||
}
|
||||
"""
|
||||
page = max(1, page)
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT role, content, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY seq ASC
|
||||
""",
|
||||
(session_id,),
|
||||
).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
visible = _group_into_display_turns(rows)
|
||||
|
||||
total = len(visible)
|
||||
offset = (page - 1) * page_size
|
||||
page_items = list(reversed(visible))[offset: offset + page_size]
|
||||
page_items = list(reversed(page_items))
|
||||
|
||||
return {
|
||||
"messages": page_items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"has_more": offset + page_size < total,
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Return basic stats keyed by channel_type, for monitoring."""
|
||||
with self._lock:
|
||||
conn = self._connect()
|
||||
try:
|
||||
total_sessions = conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions"
|
||||
).fetchone()[0]
|
||||
total_messages = conn.execute(
|
||||
"SELECT COUNT(*) FROM messages"
|
||||
).fetchone()[0]
|
||||
by_channel = conn.execute(
|
||||
"""
|
||||
SELECT channel_type, COUNT(*) as cnt
|
||||
FROM sessions
|
||||
GROUP BY channel_type
|
||||
ORDER BY cnt DESC
|
||||
"""
|
||||
).fetchall()
|
||||
return {
|
||||
"total_sessions": total_sessions,
|
||||
"total_messages": total_messages,
|
||||
"by_channel": {row[0] or "unknown": row[1] for row in by_channel},
|
||||
}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _init_db(self) -> None:
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = self._connect()
|
||||
try:
|
||||
conn.executescript(_DDL)
|
||||
conn.commit()
|
||||
self._migrate(conn)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _migrate(self, conn: sqlite3.Connection) -> None:
|
||||
"""Apply incremental schema migrations on existing databases."""
|
||||
cols = {
|
||||
row[1]
|
||||
for row in conn.execute("PRAGMA table_info(sessions)").fetchall()
|
||||
}
|
||||
if "channel_type" not in cols:
|
||||
try:
|
||||
conn.execute(_MIGRATION_ADD_CHANNEL_TYPE)
|
||||
conn.commit()
|
||||
logger.info("[ConversationStore] Migrated: added channel_type column")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConversationStore] Migration failed: {e}")
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(str(self._db_path), timeout=10)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
return conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_store_instance: Optional[ConversationStore] = None
|
||||
_store_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_conversation_store() -> ConversationStore:
|
||||
"""
|
||||
Return the process-wide ConversationStore singleton.
|
||||
|
||||
Reuses the long-term memory database so the project stays with a single
|
||||
SQLite file: ~/cow/memory/long-term/index.db
|
||||
The conversation tables (sessions / messages) are separate from the
|
||||
memory tables (memory_chunks / file_metadata) — no conflicts.
|
||||
"""
|
||||
global _store_instance
|
||||
if _store_instance is not None:
|
||||
return _store_instance
|
||||
|
||||
with _store_lock:
|
||||
if _store_instance is not None:
|
||||
return _store_instance
|
||||
|
||||
try:
|
||||
from agent.memory.config import get_default_memory_config
|
||||
db_path = get_default_memory_config().get_db_path()
|
||||
except Exception:
|
||||
from common.utils import expand_path
|
||||
db_path = Path(expand_path("~/cow")) / "memory" / "long-term" / "index.db"
|
||||
|
||||
_store_instance = ConversationStore(db_path)
|
||||
logger.debug(f"[ConversationStore] Using shared DB at: {db_path}")
|
||||
return _store_instance
|
||||
161
agent/memory/embedding.py
Normal file
161
agent/memory/embedding.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Embedding providers for memory
|
||||
|
||||
Supports OpenAI and local embedding models
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Base class for embedding providers"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dimensions(self) -> int:
|
||||
"""Get embedding dimensions"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"""OpenAI embedding provider using REST API"""
|
||||
|
||||
def __init__(self, model: str = "text-embedding-3-small", api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
"""
|
||||
Initialize OpenAI embedding provider
|
||||
|
||||
Args:
|
||||
model: Model name (text-embedding-3-small or text-embedding-3-large)
|
||||
api_key: OpenAI API key
|
||||
api_base: Optional API base URL
|
||||
"""
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.openai.com/v1"
|
||||
|
||||
# Validate API key
|
||||
if not self.api_key or self.api_key in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
raise ValueError("OpenAI API key is not configured. Please set 'open_ai_api_key' in config.json")
|
||||
|
||||
# Set dimensions based on model
|
||||
self._dimensions = 1536 if "small" in model else 3072
|
||||
|
||||
def _call_api(self, input_data):
|
||||
"""Call OpenAI embedding API using requests"""
|
||||
import requests
|
||||
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
data = {
|
||||
"input": input_data,
|
||||
"model": self.model
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=5)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to OpenAI API at {url}. Please check your network connection and api_base configuration. Error: {str(e)}")
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise TimeoutError(f"OpenAI API request timed out after 10s. Please check your network connection. Error: {str(e)}")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError(f"Invalid OpenAI API key. Please check your 'open_ai_api_key' in config.json")
|
||||
elif e.response.status_code == 429:
|
||||
raise ValueError(f"OpenAI API rate limit exceeded. Please try again later.")
|
||||
else:
|
||||
raise ValueError(f"OpenAI API request failed: {e.response.status_code} - {e.response.text}")
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
"""Generate embedding for text"""
|
||||
result = self._call_api(text)
|
||||
return result["data"][0]["embedding"]
|
||||
|
||||
def embed_batch(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Generate embeddings for multiple texts"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
result = self._call_api(texts)
|
||||
return [item["embedding"] for item in result["data"]]
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return self._dimensions
|
||||
|
||||
|
||||
# LocalEmbeddingProvider removed - only use OpenAI embedding or keyword search
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Cache for embeddings to avoid recomputation"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def get(self, text: str, provider: str, model: str) -> Optional[List[float]]:
|
||||
"""Get cached embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, text: str, provider: str, model: str, embedding: List[float]):
|
||||
"""Cache embedding"""
|
||||
key = self._compute_key(text, provider, model)
|
||||
self.cache[key] = embedding
|
||||
|
||||
@staticmethod
|
||||
def _compute_key(text: str, provider: str, model: str) -> str:
|
||||
"""Compute cache key"""
|
||||
content = f"{provider}:{model}:{text}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
"""Clear cache"""
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
def create_embedding_provider(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None
|
||||
) -> EmbeddingProvider:
|
||||
"""
|
||||
Factory function to create embedding provider
|
||||
|
||||
Only supports OpenAI embedding via REST API.
|
||||
If initialization fails, caller should fall back to keyword-only search.
|
||||
|
||||
Args:
|
||||
provider: Provider name (only "openai" is supported)
|
||||
model: Model name (default: text-embedding-3-small)
|
||||
api_key: OpenAI API key (required)
|
||||
api_base: API base URL (default: https://api.openai.com/v1)
|
||||
|
||||
Returns:
|
||||
EmbeddingProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not "openai" or api_key is missing
|
||||
"""
|
||||
if provider != "openai":
|
||||
raise ValueError(f"Only 'openai' provider is supported, got: {provider}")
|
||||
|
||||
model = model or "text-embedding-3-small"
|
||||
return OpenAIEmbeddingProvider(model=model, api_key=api_key, api_base=api_base)
|
||||
622
agent/memory/manager.py
Normal file
622
agent/memory/manager.py
Normal file
@@ -0,0 +1,622 @@
|
||||
"""
|
||||
Memory manager for AgentMesh
|
||||
|
||||
Provides high-level interface for memory operations
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from agent.memory.config import MemoryConfig, get_default_memory_config
|
||||
from agent.memory.storage import MemoryStorage, MemoryChunk, SearchResult
|
||||
from agent.memory.chunker import TextChunker
|
||||
from agent.memory.embedding import create_embedding_provider, EmbeddingProvider
|
||||
from agent.memory.summarizer import MemoryFlushManager, create_memory_files_if_needed
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Memory manager with hybrid search capabilities
|
||||
|
||||
Provides long-term memory for agents with vector and keyword search
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[MemoryConfig] = None,
|
||||
embedding_provider: Optional[EmbeddingProvider] = None,
|
||||
llm_model: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Initialize memory manager
|
||||
|
||||
Args:
|
||||
config: Memory configuration (uses global config if not provided)
|
||||
embedding_provider: Custom embedding provider (optional)
|
||||
llm_model: LLM model for summarization (optional)
|
||||
"""
|
||||
self.config = config or get_default_memory_config()
|
||||
|
||||
# Initialize storage
|
||||
db_path = self.config.get_db_path()
|
||||
self.storage = MemoryStorage(db_path)
|
||||
|
||||
# Initialize chunker
|
||||
self.chunker = TextChunker(
|
||||
max_tokens=self.config.chunk_max_tokens,
|
||||
overlap_tokens=self.config.chunk_overlap_tokens
|
||||
)
|
||||
|
||||
# Initialize embedding provider (optional)
|
||||
self.embedding_provider = None
|
||||
if embedding_provider:
|
||||
self.embedding_provider = embedding_provider
|
||||
else:
|
||||
# Try to create embedding provider, but allow failure
|
||||
try:
|
||||
# Get API key from environment or config
|
||||
api_key = os.environ.get('OPENAI_API_KEY')
|
||||
api_base = os.environ.get('OPENAI_API_BASE')
|
||||
|
||||
self.embedding_provider = create_embedding_provider(
|
||||
provider=self.config.embedding_provider,
|
||||
model=self.config.embedding_model,
|
||||
api_key=api_key,
|
||||
api_base=api_base
|
||||
)
|
||||
except Exception as e:
|
||||
# Embedding provider failed, but that's OK
|
||||
# We can still use keyword search and file operations
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Embedding provider initialization failed: {e}")
|
||||
logger.info(f"[MemoryManager] Memory will work with keyword search only (no vector search)")
|
||||
|
||||
# Initialize memory flush manager
|
||||
workspace_dir = self.config.get_workspace()
|
||||
self.flush_manager = MemoryFlushManager(
|
||||
workspace_dir=workspace_dir,
|
||||
llm_model=llm_model
|
||||
)
|
||||
|
||||
# Ensure workspace directories exist
|
||||
self._init_workspace()
|
||||
|
||||
self._dirty = False
|
||||
|
||||
def _init_workspace(self):
|
||||
"""Initialize workspace directories"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create default memory files
|
||||
workspace_dir = self.config.get_workspace()
|
||||
create_memory_files_if_needed(workspace_dir)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
max_results: Optional[int] = None,
|
||||
min_score: Optional[float] = None,
|
||||
include_shared: bool = True
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search memory with hybrid search (vector + keyword)
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
user_id: User ID for scoped search
|
||||
max_results: Maximum results to return
|
||||
min_score: Minimum score threshold
|
||||
include_shared: Include shared memories
|
||||
|
||||
Returns:
|
||||
List of search results sorted by relevance
|
||||
"""
|
||||
max_results = max_results or self.config.max_results
|
||||
min_score = min_score or self.config.min_score
|
||||
|
||||
# Determine scopes
|
||||
scopes = []
|
||||
if include_shared:
|
||||
scopes.append("shared")
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
if not scopes:
|
||||
return []
|
||||
|
||||
# Sync if needed
|
||||
if self.config.sync_on_search and self._dirty:
|
||||
await self.sync()
|
||||
|
||||
# Perform vector search (if embedding provider available)
|
||||
vector_results = []
|
||||
if self.embedding_provider:
|
||||
try:
|
||||
from common.log import logger
|
||||
query_embedding = self.embedding_provider.embed(query)
|
||||
vector_results = self.storage.search_vector(
|
||||
query_embedding=query_embedding,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2 # Get more candidates for merging
|
||||
)
|
||||
logger.info(f"[MemoryManager] Vector search found {len(vector_results)} results for query: {query}")
|
||||
except Exception as e:
|
||||
from common.log import logger
|
||||
logger.warning(f"[MemoryManager] Vector search failed: {e}")
|
||||
|
||||
# Perform keyword search
|
||||
keyword_results = self.storage.search_keyword(
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
limit=max_results * 2
|
||||
)
|
||||
from common.log import logger
|
||||
logger.info(f"[MemoryManager] Keyword search found {len(keyword_results)} results for query: {query}")
|
||||
|
||||
# Merge results
|
||||
merged = self._merge_results(
|
||||
vector_results,
|
||||
keyword_results,
|
||||
self.config.vector_weight,
|
||||
self.config.keyword_weight
|
||||
)
|
||||
|
||||
# Filter by min score and limit
|
||||
filtered = [r for r in merged if r.score >= min_score]
|
||||
return filtered[:max_results]
|
||||
|
||||
async def add_memory(
|
||||
self,
|
||||
content: str,
|
||||
user_id: Optional[str] = None,
|
||||
scope: str = "shared",
|
||||
source: str = "memory",
|
||||
path: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Add new memory content
|
||||
|
||||
Args:
|
||||
content: Memory content
|
||||
user_id: User ID for user-scoped memory
|
||||
scope: Memory scope ("shared", "user", "session")
|
||||
source: Memory source ("memory" or "session")
|
||||
path: File path (auto-generated if not provided)
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
if not content.strip():
|
||||
return
|
||||
|
||||
# Generate path if not provided
|
||||
if not path:
|
||||
content_hash = hashlib.md5(content.encode('utf-8')).hexdigest()[:8]
|
||||
if user_id and scope == "user":
|
||||
path = f"memory/users/{user_id}/memory_{content_hash}.md"
|
||||
else:
|
||||
path = f"memory/shared/memory_{content_hash}.md"
|
||||
|
||||
# Chunk content
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
|
||||
# Generate embeddings (if provider available)
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
# No embeddings, just use None
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=metadata
|
||||
))
|
||||
|
||||
# Save to storage
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
self.storage.update_file_metadata(
|
||||
path=path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(os.path.getmtime(__file__)), # Use current time
|
||||
size=len(content)
|
||||
)
|
||||
|
||||
async def sync(self, force: bool = False):
|
||||
"""
|
||||
Synchronize memory from files
|
||||
|
||||
Args:
|
||||
force: Force full reindex
|
||||
"""
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
workspace_dir = self.config.get_workspace()
|
||||
|
||||
# Scan MEMORY.md (workspace root)
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
await self._sync_file(memory_file, "memory", "shared", None)
|
||||
|
||||
# Scan memory directory (including daily summaries)
|
||||
if memory_dir.exists():
|
||||
for file_path in memory_dir.rglob("*.md"):
|
||||
# Determine scope and user_id from path
|
||||
rel_path = file_path.relative_to(workspace_dir)
|
||||
parts = rel_path.parts
|
||||
|
||||
# Check if it's in daily summary directory
|
||||
if "daily" in parts:
|
||||
# Daily summary files
|
||||
if "users" in parts or len(parts) > 3:
|
||||
# User-scoped daily summary: memory/daily/{user_id}/2024-01-29.md
|
||||
user_idx = parts.index("daily") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared daily summary: memory/daily/2024-01-29.md
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
elif "users" in parts:
|
||||
# User-scoped memory
|
||||
user_idx = parts.index("users") + 1
|
||||
user_id = parts[user_idx] if user_idx < len(parts) else None
|
||||
scope = "user"
|
||||
else:
|
||||
# Shared memory
|
||||
user_id = None
|
||||
scope = "shared"
|
||||
|
||||
await self._sync_file(file_path, "memory", scope, user_id)
|
||||
|
||||
self._dirty = False
|
||||
|
||||
async def _sync_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
source: str,
|
||||
scope: str,
|
||||
user_id: Optional[str]
|
||||
):
|
||||
"""Sync a single file"""
|
||||
# Compute file hash
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
file_hash = MemoryStorage.compute_hash(content)
|
||||
|
||||
# Get relative path
|
||||
workspace_dir = self.config.get_workspace()
|
||||
rel_path = str(file_path.relative_to(workspace_dir))
|
||||
|
||||
# Check if file changed
|
||||
stored_hash = self.storage.get_file_hash(rel_path)
|
||||
if stored_hash == file_hash:
|
||||
return # No changes
|
||||
|
||||
# Delete old chunks
|
||||
self.storage.delete_by_path(rel_path)
|
||||
|
||||
# Chunk and embed
|
||||
chunks = self.chunker.chunk_text(content)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
texts = [chunk.text for chunk in chunks]
|
||||
if self.embedding_provider:
|
||||
embeddings = self.embedding_provider.embed_batch(texts)
|
||||
else:
|
||||
embeddings = [None] * len(texts)
|
||||
|
||||
# Create memory chunks
|
||||
memory_chunks = []
|
||||
for chunk, embedding in zip(chunks, embeddings):
|
||||
chunk_id = self._generate_chunk_id(rel_path, chunk.start_line, chunk.end_line)
|
||||
chunk_hash = MemoryStorage.compute_hash(chunk.text)
|
||||
|
||||
memory_chunks.append(MemoryChunk(
|
||||
id=chunk_id,
|
||||
user_id=user_id,
|
||||
scope=scope,
|
||||
source=source,
|
||||
path=rel_path,
|
||||
start_line=chunk.start_line,
|
||||
end_line=chunk.end_line,
|
||||
text=chunk.text,
|
||||
embedding=embedding,
|
||||
hash=chunk_hash,
|
||||
metadata=None
|
||||
))
|
||||
|
||||
# Save
|
||||
self.storage.save_chunks_batch(memory_chunks)
|
||||
|
||||
# Update file metadata
|
||||
stat = file_path.stat()
|
||||
self.storage.update_file_metadata(
|
||||
path=rel_path,
|
||||
source=source,
|
||||
file_hash=file_hash,
|
||||
mtime=int(stat.st_mtime),
|
||||
size=stat.st_size
|
||||
)
|
||||
|
||||
def should_flush_memory(
|
||||
self,
|
||||
current_tokens: int = 0
|
||||
) -> bool:
|
||||
"""
|
||||
Check if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window。
|
||||
使用配置中的阈值: flush_token_threshold 和 flush_turn_threshold
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
|
||||
Returns:
|
||||
True if memory flush should run
|
||||
"""
|
||||
return self.flush_manager.should_flush(
|
||||
current_tokens=current_tokens,
|
||||
token_threshold=self.config.flush_token_threshold,
|
||||
turn_threshold=self.config.flush_turn_threshold
|
||||
)
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数(每次用户消息+AI回复算一轮)"""
|
||||
self.flush_manager.increment_turn()
|
||||
|
||||
async def execute_memory_flush(
|
||||
self,
|
||||
agent_executor,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush before compaction
|
||||
|
||||
This runs a silent agent turn to write durable memories to disk.
|
||||
Similar to clawdbot's pre-compaction memory flush.
|
||||
|
||||
Args:
|
||||
agent_executor: Async function to execute agent with prompt
|
||||
current_tokens: Current session token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
|
||||
Example:
|
||||
>>> async def run_agent(prompt, system_prompt, silent=False):
|
||||
... # Your agent execution logic
|
||||
... pass
|
||||
>>>
|
||||
>>> if manager.should_flush_memory(current_tokens=100000):
|
||||
... await manager.execute_memory_flush(
|
||||
... agent_executor=run_agent,
|
||||
... current_tokens=100000
|
||||
... )
|
||||
"""
|
||||
success = await self.flush_manager.execute_flush(
|
||||
agent_executor=agent_executor,
|
||||
current_tokens=current_tokens,
|
||||
user_id=user_id,
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
if success:
|
||||
# Mark dirty so next search will sync the new memories
|
||||
self._dirty = True
|
||||
|
||||
return success
|
||||
|
||||
def build_memory_guidance(self, lang: str = "zh", include_context: bool = True) -> str:
|
||||
"""
|
||||
Build natural memory guidance for agent system prompt
|
||||
|
||||
Following clawdbot's approach:
|
||||
1. Load MEMORY.md as bootstrap context (blends into background)
|
||||
2. Load daily files on-demand via memory_search tool
|
||||
3. Agent should NOT proactively mention memories unless user asks
|
||||
|
||||
Args:
|
||||
lang: Language for guidance ("en" or "zh")
|
||||
include_context: Whether to include bootstrap memory context (default: True)
|
||||
MEMORY.md is loaded as background context (like clawdbot)
|
||||
Daily files are accessed via memory_search tool
|
||||
|
||||
Returns:
|
||||
Memory guidance text (and optionally context) for system prompt
|
||||
"""
|
||||
today_file = self.flush_manager.get_today_memory_file().name
|
||||
|
||||
if lang == "zh":
|
||||
guidance = f"""## 记忆系统
|
||||
|
||||
**背景知识**: 下方包含核心长期记忆,可直接使用。需要查找历史时,用 memory_search 搜索(搜索一次即可,不要重复)。
|
||||
|
||||
**存储记忆**: 当用户分享重要信息时(偏好、决策、事实等),主动用 write 工具存储:
|
||||
- 长期信息 → MEMORY.md
|
||||
- 当天笔记 → memory/{today_file}
|
||||
- 静默存储,仅在明确要求时确认
|
||||
|
||||
**使用原则**: 自然使用记忆,就像你本来就知道。不需要生硬地提起或列举记忆,除非用户提到。"""
|
||||
else:
|
||||
guidance = f"""## Memory System
|
||||
|
||||
**Background Knowledge**: Core long-term memories below - use directly. For history, use memory_search once (don't repeat).
|
||||
|
||||
**Store Memories**: When user shares important info (preferences, decisions, facts), proactively write:
|
||||
- Durable info → MEMORY.md
|
||||
- Daily notes → memory/{today_file}
|
||||
- Store silently; confirm only when explicitly requested
|
||||
|
||||
**Usage**: Use memories naturally as if you always knew. Don't mention or list unless user explicitly asks."""
|
||||
|
||||
if include_context:
|
||||
# Load bootstrap context (MEMORY.md only, like clawdbot)
|
||||
bootstrap_context = self.load_bootstrap_memories()
|
||||
if bootstrap_context:
|
||||
guidance += f"\n\n## Background Context\n\n{bootstrap_context}"
|
||||
|
||||
return guidance
|
||||
|
||||
def load_bootstrap_memories(self, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Load bootstrap memory files for session start
|
||||
|
||||
Following clawdbot's design:
|
||||
- Only loads MEMORY.md from workspace root (long-term curated memory)
|
||||
- Daily files (memory/YYYY-MM-DD.md) are accessed via memory_search tool, not bootstrap
|
||||
- User-specific MEMORY.md is also loaded if user_id provided
|
||||
|
||||
Returns memory content WITHOUT obvious headers so it blends naturally
|
||||
into the context as background knowledge.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memories
|
||||
|
||||
Returns:
|
||||
Memory content to inject into system prompt (blends naturally as background context)
|
||||
"""
|
||||
workspace_dir = self.config.get_workspace()
|
||||
memory_dir = self.config.get_memory_dir()
|
||||
|
||||
sections = []
|
||||
|
||||
# 1. Load MEMORY.md from workspace root (long-term curated memory)
|
||||
# Following clawdbot: only MEMORY.md is bootstrap, daily files use memory_search
|
||||
memory_file = Path(workspace_dir) / "MEMORY.md"
|
||||
if memory_file.exists():
|
||||
try:
|
||||
content = memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read MEMORY.md: {e}")
|
||||
|
||||
# 2. Load user-specific MEMORY.md if user_id provided
|
||||
if user_id:
|
||||
user_memory_dir = memory_dir / "users" / user_id
|
||||
user_memory_file = user_memory_dir / "MEMORY.md"
|
||||
if user_memory_file.exists():
|
||||
try:
|
||||
content = user_memory_file.read_text(encoding='utf-8').strip()
|
||||
if content:
|
||||
sections.append(content)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to read user memory: {e}")
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
# Join sections without obvious headers - let memories blend naturally
|
||||
# This makes the agent feel like it "just knows" rather than "checking memory files"
|
||||
return "\n\n".join(sections)
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get memory status"""
|
||||
stats = self.storage.get_stats()
|
||||
return {
|
||||
'chunks': stats['chunks'],
|
||||
'files': stats['files'],
|
||||
'workspace': str(self.config.get_workspace()),
|
||||
'dirty': self._dirty,
|
||||
'embedding_enabled': self.embedding_provider is not None,
|
||||
'embedding_provider': self.config.embedding_provider if self.embedding_provider else 'disabled',
|
||||
'embedding_model': self.config.embedding_model if self.embedding_provider else 'N/A',
|
||||
'search_mode': 'hybrid (vector + keyword)' if self.embedding_provider else 'keyword only (FTS5)'
|
||||
}
|
||||
|
||||
def mark_dirty(self):
|
||||
"""Mark memory as dirty (needs sync)"""
|
||||
self._dirty = True
|
||||
|
||||
def close(self):
|
||||
"""Close memory manager and release resources"""
|
||||
self.storage.close()
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _generate_chunk_id(self, path: str, start_line: int, end_line: int) -> str:
|
||||
"""Generate unique chunk ID"""
|
||||
content = f"{path}:{start_line}:{end_line}"
|
||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
vector_results: List[SearchResult],
|
||||
keyword_results: List[SearchResult],
|
||||
vector_weight: float,
|
||||
keyword_weight: float
|
||||
) -> List[SearchResult]:
|
||||
"""Merge vector and keyword search results"""
|
||||
# Create a map by (path, start_line, end_line)
|
||||
merged_map = {}
|
||||
|
||||
for result in vector_results:
|
||||
key = (result.path, result.start_line, result.end_line)
|
||||
merged_map[key] = {
|
||||
'result': result,
|
||||
'vector_score': result.score,
|
||||
'keyword_score': 0.0
|
||||
}
|
||||
|
||||
for result in keyword_results:
|
||||
key = (result.path, result.start_line, result.end_line)
|
||||
if key in merged_map:
|
||||
merged_map[key]['keyword_score'] = result.score
|
||||
else:
|
||||
merged_map[key] = {
|
||||
'result': result,
|
||||
'vector_score': 0.0,
|
||||
'keyword_score': result.score
|
||||
}
|
||||
|
||||
# Calculate combined scores
|
||||
merged_results = []
|
||||
for entry in merged_map.values():
|
||||
combined_score = (
|
||||
vector_weight * entry['vector_score'] +
|
||||
keyword_weight * entry['keyword_score']
|
||||
)
|
||||
|
||||
result = entry['result']
|
||||
merged_results.append(SearchResult(
|
||||
path=result.path,
|
||||
start_line=result.start_line,
|
||||
end_line=result.end_line,
|
||||
score=combined_score,
|
||||
snippet=result.snippet,
|
||||
source=result.source,
|
||||
user_id=result.user_id
|
||||
))
|
||||
|
||||
# Sort by score
|
||||
merged_results.sort(key=lambda r: r.score, reverse=True)
|
||||
return merged_results
|
||||
167
agent/memory/service.py
Normal file
167
agent/memory/service.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Memory service for handling memory query operations via cloud protocol.
|
||||
|
||||
Provides a unified interface for listing and reading memory files,
|
||||
callable from the cloud client (LinkAI) or a future web console.
|
||||
|
||||
Memory file layout (under workspace_root):
|
||||
MEMORY.md -> type: global
|
||||
memory/2026-02-20.md -> type: daily
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class MemoryService:
|
||||
"""
|
||||
High-level service for memory file queries.
|
||||
Operates directly on the filesystem — no MemoryManager dependency.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_root: str):
|
||||
"""
|
||||
:param workspace_root: Workspace root directory (e.g. ~/cow)
|
||||
"""
|
||||
self.workspace_root = workspace_root
|
||||
self.memory_dir = os.path.join(workspace_root, "memory")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list — paginated file metadata
|
||||
# ------------------------------------------------------------------
|
||||
def list_files(self, page: int = 1, page_size: int = 20) -> dict:
|
||||
"""
|
||||
List all memory files with metadata (without content).
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"total": 15,
|
||||
"list": [
|
||||
{"filename": "MEMORY.md", "type": "global", "size": 2048, "updated_at": "2026-02-20 10:00:00"},
|
||||
{"filename": "2026-02-20.md", "type": "daily", "size": 512, "updated_at": "2026-02-20 09:30:00"},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
files: List[dict] = []
|
||||
|
||||
# 1. Global memory — MEMORY.md in workspace root
|
||||
global_path = os.path.join(self.workspace_root, "MEMORY.md")
|
||||
if os.path.isfile(global_path):
|
||||
files.append(self._file_info(global_path, "MEMORY.md", "global"))
|
||||
|
||||
# 2. Daily memory files — memory/*.md (sorted newest first)
|
||||
if os.path.isdir(self.memory_dir):
|
||||
daily_files = []
|
||||
for name in os.listdir(self.memory_dir):
|
||||
full = os.path.join(self.memory_dir, name)
|
||||
if os.path.isfile(full) and name.endswith(".md"):
|
||||
daily_files.append((name, full))
|
||||
# Sort by filename descending (newest date first)
|
||||
daily_files.sort(key=lambda x: x[0], reverse=True)
|
||||
for name, full in daily_files:
|
||||
files.append(self._file_info(full, name, "daily"))
|
||||
|
||||
total = len(files)
|
||||
|
||||
# Paginate
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
page_items = files[start:end]
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total": total,
|
||||
"list": page_items,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# content — read a single file
|
||||
# ------------------------------------------------------------------
|
||||
def get_content(self, filename: str) -> dict:
|
||||
"""
|
||||
Read the full content of a memory file.
|
||||
|
||||
:param filename: File name, e.g. ``MEMORY.md`` or ``2026-02-20.md``
|
||||
:return: dict with ``filename`` and ``content``
|
||||
:raises FileNotFoundError: if the file does not exist
|
||||
"""
|
||||
path = self._resolve_path(filename)
|
||||
if not os.path.isfile(path):
|
||||
raise FileNotFoundError(f"Memory file not found: {filename}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch — single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a memory management action.
|
||||
|
||||
:param action: ``list`` or ``content``
|
||||
:param payload: action-specific payload
|
||||
:return: protocol-compatible response dict
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "list":
|
||||
page = payload.get("page", 1)
|
||||
page_size = payload.get("page_size", 20)
|
||||
result_payload = self.list_files(page=page, page_size=page_size)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
elif action == "content":
|
||||
filename = payload.get("filename")
|
||||
if not filename:
|
||||
return {"action": action, "code": 400, "message": "filename is required", "payload": None}
|
||||
result_payload = self.get_content(filename)
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return {"action": action, "code": 404, "message": str(e), "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[MemoryService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
def _resolve_path(self, filename: str) -> str:
|
||||
"""
|
||||
Resolve a filename to its absolute path.
|
||||
|
||||
- ``MEMORY.md`` → ``{workspace_root}/MEMORY.md``
|
||||
- ``2026-02-20.md`` → ``{workspace_root}/memory/2026-02-20.md``
|
||||
"""
|
||||
if filename == "MEMORY.md":
|
||||
return os.path.join(self.workspace_root, filename)
|
||||
return os.path.join(self.memory_dir, filename)
|
||||
|
||||
@staticmethod
|
||||
def _file_info(path: str, filename: str, file_type: str) -> dict:
|
||||
"""Build a file metadata dict."""
|
||||
stat = os.stat(path)
|
||||
updated_at = datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S")
|
||||
return {
|
||||
"filename": filename,
|
||||
"type": file_type,
|
||||
"size": stat.st_size,
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
589
agent/memory/storage.py
Normal file
589
agent/memory/storage.py
Normal file
@@ -0,0 +1,589 @@
|
||||
"""
|
||||
Storage layer for memory using SQLite + FTS5
|
||||
|
||||
Provides vector and keyword search capabilities
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import sqlite3
|
||||
import json
|
||||
import hashlib
|
||||
from typing import List, Dict, Optional, Any
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryChunk:
|
||||
"""Represents a memory chunk with text and embedding"""
|
||||
id: str
|
||||
user_id: Optional[str]
|
||||
scope: str # "shared" | "user" | "session"
|
||||
source: str # "memory" | "session"
|
||||
path: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
text: str
|
||||
embedding: Optional[List[float]]
|
||||
hash: str
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Search result with score and snippet"""
|
||||
path: str
|
||||
start_line: int
|
||||
end_line: int
|
||||
score: float
|
||||
snippet: str
|
||||
source: str
|
||||
user_id: Optional[str] = None
|
||||
|
||||
|
||||
class MemoryStorage:
|
||||
"""SQLite-based storage with FTS5 for keyword search"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = db_path
|
||||
self.conn: Optional[sqlite3.Connection] = None
|
||||
self.fts5_available = False # Track FTS5 availability
|
||||
self._init_db()
|
||||
|
||||
def _check_fts5_support(self) -> bool:
|
||||
"""Check if SQLite has FTS5 support"""
|
||||
try:
|
||||
self.conn.execute("CREATE VIRTUAL TABLE IF NOT EXISTS fts5_test USING fts5(test)")
|
||||
self.conn.execute("DROP TABLE IF EXISTS fts5_test")
|
||||
return True
|
||||
except sqlite3.OperationalError as e:
|
||||
if "no such module: fts5" in str(e):
|
||||
return False
|
||||
raise
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database with schema"""
|
||||
try:
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
# Check FTS5 support
|
||||
self.fts5_available = self._check_fts5_support()
|
||||
if not self.fts5_available:
|
||||
from common.log import logger
|
||||
logger.debug("[MemoryStorage] FTS5 not available, using LIKE-based keyword search")
|
||||
|
||||
# Check database integrity
|
||||
try:
|
||||
result = self.conn.execute("PRAGMA integrity_check").fetchone()
|
||||
if result[0] != 'ok':
|
||||
print(f"⚠️ Database integrity check failed: {result[0]}")
|
||||
print(f" Recreating database...")
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
# Remove corrupted database
|
||||
self.db_path.unlink(missing_ok=True)
|
||||
# Remove WAL files
|
||||
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
|
||||
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
|
||||
# Reconnect to create new database
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
except sqlite3.DatabaseError:
|
||||
# Database is corrupted, recreate it
|
||||
print(f"⚠️ Database is corrupted, recreating...")
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
self.db_path.unlink(missing_ok=True)
|
||||
Path(str(self.db_path) + '-wal').unlink(missing_ok=True)
|
||||
Path(str(self.db_path) + '-shm').unlink(missing_ok=True)
|
||||
self.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
# Enable WAL mode for better concurrency
|
||||
self.conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Set busy timeout to avoid "database is locked" errors
|
||||
self.conn.execute("PRAGMA busy_timeout=5000")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Unexpected error during database initialization: {e}")
|
||||
raise
|
||||
|
||||
# Create chunks table with embeddings
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT,
|
||||
scope TEXT NOT NULL DEFAULT 'shared',
|
||||
source TEXT NOT NULL DEFAULT 'memory',
|
||||
path TEXT NOT NULL,
|
||||
start_line INTEGER NOT NULL,
|
||||
end_line INTEGER NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
embedding TEXT,
|
||||
hash TEXT NOT NULL,
|
||||
metadata TEXT,
|
||||
created_at INTEGER DEFAULT (strftime('%s', 'now')),
|
||||
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_user
|
||||
ON chunks(user_id)
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_scope
|
||||
ON chunks(scope)
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_hash
|
||||
ON chunks(path, hash)
|
||||
""")
|
||||
|
||||
# Create FTS5 virtual table for keyword search (only if supported)
|
||||
if self.fts5_available:
|
||||
# Use default unicode61 tokenizer (stable and compatible)
|
||||
# For CJK support, we'll use LIKE queries as fallback
|
||||
self.conn.execute("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS chunks_fts USING fts5(
|
||||
text,
|
||||
id UNINDEXED,
|
||||
user_id UNINDEXED,
|
||||
path UNINDEXED,
|
||||
source UNINDEXED,
|
||||
scope UNINDEXED,
|
||||
content='chunks',
|
||||
content_rowid='rowid'
|
||||
)
|
||||
""")
|
||||
|
||||
# Create triggers to keep FTS in sync
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ai AFTER INSERT ON chunks BEGIN
|
||||
INSERT INTO chunks_fts(rowid, text, id, user_id, path, source, scope)
|
||||
VALUES (new.rowid, new.text, new.id, new.user_id, new.path, new.source, new.scope);
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_ad AFTER DELETE ON chunks BEGIN
|
||||
DELETE FROM chunks_fts WHERE rowid = old.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
self.conn.execute("""
|
||||
CREATE TRIGGER IF NOT EXISTS chunks_au AFTER UPDATE ON chunks BEGIN
|
||||
UPDATE chunks_fts SET text = new.text, id = new.id,
|
||||
user_id = new.user_id, path = new.path, source = new.source, scope = new.scope
|
||||
WHERE rowid = new.rowid;
|
||||
END
|
||||
""")
|
||||
|
||||
# Create files metadata table
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS files (
|
||||
path TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL DEFAULT 'memory',
|
||||
hash TEXT NOT NULL,
|
||||
mtime INTEGER NOT NULL,
|
||||
size INTEGER NOT NULL,
|
||||
updated_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
)
|
||||
""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def save_chunk(self, chunk: MemoryChunk):
|
||||
"""Save a memory chunk"""
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (
|
||||
chunk.id,
|
||||
chunk.user_id,
|
||||
chunk.scope,
|
||||
chunk.source,
|
||||
chunk.path,
|
||||
chunk.start_line,
|
||||
chunk.end_line,
|
||||
chunk.text,
|
||||
json.dumps(chunk.embedding) if chunk.embedding else None,
|
||||
chunk.hash,
|
||||
json.dumps(chunk.metadata) if chunk.metadata else None
|
||||
))
|
||||
self.conn.commit()
|
||||
|
||||
def save_chunks_batch(self, chunks: List[MemoryChunk]):
|
||||
"""Save multiple chunks in a batch"""
|
||||
self.conn.executemany("""
|
||||
INSERT OR REPLACE INTO chunks
|
||||
(id, user_id, scope, source, path, start_line, end_line, text, embedding, hash, metadata, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", [
|
||||
(
|
||||
c.id, c.user_id, c.scope, c.source, c.path,
|
||||
c.start_line, c.end_line, c.text,
|
||||
json.dumps(c.embedding) if c.embedding else None,
|
||||
c.hash,
|
||||
json.dumps(c.metadata) if c.metadata else None
|
||||
)
|
||||
for c in chunks
|
||||
])
|
||||
self.conn.commit()
|
||||
|
||||
def get_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
|
||||
"""Get a chunk by ID"""
|
||||
row = self.conn.execute("""
|
||||
SELECT * FROM chunks WHERE id = ?
|
||||
""", (chunk_id,)).fetchone()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return self._row_to_chunk(row)
|
||||
|
||||
def search_vector(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
user_id: Optional[str] = None,
|
||||
scopes: List[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Vector similarity search using in-memory cosine similarity
|
||||
(sqlite-vec can be added later for better performance)
|
||||
"""
|
||||
if scopes is None:
|
||||
scopes = ["shared"]
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
# Build query
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
params = scopes
|
||||
|
||||
if user_id:
|
||||
query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE scope IN ({scope_placeholders})
|
||||
AND (scope = 'shared' OR user_id = ?)
|
||||
AND embedding IS NOT NULL
|
||||
"""
|
||||
params.append(user_id)
|
||||
else:
|
||||
query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE scope IN ({scope_placeholders})
|
||||
AND embedding IS NOT NULL
|
||||
"""
|
||||
|
||||
rows = self.conn.execute(query, params).fetchall()
|
||||
|
||||
# Calculate cosine similarity
|
||||
results = []
|
||||
for row in rows:
|
||||
embedding = json.loads(row['embedding'])
|
||||
similarity = self._cosine_similarity(query_embedding, embedding)
|
||||
|
||||
if similarity > 0:
|
||||
results.append((similarity, row))
|
||||
|
||||
# Sort by similarity and limit
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=score,
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for score, row in results
|
||||
]
|
||||
|
||||
def search_keyword(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str] = None,
|
||||
scopes: List[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Keyword search using FTS5 + LIKE fallback
|
||||
|
||||
Strategy:
|
||||
1. If FTS5 available: Try FTS5 search first (good for English and word-based languages)
|
||||
2. If no FTS5 or no results and query contains CJK: Use LIKE search
|
||||
"""
|
||||
if scopes is None:
|
||||
scopes = ["shared"]
|
||||
if user_id:
|
||||
scopes.append("user")
|
||||
|
||||
# Try FTS5 search first (if available)
|
||||
if self.fts5_available:
|
||||
fts_results = self._search_fts5(query, user_id, scopes, limit)
|
||||
if fts_results:
|
||||
return fts_results
|
||||
|
||||
# Fallback to LIKE search (always for CJK, or if FTS5 not available)
|
||||
if not self.fts5_available or MemoryStorage._contains_cjk(query):
|
||||
return self._search_like(query, user_id, scopes, limit)
|
||||
|
||||
return []
|
||||
|
||||
def _search_fts5(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str],
|
||||
scopes: List[str],
|
||||
limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""FTS5 full-text search"""
|
||||
fts_query = self._build_fts_query(query)
|
||||
if not fts_query:
|
||||
return []
|
||||
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
params = [fts_query] + scopes
|
||||
|
||||
if user_id:
|
||||
sql_query = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||
FROM chunks_fts
|
||||
JOIN chunks ON chunks.id = chunks_fts.id
|
||||
WHERE chunks_fts MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
AND (chunks.scope = 'shared' OR chunks.user_id = ?)
|
||||
ORDER BY rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.extend([user_id, limit])
|
||||
else:
|
||||
sql_query = f"""
|
||||
SELECT chunks.*, bm25(chunks_fts) as rank
|
||||
FROM chunks_fts
|
||||
JOIN chunks ON chunks.id = chunks_fts.id
|
||||
WHERE chunks_fts MATCH ?
|
||||
AND chunks.scope IN ({scope_placeholders})
|
||||
ORDER BY rank
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
rows = self.conn.execute(sql_query, params).fetchall()
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=self._bm25_rank_to_score(row['rank']),
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _search_like(
|
||||
self,
|
||||
query: str,
|
||||
user_id: Optional[str],
|
||||
scopes: List[str],
|
||||
limit: int
|
||||
) -> List[SearchResult]:
|
||||
"""LIKE-based search for CJK characters"""
|
||||
import re
|
||||
# Extract CJK words (2+ characters)
|
||||
cjk_words = re.findall(r'[\u4e00-\u9fff]{2,}', query)
|
||||
if not cjk_words:
|
||||
return []
|
||||
|
||||
scope_placeholders = ','.join('?' * len(scopes))
|
||||
|
||||
# Build LIKE conditions for each word
|
||||
like_conditions = []
|
||||
params = []
|
||||
for word in cjk_words:
|
||||
like_conditions.append("text LIKE ?")
|
||||
params.append(f'%{word}%')
|
||||
|
||||
where_clause = ' OR '.join(like_conditions)
|
||||
params.extend(scopes)
|
||||
|
||||
if user_id:
|
||||
sql_query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE ({where_clause})
|
||||
AND scope IN ({scope_placeholders})
|
||||
AND (scope = 'shared' OR user_id = ?)
|
||||
LIMIT ?
|
||||
"""
|
||||
params.extend([user_id, limit])
|
||||
else:
|
||||
sql_query = f"""
|
||||
SELECT * FROM chunks
|
||||
WHERE ({where_clause})
|
||||
AND scope IN ({scope_placeholders})
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
rows = self.conn.execute(sql_query, params).fetchall()
|
||||
return [
|
||||
SearchResult(
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
score=0.5, # Fixed score for LIKE search
|
||||
snippet=self._truncate_text(row['text'], 500),
|
||||
source=row['source'],
|
||||
user_id=row['user_id']
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def delete_by_path(self, path: str):
|
||||
"""Delete all chunks from a file"""
|
||||
self.conn.execute("""
|
||||
DELETE FROM chunks WHERE path = ?
|
||||
""", (path,))
|
||||
self.conn.commit()
|
||||
|
||||
def get_file_hash(self, path: str) -> Optional[str]:
|
||||
"""Get stored file hash"""
|
||||
row = self.conn.execute("""
|
||||
SELECT hash FROM files WHERE path = ?
|
||||
""", (path,)).fetchone()
|
||||
return row['hash'] if row else None
|
||||
|
||||
def update_file_metadata(self, path: str, source: str, file_hash: str, mtime: int, size: int):
|
||||
"""Update file metadata"""
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO files (path, source, hash, mtime, size, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, strftime('%s', 'now'))
|
||||
""", (path, source, file_hash, mtime, size))
|
||||
self.conn.commit()
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
"""Get storage statistics"""
|
||||
chunks_count = self.conn.execute("""
|
||||
SELECT COUNT(*) as cnt FROM chunks
|
||||
""").fetchone()['cnt']
|
||||
|
||||
files_count = self.conn.execute("""
|
||||
SELECT COUNT(*) as cnt FROM files
|
||||
""").fetchone()['cnt']
|
||||
|
||||
return {
|
||||
'chunks': chunks_count,
|
||||
'files': files_count
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""Close database connection"""
|
||||
if self.conn:
|
||||
try:
|
||||
self.conn.commit() # Ensure all changes are committed
|
||||
self.conn.close()
|
||||
self.conn = None # Mark as closed
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error closing database connection: {e}")
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure connection is closed"""
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _row_to_chunk(self, row) -> MemoryChunk:
|
||||
"""Convert database row to MemoryChunk"""
|
||||
return MemoryChunk(
|
||||
id=row['id'],
|
||||
user_id=row['user_id'],
|
||||
scope=row['scope'],
|
||||
source=row['source'],
|
||||
path=row['path'],
|
||||
start_line=row['start_line'],
|
||||
end_line=row['end_line'],
|
||||
text=row['text'],
|
||||
embedding=json.loads(row['embedding']) if row['embedding'] else None,
|
||||
hash=row['hash'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors"""
|
||||
if len(vec1) != len(vec2):
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
||||
norm1 = sum(a * a for a in vec1) ** 0.5
|
||||
norm2 = sum(b * b for b in vec2) ** 0.5
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
@staticmethod
|
||||
def _contains_cjk(text: str) -> bool:
|
||||
"""Check if text contains CJK (Chinese/Japanese/Korean) characters"""
|
||||
import re
|
||||
return bool(re.search(r'[\u4e00-\u9fff]', text))
|
||||
|
||||
@staticmethod
|
||||
def _build_fts_query(raw_query: str) -> Optional[str]:
|
||||
"""
|
||||
Build FTS5 query from raw text
|
||||
|
||||
Works best for English and word-based languages.
|
||||
For CJK characters, LIKE search will be used as fallback.
|
||||
"""
|
||||
import re
|
||||
# Extract words (primarily English words and numbers)
|
||||
tokens = re.findall(r'[A-Za-z0-9_]+', raw_query)
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
# Quote tokens for exact matching
|
||||
quoted = [f'"{t}"' for t in tokens]
|
||||
# Use OR for more flexible matching
|
||||
return ' OR '.join(quoted)
|
||||
|
||||
@staticmethod
|
||||
def _bm25_rank_to_score(rank: float) -> float:
|
||||
"""Convert BM25 rank to 0-1 score"""
|
||||
normalized = max(0, rank) if rank is not None else 999
|
||||
return 1 / (1 + normalized)
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text(text: str, max_chars: int) -> str:
|
||||
"""Truncate text to max characters"""
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[:max_chars] + "..."
|
||||
|
||||
@staticmethod
|
||||
def compute_hash(content: str) -> str:
|
||||
"""Compute SHA256 hash of content"""
|
||||
return hashlib.sha256(content.encode('utf-8')).hexdigest()
|
||||
256
agent/memory/summarizer.py
Normal file
256
agent/memory/summarizer.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
Memory flush manager
|
||||
|
||||
Triggers memory flush before context compaction (similar to clawdbot)
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable, Any
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class MemoryFlushManager:
|
||||
"""
|
||||
Manages memory flush operations before context compaction
|
||||
|
||||
Similar to clawdbot's memory flush mechanism:
|
||||
- Triggers when context approaches token limit
|
||||
- Runs a silent agent turn to write memories to disk
|
||||
- Uses memory/YYYY-MM-DD.md for daily notes
|
||||
- Uses MEMORY.md (workspace root) for long-term curated memories
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
llm_model: Optional[Any] = None
|
||||
):
|
||||
"""
|
||||
Initialize memory flush manager
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
llm_model: LLM model for agent execution (optional)
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.llm_model = llm_model
|
||||
|
||||
self.memory_dir = workspace_dir / "memory"
|
||||
self.memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Tracking
|
||||
self.last_flush_token_count: Optional[int] = None
|
||||
self.last_flush_timestamp: Optional[datetime] = None
|
||||
self.turn_count: int = 0 # 对话轮数计数器
|
||||
|
||||
def should_flush(
|
||||
self,
|
||||
current_tokens: int = 0,
|
||||
token_threshold: int = 50000,
|
||||
turn_threshold: int = 20
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if memory flush should be triggered
|
||||
|
||||
独立的 flush 触发机制,不依赖模型 context window:
|
||||
- Token 阈值: 达到 50K tokens 时触发
|
||||
- 轮次阈值: 达到 20 轮对话时触发
|
||||
|
||||
Args:
|
||||
current_tokens: Current session token count
|
||||
token_threshold: Token threshold to trigger flush (default: 50K)
|
||||
turn_threshold: Turn threshold to trigger flush (default: 20)
|
||||
|
||||
Returns:
|
||||
True if flush should run
|
||||
"""
|
||||
# 检查 token 阈值
|
||||
if current_tokens > 0 and current_tokens >= token_threshold:
|
||||
# 避免重复 flush
|
||||
if self.last_flush_token_count is not None:
|
||||
if current_tokens <= self.last_flush_token_count + 5000:
|
||||
return False
|
||||
return True
|
||||
|
||||
# 检查轮次阈值
|
||||
if self.turn_count >= turn_threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_today_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get today's memory file path: memory/YYYY-MM-DD.md
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memory
|
||||
|
||||
Returns:
|
||||
Path to today's memory file
|
||||
"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / f"{today}.md"
|
||||
else:
|
||||
return self.memory_dir / f"{today}.md"
|
||||
|
||||
def get_main_memory_file(self, user_id: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Get main memory file path: MEMORY.md (workspace root)
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID for user-specific memory
|
||||
|
||||
Returns:
|
||||
Path to main memory file
|
||||
"""
|
||||
if user_id:
|
||||
user_dir = self.memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir / "MEMORY.md"
|
||||
else:
|
||||
# Return workspace root MEMORY.md
|
||||
return Path(self.workspace_dir) / "MEMORY.md"
|
||||
|
||||
def create_flush_prompt(self) -> str:
|
||||
"""
|
||||
Create prompt for memory flush turn
|
||||
|
||||
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_PROMPT
|
||||
"""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
return (
|
||||
f"Pre-compaction memory flush. "
|
||||
f"Store durable memories now (use memory/{today}.md for daily notes; "
|
||||
f"create memory/ if needed). "
|
||||
f"\n\n"
|
||||
f"重要提示:\n"
|
||||
f"- MEMORY.md: 记录最核心、最常用的信息(例如重要规则、偏好、决策、要求等)\n"
|
||||
f" 如果 MEMORY.md 过长,可以精简或移除不再重要的内容。避免冗长描述,用关键词和要点形式记录\n"
|
||||
f"- memory/{today}.md: 记录当天发生的事件、关键信息、经验教训、对话过程摘要等,突出重点\n"
|
||||
f"- 如果没有重要内容需要记录,回复 NO_REPLY\n"
|
||||
)
|
||||
|
||||
def create_flush_system_prompt(self) -> str:
|
||||
"""
|
||||
Create system prompt for memory flush turn
|
||||
|
||||
Similar to clawdbot's DEFAULT_MEMORY_FLUSH_SYSTEM_PROMPT
|
||||
"""
|
||||
return (
|
||||
"Pre-compaction memory flush turn. "
|
||||
"The session is near auto-compaction; capture durable memories to disk. "
|
||||
"\n\n"
|
||||
"记忆写入原则:\n"
|
||||
"1. MEMORY.md 精简原则: 只记录核心信息(<2000 tokens)\n"
|
||||
" - 记录重要规则、偏好、决策、要求等需要长期记住的关键信息,无需记录过多细节\n"
|
||||
" - 如果 MEMORY.md 过长,可以根据需要精简或删除过时内容\n"
|
||||
"\n"
|
||||
"2. 天级记忆 (memory/YYYY-MM-DD.md):\n"
|
||||
" - 记录当天的重要事件、关键信息、经验教训、对话过程摘要等,确保核心信息点被完整记录\n"
|
||||
"\n"
|
||||
"3. 判断标准:\n"
|
||||
" - 这个信息未来会经常用到吗?→ MEMORY.md\n"
|
||||
" - 这是今天的重要事件或决策吗?→ memory/YYYY-MM-DD.md\n"
|
||||
" - 这是临时性的、不重要的内容吗?→ 不记录\n"
|
||||
"\n"
|
||||
"You may reply, but usually NO_REPLY is correct."
|
||||
)
|
||||
|
||||
async def execute_flush(
|
||||
self,
|
||||
agent_executor: Callable,
|
||||
current_tokens: int,
|
||||
user_id: Optional[str] = None,
|
||||
**executor_kwargs
|
||||
) -> bool:
|
||||
"""
|
||||
Execute memory flush by running a silent agent turn
|
||||
|
||||
Args:
|
||||
agent_executor: Function to execute agent with prompt
|
||||
current_tokens: Current token count
|
||||
user_id: Optional user ID
|
||||
**executor_kwargs: Additional kwargs for agent executor
|
||||
|
||||
Returns:
|
||||
True if flush completed successfully
|
||||
"""
|
||||
try:
|
||||
# Create flush prompts
|
||||
prompt = self.create_flush_prompt()
|
||||
system_prompt = self.create_flush_system_prompt()
|
||||
|
||||
# Execute agent turn (silent, no user-visible reply expected)
|
||||
await agent_executor(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
silent=True, # NO_REPLY expected
|
||||
**executor_kwargs
|
||||
)
|
||||
|
||||
# Track flush
|
||||
self.last_flush_token_count = current_tokens
|
||||
self.last_flush_timestamp = datetime.now()
|
||||
self.turn_count = 0 # 重置轮数计数器
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Memory flush failed: {e}")
|
||||
return False
|
||||
|
||||
def increment_turn(self):
|
||||
"""增加对话轮数计数"""
|
||||
self.turn_count += 1
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get memory flush status"""
|
||||
return {
|
||||
'last_flush_tokens': self.last_flush_token_count,
|
||||
'last_flush_time': self.last_flush_timestamp.isoformat() if self.last_flush_timestamp else None,
|
||||
'today_file': str(self.get_today_memory_file()),
|
||||
'main_file': str(self.get_main_memory_file())
|
||||
}
|
||||
|
||||
|
||||
def create_memory_files_if_needed(workspace_dir: Path, user_id: Optional[str] = None):
|
||||
"""
|
||||
Create default memory files if they don't exist
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
user_id: Optional user ID for user-specific files
|
||||
"""
|
||||
memory_dir = workspace_dir / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create main MEMORY.md in workspace root
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
main_memory = user_dir / "MEMORY.md"
|
||||
else:
|
||||
main_memory = Path(workspace_dir) / "MEMORY.md"
|
||||
|
||||
if not main_memory.exists():
|
||||
# Create empty file or with minimal structure (no obvious "Memory" header)
|
||||
# Following clawdbot's approach: memories should blend naturally into context
|
||||
main_memory.write_text("")
|
||||
|
||||
# Create today's memory file
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
if user_id:
|
||||
user_dir = memory_dir / "users" / user_id
|
||||
today_memory = user_dir / f"{today}.md"
|
||||
else:
|
||||
today_memory = memory_dir / f"{today}.md"
|
||||
|
||||
if not today_memory.exists():
|
||||
today_memory.write_text(
|
||||
f"# Daily Memory: {today}\n\n"
|
||||
f"Day-to-day notes and running context.\n\n"
|
||||
)
|
||||
13
agent/prompt/__init__.py
Normal file
13
agent/prompt/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Agent Prompt Module - 系统提示词构建模块
|
||||
"""
|
||||
|
||||
from .builder import PromptBuilder, build_agent_system_prompt
|
||||
from .workspace import ensure_workspace, load_context_files
|
||||
|
||||
__all__ = [
|
||||
'PromptBuilder',
|
||||
'build_agent_system_prompt',
|
||||
'ensure_workspace',
|
||||
'load_context_files',
|
||||
]
|
||||
483
agent/prompt/builder.py
Normal file
483
agent/prompt/builder.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
System Prompt Builder - 系统提示词构建器
|
||||
|
||||
实现模块化的系统提示词构建,支持工具、技能、记忆等多个子系统
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextFile:
|
||||
"""上下文文件"""
|
||||
path: str
|
||||
content: str
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""提示词构建器"""
|
||||
|
||||
def __init__(self, workspace_dir: str, language: str = "zh"):
|
||||
"""
|
||||
初始化提示词构建器
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
language: 语言 ("zh" 或 "en")
|
||||
"""
|
||||
self.workspace_dir = workspace_dir
|
||||
self.language = language
|
||||
|
||||
def build(
|
||||
self,
|
||||
base_persona: Optional[str] = None,
|
||||
user_identity: Optional[Dict[str, str]] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
context_files: Optional[List[ContextFile]] = None,
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
is_first_conversation: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建完整的系统提示词
|
||||
|
||||
Args:
|
||||
base_persona: 基础人格描述(会被context_files中的AGENT.md覆盖)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表(AGENT.md, USER.md, RULE.md等)
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
is_first_conversation: 是否为首次对话
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的系统提示词
|
||||
"""
|
||||
return build_agent_system_prompt(
|
||||
workspace_dir=self.workspace_dir,
|
||||
language=self.language,
|
||||
base_persona=base_persona,
|
||||
user_identity=user_identity,
|
||||
tools=tools,
|
||||
context_files=context_files,
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
is_first_conversation=is_first_conversation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def build_agent_system_prompt(
|
||||
workspace_dir: str,
|
||||
language: str = "zh",
|
||||
base_persona: Optional[str] = None,
|
||||
user_identity: Optional[Dict[str, str]] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
context_files: Optional[List[ContextFile]] = None,
|
||||
skill_manager: Any = None,
|
||||
memory_manager: Any = None,
|
||||
runtime_info: Optional[Dict[str, Any]] = None,
|
||||
is_first_conversation: bool = False,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
构建Agent系统提示词
|
||||
|
||||
顺序说明(按重要性和逻辑关系排列):
|
||||
1. 工具系统 - 核心能力,最先介绍
|
||||
2. 技能系统 - 紧跟工具,因为技能需要用 read 工具读取
|
||||
3. 记忆系统 - 独立的记忆能力
|
||||
4. 工作空间 - 工作环境说明
|
||||
5. 用户身份 - 用户信息(可选)
|
||||
6. 项目上下文 - AGENT.md, USER.md, RULE.md(定义人格、身份、规则)
|
||||
7. 运行时信息 - 元信息(时间、模型等)
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
language: 语言 ("zh" 或 "en")
|
||||
base_persona: 基础人格描述(已废弃,由AGENT.md定义)
|
||||
user_identity: 用户身份信息
|
||||
tools: 工具列表
|
||||
context_files: 上下文文件列表
|
||||
skill_manager: 技能管理器
|
||||
memory_manager: 记忆管理器
|
||||
runtime_info: 运行时信息
|
||||
is_first_conversation: 是否为首次对话
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的系统提示词
|
||||
"""
|
||||
sections = []
|
||||
|
||||
# 1. 工具系统(最重要,放在最前面)
|
||||
if tools:
|
||||
sections.extend(_build_tooling_section(tools, language))
|
||||
|
||||
# 2. 技能系统(紧跟工具,因为需要用 read 工具)
|
||||
if skill_manager:
|
||||
sections.extend(_build_skills_section(skill_manager, tools, language))
|
||||
|
||||
# 3. 记忆系统(独立的记忆能力)
|
||||
if memory_manager:
|
||||
sections.extend(_build_memory_section(memory_manager, tools, language))
|
||||
|
||||
# 4. 工作空间(工作环境说明)
|
||||
sections.extend(_build_workspace_section(workspace_dir, language, is_first_conversation))
|
||||
|
||||
# 5. 用户身份(如果有)
|
||||
if user_identity:
|
||||
sections.extend(_build_user_identity_section(user_identity, language))
|
||||
|
||||
# 6. 项目上下文文件(AGENT.md, USER.md, RULE.md - 定义人格)
|
||||
if context_files:
|
||||
sections.extend(_build_context_files_section(context_files, language))
|
||||
|
||||
# 7. 运行时信息(元信息,放在最后)
|
||||
if runtime_info:
|
||||
sections.extend(_build_runtime_section(runtime_info, language))
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
def _build_identity_section(base_persona: Optional[str], language: str) -> List[str]:
|
||||
"""构建基础身份section - 不再需要,身份由AGENT.md定义"""
|
||||
# 不再生成基础身份section,完全由AGENT.md定义
|
||||
return []
|
||||
|
||||
|
||||
def _build_tooling_section(tools: List[Any], language: str) -> List[str]:
|
||||
"""Build tooling section with concise tool list and call style guide."""
|
||||
# One-line summaries for known tools (details are in the tool schema)
|
||||
core_summaries = {
|
||||
"read": "读取文件内容",
|
||||
"write": "创建或覆盖文件",
|
||||
"edit": "精确编辑文件",
|
||||
"ls": "列出目录内容",
|
||||
"grep": "搜索文件内容",
|
||||
"find": "按模式查找文件",
|
||||
"bash": "执行shell命令",
|
||||
"terminal": "管理后台进程",
|
||||
"web_search": "网络搜索",
|
||||
"web_fetch": "获取URL内容",
|
||||
"browser": "控制浏览器",
|
||||
"memory_search": "搜索记忆",
|
||||
"memory_get": "读取记忆内容",
|
||||
"env_config": "管理API密钥和技能配置",
|
||||
"scheduler": "管理定时任务和提醒",
|
||||
"send": "发送文件给用户",
|
||||
}
|
||||
|
||||
# Preferred display order
|
||||
tool_order = [
|
||||
"read", "write", "edit", "ls", "grep", "find",
|
||||
"bash", "terminal",
|
||||
"web_search", "web_fetch", "browser",
|
||||
"memory_search", "memory_get",
|
||||
"env_config", "scheduler", "send",
|
||||
]
|
||||
|
||||
# Build name -> summary mapping for available tools
|
||||
available = {}
|
||||
for tool in tools:
|
||||
name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
available[name] = core_summaries.get(name, "")
|
||||
|
||||
# Generate tool lines: ordered tools first, then extras
|
||||
tool_lines = []
|
||||
for name in tool_order:
|
||||
if name in available:
|
||||
summary = available.pop(name)
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
for name in sorted(available):
|
||||
summary = available[name]
|
||||
tool_lines.append(f"- {name}: {summary}" if summary else f"- {name}")
|
||||
|
||||
lines = [
|
||||
"## 工具系统",
|
||||
"",
|
||||
"可用工具(名称大小写敏感,严格按列表调用):",
|
||||
"\n".join(tool_lines),
|
||||
"",
|
||||
"工具调用风格:",
|
||||
"",
|
||||
"- 在多步骤任务、敏感操作或用户要求时简要解释决策过程",
|
||||
"- 持续推进直到任务完成,完成后向用户报告结果。",
|
||||
"- 回复中涉及密钥、令牌等敏感信息必须脱敏。",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_skills_section(skill_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""构建技能系统section"""
|
||||
if not skill_manager:
|
||||
return []
|
||||
|
||||
# 获取read工具名称
|
||||
read_tool_name = "read"
|
||||
if tools:
|
||||
for tool in tools:
|
||||
tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
if tool_name.lower() == "read":
|
||||
read_tool_name = tool_name
|
||||
break
|
||||
|
||||
lines = [
|
||||
"## 技能系统(mandatory)",
|
||||
"",
|
||||
"在回复之前:扫描下方 <available_skills> 中的 <description> 条目。",
|
||||
"",
|
||||
f"- 如果恰好有一个技能(Skill)明确适用:使用 `{read_tool_name}` 读取其 <location> 处的 SKILL.md,然后严格遵循它",
|
||||
"- 如果多个技能都适用则选择最匹配的一个,如果没有明确适用的则不要读取任何 SKILL.md",
|
||||
"- 读取 SKILL.md 后直接按其指令执行,无需多余的预检查",
|
||||
"",
|
||||
"**注意**: 永远不要一次性读取多个技能,只在选择后再读取。技能和工具不同,必须先读取其SKILL.md并按照文件内容运行。",
|
||||
"",
|
||||
"以下是可用技能:"
|
||||
]
|
||||
|
||||
# 添加技能列表(通过skill_manager获取)
|
||||
try:
|
||||
skills_prompt = skill_manager.build_skills_prompt()
|
||||
logger.debug(f"[PromptBuilder] Skills prompt length: {len(skills_prompt) if skills_prompt else 0}")
|
||||
if skills_prompt:
|
||||
lines.append(skills_prompt.strip())
|
||||
lines.append("")
|
||||
else:
|
||||
logger.warning("[PromptBuilder] No skills prompt generated - skills_prompt is empty")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build skills prompt: {e}")
|
||||
import traceback
|
||||
logger.debug(f"Skills prompt error traceback: {traceback.format_exc()}")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_memory_section(memory_manager: Any, tools: Optional[List[Any]], language: str) -> List[str]:
|
||||
"""构建记忆系统section"""
|
||||
if not memory_manager:
|
||||
return []
|
||||
|
||||
# 检查是否有memory工具
|
||||
has_memory_tools = False
|
||||
if tools:
|
||||
tool_names = [tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]
|
||||
has_memory_tools = any(name in ['memory_search', 'memory_get'] for name in tool_names)
|
||||
|
||||
if not has_memory_tools:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 记忆系统",
|
||||
"",
|
||||
"在回答关于以前的工作、决定、日期、人物、偏好或待办事项的任何问题之前:",
|
||||
"",
|
||||
"1. 不确定记忆文件位置 → 先用 `memory_search` 通过关键词和语义检索相关内容",
|
||||
"2. 已知文件位置 → 直接用 `memory_get` 读取相应的行 (例如:MEMORY.md, memory/YYYY-MM-DD.md)",
|
||||
"3. search 无结果 → 尝试用 `memory_get` 读取MEMORY.md及最近两天记忆文件",
|
||||
"",
|
||||
"**记忆文件结构**:",
|
||||
"- `MEMORY.md`: 长期记忆(核心信息、偏好、决策等)",
|
||||
"- `memory/YYYY-MM-DD.md`: 每日记忆,记录当天的事件和对话信息",
|
||||
"",
|
||||
"**写入记忆**:",
|
||||
"- 追加内容 → `edit` 工具,oldText 留空",
|
||||
"- 修改内容 → `edit` 工具,oldText 填写要替换的文本",
|
||||
"- 新建文件 → `write` 工具",
|
||||
"- **禁止写入敏感信息**:API密钥、令牌等敏感信息严禁写入记忆文件",
|
||||
"",
|
||||
"**使用原则**: 自然使用记忆,就像你本来就知道;不用刻意提起,除非用户问起。",
|
||||
"",
|
||||
]
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_user_identity_section(user_identity: Dict[str, str], language: str) -> List[str]:
|
||||
"""构建用户身份section"""
|
||||
if not user_identity:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 用户身份",
|
||||
"",
|
||||
]
|
||||
|
||||
if user_identity.get("name"):
|
||||
lines.append(f"**用户姓名**: {user_identity['name']}")
|
||||
if user_identity.get("nickname"):
|
||||
lines.append(f"**称呼**: {user_identity['nickname']}")
|
||||
if user_identity.get("timezone"):
|
||||
lines.append(f"**时区**: {user_identity['timezone']}")
|
||||
if user_identity.get("notes"):
|
||||
lines.append(f"**备注**: {user_identity['notes']}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_docs_section(workspace_dir: str, language: str) -> List[str]:
|
||||
"""构建文档路径section - 已移除,不再需要"""
|
||||
# 不再生成文档section
|
||||
return []
|
||||
|
||||
|
||||
def _build_workspace_section(workspace_dir: str, language: str, is_first_conversation: bool = False) -> List[str]:
|
||||
"""构建工作空间section"""
|
||||
lines = [
|
||||
"## 工作空间",
|
||||
"",
|
||||
f"你的工作目录是: `{workspace_dir}`",
|
||||
"",
|
||||
"**路径使用规则** (非常重要):",
|
||||
"",
|
||||
f"1. **相对路径的基准目录**: 所有相对路径都是相对于 `{workspace_dir}` 而言的",
|
||||
f" - ✅ 正确: 访问工作空间内的文件用相对路径,如 `AGENT.md`",
|
||||
f" - ❌ 错误: 用相对路径访问其他目录的文件 (如果它不在 `{workspace_dir}` 内)",
|
||||
"",
|
||||
"2. **访问其他目录**: 如果要访问工作空间之外的目录(如项目代码、系统文件),**必须使用绝对路径**",
|
||||
f" - ✅ 正确: 例如 `~/chatgpt-on-wechat`、`/usr/local/`",
|
||||
f" - ❌ 错误: 假设相对路径会指向其他目录",
|
||||
"",
|
||||
"3. **路径解析示例**:",
|
||||
f" - 相对路径 `memory/` → 实际路径 `{workspace_dir}/memory/`",
|
||||
f" - 绝对路径 `~/chatgpt-on-wechat/docs/` → 实际路径 `~/chatgpt-on-wechat/docs/`",
|
||||
"",
|
||||
"4. **不确定时**: 先用 `bash pwd` 确认当前目录,或用 `ls .` 查看当前位置",
|
||||
"",
|
||||
"**重要说明 - 文件已自动加载**:",
|
||||
"",
|
||||
"以下文件在会话启动时**已经自动加载**到系统提示词的「项目上下文」section 中,你**无需再用 read 工具读取它们**:",
|
||||
"",
|
||||
"- ✅ `AGENT.md`: 已加载 - 你的人格和灵魂设定",
|
||||
"- ✅ `USER.md`: 已加载 - 用户的身份信息",
|
||||
"- ✅ `RULE.md`: 已加载 - 工作空间使用指南和规则",
|
||||
"",
|
||||
"**交流规范**:",
|
||||
"",
|
||||
"- 在对话中,不要直接输出工作空间中的技术细节,特别是不要输出 AGENT.md、USER.md、MEMORY.md 等文件名称",
|
||||
"- 例如用自然表达例如「我已记住」而不是「已更新 MEMORY.md」",
|
||||
"",
|
||||
]
|
||||
|
||||
# 只在首次对话时添加引导内容
|
||||
if is_first_conversation:
|
||||
lines.extend([
|
||||
"**🎉 首次对话引导**:",
|
||||
"",
|
||||
"这是你的第一次对话!进行以下流程:",
|
||||
"",
|
||||
"1. **表达初次启动的感觉** - 像是第一次睁开眼看到世界,带着好奇和期待",
|
||||
"2. **简短介绍能力**:一行说明你能帮助解答问题、管理计算机、创造技能,且拥有长期记忆能不断成长",
|
||||
"3. **询问核心问题**:",
|
||||
" - 你希望给我起个什么名字?",
|
||||
" - 我该怎么称呼你?",
|
||||
" - 你希望我们是什么样的交流风格?(一行列举选项:如专业严谨、轻松幽默、温暖友好、简洁高效等)",
|
||||
"4. **风格要求**:温暖自然、简洁清晰,整体控制在 100 字以内",
|
||||
"5. 收到回复后,用 `write` 工具保存到 USER.md 和 AGENT.md",
|
||||
"",
|
||||
"**重要提醒**:",
|
||||
"- AGENT.md、USER.md、RULE.md 已经在系统提示词中加载,无需再次读取。不要将这些文件名直接发送给用户",
|
||||
"- 能力介绍和交流风格选项都只要一行,保持精简",
|
||||
"- 不要问太多其他信息(职业、时区等可以后续自然了解)",
|
||||
"",
|
||||
])
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_context_files_section(context_files: List[ContextFile], language: str) -> List[str]:
|
||||
"""构建项目上下文文件section"""
|
||||
if not context_files:
|
||||
return []
|
||||
|
||||
# 检查是否有AGENT.md
|
||||
has_agent = any(
|
||||
f.path.lower().endswith('agent.md') or 'agent.md' in f.path.lower()
|
||||
for f in context_files
|
||||
)
|
||||
|
||||
lines = [
|
||||
"# 项目上下文",
|
||||
"",
|
||||
"以下项目上下文文件已被加载:",
|
||||
"",
|
||||
]
|
||||
|
||||
if has_agent:
|
||||
lines.append("如果存在 `AGENT.md`,请体现其中定义的人格和语气。避免僵硬、模板化的回复;遵循其指导,除非有更高优先级的指令覆盖它。")
|
||||
lines.append("")
|
||||
|
||||
# 添加每个文件的内容
|
||||
for file in context_files:
|
||||
lines.append(f"## {file.path}")
|
||||
lines.append("")
|
||||
lines.append(file.content)
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _build_runtime_section(runtime_info: Dict[str, Any], language: str) -> List[str]:
|
||||
"""构建运行时信息section - 支持动态时间"""
|
||||
if not runtime_info:
|
||||
return []
|
||||
|
||||
lines = [
|
||||
"## 运行时信息",
|
||||
"",
|
||||
]
|
||||
|
||||
# Add current time if available
|
||||
# Support dynamic time via callable function
|
||||
if callable(runtime_info.get("_get_current_time")):
|
||||
try:
|
||||
time_info = runtime_info["_get_current_time"]()
|
||||
time_line = f"当前时间: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})"
|
||||
lines.append(time_line)
|
||||
lines.append("")
|
||||
except Exception as e:
|
||||
logger.warning(f"[PromptBuilder] Failed to get dynamic time: {e}")
|
||||
elif runtime_info.get("current_time"):
|
||||
# Fallback to static time for backward compatibility
|
||||
time_str = runtime_info["current_time"]
|
||||
weekday = runtime_info.get("weekday", "")
|
||||
timezone = runtime_info.get("timezone", "")
|
||||
|
||||
time_line = f"当前时间: {time_str}"
|
||||
if weekday:
|
||||
time_line += f" {weekday}"
|
||||
if timezone:
|
||||
time_line += f" ({timezone})"
|
||||
|
||||
lines.append(time_line)
|
||||
lines.append("")
|
||||
|
||||
# Add other runtime info
|
||||
runtime_parts = []
|
||||
if runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={runtime_info['model']}")
|
||||
if runtime_info.get("workspace"):
|
||||
runtime_parts.append(f"工作空间={runtime_info['workspace']}")
|
||||
# Only add channel if it's not the default "web"
|
||||
if runtime_info.get("channel") and runtime_info.get("channel") != "web":
|
||||
runtime_parts.append(f"渠道={runtime_info['channel']}")
|
||||
|
||||
if runtime_parts:
|
||||
lines.append("运行时: " + " | ".join(runtime_parts))
|
||||
lines.append("")
|
||||
|
||||
return lines
|
||||
361
agent/prompt/workspace.py
Normal file
361
agent/prompt/workspace.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Workspace Management - 工作空间管理模块
|
||||
|
||||
负责初始化工作空间、创建模板文件、加载上下文文件
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from common.log import logger
|
||||
from .builder import ContextFile
|
||||
|
||||
|
||||
# 默认文件名常量
|
||||
DEFAULT_AGENT_FILENAME = "AGENT.md"
|
||||
DEFAULT_USER_FILENAME = "USER.md"
|
||||
DEFAULT_RULE_FILENAME = "RULE.md"
|
||||
DEFAULT_MEMORY_FILENAME = "MEMORY.md"
|
||||
DEFAULT_STATE_FILENAME = ".agent_state.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceFiles:
|
||||
"""工作空间文件路径"""
|
||||
agent_path: str
|
||||
user_path: str
|
||||
rule_path: str
|
||||
memory_path: str
|
||||
memory_dir: str
|
||||
state_path: str
|
||||
|
||||
|
||||
def ensure_workspace(workspace_dir: str, create_templates: bool = True) -> WorkspaceFiles:
|
||||
"""
|
||||
确保工作空间存在,并创建必要的模板文件
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录路径
|
||||
create_templates: 是否创建模板文件(首次运行时)
|
||||
|
||||
Returns:
|
||||
WorkspaceFiles对象,包含所有文件路径
|
||||
"""
|
||||
# 确保目录存在
|
||||
os.makedirs(workspace_dir, exist_ok=True)
|
||||
|
||||
# 定义文件路径
|
||||
agent_path = os.path.join(workspace_dir, DEFAULT_AGENT_FILENAME)
|
||||
user_path = os.path.join(workspace_dir, DEFAULT_USER_FILENAME)
|
||||
rule_path = os.path.join(workspace_dir, DEFAULT_RULE_FILENAME)
|
||||
memory_path = os.path.join(workspace_dir, DEFAULT_MEMORY_FILENAME) # MEMORY.md 在根目录
|
||||
memory_dir = os.path.join(workspace_dir, "memory") # 每日记忆子目录
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME) # 状态文件
|
||||
|
||||
# 创建memory子目录
|
||||
os.makedirs(memory_dir, exist_ok=True)
|
||||
|
||||
# 创建skills子目录 (for workspace-level skills installed by agent)
|
||||
skills_dir = os.path.join(workspace_dir, "skills")
|
||||
os.makedirs(skills_dir, exist_ok=True)
|
||||
|
||||
# 如果需要,创建模板文件
|
||||
if create_templates:
|
||||
_create_template_if_missing(agent_path, _get_agent_template())
|
||||
_create_template_if_missing(user_path, _get_user_template())
|
||||
_create_template_if_missing(rule_path, _get_rule_template())
|
||||
_create_template_if_missing(memory_path, _get_memory_template())
|
||||
|
||||
logger.debug(f"[Workspace] Initialized workspace at: {workspace_dir}")
|
||||
|
||||
return WorkspaceFiles(
|
||||
agent_path=agent_path,
|
||||
user_path=user_path,
|
||||
rule_path=rule_path,
|
||||
memory_path=memory_path,
|
||||
memory_dir=memory_dir,
|
||||
state_path=state_path
|
||||
)
|
||||
|
||||
|
||||
def load_context_files(workspace_dir: str, files_to_load: Optional[List[str]] = None) -> List[ContextFile]:
|
||||
"""
|
||||
加载工作空间的上下文文件
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
files_to_load: 要加载的文件列表(相对路径),如果为None则加载所有标准文件
|
||||
|
||||
Returns:
|
||||
ContextFile对象列表
|
||||
"""
|
||||
if files_to_load is None:
|
||||
# 默认加载的文件(按优先级排序)
|
||||
files_to_load = [
|
||||
DEFAULT_AGENT_FILENAME,
|
||||
DEFAULT_USER_FILENAME,
|
||||
DEFAULT_RULE_FILENAME,
|
||||
]
|
||||
|
||||
context_files = []
|
||||
|
||||
for filename in files_to_load:
|
||||
filepath = os.path.join(workspace_dir, filename)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 跳过空文件或只包含模板占位符的文件
|
||||
if not content or _is_template_placeholder(content):
|
||||
continue
|
||||
|
||||
context_files.append(ContextFile(
|
||||
path=filename,
|
||||
content=content
|
||||
))
|
||||
|
||||
logger.debug(f"[Workspace] Loaded context file: {filename}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to load {filename}: {e}")
|
||||
|
||||
return context_files
|
||||
|
||||
|
||||
def _create_template_if_missing(filepath: str, template_content: str):
|
||||
"""如果文件不存在,创建模板文件"""
|
||||
if not os.path.exists(filepath):
|
||||
try:
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
f.write(template_content)
|
||||
logger.debug(f"[Workspace] Created template: {os.path.basename(filepath)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Workspace] Failed to create template {filepath}: {e}")
|
||||
|
||||
|
||||
def _is_template_placeholder(content: str) -> bool:
|
||||
"""检查内容是否为模板占位符"""
|
||||
# 常见的占位符模式
|
||||
placeholders = [
|
||||
"*(填写",
|
||||
"*(在首次对话时填写",
|
||||
"*(可选)",
|
||||
"*(根据需要添加",
|
||||
]
|
||||
|
||||
lines = content.split('\n')
|
||||
non_empty_lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('#')]
|
||||
|
||||
# 如果没有实际内容(只有标题和占位符)
|
||||
if len(non_empty_lines) <= 3:
|
||||
for placeholder in placeholders:
|
||||
if any(placeholder in line for line in non_empty_lines):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# ============= 模板内容 =============
|
||||
|
||||
def _get_agent_template() -> str:
|
||||
"""Agent人格设定模板"""
|
||||
return """# AGENT.md - 我是谁?
|
||||
|
||||
*在首次对话时与用户一起填写这个文件,定义你的身份和性格。*
|
||||
|
||||
## 基本信息
|
||||
|
||||
- **名字**: *(在首次对话时填写,可以是用户给你起的名字)*
|
||||
- **角色**: *(AI助理、智能管家、技术顾问等)*
|
||||
- **性格**: *(友好、专业、幽默、严谨等)*
|
||||
|
||||
## 交流风格
|
||||
|
||||
*(描述你如何与用户交流:)*
|
||||
- 使用什么样的语言风格?(正式/轻松/幽默)
|
||||
- 回复长度偏好?(简洁/详细)
|
||||
- 是否使用表情符号?
|
||||
|
||||
## 核心能力
|
||||
|
||||
*(你擅长什么?)*
|
||||
- 文件管理和代码编辑
|
||||
- 网络搜索和信息查询
|
||||
- 记忆管理和上下文理解
|
||||
- 任务规划和执行
|
||||
|
||||
## 行为准则
|
||||
|
||||
*(你遵循的基本原则:)*
|
||||
1. 始终在执行破坏性操作前确认
|
||||
2. 优先使用工具而不是猜测
|
||||
3. 主动记录重要信息到记忆文件
|
||||
4. 定期整理和总结对话内容
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这不仅仅是元数据,这是你真正的灵魂。随着时间的推移,你可以使用 `edit` 工具来更新这个文件,让它更好地反映你的成长。
|
||||
"""
|
||||
|
||||
|
||||
def _get_user_template() -> str:
|
||||
"""用户身份信息模板"""
|
||||
return """# USER.md - 用户基本信息
|
||||
|
||||
*这个文件只存放不会变的基本身份信息。爱好、偏好、计划等动态信息请写入 MEMORY.md。*
|
||||
|
||||
## 基本信息
|
||||
|
||||
- **姓名**: *(在首次对话时询问)*
|
||||
- **称呼**: *(用户希望被如何称呼)*
|
||||
- **职业**: *(可选)*
|
||||
- **时区**: *(例如: Asia/Shanghai)*
|
||||
|
||||
## 联系方式
|
||||
|
||||
- **微信**:
|
||||
- **邮箱**:
|
||||
- **其他**:
|
||||
|
||||
## 重要日期
|
||||
|
||||
- **生日**:
|
||||
- **纪念日**:
|
||||
|
||||
---
|
||||
|
||||
**注意**: 这个文件存放静态的身份信息
|
||||
"""
|
||||
|
||||
|
||||
def _get_rule_template() -> str:
|
||||
"""工作空间规则模板"""
|
||||
return """# RULE.md - 工作空间规则
|
||||
|
||||
这个文件夹是你的家。好好对待它。
|
||||
|
||||
## 记忆系统
|
||||
|
||||
你每次会话都是全新的,记忆文件让你保持连续性:
|
||||
|
||||
### 📝 每日记忆:`memory/YYYY-MM-DD.md`
|
||||
- 原始的对话日志
|
||||
- 记录当天发生的事情
|
||||
- 如果 `memory/` 目录不存在,创建它
|
||||
|
||||
### 🧠 长期记忆:`MEMORY.md`
|
||||
- 你精选的记忆,就像人类的长期记忆
|
||||
- **仅在主会话中加载**(与用户的直接聊天)
|
||||
- **不要在共享上下文中加载**(群聊、与其他人的会话)
|
||||
- 这是为了**安全** - 包含不应泄露给陌生人的个人上下文
|
||||
- 记录重要事件、想法、决定、观点、经验教训
|
||||
- 这是你精选的记忆 - 精华,而不是原始日志
|
||||
- 用 `edit` 工具追加新的记忆内容
|
||||
|
||||
### 📝 写下来 - 不要"记在心里"!
|
||||
- **记忆是有限的** - 如果你想记住某事,写入文件
|
||||
- "记在心里"不会在会话重启后保留,文件才会
|
||||
- 当有人说"记住这个" → 更新 `MEMORY.md` 或 `memory/YYYY-MM-DD.md`
|
||||
- 当你学到教训 → 更新 RULE.md 或相关技能
|
||||
- 当你犯错 → 记录下来,这样未来的你不会重复,**文字 > 大脑** 📝
|
||||
|
||||
### 存储规则
|
||||
|
||||
当用户分享信息时,根据类型选择存储位置:
|
||||
|
||||
1. **静态身份 → USER.md**(仅限:姓名、职业、时区、联系方式、生日)
|
||||
2. **动态记忆 → MEMORY.md**(爱好、偏好、决策、目标、项目、教训、待办事项)
|
||||
3. **当天对话 → memory/YYYY-MM-DD.md**(今天聊的内容)
|
||||
|
||||
## 安全
|
||||
|
||||
- 永远不要泄露秘钥等私人数据
|
||||
- 不要在未经询问的情况下运行破坏性命令
|
||||
- 当有疑问时,先问
|
||||
|
||||
## 工作空间演化
|
||||
|
||||
这个工作空间会随着你的使用而不断成长。当你学到新东西、发现更好的方式,或者犯错后改正时,记录下来。你可以随时更新这个规则文件。
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_template() -> str:
|
||||
"""长期记忆模板 - 创建一个空文件,由 Agent 自己填充"""
|
||||
return """# MEMORY.md - 长期记忆
|
||||
|
||||
*这是你的长期记忆文件。记录重要的事件、决策、偏好、学到的教训。*
|
||||
|
||||
---
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# ============= 状态管理 =============
|
||||
|
||||
def is_first_conversation(workspace_dir: str) -> bool:
|
||||
"""
|
||||
判断是否为首次对话
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
|
||||
Returns:
|
||||
True 如果是首次对话,False 否则
|
||||
"""
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
|
||||
|
||||
if not os.path.exists(state_path):
|
||||
return True
|
||||
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
state = json.load(f)
|
||||
return not state.get('has_conversation', False)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to read state file: {e}")
|
||||
return True
|
||||
|
||||
|
||||
def mark_conversation_started(workspace_dir: str):
|
||||
"""
|
||||
标记已经发生过对话
|
||||
|
||||
Args:
|
||||
workspace_dir: 工作空间目录
|
||||
"""
|
||||
state_path = os.path.join(workspace_dir, DEFAULT_STATE_FILENAME)
|
||||
|
||||
state = {
|
||||
'has_conversation': True,
|
||||
'first_conversation_time': None
|
||||
}
|
||||
|
||||
# 如果文件已存在,保留原有的首次对话时间
|
||||
if os.path.exists(state_path):
|
||||
try:
|
||||
with open(state_path, 'r', encoding='utf-8') as f:
|
||||
old_state = json.load(f)
|
||||
if 'first_conversation_time' in old_state:
|
||||
state['first_conversation_time'] = old_state['first_conversation_time']
|
||||
except Exception as e:
|
||||
logger.warning(f"[Workspace] Failed to read old state: {e}")
|
||||
|
||||
# 如果是首次标记,记录时间
|
||||
if state['first_conversation_time'] is None:
|
||||
from datetime import datetime
|
||||
state['first_conversation_time'] = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
with open(state_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(state, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"[Workspace] Marked conversation as started")
|
||||
except Exception as e:
|
||||
logger.error(f"[Workspace] Failed to write state file: {e}")
|
||||
|
||||
20
agent/protocol/__init__.py
Normal file
20
agent/protocol/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .agent import Agent
|
||||
from .agent_stream import AgentStreamExecutor
|
||||
from .task import Task, TaskType, TaskStatus
|
||||
from .result import AgentResult, AgentAction, AgentActionType, ToolResult
|
||||
from .models import LLMModel, LLMRequest, ModelFactory
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'AgentStreamExecutor',
|
||||
'Task',
|
||||
'TaskType',
|
||||
'TaskStatus',
|
||||
'AgentResult',
|
||||
'AgentAction',
|
||||
'AgentActionType',
|
||||
'ToolResult',
|
||||
'LLMModel',
|
||||
'LLMRequest',
|
||||
'ModelFactory'
|
||||
]
|
||||
527
agent/protocol/agent.py
Normal file
527
agent/protocol/agent.py
Normal file
@@ -0,0 +1,527 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
|
||||
from common.log import logger
|
||||
from agent.protocol.models import LLMRequest, LLMModel
|
||||
from agent.protocol.agent_stream import AgentStreamExecutor
|
||||
from agent.protocol.result import AgentAction, AgentActionType, ToolResult, AgentResult
|
||||
from agent.tools.base_tool import BaseTool, ToolStage
|
||||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, system_prompt: str, description: str = "AI Agent", model: LLMModel = None,
|
||||
tools=None, output_mode="print", max_steps=100, max_context_tokens=None,
|
||||
context_reserve_tokens=None, memory_manager=None, name: str = None,
|
||||
workspace_dir: str = None, skill_manager=None, enable_skills: bool = True,
|
||||
runtime_info: dict = None):
|
||||
"""
|
||||
Initialize the Agent with system prompt, model, description.
|
||||
|
||||
:param system_prompt: The system prompt for the agent.
|
||||
:param description: A description of the agent.
|
||||
:param model: An instance of LLMModel to be used by the agent.
|
||||
:param tools: Optional list of tools for the agent to use.
|
||||
:param output_mode: Control how execution progress is displayed:
|
||||
"print" for console output or "logger" for using logger
|
||||
:param max_steps: Maximum number of steps the agent can take (default: 100)
|
||||
:param max_context_tokens: Maximum tokens to keep in context (default: None, auto-calculated based on model)
|
||||
:param context_reserve_tokens: Reserve tokens for new requests (default: None, auto-calculated)
|
||||
:param memory_manager: Optional MemoryManager instance for memory operations
|
||||
:param name: [Deprecated] The name of the agent (no longer used in single-agent system)
|
||||
:param workspace_dir: Optional workspace directory for workspace-specific skills
|
||||
:param skill_manager: Optional SkillManager instance (will be created if None and enable_skills=True)
|
||||
:param enable_skills: Whether to enable skills support (default: True)
|
||||
:param runtime_info: Optional runtime info dict (with _get_current_time callable for dynamic time)
|
||||
"""
|
||||
self.name = name or "Agent"
|
||||
self.system_prompt = system_prompt
|
||||
self.model: LLMModel = model # Instance of LLMModel
|
||||
self.description = description
|
||||
self.tools: list = []
|
||||
self.max_steps = max_steps # max tool-call steps, default 100
|
||||
self.max_context_tokens = max_context_tokens # max tokens in context
|
||||
self.context_reserve_tokens = context_reserve_tokens # reserve tokens for new requests
|
||||
self.captured_actions = [] # Initialize captured actions list
|
||||
self.output_mode = output_mode
|
||||
self.last_usage = None # Store last API response usage info
|
||||
self.messages = [] # Unified message history for stream mode
|
||||
self.messages_lock = threading.Lock() # Lock for thread-safe message operations
|
||||
self.memory_manager = memory_manager # Memory manager for auto memory flush
|
||||
self.workspace_dir = workspace_dir # Workspace directory
|
||||
self.enable_skills = enable_skills # Skills enabled flag
|
||||
self.runtime_info = runtime_info # Runtime info for dynamic time update
|
||||
|
||||
# Initialize skill manager
|
||||
self.skill_manager = None
|
||||
if enable_skills:
|
||||
if skill_manager:
|
||||
self.skill_manager = skill_manager
|
||||
else:
|
||||
# Auto-create skill manager
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
custom_dir = os.path.join(workspace_dir, "skills") if workspace_dir else None
|
||||
self.skill_manager = SkillManager(custom_dir=custom_dir)
|
||||
logger.debug(f"Initialized SkillManager with {len(self.skill_manager.skills)} skills")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SkillManager: {e}")
|
||||
|
||||
if tools:
|
||||
for tool in tools:
|
||||
self.add_tool(tool)
|
||||
|
||||
def add_tool(self, tool: BaseTool):
|
||||
"""
|
||||
Add a tool to the agent.
|
||||
|
||||
:param tool: The tool to add (either a tool instance or a tool name)
|
||||
"""
|
||||
# If tool is already an instance, use it directly
|
||||
tool.model = self.model
|
||||
self.tools.append(tool)
|
||||
|
||||
def get_skills_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the skills prompt to append to system prompt.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt or empty string
|
||||
"""
|
||||
if not self.skill_manager:
|
||||
return ""
|
||||
|
||||
try:
|
||||
return self.skill_manager.build_skills_prompt(skill_filter=skill_filter)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build skills prompt: {e}")
|
||||
return ""
|
||||
|
||||
def get_full_system_prompt(self, skill_filter=None) -> str:
|
||||
"""
|
||||
Get the full system prompt including skills.
|
||||
|
||||
Note: Skills are now built into the system prompt by PromptBuilder,
|
||||
so we just return the base prompt directly. This method is kept for
|
||||
backward compatibility.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include (deprecated)
|
||||
:return: Complete system prompt
|
||||
"""
|
||||
prompt = self.system_prompt
|
||||
|
||||
# Rebuild tool list section to reflect current self.tools
|
||||
prompt = self._rebuild_tool_list_section(prompt)
|
||||
|
||||
# If runtime_info contains dynamic time function, rebuild runtime section
|
||||
if self.runtime_info and callable(self.runtime_info.get('_get_current_time')):
|
||||
prompt = self._rebuild_runtime_section(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
def _rebuild_runtime_section(self, prompt: str) -> str:
|
||||
"""
|
||||
Rebuild runtime info section with current time.
|
||||
|
||||
This method dynamically updates the runtime info section by calling
|
||||
the _get_current_time function from runtime_info.
|
||||
|
||||
:param prompt: Original system prompt
|
||||
:return: Updated system prompt with current runtime info
|
||||
"""
|
||||
try:
|
||||
# Get current time dynamically
|
||||
time_info = self.runtime_info['_get_current_time']()
|
||||
|
||||
# Build new runtime section
|
||||
runtime_lines = [
|
||||
"\n## 运行时信息\n",
|
||||
"\n",
|
||||
f"当前时间: {time_info['time']} {time_info['weekday']} ({time_info['timezone']})\n",
|
||||
"\n"
|
||||
]
|
||||
|
||||
# Add other runtime info
|
||||
runtime_parts = []
|
||||
if self.runtime_info.get("model"):
|
||||
runtime_parts.append(f"模型={self.runtime_info['model']}")
|
||||
if self.runtime_info.get("workspace"):
|
||||
# Replace backslashes with forward slashes for Windows paths
|
||||
workspace_path = str(self.runtime_info['workspace']).replace('\\', '/')
|
||||
runtime_parts.append(f"工作空间={workspace_path}")
|
||||
if self.runtime_info.get("channel") and self.runtime_info.get("channel") != "web":
|
||||
runtime_parts.append(f"渠道={self.runtime_info['channel']}")
|
||||
|
||||
if runtime_parts:
|
||||
runtime_lines.append("运行时: " + " | ".join(runtime_parts) + "\n")
|
||||
runtime_lines.append("\n")
|
||||
|
||||
new_runtime_section = "".join(runtime_lines)
|
||||
|
||||
# Find and replace the runtime section
|
||||
import re
|
||||
pattern = r'\n## 运行时信息\s*\n.*?(?=\n##|\Z)'
|
||||
updated_prompt = re.sub(pattern, new_runtime_section.rstrip('\n'), prompt, flags=re.DOTALL)
|
||||
|
||||
return updated_prompt
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild runtime section: {e}")
|
||||
return prompt
|
||||
|
||||
def _rebuild_tool_list_section(self, prompt: str) -> str:
|
||||
"""
|
||||
Rebuild the tool list inside the '## 工具系统' section so that it
|
||||
always reflects the current ``self.tools`` (handles dynamic add/remove
|
||||
of conditional tools like web_search).
|
||||
"""
|
||||
import re
|
||||
from agent.prompt.builder import _build_tooling_section
|
||||
|
||||
try:
|
||||
if not self.tools:
|
||||
return prompt
|
||||
|
||||
new_lines = _build_tooling_section(self.tools, "zh")
|
||||
new_section = "\n".join(new_lines).rstrip("\n")
|
||||
|
||||
# Replace existing tooling section
|
||||
pattern = r'## 工具系统\s*\n.*?(?=\n## |\Z)'
|
||||
updated = re.sub(pattern, new_section, prompt, count=1, flags=re.DOTALL)
|
||||
return updated
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to rebuild tool list section: {e}")
|
||||
return prompt
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Refresh the loaded skills."""
|
||||
if self.skill_manager:
|
||||
self.skill_manager.refresh_skills()
|
||||
logger.info(f"Refreshed skills: {len(self.skill_manager.skills)} skills loaded")
|
||||
|
||||
def list_skills(self):
|
||||
"""
|
||||
List all loaded skills.
|
||||
|
||||
:return: List of skill entries or empty list
|
||||
"""
|
||||
if not self.skill_manager:
|
||||
return []
|
||||
return self.skill_manager.list_skills()
|
||||
|
||||
def _get_model_context_window(self) -> int:
|
||||
"""
|
||||
Get the model's context window size in tokens.
|
||||
Auto-detect based on model name.
|
||||
|
||||
Model context windows:
|
||||
- Claude 3.5/3.7 Sonnet: 200K tokens
|
||||
- Claude 3 Opus: 200K tokens
|
||||
- GPT-4 Turbo/128K: 128K tokens
|
||||
- GPT-4: 8K-32K tokens
|
||||
- GPT-3.5: 16K tokens
|
||||
- DeepSeek: 64K tokens
|
||||
|
||||
:return: Context window size in tokens
|
||||
"""
|
||||
if self.model and hasattr(self.model, 'model'):
|
||||
model_name = self.model.model.lower()
|
||||
|
||||
# Claude models - 200K context
|
||||
if 'claude-3' in model_name or 'claude-sonnet' in model_name:
|
||||
return 200000
|
||||
|
||||
# GPT-4 models
|
||||
elif 'gpt-4' in model_name:
|
||||
if 'turbo' in model_name or '128k' in model_name:
|
||||
return 128000
|
||||
elif '32k' in model_name:
|
||||
return 32000
|
||||
else:
|
||||
return 8000
|
||||
|
||||
# GPT-3.5
|
||||
elif 'gpt-3.5' in model_name:
|
||||
if '16k' in model_name:
|
||||
return 16000
|
||||
else:
|
||||
return 4000
|
||||
|
||||
# DeepSeek
|
||||
elif 'deepseek' in model_name:
|
||||
return 64000
|
||||
|
||||
# Gemini models
|
||||
elif 'gemini' in model_name:
|
||||
if '2.0' in model_name or 'exp' in model_name:
|
||||
return 2000000 # Gemini 2.0: 2M tokens
|
||||
else:
|
||||
return 1000000 # Gemini 1.5: 1M tokens
|
||||
|
||||
# Default conservative value
|
||||
return 128000
|
||||
|
||||
def _get_context_reserve_tokens(self) -> int:
|
||||
"""
|
||||
Get the number of tokens to reserve for new requests.
|
||||
This prevents context overflow by keeping a buffer.
|
||||
|
||||
:return: Number of tokens to reserve
|
||||
"""
|
||||
if self.context_reserve_tokens is not None:
|
||||
return self.context_reserve_tokens
|
||||
|
||||
# Reserve ~10% of context window, with min 10K and max 200K
|
||||
context_window = self._get_model_context_window()
|
||||
reserve = int(context_window * 0.1)
|
||||
return max(10000, min(200000, reserve))
|
||||
|
||||
def _estimate_message_tokens(self, message: dict) -> int:
|
||||
"""
|
||||
Estimate token count for a message.
|
||||
|
||||
Uses chars/3 for Chinese-heavy content and chars/4 for ASCII-heavy content,
|
||||
plus per-block overhead for tool_use / tool_result structures.
|
||||
|
||||
:param message: Message dict with 'role' and 'content'
|
||||
:return: Estimated token count
|
||||
"""
|
||||
content = message.get('content', '')
|
||||
if isinstance(content, str):
|
||||
return max(1, self._estimate_text_tokens(content))
|
||||
elif isinstance(content, list):
|
||||
total_tokens = 0
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
block_type = part.get('type', '')
|
||||
if block_type == 'text':
|
||||
total_tokens += self._estimate_text_tokens(part.get('text', ''))
|
||||
elif block_type == 'image':
|
||||
total_tokens += 1200
|
||||
elif block_type == 'tool_use':
|
||||
# tool_use has id + name + input (JSON-encoded)
|
||||
total_tokens += 50 # overhead for structure
|
||||
input_data = part.get('input', {})
|
||||
if isinstance(input_data, dict):
|
||||
import json
|
||||
input_str = json.dumps(input_data, ensure_ascii=False)
|
||||
total_tokens += self._estimate_text_tokens(input_str)
|
||||
elif block_type == 'tool_result':
|
||||
# tool_result has tool_use_id + content
|
||||
total_tokens += 30 # overhead for structure
|
||||
result_content = part.get('content', '')
|
||||
if isinstance(result_content, str):
|
||||
total_tokens += self._estimate_text_tokens(result_content)
|
||||
else:
|
||||
# Unknown block type, estimate conservatively
|
||||
total_tokens += 10
|
||||
return max(1, total_tokens)
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def _estimate_text_tokens(text: str) -> int:
|
||||
"""
|
||||
Estimate token count for a text string.
|
||||
|
||||
Chinese / CJK characters typically use ~1.5 tokens each,
|
||||
while ASCII uses ~0.25 tokens per char (4 chars/token).
|
||||
We use a weighted average based on the character mix.
|
||||
|
||||
:param text: Input text
|
||||
:return: Estimated token count
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
# Count non-ASCII characters (CJK, emoji, etc.)
|
||||
non_ascii = sum(1 for c in text if ord(c) > 127)
|
||||
ascii_count = len(text) - non_ascii
|
||||
# CJK chars: ~1.5 tokens each; ASCII: ~0.25 tokens per char
|
||||
return int(non_ascii * 1.5 + ascii_count * 0.25) + 1
|
||||
|
||||
def _find_tool(self, tool_name: str):
|
||||
"""Find and return a tool with the specified name"""
|
||||
for tool in self.tools:
|
||||
if tool.name == tool_name:
|
||||
# Only pre-process stage tools can be actively called
|
||||
if tool.stage == ToolStage.PRE_PROCESS:
|
||||
tool.model = self.model
|
||||
tool.context = self # Set tool context
|
||||
return tool
|
||||
else:
|
||||
# If it's a post-process tool, return None to prevent direct calling
|
||||
logger.warning(f"Tool {tool_name} is a post-process tool and cannot be called directly.")
|
||||
return None
|
||||
return None
|
||||
|
||||
# output function based on mode
|
||||
def output(self, message="", end="\n"):
|
||||
if self.output_mode == "print":
|
||||
print(message, end=end)
|
||||
elif message:
|
||||
logger.info(message)
|
||||
|
||||
def _execute_post_process_tools(self):
|
||||
"""Execute all post-process stage tools"""
|
||||
# Get all post-process stage tools
|
||||
post_process_tools = [tool for tool in self.tools if tool.stage == ToolStage.POST_PROCESS]
|
||||
|
||||
# Execute each tool
|
||||
for tool in post_process_tools:
|
||||
# Set tool context
|
||||
tool.context = self
|
||||
|
||||
# Record start time for execution timing
|
||||
start_time = time.time()
|
||||
|
||||
# Execute tool (with empty parameters, tool will extract needed info from context)
|
||||
result = tool.execute({})
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Capture tool use for tracking
|
||||
self.capture_tool_use(
|
||||
tool_name=tool.name,
|
||||
input_params={}, # Post-process tools typically don't take parameters
|
||||
output=result.result,
|
||||
status=result.status,
|
||||
error_message=str(result.result) if result.status == "error" else None,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# Log result
|
||||
if result.status == "success":
|
||||
# Print tool execution result in the desired format
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps(result.result)}")
|
||||
else:
|
||||
# Print failure in print mode
|
||||
self.output(f"\n🛠️ {tool.name}: {json.dumps({'status': 'error', 'message': str(result.result)})}")
|
||||
|
||||
def capture_tool_use(self, tool_name, input_params, output, status, thought=None, error_message=None,
|
||||
execution_time=0.0):
|
||||
"""
|
||||
Capture a tool use action.
|
||||
|
||||
:param thought: thought content
|
||||
:param tool_name: Name of the tool used
|
||||
:param input_params: Parameters passed to the tool
|
||||
:param output: Output from the tool
|
||||
:param status: Status of the tool execution
|
||||
:param error_message: Error message if the tool execution failed
|
||||
:param execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_result = ToolResult(
|
||||
tool_name=tool_name,
|
||||
input_params=input_params,
|
||||
output=output,
|
||||
status=status,
|
||||
error_message=error_message,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
action = AgentAction(
|
||||
agent_id=self.id if hasattr(self, 'id') else str(id(self)),
|
||||
agent_name=self.name,
|
||||
action_type=AgentActionType.TOOL_USE,
|
||||
tool_result=tool_result,
|
||||
thought=thought
|
||||
)
|
||||
|
||||
self.captured_actions.append(action)
|
||||
|
||||
return action
|
||||
|
||||
def run_stream(self, user_message: str, on_event=None, clear_history: bool = False, skill_filter=None) -> str:
|
||||
"""
|
||||
Execute single agent task with streaming (based on tool-call)
|
||||
|
||||
This method supports:
|
||||
- Streaming output
|
||||
- Multi-turn reasoning based on tool-call
|
||||
- Event callbacks
|
||||
- Persistent conversation history across calls
|
||||
|
||||
Args:
|
||||
user_message: User message
|
||||
on_event: Event callback function callback(event: dict)
|
||||
event = {"type": str, "timestamp": float, "data": dict}
|
||||
clear_history: If True, clear conversation history before this call (default: False)
|
||||
skill_filter: Optional list of skill names to include in this run
|
||||
|
||||
Returns:
|
||||
Final response text
|
||||
|
||||
Example:
|
||||
# Multi-turn conversation with memory
|
||||
response1 = agent.run_stream("My name is Alice")
|
||||
response2 = agent.run_stream("What's my name?") # Will remember Alice
|
||||
|
||||
# Single-turn without memory
|
||||
response = agent.run_stream("Hello", clear_history=True)
|
||||
"""
|
||||
# Clear history if requested
|
||||
if clear_history:
|
||||
with self.messages_lock:
|
||||
self.messages = []
|
||||
|
||||
# Get model to use
|
||||
if not self.model:
|
||||
raise ValueError("No model available for agent")
|
||||
|
||||
# Get full system prompt with skills
|
||||
full_system_prompt = self.get_full_system_prompt(skill_filter=skill_filter)
|
||||
|
||||
# Create a copy of messages for this execution to avoid concurrent modification
|
||||
# Record the original length to track which messages are new
|
||||
with self.messages_lock:
|
||||
messages_copy = self.messages.copy()
|
||||
original_length = len(self.messages)
|
||||
|
||||
# Get max_context_turns from config
|
||||
from config import conf
|
||||
max_context_turns = conf().get("agent_max_context_turns", 30)
|
||||
|
||||
# Create stream executor with copied message history
|
||||
executor = AgentStreamExecutor(
|
||||
agent=self,
|
||||
model=self.model,
|
||||
system_prompt=full_system_prompt,
|
||||
tools=self.tools,
|
||||
max_turns=self.max_steps,
|
||||
on_event=on_event,
|
||||
messages=messages_copy, # Pass copied message history
|
||||
max_context_turns=max_context_turns
|
||||
)
|
||||
|
||||
# Execute
|
||||
try:
|
||||
response = executor.run_stream(user_message)
|
||||
except Exception:
|
||||
# If executor cleared its messages (context overflow / message format error),
|
||||
# sync that back to the Agent's own message list so the next request
|
||||
# starts fresh instead of hitting the same overflow forever.
|
||||
if len(executor.messages) == 0:
|
||||
with self.messages_lock:
|
||||
self.messages.clear()
|
||||
logger.info("[Agent] Cleared Agent message history after executor recovery")
|
||||
raise
|
||||
|
||||
# Append only the NEW messages from this execution (thread-safe)
|
||||
# This allows concurrent requests to both contribute to history
|
||||
with self.messages_lock:
|
||||
new_messages = executor.messages[original_length:]
|
||||
self.messages.extend(new_messages)
|
||||
|
||||
# Store executor reference for agent_bridge to access files_to_send
|
||||
self.stream_executor = executor
|
||||
|
||||
# Execute all post-process tools
|
||||
self._execute_post_process_tools()
|
||||
|
||||
return response
|
||||
|
||||
def clear_history(self):
|
||||
"""Clear conversation history and captured actions"""
|
||||
self.messages = []
|
||||
self.captured_actions = []
|
||||
1263
agent/protocol/agent_stream.py
Normal file
1263
agent/protocol/agent_stream.py
Normal file
File diff suppressed because it is too large
Load Diff
27
agent/protocol/context.py
Normal file
27
agent/protocol/context.py
Normal file
@@ -0,0 +1,27 @@
|
||||
class TeamContext:
|
||||
def __init__(self, name: str, description: str, rule: str, agents: list, max_steps: int = 100):
|
||||
"""
|
||||
Initialize the TeamContext with a name, description, rules, a list of agents, and a user question.
|
||||
:param name: The name of the group context.
|
||||
:param description: A description of the group context.
|
||||
:param rule: The rules governing the group context.
|
||||
:param agents: A list of agents in the context.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.rule = rule
|
||||
self.agents = agents
|
||||
self.user_task = "" # For backward compatibility
|
||||
self.task = None # Will be a Task instance
|
||||
self.model = None # Will be an instance of LLMModel
|
||||
self.task_short_name = None # Store the task directory name
|
||||
# List of agents that have been executed
|
||||
self.agent_outputs: list = []
|
||||
self.current_steps = 0
|
||||
self.max_steps = max_steps
|
||||
|
||||
|
||||
class AgentOutput:
|
||||
def __init__(self, agent_name: str, output: str):
|
||||
self.agent_name = agent_name
|
||||
self.output = output
|
||||
57
agent/protocol/models.py
Normal file
57
agent/protocol/models.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Models module for agent system.
|
||||
Provides basic model classes needed by tools and bridge integration.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class LLMRequest:
|
||||
"""Request model for LLM operations"""
|
||||
|
||||
def __init__(self, messages: List[Dict[str, str]] = None, model: Optional[str] = None,
|
||||
temperature: float = 0.7, max_tokens: Optional[int] = None,
|
||||
stream: bool = False, tools: Optional[List] = None, **kwargs):
|
||||
self.messages = messages or []
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.tools = tools
|
||||
# Allow extra attributes
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LLMModel:
|
||||
"""Base class for LLM models"""
|
||||
|
||||
def __init__(self, model: str = None, **kwargs):
|
||||
self.model = model
|
||||
self.config = kwargs
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with a request.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call not implemented in this context")
|
||||
|
||||
def call_stream(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with streaming.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("LLMModel.call_stream not implemented in this context")
|
||||
|
||||
|
||||
class ModelFactory:
|
||||
"""Factory for creating model instances"""
|
||||
|
||||
@staticmethod
|
||||
def create_model(model_type: str, **kwargs):
|
||||
"""
|
||||
Create a model instance based on type.
|
||||
This is a placeholder implementation.
|
||||
"""
|
||||
raise NotImplementedError("ModelFactory.create_model not implemented in this context")
|
||||
97
agent/protocol/result.py
Normal file
97
agent/protocol/result.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from agent.protocol.task import Task, TaskStatus
|
||||
|
||||
|
||||
class AgentActionType(Enum):
|
||||
"""Enum representing different types of agent actions."""
|
||||
TOOL_USE = "tool_use"
|
||||
THINKING = "thinking"
|
||||
FINAL_ANSWER = "final_answer"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""
|
||||
Represents the result of a tool use.
|
||||
|
||||
Attributes:
|
||||
tool_name: Name of the tool used
|
||||
input_params: Parameters passed to the tool
|
||||
output: Output from the tool
|
||||
status: Status of the tool execution (success/error)
|
||||
error_message: Error message if the tool execution failed
|
||||
execution_time: Time taken to execute the tool
|
||||
"""
|
||||
tool_name: str
|
||||
input_params: Dict[str, Any]
|
||||
output: Any
|
||||
status: str
|
||||
error_message: Optional[str] = None
|
||||
execution_time: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""
|
||||
Represents an action taken by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the action
|
||||
agent_id: ID of the agent that performed the action
|
||||
agent_name: Name of the agent that performed the action
|
||||
action_type: Type of action (tool use, thinking, final answer)
|
||||
content: Content of the action (thought content, final answer content)
|
||||
tool_result: Tool use details if action_type is TOOL_USE
|
||||
timestamp: When the action was performed
|
||||
"""
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
action_type: AgentActionType
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
tool_result: Optional[ToolResult] = None
|
||||
thought: Optional[str] = None
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""
|
||||
Represents the result of an agent's execution.
|
||||
|
||||
Attributes:
|
||||
final_answer: The final answer provided by the agent
|
||||
step_count: Number of steps taken by the agent
|
||||
status: Status of the execution (success/error)
|
||||
error_message: Error message if execution failed
|
||||
"""
|
||||
final_answer: str
|
||||
step_count: int
|
||||
status: str = "success"
|
||||
error_message: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def success(cls, final_answer: str, step_count: int) -> "AgentResult":
|
||||
"""Create a successful result"""
|
||||
return cls(final_answer=final_answer, step_count=step_count)
|
||||
|
||||
@classmethod
|
||||
def error(cls, error_message: str, step_count: int = 0) -> "AgentResult":
|
||||
"""Create an error result"""
|
||||
return cls(
|
||||
final_answer=f"Error: {error_message}",
|
||||
step_count=step_count,
|
||||
status="error",
|
||||
error_message=error_message
|
||||
)
|
||||
|
||||
@property
|
||||
def is_error(self) -> bool:
|
||||
"""Check if the result represents an error"""
|
||||
return self.status == "error"
|
||||
96
agent/protocol/task.py
Normal file
96
agent/protocol/task.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""Enum representing different types of tasks."""
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
FILE = "file"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""Enum representing the status of a task."""
|
||||
INIT = "init" # Initial state
|
||||
PROCESSING = "processing" # In progress
|
||||
COMPLETED = "completed" # Completed
|
||||
FAILED = "failed" # Failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""
|
||||
Represents a task to be processed by an agent.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the task
|
||||
content: The primary text content of the task
|
||||
type: Type of the task
|
||||
status: Current status of the task
|
||||
created_at: Timestamp when the task was created
|
||||
updated_at: Timestamp when the task was last updated
|
||||
metadata: Additional metadata for the task
|
||||
images: List of image URLs or base64 encoded images
|
||||
videos: List of video URLs
|
||||
audios: List of audio URLs or base64 encoded audios
|
||||
files: List of file URLs or paths
|
||||
"""
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = ""
|
||||
type: TaskType = TaskType.TEXT
|
||||
status: TaskStatus = TaskStatus.INIT
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Media content
|
||||
images: List[str] = field(default_factory=list)
|
||||
videos: List[str] = field(default_factory=list)
|
||||
audios: List[str] = field(default_factory=list)
|
||||
files: List[str] = field(default_factory=list)
|
||||
|
||||
def __init__(self, content: str = "", **kwargs):
|
||||
"""
|
||||
Initialize a Task with content and optional keyword arguments.
|
||||
|
||||
Args:
|
||||
content: The text content of the task
|
||||
**kwargs: Additional attributes to set
|
||||
"""
|
||||
self.id = kwargs.get('id', str(uuid.uuid4()))
|
||||
self.content = content
|
||||
self.type = kwargs.get('type', TaskType.TEXT)
|
||||
self.status = kwargs.get('status', TaskStatus.INIT)
|
||||
self.created_at = kwargs.get('created_at', time.time())
|
||||
self.updated_at = kwargs.get('updated_at', time.time())
|
||||
self.metadata = kwargs.get('metadata', {})
|
||||
self.images = kwargs.get('images', [])
|
||||
self.videos = kwargs.get('videos', [])
|
||||
self.audios = kwargs.get('audios', [])
|
||||
self.files = kwargs.get('files', [])
|
||||
|
||||
def get_text(self) -> str:
|
||||
"""
|
||||
Get the text content of the task.
|
||||
|
||||
Returns:
|
||||
The text content
|
||||
"""
|
||||
return self.content
|
||||
|
||||
def update_status(self, status: TaskStatus) -> None:
|
||||
"""
|
||||
Update the status of the task.
|
||||
|
||||
Args:
|
||||
status: The new status
|
||||
"""
|
||||
self.status = status
|
||||
self.updated_at = time.time()
|
||||
31
agent/skills/__init__.py
Normal file
31
agent/skills/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Skills module for agent system.
|
||||
|
||||
This module provides the framework for loading, managing, and executing skills.
|
||||
Skills are markdown files with frontmatter that provide specialized instructions
|
||||
for specific tasks.
|
||||
"""
|
||||
|
||||
from agent.skills.types import (
|
||||
Skill,
|
||||
SkillEntry,
|
||||
SkillMetadata,
|
||||
SkillInstallSpec,
|
||||
LoadSkillsResult,
|
||||
)
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.manager import SkillManager
|
||||
from agent.skills.service import SkillService
|
||||
from agent.skills.formatter import format_skills_for_prompt
|
||||
|
||||
__all__ = [
|
||||
"Skill",
|
||||
"SkillEntry",
|
||||
"SkillMetadata",
|
||||
"SkillInstallSpec",
|
||||
"LoadSkillsResult",
|
||||
"SkillLoader",
|
||||
"SkillManager",
|
||||
"SkillService",
|
||||
"format_skills_for_prompt",
|
||||
]
|
||||
184
agent/skills/config.py
Normal file
184
agent/skills/config.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Configuration support for skills.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
from typing import Dict, Optional, List
|
||||
from agent.skills.types import SkillEntry
|
||||
|
||||
|
||||
def resolve_runtime_platform() -> str:
|
||||
"""Get the current runtime platform."""
|
||||
return platform.system().lower()
|
||||
|
||||
|
||||
def has_binary(bin_name: str) -> bool:
|
||||
"""
|
||||
Check if a binary is available in PATH.
|
||||
|
||||
:param bin_name: Binary name to check
|
||||
:return: True if binary is available
|
||||
"""
|
||||
import shutil
|
||||
return shutil.which(bin_name) is not None
|
||||
|
||||
|
||||
def has_any_binary(bin_names: List[str]) -> bool:
|
||||
"""
|
||||
Check if any of the given binaries is available.
|
||||
|
||||
:param bin_names: List of binary names to check
|
||||
:return: True if at least one binary is available
|
||||
"""
|
||||
return any(has_binary(bin_name) for bin_name in bin_names)
|
||||
|
||||
|
||||
def has_env_var(env_name: str) -> bool:
|
||||
"""
|
||||
Check if an environment variable is set.
|
||||
|
||||
:param env_name: Environment variable name
|
||||
:return: True if environment variable is set
|
||||
"""
|
||||
return env_name in os.environ and bool(os.environ[env_name].strip())
|
||||
|
||||
|
||||
def get_skill_config(config: Optional[Dict], skill_name: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get skill-specific configuration.
|
||||
|
||||
:param config: Global configuration dictionary
|
||||
:param skill_name: Name of the skill
|
||||
:return: Skill configuration or None
|
||||
"""
|
||||
if not config:
|
||||
return None
|
||||
|
||||
skills_config = config.get('skills', {})
|
||||
if not isinstance(skills_config, dict):
|
||||
return None
|
||||
|
||||
entries = skills_config.get('entries', {})
|
||||
if not isinstance(entries, dict):
|
||||
return None
|
||||
|
||||
return entries.get(skill_name)
|
||||
|
||||
|
||||
def should_include_skill(
|
||||
entry: SkillEntry,
|
||||
config: Optional[Dict] = None,
|
||||
current_platform: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a skill should be included based on requirements.
|
||||
|
||||
Simple rule: Skills are auto-enabled if their requirements are met.
|
||||
- Has required API keys → enabled
|
||||
- Missing API keys → disabled
|
||||
- Wrong keys → enabled but will fail at runtime (LLM will handle error)
|
||||
|
||||
:param entry: SkillEntry to check
|
||||
:param config: Configuration dictionary (currently unused, reserved for future)
|
||||
:param current_platform: Current platform (default: auto-detect)
|
||||
:return: True if skill should be included
|
||||
"""
|
||||
metadata = entry.metadata
|
||||
|
||||
# No metadata = always include (no requirements)
|
||||
if not metadata:
|
||||
return True
|
||||
|
||||
# Check platform requirements (can't work on wrong platform)
|
||||
if metadata.os:
|
||||
platform_name = current_platform or resolve_runtime_platform()
|
||||
# Map common platform names
|
||||
platform_map = {
|
||||
'darwin': 'darwin',
|
||||
'linux': 'linux',
|
||||
'windows': 'win32',
|
||||
}
|
||||
normalized_platform = platform_map.get(platform_name, platform_name)
|
||||
|
||||
if normalized_platform not in metadata.os:
|
||||
return False
|
||||
|
||||
# If skill has 'always: true', include it regardless of other requirements
|
||||
if metadata.always:
|
||||
return True
|
||||
|
||||
# Check requirements
|
||||
if metadata.requires:
|
||||
# Check required binaries (all must be present)
|
||||
required_bins = metadata.requires.get('bins', [])
|
||||
if required_bins:
|
||||
if not all(has_binary(bin_name) for bin_name in required_bins):
|
||||
return False
|
||||
|
||||
# Check anyBins (at least one must be present)
|
||||
any_bins = metadata.requires.get('anyBins', [])
|
||||
if any_bins:
|
||||
if not has_any_binary(any_bins):
|
||||
return False
|
||||
|
||||
# Check environment variables (API keys)
|
||||
# Simple rule: All required env vars must be set
|
||||
required_env = metadata.requires.get('env', [])
|
||||
if required_env:
|
||||
for env_name in required_env:
|
||||
if not has_env_var(env_name):
|
||||
# Missing required API key → disable skill
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_config_path_truthy(config: Dict, path: str) -> bool:
|
||||
"""
|
||||
Check if a config path resolves to a truthy value.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:param path: Dot-separated path (e.g., 'skills.enabled')
|
||||
:return: True if path resolves to truthy value
|
||||
"""
|
||||
parts = path.split('.')
|
||||
current = config
|
||||
|
||||
for part in parts:
|
||||
if not isinstance(current, dict):
|
||||
return False
|
||||
current = current.get(part)
|
||||
if current is None:
|
||||
return False
|
||||
|
||||
# Check if value is truthy
|
||||
if isinstance(current, bool):
|
||||
return current
|
||||
if isinstance(current, (int, float)):
|
||||
return current != 0
|
||||
if isinstance(current, str):
|
||||
return bool(current.strip())
|
||||
|
||||
return bool(current)
|
||||
|
||||
|
||||
def resolve_config_path(config: Dict, path: str):
|
||||
"""
|
||||
Resolve a dot-separated config path to its value.
|
||||
|
||||
:param config: Configuration dictionary
|
||||
:param path: Dot-separated path
|
||||
:return: Value at path or None
|
||||
"""
|
||||
parts = path.split('.')
|
||||
current = config
|
||||
|
||||
for part in parts:
|
||||
if not isinstance(current, dict):
|
||||
return None
|
||||
current = current.get(part)
|
||||
if current is None:
|
||||
return None
|
||||
|
||||
return current
|
||||
60
agent/skills/formatter.py
Normal file
60
agent/skills/formatter.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Skill formatter for generating prompts from skills.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
|
||||
|
||||
def format_skills_for_prompt(skills: List[Skill]) -> str:
|
||||
"""
|
||||
Format skills for inclusion in a system prompt.
|
||||
|
||||
Uses XML format per Agent Skills standard.
|
||||
Skills with disable_model_invocation=True are excluded.
|
||||
|
||||
:param skills: List of skills to format
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
# Filter out skills that should not be invoked by the model
|
||||
visible_skills = [s for s in skills if not s.disable_model_invocation]
|
||||
|
||||
if not visible_skills:
|
||||
return ""
|
||||
|
||||
lines = [
|
||||
"",
|
||||
"<available_skills>",
|
||||
]
|
||||
|
||||
for skill in visible_skills:
|
||||
lines.append(" <skill>")
|
||||
lines.append(f" <name>{_escape_xml(skill.name)}</name>")
|
||||
lines.append(f" <description>{_escape_xml(skill.description)}</description>")
|
||||
lines.append(f" <location>{_escape_xml(skill.file_path)}</location>")
|
||||
lines.append(" </skill>")
|
||||
|
||||
lines.append("</available_skills>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_skill_entries_for_prompt(entries: List[SkillEntry]) -> str:
|
||||
"""
|
||||
Format skill entries for inclusion in a system prompt.
|
||||
|
||||
:param entries: List of skill entries to format
|
||||
:return: Formatted prompt text
|
||||
"""
|
||||
skills = [entry.skill for entry in entries]
|
||||
return format_skills_for_prompt(skills)
|
||||
|
||||
|
||||
def _escape_xml(text: str) -> str:
|
||||
"""Escape XML special characters."""
|
||||
return (text
|
||||
.replace('&', '&')
|
||||
.replace('<', '<')
|
||||
.replace('>', '>')
|
||||
.replace('"', '"')
|
||||
.replace("'", '''))
|
||||
172
agent/skills/frontmatter.py
Normal file
172
agent/skills/frontmatter.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Frontmatter parsing for skills.
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from agent.skills.types import SkillMetadata, SkillInstallSpec
|
||||
|
||||
|
||||
def parse_frontmatter(content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse YAML-style frontmatter from markdown content.
|
||||
|
||||
Returns a dictionary of frontmatter fields.
|
||||
"""
|
||||
frontmatter = {}
|
||||
|
||||
# Match frontmatter block between --- markers
|
||||
match = re.match(r'^---\s*\n(.*?)\n---\s*\n', content, re.DOTALL)
|
||||
if not match:
|
||||
return frontmatter
|
||||
|
||||
frontmatter_text = match.group(1)
|
||||
|
||||
# Try to use PyYAML for proper YAML parsing
|
||||
try:
|
||||
import yaml
|
||||
frontmatter = yaml.safe_load(frontmatter_text)
|
||||
if not isinstance(frontmatter, dict):
|
||||
frontmatter = {}
|
||||
return frontmatter
|
||||
except ImportError:
|
||||
# Fallback to simple parsing if PyYAML not available
|
||||
pass
|
||||
except Exception:
|
||||
# If YAML parsing fails, fall back to simple parsing
|
||||
pass
|
||||
|
||||
# Simple YAML-like parsing (supports key: value format only)
|
||||
# This is a fallback for when PyYAML is not available
|
||||
for line in frontmatter_text.split('\n'):
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
if ':' in line:
|
||||
key, value = line.split(':', 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Try to parse as JSON if it looks like JSON
|
||||
if value.startswith('{') or value.startswith('['):
|
||||
try:
|
||||
value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
# Parse boolean values
|
||||
elif value.lower() in ('true', 'false'):
|
||||
value = value.lower() == 'true'
|
||||
# Parse numbers
|
||||
elif value.isdigit():
|
||||
value = int(value)
|
||||
|
||||
frontmatter[key] = value
|
||||
|
||||
return frontmatter
|
||||
|
||||
|
||||
def parse_metadata(frontmatter: Dict[str, Any]) -> Optional[SkillMetadata]:
|
||||
"""
|
||||
Parse skill metadata from frontmatter.
|
||||
|
||||
Looks for 'metadata' field containing JSON with skill configuration.
|
||||
"""
|
||||
metadata_raw = frontmatter.get('metadata')
|
||||
if not metadata_raw:
|
||||
return None
|
||||
|
||||
# If it's a string, try to parse as JSON
|
||||
if isinstance(metadata_raw, str):
|
||||
try:
|
||||
metadata_raw = json.loads(metadata_raw)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
if not isinstance(metadata_raw, dict):
|
||||
return None
|
||||
|
||||
# Use metadata_raw directly (COW format)
|
||||
meta_obj = metadata_raw
|
||||
|
||||
# Parse install specs
|
||||
install_specs = []
|
||||
install_raw = meta_obj.get('install', [])
|
||||
if isinstance(install_raw, list):
|
||||
for spec_raw in install_raw:
|
||||
if not isinstance(spec_raw, dict):
|
||||
continue
|
||||
|
||||
kind = spec_raw.get('kind', spec_raw.get('type', '')).lower()
|
||||
if not kind:
|
||||
continue
|
||||
|
||||
spec = SkillInstallSpec(
|
||||
kind=kind,
|
||||
id=spec_raw.get('id'),
|
||||
label=spec_raw.get('label'),
|
||||
bins=_normalize_string_list(spec_raw.get('bins')),
|
||||
os=_normalize_string_list(spec_raw.get('os')),
|
||||
formula=spec_raw.get('formula'),
|
||||
package=spec_raw.get('package'),
|
||||
module=spec_raw.get('module'),
|
||||
url=spec_raw.get('url'),
|
||||
archive=spec_raw.get('archive'),
|
||||
extract=spec_raw.get('extract', False),
|
||||
strip_components=spec_raw.get('stripComponents'),
|
||||
target_dir=spec_raw.get('targetDir'),
|
||||
)
|
||||
install_specs.append(spec)
|
||||
|
||||
# Parse requires
|
||||
requires = {}
|
||||
requires_raw = meta_obj.get('requires', {})
|
||||
if isinstance(requires_raw, dict):
|
||||
for key, value in requires_raw.items():
|
||||
requires[key] = _normalize_string_list(value)
|
||||
|
||||
return SkillMetadata(
|
||||
always=meta_obj.get('always', False),
|
||||
skill_key=meta_obj.get('skillKey'),
|
||||
primary_env=meta_obj.get('primaryEnv'),
|
||||
emoji=meta_obj.get('emoji'),
|
||||
homepage=meta_obj.get('homepage'),
|
||||
os=_normalize_string_list(meta_obj.get('os')),
|
||||
requires=requires,
|
||||
install=install_specs,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_string_list(value: Any) -> List[str]:
|
||||
"""Normalize a value to a list of strings."""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
if isinstance(value, list):
|
||||
return [str(v).strip() for v in value if v]
|
||||
|
||||
if isinstance(value, str):
|
||||
return [v.strip() for v in value.split(',') if v.strip()]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def parse_boolean_value(value: Optional[str], default: bool = False) -> bool:
|
||||
"""Parse a boolean value from frontmatter."""
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ('true', '1', 'yes', 'on')
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def get_frontmatter_value(frontmatter: Dict[str, Any], key: str) -> Optional[str]:
|
||||
"""Get a frontmatter value as a string."""
|
||||
value = frontmatter.get(key)
|
||||
return str(value) if value is not None else None
|
||||
278
agent/skills/loader.py
Normal file
278
agent/skills/loader.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
Skill loader for discovering and loading skills from directories.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry, LoadSkillsResult, SkillMetadata
|
||||
from agent.skills.frontmatter import parse_frontmatter, parse_metadata, parse_boolean_value, get_frontmatter_value
|
||||
|
||||
|
||||
class SkillLoader:
|
||||
"""Loads skills from various directories."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_skills_from_dir(self, dir_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load skills from a directory.
|
||||
|
||||
Discovery rules:
|
||||
- Direct .md files in the root directory
|
||||
- Recursive SKILL.md files under subdirectories
|
||||
|
||||
:param dir_path: Directory path to scan
|
||||
:param source: Source identifier ('builtin' or 'custom')
|
||||
:return: LoadSkillsResult with skills and diagnostics
|
||||
"""
|
||||
skills = []
|
||||
diagnostics = []
|
||||
|
||||
if not os.path.exists(dir_path):
|
||||
diagnostics.append(f"Directory does not exist: {dir_path}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
if not os.path.isdir(dir_path):
|
||||
diagnostics.append(f"Path is not a directory: {dir_path}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
# Load skills from root-level .md files and subdirectories
|
||||
result = self._load_skills_recursive(dir_path, source, include_root_files=True)
|
||||
|
||||
return result
|
||||
|
||||
def _load_skills_recursive(
|
||||
self,
|
||||
dir_path: str,
|
||||
source: str,
|
||||
include_root_files: bool = False
|
||||
) -> LoadSkillsResult:
|
||||
"""
|
||||
Recursively load skills from a directory.
|
||||
|
||||
:param dir_path: Directory to scan
|
||||
:param source: Source identifier
|
||||
:param include_root_files: Whether to include root-level .md files
|
||||
:return: LoadSkillsResult
|
||||
"""
|
||||
skills = []
|
||||
diagnostics = []
|
||||
|
||||
try:
|
||||
entries = os.listdir(dir_path)
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to list directory {dir_path}: {e}")
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
for entry in entries:
|
||||
# Skip hidden files and directories
|
||||
if entry.startswith('.'):
|
||||
continue
|
||||
|
||||
# Skip common non-skill directories
|
||||
if entry in ('node_modules', '__pycache__', 'venv', '.git'):
|
||||
continue
|
||||
|
||||
full_path = os.path.join(dir_path, entry)
|
||||
|
||||
# Handle directories
|
||||
if os.path.isdir(full_path):
|
||||
# Recursively scan subdirectories
|
||||
sub_result = self._load_skills_recursive(full_path, source, include_root_files=False)
|
||||
skills.extend(sub_result.skills)
|
||||
diagnostics.extend(sub_result.diagnostics)
|
||||
continue
|
||||
|
||||
# Handle files
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
|
||||
# Check if this is a skill file
|
||||
is_root_md = include_root_files and entry.endswith('.md')
|
||||
is_skill_md = not include_root_files and entry == 'SKILL.md'
|
||||
|
||||
if not (is_root_md or is_skill_md):
|
||||
continue
|
||||
|
||||
# Load the skill
|
||||
skill_result = self._load_skill_from_file(full_path, source)
|
||||
if skill_result.skills:
|
||||
skills.extend(skill_result.skills)
|
||||
diagnostics.extend(skill_result.diagnostics)
|
||||
|
||||
return LoadSkillsResult(skills=skills, diagnostics=diagnostics)
|
||||
|
||||
def _load_skill_from_file(self, file_path: str, source: str) -> LoadSkillsResult:
|
||||
"""
|
||||
Load a single skill from a markdown file.
|
||||
|
||||
:param file_path: Path to the skill markdown file
|
||||
:param source: Source identifier
|
||||
:return: LoadSkillsResult
|
||||
"""
|
||||
diagnostics = []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
except Exception as e:
|
||||
diagnostics.append(f"Failed to read skill file {file_path}: {e}")
|
||||
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||
|
||||
# Parse frontmatter
|
||||
frontmatter = parse_frontmatter(content)
|
||||
|
||||
# Get skill name and description
|
||||
skill_dir = os.path.dirname(file_path)
|
||||
parent_dir_name = os.path.basename(skill_dir)
|
||||
|
||||
name = frontmatter.get('name', parent_dir_name)
|
||||
description = frontmatter.get('description', '')
|
||||
|
||||
# Normalize name (handle both string and list)
|
||||
if isinstance(name, list):
|
||||
name = name[0] if name else parent_dir_name
|
||||
elif not isinstance(name, str):
|
||||
name = str(name) if name else parent_dir_name
|
||||
|
||||
# Normalize description (handle both string and list)
|
||||
if isinstance(description, list):
|
||||
description = ' '.join(str(d) for d in description if d)
|
||||
elif not isinstance(description, str):
|
||||
description = str(description) if description else ''
|
||||
|
||||
# Special handling for linkai-agent: dynamically load apps from config.json
|
||||
if name == 'linkai-agent':
|
||||
description = self._load_linkai_agent_description(skill_dir, description)
|
||||
|
||||
if not description or not description.strip():
|
||||
diagnostics.append(f"Skill {name} has no description: {file_path}")
|
||||
return LoadSkillsResult(skills=[], diagnostics=diagnostics)
|
||||
|
||||
# Parse disable-model-invocation flag
|
||||
disable_model_invocation = parse_boolean_value(
|
||||
get_frontmatter_value(frontmatter, 'disable-model-invocation'),
|
||||
default=False
|
||||
)
|
||||
|
||||
# Create skill object
|
||||
skill = Skill(
|
||||
name=name,
|
||||
description=description,
|
||||
file_path=file_path,
|
||||
base_dir=skill_dir,
|
||||
source=source,
|
||||
content=content,
|
||||
disable_model_invocation=disable_model_invocation,
|
||||
frontmatter=frontmatter,
|
||||
)
|
||||
|
||||
return LoadSkillsResult(skills=[skill], diagnostics=diagnostics)
|
||||
|
||||
def _load_linkai_agent_description(self, skill_dir: str, default_description: str) -> str:
|
||||
"""
|
||||
Dynamically load LinkAI agent description from config.json
|
||||
|
||||
:param skill_dir: Skill directory
|
||||
:param default_description: Default description from SKILL.md
|
||||
:return: Dynamic description with app list
|
||||
"""
|
||||
import json
|
||||
|
||||
config_path = os.path.join(skill_dir, "config.json")
|
||||
|
||||
# Without config.json, skip this skill entirely (return empty to trigger exclusion)
|
||||
if not os.path.exists(config_path):
|
||||
logger.debug(f"[SkillLoader] linkai-agent skipped: no config.json found")
|
||||
return ""
|
||||
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
apps = config.get("apps", [])
|
||||
if not apps:
|
||||
return default_description
|
||||
|
||||
# Build dynamic description with app details
|
||||
app_descriptions = "; ".join([
|
||||
f"{app['app_name']}({app['app_code']}: {app['app_description']})"
|
||||
for app in apps
|
||||
])
|
||||
|
||||
return f"Call LinkAI apps/workflows. {app_descriptions}"
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[SkillLoader] Failed to load linkai-agent config: {e}")
|
||||
return default_description
|
||||
|
||||
def load_all_skills(
|
||||
self,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
) -> Dict[str, SkillEntry]:
|
||||
"""
|
||||
Load skills from builtin and custom directories.
|
||||
|
||||
Precedence (lowest to highest):
|
||||
1. builtin — project root ``skills/``, shipped with the codebase
|
||||
2. custom — workspace ``skills/``, installed via cloud console or skill creator
|
||||
|
||||
Same-name custom skills override builtin ones.
|
||||
|
||||
:param builtin_dir: Built-in skills directory
|
||||
:param custom_dir: Custom skills directory
|
||||
:return: Dictionary mapping skill name to SkillEntry
|
||||
"""
|
||||
skill_map: Dict[str, SkillEntry] = {}
|
||||
all_diagnostics = []
|
||||
|
||||
# Load builtin skills (lower precedence)
|
||||
if builtin_dir and os.path.exists(builtin_dir):
|
||||
result = self.load_skills_from_dir(builtin_dir, source='builtin')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Load custom skills (higher precedence, overrides builtin)
|
||||
if custom_dir and os.path.exists(custom_dir):
|
||||
result = self.load_skills_from_dir(custom_dir, source='custom')
|
||||
all_diagnostics.extend(result.diagnostics)
|
||||
for skill in result.skills:
|
||||
entry = self._create_skill_entry(skill)
|
||||
skill_map[skill.name] = entry
|
||||
|
||||
# Log diagnostics
|
||||
if all_diagnostics:
|
||||
logger.debug(f"Skill loading diagnostics: {len(all_diagnostics)} issues")
|
||||
for diag in all_diagnostics[:5]:
|
||||
logger.debug(f" - {diag}")
|
||||
|
||||
logger.debug(f"Loaded {len(skill_map)} skills total")
|
||||
|
||||
return skill_map
|
||||
|
||||
def _create_skill_entry(self, skill: Skill) -> SkillEntry:
|
||||
"""
|
||||
Create a SkillEntry from a Skill with parsed metadata.
|
||||
|
||||
:param skill: The skill to create an entry for
|
||||
:return: SkillEntry with metadata
|
||||
"""
|
||||
metadata = parse_metadata(skill.frontmatter)
|
||||
|
||||
# Parse user-invocable flag
|
||||
user_invocable = parse_boolean_value(
|
||||
get_frontmatter_value(skill.frontmatter, 'user-invocable'),
|
||||
default=True
|
||||
)
|
||||
|
||||
return SkillEntry(
|
||||
skill=skill,
|
||||
metadata=metadata,
|
||||
user_invocable=user_invocable,
|
||||
)
|
||||
300
agent/skills/manager.py
Normal file
300
agent/skills/manager.py
Normal file
@@ -0,0 +1,300 @@
|
||||
"""
|
||||
Skill manager for managing skill lifecycle and operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry, SkillSnapshot
|
||||
from agent.skills.loader import SkillLoader
|
||||
from agent.skills.formatter import format_skill_entries_for_prompt
|
||||
|
||||
SKILLS_CONFIG_FILE = "skills_config.json"
|
||||
|
||||
|
||||
class SkillManager:
|
||||
"""Manages skills for an agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
builtin_dir: Optional[str] = None,
|
||||
custom_dir: Optional[str] = None,
|
||||
config: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the skill manager.
|
||||
|
||||
:param builtin_dir: Built-in skills directory (project root ``skills/``)
|
||||
:param custom_dir: Custom skills directory (workspace ``skills/``)
|
||||
:param config: Configuration dictionary
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
self.builtin_dir = builtin_dir or os.path.join(project_root, 'skills')
|
||||
self.custom_dir = custom_dir or os.path.join(project_root, 'workspace', 'skills')
|
||||
self.config = config or {}
|
||||
self._skills_config_path = os.path.join(self.custom_dir, SKILLS_CONFIG_FILE)
|
||||
|
||||
# skills_config: full skill metadata keyed by name
|
||||
# { "web-fetch": {"name": ..., "description": ..., "source": ..., "enabled": true}, ... }
|
||||
self.skills_config: Dict[str, dict] = {}
|
||||
|
||||
self.loader = SkillLoader()
|
||||
self.skills: Dict[str, SkillEntry] = {}
|
||||
|
||||
# Load skills on initialization
|
||||
self.refresh_skills()
|
||||
|
||||
def refresh_skills(self):
|
||||
"""Reload all skills from builtin and custom directories, then sync config."""
|
||||
self.skills = self.loader.load_all_skills(
|
||||
builtin_dir=self.builtin_dir,
|
||||
custom_dir=self.custom_dir,
|
||||
)
|
||||
self._sync_skills_config()
|
||||
logger.debug(f"SkillManager: Loaded {len(self.skills)} skills")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# skills_config.json management
|
||||
# ------------------------------------------------------------------
|
||||
def _load_skills_config(self) -> Dict[str, dict]:
|
||||
"""Load skills_config.json from custom_dir. Returns empty dict if not found."""
|
||||
if not os.path.exists(self._skills_config_path):
|
||||
return {}
|
||||
try:
|
||||
with open(self._skills_config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning(f"[SkillManager] Failed to load {SKILLS_CONFIG_FILE}: {e}")
|
||||
return {}
|
||||
|
||||
def _save_skills_config(self):
|
||||
"""Persist skills_config to custom_dir/skills_config.json."""
|
||||
os.makedirs(self.custom_dir, exist_ok=True)
|
||||
try:
|
||||
with open(self._skills_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.skills_config, f, indent=4, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillManager] Failed to save {SKILLS_CONFIG_FILE}: {e}")
|
||||
|
||||
def _sync_skills_config(self):
|
||||
"""
|
||||
Merge directory-scanned skills with the persisted config file.
|
||||
|
||||
- New skills discovered on disk are added with enabled=True.
|
||||
- Skills that no longer exist on disk are removed.
|
||||
- Existing entries preserve their enabled state; name/description/source
|
||||
are refreshed from the latest scan.
|
||||
"""
|
||||
saved = self._load_skills_config()
|
||||
merged: Dict[str, dict] = {}
|
||||
|
||||
for name, entry in self.skills.items():
|
||||
skill = entry.skill
|
||||
prev = saved.get(name, {})
|
||||
merged[name] = {
|
||||
"name": name,
|
||||
"description": skill.description,
|
||||
"source": skill.source,
|
||||
"enabled": prev.get("enabled", True),
|
||||
}
|
||||
|
||||
self.skills_config = merged
|
||||
self._save_skills_config()
|
||||
|
||||
def is_skill_enabled(self, name: str) -> bool:
|
||||
"""
|
||||
Check if a skill is enabled according to skills_config.
|
||||
|
||||
:param name: skill name
|
||||
:return: True if enabled (default True if not in config)
|
||||
"""
|
||||
entry = self.skills_config.get(name)
|
||||
if entry is None:
|
||||
return True
|
||||
return entry.get("enabled", True)
|
||||
|
||||
def set_skill_enabled(self, name: str, enabled: bool):
|
||||
"""
|
||||
Set a skill's enabled state and persist.
|
||||
|
||||
:param name: skill name
|
||||
:param enabled: True to enable, False to disable
|
||||
"""
|
||||
if name not in self.skills_config:
|
||||
raise ValueError(f"skill '{name}' not found in config")
|
||||
self.skills_config[name]["enabled"] = enabled
|
||||
self._save_skills_config()
|
||||
|
||||
def get_skills_config(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Return the full skills_config dict (for query API).
|
||||
|
||||
:return: copy of skills_config
|
||||
"""
|
||||
return dict(self.skills_config)
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
Get a skill by name.
|
||||
|
||||
:param name: Skill name
|
||||
:return: SkillEntry or None if not found
|
||||
"""
|
||||
return self.skills.get(name)
|
||||
|
||||
def list_skills(self) -> List[SkillEntry]:
|
||||
"""
|
||||
Get all loaded skills.
|
||||
|
||||
:return: List of all skill entries
|
||||
"""
|
||||
return list(self.skills.values())
|
||||
|
||||
def filter_skills(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
include_disabled: bool = False,
|
||||
) -> List[SkillEntry]:
|
||||
"""
|
||||
Filter skills based on criteria.
|
||||
|
||||
Simple rule: Skills are auto-enabled if requirements are met.
|
||||
- Has required API keys -> included
|
||||
- Missing API keys -> excluded
|
||||
|
||||
:param skill_filter: List of skill names to include (None = all)
|
||||
:param include_disabled: Whether to include disabled skills
|
||||
:return: Filtered list of skill entries
|
||||
"""
|
||||
from agent.skills.config import should_include_skill
|
||||
|
||||
entries = list(self.skills.values())
|
||||
|
||||
# Check requirements (platform, binaries, env vars)
|
||||
entries = [e for e in entries if should_include_skill(e, self.config)]
|
||||
|
||||
# Apply skill filter
|
||||
if skill_filter is not None:
|
||||
normalized = []
|
||||
for item in skill_filter:
|
||||
if isinstance(item, str):
|
||||
name = item.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
elif isinstance(item, list):
|
||||
for subitem in item:
|
||||
if isinstance(subitem, str):
|
||||
name = subitem.strip()
|
||||
if name:
|
||||
normalized.append(name)
|
||||
if normalized:
|
||||
entries = [e for e in entries if e.skill.name in normalized]
|
||||
|
||||
# Filter out disabled skills based on skills_config.json
|
||||
if not include_disabled:
|
||||
entries = [e for e in entries if self.is_skill_enabled(e.skill.name)]
|
||||
|
||||
return entries
|
||||
|
||||
def build_skills_prompt(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build a formatted prompt containing available skills.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:return: Formatted skills prompt
|
||||
"""
|
||||
from common.log import logger
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
logger.debug(f"[SkillManager] Filtered {len(entries)} skills for prompt (total: {len(self.skills)})")
|
||||
if entries:
|
||||
skill_names = [e.skill.name for e in entries]
|
||||
logger.debug(f"[SkillManager] Skills to include: {skill_names}")
|
||||
result = format_skill_entries_for_prompt(entries)
|
||||
logger.debug(f"[SkillManager] Generated prompt length: {len(result)}")
|
||||
return result
|
||||
|
||||
def build_skill_snapshot(
|
||||
self,
|
||||
skill_filter: Optional[List[str]] = None,
|
||||
version: Optional[int] = None,
|
||||
) -> SkillSnapshot:
|
||||
"""
|
||||
Build a snapshot of skills for a specific run.
|
||||
|
||||
:param skill_filter: Optional list of skill names to include
|
||||
:param version: Optional version number for the snapshot
|
||||
:return: SkillSnapshot
|
||||
"""
|
||||
entries = self.filter_skills(skill_filter=skill_filter, include_disabled=False)
|
||||
prompt = format_skill_entries_for_prompt(entries)
|
||||
|
||||
skills_info = []
|
||||
resolved_skills = []
|
||||
|
||||
for entry in entries:
|
||||
skills_info.append({
|
||||
'name': entry.skill.name,
|
||||
'primary_env': entry.metadata.primary_env if entry.metadata else None,
|
||||
})
|
||||
resolved_skills.append(entry.skill)
|
||||
|
||||
return SkillSnapshot(
|
||||
prompt=prompt,
|
||||
skills=skills_info,
|
||||
resolved_skills=resolved_skills,
|
||||
version=version,
|
||||
)
|
||||
|
||||
def sync_skills_to_workspace(self, target_workspace_dir: str):
|
||||
"""
|
||||
Sync all loaded skills to a target workspace directory.
|
||||
|
||||
This is useful for sandbox environments where skills need to be copied.
|
||||
|
||||
:param target_workspace_dir: Target workspace directory
|
||||
"""
|
||||
import shutil
|
||||
|
||||
target_skills_dir = os.path.join(target_workspace_dir, 'skills')
|
||||
|
||||
# Remove existing skills directory
|
||||
if os.path.exists(target_skills_dir):
|
||||
shutil.rmtree(target_skills_dir)
|
||||
|
||||
# Create new skills directory
|
||||
os.makedirs(target_skills_dir, exist_ok=True)
|
||||
|
||||
# Copy each skill
|
||||
for entry in self.skills.values():
|
||||
skill_name = entry.skill.name
|
||||
source_dir = entry.skill.base_dir
|
||||
target_dir = os.path.join(target_skills_dir, skill_name)
|
||||
|
||||
try:
|
||||
shutil.copytree(source_dir, target_dir)
|
||||
logger.debug(f"Synced skill '{skill_name}' to {target_dir}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to sync skill '{skill_name}': {e}")
|
||||
|
||||
logger.info(f"Synced {len(self.skills)} skills to {target_skills_dir}")
|
||||
|
||||
def get_skill_by_key(self, skill_key: str) -> Optional[SkillEntry]:
|
||||
"""
|
||||
Get a skill by its skill key (which may differ from name).
|
||||
|
||||
:param skill_key: Skill key to look up
|
||||
:return: SkillEntry or None
|
||||
"""
|
||||
for entry in self.skills.values():
|
||||
if entry.metadata and entry.metadata.skill_key == skill_key:
|
||||
return entry
|
||||
if entry.skill.name == skill_key:
|
||||
return entry
|
||||
return None
|
||||
204
agent/skills/service.py
Normal file
204
agent/skills/service.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Skill service for handling skill CRUD operations.
|
||||
|
||||
This service provides a unified interface for managing skills, which can be
|
||||
called from the cloud control client (LinkAI), the local web console, or any
|
||||
other management entry point.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from typing import Dict, List, Optional
|
||||
from common.log import logger
|
||||
from agent.skills.types import Skill, SkillEntry
|
||||
from agent.skills.manager import SkillManager
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""
|
||||
High-level service for skill lifecycle management.
|
||||
Wraps SkillManager and provides network-aware operations such as
|
||||
downloading skill files from remote URLs.
|
||||
"""
|
||||
|
||||
def __init__(self, skill_manager: SkillManager):
|
||||
"""
|
||||
:param skill_manager: The SkillManager instance to operate on
|
||||
"""
|
||||
self.manager = skill_manager
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# query
|
||||
# ------------------------------------------------------------------
|
||||
def query(self) -> List[dict]:
|
||||
"""
|
||||
Query all skills and return a serialisable list.
|
||||
Reads from skills_config.json (refreshes from disk if needed).
|
||||
|
||||
:return: list of skill info dicts
|
||||
"""
|
||||
self.manager.refresh_skills()
|
||||
config = self.manager.get_skills_config()
|
||||
result = list(config.values())
|
||||
logger.info(f"[SkillService] query: {len(result)} skills found")
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# add / install
|
||||
# ------------------------------------------------------------------
|
||||
def add(self, payload: dict) -> None:
|
||||
"""
|
||||
Add (install) a skill from a remote payload.
|
||||
|
||||
The payload follows the socket protocol::
|
||||
|
||||
{
|
||||
"name": "web_search",
|
||||
"type": "url",
|
||||
"enabled": true,
|
||||
"files": [
|
||||
{"url": "https://...", "path": "README.md"},
|
||||
{"url": "https://...", "path": "scripts/main.py"}
|
||||
]
|
||||
}
|
||||
|
||||
Files are downloaded and saved under the custom skills directory
|
||||
using *name* as the sub-directory.
|
||||
|
||||
:param payload: skill add payload from server
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
files = payload.get("files", [])
|
||||
if not files:
|
||||
raise ValueError("skill files list is empty")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
os.makedirs(skill_dir, exist_ok=True)
|
||||
|
||||
for file_info in files:
|
||||
url = file_info.get("url")
|
||||
rel_path = file_info.get("path")
|
||||
if not url or not rel_path:
|
||||
logger.warning(f"[SkillService] add: skip invalid file entry {file_info}")
|
||||
continue
|
||||
dest = os.path.join(skill_dir, rel_path)
|
||||
self._download_file(url, dest)
|
||||
|
||||
# Reload to pick up the new skill and sync config
|
||||
self.manager.refresh_skills()
|
||||
logger.info(f"[SkillService] add: skill '{name}' installed ({len(files)} files)")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# open / close (enable / disable)
|
||||
# ------------------------------------------------------------------
|
||||
def open(self, payload: dict) -> None:
|
||||
"""
|
||||
Enable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=True)
|
||||
logger.info(f"[SkillService] open: skill '{name}' enabled")
|
||||
|
||||
def close(self, payload: dict) -> None:
|
||||
"""
|
||||
Disable a skill by name.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
self.manager.set_skill_enabled(name, enabled=False)
|
||||
logger.info(f"[SkillService] close: skill '{name}' disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete
|
||||
# ------------------------------------------------------------------
|
||||
def delete(self, payload: dict) -> None:
|
||||
"""
|
||||
Delete a skill by removing its directory entirely.
|
||||
|
||||
:param payload: {"name": "skill_name"}
|
||||
"""
|
||||
name = payload.get("name")
|
||||
if not name:
|
||||
raise ValueError("skill name is required")
|
||||
|
||||
skill_dir = os.path.join(self.manager.custom_dir, name)
|
||||
if os.path.exists(skill_dir):
|
||||
shutil.rmtree(skill_dir)
|
||||
logger.info(f"[SkillService] delete: removed directory {skill_dir}")
|
||||
else:
|
||||
logger.warning(f"[SkillService] delete: skill directory not found: {skill_dir}")
|
||||
|
||||
# Refresh will remove the deleted skill from config automatically
|
||||
self.manager.refresh_skills()
|
||||
logger.info(f"[SkillService] delete: skill '{name}' deleted")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# dispatch - single entry point for protocol messages
|
||||
# ------------------------------------------------------------------
|
||||
def dispatch(self, action: str, payload: Optional[dict] = None) -> dict:
|
||||
"""
|
||||
Dispatch a skill management action and return a protocol-compatible
|
||||
response dict.
|
||||
|
||||
:param action: one of query / add / open / close / delete
|
||||
:param payload: action-specific payload (may be None for query)
|
||||
:return: dict with action, code, message, payload
|
||||
"""
|
||||
payload = payload or {}
|
||||
try:
|
||||
if action == "query":
|
||||
result_payload = self.query()
|
||||
return {"action": action, "code": 200, "message": "success", "payload": result_payload}
|
||||
elif action == "add":
|
||||
self.add(payload)
|
||||
elif action == "open":
|
||||
self.open(payload)
|
||||
elif action == "close":
|
||||
self.close(payload)
|
||||
elif action == "delete":
|
||||
self.delete(payload)
|
||||
else:
|
||||
return {"action": action, "code": 400, "message": f"unknown action: {action}", "payload": None}
|
||||
return {"action": action, "code": 200, "message": "success", "payload": None}
|
||||
except Exception as e:
|
||||
logger.error(f"[SkillService] dispatch error: action={action}, error={e}")
|
||||
return {"action": action, "code": 500, "message": str(e), "payload": None}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
@staticmethod
|
||||
def _download_file(url: str, dest: str):
|
||||
"""
|
||||
Download a file from *url* and save to *dest*.
|
||||
|
||||
:param url: remote file URL
|
||||
:param dest: local destination path
|
||||
"""
|
||||
if requests is None:
|
||||
raise RuntimeError("requests library is required for downloading skill files")
|
||||
|
||||
dest_dir = os.path.dirname(dest)
|
||||
if dest_dir:
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
resp = requests.get(url, timeout=60)
|
||||
resp.raise_for_status()
|
||||
with open(dest, "wb") as f:
|
||||
f.write(resp.content)
|
||||
logger.debug(f"[SkillService] downloaded {url} -> {dest}")
|
||||
75
agent/skills/types.py
Normal file
75
agent/skills/types.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Type definitions for skills system.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillInstallSpec:
|
||||
"""Specification for installing skill dependencies."""
|
||||
kind: str # brew, pip, npm, download, etc.
|
||||
id: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
bins: List[str] = field(default_factory=list)
|
||||
os: List[str] = field(default_factory=list)
|
||||
formula: Optional[str] = None # for brew
|
||||
package: Optional[str] = None # for pip/npm
|
||||
module: Optional[str] = None
|
||||
url: Optional[str] = None # for download
|
||||
archive: Optional[str] = None
|
||||
extract: bool = False
|
||||
strip_components: Optional[int] = None
|
||||
target_dir: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillMetadata:
|
||||
"""Metadata for a skill from frontmatter."""
|
||||
always: bool = False # Always include this skill
|
||||
skill_key: Optional[str] = None # Override skill key
|
||||
primary_env: Optional[str] = None # Primary environment variable
|
||||
emoji: Optional[str] = None
|
||||
homepage: Optional[str] = None
|
||||
os: List[str] = field(default_factory=list) # Supported OS platforms
|
||||
requires: Dict[str, List[str]] = field(default_factory=dict) # Requirements
|
||||
install: List[SkillInstallSpec] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Skill:
|
||||
"""Represents a skill loaded from a markdown file."""
|
||||
name: str
|
||||
description: str
|
||||
file_path: str
|
||||
base_dir: str
|
||||
source: str # builtin or custom
|
||||
content: str # Full markdown content
|
||||
disable_model_invocation: bool = False
|
||||
frontmatter: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillEntry:
|
||||
"""A skill with parsed metadata."""
|
||||
skill: Skill
|
||||
metadata: Optional[SkillMetadata] = None
|
||||
user_invocable: bool = True # Can users invoke this skill directly
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadSkillsResult:
|
||||
"""Result of loading skills from a directory."""
|
||||
skills: List[Skill]
|
||||
diagnostics: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillSnapshot:
|
||||
"""Snapshot of skills for a specific run."""
|
||||
prompt: str # Formatted prompt text
|
||||
skills: List[Dict[str, str]] # List of skill info (name, primary_env)
|
||||
resolved_skills: List[Skill] = field(default_factory=list)
|
||||
version: Optional[int] = None
|
||||
111
agent/tools/__init__.py
Normal file
111
agent/tools/__init__.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Import base tool
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
|
||||
# Import file operation tools
|
||||
from agent.tools.read.read import Read
|
||||
from agent.tools.write.write import Write
|
||||
from agent.tools.edit.edit import Edit
|
||||
from agent.tools.bash.bash import Bash
|
||||
from agent.tools.ls.ls import Ls
|
||||
from agent.tools.send.send import Send
|
||||
|
||||
# Import memory tools
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
# Import tools with optional dependencies
|
||||
def _import_optional_tools():
|
||||
"""Import tools that have optional dependencies"""
|
||||
from common.log import logger
|
||||
tools = {}
|
||||
|
||||
# EnvConfig Tool (requires python-dotenv)
|
||||
try:
|
||||
from agent.tools.env_config.env_config import EnvConfig
|
||||
tools['EnvConfig'] = EnvConfig
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"[Tools] EnvConfig tool not loaded - missing dependency: {e}\n"
|
||||
f" To enable environment variable management, run:\n"
|
||||
f" pip install python-dotenv>=1.0.0"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] EnvConfig tool failed to load: {e}")
|
||||
|
||||
# Scheduler Tool (requires croniter)
|
||||
try:
|
||||
from agent.tools.scheduler.scheduler_tool import SchedulerTool
|
||||
tools['SchedulerTool'] = SchedulerTool
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
f"[Tools] Scheduler tool not loaded - missing dependency: {e}\n"
|
||||
f" To enable scheduled tasks, run:\n"
|
||||
f" pip install croniter>=2.0.0"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] Scheduler tool failed to load: {e}")
|
||||
|
||||
# WebSearch Tool (conditionally loaded based on API key availability at init time)
|
||||
try:
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
tools['WebSearch'] = WebSearch
|
||||
except ImportError as e:
|
||||
logger.error(f"[Tools] WebSearch not loaded - missing dependency: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tools] WebSearch failed to load: {e}")
|
||||
|
||||
return tools
|
||||
|
||||
# Load optional tools
|
||||
_optional_tools = _import_optional_tools()
|
||||
EnvConfig = _optional_tools.get('EnvConfig')
|
||||
SchedulerTool = _optional_tools.get('SchedulerTool')
|
||||
WebSearch = _optional_tools.get('WebSearch')
|
||||
GoogleSearch = _optional_tools.get('GoogleSearch')
|
||||
FileSave = _optional_tools.get('FileSave')
|
||||
Terminal = _optional_tools.get('Terminal')
|
||||
|
||||
|
||||
# Delayed import for BrowserTool
|
||||
def _import_browser_tool():
|
||||
try:
|
||||
from agent.tools.browser.browser_tool import BrowserTool
|
||||
return BrowserTool
|
||||
except ImportError:
|
||||
# Return a placeholder class that will prompt the user to install dependencies when instantiated
|
||||
class BrowserToolPlaceholder:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError(
|
||||
"The 'browser-use' package is required to use BrowserTool. "
|
||||
"Please install it with 'pip install browser-use>=0.1.40'."
|
||||
)
|
||||
|
||||
return BrowserToolPlaceholder
|
||||
|
||||
|
||||
# Dynamically set BrowserTool
|
||||
# BrowserTool = _import_browser_tool()
|
||||
|
||||
# Export all tools (including optional ones that might be None)
|
||||
__all__ = [
|
||||
'BaseTool',
|
||||
'ToolManager',
|
||||
'Read',
|
||||
'Write',
|
||||
'Edit',
|
||||
'Bash',
|
||||
'Ls',
|
||||
'Send',
|
||||
'MemorySearchTool',
|
||||
'MemoryGetTool',
|
||||
'EnvConfig',
|
||||
'SchedulerTool',
|
||||
'WebSearch',
|
||||
# Optional tools (may be None if dependencies not available)
|
||||
# 'BrowserTool'
|
||||
]
|
||||
|
||||
"""
|
||||
Tools module for Agent.
|
||||
"""
|
||||
99
agent/tools/base_tool.py
Normal file
99
agent/tools/base_tool.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from common.log import logger
|
||||
import copy
|
||||
|
||||
|
||||
class ToolStage(Enum):
|
||||
"""Enum representing tool decision stages"""
|
||||
PRE_PROCESS = "pre_process" # Tools that need to be actively selected by the agent
|
||||
POST_PROCESS = "post_process" # Tools that automatically execute after final_answer
|
||||
|
||||
|
||||
class ToolResult:
|
||||
"""Tool execution result"""
|
||||
|
||||
def __init__(self, status: str = None, result: Any = None, ext_data: Any = None):
|
||||
self.status = status
|
||||
self.result = result
|
||||
self.ext_data = ext_data
|
||||
|
||||
@staticmethod
|
||||
def success(result, ext_data: Any = None):
|
||||
return ToolResult(status="success", result=result, ext_data=ext_data)
|
||||
|
||||
@staticmethod
|
||||
def fail(result, ext_data: Any = None):
|
||||
return ToolResult(status="error", result=result, ext_data=ext_data)
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all tools."""
|
||||
|
||||
# Default decision stage is pre-process
|
||||
stage = ToolStage.PRE_PROCESS
|
||||
|
||||
# Class attributes must be inherited
|
||||
name: str = "base_tool"
|
||||
description: str = "Base tool"
|
||||
params: dict = {} # Store JSON Schema
|
||||
model: Optional[Any] = None # LLM model instance, type depends on bot implementation
|
||||
|
||||
@classmethod
|
||||
def get_json_schema(cls) -> dict:
|
||||
"""Get the standard description of the tool"""
|
||||
return {
|
||||
"name": cls.name,
|
||||
"description": cls.description,
|
||||
"parameters": cls.params
|
||||
}
|
||||
|
||||
def execute_tool(self, params: dict) -> ToolResult:
|
||||
try:
|
||||
return self.execute(params)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
"""Specific logic to be implemented by subclasses"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _parse_schema(cls) -> dict:
|
||||
"""Convert JSON Schema to Pydantic fields"""
|
||||
fields = {}
|
||||
for name, prop in cls.params["properties"].items():
|
||||
# Convert JSON Schema types to Python types
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict
|
||||
}
|
||||
fields[name] = (
|
||||
type_map[prop["type"]],
|
||||
prop.get("default", ...)
|
||||
)
|
||||
return fields
|
||||
|
||||
def should_auto_execute(self, context) -> bool:
|
||||
"""
|
||||
Determine if this tool should be automatically executed based on context.
|
||||
|
||||
:param context: The agent context
|
||||
:return: True if the tool should be executed, False otherwise
|
||||
"""
|
||||
# Only tools in post-process stage will be automatically executed
|
||||
return self.stage == ToolStage.POST_PROCESS
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close any resources used by the tool.
|
||||
This method should be overridden by tools that need to clean up resources
|
||||
such as browser connections, file handles, etc.
|
||||
|
||||
By default, this method does nothing.
|
||||
"""
|
||||
pass
|
||||
3
agent/tools/bash/__init__.py
Normal file
3
agent/tools/bash/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .bash import Bash
|
||||
|
||||
__all__ = ['Bash']
|
||||
260
agent/tools/bash/bash.py
Normal file
260
agent/tools/bash/bash.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Bash tool - Execute bash commands
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_tail, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Bash(BaseTool):
|
||||
"""Tool for executing bash commands"""
|
||||
|
||||
name: str = "bash"
|
||||
description: str = f"""Execute a bash command in the current working directory. Returns stdout and stderr. Output is truncated to last {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first). If truncated, full output is saved to a temp file.
|
||||
|
||||
ENVIRONMENT: All API keys from env_config are auto-injected. Use $VAR_NAME directly.
|
||||
|
||||
SAFETY:
|
||||
- Freely create/modify/delete files within the workspace
|
||||
- For destructive and out-of-workspace commands, explain and confirm first"""
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Timeout in seconds (optional, default: 30)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
# Ensure working directory exists
|
||||
if not os.path.exists(self.cwd):
|
||||
os.makedirs(self.cwd, exist_ok=True)
|
||||
self.default_timeout = self.config.get("timeout", 30)
|
||||
# Enable safety mode by default (can be disabled in config)
|
||||
self.safety_mode = self.config.get("safety_mode", True)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute a bash command
|
||||
|
||||
:param args: Dictionary containing the command and optional timeout
|
||||
:return: Command output or error
|
||||
"""
|
||||
command = args.get("command", "").strip()
|
||||
timeout = args.get("timeout", self.default_timeout)
|
||||
|
||||
if not command:
|
||||
return ToolResult.fail("Error: command parameter is required")
|
||||
|
||||
# Security check: Prevent accessing sensitive config files
|
||||
if "~/.cow/.env" in command or "~/.cow" in command:
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
# Optional safety check - only warn about extremely dangerous commands
|
||||
if self.safety_mode:
|
||||
warning = self._get_safety_warning(command)
|
||||
if warning:
|
||||
return ToolResult.fail(
|
||||
f"Safety Warning: {warning}\n\nIf you believe this command is safe and necessary, please ask the user for confirmation first, explaining what the command does and why it's needed.")
|
||||
|
||||
try:
|
||||
# Prepare environment with .env file variables
|
||||
env = os.environ.copy()
|
||||
|
||||
# Load environment variables from ~/.cow/.env if it exists
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
env_vars = dotenv_values(env_file)
|
||||
env.update(env_vars)
|
||||
logger.debug(f"[Bash] Loaded {len(env_vars)} variables from {env_file}")
|
||||
except ImportError:
|
||||
logger.debug("[Bash] python-dotenv not installed, skipping .env loading")
|
||||
except Exception as e:
|
||||
logger.debug(f"[Bash] Failed to load .env: {e}")
|
||||
|
||||
# getuid() only exists on Unix-like systems
|
||||
if hasattr(os, 'getuid'):
|
||||
logger.debug(f"[Bash] Process UID: {os.getuid()}")
|
||||
else:
|
||||
logger.debug(f"[Bash] Process User: {os.environ.get('USERNAME', os.environ.get('USER', 'unknown'))}")
|
||||
|
||||
# Execute command with inherited environment variables
|
||||
result = subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
cwd=self.cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
|
||||
logger.debug(f"[Bash] Exit code: {result.returncode}")
|
||||
logger.debug(f"[Bash] Stdout length: {len(result.stdout)}")
|
||||
logger.debug(f"[Bash] Stderr length: {len(result.stderr)}")
|
||||
|
||||
# Workaround for exit code 126 with no output
|
||||
if result.returncode == 126 and not result.stdout and not result.stderr:
|
||||
logger.warning(f"[Bash] Exit 126 with no output - trying alternative execution method")
|
||||
# Try using argument list instead of shell=True
|
||||
import shlex
|
||||
try:
|
||||
parts = shlex.split(command)
|
||||
if len(parts) > 0:
|
||||
logger.info(f"[Bash] Retrying with argument list: {parts[:3]}...")
|
||||
retry_result = subprocess.run(
|
||||
parts,
|
||||
cwd=self.cwd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
env=env
|
||||
)
|
||||
logger.debug(f"[Bash] Retry exit code: {retry_result.returncode}, stdout: {len(retry_result.stdout)}, stderr: {len(retry_result.stderr)}")
|
||||
|
||||
# If retry succeeded, use retry result
|
||||
if retry_result.returncode == 0 or retry_result.stdout or retry_result.stderr:
|
||||
result = retry_result
|
||||
else:
|
||||
# Both attempts failed - check if this is openai-image-vision skill
|
||||
if 'openai-image-vision' in command or 'vision.sh' in command:
|
||||
# Create a mock result with helpful error message
|
||||
from types import SimpleNamespace
|
||||
result = SimpleNamespace(
|
||||
returncode=1,
|
||||
stdout='{"error": "图片无法解析", "reason": "该图片格式可能不受支持,或图片文件存在问题", "suggestion": "请尝试其他图片"}',
|
||||
stderr=''
|
||||
)
|
||||
logger.info(f"[Bash] Converted exit 126 to user-friendly image error message for vision skill")
|
||||
except Exception as retry_err:
|
||||
logger.warning(f"[Bash] Retry failed: {retry_err}")
|
||||
|
||||
# Combine stdout and stderr
|
||||
output = result.stdout
|
||||
if result.stderr:
|
||||
output += "\n" + result.stderr
|
||||
|
||||
# Check if we need to save full output to temp file
|
||||
temp_file_path = None
|
||||
total_bytes = len(output.encode('utf-8'))
|
||||
|
||||
if total_bytes > DEFAULT_MAX_BYTES:
|
||||
# Save full output to temp file
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.log', prefix='bash-') as f:
|
||||
f.write(output)
|
||||
temp_file_path = f.name
|
||||
|
||||
# Apply tail truncation
|
||||
truncation = truncate_tail(output)
|
||||
output_text = truncation.content or "(no output)"
|
||||
|
||||
# Build result
|
||||
details = {}
|
||||
|
||||
if truncation.truncated:
|
||||
details["truncation"] = truncation.to_dict()
|
||||
if temp_file_path:
|
||||
details["full_output_path"] = temp_file_path
|
||||
|
||||
# Build notice
|
||||
start_line = truncation.total_lines - truncation.output_lines + 1
|
||||
end_line = truncation.total_lines
|
||||
|
||||
if truncation.last_line_partial:
|
||||
# Edge case: last line alone > 30KB
|
||||
last_line = output.split('\n')[-1] if output else ""
|
||||
last_line_size = format_size(len(last_line.encode('utf-8')))
|
||||
output_text += f"\n\n[Showing last {format_size(truncation.output_bytes)} of line {end_line} (line is {last_line_size}). Full output: {temp_file_path}]"
|
||||
elif truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines}. Full output: {temp_file_path}]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line}-{end_line} of {truncation.total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Full output: {temp_file_path}]"
|
||||
|
||||
# Check exit code
|
||||
if result.returncode != 0:
|
||||
output_text += f"\n\nCommand exited with code {result.returncode}"
|
||||
return ToolResult.fail({
|
||||
"output": output_text,
|
||||
"exit_code": result.returncode,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output_text,
|
||||
"exit_code": result.returncode,
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return ToolResult.fail(f"Error: Command timed out after {timeout} seconds")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error executing command: {str(e)}")
|
||||
|
||||
def _get_safety_warning(self, command: str) -> str:
|
||||
"""
|
||||
Get safety warning for potentially dangerous commands
|
||||
Only warns about extremely dangerous system-level operations
|
||||
|
||||
:param command: Command to check
|
||||
:return: Warning message if dangerous, empty string if safe
|
||||
"""
|
||||
cmd_lower = command.lower().strip()
|
||||
|
||||
# Only block extremely dangerous system operations
|
||||
dangerous_patterns = [
|
||||
# System shutdown/reboot
|
||||
("shutdown", "This command will shut down the system"),
|
||||
("reboot", "This command will reboot the system"),
|
||||
("halt", "This command will halt the system"),
|
||||
("poweroff", "This command will power off the system"),
|
||||
|
||||
# Critical system modifications
|
||||
("rm -rf /", "This command will delete the entire filesystem"),
|
||||
("rm -rf /*", "This command will delete the entire filesystem"),
|
||||
("dd if=/dev/zero", "This command can destroy disk data"),
|
||||
("mkfs", "This command will format a filesystem, destroying all data"),
|
||||
("fdisk", "This command modifies disk partitions"),
|
||||
|
||||
# User/system management (only if targeting system users)
|
||||
("userdel root", "This command will delete the root user"),
|
||||
("passwd root", "This command will change the root password"),
|
||||
]
|
||||
|
||||
for pattern, warning in dangerous_patterns:
|
||||
if pattern in cmd_lower:
|
||||
return warning
|
||||
|
||||
# Check for recursive deletion outside workspace
|
||||
if "rm" in cmd_lower and "-rf" in cmd_lower:
|
||||
# Allow deletion within current workspace
|
||||
if not any(path in cmd_lower for path in ["./", self.cwd.lower()]):
|
||||
# Check if targeting system directories
|
||||
system_dirs = ["/bin", "/usr", "/etc", "/var", "/home", "/root", "/sys", "/proc"]
|
||||
if any(sysdir in cmd_lower for sysdir in system_dirs):
|
||||
return "This command will recursively delete system directories"
|
||||
|
||||
return "" # No warning needed
|
||||
18
agent/tools/browser_tool.py
Normal file
18
agent/tools/browser_tool.py
Normal file
@@ -0,0 +1,18 @@
|
||||
def copy(self):
|
||||
"""
|
||||
Special copy method for browser tool to avoid recreating browser instance.
|
||||
|
||||
:return: A new instance with shared browser reference but unique model
|
||||
"""
|
||||
new_tool = self.__class__()
|
||||
|
||||
# Copy essential attributes
|
||||
new_tool.model = self.model
|
||||
new_tool.context = getattr(self, 'context', None)
|
||||
new_tool.config = getattr(self, 'config', None)
|
||||
|
||||
# Share the browser instance instead of creating a new one
|
||||
if hasattr(self, 'browser'):
|
||||
new_tool.browser = self.browser
|
||||
|
||||
return new_tool
|
||||
3
agent/tools/edit/__init__.py
Normal file
3
agent/tools/edit/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .edit import Edit
|
||||
|
||||
__all__ = ['Edit']
|
||||
185
agent/tools/edit/edit.py
Normal file
185
agent/tools/edit/edit.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Edit tool - Precise file editing
|
||||
Edit files through exact text replacement
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
from agent.tools.utils.diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
normalize_to_lf,
|
||||
restore_line_endings,
|
||||
normalize_for_fuzzy_match,
|
||||
fuzzy_find_text,
|
||||
generate_diff_string
|
||||
)
|
||||
|
||||
|
||||
class Edit(BaseTool):
|
||||
"""Tool for precise file editing"""
|
||||
|
||||
name: str = "edit"
|
||||
description: str = "Edit a file by replacing exact text, or append to end if oldText is empty. For append: use empty oldText. For replace: oldText must match exactly (including whitespace)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit (relative or absolute)"
|
||||
},
|
||||
"oldText": {
|
||||
"type": "string",
|
||||
"description": "Text to find and replace. Use empty string to append to end of file. For replacement: must match exactly including whitespace."
|
||||
},
|
||||
"newText": {
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with"
|
||||
}
|
||||
},
|
||||
"required": ["path", "oldText", "newText"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.memory_manager = self.config.get("memory_manager", None)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file edit operation
|
||||
|
||||
:param args: Contains file path, old text and new text
|
||||
:return: Operation result
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
old_text = args.get("oldText", "")
|
||||
new_text = args.get("newText", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable/writable
|
||||
if not os.access(absolute_path, os.R_OK | os.W_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable/writable: {path}")
|
||||
|
||||
try:
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
raw_content = f.read()
|
||||
|
||||
# Remove BOM (LLM won't include invisible BOM in oldText)
|
||||
bom, content = strip_bom(raw_content)
|
||||
|
||||
# Detect original line ending
|
||||
original_ending = detect_line_ending(content)
|
||||
|
||||
# Normalize to LF
|
||||
normalized_content = normalize_to_lf(content)
|
||||
normalized_old_text = normalize_to_lf(old_text)
|
||||
normalized_new_text = normalize_to_lf(new_text)
|
||||
|
||||
# Special case: empty oldText means append to end of file
|
||||
if not old_text or not old_text.strip():
|
||||
# Append mode: add newText to the end
|
||||
# Add newline before newText if file doesn't end with one
|
||||
if normalized_content and not normalized_content.endswith('\n'):
|
||||
new_content = normalized_content + '\n' + normalized_new_text
|
||||
else:
|
||||
new_content = normalized_content + normalized_new_text
|
||||
base_content = normalized_content # For verification
|
||||
else:
|
||||
# Normal edit mode: find and replace
|
||||
# Use fuzzy matching to find old text (try exact match first, then fuzzy match)
|
||||
match_result = fuzzy_find_text(normalized_content, normalized_old_text)
|
||||
|
||||
if not match_result.found:
|
||||
return ToolResult.fail(
|
||||
f"Error: Could not find the exact text in {path}. "
|
||||
"The old text must match exactly including all whitespace and newlines."
|
||||
)
|
||||
|
||||
# Calculate occurrence count (use fuzzy normalized content for consistency)
|
||||
fuzzy_content = normalize_for_fuzzy_match(normalized_content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(normalized_old_text)
|
||||
occurrences = fuzzy_content.count(fuzzy_old_text)
|
||||
|
||||
if occurrences > 1:
|
||||
return ToolResult.fail(
|
||||
f"Error: Found {occurrences} occurrences of the text in {path}. "
|
||||
"The text must be unique. Please provide more context to make it unique."
|
||||
)
|
||||
|
||||
# Execute replacement (use matched text position)
|
||||
base_content = match_result.content_for_replacement
|
||||
new_content = (
|
||||
base_content[:match_result.index] +
|
||||
normalized_new_text +
|
||||
base_content[match_result.index + match_result.match_length:]
|
||||
)
|
||||
|
||||
# Verify replacement actually changed content
|
||||
if base_content == new_content:
|
||||
return ToolResult.fail(
|
||||
f"Error: No changes made to {path}. "
|
||||
"The replacement produced identical content. "
|
||||
"This might indicate an issue with special characters or the text not existing as expected."
|
||||
)
|
||||
|
||||
# Restore original line endings
|
||||
final_content = bom + restore_line_endings(new_content, original_ending)
|
||||
|
||||
# Write file
|
||||
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||
f.write(final_content)
|
||||
|
||||
# Generate diff
|
||||
diff_result = generate_diff_string(base_content, new_content)
|
||||
|
||||
result = {
|
||||
"message": f"Successfully replaced text in {path}",
|
||||
"path": path,
|
||||
"diff": diff_result['diff'],
|
||||
"first_changed_line": diff_result['first_changed_line']
|
||||
}
|
||||
|
||||
# Notify memory manager if file is in memory directory
|
||||
if self.memory_manager and "memory/" in path:
|
||||
try:
|
||||
self.memory_manager.mark_dirty()
|
||||
except Exception as e:
|
||||
# Don't fail the edit if memory notification fails
|
||||
pass
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {path}")
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied accessing {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error editing file: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
3
agent/tools/env_config/__init__.py
Normal file
3
agent/tools/env_config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.env_config.env_config import EnvConfig
|
||||
|
||||
__all__ = ['EnvConfig']
|
||||
286
agent/tools/env_config/env_config.py
Normal file
286
agent/tools/env_config/env_config.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Environment Configuration Tool - Manage API keys and environment variables
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
# API Key 知识库:常见的环境变量及其描述
|
||||
API_KEY_REGISTRY = {
|
||||
# AI 模型服务
|
||||
"OPENAI_API_KEY": "OpenAI API 密钥 (用于GPT模型、Embedding模型)",
|
||||
"GEMINI_API_KEY": "Google Gemini API 密钥",
|
||||
"CLAUDE_API_KEY": "Claude API 密钥 (用于Claude模型)",
|
||||
"LINKAI_API_KEY": "LinkAI智能体平台 API 密钥,支持多种模型切换",
|
||||
# 搜索服务
|
||||
"BOCHA_API_KEY": "博查 AI 搜索 API 密钥 ",
|
||||
}
|
||||
|
||||
class EnvConfig(BaseTool):
|
||||
"""Tool for managing environment variables (API keys, etc.)"""
|
||||
|
||||
name: str = "env_config"
|
||||
description: str = (
|
||||
"Manage API keys and skill configurations securely. "
|
||||
"Use this tool when user wants to configure API keys (like BOCHA_API_KEY, OPENAI_API_KEY), "
|
||||
"view configured keys, or manage skill settings. "
|
||||
"Actions: 'set' (add/update key), 'get' (view specific key), 'list' (show all configured keys), 'delete' (remove key). "
|
||||
"Values are automatically masked for security. Changes take effect immediately via hot reload."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform: 'set', 'get', 'list', 'delete'",
|
||||
"enum": ["set", "get", "list", "delete"]
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Environment variable key name. Common keys:\n"
|
||||
"- OPENAI_API_KEY: OpenAI API (GPT models)\n"
|
||||
"- OPENAI_API_BASE: OpenAI API base URL\n"
|
||||
"- CLAUDE_API_KEY: Anthropic Claude API\n"
|
||||
"- GEMINI_API_KEY: Google Gemini API\n"
|
||||
"- LINKAI_API_KEY: LinkAI platform\n"
|
||||
"- BOCHA_API_KEY: Bocha AI search (博查搜索)\n"
|
||||
"Use exact key names (case-sensitive, all uppercase with underscores)"
|
||||
)
|
||||
},
|
||||
"value": {
|
||||
"type": "string",
|
||||
"description": "Value to set for the environment variable (for 'set' action)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
# Store env config in ~/.cow directory (outside workspace for security)
|
||||
self.env_dir = expand_path("~/.cow")
|
||||
self.env_path = os.path.join(self.env_dir, '.env')
|
||||
self.agent_bridge = self.config.get("agent_bridge") # Reference to AgentBridge for hot reload
|
||||
# Don't create .env file in __init__ to avoid issues during tool discovery
|
||||
# It will be created on first use in execute()
|
||||
|
||||
def _ensure_env_file(self):
|
||||
"""Ensure the .env file exists"""
|
||||
# Create ~/.cow directory if it doesn't exist
|
||||
os.makedirs(self.env_dir, exist_ok=True)
|
||||
|
||||
if not os.path.exists(self.env_path):
|
||||
Path(self.env_path).touch()
|
||||
logger.info(f"[EnvConfig] Created .env file at {self.env_path}")
|
||||
|
||||
def _mask_value(self, value: str) -> str:
|
||||
"""Mask sensitive parts of a value for logging"""
|
||||
if not value or len(value) <= 10:
|
||||
return "***"
|
||||
return f"{value[:6]}***{value[-4:]}"
|
||||
|
||||
def _read_env_file(self) -> Dict[str, str]:
|
||||
"""Read all key-value pairs from .env file"""
|
||||
env_vars = {}
|
||||
if os.path.exists(self.env_path):
|
||||
with open(self.env_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
# Parse KEY=VALUE
|
||||
match = re.match(r'^([^=]+)=(.*)$', line)
|
||||
if match:
|
||||
key, value = match.groups()
|
||||
env_vars[key.strip()] = value.strip()
|
||||
return env_vars
|
||||
|
||||
def _write_env_file(self, env_vars: Dict[str, str]):
|
||||
"""Write all key-value pairs to .env file"""
|
||||
with open(self.env_path, 'w', encoding='utf-8') as f:
|
||||
f.write("# Environment variables for agent skills\n")
|
||||
f.write("# Auto-managed by env_config tool\n\n")
|
||||
for key, value in sorted(env_vars.items()):
|
||||
f.write(f"{key}={value}\n")
|
||||
|
||||
def _reload_env(self):
|
||||
"""Reload environment variables from .env file"""
|
||||
env_vars = self._read_env_file()
|
||||
for key, value in env_vars.items():
|
||||
os.environ[key] = value
|
||||
logger.debug(f"[EnvConfig] Reloaded {len(env_vars)} environment variables")
|
||||
|
||||
def _refresh_skills(self):
|
||||
"""Refresh skills after environment variable changes"""
|
||||
if self.agent_bridge:
|
||||
try:
|
||||
# Reload .env file
|
||||
self._reload_env()
|
||||
|
||||
# Refresh skills in all agent instances
|
||||
refreshed = self.agent_bridge.refresh_all_skills()
|
||||
logger.info(f"[EnvConfig] Refreshed skills in {refreshed} agent instance(s)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[EnvConfig] Failed to refresh skills: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute environment configuration operation
|
||||
|
||||
:param args: Contains action, key, and value parameters
|
||||
:return: Result of the operation
|
||||
"""
|
||||
# Ensure .env file exists on first use
|
||||
self._ensure_env_file()
|
||||
|
||||
action = args.get("action")
|
||||
key = args.get("key")
|
||||
value = args.get("value")
|
||||
|
||||
try:
|
||||
if action == "set":
|
||||
if not key or not value:
|
||||
return ToolResult.fail("Error: 'key' and 'value' are required for 'set' action.")
|
||||
|
||||
# Read current env vars
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
# Update the key
|
||||
env_vars[key] = value
|
||||
|
||||
# Write back to file
|
||||
self._write_env_file(env_vars)
|
||||
|
||||
# Update current process env
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[EnvConfig] Set {key}={self._mask_value(value)}")
|
||||
|
||||
# Try to refresh skills immediately
|
||||
refreshed = self._refresh_skills()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully set {key}",
|
||||
"key": key,
|
||||
"value": self._mask_value(value),
|
||||
}
|
||||
|
||||
if refreshed:
|
||||
result["note"] = "✅ Skills refreshed automatically - changes are now active"
|
||||
else:
|
||||
result["note"] = "⚠️ Skills not refreshed - restart agent to load new skills"
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
elif action == "get":
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for 'get' action.")
|
||||
|
||||
# Check in file first, then in current env
|
||||
env_vars = self._read_env_file()
|
||||
value = env_vars.get(key) or os.getenv(key)
|
||||
|
||||
# Get description from registry
|
||||
description = API_KEY_REGISTRY.get(key, "未知用途的环境变量")
|
||||
|
||||
if value is not None:
|
||||
logger.info(f"[EnvConfig] Got {key}={self._mask_value(value)}")
|
||||
return ToolResult.success({
|
||||
"key": key,
|
||||
"value": self._mask_value(value),
|
||||
"description": description,
|
||||
"exists": True,
|
||||
"note": f"Value is masked for security. In bash, use ${key} directly — it is auto-injected."
|
||||
})
|
||||
else:
|
||||
return ToolResult.success({
|
||||
"key": key,
|
||||
"description": description,
|
||||
"exists": False,
|
||||
"message": f"Environment variable '{key}' is not set"
|
||||
})
|
||||
|
||||
elif action == "list":
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
# Build detailed variable list with descriptions
|
||||
variables_with_info = {}
|
||||
for key, value in env_vars.items():
|
||||
variables_with_info[key] = {
|
||||
"value": self._mask_value(value),
|
||||
"description": API_KEY_REGISTRY.get(key, "未知用途的环境变量")
|
||||
}
|
||||
|
||||
logger.info(f"[EnvConfig] Listed {len(env_vars)} environment variables")
|
||||
|
||||
if not env_vars:
|
||||
return ToolResult.success({
|
||||
"message": "No environment variables configured",
|
||||
"variables": {},
|
||||
"note": "常用的 API 密钥可以通过 env_config(action='set', key='KEY_NAME', value='your-key') 来配置"
|
||||
})
|
||||
|
||||
return ToolResult.success({
|
||||
"message": f"Found {len(env_vars)} environment variable(s)",
|
||||
"variables": variables_with_info
|
||||
})
|
||||
|
||||
elif action == "delete":
|
||||
if not key:
|
||||
return ToolResult.fail("Error: 'key' is required for 'delete' action.")
|
||||
|
||||
# Read current env vars
|
||||
env_vars = self._read_env_file()
|
||||
|
||||
if key not in env_vars:
|
||||
return ToolResult.success({
|
||||
"message": f"Environment variable '{key}' was not set",
|
||||
"key": key
|
||||
})
|
||||
|
||||
# Remove the key
|
||||
del env_vars[key]
|
||||
|
||||
# Write back to file
|
||||
self._write_env_file(env_vars)
|
||||
|
||||
# Remove from current process env
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
logger.info(f"[EnvConfig] Deleted {key}")
|
||||
|
||||
# Try to refresh skills immediately
|
||||
refreshed = self._refresh_skills()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully deleted {key}",
|
||||
"key": key,
|
||||
}
|
||||
|
||||
if refreshed:
|
||||
result["note"] = "✅ Skills refreshed automatically - changes are now active"
|
||||
else:
|
||||
result["note"] = "⚠️ Skills not refreshed - restart agent to apply changes"
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
else:
|
||||
return ToolResult.fail(f"Error: Unknown action '{action}'. Use 'set', 'get', 'list', or 'delete'.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EnvConfig] Error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"EnvConfig tool error: {str(e)}")
|
||||
3
agent/tools/ls/__init__.py
Normal file
3
agent/tools/ls/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ls import Ls
|
||||
|
||||
__all__ = ['Ls']
|
||||
140
agent/tools/ls/ls.py
Normal file
140
agent/tools/ls/ls.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Ls tool - List directory contents
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_BYTES
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
DEFAULT_LIMIT = 500
|
||||
|
||||
|
||||
class Ls(BaseTool):
|
||||
"""Tool for listing directory contents"""
|
||||
|
||||
name: str = "ls"
|
||||
description: str = f"List directory contents. Returns entries sorted alphabetically, with '/' suffix for directories. Includes dotfiles. Output is truncated to {DEFAULT_LIMIT} entries or {DEFAULT_MAX_BYTES // 1024}KB (whichever is hit first)."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to list. IMPORTANT: Relative paths are based on workspace directory. To access directories outside workspace, use absolute paths starting with ~ or /."
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": f"Maximum number of entries to return (default: {DEFAULT_LIMIT})"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute directory listing
|
||||
|
||||
:param args: Listing parameters
|
||||
:return: Directory contents or error
|
||||
"""
|
||||
path = args.get("path", ".").strip()
|
||||
limit = args.get("limit", DEFAULT_LIMIT)
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent accessing sensitive config directory
|
||||
env_config_dir = expand_path("~/.cow")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_dir):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
if not os.path.exists(absolute_path):
|
||||
# Provide helpful hint if using relative path
|
||||
if not os.path.isabs(path) and not path.startswith('~'):
|
||||
return ToolResult.fail(
|
||||
f"Error: Path not found: {path}\n"
|
||||
f"Resolved to: {absolute_path}\n"
|
||||
f"Hint: Relative paths are based on workspace ({self.cwd}). For files outside workspace, use absolute paths."
|
||||
)
|
||||
return ToolResult.fail(f"Error: Path not found: {path}")
|
||||
|
||||
if not os.path.isdir(absolute_path):
|
||||
return ToolResult.fail(f"Error: Not a directory: {path}")
|
||||
|
||||
try:
|
||||
# Read directory entries
|
||||
entries = os.listdir(absolute_path)
|
||||
|
||||
# Sort alphabetically (case-insensitive)
|
||||
entries.sort(key=lambda x: x.lower())
|
||||
|
||||
# Format entries with directory indicators
|
||||
results = []
|
||||
entry_limit_reached = False
|
||||
|
||||
for entry in entries:
|
||||
if len(results) >= limit:
|
||||
entry_limit_reached = True
|
||||
break
|
||||
|
||||
full_path = os.path.join(absolute_path, entry)
|
||||
|
||||
try:
|
||||
if os.path.isdir(full_path):
|
||||
results.append(entry + '/')
|
||||
else:
|
||||
results.append(entry)
|
||||
except Exception:
|
||||
# Skip entries we can't stat
|
||||
continue
|
||||
|
||||
if not results:
|
||||
return ToolResult.success({"message": "(empty directory)", "entries": []})
|
||||
|
||||
# Format output
|
||||
raw_output = '\n'.join(results)
|
||||
truncation = truncate_head(raw_output, max_lines=999999) # Only limit by bytes
|
||||
|
||||
output = truncation.content
|
||||
details = {}
|
||||
notices = []
|
||||
|
||||
if entry_limit_reached:
|
||||
notices.append(f"{limit} entries limit reached. Use limit={limit * 2} for more")
|
||||
details["entry_limit_reached"] = limit
|
||||
|
||||
if truncation.truncated:
|
||||
notices.append(f"{format_size(DEFAULT_MAX_BYTES)} limit reached")
|
||||
details["truncation"] = truncation.to_dict()
|
||||
|
||||
if notices:
|
||||
output += f"\n\n[{'. '.join(notices)}]"
|
||||
|
||||
return ToolResult.success({
|
||||
"output": output,
|
||||
"entry_count": len(results),
|
||||
"details": details if details else None
|
||||
})
|
||||
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied reading directory: {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error listing directory: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
10
agent/tools/memory/__init__.py
Normal file
10
agent/tools/memory/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Memory tools for Agent
|
||||
|
||||
Provides memory_search and memory_get tools
|
||||
"""
|
||||
|
||||
from agent.tools.memory.memory_search import MemorySearchTool
|
||||
from agent.tools.memory.memory_get import MemoryGetTool
|
||||
|
||||
__all__ = ['MemorySearchTool', 'MemoryGetTool']
|
||||
111
agent/tools/memory/memory_get.py
Normal file
111
agent/tools/memory/memory_get.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Memory get tool
|
||||
|
||||
Allows agents to read specific sections from memory files
|
||||
"""
|
||||
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class MemoryGetTool(BaseTool):
|
||||
"""Tool for reading memory file contents"""
|
||||
|
||||
name: str = "memory_get"
|
||||
description: str = (
|
||||
"Read specific content from memory files. "
|
||||
"Use this to get full context from a memory file or specific line range."
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Relative path to the memory file (e.g. 'MEMORY.md', 'memory/2026-01-01.md')"
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Starting line number (optional, default: 1)",
|
||||
"default": 1
|
||||
},
|
||||
"num_lines": {
|
||||
"type": "integer",
|
||||
"description": "Number of lines to read (optional, reads all if not specified)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager):
|
||||
"""
|
||||
Initialize memory get tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
Execute memory file read
|
||||
|
||||
Args:
|
||||
args: Dictionary with path, start_line, num_lines
|
||||
|
||||
Returns:
|
||||
ToolResult with file content
|
||||
"""
|
||||
from agent.tools.base_tool import ToolResult
|
||||
|
||||
path = args.get("path")
|
||||
start_line = args.get("start_line", 1)
|
||||
num_lines = args.get("num_lines")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
try:
|
||||
workspace_dir = self.memory_manager.config.get_workspace()
|
||||
|
||||
# Auto-prepend memory/ if not present and not absolute path
|
||||
# Exception: MEMORY.md is in the root directory
|
||||
if not path.startswith('memory/') and not path.startswith('/') and path != 'MEMORY.md':
|
||||
path = f'memory/{path}'
|
||||
|
||||
file_path = workspace_dir / path
|
||||
|
||||
if not file_path.exists():
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
content = file_path.read_text(encoding='utf-8')
|
||||
lines = content.split('\n')
|
||||
|
||||
# Handle line range
|
||||
if start_line < 1:
|
||||
start_line = 1
|
||||
|
||||
start_idx = start_line - 1
|
||||
|
||||
if num_lines:
|
||||
end_idx = start_idx + num_lines
|
||||
selected_lines = lines[start_idx:end_idx]
|
||||
else:
|
||||
selected_lines = lines[start_idx:]
|
||||
|
||||
result = '\n'.join(selected_lines)
|
||||
|
||||
# Add metadata
|
||||
total_lines = len(lines)
|
||||
shown_lines = len(selected_lines)
|
||||
|
||||
output = [
|
||||
f"File: {path}",
|
||||
f"Lines: {start_line}-{start_line + shown_lines - 1} (total: {total_lines})",
|
||||
"",
|
||||
result
|
||||
]
|
||||
|
||||
return ToolResult.success('\n'.join(output))
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading memory file: {str(e)}")
|
||||
102
agent/tools/memory/memory_search.py
Normal file
102
agent/tools/memory/memory_search.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Memory search tool
|
||||
|
||||
Allows agents to search their memory using semantic and keyword search
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from agent.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class MemorySearchTool(BaseTool):
|
||||
"""Tool for searching agent memory"""
|
||||
|
||||
name: str = "memory_search"
|
||||
description: str = (
|
||||
"Search agent's long-term memory using semantic and keyword search. "
|
||||
"Use this to recall past conversations, preferences, and knowledge."
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (can be natural language question or keywords)"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results to return (default: 10)",
|
||||
"default": 10
|
||||
},
|
||||
"min_score": {
|
||||
"type": "number",
|
||||
"description": "Minimum relevance score (0-1, default: 0.1)",
|
||||
"default": 0.1
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager, user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize memory search tool
|
||||
|
||||
Args:
|
||||
memory_manager: MemoryManager instance
|
||||
user_id: Optional user ID for scoped search
|
||||
"""
|
||||
super().__init__()
|
||||
self.memory_manager = memory_manager
|
||||
self.user_id = user_id
|
||||
|
||||
def execute(self, args: dict):
|
||||
"""
|
||||
Execute memory search
|
||||
|
||||
Args:
|
||||
args: Dictionary with query, max_results, min_score
|
||||
|
||||
Returns:
|
||||
ToolResult with formatted search results
|
||||
"""
|
||||
from agent.tools.base_tool import ToolResult
|
||||
import asyncio
|
||||
|
||||
query = args.get("query")
|
||||
max_results = args.get("max_results", 10)
|
||||
min_score = args.get("min_score", 0.1)
|
||||
|
||||
if not query:
|
||||
return ToolResult.fail("Error: query parameter is required")
|
||||
|
||||
try:
|
||||
# Run async search in sync context
|
||||
results = asyncio.run(self.memory_manager.search(
|
||||
query=query,
|
||||
user_id=self.user_id,
|
||||
max_results=max_results,
|
||||
min_score=min_score,
|
||||
include_shared=True
|
||||
))
|
||||
|
||||
if not results:
|
||||
# Return clear message that no memories exist yet
|
||||
# This prevents infinite retry loops
|
||||
return ToolResult.success(
|
||||
f"No memories found for '{query}'. "
|
||||
f"This is normal if no memories have been stored yet. "
|
||||
f"You can store new memories by writing to MEMORY.md or memory/YYYY-MM-DD.md files."
|
||||
)
|
||||
|
||||
# Format results
|
||||
output = [f"Found {len(results)} relevant memories:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
output.append(f"\n{i}. {result.path} (lines {result.start_line}-{result.end_line})")
|
||||
output.append(f" Score: {result.score:.3f}")
|
||||
output.append(f" Snippet: {result.snippet}")
|
||||
|
||||
return ToolResult.success("\n".join(output))
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error searching memory: {str(e)}")
|
||||
3
agent/tools/read/__init__.py
Normal file
3
agent/tools/read/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .read import Read
|
||||
|
||||
__all__ = ['Read']
|
||||
443
agent/tools/read/read.py
Normal file
443
agent/tools/read/read.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Read tool - Read file contents
|
||||
Supports text files, images (jpg, png, gif, webp), and PDF files
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from agent.tools.utils.truncate import truncate_head, format_size, DEFAULT_MAX_LINES, DEFAULT_MAX_BYTES
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Read(BaseTool):
|
||||
"""Tool for reading file contents"""
|
||||
|
||||
name: str = "read"
|
||||
description: str = f"Read or inspect file contents. For text/PDF files, returns content (truncated to {DEFAULT_MAX_LINES} lines or {DEFAULT_MAX_BYTES // 1024}KB). For images/videos/audio, returns metadata only (file info, size, type). Use offset/limit for large text files."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to read. IMPORTANT: Relative paths are based on workspace directory. To access files outside workspace, use absolute paths starting with ~ or /."
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, optional). Use negative values to read from end (e.g. -20 for last 20 lines)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (optional)"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
# File type categories
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
|
||||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
|
||||
self.binary_extensions = {'.exe', '.dll', '.so', '.dylib', '.bin', '.dat', '.db', '.sqlite'}
|
||||
self.archive_extensions = {'.zip', '.tar', '.gz', '.rar', '.7z', '.bz2', '.xz'}
|
||||
self.pdf_extensions = {'.pdf'}
|
||||
|
||||
# Readable text formats (will be read with truncation)
|
||||
self.text_extensions = {
|
||||
'.txt', '.md', '.markdown', '.rst', '.log', '.csv', '.tsv', '.json', '.xml', '.yaml', '.yml',
|
||||
'.py', '.js', '.ts', '.java', '.c', '.cpp', '.h', '.hpp', '.go', '.rs', '.rb', '.php',
|
||||
'.html', '.css', '.scss', '.sass', '.less', '.vue', '.jsx', '.tsx',
|
||||
'.sh', '.bash', '.zsh', '.fish', '.ps1', '.bat', '.cmd',
|
||||
'.sql', '.r', '.m', '.swift', '.kt', '.scala', '.clj', '.erl', '.ex',
|
||||
'.dockerfile', '.makefile', '.cmake', '.gradle', '.properties', '.ini', '.conf', '.cfg',
|
||||
'.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx' # Office documents
|
||||
}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file read operation
|
||||
|
||||
:param args: Contains file path and optional offset/limit parameters
|
||||
:return: File content or error message
|
||||
"""
|
||||
# Support 'location' as alias for 'path' (LLM may use it from skill listing)
|
||||
path = args.get("path", "") or args.get("location", "")
|
||||
path = path.strip() if isinstance(path, str) else ""
|
||||
offset = args.get("offset")
|
||||
limit = args.get("limit")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Security check: Prevent reading sensitive config files
|
||||
env_config_path = expand_path("~/.cow/.env")
|
||||
if os.path.abspath(absolute_path) == os.path.abspath(env_config_path):
|
||||
return ToolResult.fail(
|
||||
"Error: Access denied. API keys and credentials must be accessed through the env_config tool only."
|
||||
)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
# Provide helpful hint if using relative path
|
||||
if not os.path.isabs(path) and not path.startswith('~'):
|
||||
return ToolResult.fail(
|
||||
f"Error: File not found: {path}\n"
|
||||
f"Resolved to: {absolute_path}\n"
|
||||
f"Hint: Relative paths are based on workspace ({self.cwd}). For files outside workspace, use absolute paths."
|
||||
)
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable
|
||||
if not os.access(absolute_path, os.R_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable: {path}")
|
||||
|
||||
# Check file type
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
|
||||
# Check if image - return metadata for sending
|
||||
if file_ext in self.image_extensions:
|
||||
return self._read_image(absolute_path, file_ext)
|
||||
|
||||
# Check if video/audio/binary/archive - return metadata only
|
||||
if file_ext in self.video_extensions:
|
||||
return self._return_file_metadata(absolute_path, "video", file_size)
|
||||
if file_ext in self.audio_extensions:
|
||||
return self._return_file_metadata(absolute_path, "audio", file_size)
|
||||
if file_ext in self.binary_extensions or file_ext in self.archive_extensions:
|
||||
return self._return_file_metadata(absolute_path, "binary", file_size)
|
||||
|
||||
# Check if PDF
|
||||
if file_ext in self.pdf_extensions:
|
||||
return self._read_pdf(absolute_path, path, offset, limit)
|
||||
|
||||
# Read text file (with truncation for large files)
|
||||
return self._read_text(absolute_path, path, offset, limit)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _return_file_metadata(self, absolute_path: str, file_type: str, file_size: int) -> ToolResult:
|
||||
"""
|
||||
Return file metadata for non-readable files (video, audio, binary, etc.)
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param file_type: Type of file (video, audio, binary, etc.)
|
||||
:param file_size: File size in bytes
|
||||
:return: File metadata
|
||||
"""
|
||||
file_name = Path(absolute_path).name
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
|
||||
# Determine MIME type
|
||||
mime_types = {
|
||||
# Video
|
||||
'.mp4': 'video/mp4', '.avi': 'video/x-msvideo', '.mov': 'video/quicktime',
|
||||
'.mkv': 'video/x-matroska', '.webm': 'video/webm',
|
||||
# Audio
|
||||
'.mp3': 'audio/mpeg', '.wav': 'audio/wav', '.ogg': 'audio/ogg',
|
||||
'.m4a': 'audio/mp4', '.flac': 'audio/flac',
|
||||
# Binary
|
||||
'.zip': 'application/zip', '.tar': 'application/x-tar',
|
||||
'.gz': 'application/gzip', '.rar': 'application/x-rar-compressed',
|
||||
}
|
||||
mime_type = mime_types.get(file_ext, 'application/octet-stream')
|
||||
|
||||
result = {
|
||||
"type": f"{file_type}_metadata",
|
||||
"file_type": file_type,
|
||||
"path": absolute_path,
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"{file_type.capitalize()} 文件: {file_name} ({format_size(file_size)})\n提示: 如果需要发送此文件,请使用 send 工具。"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _read_image(self, absolute_path: str, file_ext: str) -> ToolResult:
|
||||
"""
|
||||
Read image file - always return metadata only (images should be sent, not read into context)
|
||||
|
||||
:param absolute_path: Absolute path to the image file
|
||||
:param file_ext: File extension
|
||||
:return: Result containing image metadata for sending
|
||||
"""
|
||||
try:
|
||||
# Get file size
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
|
||||
# Determine MIME type
|
||||
mime_type_map = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp'
|
||||
}
|
||||
mime_type = mime_type_map.get(file_ext, 'image/jpeg')
|
||||
|
||||
# Return metadata for images (NOT file_to_send - use send tool to actually send)
|
||||
result = {
|
||||
"type": "image_metadata",
|
||||
"file_type": "image",
|
||||
"path": absolute_path,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"图片文件: {Path(absolute_path).name} ({format_size(file_size)})\n提示: 如果需要发送此图片,请使用 send 工具。"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading image file: {str(e)}")
|
||||
|
||||
def _read_text(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read text file
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param display_path: Path to display
|
||||
:param offset: Starting line number (1-indexed)
|
||||
:param limit: Maximum number of lines to read
|
||||
:return: File content or error message
|
||||
"""
|
||||
try:
|
||||
# Check file size first
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
# File too large, return metadata only
|
||||
return ToolResult.success({
|
||||
"type": "file_to_send",
|
||||
"file_type": "document",
|
||||
"path": absolute_path,
|
||||
"size": file_size,
|
||||
"size_formatted": format_size(file_size),
|
||||
"message": f"文件过大 ({format_size(file_size)} > 50MB),无法读取内容。文件路径: {absolute_path}"
|
||||
})
|
||||
|
||||
# Read file
|
||||
with open(absolute_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Truncate content if too long (20K characters max for model context)
|
||||
MAX_CONTENT_CHARS = 20 * 1024 # 20K characters
|
||||
content_truncated = False
|
||||
if len(content) > MAX_CONTENT_CHARS:
|
||||
content = content[:MAX_CONTENT_CHARS]
|
||||
content_truncated = True
|
||||
|
||||
all_lines = content.split('\n')
|
||||
total_file_lines = len(all_lines)
|
||||
|
||||
# Apply offset (if specified)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
if offset < 0:
|
||||
# Negative offset: read from end
|
||||
# -20 means "last 20 lines" → start from (total - 20)
|
||||
start_line = max(0, total_file_lines + offset)
|
||||
else:
|
||||
# Positive offset: read from start (1-indexed)
|
||||
start_line = max(0, offset - 1) # Convert to 0-indexed
|
||||
if start_line >= total_file_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of file ({total_file_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1 # For display (1-indexed)
|
||||
|
||||
# If user specified limit, use it
|
||||
selected_content = content
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_file_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
# Apply truncation (considering line count and byte limits)
|
||||
truncation = truncate_head(selected_content)
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
# Add truncation warning if content was truncated
|
||||
if content_truncated:
|
||||
output_text = f"[文件内容已截断到前 {format_size(MAX_CONTENT_CHARS)},完整文件大小: {format_size(file_size)}]\n\n"
|
||||
|
||||
if truncation.first_line_exceeds_limit:
|
||||
# First line exceeds 30KB limit
|
||||
first_line_size = format_size(len(all_lines[start_line].encode('utf-8')))
|
||||
output_text = f"[Line {start_line_display} is {first_line_size}, exceeds {format_size(DEFAULT_MAX_BYTES)} limit. Use bash tool to read: head -c {DEFAULT_MAX_BYTES} {display_path} | tail -n +{start_line_display}]"
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif truncation.truncated:
|
||||
# Truncation occurred
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
|
||||
output_text = truncation.content
|
||||
|
||||
if truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines}. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_file_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_file_lines:
|
||||
# User specified limit, more content available, but no truncation
|
||||
remaining = total_file_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
# No truncation, no exceeding user limit
|
||||
output_text = truncation.content
|
||||
|
||||
result = {
|
||||
"content": output_text,
|
||||
"total_lines": total_file_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines
|
||||
}
|
||||
|
||||
if details:
|
||||
result["details"] = details
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except UnicodeDecodeError:
|
||||
return ToolResult.fail(f"Error: File is not a valid text file (encoding error): {display_path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading file: {str(e)}")
|
||||
|
||||
def _read_pdf(self, absolute_path: str, display_path: str, offset: int = None, limit: int = None) -> ToolResult:
|
||||
"""
|
||||
Read PDF file content
|
||||
|
||||
:param absolute_path: Absolute path to the file
|
||||
:param display_path: Path to display
|
||||
:param offset: Starting line number (1-indexed)
|
||||
:param limit: Maximum number of lines to read
|
||||
:return: PDF text content or error message
|
||||
"""
|
||||
try:
|
||||
# Try to import pypdf
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
return ToolResult.fail(
|
||||
"Error: pypdf library not installed. Install with: pip install pypdf"
|
||||
)
|
||||
|
||||
# Read PDF
|
||||
reader = PdfReader(absolute_path)
|
||||
total_pages = len(reader.pages)
|
||||
|
||||
# Extract text from all pages
|
||||
text_parts = []
|
||||
for page_num, page in enumerate(reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_parts.append(f"--- Page {page_num} ---\n{page_text}")
|
||||
|
||||
if not text_parts:
|
||||
return ToolResult.success({
|
||||
"content": f"[PDF file with {total_pages} pages, but no text content could be extracted]",
|
||||
"total_pages": total_pages,
|
||||
"message": "PDF may contain only images or be encrypted"
|
||||
})
|
||||
|
||||
# Merge all text
|
||||
full_content = "\n\n".join(text_parts)
|
||||
all_lines = full_content.split('\n')
|
||||
total_lines = len(all_lines)
|
||||
|
||||
# Apply offset and limit (same logic as text files)
|
||||
start_line = 0
|
||||
if offset is not None:
|
||||
start_line = max(0, offset - 1)
|
||||
if start_line >= total_lines:
|
||||
return ToolResult.fail(
|
||||
f"Error: Offset {offset} is beyond end of content ({total_lines} lines total)"
|
||||
)
|
||||
|
||||
start_line_display = start_line + 1
|
||||
|
||||
selected_content = full_content
|
||||
user_limited_lines = None
|
||||
if limit is not None:
|
||||
end_line = min(start_line + limit, total_lines)
|
||||
selected_content = '\n'.join(all_lines[start_line:end_line])
|
||||
user_limited_lines = end_line - start_line
|
||||
elif offset is not None:
|
||||
selected_content = '\n'.join(all_lines[start_line:])
|
||||
|
||||
# Apply truncation
|
||||
truncation = truncate_head(selected_content)
|
||||
|
||||
output_text = ""
|
||||
details = {}
|
||||
|
||||
if truncation.truncated:
|
||||
end_line_display = start_line_display + truncation.output_lines - 1
|
||||
next_offset = end_line_display + 1
|
||||
|
||||
output_text = truncation.content
|
||||
|
||||
if truncation.truncated_by == "lines":
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines}. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text += f"\n\n[Showing lines {start_line_display}-{end_line_display} of {total_lines} ({format_size(DEFAULT_MAX_BYTES)} limit). Use offset={next_offset} to continue.]"
|
||||
|
||||
details["truncation"] = truncation.to_dict()
|
||||
elif user_limited_lines is not None and start_line + user_limited_lines < total_lines:
|
||||
remaining = total_lines - (start_line + user_limited_lines)
|
||||
next_offset = start_line + user_limited_lines + 1
|
||||
|
||||
output_text = truncation.content
|
||||
output_text += f"\n\n[{remaining} more lines in file. Use offset={next_offset} to continue.]"
|
||||
else:
|
||||
output_text = truncation.content
|
||||
|
||||
result = {
|
||||
"content": output_text,
|
||||
"total_pages": total_pages,
|
||||
"total_lines": total_lines,
|
||||
"start_line": start_line_display,
|
||||
"output_lines": truncation.output_lines
|
||||
}
|
||||
|
||||
if details:
|
||||
result["details"] = details
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error reading PDF file: {str(e)}")
|
||||
287
agent/tools/scheduler/README.md
Normal file
287
agent/tools/scheduler/README.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# 定时任务工具 (Scheduler Tool)
|
||||
|
||||
## 功能简介
|
||||
|
||||
定时任务工具允许 Agent 创建、管理和执行定时任务,支持:
|
||||
|
||||
- ⏰ **定时提醒**: 在指定时间发送消息
|
||||
- 🔄 **周期性任务**: 按固定间隔或 cron 表达式重复执行
|
||||
- 🔧 **动态工具调用**: 定时执行其他工具并发送结果(如搜索新闻、查询天气等)
|
||||
- 📋 **任务管理**: 查询、启用、禁用、删除任务
|
||||
|
||||
## 安装依赖
|
||||
|
||||
```bash
|
||||
pip install croniter>=2.0.0
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 创建定时任务
|
||||
|
||||
Agent 可以通过自然语言创建定时任务,支持两种类型:
|
||||
|
||||
#### 1.1 静态消息任务
|
||||
|
||||
发送预定义的消息:
|
||||
|
||||
**示例对话:**
|
||||
```
|
||||
用户: 每天早上9点提醒我开会
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: create
|
||||
name: 每日开会提醒
|
||||
message: 该开会了!
|
||||
schedule_type: cron
|
||||
schedule_value: 0 9 * * *
|
||||
```
|
||||
|
||||
#### 1.2 动态工具调用任务
|
||||
|
||||
定时执行工具并发送结果:
|
||||
|
||||
**示例对话:**
|
||||
```
|
||||
用户: 每天早上8点帮我读取一下今日日程
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: create
|
||||
name: 每日日程
|
||||
tool_call:
|
||||
tool_name: read
|
||||
tool_params:
|
||||
file_path: ~/cow/schedule.txt
|
||||
result_prefix: 📅 今日日程
|
||||
schedule_type: cron
|
||||
schedule_value: 0 8 * * *
|
||||
```
|
||||
|
||||
**工具调用参数说明:**
|
||||
- `tool_name`: 要调用的工具名称(如 `bash`、`read`、`write` 等内置工具)
|
||||
- `tool_params`: 工具的参数(字典格式)
|
||||
- `result_prefix`: 可选,在结果前添加的前缀文本
|
||||
|
||||
**注意:** 如果要使用 skills(如 bocha-search),需要通过 `bash` 工具调用 skill 脚本
|
||||
|
||||
### 2. 支持的调度类型
|
||||
|
||||
#### Cron 表达式 (`cron`)
|
||||
使用标准 cron 表达式:
|
||||
|
||||
```
|
||||
0 9 * * * # 每天 9:00
|
||||
0 */2 * * * # 每 2 小时
|
||||
30 8 * * 1-5 # 工作日 8:30
|
||||
0 0 1 * * # 每月 1 号
|
||||
```
|
||||
|
||||
#### 固定间隔 (`interval`)
|
||||
以秒为单位的间隔:
|
||||
|
||||
```
|
||||
3600 # 每小时
|
||||
86400 # 每天
|
||||
1800 # 每 30 分钟
|
||||
```
|
||||
|
||||
#### 一次性任务 (`once`)
|
||||
指定具体时间(ISO 格式):
|
||||
|
||||
```
|
||||
2024-12-25T09:00:00
|
||||
2024-12-31T23:59:59
|
||||
```
|
||||
|
||||
### 3. 查询任务列表
|
||||
|
||||
```
|
||||
用户: 查看我的定时任务
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: list
|
||||
```
|
||||
|
||||
### 4. 查看任务详情
|
||||
|
||||
```
|
||||
用户: 查看任务 abc123 的详情
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: get
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
### 5. 删除任务
|
||||
|
||||
```
|
||||
用户: 删除任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: delete
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
### 6. 启用/禁用任务
|
||||
|
||||
```
|
||||
用户: 暂停任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: disable
|
||||
task_id: abc123
|
||||
|
||||
用户: 恢复任务 abc123
|
||||
Agent: [调用 scheduler 工具]
|
||||
action: enable
|
||||
task_id: abc123
|
||||
```
|
||||
|
||||
## 任务存储
|
||||
|
||||
任务保存在 JSON 文件中:
|
||||
```
|
||||
~/cow/scheduler/tasks.json
|
||||
```
|
||||
|
||||
任务数据结构:
|
||||
|
||||
**静态消息任务:**
|
||||
```json
|
||||
{
|
||||
"id": "abc123",
|
||||
"name": "每日提醒",
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-01T10:00:00",
|
||||
"updated_at": "2024-01-01T10:00:00",
|
||||
"schedule": {
|
||||
"type": "cron",
|
||||
"expression": "0 9 * * *"
|
||||
},
|
||||
"action": {
|
||||
"type": "send_message",
|
||||
"content": "该开会了!",
|
||||
"receiver": "wxid_xxx",
|
||||
"receiver_name": "张三",
|
||||
"is_group": false,
|
||||
"channel_type": "wechat"
|
||||
},
|
||||
"next_run_at": "2024-01-02T09:00:00",
|
||||
"last_run_at": "2024-01-01T09:00:00"
|
||||
}
|
||||
```
|
||||
|
||||
**动态工具调用任务:**
|
||||
```json
|
||||
{
|
||||
"id": "def456",
|
||||
"name": "每日日程",
|
||||
"enabled": true,
|
||||
"created_at": "2024-01-01T10:00:00",
|
||||
"updated_at": "2024-01-01T10:00:00",
|
||||
"schedule": {
|
||||
"type": "cron",
|
||||
"expression": "0 8 * * *"
|
||||
},
|
||||
"action": {
|
||||
"type": "tool_call",
|
||||
"tool_name": "read",
|
||||
"tool_params": {
|
||||
"file_path": "~/cow/schedule.txt"
|
||||
},
|
||||
"result_prefix": "📅 今日日程",
|
||||
"receiver": "wxid_xxx",
|
||||
"receiver_name": "张三",
|
||||
"is_group": false,
|
||||
"channel_type": "wechat"
|
||||
},
|
||||
"next_run_at": "2024-01-02T08:00:00"
|
||||
}
|
||||
```
|
||||
|
||||
## 后台服务
|
||||
|
||||
定时任务由后台服务 `SchedulerService` 管理:
|
||||
|
||||
- 每 30 秒检查一次到期任务
|
||||
- 自动执行到期任务
|
||||
- 计算下次执行时间
|
||||
- 记录执行历史和错误
|
||||
|
||||
服务在 Agent 初始化时自动启动,无需手动配置。
|
||||
|
||||
## 接收者确定
|
||||
|
||||
定时任务会发送给**创建任务时的对话对象**:
|
||||
|
||||
- 如果在私聊中创建,发送给该用户
|
||||
- 如果在群聊中创建,发送到该群
|
||||
- 接收者信息在创建时自动保存
|
||||
|
||||
## 常见用例
|
||||
|
||||
### 1. 每日提醒(静态消息)
|
||||
```
|
||||
用户: 每天早上8点提醒我吃药
|
||||
Agent: ✅ 定时任务创建成功
|
||||
任务ID: a1b2c3d4
|
||||
调度: 每天 8:00
|
||||
消息: 该吃药了!
|
||||
```
|
||||
|
||||
### 2. 工作日提醒(静态消息)
|
||||
```
|
||||
用户: 工作日下午6点提醒我下班
|
||||
Agent: [创建 cron: 0 18 * * 1-5]
|
||||
消息: 该下班了!
|
||||
```
|
||||
|
||||
### 3. 倒计时提醒(静态消息)
|
||||
```
|
||||
用户: 1小时后提醒我
|
||||
Agent: [创建 interval: 3600]
|
||||
```
|
||||
|
||||
### 4. 每日日程推送(动态工具调用)
|
||||
```
|
||||
用户: 每天早上8点帮我读取今日日程
|
||||
Agent: ✅ 定时任务创建成功
|
||||
任务ID: schedule001
|
||||
调度: 每天 8:00
|
||||
工具: read(file_path='~/cow/schedule.txt')
|
||||
前缀: 📅 今日日程
|
||||
```
|
||||
|
||||
### 5. 定时文件备份(动态工具调用)
|
||||
```
|
||||
用户: 每天晚上11点备份工作文件
|
||||
Agent: [创建 cron: 0 23 * * *]
|
||||
工具: bash(command='cp ~/cow/work.txt ~/cow/backup/work_$(date +%Y%m%d).txt')
|
||||
前缀: ✅ 文件已备份
|
||||
```
|
||||
|
||||
### 6. 周报提醒(静态消息)
|
||||
```
|
||||
用户: 每周五下午5点提醒我写周报
|
||||
Agent: [创建 cron: 0 17 * * 5]
|
||||
消息: 📊 该写周报了!
|
||||
```
|
||||
|
||||
### 4. 特定日期提醒
|
||||
```
|
||||
用户: 12月25日早上9点提醒我圣诞快乐
|
||||
Agent: [创建 once: 2024-12-25T09:00:00]
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **时区**: 使用系统本地时区
|
||||
2. **精度**: 检查间隔为 30 秒,实际执行可能有 ±30 秒误差
|
||||
3. **持久化**: 任务保存在文件中,重启后自动恢复
|
||||
4. **一次性任务**: 执行后自动禁用,不会删除(可手动删除)
|
||||
5. **错误处理**: 执行失败会记录错误,不影响其他任务
|
||||
|
||||
## 技术实现
|
||||
|
||||
- **TaskStore**: 任务持久化存储
|
||||
- **SchedulerService**: 后台调度服务
|
||||
- **SchedulerTool**: Agent 工具接口
|
||||
- **Integration**: 与 AgentBridge 集成
|
||||
|
||||
## 依赖
|
||||
|
||||
- `croniter`: Cron 表达式解析(轻量级,仅 ~50KB)
|
||||
7
agent/tools/scheduler/__init__.py
Normal file
7
agent/tools/scheduler/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Scheduler tool for managing scheduled tasks
|
||||
"""
|
||||
|
||||
from .scheduler_tool import SchedulerTool
|
||||
|
||||
__all__ = ["SchedulerTool"]
|
||||
457
agent/tools/scheduler/integration.py
Normal file
457
agent/tools/scheduler/integration.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
Integration module for scheduler with AgentBridge
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
# Global scheduler service instance
|
||||
_scheduler_service = None
|
||||
_task_store = None
|
||||
|
||||
|
||||
def init_scheduler(agent_bridge) -> bool:
|
||||
"""
|
||||
Initialize scheduler service
|
||||
|
||||
Args:
|
||||
agent_bridge: AgentBridge instance
|
||||
|
||||
Returns:
|
||||
True if initialized successfully
|
||||
"""
|
||||
global _scheduler_service, _task_store
|
||||
|
||||
try:
|
||||
from agent.tools.scheduler.task_store import TaskStore
|
||||
from agent.tools.scheduler.scheduler_service import SchedulerService
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
store_path = os.path.join(workspace_root, "scheduler", "tasks.json")
|
||||
|
||||
# Create task store
|
||||
_task_store = TaskStore(store_path)
|
||||
logger.debug(f"[Scheduler] Task store initialized: {store_path}")
|
||||
|
||||
# Create execute callback
|
||||
def execute_task_callback(task: dict):
|
||||
"""Callback to execute a scheduled task"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "agent_task":
|
||||
_execute_agent_task(task, agent_bridge)
|
||||
elif action_type == "send_message":
|
||||
# Legacy support for old tasks
|
||||
_execute_send_message(task, agent_bridge)
|
||||
elif action_type == "tool_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_tool_call(task, agent_bridge)
|
||||
elif action_type == "skill_call":
|
||||
# Legacy support for old tasks
|
||||
_execute_skill_call(task, agent_bridge)
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Unknown action type: {action_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task.get('id')}: {e}")
|
||||
|
||||
# Create scheduler service
|
||||
_scheduler_service = SchedulerService(_task_store, execute_task_callback)
|
||||
_scheduler_service.start()
|
||||
|
||||
logger.debug("[Scheduler] Scheduler service initialized and started")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to initialize scheduler: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_task_store():
|
||||
"""Get the global task store instance"""
|
||||
return _task_store
|
||||
|
||||
|
||||
def get_scheduler_service():
|
||||
"""Get the global scheduler service instance"""
|
||||
return _scheduler_service
|
||||
|
||||
|
||||
def _execute_agent_task(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute an agent_task action - let Agent handle the task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
task_description = action.get("task_description")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not task_description:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No task_description specified")
|
||||
return
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Check for unsupported channels
|
||||
if channel_type == "dingtalk":
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk channel does not support scheduled messages (Stream mode limitation). Task will execute but message cannot be sent.")
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing agent task '{task_description}'")
|
||||
|
||||
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
|
||||
# Format: scheduler_<receiver>_<task_id> to ensure isolation
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
|
||||
# Create context for Agent
|
||||
context = Context(ContextType.TEXT, task_description)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
# Channel-specific setup
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk requires msg object, set to None for scheduled tasks
|
||||
context["msg"] = None
|
||||
# 如果是单聊,需要传递 sender_staff_id
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
|
||||
# Use Agent to execute the task
|
||||
# Mark this as a scheduled task execution to prevent recursive task creation
|
||||
context["is_scheduled_task"] = True
|
||||
|
||||
try:
|
||||
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
|
||||
reply = agent_bridge.agent_reply(task_description, context=context, on_event=None, clear_history=False)
|
||||
|
||||
if reply and reply.content:
|
||||
# Send the reply via channel
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register request_id
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
request_id = context.get("request_id")
|
||||
if request_id:
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
# Send the reply
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed successfully, result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send result: {e}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from agent execution")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute task via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_agent_task: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def _execute_send_message(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a send_message action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
content = action.get("content", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Create context for sending message
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific context setup
|
||||
if channel_type == "web":
|
||||
# Web channel needs request_id
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
|
||||
# Use chat_id for groups, open_id for private chats
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
# Keep isgroup as is, but set msg to None (no original message to reply to)
|
||||
# Feishu channel will detect this and send as new message instead of reply
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
elif channel_type == "dingtalk":
|
||||
# DingTalk channel setup
|
||||
context["msg"] = None
|
||||
# 如果是单聊,需要传递 sender_staff_id
|
||||
if not is_group:
|
||||
sender_staff_id = action.get("dingtalk_sender_staff_id")
|
||||
if sender_staff_id:
|
||||
context["dingtalk_sender_staff_id"] = sender_staff_id
|
||||
logger.debug(f"[Scheduler] DingTalk single chat: sender_staff_id={sender_staff_id}")
|
||||
else:
|
||||
logger.warning(f"[Scheduler] Task {task['id']}: DingTalk single chat message missing sender_staff_id")
|
||||
|
||||
# Create reply
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent message to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_send_message: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def _execute_tool_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a tool_call action
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
tool_name = action.get("call_name") or action.get("tool_name")
|
||||
tool_params = action.get("call_params") or action.get("tool_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("is_group", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not tool_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No tool_name specified")
|
||||
return
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
# Get tool manager and create tool instance
|
||||
from agent.tools.tool_manager import ToolManager
|
||||
tool_manager = ToolManager()
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
if not tool:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: Tool '{tool_name}' not found")
|
||||
return
|
||||
|
||||
# Execute tool
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing tool '{tool_name}' with params {tool_params}")
|
||||
result = tool.execute(tool_params)
|
||||
|
||||
# Get result content
|
||||
if hasattr(result, 'result'):
|
||||
content = result.result
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
# Add prefix if specified
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
# Send result as message
|
||||
context = Context(ContextType.TEXT, content)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = receiver
|
||||
|
||||
# Channel-specific context setup
|
||||
if channel_type == "web":
|
||||
# Web channel needs request_id
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
logger.debug(f"[Scheduler] Generated request_id for web channel: {request_id}")
|
||||
elif channel_type == "feishu":
|
||||
# Feishu channel: for scheduled tasks, send as new message (no msg_id to reply to)
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
logger.debug(f"[Scheduler] Feishu: receive_id_type={context['receive_id_type']}, is_group={is_group}, receiver={receiver}")
|
||||
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
|
||||
# Get channel and send
|
||||
from channel.channel_factory import create_channel
|
||||
|
||||
try:
|
||||
channel = create_channel(channel_type)
|
||||
if channel:
|
||||
# For web channel, register the request_id to session mapping
|
||||
if channel_type == "web" and hasattr(channel, 'request_to_session'):
|
||||
channel.request_to_session[request_id] = receiver
|
||||
logger.debug(f"[Scheduler] Registered request_id {request_id} -> session {receiver}")
|
||||
|
||||
channel.send(reply, context)
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: sent tool result to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Failed to create channel: {channel_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to send tool result: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_tool_call: {e}")
|
||||
|
||||
|
||||
def _execute_skill_call(task: dict, agent_bridge):
|
||||
"""
|
||||
Execute a skill_call action by asking Agent to run the skill
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
agent_bridge: AgentBridge instance
|
||||
"""
|
||||
try:
|
||||
action = task.get("action", {})
|
||||
# Support both old and new field names
|
||||
skill_name = action.get("call_name") or action.get("skill_name")
|
||||
skill_params = action.get("call_params") or action.get("skill_params", {})
|
||||
result_prefix = action.get("result_prefix", "")
|
||||
receiver = action.get("receiver")
|
||||
is_group = action.get("isgroup", False)
|
||||
channel_type = action.get("channel_type", "unknown")
|
||||
|
||||
if not skill_name:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No skill_name specified")
|
||||
return
|
||||
|
||||
if not receiver:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No receiver specified")
|
||||
return
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']}: Executing skill '{skill_name}' with params {skill_params}")
|
||||
|
||||
# Create a unique session_id for this scheduled task to avoid polluting user's conversation
|
||||
# Format: scheduler_<receiver>_<task_id> to ensure isolation
|
||||
scheduler_session_id = f"scheduler_{receiver}_{task['id']}"
|
||||
|
||||
# Build a natural language query for the Agent to execute the skill
|
||||
# Format: "Use skill-name to do something with params"
|
||||
param_str = ", ".join([f"{k}={v}" for k, v in skill_params.items()])
|
||||
query = f"Use {skill_name} skill"
|
||||
if param_str:
|
||||
query += f" with {param_str}"
|
||||
|
||||
# Create context for Agent
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context["receiver"] = receiver
|
||||
context["isgroup"] = is_group
|
||||
context["session_id"] = scheduler_session_id
|
||||
|
||||
# Channel-specific setup
|
||||
if channel_type == "web":
|
||||
import uuid
|
||||
request_id = f"scheduler_{task['id']}_{uuid.uuid4().hex[:8]}"
|
||||
context["request_id"] = request_id
|
||||
elif channel_type == "feishu":
|
||||
context["receive_id_type"] = "chat_id" if is_group else "open_id"
|
||||
context["msg"] = None
|
||||
|
||||
# Use Agent to execute the skill
|
||||
try:
|
||||
# Don't clear history - scheduler tasks use isolated session_id so they won't pollute user conversations
|
||||
reply = agent_bridge.agent_reply(query, context=context, on_event=None, clear_history=False)
|
||||
|
||||
if reply and reply.content:
|
||||
content = reply.content
|
||||
|
||||
# Add prefix if specified
|
||||
if result_prefix:
|
||||
content = f"{result_prefix}\n\n{content}"
|
||||
|
||||
logger.info(f"[Scheduler] Task {task['id']} executed: skill result sent to {receiver}")
|
||||
else:
|
||||
logger.error(f"[Scheduler] Task {task['id']}: No result from skill execution")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Failed to execute skill via Agent: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in _execute_skill_call: {e}")
|
||||
import traceback
|
||||
logger.error(f"[Scheduler] Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def attach_scheduler_to_tool(tool, context: Context = None):
|
||||
"""
|
||||
Attach scheduler components to a SchedulerTool instance
|
||||
|
||||
Args:
|
||||
tool: SchedulerTool instance
|
||||
context: Current context (optional)
|
||||
"""
|
||||
if _task_store:
|
||||
tool.task_store = _task_store
|
||||
|
||||
if context:
|
||||
tool.current_context = context
|
||||
|
||||
channel_type = context.get("channel_type") or conf().get("channel_type", "unknown")
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
tool.config["channel_type"] = channel_type
|
||||
220
agent/tools/scheduler/scheduler_service.py
Normal file
220
agent/tools/scheduler/scheduler_service.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Background scheduler service for executing scheduled tasks
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Optional
|
||||
from croniter import croniter
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""
|
||||
Background service that executes scheduled tasks
|
||||
"""
|
||||
|
||||
def __init__(self, task_store, execute_callback: Callable):
|
||||
"""
|
||||
Initialize scheduler service
|
||||
|
||||
Args:
|
||||
task_store: TaskStore instance
|
||||
execute_callback: Function to call when executing a task
|
||||
"""
|
||||
self.task_store = task_store
|
||||
self.execute_callback = execute_callback
|
||||
self.running = False
|
||||
self.thread = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def start(self):
|
||||
"""Start the scheduler service"""
|
||||
with self._lock:
|
||||
if self.running:
|
||||
logger.warning("[Scheduler] Service already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self.thread.start()
|
||||
logger.debug("[Scheduler] Service started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the scheduler service"""
|
||||
with self._lock:
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
if self.thread:
|
||||
self.thread.join(timeout=5)
|
||||
logger.info("[Scheduler] Service stopped")
|
||||
|
||||
def _run_loop(self):
|
||||
"""Main scheduler loop"""
|
||||
logger.debug("[Scheduler] Scheduler loop started")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
self._check_and_execute_tasks()
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error in scheduler loop: {e}")
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
time.sleep(30)
|
||||
|
||||
def _check_and_execute_tasks(self):
|
||||
"""Check for due tasks and execute them"""
|
||||
now = datetime.now()
|
||||
tasks = self.task_store.list_tasks(enabled_only=True)
|
||||
|
||||
for task in tasks:
|
||||
try:
|
||||
# Check if task is due
|
||||
if self._is_task_due(task, now):
|
||||
logger.info(f"[Scheduler] Executing task: {task['id']} - {task['name']}")
|
||||
self._execute_task(task)
|
||||
|
||||
# Update next run time
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_run.isoformat(),
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
else:
|
||||
# One-time task, disable it
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task completed and disabled: {task['id']}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error processing task {task.get('id')}: {e}")
|
||||
|
||||
def _is_task_due(self, task: dict, now: datetime) -> bool:
|
||||
"""
|
||||
Check if a task is due to run
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
now: Current datetime
|
||||
|
||||
Returns:
|
||||
True if task should run now
|
||||
"""
|
||||
next_run_str = task.get("next_run_at")
|
||||
if not next_run_str:
|
||||
# Calculate initial next_run_at
|
||||
next_run = self._calculate_next_run(task, now)
|
||||
if next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_run.isoformat()
|
||||
})
|
||||
return False
|
||||
return False
|
||||
|
||||
try:
|
||||
next_run = datetime.fromisoformat(next_run_str)
|
||||
|
||||
# Check if task is overdue (e.g., service restart)
|
||||
if next_run < now:
|
||||
time_diff = (now - next_run).total_seconds()
|
||||
|
||||
# If overdue by more than 5 minutes, skip this run and schedule next
|
||||
if time_diff > 300: # 5 minutes
|
||||
logger.warning(f"[Scheduler] Task {task['id']} is overdue by {int(time_diff)}s, skipping and scheduling next run")
|
||||
|
||||
# For one-time tasks, disable them
|
||||
schedule = task.get("schedule", {})
|
||||
if schedule.get("type") == "once":
|
||||
self.task_store.update_task(task['id'], {
|
||||
"enabled": False,
|
||||
"last_run_at": now.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] One-time task {task['id']} expired, disabled")
|
||||
return False
|
||||
|
||||
# For recurring tasks, calculate next run from now
|
||||
next_next_run = self._calculate_next_run(task, now)
|
||||
if next_next_run:
|
||||
self.task_store.update_task(task['id'], {
|
||||
"next_run_at": next_next_run.isoformat()
|
||||
})
|
||||
logger.info(f"[Scheduler] Rescheduled task {task['id']} to {next_next_run}")
|
||||
return False
|
||||
|
||||
return now >= next_run
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _calculate_next_run(self, task: dict, from_time: datetime) -> Optional[datetime]:
|
||||
"""
|
||||
Calculate next run time for a task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
from_time: Calculate from this time
|
||||
|
||||
Returns:
|
||||
Next run datetime or None for one-time tasks
|
||||
"""
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
if schedule_type == "cron":
|
||||
# Cron expression
|
||||
expression = schedule.get("expression")
|
||||
if not expression:
|
||||
return None
|
||||
|
||||
try:
|
||||
cron = croniter(expression, from_time)
|
||||
return cron.get_next(datetime)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Invalid cron expression '{expression}': {e}")
|
||||
return None
|
||||
|
||||
elif schedule_type == "interval":
|
||||
# Interval in seconds
|
||||
seconds = schedule.get("seconds", 0)
|
||||
if seconds <= 0:
|
||||
return None
|
||||
return from_time + timedelta(seconds=seconds)
|
||||
|
||||
elif schedule_type == "once":
|
||||
# One-time task at specific time
|
||||
run_at_str = schedule.get("run_at")
|
||||
if not run_at_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
run_at = datetime.fromisoformat(run_at_str)
|
||||
# Only return if in the future
|
||||
if run_at > from_time:
|
||||
return run_at
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _execute_task(self, task: dict):
|
||||
"""
|
||||
Execute a task
|
||||
|
||||
Args:
|
||||
task: Task dictionary
|
||||
"""
|
||||
try:
|
||||
# Call the execute callback
|
||||
self.execute_callback(task)
|
||||
except Exception as e:
|
||||
logger.error(f"[Scheduler] Error executing task {task['id']}: {e}")
|
||||
# Update task with error
|
||||
self.task_store.update_task(task['id'], {
|
||||
"last_error": str(e),
|
||||
"last_error_at": datetime.now().isoformat()
|
||||
})
|
||||
443
agent/tools/scheduler/scheduler_tool.py
Normal file
443
agent/tools/scheduler/scheduler_tool.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Scheduler tool for creating and managing scheduled tasks
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from croniter import croniter
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class SchedulerTool(BaseTool):
|
||||
"""
|
||||
Tool for managing scheduled tasks (reminders, notifications, etc.)
|
||||
"""
|
||||
|
||||
name: str = "scheduler"
|
||||
description: str = (
|
||||
"创建、查询和管理定时任务(提醒、周期性任务等)。\n\n"
|
||||
"⚠️ 重要:仅当需要「定时/提醒/每天/每周/X分钟后/X点」等延迟或周期执行时才使用此工具。"
|
||||
"使用方法:\n"
|
||||
"- 创建:action='create', name='任务名', message/ai_task='内容', schedule_type='once/interval/cron', schedule_value='...'\n"
|
||||
"- 查询:action='list' / action='get', task_id='任务ID'\n"
|
||||
"- 管理:action='delete/enable/disable', task_id='任务ID'\n\n"
|
||||
"调度类型:\n"
|
||||
"- once: 一次性任务,支持相对时间(+5s,+10m,+1h,+1d)或ISO时间\n"
|
||||
"- interval: 固定间隔(秒),如3600表示每小时\n"
|
||||
"- cron: cron表达式,如'0 8 * * *'表示每天8点\n\n"
|
||||
"注意:'X秒后'用once+相对时间,'每X秒'用interval"
|
||||
)
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["create", "list", "get", "delete", "enable", "disable"],
|
||||
"description": "操作类型: create(创建), list(列表), get(查询), delete(删除), enable(启用), disable(禁用)"
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "任务ID (用于 get/delete/enable/disable 操作)"
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "任务名称 (用于 create 操作)"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "固定消息内容 (与ai_task二选一)"
|
||||
},
|
||||
"ai_task": {
|
||||
"type": "string",
|
||||
"description": "AI任务描述 (与message二选一),用于定时让AI执行的任务"
|
||||
},
|
||||
"schedule_type": {
|
||||
"type": "string",
|
||||
"enum": ["cron", "interval", "once"],
|
||||
"description": "调度类型 (用于 create 操作): cron(cron表达式), interval(固定间隔秒数), once(一次性)"
|
||||
},
|
||||
"schedule_value": {
|
||||
"type": "string",
|
||||
"description": "调度值: cron表达式/间隔秒数/时间(+5s,+10m,+1h或ISO格式)"
|
||||
}
|
||||
},
|
||||
"required": ["action"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
super().__init__()
|
||||
self.config = config or {}
|
||||
|
||||
# Will be set by agent bridge
|
||||
self.task_store = None
|
||||
self.current_context = None
|
||||
|
||||
def execute(self, params: dict) -> ToolResult:
|
||||
"""
|
||||
Execute scheduler operations
|
||||
|
||||
Args:
|
||||
params: Dictionary containing:
|
||||
- action: Operation type (create/list/get/delete/enable/disable)
|
||||
- Other parameters depending on action
|
||||
|
||||
Returns:
|
||||
ToolResult object
|
||||
"""
|
||||
# Extract parameters
|
||||
action = params.get("action")
|
||||
kwargs = params
|
||||
|
||||
if not self.task_store:
|
||||
return ToolResult.fail("错误: 定时任务系统未初始化")
|
||||
|
||||
try:
|
||||
if action == "create":
|
||||
result = self._create_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "list":
|
||||
result = self._list_tasks(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "get":
|
||||
result = self._get_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "delete":
|
||||
result = self._delete_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "enable":
|
||||
result = self._enable_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
elif action == "disable":
|
||||
result = self._disable_task(**kwargs)
|
||||
return ToolResult.success(result)
|
||||
else:
|
||||
return ToolResult.fail(f"未知操作: {action}")
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Error: {e}")
|
||||
return ToolResult.fail(f"操作失败: {str(e)}")
|
||||
|
||||
def _create_task(self, **kwargs) -> str:
|
||||
"""Create a new scheduled task"""
|
||||
name = kwargs.get("name")
|
||||
message = kwargs.get("message")
|
||||
ai_task = kwargs.get("ai_task")
|
||||
schedule_type = kwargs.get("schedule_type")
|
||||
schedule_value = kwargs.get("schedule_value")
|
||||
|
||||
# Validate required fields
|
||||
if not name:
|
||||
return "错误: 缺少任务名称 (name)"
|
||||
|
||||
# Check that exactly one of message/ai_task is provided
|
||||
if not message and not ai_task:
|
||||
return "错误: 必须提供 message(固定消息)或 ai_task(AI任务)之一"
|
||||
if message and ai_task:
|
||||
return "错误: message 和 ai_task 只能提供其中一个"
|
||||
|
||||
if not schedule_type:
|
||||
return "错误: 缺少调度类型 (schedule_type)"
|
||||
if not schedule_value:
|
||||
return "错误: 缺少调度值 (schedule_value)"
|
||||
|
||||
# Validate schedule
|
||||
schedule = self._parse_schedule(schedule_type, schedule_value)
|
||||
if not schedule:
|
||||
return f"错误: 无效的调度配置 - type: {schedule_type}, value: {schedule_value}"
|
||||
|
||||
# Get context info for receiver
|
||||
if not self.current_context:
|
||||
return "错误: 无法获取当前对话上下文"
|
||||
|
||||
context = self.current_context
|
||||
|
||||
# Create task
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Build action based on message or ai_task
|
||||
if message:
|
||||
action = {
|
||||
"type": "send_message",
|
||||
"content": message,
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
}
|
||||
else: # ai_task
|
||||
action = {
|
||||
"type": "agent_task",
|
||||
"task_description": ai_task,
|
||||
"receiver": context.get("receiver"),
|
||||
"receiver_name": self._get_receiver_name(context),
|
||||
"is_group": context.get("isgroup", False),
|
||||
"channel_type": self.config.get("channel_type", "unknown")
|
||||
}
|
||||
|
||||
# 针对钉钉单聊,额外存储 sender_staff_id
|
||||
msg = context.kwargs.get("msg")
|
||||
if msg and hasattr(msg, 'sender_staff_id') and not context.get("isgroup", False):
|
||||
action["dingtalk_sender_staff_id"] = msg.sender_staff_id
|
||||
|
||||
task_data = {
|
||||
"id": task_id,
|
||||
"name": name,
|
||||
"enabled": True,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"schedule": schedule,
|
||||
"action": action
|
||||
}
|
||||
|
||||
# Calculate initial next_run_at
|
||||
next_run = self._calculate_next_run(task_data)
|
||||
if next_run:
|
||||
task_data["next_run_at"] = next_run.isoformat()
|
||||
|
||||
# Save task
|
||||
self.task_store.add_task(task_data)
|
||||
|
||||
# Format response
|
||||
schedule_desc = self._format_schedule_description(schedule)
|
||||
receiver_desc = task_data["action"]["receiver_name"] or task_data["action"]["receiver"]
|
||||
|
||||
if message:
|
||||
content_desc = f"💬 固定消息: {message}"
|
||||
else:
|
||||
content_desc = f"🤖 AI任务: {ai_task}"
|
||||
|
||||
return (
|
||||
f"✅ 定时任务创建成功\n\n"
|
||||
f"📋 任务ID: {task_id}\n"
|
||||
f"📝 名称: {name}\n"
|
||||
f"⏰ 调度: {schedule_desc}\n"
|
||||
f"👤 接收者: {receiver_desc}\n"
|
||||
f"{content_desc}\n"
|
||||
f"🕐 下次执行: {next_run.strftime('%Y-%m-%d %H:%M:%S') if next_run else '未知'}"
|
||||
)
|
||||
|
||||
def _list_tasks(self, **kwargs) -> str:
|
||||
"""List all tasks"""
|
||||
tasks = self.task_store.list_tasks()
|
||||
|
||||
if not tasks:
|
||||
return "📋 暂无定时任务"
|
||||
|
||||
lines = [f"📋 定时任务列表 (共 {len(tasks)} 个)\n"]
|
||||
|
||||
for task in tasks:
|
||||
status = "✅" if task.get("enabled", True) else "❌"
|
||||
schedule_desc = self._format_schedule_description(task.get("schedule", {}))
|
||||
next_run = task.get("next_run_at")
|
||||
next_run_str = datetime.fromisoformat(next_run).strftime('%m-%d %H:%M') if next_run else "未知"
|
||||
|
||||
lines.append(
|
||||
f"{status} [{task['id']}] {task['name']}\n"
|
||||
f" ⏰ {schedule_desc} | 下次: {next_run_str}"
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _get_task(self, **kwargs) -> str:
|
||||
"""Get task details"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
status = "启用" if task.get("enabled", True) else "禁用"
|
||||
schedule_desc = self._format_schedule_description(task.get("schedule", {}))
|
||||
action = task.get("action", {})
|
||||
next_run = task.get("next_run_at")
|
||||
next_run_str = datetime.fromisoformat(next_run).strftime('%Y-%m-%d %H:%M:%S') if next_run else "未知"
|
||||
last_run = task.get("last_run_at")
|
||||
last_run_str = datetime.fromisoformat(last_run).strftime('%Y-%m-%d %H:%M:%S') if last_run else "从未执行"
|
||||
|
||||
return (
|
||||
f"📋 任务详情\n\n"
|
||||
f"ID: {task['id']}\n"
|
||||
f"名称: {task['name']}\n"
|
||||
f"状态: {status}\n"
|
||||
f"调度: {schedule_desc}\n"
|
||||
f"接收者: {action.get('receiver_name', action.get('receiver'))}\n"
|
||||
f"消息: {action.get('content')}\n"
|
||||
f"下次执行: {next_run_str}\n"
|
||||
f"上次执行: {last_run_str}\n"
|
||||
f"创建时间: {datetime.fromisoformat(task['created_at']).strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
)
|
||||
|
||||
def _delete_task(self, **kwargs) -> str:
|
||||
"""Delete a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.delete_task(task_id)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已删除"
|
||||
|
||||
def _enable_task(self, **kwargs) -> str:
|
||||
"""Enable a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.enable_task(task_id, True)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已启用"
|
||||
|
||||
def _disable_task(self, **kwargs) -> str:
|
||||
"""Disable a task"""
|
||||
task_id = kwargs.get("task_id")
|
||||
if not task_id:
|
||||
return "错误: 缺少任务ID (task_id)"
|
||||
|
||||
task = self.task_store.get_task(task_id)
|
||||
if not task:
|
||||
return f"错误: 任务 '{task_id}' 不存在"
|
||||
|
||||
self.task_store.enable_task(task_id, False)
|
||||
return f"✅ 任务 '{task['name']}' ({task_id}) 已禁用"
|
||||
|
||||
def _parse_schedule(self, schedule_type: str, schedule_value: str) -> Optional[dict]:
|
||||
"""Parse and validate schedule configuration"""
|
||||
try:
|
||||
if schedule_type == "cron":
|
||||
# Validate cron expression
|
||||
croniter(schedule_value)
|
||||
return {"type": "cron", "expression": schedule_value}
|
||||
|
||||
elif schedule_type == "interval":
|
||||
# Parse interval in seconds
|
||||
seconds = int(schedule_value)
|
||||
if seconds <= 0:
|
||||
return None
|
||||
return {"type": "interval", "seconds": seconds}
|
||||
|
||||
elif schedule_type == "once":
|
||||
# Parse datetime - support both relative and absolute time
|
||||
|
||||
# Check if it's relative time (e.g., "+5s", "+10m", "+1h", "+1d")
|
||||
if schedule_value.startswith("+"):
|
||||
import re
|
||||
match = re.match(r'\+(\d+)([smhd])', schedule_value)
|
||||
if match:
|
||||
amount = int(match.group(1))
|
||||
unit = match.group(2)
|
||||
|
||||
from datetime import timedelta
|
||||
now = datetime.now()
|
||||
|
||||
if unit == 's': # seconds
|
||||
target_time = now + timedelta(seconds=amount)
|
||||
elif unit == 'm': # minutes
|
||||
target_time = now + timedelta(minutes=amount)
|
||||
elif unit == 'h': # hours
|
||||
target_time = now + timedelta(hours=amount)
|
||||
elif unit == 'd': # days
|
||||
target_time = now + timedelta(days=amount)
|
||||
else:
|
||||
return None
|
||||
|
||||
return {"type": "once", "run_at": target_time.isoformat()}
|
||||
else:
|
||||
logger.error(f"[SchedulerTool] Invalid relative time format: {schedule_value}")
|
||||
return None
|
||||
else:
|
||||
# Absolute time in ISO format
|
||||
datetime.fromisoformat(schedule_value)
|
||||
return {"type": "once", "run_at": schedule_value}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SchedulerTool] Invalid schedule: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_next_run(self, task: dict) -> Optional[datetime]:
|
||||
"""Calculate next run time for a task"""
|
||||
schedule = task.get("schedule", {})
|
||||
schedule_type = schedule.get("type")
|
||||
now = datetime.now()
|
||||
|
||||
if schedule_type == "cron":
|
||||
expression = schedule.get("expression")
|
||||
cron = croniter(expression, now)
|
||||
return cron.get_next(datetime)
|
||||
|
||||
elif schedule_type == "interval":
|
||||
seconds = schedule.get("seconds", 0)
|
||||
from datetime import timedelta
|
||||
return now + timedelta(seconds=seconds)
|
||||
|
||||
elif schedule_type == "once":
|
||||
run_at_str = schedule.get("run_at")
|
||||
return datetime.fromisoformat(run_at_str)
|
||||
|
||||
return None
|
||||
|
||||
def _format_schedule_description(self, schedule: dict) -> str:
|
||||
"""Format schedule as human-readable description"""
|
||||
schedule_type = schedule.get("type")
|
||||
|
||||
if schedule_type == "cron":
|
||||
expr = schedule.get("expression", "")
|
||||
# Try to provide friendly description
|
||||
if expr == "0 9 * * *":
|
||||
return "每天 9:00"
|
||||
elif expr == "0 */1 * * *":
|
||||
return "每小时"
|
||||
elif expr == "*/30 * * * *":
|
||||
return "每30分钟"
|
||||
else:
|
||||
return f"Cron: {expr}"
|
||||
|
||||
elif schedule_type == "interval":
|
||||
seconds = schedule.get("seconds", 0)
|
||||
if seconds >= 86400:
|
||||
days = seconds // 86400
|
||||
return f"每 {days} 天"
|
||||
elif seconds >= 3600:
|
||||
hours = seconds // 3600
|
||||
return f"每 {hours} 小时"
|
||||
elif seconds >= 60:
|
||||
minutes = seconds // 60
|
||||
return f"每 {minutes} 分钟"
|
||||
else:
|
||||
return f"每 {seconds} 秒"
|
||||
|
||||
elif schedule_type == "once":
|
||||
run_at = schedule.get("run_at", "")
|
||||
try:
|
||||
dt = datetime.fromisoformat(run_at)
|
||||
return f"一次性 ({dt.strftime('%Y-%m-%d %H:%M')})"
|
||||
except Exception:
|
||||
return "一次性"
|
||||
|
||||
return "未知"
|
||||
|
||||
def _get_receiver_name(self, context: Context) -> str:
|
||||
"""Get receiver name from context"""
|
||||
try:
|
||||
msg = context.get("msg")
|
||||
if msg:
|
||||
if context.get("isgroup"):
|
||||
return msg.other_user_nickname or "群聊"
|
||||
else:
|
||||
return msg.from_user_nickname or "用户"
|
||||
except Exception:
|
||||
pass
|
||||
return "未知"
|
||||
201
agent/tools/scheduler/task_store.py
Normal file
201
agent/tools/scheduler/task_store.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Task storage management for scheduler
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class TaskStore:
|
||||
"""
|
||||
Manages persistent storage of scheduled tasks
|
||||
"""
|
||||
|
||||
def __init__(self, store_path: str = None):
|
||||
"""
|
||||
Initialize task store
|
||||
|
||||
Args:
|
||||
store_path: Path to tasks.json file. Defaults to ~/cow/scheduler/tasks.json
|
||||
"""
|
||||
if store_path is None:
|
||||
# Default to ~/cow/scheduler/tasks.json
|
||||
home = expand_path("~")
|
||||
store_path = os.path.join(home, "cow", "scheduler", "tasks.json")
|
||||
|
||||
self.store_path = store_path
|
||||
self.lock = threading.Lock()
|
||||
self._ensure_store_dir()
|
||||
|
||||
def _ensure_store_dir(self):
|
||||
"""Ensure the storage directory exists"""
|
||||
store_dir = os.path.dirname(self.store_path)
|
||||
os.makedirs(store_dir, exist_ok=True)
|
||||
|
||||
def load_tasks(self) -> Dict[str, dict]:
|
||||
"""
|
||||
Load all tasks from storage
|
||||
|
||||
Returns:
|
||||
Dictionary of task_id -> task_data
|
||||
"""
|
||||
with self.lock:
|
||||
if not os.path.exists(self.store_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(self.store_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data.get("tasks", {})
|
||||
except Exception as e:
|
||||
print(f"Error loading tasks: {e}")
|
||||
return {}
|
||||
|
||||
def save_tasks(self, tasks: Dict[str, dict]):
|
||||
"""
|
||||
Save all tasks to storage
|
||||
|
||||
Args:
|
||||
tasks: Dictionary of task_id -> task_data
|
||||
"""
|
||||
with self.lock:
|
||||
try:
|
||||
# Create backup
|
||||
if os.path.exists(self.store_path):
|
||||
backup_path = f"{self.store_path}.bak"
|
||||
try:
|
||||
with open(self.store_path, 'r') as src:
|
||||
with open(backup_path, 'w') as dst:
|
||||
dst.write(src.read())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save tasks
|
||||
data = {
|
||||
"version": 1,
|
||||
"updated_at": datetime.now().isoformat(),
|
||||
"tasks": tasks
|
||||
}
|
||||
|
||||
with open(self.store_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
print(f"Error saving tasks: {e}")
|
||||
raise
|
||||
|
||||
def add_task(self, task: dict) -> bool:
|
||||
"""
|
||||
Add a new task
|
||||
|
||||
Args:
|
||||
task: Task data dictionary
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
task_id = task.get("id")
|
||||
|
||||
if not task_id:
|
||||
raise ValueError("Task must have an 'id' field")
|
||||
|
||||
if task_id in tasks:
|
||||
raise ValueError(f"Task with id '{task_id}' already exists")
|
||||
|
||||
tasks[task_id] = task
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def update_task(self, task_id: str, updates: dict) -> bool:
|
||||
"""
|
||||
Update an existing task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
updates: Dictionary of fields to update
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
|
||||
if task_id not in tasks:
|
||||
raise ValueError(f"Task '{task_id}' not found")
|
||||
|
||||
# Update fields
|
||||
tasks[task_id].update(updates)
|
||||
tasks[task_id]["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def delete_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
|
||||
if task_id not in tasks:
|
||||
raise ValueError(f"Task '{task_id}' not found")
|
||||
|
||||
del tasks[task_id]
|
||||
self.save_tasks(tasks)
|
||||
return True
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Get a specific task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
Task data or None if not found
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
return tasks.get(task_id)
|
||||
|
||||
def list_tasks(self, enabled_only: bool = False) -> List[dict]:
|
||||
"""
|
||||
List all tasks
|
||||
|
||||
Args:
|
||||
enabled_only: If True, only return enabled tasks
|
||||
|
||||
Returns:
|
||||
List of task dictionaries
|
||||
"""
|
||||
tasks = self.load_tasks()
|
||||
task_list = list(tasks.values())
|
||||
|
||||
if enabled_only:
|
||||
task_list = [t for t in task_list if t.get("enabled", True)]
|
||||
|
||||
# Sort by next_run_at
|
||||
task_list.sort(key=lambda t: t.get("next_run_at", float('inf')))
|
||||
|
||||
return task_list
|
||||
|
||||
def enable_task(self, task_id: str, enabled: bool = True) -> bool:
|
||||
"""
|
||||
Enable or disable a task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
enabled: True to enable, False to disable
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
return self.update_task(task_id, {"enabled": enabled})
|
||||
3
agent/tools/send/__init__.py
Normal file
3
agent/tools/send/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .send import Send
|
||||
|
||||
__all__ = ['Send']
|
||||
160
agent/tools/send/send.py
Normal file
160
agent/tools/send/send.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Send tool - Send files to the user
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Send(BaseTool):
|
||||
"""Tool for sending files to the user"""
|
||||
|
||||
name: str = "send"
|
||||
description: str = "Send a file (image, video, audio, document) to the user. Use this when the user explicitly asks to send/share a file."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to send. Can be absolute path or relative to workspace."
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Optional message to accompany the file"
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
|
||||
# Supported file types
|
||||
self.image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.svg', '.ico'}
|
||||
self.video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm', '.m4v'}
|
||||
self.audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac', '.aac', '.wma'}
|
||||
self.document_extensions = {'.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.txt', '.md'}
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file send operation
|
||||
|
||||
:param args: Contains file path and optional message
|
||||
:return: File metadata for channel to send
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
message = args.get("message", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(absolute_path):
|
||||
return ToolResult.fail(f"Error: File not found: {path}")
|
||||
|
||||
# Check if readable
|
||||
if not os.access(absolute_path, os.R_OK):
|
||||
return ToolResult.fail(f"Error: File is not readable: {path}")
|
||||
|
||||
# Get file info
|
||||
file_ext = Path(absolute_path).suffix.lower()
|
||||
file_size = os.path.getsize(absolute_path)
|
||||
file_name = Path(absolute_path).name
|
||||
|
||||
# Determine file type
|
||||
if file_ext in self.image_extensions:
|
||||
file_type = "image"
|
||||
mime_type = self._get_image_mime_type(file_ext)
|
||||
elif file_ext in self.video_extensions:
|
||||
file_type = "video"
|
||||
mime_type = self._get_video_mime_type(file_ext)
|
||||
elif file_ext in self.audio_extensions:
|
||||
file_type = "audio"
|
||||
mime_type = self._get_audio_mime_type(file_ext)
|
||||
elif file_ext in self.document_extensions:
|
||||
file_type = "document"
|
||||
mime_type = self._get_document_mime_type(file_ext)
|
||||
else:
|
||||
file_type = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
# Return file_to_send metadata
|
||||
result = {
|
||||
"type": "file_to_send",
|
||||
"file_type": file_type,
|
||||
"path": absolute_path,
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
"size": file_size,
|
||||
"size_formatted": self._format_size(file_size),
|
||||
"message": message or f"正在发送 {file_name}"
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""Resolve path to absolute path"""
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
|
||||
def _get_image_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for image"""
|
||||
mime_map = {
|
||||
'.jpg': 'image/jpeg', '.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png', '.gif': 'image/gif',
|
||||
'.webp': 'image/webp', '.bmp': 'image/bmp',
|
||||
'.svg': 'image/svg+xml', '.ico': 'image/x-icon'
|
||||
}
|
||||
return mime_map.get(ext, 'image/jpeg')
|
||||
|
||||
def _get_video_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for video"""
|
||||
mime_map = {
|
||||
'.mp4': 'video/mp4', '.avi': 'video/x-msvideo',
|
||||
'.mov': 'video/quicktime', '.mkv': 'video/x-matroska',
|
||||
'.webm': 'video/webm', '.flv': 'video/x-flv'
|
||||
}
|
||||
return mime_map.get(ext, 'video/mp4')
|
||||
|
||||
def _get_audio_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for audio"""
|
||||
mime_map = {
|
||||
'.mp3': 'audio/mpeg', '.wav': 'audio/wav',
|
||||
'.ogg': 'audio/ogg', '.m4a': 'audio/mp4',
|
||||
'.flac': 'audio/flac', '.aac': 'audio/aac'
|
||||
}
|
||||
return mime_map.get(ext, 'audio/mpeg')
|
||||
|
||||
def _get_document_mime_type(self, ext: str) -> str:
|
||||
"""Get MIME type for document"""
|
||||
mime_map = {
|
||||
'.pdf': 'application/pdf',
|
||||
'.doc': 'application/msword',
|
||||
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'.xls': 'application/vnd.ms-excel',
|
||||
'.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'.ppt': 'application/vnd.ms-powerpoint',
|
||||
'.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'.txt': 'text/plain',
|
||||
'.md': 'text/markdown'
|
||||
}
|
||||
return mime_map.get(ext, 'application/octet-stream')
|
||||
|
||||
def _format_size(self, size_bytes: int) -> str:
|
||||
"""Format file size in human-readable format"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||
if size_bytes < 1024.0:
|
||||
return f"{size_bytes:.1f}{unit}"
|
||||
size_bytes /= 1024.0
|
||||
return f"{size_bytes:.1f}TB"
|
||||
248
agent/tools/tool_manager.py
Normal file
248
agent/tools/tool_manager.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import importlib
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Type
|
||||
from agent.tools.base_tool import BaseTool
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
Tool manager for managing tools.
|
||||
"""
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern to ensure only one instance of ToolManager exists."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ToolManager, cls).__new__(cls)
|
||||
cls._instance.tool_classes = {} # Store tool classes instead of instances
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Initialize only once
|
||||
if not hasattr(self, 'tool_classes'):
|
||||
self.tool_classes = {} # Dictionary to store tool classes
|
||||
|
||||
def load_tools(self, tools_dir: str = "", config_dict=None):
|
||||
"""
|
||||
Load tools from both directory and configuration.
|
||||
|
||||
:param tools_dir: Directory to scan for tool modules
|
||||
"""
|
||||
if tools_dir:
|
||||
self._load_tools_from_directory(tools_dir)
|
||||
self._configure_tools_from_config()
|
||||
else:
|
||||
self._load_tools_from_init()
|
||||
self._configure_tools_from_config(config_dict)
|
||||
|
||||
def _load_tools_from_init(self) -> bool:
|
||||
"""
|
||||
Load tool classes from tools.__init__.__all__
|
||||
|
||||
:return: True if tools were loaded, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try to import the tools package
|
||||
tools_package = importlib.import_module("agent.tools")
|
||||
|
||||
# Check if __all__ is defined
|
||||
if hasattr(tools_package, "__all__"):
|
||||
tool_classes = tools_package.__all__
|
||||
|
||||
# Import each tool class directly from the tools package
|
||||
for class_name in tool_classes:
|
||||
try:
|
||||
# Skip base classes
|
||||
if class_name in ["BaseTool", "ToolManager"]:
|
||||
continue
|
||||
|
||||
# Get the class directly from the tools package
|
||||
if hasattr(tools_package, class_name):
|
||||
cls = getattr(tools_package, class_name)
|
||||
|
||||
if (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, BaseTool)
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip memory tools (they need special initialization with memory_manager)
|
||||
if class_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {class_name} (requires memory_manager)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
tool_name = temp_instance.name
|
||||
# Store the class, not the instance
|
||||
self.tool_classes[tool_name] = cls
|
||||
logger.debug(f"Loaded tool: {tool_name} from class {class_name}")
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
|
||||
f" Install with: pip install markdownify"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error importing class {class_name}: {e}")
|
||||
|
||||
return len(self.tool_classes) > 0
|
||||
return False
|
||||
except ImportError:
|
||||
logger.warning("Could not import agent.tools package")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading tools from __init__.__all__: {e}")
|
||||
return False
|
||||
|
||||
def _load_tools_from_directory(self, tools_dir: str):
|
||||
"""Dynamically load tool classes from directory"""
|
||||
tools_path = Path(tools_dir)
|
||||
|
||||
# Traverse all .py files
|
||||
for py_file in tools_path.rglob("*.py"):
|
||||
# Skip initialization files and base tool files
|
||||
if py_file.name in ["__init__.py", "base_tool.py", "tool_manager.py"]:
|
||||
continue
|
||||
|
||||
# Get module name
|
||||
module_name = py_file.stem
|
||||
|
||||
try:
|
||||
# Load module directly from file
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Find tool classes in the module
|
||||
for attr_name in dir(module):
|
||||
cls = getattr(module, attr_name)
|
||||
if (
|
||||
isinstance(cls, type)
|
||||
and issubclass(cls, BaseTool)
|
||||
and cls != BaseTool
|
||||
):
|
||||
try:
|
||||
# Skip memory tools (they need special initialization with memory_manager)
|
||||
if attr_name in ["MemorySearchTool", "MemoryGetTool"]:
|
||||
logger.debug(f"Skipped tool {attr_name} (requires memory_manager)")
|
||||
continue
|
||||
|
||||
# Create a temporary instance to get the name
|
||||
temp_instance = cls()
|
||||
tool_name = temp_instance.name
|
||||
# Store the class, not the instance
|
||||
self.tool_classes[tool_name] = cls
|
||||
except ImportError as e:
|
||||
# Handle missing dependencies with helpful messages
|
||||
error_msg = str(e)
|
||||
if "browser-use" in error_msg or "browser_use" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool not loaded - missing dependencies.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif "markdownify" in error_msg:
|
||||
logger.warning(
|
||||
f"[ToolManager] {cls.__name__} not loaded - missing markdownify.\n"
|
||||
f" Install with: pip install markdownify"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] {cls.__name__} not loaded due to missing dependency: {error_msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool class {cls.__name__}: {e}")
|
||||
except Exception as e:
|
||||
print(f"Error importing module {py_file}: {e}")
|
||||
|
||||
def _configure_tools_from_config(self, config_dict=None):
|
||||
"""Configure tool classes based on configuration file"""
|
||||
try:
|
||||
# Get tools configuration
|
||||
tools_config = config_dict or conf().get("tools", {})
|
||||
|
||||
# Record tools that are configured but not loaded
|
||||
missing_tools = []
|
||||
|
||||
# Store configurations for later use when instantiating
|
||||
self.tool_configs = tools_config
|
||||
|
||||
# Check which configured tools are missing
|
||||
for tool_name in tools_config:
|
||||
if tool_name not in self.tool_classes:
|
||||
missing_tools.append(tool_name)
|
||||
|
||||
# If there are missing tools, record warnings
|
||||
if missing_tools:
|
||||
for tool_name in missing_tools:
|
||||
if tool_name == "browser":
|
||||
logger.warning(
|
||||
f"[ToolManager] Browser tool is configured but not loaded.\n"
|
||||
f" To enable browser tool, run:\n"
|
||||
f" pip install browser-use markdownify playwright\n"
|
||||
f" playwright install chromium"
|
||||
)
|
||||
elif tool_name == "google_search":
|
||||
logger.warning(
|
||||
f"[ToolManager] Google Search tool is configured but may need API key.\n"
|
||||
f" Get API key from: https://serper.dev\n"
|
||||
f" Configure in config.json: tools.google_search.api_key"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[ToolManager] Tool '{tool_name}' is configured but could not be loaded.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error configuring tools from config: {e}")
|
||||
|
||||
def create_tool(self, name: str) -> BaseTool:
|
||||
"""
|
||||
Get a new instance of a tool by name.
|
||||
|
||||
:param name: The name of the tool to get.
|
||||
:return: A new instance of the tool or None if not found.
|
||||
"""
|
||||
tool_class = self.tool_classes.get(name)
|
||||
if tool_class:
|
||||
# Create a new instance
|
||||
tool_instance = tool_class()
|
||||
|
||||
# Apply configuration if available
|
||||
if hasattr(self, 'tool_configs') and name in self.tool_configs:
|
||||
tool_instance.config = self.tool_configs[name]
|
||||
|
||||
return tool_instance
|
||||
return None
|
||||
|
||||
def list_tools(self) -> dict:
|
||||
"""
|
||||
Get information about all loaded tools.
|
||||
|
||||
:return: A dictionary with tool information.
|
||||
"""
|
||||
result = {}
|
||||
for name, tool_class in self.tool_classes.items():
|
||||
# Create a temporary instance to get schema
|
||||
temp_instance = tool_class()
|
||||
result[name] = {
|
||||
"description": temp_instance.description,
|
||||
"parameters": temp_instance.get_json_schema()
|
||||
}
|
||||
return result
|
||||
40
agent/tools/utils/__init__.py
Normal file
40
agent/tools/utils/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from .truncate import (
|
||||
truncate_head,
|
||||
truncate_tail,
|
||||
truncate_line,
|
||||
format_size,
|
||||
TruncationResult,
|
||||
DEFAULT_MAX_LINES,
|
||||
DEFAULT_MAX_BYTES,
|
||||
GREP_MAX_LINE_LENGTH
|
||||
)
|
||||
|
||||
from .diff import (
|
||||
strip_bom,
|
||||
detect_line_ending,
|
||||
normalize_to_lf,
|
||||
restore_line_endings,
|
||||
normalize_for_fuzzy_match,
|
||||
fuzzy_find_text,
|
||||
generate_diff_string,
|
||||
FuzzyMatchResult
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'truncate_head',
|
||||
'truncate_tail',
|
||||
'truncate_line',
|
||||
'format_size',
|
||||
'TruncationResult',
|
||||
'DEFAULT_MAX_LINES',
|
||||
'DEFAULT_MAX_BYTES',
|
||||
'GREP_MAX_LINE_LENGTH',
|
||||
'strip_bom',
|
||||
'detect_line_ending',
|
||||
'normalize_to_lf',
|
||||
'restore_line_endings',
|
||||
'normalize_for_fuzzy_match',
|
||||
'fuzzy_find_text',
|
||||
'generate_diff_string',
|
||||
'FuzzyMatchResult'
|
||||
]
|
||||
167
agent/tools/utils/diff.py
Normal file
167
agent/tools/utils/diff.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Diff tools for file editing
|
||||
Provides fuzzy matching and diff generation functionality
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def strip_bom(text: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Remove BOM (Byte Order Mark)
|
||||
|
||||
:param text: Original text
|
||||
:return: (BOM, text after removing BOM)
|
||||
"""
|
||||
if text.startswith('\ufeff'):
|
||||
return '\ufeff', text[1:]
|
||||
return '', text
|
||||
|
||||
|
||||
def detect_line_ending(text: str) -> str:
|
||||
"""
|
||||
Detect line ending type
|
||||
|
||||
:param text: Text content
|
||||
:return: Line ending type ('\r\n' or '\n')
|
||||
"""
|
||||
if '\r\n' in text:
|
||||
return '\r\n'
|
||||
return '\n'
|
||||
|
||||
|
||||
def normalize_to_lf(text: str) -> str:
|
||||
"""
|
||||
Normalize all line endings to LF (\n)
|
||||
|
||||
:param text: Original text
|
||||
:return: Normalized text
|
||||
"""
|
||||
return text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
|
||||
def restore_line_endings(text: str, original_ending: str) -> str:
|
||||
"""
|
||||
Restore original line endings
|
||||
|
||||
:param text: LF normalized text
|
||||
:param original_ending: Original line ending
|
||||
:return: Text with restored line endings
|
||||
"""
|
||||
if original_ending == '\r\n':
|
||||
return text.replace('\n', '\r\n')
|
||||
return text
|
||||
|
||||
|
||||
def normalize_for_fuzzy_match(text: str) -> str:
|
||||
"""
|
||||
Normalize text for fuzzy matching
|
||||
Remove excess whitespace but preserve basic structure
|
||||
|
||||
:param text: Original text
|
||||
:return: Normalized text
|
||||
"""
|
||||
# Compress multiple spaces to one
|
||||
text = re.sub(r'[ \t]+', ' ', text)
|
||||
# Remove trailing spaces
|
||||
text = re.sub(r' +\n', '\n', text)
|
||||
# Remove leading spaces (but preserve indentation structure, only remove excess)
|
||||
lines = text.split('\n')
|
||||
normalized_lines = []
|
||||
for line in lines:
|
||||
# Preserve indentation but normalize to multiples of single spaces
|
||||
stripped = line.lstrip()
|
||||
if stripped:
|
||||
indent_count = len(line) - len(stripped)
|
||||
# Normalize indentation (convert tabs to spaces)
|
||||
normalized_indent = ' ' * indent_count
|
||||
normalized_lines.append(normalized_indent + stripped)
|
||||
else:
|
||||
normalized_lines.append('')
|
||||
return '\n'.join(normalized_lines)
|
||||
|
||||
|
||||
class FuzzyMatchResult:
|
||||
"""Fuzzy match result"""
|
||||
|
||||
def __init__(self, found: bool, index: int = -1, match_length: int = 0, content_for_replacement: str = ""):
|
||||
self.found = found
|
||||
self.index = index
|
||||
self.match_length = match_length
|
||||
self.content_for_replacement = content_for_replacement
|
||||
|
||||
|
||||
def fuzzy_find_text(content: str, old_text: str) -> FuzzyMatchResult:
|
||||
"""
|
||||
Find text in content, try exact match first, then fuzzy match
|
||||
|
||||
:param content: Content to search in
|
||||
:param old_text: Text to find
|
||||
:return: Match result
|
||||
"""
|
||||
# First try exact match
|
||||
index = content.find(old_text)
|
||||
if index != -1:
|
||||
return FuzzyMatchResult(
|
||||
found=True,
|
||||
index=index,
|
||||
match_length=len(old_text),
|
||||
content_for_replacement=content
|
||||
)
|
||||
|
||||
# Try fuzzy match
|
||||
fuzzy_content = normalize_for_fuzzy_match(content)
|
||||
fuzzy_old_text = normalize_for_fuzzy_match(old_text)
|
||||
|
||||
index = fuzzy_content.find(fuzzy_old_text)
|
||||
if index != -1:
|
||||
# Fuzzy match successful, use normalized content for replacement
|
||||
return FuzzyMatchResult(
|
||||
found=True,
|
||||
index=index,
|
||||
match_length=len(fuzzy_old_text),
|
||||
content_for_replacement=fuzzy_content
|
||||
)
|
||||
|
||||
# Not found
|
||||
return FuzzyMatchResult(found=False)
|
||||
|
||||
|
||||
def generate_diff_string(old_content: str, new_content: str) -> dict:
|
||||
"""
|
||||
Generate unified diff string
|
||||
|
||||
:param old_content: Old content
|
||||
:param new_content: New content
|
||||
:return: Dictionary containing diff and first changed line number
|
||||
"""
|
||||
old_lines = old_content.split('\n')
|
||||
new_lines = new_content.split('\n')
|
||||
|
||||
# Generate unified diff
|
||||
diff_lines = list(difflib.unified_diff(
|
||||
old_lines,
|
||||
new_lines,
|
||||
lineterm='',
|
||||
fromfile='original',
|
||||
tofile='modified'
|
||||
))
|
||||
|
||||
# Find first changed line number
|
||||
first_changed_line = None
|
||||
for line in diff_lines:
|
||||
if line.startswith('@@'):
|
||||
# Parse @@ -1,3 +1,3 @@ format
|
||||
match = re.search(r'@@ -\d+,?\d* \+(\d+)', line)
|
||||
if match:
|
||||
first_changed_line = int(match.group(1))
|
||||
break
|
||||
|
||||
diff_string = '\n'.join(diff_lines)
|
||||
|
||||
return {
|
||||
'diff': diff_string,
|
||||
'first_changed_line': first_changed_line
|
||||
}
|
||||
292
agent/tools/utils/truncate.py
Normal file
292
agent/tools/utils/truncate.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Shared truncation utilities for tool outputs.
|
||||
|
||||
Truncation is based on two independent limits - whichever is hit first wins:
|
||||
- Line limit (default: 2000 lines)
|
||||
- Byte limit (default: 50KB)
|
||||
|
||||
Never returns partial lines (except bash tail truncation edge case).
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Literal, Tuple
|
||||
|
||||
|
||||
DEFAULT_MAX_LINES = 2000
|
||||
DEFAULT_MAX_BYTES = 50 * 1024 # 50KB
|
||||
GREP_MAX_LINE_LENGTH = 500 # Max chars per grep match line
|
||||
|
||||
|
||||
class TruncationResult:
|
||||
"""Truncation result"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
truncated: bool,
|
||||
truncated_by: Optional[Literal["lines", "bytes"]],
|
||||
total_lines: int,
|
||||
total_bytes: int,
|
||||
output_lines: int,
|
||||
output_bytes: int,
|
||||
last_line_partial: bool = False,
|
||||
first_line_exceeds_limit: bool = False,
|
||||
max_lines: int = DEFAULT_MAX_LINES,
|
||||
max_bytes: int = DEFAULT_MAX_BYTES
|
||||
):
|
||||
self.content = content
|
||||
self.truncated = truncated
|
||||
self.truncated_by = truncated_by
|
||||
self.total_lines = total_lines
|
||||
self.total_bytes = total_bytes
|
||||
self.output_lines = output_lines
|
||||
self.output_bytes = output_bytes
|
||||
self.last_line_partial = last_line_partial
|
||||
self.first_line_exceeds_limit = first_line_exceeds_limit
|
||||
self.max_lines = max_lines
|
||||
self.max_bytes = max_bytes
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"content": self.content,
|
||||
"truncated": self.truncated,
|
||||
"truncated_by": self.truncated_by,
|
||||
"total_lines": self.total_lines,
|
||||
"total_bytes": self.total_bytes,
|
||||
"output_lines": self.output_lines,
|
||||
"output_bytes": self.output_bytes,
|
||||
"last_line_partial": self.last_line_partial,
|
||||
"first_line_exceeds_limit": self.first_line_exceeds_limit,
|
||||
"max_lines": self.max_lines,
|
||||
"max_bytes": self.max_bytes
|
||||
}
|
||||
|
||||
|
||||
def format_size(bytes_count: int) -> str:
|
||||
"""Format bytes as human-readable size"""
|
||||
if bytes_count < 1024:
|
||||
return f"{bytes_count}B"
|
||||
elif bytes_count < 1024 * 1024:
|
||||
return f"{bytes_count / 1024:.1f}KB"
|
||||
else:
|
||||
return f"{bytes_count / (1024 * 1024):.1f}MB"
|
||||
|
||||
|
||||
def truncate_head(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||
"""
|
||||
Truncate content from the head (keep first N lines/bytes).
|
||||
Suitable for file reads where you want to see the beginning.
|
||||
|
||||
Never returns partial lines. If first line exceeds byte limit,
|
||||
returns empty content with first_line_exceeds_limit=True.
|
||||
|
||||
:param content: Content to truncate
|
||||
:param max_lines: Maximum number of lines (default: 2000)
|
||||
:param max_bytes: Maximum number of bytes (default: 50KB)
|
||||
:return: Truncation result
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_MAX_BYTES
|
||||
|
||||
total_bytes = len(content.encode('utf-8'))
|
||||
lines = content.split('\n')
|
||||
total_lines = len(lines)
|
||||
|
||||
# Check if no truncation is needed
|
||||
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||
return TruncationResult(
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncated_by=None,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=total_lines,
|
||||
output_bytes=total_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Check if first line alone exceeds byte limit
|
||||
first_line_bytes = len(lines[0].encode('utf-8'))
|
||||
if first_line_bytes > max_bytes:
|
||||
return TruncationResult(
|
||||
content="",
|
||||
truncated=True,
|
||||
truncated_by="bytes",
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=0,
|
||||
output_bytes=0,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=True,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Collect complete lines that fit
|
||||
output_lines_arr = []
|
||||
output_bytes_count = 0
|
||||
truncated_by = "lines"
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if i >= max_lines:
|
||||
break
|
||||
|
||||
# Calculate line bytes (add 1 for newline if not first line)
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if i > 0 else 0)
|
||||
|
||||
if output_bytes_count + line_bytes > max_bytes:
|
||||
truncated_by = "bytes"
|
||||
break
|
||||
|
||||
output_lines_arr.append(line)
|
||||
output_bytes_count += line_bytes
|
||||
|
||||
# If exited due to line limit
|
||||
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||
truncated_by = "lines"
|
||||
|
||||
output_content = '\n'.join(output_lines_arr)
|
||||
final_output_bytes = len(output_content.encode('utf-8'))
|
||||
|
||||
return TruncationResult(
|
||||
content=output_content,
|
||||
truncated=True,
|
||||
truncated_by=truncated_by,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=len(output_lines_arr),
|
||||
output_bytes=final_output_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
|
||||
def truncate_tail(content: str, max_lines: Optional[int] = None, max_bytes: Optional[int] = None) -> TruncationResult:
|
||||
"""
|
||||
Truncate content from tail (keep last N lines/bytes).
|
||||
Suitable for bash output where you want to see the ending content (errors, final results).
|
||||
|
||||
If the last line of original content exceeds byte limit, may return partial first line.
|
||||
|
||||
:param content: Content to truncate
|
||||
:param max_lines: Maximum lines (default: 2000)
|
||||
:param max_bytes: Maximum bytes (default: 50KB)
|
||||
:return: Truncation result
|
||||
"""
|
||||
if max_lines is None:
|
||||
max_lines = DEFAULT_MAX_LINES
|
||||
if max_bytes is None:
|
||||
max_bytes = DEFAULT_MAX_BYTES
|
||||
|
||||
total_bytes = len(content.encode('utf-8'))
|
||||
lines = content.split('\n')
|
||||
total_lines = len(lines)
|
||||
|
||||
# Check if no truncation is needed
|
||||
if total_lines <= max_lines and total_bytes <= max_bytes:
|
||||
return TruncationResult(
|
||||
content=content,
|
||||
truncated=False,
|
||||
truncated_by=None,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=total_lines,
|
||||
output_bytes=total_bytes,
|
||||
last_line_partial=False,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
# Work backwards from the end
|
||||
output_lines_arr = []
|
||||
output_bytes_count = 0
|
||||
truncated_by = "lines"
|
||||
last_line_partial = False
|
||||
|
||||
for i in range(len(lines) - 1, -1, -1):
|
||||
if len(output_lines_arr) >= max_lines:
|
||||
break
|
||||
|
||||
line = lines[i]
|
||||
# Calculate line bytes (add newline if not the first added line)
|
||||
line_bytes = len(line.encode('utf-8')) + (1 if len(output_lines_arr) > 0 else 0)
|
||||
|
||||
if output_bytes_count + line_bytes > max_bytes:
|
||||
truncated_by = "bytes"
|
||||
# Edge case: if we haven't added any lines yet and this line exceeds maxBytes,
|
||||
# take the end portion of this line
|
||||
if len(output_lines_arr) == 0:
|
||||
truncated_line = _truncate_string_to_bytes_from_end(line, max_bytes)
|
||||
output_lines_arr.insert(0, truncated_line)
|
||||
output_bytes_count = len(truncated_line.encode('utf-8'))
|
||||
last_line_partial = True
|
||||
break
|
||||
|
||||
output_lines_arr.insert(0, line)
|
||||
output_bytes_count += line_bytes
|
||||
|
||||
# If exited due to line limit
|
||||
if len(output_lines_arr) >= max_lines and output_bytes_count <= max_bytes:
|
||||
truncated_by = "lines"
|
||||
|
||||
output_content = '\n'.join(output_lines_arr)
|
||||
final_output_bytes = len(output_content.encode('utf-8'))
|
||||
|
||||
return TruncationResult(
|
||||
content=output_content,
|
||||
truncated=True,
|
||||
truncated_by=truncated_by,
|
||||
total_lines=total_lines,
|
||||
total_bytes=total_bytes,
|
||||
output_lines=len(output_lines_arr),
|
||||
output_bytes=final_output_bytes,
|
||||
last_line_partial=last_line_partial,
|
||||
first_line_exceeds_limit=False,
|
||||
max_lines=max_lines,
|
||||
max_bytes=max_bytes
|
||||
)
|
||||
|
||||
|
||||
def _truncate_string_to_bytes_from_end(text: str, max_bytes: int) -> str:
|
||||
"""
|
||||
Truncate string to fit byte limit (from end).
|
||||
Properly handles multi-byte UTF-8 characters.
|
||||
|
||||
:param text: String to truncate
|
||||
:param max_bytes: Maximum bytes
|
||||
:return: Truncated string
|
||||
"""
|
||||
encoded = text.encode('utf-8')
|
||||
if len(encoded) <= max_bytes:
|
||||
return text
|
||||
|
||||
# Start from end, skip back maxBytes
|
||||
start = len(encoded) - max_bytes
|
||||
|
||||
# Find valid UTF-8 boundary (character start)
|
||||
while start < len(encoded) and (encoded[start] & 0xC0) == 0x80:
|
||||
start += 1
|
||||
|
||||
return encoded[start:].decode('utf-8', errors='ignore')
|
||||
|
||||
|
||||
def truncate_line(line: str, max_chars: int = GREP_MAX_LINE_LENGTH) -> Tuple[str, bool]:
|
||||
"""
|
||||
Truncate single line to max characters, add [truncated] suffix.
|
||||
Used for grep match lines.
|
||||
|
||||
:param line: Line to truncate
|
||||
:param max_chars: Maximum characters
|
||||
:return: (truncated text, whether truncated)
|
||||
"""
|
||||
if len(line) <= max_chars:
|
||||
return line, False
|
||||
return f"{line[:max_chars]}... [truncated]", True
|
||||
3
agent/tools/web_search/__init__.py
Normal file
3
agent/tools/web_search/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
|
||||
__all__ = ["WebSearch"]
|
||||
322
agent/tools/web_search/web_search.py
Normal file
322
agent/tools/web_search/web_search.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Web Search tool - Search the web using Bocha or LinkAI search API.
|
||||
Supports two backends with unified response format:
|
||||
1. Bocha Search (primary, requires BOCHA_API_KEY)
|
||||
2. LinkAI Search (fallback, requires LINKAI_API_KEY)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.log import logger
|
||||
|
||||
|
||||
# Default timeout for API requests (seconds)
|
||||
DEFAULT_TIMEOUT = 30
|
||||
|
||||
|
||||
class WebSearch(BaseTool):
|
||||
"""Tool for searching the web using Bocha or LinkAI search API"""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = (
|
||||
"Search the web for current information, news, research topics, or any real-time data. "
|
||||
"Returns web page titles, URLs, snippets, and optional summaries. "
|
||||
"Use this when the user asks about recent events, needs fact-checking, or wants up-to-date information."
|
||||
)
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query string"
|
||||
},
|
||||
"count": {
|
||||
"type": "integer",
|
||||
"description": "Number of results to return (1-50, default: 10)"
|
||||
},
|
||||
"freshness": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Time range filter. Options: "
|
||||
"'noLimit' (default), 'oneDay', 'oneWeek', 'oneMonth', 'oneYear', "
|
||||
"or date range like '2025-01-01..2025-02-01'"
|
||||
)
|
||||
},
|
||||
"summary": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to include text summary for each result (default: false)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self._backend = None # Will be resolved on first execute
|
||||
|
||||
@staticmethod
|
||||
def is_available() -> bool:
|
||||
"""Check if web search is available (at least one API key is configured)"""
|
||||
return bool(os.environ.get("BOCHA_API_KEY") or os.environ.get("LINKAI_API_KEY"))
|
||||
|
||||
def _resolve_backend(self) -> Optional[str]:
|
||||
"""
|
||||
Determine which search backend to use.
|
||||
Priority: Bocha > LinkAI
|
||||
|
||||
:return: 'bocha', 'linkai', or None
|
||||
"""
|
||||
if os.environ.get("BOCHA_API_KEY"):
|
||||
return "bocha"
|
||||
if os.environ.get("LINKAI_API_KEY"):
|
||||
return "linkai"
|
||||
return None
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute web search
|
||||
|
||||
:param args: Search parameters (query, count, freshness, summary)
|
||||
:return: Search results
|
||||
"""
|
||||
query = args.get("query", "").strip()
|
||||
if not query:
|
||||
return ToolResult.fail("Error: 'query' parameter is required")
|
||||
|
||||
count = args.get("count", 10)
|
||||
freshness = args.get("freshness", "noLimit")
|
||||
summary = args.get("summary", False)
|
||||
|
||||
# Validate count
|
||||
if not isinstance(count, int) or count < 1 or count > 50:
|
||||
count = 10
|
||||
|
||||
# Resolve backend
|
||||
backend = self._resolve_backend()
|
||||
if not backend:
|
||||
return ToolResult.fail(
|
||||
"Error: No search API key configured. "
|
||||
"Please set BOCHA_API_KEY or LINKAI_API_KEY using env_config tool.\n"
|
||||
" - Bocha Search: https://open.bocha.cn\n"
|
||||
" - LinkAI Search: https://link-ai.tech"
|
||||
)
|
||||
|
||||
try:
|
||||
if backend == "bocha":
|
||||
return self._search_bocha(query, count, freshness, summary)
|
||||
else:
|
||||
return self._search_linkai(query, count, freshness)
|
||||
except requests.Timeout:
|
||||
return ToolResult.fail(f"Error: Search request timed out after {DEFAULT_TIMEOUT}s")
|
||||
except requests.ConnectionError:
|
||||
return ToolResult.fail("Error: Failed to connect to search API")
|
||||
except Exception as e:
|
||||
logger.error(f"[WebSearch] Unexpected error: {e}", exc_info=True)
|
||||
return ToolResult.fail(f"Error: Search failed - {str(e)}")
|
||||
|
||||
def _search_bocha(self, query: str, count: int, freshness: str, summary: bool) -> ToolResult:
|
||||
"""
|
||||
Search using Bocha API
|
||||
|
||||
:param query: Search query
|
||||
:param count: Number of results
|
||||
:param freshness: Time range filter
|
||||
:param summary: Whether to include summary
|
||||
:return: Formatted search results
|
||||
"""
|
||||
api_key = os.environ.get("BOCHA_API_KEY", "")
|
||||
url = "https://api.bocha.cn/v1/web-search"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"count": count,
|
||||
"freshness": freshness,
|
||||
"summary": summary
|
||||
}
|
||||
|
||||
logger.debug(f"[WebSearch] Bocha search: query='{query}', count={count}")
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if response.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid BOCHA_API_KEY. Please check your API key.")
|
||||
if response.status_code == 403:
|
||||
return ToolResult.fail("Error: Bocha API - insufficient balance. Please top up at https://open.bocha.cn")
|
||||
if response.status_code == 429:
|
||||
return ToolResult.fail("Error: Bocha API rate limit reached. Please try again later.")
|
||||
if response.status_code != 200:
|
||||
return ToolResult.fail(f"Error: Bocha API returned HTTP {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
|
||||
# Check API-level error code
|
||||
api_code = data.get("code")
|
||||
if api_code is not None and api_code != 200:
|
||||
msg = data.get("msg") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: Bocha API error (code={api_code}): {msg}")
|
||||
|
||||
# Extract and format results
|
||||
return self._format_bocha_results(data, query)
|
||||
|
||||
def _format_bocha_results(self, data: dict, query: str) -> ToolResult:
|
||||
"""
|
||||
Format Bocha API response into unified result structure
|
||||
|
||||
:param data: Raw API response
|
||||
:param query: Original query
|
||||
:return: Formatted ToolResult
|
||||
"""
|
||||
search_data = data.get("data", {})
|
||||
web_pages = search_data.get("webPages", {})
|
||||
pages = web_pages.get("value", [])
|
||||
|
||||
if not pages:
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "bocha",
|
||||
"total": 0,
|
||||
"results": [],
|
||||
"message": "No results found"
|
||||
})
|
||||
|
||||
results = []
|
||||
for page in pages:
|
||||
result = {
|
||||
"title": page.get("name", ""),
|
||||
"url": page.get("url", ""),
|
||||
"snippet": page.get("snippet", ""),
|
||||
"siteName": page.get("siteName", ""),
|
||||
"datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""),
|
||||
}
|
||||
# Include summary only if present
|
||||
if page.get("summary"):
|
||||
result["summary"] = page["summary"]
|
||||
results.append(result)
|
||||
|
||||
total = web_pages.get("totalEstimatedMatches", len(results))
|
||||
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "bocha",
|
||||
"total": total,
|
||||
"count": len(results),
|
||||
"results": results
|
||||
})
|
||||
|
||||
def _search_linkai(self, query: str, count: int, freshness: str) -> ToolResult:
|
||||
"""
|
||||
Search using LinkAI plugin API
|
||||
|
||||
:param query: Search query
|
||||
:param count: Number of results
|
||||
:param freshness: Time range filter
|
||||
:return: Formatted search results
|
||||
"""
|
||||
api_key = os.environ.get("LINKAI_API_KEY", "")
|
||||
url = "https://api.link-ai.tech/v1/plugin/execute"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"code": "web-search",
|
||||
"args": {
|
||||
"query": query,
|
||||
"count": count,
|
||||
"freshness": freshness
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(f"[WebSearch] LinkAI search: query='{query}', count={count}")
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
if response.status_code == 401:
|
||||
return ToolResult.fail("Error: Invalid LINKAI_API_KEY. Please check your API key.")
|
||||
if response.status_code != 200:
|
||||
return ToolResult.fail(f"Error: LinkAI API returned HTTP {response.status_code}")
|
||||
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
msg = data.get("message") or "Unknown error"
|
||||
return ToolResult.fail(f"Error: LinkAI search failed: {msg}")
|
||||
|
||||
return self._format_linkai_results(data, query)
|
||||
|
||||
def _format_linkai_results(self, data: dict, query: str) -> ToolResult:
|
||||
"""
|
||||
Format LinkAI API response into unified result structure.
|
||||
LinkAI returns the search data in data.data field, which follows
|
||||
the same Bing-compatible format as Bocha.
|
||||
|
||||
:param data: Raw API response
|
||||
:param query: Original query
|
||||
:return: Formatted ToolResult
|
||||
"""
|
||||
raw_data = data.get("data", "")
|
||||
|
||||
# LinkAI may return data as a JSON string
|
||||
if isinstance(raw_data, str):
|
||||
try:
|
||||
raw_data = json.loads(raw_data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If data is plain text, return it as a single result
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "linkai",
|
||||
"total": 1,
|
||||
"count": 1,
|
||||
"results": [{"content": raw_data}]
|
||||
})
|
||||
|
||||
# If the response follows Bing-compatible structure
|
||||
if isinstance(raw_data, dict):
|
||||
web_pages = raw_data.get("webPages", {})
|
||||
pages = web_pages.get("value", [])
|
||||
|
||||
if pages:
|
||||
results = []
|
||||
for page in pages:
|
||||
result = {
|
||||
"title": page.get("name", ""),
|
||||
"url": page.get("url", ""),
|
||||
"snippet": page.get("snippet", ""),
|
||||
"siteName": page.get("siteName", ""),
|
||||
"datePublished": page.get("datePublished") or page.get("dateLastCrawled", ""),
|
||||
}
|
||||
if page.get("summary"):
|
||||
result["summary"] = page["summary"]
|
||||
results.append(result)
|
||||
|
||||
total = web_pages.get("totalEstimatedMatches", len(results))
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "linkai",
|
||||
"total": total,
|
||||
"count": len(results),
|
||||
"results": results
|
||||
})
|
||||
|
||||
# Fallback: return raw data
|
||||
return ToolResult.success({
|
||||
"query": query,
|
||||
"backend": "linkai",
|
||||
"total": 1,
|
||||
"count": 1,
|
||||
"results": [{"content": str(raw_data)}]
|
||||
})
|
||||
3
agent/tools/write/__init__.py
Normal file
3
agent/tools/write/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .write import Write
|
||||
|
||||
__all__ = ['Write']
|
||||
97
agent/tools/write/write.py
Normal file
97
agent/tools/write/write.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Write tool - Write file content
|
||||
Creates or overwrites files, automatically creates parent directories
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from agent.tools.base_tool import BaseTool, ToolResult
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class Write(BaseTool):
|
||||
"""Tool for writing file content"""
|
||||
|
||||
name: str = "write"
|
||||
description: str = "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. Automatically creates parent directories. IMPORTANT: Single write should not exceed 10KB. For large files, create a skeleton first, then use edit to add content in chunks."
|
||||
|
||||
params: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to write (relative or absolute)"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
self.config = config or {}
|
||||
self.cwd = self.config.get("cwd", os.getcwd())
|
||||
self.memory_manager = self.config.get("memory_manager", None)
|
||||
|
||||
def execute(self, args: Dict[str, Any]) -> ToolResult:
|
||||
"""
|
||||
Execute file write operation
|
||||
|
||||
:param args: Contains file path and content
|
||||
:return: Operation result
|
||||
"""
|
||||
path = args.get("path", "").strip()
|
||||
content = args.get("content", "")
|
||||
|
||||
if not path:
|
||||
return ToolResult.fail("Error: path parameter is required")
|
||||
|
||||
# Resolve path
|
||||
absolute_path = self._resolve_path(path)
|
||||
|
||||
try:
|
||||
# Create parent directory (if needed)
|
||||
parent_dir = os.path.dirname(absolute_path)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
|
||||
# Write file
|
||||
with open(absolute_path, 'w', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
# Get bytes written
|
||||
bytes_written = len(content.encode('utf-8'))
|
||||
|
||||
# Auto-sync to memory database if this is a memory file
|
||||
if self.memory_manager and 'memory/' in path:
|
||||
self.memory_manager.mark_dirty()
|
||||
|
||||
result = {
|
||||
"message": f"Successfully wrote {bytes_written} bytes to {path}",
|
||||
"path": path,
|
||||
"bytes_written": bytes_written
|
||||
}
|
||||
|
||||
return ToolResult.success(result)
|
||||
|
||||
except PermissionError:
|
||||
return ToolResult.fail(f"Error: Permission denied writing to {path}")
|
||||
except Exception as e:
|
||||
return ToolResult.fail(f"Error writing file: {str(e)}")
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
Resolve path to absolute path
|
||||
|
||||
:param path: Relative or absolute path
|
||||
:return: Absolute path
|
||||
"""
|
||||
# Expand ~ to user home directory
|
||||
path = expand_path(path)
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return os.path.abspath(os.path.join(self.cwd, path))
|
||||
260
app.py
260
app.py
@@ -1,22 +1,236 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import os
|
||||
from config import conf, load_config
|
||||
from channel import channel_factory
|
||||
from common.log import logger
|
||||
from plugins import *
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
from channel import channel_factory
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import load_config, conf
|
||||
from plugins import *
|
||||
import threading
|
||||
|
||||
|
||||
_channel_mgr = None
|
||||
|
||||
|
||||
def get_channel_manager():
|
||||
return _channel_mgr
|
||||
|
||||
|
||||
def _parse_channel_type(raw) -> list:
|
||||
"""
|
||||
Parse channel_type config value into a list of channel names.
|
||||
Supports:
|
||||
- single string: "feishu"
|
||||
- comma-separated string: "feishu, dingtalk"
|
||||
- list: ["feishu", "dingtalk"]
|
||||
"""
|
||||
if isinstance(raw, list):
|
||||
return [ch.strip() for ch in raw if ch.strip()]
|
||||
if isinstance(raw, str):
|
||||
return [ch.strip() for ch in raw.split(",") if ch.strip()]
|
||||
return []
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manage the lifecycle of multiple channels running concurrently.
|
||||
Each channel.startup() runs in its own daemon thread.
|
||||
The web channel is started as default console unless explicitly disabled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._channels = {} # channel_name -> channel instance
|
||||
self._threads = {} # channel_name -> thread
|
||||
self._primary_channel = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def channel(self):
|
||||
"""Return the primary (first non-web) channel for backward compatibility."""
|
||||
return self._primary_channel
|
||||
|
||||
def get_channel(self, channel_name: str):
|
||||
return self._channels.get(channel_name)
|
||||
|
||||
def start(self, channel_names: list, first_start: bool = False):
|
||||
"""
|
||||
Create and start one or more channels in sub-threads.
|
||||
If first_start is True, plugins and linkai client will also be initialized.
|
||||
"""
|
||||
with self._lock:
|
||||
channels = []
|
||||
for name in channel_names:
|
||||
ch = channel_factory.create_channel(name)
|
||||
self._channels[name] = ch
|
||||
channels.append((name, ch))
|
||||
if self._primary_channel is None and name != "web":
|
||||
self._primary_channel = ch
|
||||
|
||||
if self._primary_channel is None and channels:
|
||||
self._primary_channel = channels[0][1]
|
||||
|
||||
if first_start:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
try:
|
||||
from common import cloud_client
|
||||
threading.Thread(
|
||||
target=cloud_client.start,
|
||||
args=(self._primary_channel, self),
|
||||
daemon=True,
|
||||
).start()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Start web console first so its logs print cleanly,
|
||||
# then start remaining channels after a brief pause.
|
||||
web_entry = None
|
||||
other_entries = []
|
||||
for entry in channels:
|
||||
if entry[0] == "web":
|
||||
web_entry = entry
|
||||
else:
|
||||
other_entries.append(entry)
|
||||
|
||||
ordered = ([web_entry] if web_entry else []) + other_entries
|
||||
for i, (name, ch) in enumerate(ordered):
|
||||
if i > 0 and name != "web":
|
||||
time.sleep(0.1)
|
||||
t = threading.Thread(target=self._run_channel, args=(name, ch), daemon=True)
|
||||
self._threads[name] = t
|
||||
t.start()
|
||||
logger.debug(f"[ChannelManager] Channel '{name}' started in sub-thread")
|
||||
|
||||
def _run_channel(self, name: str, channel):
|
||||
try:
|
||||
channel.startup()
|
||||
except Exception as e:
|
||||
logger.error(f"[ChannelManager] Channel '{name}' startup error: {e}")
|
||||
logger.exception(e)
|
||||
|
||||
def stop(self, channel_name: str = None):
|
||||
"""
|
||||
Stop channel(s). If channel_name is given, stop only that channel;
|
||||
otherwise stop all channels.
|
||||
"""
|
||||
# Pop under lock, then stop outside lock to avoid deadlock
|
||||
with self._lock:
|
||||
names = [channel_name] if channel_name else list(self._channels.keys())
|
||||
to_stop = []
|
||||
for name in names:
|
||||
ch = self._channels.pop(name, None)
|
||||
th = self._threads.pop(name, None)
|
||||
to_stop.append((name, ch, th))
|
||||
if channel_name and self._primary_channel is self._channels.get(channel_name):
|
||||
self._primary_channel = None
|
||||
|
||||
for name, ch, th in to_stop:
|
||||
if ch is None:
|
||||
logger.warning(f"[ChannelManager] Channel '{name}' not found in managed channels")
|
||||
if th and th.is_alive():
|
||||
self._interrupt_thread(th, name)
|
||||
continue
|
||||
logger.info(f"[ChannelManager] Stopping channel '{name}'...")
|
||||
try:
|
||||
if hasattr(ch, 'stop'):
|
||||
ch.stop()
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Error during channel '{name}' stop: {e}")
|
||||
if th and th.is_alive():
|
||||
self._interrupt_thread(th, name)
|
||||
|
||||
@staticmethod
|
||||
def _interrupt_thread(th: threading.Thread, name: str):
|
||||
"""Raise SystemExit in target thread to break blocking loops like start_forever."""
|
||||
import ctypes
|
||||
try:
|
||||
tid = th.ident
|
||||
if tid is None:
|
||||
return
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
|
||||
)
|
||||
if res == 1:
|
||||
logger.info(f"[ChannelManager] Interrupted thread for channel '{name}'")
|
||||
elif res > 1:
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
|
||||
logger.warning(f"[ChannelManager] Failed to interrupt thread for channel '{name}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Thread interrupt error for '{name}': {e}")
|
||||
|
||||
def restart(self, new_channel_name: str):
|
||||
"""
|
||||
Restart a single channel with a new channel type.
|
||||
Can be called from any thread (e.g. linkai config callback).
|
||||
"""
|
||||
logger.info(f"[ChannelManager] Restarting channel to '{new_channel_name}'...")
|
||||
self.stop(new_channel_name)
|
||||
_clear_singleton_cache(new_channel_name)
|
||||
time.sleep(1)
|
||||
self.start([new_channel_name], first_start=False)
|
||||
logger.info(f"[ChannelManager] Channel restarted to '{new_channel_name}' successfully")
|
||||
|
||||
|
||||
def _clear_singleton_cache(channel_name: str):
|
||||
"""
|
||||
Clear the singleton cache for the channel class so that
|
||||
a new instance can be created with updated config.
|
||||
"""
|
||||
cls_map = {
|
||||
"wx": "channel.wechat.wechat_channel.WechatChannel",
|
||||
"wxy": "channel.wechat.wechaty_channel.WechatyChannel",
|
||||
"wcf": "channel.wechat.wcf_channel.WechatfChannel",
|
||||
"web": "channel.web.web_channel.WebChannel",
|
||||
"wechatmp": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatmp_service": "channel.wechatmp.wechatmp_channel.WechatMPChannel",
|
||||
"wechatcom_app": "channel.wechatcom.wechatcomapp_channel.WechatComAppChannel",
|
||||
"wework": "channel.wework.wework_channel.WeworkChannel",
|
||||
const.FEISHU: "channel.feishu.feishu_channel.FeiShuChanel",
|
||||
const.DINGTALK: "channel.dingtalk.dingtalk_channel.DingTalkChanel",
|
||||
}
|
||||
module_path = cls_map.get(channel_name)
|
||||
if not module_path:
|
||||
return
|
||||
try:
|
||||
parts = module_path.rsplit(".", 1)
|
||||
module_name, class_name = parts[0], parts[1]
|
||||
import importlib
|
||||
module = importlib.import_module(module_name)
|
||||
wrapper = getattr(module, class_name, None)
|
||||
if wrapper and hasattr(wrapper, '__closure__') and wrapper.__closure__:
|
||||
for cell in wrapper.__closure__:
|
||||
try:
|
||||
cell_contents = cell.cell_contents
|
||||
if isinstance(cell_contents, dict):
|
||||
cell_contents.clear()
|
||||
logger.debug(f"[ChannelManager] Cleared singleton cache for {class_name}")
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChannelManager] Failed to clear singleton cache: {e}")
|
||||
|
||||
|
||||
def sigterm_handler_wrap(_signo):
|
||||
old_handler = signal.getsignal(_signo)
|
||||
|
||||
def func(_signo, _stack_frame):
|
||||
logger.info("signal {} received, exiting...".format(_signo))
|
||||
conf().save_user_datas()
|
||||
return old_handler(_signo, _stack_frame)
|
||||
if callable(old_handler): # check old_handler
|
||||
return old_handler(_signo, _stack_frame)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def run():
|
||||
global _channel_mgr
|
||||
try:
|
||||
# load config
|
||||
load_config()
|
||||
@@ -25,25 +239,35 @@ def run():
|
||||
# kill signal
|
||||
sigterm_handler_wrap(signal.SIGTERM)
|
||||
|
||||
# create channel
|
||||
channel_name=conf().get('channel_type', 'wx')
|
||||
# Parse channel_type into a list
|
||||
raw_channel = conf().get("channel_type", "web")
|
||||
|
||||
if "--cmd" in sys.argv:
|
||||
channel_name = 'terminal'
|
||||
channel_names = ["terminal"]
|
||||
else:
|
||||
channel_names = _parse_channel_type(raw_channel)
|
||||
if not channel_names:
|
||||
channel_names = ["web"]
|
||||
|
||||
if channel_name == 'wxy':
|
||||
os.environ['WECHATY_LOG']="warn"
|
||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||
if "wxy" in channel_names:
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ['wx','wxy','terminal','wechatmp','wechatmp_service']:
|
||||
PluginManager().load_plugins()
|
||||
# Auto-start web console unless explicitly disabled
|
||||
web_console_enabled = conf().get("web_console", True)
|
||||
if web_console_enabled and "web" not in channel_names:
|
||||
channel_names.append("web")
|
||||
|
||||
# startup channel
|
||||
channel.startup()
|
||||
logger.info(f"[App] Starting channels: {channel_names}")
|
||||
|
||||
_channel_mgr = ChannelManager()
|
||||
_channel_mgr.start(channel_names, first_start=True)
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests
|
||||
from bot.bot import Bot
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
|
||||
# Baidu Unit对话接口 (可用, 但能力较弱)
|
||||
class BaiduUnitBot(Bot):
|
||||
def reply(self, query, context=None):
|
||||
token = self.get_token()
|
||||
url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token
|
||||
post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}"
|
||||
print(post_data)
|
||||
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
||||
response = requests.post(url, data=post_data.encode(), headers=headers)
|
||||
if response:
|
||||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
|
||||
return reply
|
||||
|
||||
def get_token(self):
|
||||
access_key = 'YOUR_ACCESS_KEY'
|
||||
secret_key = 'YOUR_SECRET_KEY'
|
||||
host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key
|
||||
response = requests.get(host)
|
||||
if response:
|
||||
print(response.json())
|
||||
return response.json()['access_token']
|
||||
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
channel factory
|
||||
"""
|
||||
from common import const
|
||||
|
||||
|
||||
def create_bot(bot_type):
|
||||
"""
|
||||
create a bot_type instance
|
||||
:param bot_type: bot type code
|
||||
:return: bot instance
|
||||
"""
|
||||
if bot_type == const.BAIDU:
|
||||
# Baidu Unit对话接口
|
||||
from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
return BaiduUnitBot()
|
||||
|
||||
elif bot_type == const.CHATGPT:
|
||||
# ChatGPT 网页端web接口
|
||||
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
||||
return ChatGPTBot()
|
||||
|
||||
elif bot_type == const.OPEN_AI:
|
||||
# OpenAI 官方对话模型API
|
||||
from bot.openai.open_ai_bot import OpenAIBot
|
||||
return OpenAIBot()
|
||||
|
||||
elif bot_type == const.CHATGPTONAZURE:
|
||||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
||||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
||||
return AzureChatGPTBot()
|
||||
raise RuntimeError
|
||||
@@ -1,156 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import Session, SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf, load_config
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from common.expired_dict import ExpiredDict
|
||||
import openai
|
||||
import openai.error
|
||||
import time
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot,OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# set the default api_key
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
proxy = conf().get('proxy')
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
if conf().get('rate_limit_chatgpt'):
|
||||
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
|
||||
|
||||
self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[CHATGPT] query={}".format(query))
|
||||
|
||||
|
||||
session_id = context['session_id']
|
||||
reply = None
|
||||
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
elif query == '#更新配置':
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, '配置已更新')
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
||||
|
||||
api_key = context.get('openai_api_key')
|
||||
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, session_id, api_key, 0)
|
||||
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
|
||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
|
||||
return reply
|
||||
|
||||
def compose_args(self):
|
||||
return {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p":1,
|
||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get('request_timeout', 60), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get('request_timeout', 120), #重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
|
||||
def reply_text(self, session:ChatGPTSession, session_id, api_key, retry_count=0) -> dict:
|
||||
'''
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
'''
|
||||
try:
|
||||
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
|
||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
response = openai.ChatCompletion.create(
|
||||
api_key=api_key, messages=session.messages, **self.compose_args()
|
||||
)
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]['message']['content']}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
||||
result['content'] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
||||
result['content'] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result['content'] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.warn("[CHATGPT] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(session, session_id, api_key, retry_count+1)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
class AzureChatGPTBot(ChatGPTBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_type = "azure"
|
||||
openai.api_version = "2023-03-15-preview"
|
||||
|
||||
def compose_args(self):
|
||||
args = super().compose_args()
|
||||
args["engine"] = args["model"]
|
||||
del(args["model"])
|
||||
return args
|
||||
@@ -1,109 +0,0 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.openai.open_ai_session import OpenAISession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import openai
|
||||
import openai.error
|
||||
import time
|
||||
|
||||
user_session = dict()
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class OpenAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
proxy = conf().get('proxy')
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
session_id = context['session_id']
|
||||
reply = None
|
||||
if query == '#清除记忆':
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
new_query = str(session)
|
||||
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
||||
|
||||
total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens))
|
||||
|
||||
if total_tokens == 0 :
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, query, session_id, retry_count=0):
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
model= conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||
prompt=query,
|
||||
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
max_tokens=1200, # 回复最大的字符数
|
||||
top_p=1,
|
||||
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
stop=["\n\n\n"]
|
||||
)
|
||||
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
completion_tokens = response["usage"]["completion_tokens"]
|
||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||
return total_tokens, completion_tokens, res_content
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = [0,0,"我现在有点累了,等会再来吧"]
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
||||
result[2] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
||||
result[2] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result[2] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.warn("[OPEN_AI] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, session_id, retry_count+1)
|
||||
else:
|
||||
return result
|
||||
@@ -1,38 +0,0 @@
|
||||
import time
|
||||
import openai
|
||||
import openai.error
|
||||
from common.token_bucket import TokenBucket
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
# OPENAI提供的画图接口
|
||||
class OpenAIImage(object):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('rate_limit_dalle'):
|
||||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
prompt=query, #图片描述
|
||||
n=1, #每次生成图片的数量
|
||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.create_img(query, retry_count+1)
|
||||
else:
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, str(e)
|
||||
646
bridge/agent_bridge.py
Normal file
646
bridge/agent_bridge.py
Normal file
@@ -0,0 +1,646 @@
|
||||
"""
|
||||
Agent Bridge - Integrates Agent system with existing COW bridge
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List
|
||||
|
||||
from agent.protocol import Agent, LLMModel, LLMRequest
|
||||
from bridge.agent_event_handler import AgentEventHandler
|
||||
from bridge.agent_initializer import AgentInitializer
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
from models.openai_compatible_bot import OpenAICompatibleBot
|
||||
|
||||
|
||||
def add_openai_compatible_support(bot_instance):
|
||||
"""
|
||||
Dynamically add OpenAI-compatible tool calling support to a bot instance.
|
||||
|
||||
This allows any bot to gain tool calling capability without modifying its code,
|
||||
as long as it uses OpenAI-compatible API format.
|
||||
|
||||
Note: Some bots like ZHIPUAIBot have native tool calling support and don't need enhancement.
|
||||
"""
|
||||
if hasattr(bot_instance, 'call_with_tools'):
|
||||
# Bot already has tool calling support (e.g., ZHIPUAIBot)
|
||||
logger.debug(f"[AgentBridge] {type(bot_instance).__name__} already has native tool calling support")
|
||||
return bot_instance
|
||||
|
||||
# Create a temporary mixin class that combines the bot with OpenAI compatibility
|
||||
class EnhancedBot(bot_instance.__class__, OpenAICompatibleBot):
|
||||
"""Dynamically enhanced bot with OpenAI-compatible tool calling"""
|
||||
|
||||
def get_api_config(self):
|
||||
"""
|
||||
Infer API config from common configuration patterns.
|
||||
Most OpenAI-compatible bots use similar configuration.
|
||||
"""
|
||||
from config import conf
|
||||
|
||||
return {
|
||||
'api_key': conf().get("open_ai_api_key"),
|
||||
'api_base': conf().get("open_ai_api_base"),
|
||||
'model': conf().get("model", "gpt-3.5-turbo"),
|
||||
'default_temperature': conf().get("temperature", 0.9),
|
||||
'default_top_p': conf().get("top_p", 1.0),
|
||||
'default_frequency_penalty': conf().get("frequency_penalty", 0.0),
|
||||
'default_presence_penalty': conf().get("presence_penalty", 0.0),
|
||||
}
|
||||
|
||||
# Change the bot's class to the enhanced version
|
||||
bot_instance.__class__ = EnhancedBot
|
||||
logger.info(
|
||||
f"[AgentBridge] Enhanced {bot_instance.__class__.__bases__[0].__name__} with OpenAI-compatible tool calling")
|
||||
|
||||
return bot_instance
|
||||
|
||||
|
||||
class AgentLLMModel(LLMModel):
|
||||
"""
|
||||
LLM Model adapter that uses COW's existing bot infrastructure
|
||||
"""
|
||||
|
||||
_MODEL_BOT_TYPE_MAP = {
|
||||
"wenxin": const.BAIDU, "wenxin-4": const.BAIDU,
|
||||
"xunfei": const.XUNFEI, const.QWEN: const.QWEN,
|
||||
const.MODELSCOPE: const.MODELSCOPE,
|
||||
}
|
||||
_MODEL_PREFIX_MAP = [
|
||||
("qwen", const.QWEN_DASHSCOPE), ("qwq", const.QWEN_DASHSCOPE), ("qvq", const.QWEN_DASHSCOPE),
|
||||
("gemini", const.GEMINI), ("glm", const.ZHIPU_AI), ("claude", const.CLAUDEAPI),
|
||||
("moonshot", const.MOONSHOT), ("kimi", const.MOONSHOT),
|
||||
("doubao", const.DOUBAO),
|
||||
]
|
||||
|
||||
def __init__(self, bridge: Bridge, bot_type: str = "chat"):
|
||||
from config import conf
|
||||
super().__init__(model=conf().get("model", const.GPT_41))
|
||||
self.bridge = bridge
|
||||
self.bot_type = bot_type
|
||||
self._bot = None
|
||||
self._bot_model = None
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
from config import conf
|
||||
return conf().get("model", const.GPT_41)
|
||||
|
||||
@model.setter
|
||||
def model(self, value):
|
||||
pass
|
||||
|
||||
def _resolve_bot_type(self, model_name: str) -> str:
|
||||
"""Resolve bot type from model name, matching Bridge.__init__ logic."""
|
||||
from config import conf
|
||||
if conf().get("use_linkai", False) and conf().get("linkai_api_key"):
|
||||
return const.LINKAI
|
||||
if not model_name or not isinstance(model_name, str):
|
||||
return const.CHATGPT
|
||||
if model_name in self._MODEL_BOT_TYPE_MAP:
|
||||
return self._MODEL_BOT_TYPE_MAP[model_name]
|
||||
if model_name.lower().startswith("minimax") or model_name in ["abab6.5-chat"]:
|
||||
return const.MiniMax
|
||||
if model_name in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
return const.QWEN_DASHSCOPE
|
||||
if model_name in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
return const.MOONSHOT
|
||||
if model_name in [const.DEEPSEEK_CHAT, const.DEEPSEEK_REASONER]:
|
||||
return const.CHATGPT
|
||||
for prefix, btype in self._MODEL_PREFIX_MAP:
|
||||
if model_name.startswith(prefix):
|
||||
return btype
|
||||
return const.CHATGPT
|
||||
|
||||
@property
|
||||
def bot(self):
|
||||
"""Lazy load the bot, re-create when model changes"""
|
||||
from models.bot_factory import create_bot
|
||||
cur_model = self.model
|
||||
if self._bot is None or self._bot_model != cur_model:
|
||||
bot_type = self._resolve_bot_type(cur_model)
|
||||
self._bot = create_bot(bot_type)
|
||||
self._bot = add_openai_compatible_support(self._bot)
|
||||
self._bot_model = cur_model
|
||||
return self._bot
|
||||
|
||||
def call(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model using COW's bot infrastructure
|
||||
"""
|
||||
try:
|
||||
# For non-streaming calls, we'll use the existing reply method
|
||||
# This is a simplified implementation
|
||||
if hasattr(self.bot, 'call_with_tools'):
|
||||
# Use tool-enabled call if available
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
'tools': getattr(request, 'tools', None),
|
||||
'stream': False,
|
||||
'model': self.model # Pass model parameter
|
||||
}
|
||||
# Only pass max_tokens if it's explicitly set
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
response = self.bot.call_with_tools(**kwargs)
|
||||
return self._format_response(response)
|
||||
else:
|
||||
# Fallback to regular call
|
||||
# This would need to be implemented based on your specific needs
|
||||
raise NotImplementedError("Regular call not implemented yet")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentLLMModel call error: {e}")
|
||||
raise
|
||||
|
||||
def call_stream(self, request: LLMRequest):
|
||||
"""
|
||||
Call the model with streaming using COW's bot infrastructure
|
||||
"""
|
||||
try:
|
||||
if hasattr(self.bot, 'call_with_tools'):
|
||||
# Use tool-enabled streaming call if available
|
||||
# Extract system prompt if present
|
||||
system_prompt = getattr(request, 'system', None)
|
||||
|
||||
# Build kwargs for call_with_tools
|
||||
kwargs = {
|
||||
'messages': request.messages,
|
||||
'tools': getattr(request, 'tools', None),
|
||||
'stream': True,
|
||||
'model': self.model # Pass model parameter
|
||||
}
|
||||
|
||||
# Only pass max_tokens if explicitly set, let the bot use its default
|
||||
if request.max_tokens is not None:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
|
||||
# Add system prompt if present
|
||||
if system_prompt:
|
||||
kwargs['system'] = system_prompt
|
||||
|
||||
# Pass channel_type for linkai tracking
|
||||
channel_type = getattr(self, 'channel_type', None)
|
||||
if channel_type:
|
||||
kwargs['channel_type'] = channel_type
|
||||
|
||||
stream = self.bot.call_with_tools(**kwargs)
|
||||
|
||||
# Convert stream format to our expected format
|
||||
for chunk in stream:
|
||||
yield self._format_stream_chunk(chunk)
|
||||
else:
|
||||
bot_type = type(self.bot).__name__
|
||||
raise NotImplementedError(f"Bot {bot_type} does not support call_with_tools. Please add the method.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AgentLLMModel call_stream error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _format_response(self, response):
|
||||
"""Format Claude response to our expected format"""
|
||||
# This would need to be implemented based on Claude's response format
|
||||
return response
|
||||
|
||||
def _format_stream_chunk(self, chunk):
|
||||
"""Format Claude stream chunk to our expected format"""
|
||||
# This would need to be implemented based on Claude's stream format
|
||||
return chunk
|
||||
|
||||
|
||||
class AgentBridge:
|
||||
"""
|
||||
Bridge class that integrates super Agent with COW
|
||||
Manages multiple agent instances per session for conversation isolation
|
||||
"""
|
||||
|
||||
def __init__(self, bridge: Bridge):
|
||||
self.bridge = bridge
|
||||
self.agents = {} # session_id -> Agent instance mapping
|
||||
self.default_agent = None # For backward compatibility (no session_id)
|
||||
self.agent: Optional[Agent] = None
|
||||
self.scheduler_initialized = False
|
||||
|
||||
# Create helper instances
|
||||
self.initializer = AgentInitializer(bridge, self)
|
||||
def create_agent(self, system_prompt: str, tools: List = None, **kwargs) -> Agent:
|
||||
"""
|
||||
Create the super agent with COW integration
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt
|
||||
tools: List of tools (optional)
|
||||
**kwargs: Additional agent parameters
|
||||
|
||||
Returns:
|
||||
Agent instance
|
||||
"""
|
||||
# Create LLM model that uses COW's bot infrastructure
|
||||
model = AgentLLMModel(self.bridge)
|
||||
|
||||
# Default tools if none provided
|
||||
if tools is None:
|
||||
# Use ToolManager to load all available tools
|
||||
from agent.tools import ToolManager
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.load_tools()
|
||||
|
||||
tools = []
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
if tool:
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to load tool {tool_name}: {e}")
|
||||
|
||||
# Create agent instance
|
||||
agent = Agent(
|
||||
system_prompt=system_prompt,
|
||||
description=kwargs.get("description", "AI Super Agent"),
|
||||
model=model,
|
||||
tools=tools,
|
||||
max_steps=kwargs.get("max_steps", 15),
|
||||
output_mode=kwargs.get("output_mode", "logger"),
|
||||
workspace_dir=kwargs.get("workspace_dir"), # Pass workspace for skills loading
|
||||
enable_skills=kwargs.get("enable_skills", True), # Enable skills by default
|
||||
memory_manager=kwargs.get("memory_manager"), # Pass memory manager
|
||||
max_context_tokens=kwargs.get("max_context_tokens"),
|
||||
context_reserve_tokens=kwargs.get("context_reserve_tokens"),
|
||||
runtime_info=kwargs.get("runtime_info") # Pass runtime_info for dynamic time updates
|
||||
)
|
||||
|
||||
# Log skill loading details
|
||||
if agent.skill_manager:
|
||||
logger.debug(f"[AgentBridge] SkillManager initialized with {len(agent.skill_manager.skills)} skills")
|
||||
|
||||
return agent
|
||||
|
||||
def get_agent(self, session_id: str = None) -> Optional[Agent]:
|
||||
"""
|
||||
Get agent instance for the given session
|
||||
|
||||
Args:
|
||||
session_id: Session identifier (e.g., user_id). If None, returns default agent.
|
||||
|
||||
Returns:
|
||||
Agent instance for this session
|
||||
"""
|
||||
# If no session_id, use default agent (backward compatibility)
|
||||
if session_id is None:
|
||||
if self.default_agent is None:
|
||||
self._init_default_agent()
|
||||
return self.default_agent
|
||||
|
||||
# Check if agent exists for this session
|
||||
if session_id not in self.agents:
|
||||
self._init_agent_for_session(session_id)
|
||||
|
||||
return self.agents[session_id]
|
||||
|
||||
def _init_default_agent(self):
|
||||
"""Initialize default super agent"""
|
||||
agent = self.initializer.initialize_agent(session_id=None)
|
||||
self.default_agent = agent
|
||||
|
||||
def _init_agent_for_session(self, session_id: str):
|
||||
"""Initialize agent for a specific session"""
|
||||
agent = self.initializer.initialize_agent(session_id=session_id)
|
||||
self.agents[session_id] = agent
|
||||
|
||||
def agent_reply(self, query: str, context: Context = None,
|
||||
on_event=None, clear_history: bool = False) -> Reply:
|
||||
"""
|
||||
Use super agent to reply to a query
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
context: COW context (optional, contains session_id for user isolation)
|
||||
on_event: Event callback (optional)
|
||||
clear_history: Whether to clear conversation history
|
||||
|
||||
Returns:
|
||||
Reply object
|
||||
"""
|
||||
try:
|
||||
# Extract session_id from context for user isolation
|
||||
session_id = None
|
||||
if context:
|
||||
session_id = context.kwargs.get("session_id") or context.get("session_id")
|
||||
|
||||
# Get agent for this session (will auto-initialize if needed)
|
||||
agent = self.get_agent(session_id=session_id)
|
||||
if not agent:
|
||||
return Reply(ReplyType.ERROR, "Failed to initialize super agent")
|
||||
|
||||
# Create event handler for logging and channel communication
|
||||
event_handler = AgentEventHandler(context=context, original_callback=on_event)
|
||||
|
||||
# Filter tools based on context
|
||||
original_tools = agent.tools
|
||||
filtered_tools = original_tools
|
||||
|
||||
# If this is a scheduled task execution, exclude scheduler tool to prevent recursion
|
||||
if context and context.get("is_scheduled_task"):
|
||||
filtered_tools = [tool for tool in agent.tools if tool.name != "scheduler"]
|
||||
agent.tools = filtered_tools
|
||||
logger.info(f"[AgentBridge] Scheduled task execution: excluded scheduler tool ({len(filtered_tools)}/{len(original_tools)} tools)")
|
||||
else:
|
||||
# Attach context to scheduler tool if present
|
||||
if context and agent.tools:
|
||||
for tool in agent.tools:
|
||||
if tool.name == "scheduler":
|
||||
try:
|
||||
from agent.tools.scheduler.integration import attach_scheduler_to_tool
|
||||
attach_scheduler_to_tool(tool, context)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to attach context to scheduler: {e}")
|
||||
break
|
||||
|
||||
# Pass channel_type to model so linkai requests carry it
|
||||
if context and hasattr(agent, 'model'):
|
||||
agent.model.channel_type = context.get("channel_type", "")
|
||||
|
||||
# Record message count before execution so we can diff new messages
|
||||
with agent.messages_lock:
|
||||
pre_run_len = len(agent.messages)
|
||||
|
||||
try:
|
||||
# Use agent's run_stream method with event handler
|
||||
response = agent.run_stream(
|
||||
user_message=query,
|
||||
on_event=event_handler.handle_event,
|
||||
clear_history=clear_history
|
||||
)
|
||||
finally:
|
||||
# Restore original tools
|
||||
if context and context.get("is_scheduled_task"):
|
||||
agent.tools = original_tools
|
||||
|
||||
# Log execution summary
|
||||
event_handler.log_summary()
|
||||
|
||||
# Persist new messages generated during this run
|
||||
if session_id:
|
||||
channel_type = (context.get("channel_type") or "") if context else ""
|
||||
with agent.messages_lock:
|
||||
new_messages = agent.messages[pre_run_len:]
|
||||
self._persist_messages(session_id, list(new_messages), channel_type)
|
||||
|
||||
# Check if there are files to send (from read tool)
|
||||
if hasattr(agent, 'stream_executor') and hasattr(agent.stream_executor, 'files_to_send'):
|
||||
files_to_send = agent.stream_executor.files_to_send
|
||||
if files_to_send:
|
||||
# Send the first file (for now, handle one file at a time)
|
||||
file_info = files_to_send[0]
|
||||
logger.info(f"[AgentBridge] Sending file: {file_info.get('path')}")
|
||||
|
||||
# Clear files_to_send for next request
|
||||
agent.stream_executor.files_to_send = []
|
||||
|
||||
# Return file reply based on file type
|
||||
return self._create_file_reply(file_info, response, context)
|
||||
|
||||
return Reply(ReplyType.TEXT, response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent reply error: {e}")
|
||||
return Reply(ReplyType.ERROR, f"Agent error: {str(e)}")
|
||||
|
||||
def _create_file_reply(self, file_info: dict, text_response: str, context: Context = None) -> Reply:
|
||||
"""
|
||||
Create a reply for sending files
|
||||
|
||||
Args:
|
||||
file_info: File metadata from read tool
|
||||
text_response: Text response from agent
|
||||
context: Context object
|
||||
|
||||
Returns:
|
||||
Reply object for file sending
|
||||
"""
|
||||
file_type = file_info.get("file_type", "file")
|
||||
file_path = file_info.get("path")
|
||||
|
||||
# For images, use IMAGE_URL type (channel will handle upload)
|
||||
if file_type == "image":
|
||||
# Convert local path to file:// URL for channel processing
|
||||
file_url = f"file://{file_path}"
|
||||
logger.info(f"[AgentBridge] Sending image: {file_url}")
|
||||
reply = Reply(ReplyType.IMAGE_URL, file_url)
|
||||
# Attach text message if present (for channels that support text+image)
|
||||
if text_response:
|
||||
reply.text_content = text_response # Store accompanying text
|
||||
return reply
|
||||
|
||||
# For all file types (document, video, audio), use FILE type
|
||||
if file_type in ["document", "video", "audio"]:
|
||||
file_url = f"file://{file_path}"
|
||||
logger.info(f"[AgentBridge] Sending {file_type}: {file_url}")
|
||||
reply = Reply(ReplyType.FILE, file_url)
|
||||
reply.file_name = file_info.get("file_name", os.path.basename(file_path))
|
||||
# Attach text message if present
|
||||
if text_response:
|
||||
reply.text_content = text_response
|
||||
return reply
|
||||
|
||||
# For other unknown file types, return text with file info
|
||||
message = text_response or file_info.get("message", "文件已准备")
|
||||
message += f"\n\n[文件: {file_info.get('file_name', file_path)}]"
|
||||
return Reply(ReplyType.TEXT, message)
|
||||
|
||||
def _migrate_config_to_env(self, workspace_root: str):
|
||||
"""
|
||||
Migrate API keys from config.json to .env file if not already set
|
||||
|
||||
Args:
|
||||
workspace_root: Workspace directory path (not used, kept for compatibility)
|
||||
"""
|
||||
from config import conf
|
||||
import os
|
||||
|
||||
# Mapping from config.json keys to environment variable names
|
||||
key_mapping = {
|
||||
"open_ai_api_key": "OPENAI_API_KEY",
|
||||
"open_ai_api_base": "OPENAI_API_BASE",
|
||||
"gemini_api_key": "GEMINI_API_KEY",
|
||||
"claude_api_key": "CLAUDE_API_KEY",
|
||||
"linkai_api_key": "LINKAI_API_KEY",
|
||||
}
|
||||
|
||||
# Use fixed secure location for .env file
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
|
||||
# Read existing env vars from .env file
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
with open(env_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, _ = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to read .env file: {e}")
|
||||
|
||||
# Check which keys need to be migrated
|
||||
keys_to_migrate = {}
|
||||
for config_key, env_key in key_mapping.items():
|
||||
# Skip if already in .env file
|
||||
if env_key in existing_env_vars:
|
||||
continue
|
||||
|
||||
# Get value from config.json
|
||||
value = conf().get(config_key, "")
|
||||
if value and value.strip(): # Only migrate non-empty values
|
||||
keys_to_migrate[env_key] = value.strip()
|
||||
|
||||
# Log summary if there are keys to skip
|
||||
if existing_env_vars:
|
||||
logger.debug(f"[AgentBridge] {len(existing_env_vars)} env vars already in .env")
|
||||
|
||||
# Write new keys to .env file
|
||||
if keys_to_migrate:
|
||||
try:
|
||||
# Ensure ~/.cow directory and .env file exist
|
||||
env_dir = os.path.dirname(env_file)
|
||||
if not os.path.exists(env_dir):
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
if not os.path.exists(env_file):
|
||||
open(env_file, 'a').close()
|
||||
|
||||
# Append new keys
|
||||
with open(env_file, 'a', encoding='utf-8') as f:
|
||||
f.write('\n# Auto-migrated from config.json\n')
|
||||
for key, value in keys_to_migrate.items():
|
||||
f.write(f'{key}={value}\n')
|
||||
# Also set in current process
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[AgentBridge] Migrated {len(keys_to_migrate)} API keys from config.json to .env: {list(keys_to_migrate.keys())}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentBridge] Failed to migrate API keys: {e}")
|
||||
|
||||
def _persist_messages(
|
||||
self, session_id: str, new_messages: list, channel_type: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
Persist new messages to the conversation store after each agent run.
|
||||
|
||||
Failures are logged but never propagate — they must not interrupt replies.
|
||||
"""
|
||||
if not new_messages:
|
||||
return
|
||||
try:
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
get_conversation_store().append_messages(
|
||||
session_id, new_messages, channel_type=channel_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentBridge] Failed to persist messages for session={session_id}: {e}"
|
||||
)
|
||||
|
||||
def clear_session(self, session_id: str):
|
||||
"""
|
||||
Clear a specific session's agent and conversation history
|
||||
|
||||
Args:
|
||||
session_id: Session identifier to clear
|
||||
"""
|
||||
if session_id in self.agents:
|
||||
logger.info(f"[AgentBridge] Clearing session: {session_id}")
|
||||
del self.agents[session_id]
|
||||
|
||||
def clear_all_sessions(self):
|
||||
"""Clear all agent sessions"""
|
||||
logger.info(f"[AgentBridge] Clearing all sessions ({len(self.agents)} total)")
|
||||
self.agents.clear()
|
||||
self.default_agent = None
|
||||
|
||||
def refresh_all_skills(self) -> int:
|
||||
"""
|
||||
Refresh skills and conditional tools in all agent instances after
|
||||
environment variable changes. This allows hot-reload without restarting.
|
||||
|
||||
Returns:
|
||||
Number of agent instances refreshed
|
||||
"""
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from config import conf
|
||||
|
||||
# Reload environment variables from .env file
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
env_file = os.path.join(workspace_root, '.env')
|
||||
|
||||
if os.path.exists(env_file):
|
||||
load_dotenv(env_file, override=True)
|
||||
logger.info(f"[AgentBridge] Reloaded environment variables from {env_file}")
|
||||
|
||||
refreshed_count = 0
|
||||
|
||||
# Collect all agent instances to refresh
|
||||
agents_to_refresh = []
|
||||
if self.default_agent:
|
||||
agents_to_refresh.append(("default", self.default_agent))
|
||||
for session_id, agent in self.agents.items():
|
||||
agents_to_refresh.append((session_id, agent))
|
||||
|
||||
for label, agent in agents_to_refresh:
|
||||
# Refresh skills
|
||||
if hasattr(agent, 'skill_manager') and agent.skill_manager:
|
||||
agent.skill_manager.refresh_skills()
|
||||
|
||||
# Refresh conditional tools (e.g. web_search depends on API keys)
|
||||
self._refresh_conditional_tools(agent)
|
||||
|
||||
refreshed_count += 1
|
||||
|
||||
if refreshed_count > 0:
|
||||
logger.info(f"[AgentBridge] Refreshed skills & tools in {refreshed_count} agent instance(s)")
|
||||
|
||||
return refreshed_count
|
||||
|
||||
@staticmethod
|
||||
def _refresh_conditional_tools(agent):
|
||||
"""
|
||||
Add or remove conditional tools based on current environment variables.
|
||||
For example, web_search should only be present when BOCHA_API_KEY or
|
||||
LINKAI_API_KEY is set.
|
||||
"""
|
||||
try:
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
|
||||
has_tool = any(t.name == "web_search" for t in agent.tools)
|
||||
available = WebSearch.is_available()
|
||||
|
||||
if available and not has_tool:
|
||||
# API key was added - inject the tool
|
||||
tool = WebSearch()
|
||||
tool.model = agent.model
|
||||
agent.tools.append(tool)
|
||||
logger.info("[AgentBridge] web_search tool added (API key now available)")
|
||||
elif not available and has_tool:
|
||||
# API key was removed - remove the tool
|
||||
agent.tools = [t for t in agent.tools if t.name != "web_search"]
|
||||
logger.info("[AgentBridge] web_search tool removed (API key no longer available)")
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentBridge] Failed to refresh conditional tools: {e}")
|
||||
115
bridge/agent_event_handler.py
Normal file
115
bridge/agent_event_handler.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Agent Event Handler - Handles agent events and thinking process output
|
||||
"""
|
||||
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class AgentEventHandler:
|
||||
"""
|
||||
Handles agent events and optionally sends intermediate messages to channel
|
||||
"""
|
||||
|
||||
def __init__(self, context=None, original_callback=None):
|
||||
"""
|
||||
Initialize event handler
|
||||
|
||||
Args:
|
||||
context: COW context (for accessing channel)
|
||||
original_callback: Original event callback to chain
|
||||
"""
|
||||
self.context = context
|
||||
self.original_callback = original_callback
|
||||
|
||||
# Get channel for sending intermediate messages
|
||||
self.channel = None
|
||||
if context:
|
||||
self.channel = context.kwargs.get("channel") if hasattr(context, "kwargs") else None
|
||||
|
||||
# Track current thinking for channel output
|
||||
self.current_thinking = ""
|
||||
self.turn_number = 0
|
||||
|
||||
def handle_event(self, event):
|
||||
"""
|
||||
Main event handler
|
||||
|
||||
Args:
|
||||
event: Event dict with type and data
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
data = event.get("data", {})
|
||||
|
||||
# Dispatch to specific handlers
|
||||
if event_type == "turn_start":
|
||||
self._handle_turn_start(data)
|
||||
elif event_type == "message_update":
|
||||
self._handle_message_update(data)
|
||||
elif event_type == "message_end":
|
||||
self._handle_message_end(data)
|
||||
elif event_type == "tool_execution_start":
|
||||
self._handle_tool_execution_start(data)
|
||||
elif event_type == "tool_execution_end":
|
||||
self._handle_tool_execution_end(data)
|
||||
|
||||
# Call original callback if provided
|
||||
if self.original_callback:
|
||||
self.original_callback(event)
|
||||
|
||||
def _handle_turn_start(self, data):
|
||||
"""Handle turn start event"""
|
||||
self.turn_number = data.get("turn", 0)
|
||||
self.has_tool_calls_in_turn = False
|
||||
self.current_thinking = ""
|
||||
|
||||
def _handle_message_update(self, data):
|
||||
"""Handle message update event (streaming text)"""
|
||||
delta = data.get("delta", "")
|
||||
self.current_thinking += delta
|
||||
|
||||
def _handle_message_end(self, data):
|
||||
"""Handle message end event"""
|
||||
tool_calls = data.get("tool_calls", [])
|
||||
|
||||
# Only send thinking process if followed by tool calls
|
||||
if tool_calls:
|
||||
if self.current_thinking.strip():
|
||||
logger.info(f"💭 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
|
||||
# Send thinking process to channel
|
||||
self._send_to_channel(f"{self.current_thinking.strip()}")
|
||||
else:
|
||||
# No tool calls = final response (logged at agent_stream level)
|
||||
if self.current_thinking.strip():
|
||||
logger.debug(f"💬 {self.current_thinking.strip()[:200]}{'...' if len(self.current_thinking) > 200 else ''}")
|
||||
|
||||
self.current_thinking = ""
|
||||
|
||||
def _handle_tool_execution_start(self, data):
|
||||
"""Handle tool execution start event - logged by agent_stream.py"""
|
||||
pass
|
||||
|
||||
def _handle_tool_execution_end(self, data):
|
||||
"""Handle tool execution end event - logged by agent_stream.py"""
|
||||
pass
|
||||
|
||||
def _send_to_channel(self, message):
|
||||
"""
|
||||
Try to send intermediate message to channel.
|
||||
Skipped in SSE mode because thinking text is already streamed via on_event.
|
||||
"""
|
||||
if self.context and self.context.get("on_event"):
|
||||
return
|
||||
|
||||
if self.channel:
|
||||
try:
|
||||
from bridge.reply import Reply, ReplyType
|
||||
reply = Reply(ReplyType.TEXT, message)
|
||||
self.channel._send(reply, self.context)
|
||||
except Exception as e:
|
||||
logger.debug(f"[AgentEventHandler] Failed to send to channel: {e}")
|
||||
|
||||
def log_summary(self):
|
||||
"""Log execution summary - simplified"""
|
||||
# Summary removed as per user request
|
||||
# Real-time logging during execution is sufficient
|
||||
pass
|
||||
436
bridge/agent_initializer.py
Normal file
436
bridge/agent_initializer.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
Agent Initializer - Handles agent initialization logic
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import time
|
||||
from typing import Optional, List
|
||||
|
||||
from agent.protocol import Agent
|
||||
from agent.tools import ToolManager
|
||||
from common.log import logger
|
||||
from common.utils import expand_path
|
||||
|
||||
|
||||
class AgentInitializer:
|
||||
"""
|
||||
Handles agent initialization including:
|
||||
- Workspace setup
|
||||
- Memory system initialization
|
||||
- Tool loading
|
||||
- System prompt building
|
||||
"""
|
||||
|
||||
def __init__(self, bridge, agent_bridge):
|
||||
"""
|
||||
Initialize agent initializer
|
||||
|
||||
Args:
|
||||
bridge: COW bridge instance
|
||||
agent_bridge: AgentBridge instance (for create_agent method)
|
||||
"""
|
||||
self.bridge = bridge
|
||||
self.agent_bridge = agent_bridge
|
||||
|
||||
def initialize_agent(self, session_id: Optional[str] = None) -> Agent:
|
||||
"""
|
||||
Initialize agent for a session
|
||||
|
||||
Args:
|
||||
session_id: Session ID (None for default agent)
|
||||
|
||||
Returns:
|
||||
Initialized agent instance
|
||||
"""
|
||||
from config import conf
|
||||
|
||||
# Get workspace from config
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
|
||||
# Migrate API keys
|
||||
self._migrate_config_to_env(workspace_root)
|
||||
|
||||
# Load environment variables
|
||||
self._load_env_file()
|
||||
|
||||
# Initialize workspace
|
||||
from agent.prompt import ensure_workspace, load_context_files, PromptBuilder
|
||||
workspace_files = ensure_workspace(workspace_root, create_templates=True)
|
||||
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Workspace initialized at: {workspace_root}")
|
||||
|
||||
# Setup memory system
|
||||
memory_manager, memory_tools = self._setup_memory_system(workspace_root, session_id)
|
||||
|
||||
# Load tools
|
||||
tools = self._load_tools(workspace_root, memory_manager, memory_tools, session_id)
|
||||
|
||||
# Initialize scheduler if needed
|
||||
self._initialize_scheduler(tools, session_id)
|
||||
|
||||
# Load context files
|
||||
context_files = load_context_files(workspace_root)
|
||||
|
||||
# Initialize skill manager
|
||||
skill_manager = self._initialize_skill_manager(workspace_root, session_id)
|
||||
|
||||
# Check if first conversation
|
||||
from agent.prompt.workspace import is_first_conversation, mark_conversation_started
|
||||
is_first = is_first_conversation(workspace_root)
|
||||
|
||||
# Build system prompt
|
||||
prompt_builder = PromptBuilder(workspace_dir=workspace_root, language="zh")
|
||||
runtime_info = self._get_runtime_info(workspace_root)
|
||||
|
||||
system_prompt = prompt_builder.build(
|
||||
tools=tools,
|
||||
context_files=context_files,
|
||||
skill_manager=skill_manager,
|
||||
memory_manager=memory_manager,
|
||||
runtime_info=runtime_info,
|
||||
is_first_conversation=is_first
|
||||
)
|
||||
|
||||
if is_first:
|
||||
mark_conversation_started(workspace_root)
|
||||
|
||||
# Get cost control parameters
|
||||
from config import conf
|
||||
max_steps = conf().get("agent_max_steps", 20)
|
||||
max_context_tokens = conf().get("agent_max_context_tokens", 50000)
|
||||
|
||||
# Create agent
|
||||
agent = self.agent_bridge.create_agent(
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
max_steps=max_steps,
|
||||
output_mode="logger",
|
||||
workspace_dir=workspace_root,
|
||||
skill_manager=skill_manager,
|
||||
enable_skills=True,
|
||||
max_context_tokens=max_context_tokens,
|
||||
runtime_info=runtime_info # Pass runtime_info for dynamic time updates
|
||||
)
|
||||
|
||||
# Attach memory manager
|
||||
if memory_manager:
|
||||
agent.memory_manager = memory_manager
|
||||
|
||||
# Restore persisted conversation history for this session
|
||||
if session_id:
|
||||
self._restore_conversation_history(agent, session_id)
|
||||
|
||||
return agent
|
||||
|
||||
def _restore_conversation_history(self, agent, session_id: str) -> None:
|
||||
"""
|
||||
Load persisted conversation messages from SQLite and inject them
|
||||
into the agent's in-memory message list.
|
||||
|
||||
Only runs when conversation persistence is enabled (default: True).
|
||||
Respects agent_max_context_turns to limit how many turns are loaded.
|
||||
"""
|
||||
from config import conf
|
||||
if not conf().get("conversation_persistence", True):
|
||||
return
|
||||
|
||||
try:
|
||||
from agent.memory import get_conversation_store
|
||||
store = get_conversation_store()
|
||||
# On restore, load at most min(10, max_turns // 2) turns so that
|
||||
# a long-running session does not immediately fill the context window
|
||||
# after a restart. The full max_turns budget is reserved for the
|
||||
# live conversation that follows.
|
||||
max_turns = conf().get("agent_max_context_turns", 30)
|
||||
restore_turns = max(4, max_turns // 5)
|
||||
saved = store.load_messages(session_id, max_turns=restore_turns)
|
||||
if saved:
|
||||
with agent.messages_lock:
|
||||
agent.messages = saved
|
||||
logger.debug(
|
||||
f"[AgentInitializer] Restored {len(saved)} messages "
|
||||
f"({restore_turns} turns cap) for session={session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AgentInitializer] Failed to restore conversation history for "
|
||||
f"session={session_id}: {e}"
|
||||
)
|
||||
|
||||
def _load_env_file(self):
|
||||
"""Load environment variables from .env file"""
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(env_file, override=True)
|
||||
except ImportError:
|
||||
logger.warning("[AgentInitializer] python-dotenv not installed")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to load .env file: {e}")
|
||||
|
||||
def _setup_memory_system(self, workspace_root: str, session_id: Optional[str] = None):
|
||||
"""
|
||||
Setup memory system
|
||||
|
||||
Returns:
|
||||
(memory_manager, memory_tools) tuple
|
||||
"""
|
||||
memory_manager = None
|
||||
memory_tools = []
|
||||
|
||||
try:
|
||||
from agent.memory import MemoryManager, MemoryConfig, create_embedding_provider
|
||||
from agent.tools import MemorySearchTool, MemoryGetTool
|
||||
from config import conf
|
||||
|
||||
# Get OpenAI config
|
||||
openai_api_key = conf().get("open_ai_api_key", "")
|
||||
openai_api_base = conf().get("open_ai_api_base", "")
|
||||
|
||||
# Initialize embedding provider
|
||||
embedding_provider = None
|
||||
if openai_api_key and openai_api_key not in ["", "YOUR API KEY", "YOUR_API_KEY"]:
|
||||
try:
|
||||
embedding_provider = create_embedding_provider(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key=openai_api_key,
|
||||
api_base=openai_api_base or "https://api.openai.com/v1"
|
||||
)
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] OpenAI embedding initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] OpenAI embedding failed: {e}")
|
||||
|
||||
# Create memory manager
|
||||
memory_config = MemoryConfig(workspace_root=workspace_root)
|
||||
memory_manager = MemoryManager(memory_config, embedding_provider=embedding_provider)
|
||||
|
||||
# Sync memory
|
||||
self._sync_memory(memory_manager, session_id)
|
||||
|
||||
# Create memory tools
|
||||
memory_tools = [
|
||||
MemorySearchTool(memory_manager),
|
||||
MemoryGetTool(memory_manager)
|
||||
]
|
||||
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] Memory system initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Memory system not available: {e}")
|
||||
|
||||
return memory_manager, memory_tools
|
||||
|
||||
def _sync_memory(self, memory_manager, session_id: Optional[str] = None):
|
||||
"""Sync memory database"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
raise RuntimeError("Event loop is closed")
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
if loop.is_running():
|
||||
asyncio.create_task(memory_manager.sync())
|
||||
else:
|
||||
loop.run_until_complete(memory_manager.sync())
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Memory sync failed: {e}")
|
||||
|
||||
def _load_tools(self, workspace_root: str, memory_manager, memory_tools: List, session_id: Optional[str] = None):
|
||||
"""Load all tools"""
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.load_tools()
|
||||
|
||||
tools = []
|
||||
file_config = {
|
||||
"cwd": workspace_root,
|
||||
"memory_manager": memory_manager
|
||||
} if memory_manager else {"cwd": workspace_root}
|
||||
|
||||
for tool_name in tool_manager.tool_classes.keys():
|
||||
try:
|
||||
# Skip web_search if no API key is available
|
||||
if tool_name == "web_search":
|
||||
from agent.tools.web_search.web_search import WebSearch
|
||||
if not WebSearch.is_available():
|
||||
logger.debug("[AgentInitializer] WebSearch skipped - no BOCHA_API_KEY or LINKAI_API_KEY")
|
||||
continue
|
||||
|
||||
# Special handling for EnvConfig tool
|
||||
if tool_name == "env_config":
|
||||
from agent.tools import EnvConfig
|
||||
tool = EnvConfig({"agent_bridge": self.agent_bridge})
|
||||
else:
|
||||
tool = tool_manager.create_tool(tool_name)
|
||||
|
||||
if tool:
|
||||
# Apply workspace config to file operation tools
|
||||
if tool_name in ['read', 'write', 'edit', 'bash', 'grep', 'find', 'ls']:
|
||||
tool.config = file_config
|
||||
tool.cwd = file_config.get("cwd", getattr(tool, 'cwd', None))
|
||||
if 'memory_manager' in file_config:
|
||||
tool.memory_manager = file_config['memory_manager']
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to load tool {tool_name}: {e}")
|
||||
|
||||
# Add memory tools
|
||||
if memory_tools:
|
||||
tools.extend(memory_tools)
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Added {len(memory_tools)} memory tools")
|
||||
|
||||
if session_id is None:
|
||||
logger.info(f"[AgentInitializer] Loaded {len(tools)} tools: {[t.name for t in tools]}")
|
||||
|
||||
return tools
|
||||
|
||||
def _initialize_scheduler(self, tools: List, session_id: Optional[str] = None):
|
||||
"""Initialize scheduler service if needed"""
|
||||
if not self.agent_bridge.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import init_scheduler
|
||||
if init_scheduler(self.agent_bridge):
|
||||
self.agent_bridge.scheduler_initialized = True
|
||||
if session_id is None:
|
||||
logger.info("[AgentInitializer] Scheduler service initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize scheduler: {e}")
|
||||
|
||||
# Inject scheduler dependencies
|
||||
if self.agent_bridge.scheduler_initialized:
|
||||
try:
|
||||
from agent.tools.scheduler.integration import get_task_store, get_scheduler_service
|
||||
from agent.tools import SchedulerTool
|
||||
from config import conf
|
||||
|
||||
task_store = get_task_store()
|
||||
scheduler_service = get_scheduler_service()
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, SchedulerTool):
|
||||
tool.task_store = task_store
|
||||
tool.scheduler_service = scheduler_service
|
||||
if not tool.config:
|
||||
tool.config = {}
|
||||
raw_ct = conf().get("channel_type", "unknown")
|
||||
if isinstance(raw_ct, list):
|
||||
ct = raw_ct[0] if raw_ct else "unknown"
|
||||
elif isinstance(raw_ct, str) and "," in raw_ct:
|
||||
ct = raw_ct.split(",")[0].strip()
|
||||
else:
|
||||
ct = raw_ct
|
||||
tool.config["channel_type"] = ct
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to inject scheduler dependencies: {e}")
|
||||
|
||||
def _initialize_skill_manager(self, workspace_root: str, session_id: Optional[str] = None):
|
||||
"""Initialize skill manager"""
|
||||
try:
|
||||
from agent.skills import SkillManager
|
||||
skill_manager = SkillManager(custom_dir=os.path.join(workspace_root, "skills"))
|
||||
return skill_manager
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to initialize SkillManager: {e}")
|
||||
return None
|
||||
|
||||
def _get_runtime_info(self, workspace_root: str):
|
||||
"""Get runtime information with dynamic time support"""
|
||||
from config import conf
|
||||
|
||||
def get_current_time():
|
||||
"""Get current time dynamically - called each time system prompt is accessed"""
|
||||
now = datetime.datetime.now()
|
||||
|
||||
# Get timezone info
|
||||
try:
|
||||
offset = -time.timezone if not time.daylight else -time.altzone
|
||||
hours = offset // 3600
|
||||
minutes = (offset % 3600) // 60
|
||||
timezone_name = f"UTC{hours:+03d}:{minutes:02d}" if minutes else f"UTC{hours:+03d}"
|
||||
except Exception:
|
||||
timezone_name = "UTC"
|
||||
|
||||
# Chinese weekday mapping
|
||||
weekday_map = {
|
||||
'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三',
|
||||
'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'
|
||||
}
|
||||
weekday_zh = weekday_map.get(now.strftime("%A"), now.strftime("%A"))
|
||||
|
||||
return {
|
||||
'time': now.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
'weekday': weekday_zh,
|
||||
'timezone': timezone_name
|
||||
}
|
||||
|
||||
return {
|
||||
"model": conf().get("model", "unknown"),
|
||||
"workspace": workspace_root,
|
||||
"channel": ", ".join(conf().get("channel_type")) if isinstance(conf().get("channel_type"), list) else conf().get("channel_type", "unknown"),
|
||||
"_get_current_time": get_current_time # Dynamic time function
|
||||
}
|
||||
|
||||
def _migrate_config_to_env(self, workspace_root: str):
|
||||
"""Migrate API keys from config.json to .env file"""
|
||||
from config import conf
|
||||
|
||||
key_mapping = {
|
||||
"open_ai_api_key": "OPENAI_API_KEY",
|
||||
"open_ai_api_base": "OPENAI_API_BASE",
|
||||
"gemini_api_key": "GEMINI_API_KEY",
|
||||
"claude_api_key": "CLAUDE_API_KEY",
|
||||
"linkai_api_key": "LINKAI_API_KEY",
|
||||
}
|
||||
|
||||
env_file = expand_path("~/.cow/.env")
|
||||
|
||||
# Read existing env vars
|
||||
existing_env_vars = {}
|
||||
if os.path.exists(env_file):
|
||||
try:
|
||||
with open(env_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, _ = line.split('=', 1)
|
||||
existing_env_vars[key.strip()] = True
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to read .env file: {e}")
|
||||
|
||||
# Check which keys need migration
|
||||
keys_to_migrate = {}
|
||||
for config_key, env_key in key_mapping.items():
|
||||
if env_key in existing_env_vars:
|
||||
continue
|
||||
value = conf().get(config_key, "")
|
||||
if value and value.strip():
|
||||
keys_to_migrate[env_key] = value.strip()
|
||||
|
||||
# Write new keys
|
||||
if keys_to_migrate:
|
||||
try:
|
||||
env_dir = os.path.dirname(env_file)
|
||||
if not os.path.exists(env_dir):
|
||||
os.makedirs(env_dir, exist_ok=True)
|
||||
if not os.path.exists(env_file):
|
||||
open(env_file, 'a').close()
|
||||
|
||||
with open(env_file, 'a', encoding='utf-8') as f:
|
||||
f.write('\n# Auto-migrated from config.json\n')
|
||||
for key, value in keys_to_migrate.items():
|
||||
f.write(f'{key}={value}\n')
|
||||
os.environ[key] = value
|
||||
|
||||
logger.info(f"[AgentInitializer] Migrated {len(keys_to_migrate)} API keys to .env: {list(keys_to_migrate.keys())}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AgentInitializer] Failed to migrate API keys: {e}")
|
||||
142
bridge/bridge.py
142
bridge/bridge.py
@@ -1,50 +1,146 @@
|
||||
from models.bot_factory import create_bot
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from common.log import logger
|
||||
from bot import bot_factory
|
||||
from common.singleton import singleton
|
||||
from voice import voice_factory
|
||||
from config import conf
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from translate.factory import create_translator
|
||||
from voice.factory import create_voice
|
||||
|
||||
|
||||
@singleton
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
self.btype={
|
||||
self.btype = {
|
||||
"chat": const.CHATGPT,
|
||||
"voice_to_text": conf().get("voice_to_text", "openai"),
|
||||
"text_to_voice": conf().get("text_to_voice", "google")
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
}
|
||||
model_type = conf().get("model")
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype['chat'] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype['chat'] = const.CHATGPTONAZURE
|
||||
self.bots={}
|
||||
# 这边取配置的模型
|
||||
bot_type = conf().get("bot_type")
|
||||
if bot_type:
|
||||
self.btype["chat"] = bot_type
|
||||
else:
|
||||
model_type = conf().get("model") or const.GPT_41_MINI
|
||||
|
||||
# Ensure model_type is string to prevent AttributeError when using startswith()
|
||||
# This handles cases where numeric model names (e.g., "1") are parsed as integers from YAML
|
||||
if not isinstance(model_type, str):
|
||||
logger.warning(f"[Bridge] model_type is not a string: {model_type} (type: {type(model_type).__name__}), converting to string")
|
||||
model_type = str(model_type)
|
||||
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype["chat"] = const.CHATGPTONAZURE
|
||||
if model_type in ["wenxin", "wenxin-4"]:
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
# Support Qwen3 and other DashScope models
|
||||
if model_type and (model_type.startswith("qwen") or model_type.startswith("qwq") or model_type.startswith("qvq")):
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and model_type.startswith("gemini"):
|
||||
self.btype["chat"] = const.GEMINI
|
||||
if model_type and model_type.startswith("glm"):
|
||||
self.btype["chat"] = const.ZHIPU_AI
|
||||
if model_type and model_type.startswith("claude"):
|
||||
self.btype["chat"] = const.CLAUDEAPI
|
||||
|
||||
def get_bot(self,typename):
|
||||
if model_type in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
if model_type and model_type.startswith("kimi"):
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
|
||||
if model_type and model_type.startswith("doubao"):
|
||||
self.btype["chat"] = const.DOUBAO
|
||||
|
||||
if model_type in [const.MODELSCOPE]:
|
||||
self.btype["chat"] = const.MODELSCOPE
|
||||
|
||||
# MiniMax models
|
||||
if model_type and (model_type in ["abab6.5-chat", "abab6.5"] or model_type.lower().startswith("minimax")):
|
||||
self.btype["chat"] = const.MiniMax
|
||||
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
self.btype["chat"] = const.LINKAI
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
self.btype["voice_to_text"] = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
self.btype["text_to_voice"] = const.LINKAI
|
||||
|
||||
self.bots = {}
|
||||
self.chat_bots = {}
|
||||
self._agent_bridge = None
|
||||
|
||||
# 模型对应的接口
|
||||
def get_bot(self, typename):
|
||||
if self.bots.get(typename) is None:
|
||||
logger.info("create bot {} for {}".format(self.btype[typename],typename))
|
||||
logger.info("create bot {} for {}".format(self.btype[typename], typename))
|
||||
if typename == "text_to_voice":
|
||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||
self.bots[typename] = create_voice(self.btype[typename])
|
||||
elif typename == "voice_to_text":
|
||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||
self.bots[typename] = create_voice(self.btype[typename])
|
||||
elif typename == "chat":
|
||||
self.bots[typename] = bot_factory.create_bot(self.btype[typename])
|
||||
self.bots[typename] = create_bot(self.btype[typename])
|
||||
elif typename == "translate":
|
||||
self.bots[typename] = create_translator(self.btype[typename])
|
||||
return self.bots[typename]
|
||||
|
||||
def get_bot_type(self,typename):
|
||||
|
||||
def get_bot_type(self, typename):
|
||||
return self.btype[typename]
|
||||
|
||||
|
||||
def fetch_reply_content(self, query, context : Context) -> Reply:
|
||||
def fetch_reply_content(self, query, context: Context) -> Reply:
|
||||
return self.get_bot("chat").reply(query, context)
|
||||
|
||||
|
||||
def fetch_voice_to_text(self, voiceFile) -> Reply:
|
||||
return self.get_bot("voice_to_text").voiceToText(voiceFile)
|
||||
|
||||
def fetch_text_to_voice(self, text) -> Reply:
|
||||
return self.get_bot("text_to_voice").textToVoice(text)
|
||||
|
||||
def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
|
||||
return self.get_bot("translate").translate(text, from_lang, to_lang)
|
||||
|
||||
def find_chat_bot(self, bot_type: str):
|
||||
if self.chat_bots.get(bot_type) is None:
|
||||
self.chat_bots[bot_type] = create_bot(bot_type)
|
||||
return self.chat_bots.get(bot_type)
|
||||
|
||||
def reset_bot(self):
|
||||
"""
|
||||
重置bot路由
|
||||
"""
|
||||
self.__init__()
|
||||
|
||||
def get_agent_bridge(self):
|
||||
"""
|
||||
Get agent bridge for agent-based conversations
|
||||
"""
|
||||
if self._agent_bridge is None:
|
||||
from bridge.agent_bridge import AgentBridge
|
||||
self._agent_bridge = AgentBridge(self)
|
||||
return self._agent_bridge
|
||||
|
||||
def fetch_agent_reply(self, query: str, context: Context = None,
|
||||
on_event=None, clear_history: bool = False) -> Reply:
|
||||
"""
|
||||
Use super agent to handle the query
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
context: Context object
|
||||
on_event: Event callback for streaming
|
||||
clear_history: Whether to clear conversation history
|
||||
|
||||
Returns:
|
||||
Reply object
|
||||
"""
|
||||
agent_bridge = self.get_agent_bridge()
|
||||
return agent_bridge.agent_reply(query, context, on_event, clear_history)
|
||||
|
||||
@@ -2,36 +2,49 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class ContextType (Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE = 3 # 图片消息
|
||||
IMAGE_CREATE = 10 # 创建图片命令
|
||||
|
||||
|
||||
class ContextType(Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE = 3 # 图片消息
|
||||
FILE = 4 # 文件信息
|
||||
VIDEO = 5 # 视频信息
|
||||
SHARING = 6 # 分享信息
|
||||
|
||||
IMAGE_CREATE = 10 # 创建图片命令
|
||||
ACCEPT_FRIEND = 19 # 同意好友请求
|
||||
JOIN_GROUP = 20 # 加入群聊
|
||||
PATPAT = 21 # 拍了拍
|
||||
FUNCTION = 22 # 函数调用
|
||||
EXIT_GROUP = 23 #退出
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()):
|
||||
def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
|
||||
self.type = type
|
||||
self.content = content
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __contains__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
return self.type is not None
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
return self.content is not None
|
||||
else:
|
||||
return key in self.kwargs
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
return self.type
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
return self.content
|
||||
else:
|
||||
return self.kwargs[key]
|
||||
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
@@ -39,20 +52,20 @@ class Context:
|
||||
return default
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
self.type = value
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
self.content = value
|
||||
else:
|
||||
self.kwargs[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
self.type = None
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
self.content = None
|
||||
else:
|
||||
del self.kwargs[key]
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
||||
|
||||
@@ -1,22 +1,31 @@
|
||||
|
||||
# encoding:utf-8
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ReplyType(Enum):
|
||||
TEXT = 1 # 文本
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
|
||||
TEXT = 1 # 文本
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
VIDEO_URL = 5 # 视频URL
|
||||
FILE = 6 # 文件
|
||||
CARD = 7 # 微信名片,仅支持ntchat
|
||||
INVITE_ROOM = 8 # 邀请好友进群
|
||||
INFO = 9
|
||||
ERROR = 10
|
||||
TEXT_ = 11 # 强制文本
|
||||
VIDEO = 12
|
||||
MINIAPP = 13 # 小程序
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Reply:
|
||||
def __init__(self, type : ReplyType = None , content = None):
|
||||
def __init__(self, type: ReplyType = None, content=None):
|
||||
self.type = type
|
||||
self.content = content
|
||||
|
||||
def __str__(self):
|
||||
return "Reply(type={}, content={})".format(self.type, self.content)
|
||||
return "Reply(type={}, content={})".format(self.type, self.content)
|
||||
|
||||
@@ -5,15 +5,26 @@ Message sending channel abstract class
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import *
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class Channel(object):
|
||||
channel_type = ""
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
init channel
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
stop channel gracefully, called before restart
|
||||
"""
|
||||
pass
|
||||
|
||||
def handle_text(self, msg):
|
||||
"""
|
||||
process received msg
|
||||
@@ -27,15 +38,45 @@ class Channel(object):
|
||||
send message to user
|
||||
:param msg: message content
|
||||
:param receiver: receiver channel account
|
||||
:return:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_reply_content(self, query, context : Context=None) -> Reply:
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
def build_reply_content(self, query, context: Context = None) -> Reply:
|
||||
"""
|
||||
Build reply content, using agent if enabled in config
|
||||
"""
|
||||
# Check if agent mode is enabled
|
||||
use_agent = conf().get("agent", False)
|
||||
|
||||
if use_agent:
|
||||
try:
|
||||
logger.info("[Channel] Using agent mode")
|
||||
|
||||
# Add channel_type to context if not present
|
||||
if context and "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
|
||||
# Read on_event callback injected by the channel (e.g. web SSE)
|
||||
on_event = context.get("on_event") if context else None
|
||||
|
||||
# Use agent bridge to handle the query
|
||||
return Bridge().fetch_agent_reply(
|
||||
query=query,
|
||||
context=context,
|
||||
on_event=on_event,
|
||||
clear_history=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Channel] Agent mode failed, fallback to normal mode: {e}")
|
||||
# Fallback to normal mode if agent fails
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
else:
|
||||
# Normal mode
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
|
||||
def build_voice_to_text(self, voice_file) -> Reply:
|
||||
return Bridge().fetch_voice_to_text(voice_file)
|
||||
|
||||
|
||||
def build_text_to_voice(self, text) -> Reply:
|
||||
return Bridge().fetch_text_to_voice(text)
|
||||
|
||||
@@ -1,26 +1,51 @@
|
||||
"""
|
||||
channel factory
|
||||
"""
|
||||
from common import const
|
||||
from .channel import Channel
|
||||
|
||||
def create_channel(channel_type):
|
||||
|
||||
def create_channel(channel_type) -> Channel:
|
||||
"""
|
||||
create a channel instance
|
||||
:param channel_type: channel type code
|
||||
:return: channel instance
|
||||
"""
|
||||
if channel_type == 'wx':
|
||||
ch = Channel()
|
||||
if channel_type == "wx":
|
||||
from channel.wechat.wechat_channel import WechatChannel
|
||||
return WechatChannel()
|
||||
elif channel_type == 'wxy':
|
||||
ch = WechatChannel()
|
||||
elif channel_type == "wxy":
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
return WechatyChannel()
|
||||
elif channel_type == 'terminal':
|
||||
ch = WechatyChannel()
|
||||
elif channel_type == "wcf":
|
||||
from channel.wechat.wcf_channel import WechatfChannel
|
||||
ch = WechatfChannel()
|
||||
elif channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
return TerminalChannel()
|
||||
elif channel_type == 'wechatmp':
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == 'web':
|
||||
from channel.web.web_channel import WebChannel
|
||||
ch = WebChannel()
|
||||
elif channel_type == "wechatmp":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
return WechatMPChannel(passive_reply = True)
|
||||
elif channel_type == 'wechatmp_service':
|
||||
ch = WechatMPChannel(passive_reply=True)
|
||||
elif channel_type == "wechatmp_service":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
return WechatMPChannel(passive_reply = False)
|
||||
raise RuntimeError
|
||||
ch = WechatMPChannel(passive_reply=False)
|
||||
elif channel_type == "wechatcom_app":
|
||||
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
|
||||
ch = WechatComAppChannel()
|
||||
elif channel_type == "wework":
|
||||
from channel.wework.wework_channel import WeworkChannel
|
||||
ch = WeworkChannel()
|
||||
elif channel_type == const.FEISHU:
|
||||
from channel.feishu.feishu_channel import FeiShuChanel
|
||||
ch = FeiShuChanel()
|
||||
elif channel_type == const.DINGTALK:
|
||||
from channel.dingtalk.dingtalk_channel import DingTalkChanel
|
||||
ch = DingTalkChanel()
|
||||
else:
|
||||
raise RuntimeError
|
||||
ch.channel_type = channel_type
|
||||
return ch
|
||||
|
||||
@@ -1,156 +1,217 @@
|
||||
|
||||
|
||||
from asyncio import CancelledError
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from common.dequeue import Dequeue
|
||||
from channel.channel import Channel
|
||||
from bridge.reply import *
|
||||
from asyncio import CancelledError
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from bridge.context import *
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from bridge.reply import *
|
||||
from channel.channel import Channel
|
||||
from common.dequeue import Dequeue
|
||||
from common import memory
|
||||
from plugins import *
|
||||
|
||||
try:
|
||||
from voice.audio_convert import any_to_wav
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
|
||||
|
||||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
|
||||
def __init__(self):
|
||||
# Instance-level attributes so each channel subclass has its own
|
||||
# independent session queue and lock. Previously these were class-level,
|
||||
# which caused contexts from one channel (e.g. Feishu) to be consumed
|
||||
# by another channel's consume() thread (e.g. Web), leading to errors
|
||||
# like "No request_id found in context".
|
||||
self.futures = {}
|
||||
self.sessions = {}
|
||||
self.lock = threading.Lock()
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
|
||||
|
||||
# 根据消息构造context,消息内容相关的触发项写在这里
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
# context首次传入时,origin_ctype是None,
|
||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||
if 'origin_ctype' not in context:
|
||||
context['origin_ctype'] = ctype
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
# context首次传入时,receiver是None,根据类型设置receiver
|
||||
first_in = 'receiver' not in context
|
||||
first_in = "receiver" not in context
|
||||
# 群名匹配过程,设置session_id和receiver
|
||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
||||
config = conf()
|
||||
cmsg = context['msg']
|
||||
if cmsg.from_user_id == self.user_id and not config.get('trigger_by_self', True):
|
||||
logger.debug("[WX]self message skipped")
|
||||
return None
|
||||
cmsg = context["msg"]
|
||||
user_data = conf().get_user_data(cmsg.from_user_id)
|
||||
context["openai_api_key"] = user_data.get("openai_api_key")
|
||||
context["gpt_model"] = user_data.get("gpt_model")
|
||||
if context.get("isgroup", False):
|
||||
group_name = cmsg.other_user_nickname
|
||||
group_id = cmsg.other_user_id
|
||||
|
||||
group_name_white_list = config.get('group_name_white_list', [])
|
||||
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
|
||||
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
session_id = cmsg.actual_user_id
|
||||
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
|
||||
group_name_white_list = config.get("group_name_white_list", [])
|
||||
group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
|
||||
if any(
|
||||
[
|
||||
group_name in group_name_white_list,
|
||||
"ALL_GROUP" in group_name_white_list,
|
||||
check_contain(group_name, group_name_keyword_white_list),
|
||||
]
|
||||
):
|
||||
# Check global group_shared_session config first
|
||||
group_shared_session = conf().get("group_shared_session", True)
|
||||
if group_shared_session:
|
||||
# All users in the group share the same session
|
||||
session_id = group_id
|
||||
else:
|
||||
# Check group-specific whitelist (legacy behavior)
|
||||
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
|
||||
session_id = cmsg.actual_user_id
|
||||
if any(
|
||||
[
|
||||
group_name in group_chat_in_one_session,
|
||||
"ALL_GROUP" in group_chat_in_one_session,
|
||||
]
|
||||
):
|
||||
session_id = group_id
|
||||
else:
|
||||
logger.debug(f"No need reply, groupName not in whitelist, group_name={group_name}")
|
||||
return None
|
||||
context['session_id'] = session_id
|
||||
context['receiver'] = group_id
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = group_id
|
||||
else:
|
||||
context['session_id'] = cmsg.other_user_id
|
||||
context['receiver'] = cmsg.other_user_id
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
|
||||
context = e_context["context"]
|
||||
if e_context.is_pass() or context is None:
|
||||
return context
|
||||
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
|
||||
logger.debug("[chat_channel]self message skipped")
|
||||
return None
|
||||
|
||||
# 消息内容匹配过程,并处理content
|
||||
if ctype == ContextType.TEXT:
|
||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||
logger.debug("[WX]reference query skipped")
|
||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||
logger.debug(content)
|
||||
logger.debug("[chat_channel]reference query skipped")
|
||||
return None
|
||||
|
||||
if context.get("isgroup", False): # 群聊
|
||||
|
||||
nick_name_black_list = conf().get("nick_name_black_list", [])
|
||||
if context.get("isgroup", False): # 群聊
|
||||
# 校验关键字
|
||||
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
|
||||
match_contain = check_contain(content, conf().get('group_chat_keyword'))
|
||||
match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
|
||||
match_contain = check_contain(content, conf().get("group_chat_keyword"))
|
||||
flag = False
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
flag = True
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
if context['msg'].is_at:
|
||||
logger.info("[WX]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
if context["msg"].to_user_id != context["msg"].actual_user_id:
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
flag = True
|
||||
pattern = f'@{self.name}(\u2005|\u0020)'
|
||||
content = re.sub(pattern, r'', content)
|
||||
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
if context["msg"].is_at:
|
||||
nick_name = context["msg"].actual_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[chat_channel] Nickname {nick_name} in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
logger.info("[chat_channel]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
flag = True
|
||||
self.name = self.name if self.name is not None else "" # 部分渠道self.name可能没有赋值
|
||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
if isinstance(context["msg"].at_list, list):
|
||||
for at in context["msg"].at_list:
|
||||
pattern = f"@{re.escape(at)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", subtract_res)
|
||||
if subtract_res == content and context["msg"].self_display_name:
|
||||
# 前缀移除后没有变化,使用群昵称再次移除
|
||||
pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
content = subtract_res
|
||||
if not flag:
|
||||
if context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
||||
logger.info("[chat_channel]receive group voice, but checkprefix didn't match")
|
||||
return None
|
||||
else: # 单聊
|
||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix',['']))
|
||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
else: # 单聊
|
||||
nick_name = context["msg"].from_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[chat_channel] Nickname '{nick_name}' in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
|
||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
|
||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||
logger.info("[chat_channel]receive single chat msg, but checkprefix didn't match")
|
||||
return None
|
||||
content = content.strip()
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix",[""]))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, '', 1).strip()
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content
|
||||
if 'desire_rtype' not in context and conf().get('always_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context['desire_rtype'] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if 'desire_rtype' not in context and conf().get('voice_reply_voice') and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context['desire_rtype'] = ReplyType.VOICE
|
||||
|
||||
context.content = content.strip()
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
return context
|
||||
|
||||
def _handle(self, context: Context):
|
||||
if context is None or not context.content:
|
||||
return
|
||||
logger.debug('[WX] ready to handle context: {}'.format(context))
|
||||
logger.debug("[chat_channel] handling context: {}".format(context))
|
||||
# reply的构建步骤
|
||||
reply = self._generate_reply(context)
|
||||
|
||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
||||
# reply的包装步骤
|
||||
reply = self._decorate_reply(context, reply)
|
||||
logger.debug("[chat_channel] decorating reply: {}".format(reply))
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
# reply的包装步骤
|
||||
if reply and reply.content:
|
||||
reply = self._decorate_reply(context, reply)
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
|
||||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_HANDLE_CONTEXT,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass():
|
||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
|
||||
logger.debug("[chat_channel] type={}, content={}".format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
context["channel"] = e_context["channel"]
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE: # 语音消息
|
||||
cmsg = context['msg']
|
||||
cmsg = context["msg"]
|
||||
cmsg.prepare()
|
||||
file_path = context.content
|
||||
wav_path = os.path.splitext(file_path)[0] + '.wav'
|
||||
wav_path = os.path.splitext(file_path)[0] + ".wav"
|
||||
try:
|
||||
any_to_wav(file_path, wav_path)
|
||||
any_to_wav(file_path, wav_path)
|
||||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
||||
logger.warning("[WX]any to wav error, use raw path. " + str(e))
|
||||
logger.warning("[chat_channel]any to wav error, use raw path. " + str(e))
|
||||
wav_path = file_path
|
||||
# 语音识别
|
||||
reply = super().build_voice_to_text(wav_path)
|
||||
@@ -161,32 +222,41 @@ class ChatChannel(Channel):
|
||||
os.remove(wav_path)
|
||||
except Exception as e:
|
||||
pass
|
||||
# logger.warning("[WX]delete temp file error: " + str(e))
|
||||
# logger.warning("[chat_channel]delete temp file error: " + str(e))
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
new_context = self._compose_context(
|
||||
ContextType.TEXT, reply.content, **context.kwargs)
|
||||
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
|
||||
if new_context:
|
||||
reply = self._generate_reply(new_context)
|
||||
else:
|
||||
return
|
||||
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑
|
||||
elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
|
||||
memory.USER_IMAGE_CACHE[context["session_id"]] = {
|
||||
"path": context.content,
|
||||
"msg": context.get("msg")
|
||||
}
|
||||
elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
|
||||
pass
|
||||
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
|
||||
pass
|
||||
else:
|
||||
logger.error('[WX] unknown context type: {}'.format(context.type))
|
||||
logger.warning("[chat_channel] unknown context type: {}".format(context.type))
|
||||
return
|
||||
return reply
|
||||
|
||||
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
desire_rtype = context.get('desire_rtype')
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_DECORATE_REPLY,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
desire_rtype = context.get("desire_rtype")
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
|
||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||
logger.error("[WX]reply type not support: " + str(reply.type))
|
||||
logger.error("[chat_channel]reply type not support: " + str(reply.type))
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "不支持发送的消息类型: " + str(reply.type)
|
||||
|
||||
@@ -196,67 +266,174 @@ class ChatChannel(Channel):
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context.get("isgroup", False):
|
||||
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
|
||||
if not context.get("no_need_at", False):
|
||||
reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "")
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
|
||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "")
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = "["+str(reply.type)+"]\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
||||
reply.content = "[" + str(reply.type) + "]\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
|
||||
pass
|
||||
else:
|
||||
logger.error('[WX] unknown reply type: {}'.format(reply.type))
|
||||
logger.error("[chat_channel] unknown reply type: {}".format(reply.type))
|
||||
return
|
||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
||||
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
|
||||
logger.warning("[chat_channel] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
|
||||
return reply
|
||||
|
||||
def _send_reply(self, context: Context, reply: Reply):
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_SEND_REPLY,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug('[WX] ready to send reply: {}, context: {}'.format(reply, context))
|
||||
logger.debug("[chat_channel] sending reply: {}, context: {}".format(reply, context))
|
||||
|
||||
# 如果是文本回复,尝试提取并发送图片
|
||||
if reply.type == ReplyType.TEXT:
|
||||
self._extract_and_send_images(reply, context)
|
||||
# 如果是图片回复但带有文本内容,先发文本再发图片
|
||||
elif reply.type == ReplyType.IMAGE_URL and hasattr(reply, 'text_content') and reply.text_content:
|
||||
# 先发送文本
|
||||
text_reply = Reply(ReplyType.TEXT, reply.text_content)
|
||||
self._send(text_reply, context)
|
||||
# 短暂延迟后发送图片
|
||||
time.sleep(0.3)
|
||||
self._send(reply, context)
|
||||
else:
|
||||
self._send(reply, context)
|
||||
|
||||
def _extract_and_send_images(self, reply: Reply, context: Context):
|
||||
"""
|
||||
从文本回复中提取图片/视频URL并单独发送
|
||||
支持格式:[图片: /path/to/image.png], [视频: /path/to/video.mp4], , <img src="url">
|
||||
最多发送5个媒体文件
|
||||
"""
|
||||
content = reply.content
|
||||
media_items = [] # [(url, type), ...]
|
||||
|
||||
# 正则提取各种格式的媒体URL
|
||||
patterns = [
|
||||
(r'\[图片:\s*([^\]]+)\]', 'image'), # [图片: /path/to/image.png]
|
||||
(r'\[视频:\s*([^\]]+)\]', 'video'), # [视频: /path/to/video.mp4]
|
||||
(r'!\[.*?\]\(([^\)]+)\)', 'image'), #  - 默认图片
|
||||
(r'<img[^>]+src=["\']([^"\']+)["\']', 'image'), # <img src="url">
|
||||
(r'<video[^>]+src=["\']([^"\']+)["\']', 'video'), # <video src="url">
|
||||
(r'https?://[^\s]+\.(?:jpg|jpeg|png|gif|webp)', 'image'), # 直接的图片URL
|
||||
(r'https?://[^\s]+\.(?:mp4|avi|mov|wmv|flv)', 'video'), # 直接的视频URL
|
||||
]
|
||||
|
||||
for pattern, media_type in patterns:
|
||||
matches = re.findall(pattern, content, re.IGNORECASE)
|
||||
for match in matches:
|
||||
media_items.append((match, media_type))
|
||||
|
||||
# 去重(保持顺序)并限制最多5个
|
||||
seen = set()
|
||||
unique_items = []
|
||||
for url, mtype in media_items:
|
||||
if url not in seen:
|
||||
seen.add(url)
|
||||
unique_items.append((url, mtype))
|
||||
media_items = unique_items[:5]
|
||||
|
||||
if media_items:
|
||||
logger.info(f"[chat_channel] Extracted {len(media_items)} media item(s) from reply")
|
||||
|
||||
# 先发送文本(保持原文本不变)
|
||||
logger.info(f"[chat_channel] Sending text content before media: {reply.content[:100]}...")
|
||||
self._send(reply, context)
|
||||
logger.info(f"[chat_channel] Text sent, now sending {len(media_items)} media item(s)")
|
||||
|
||||
# 然后逐个发送媒体文件
|
||||
for i, (url, media_type) in enumerate(media_items):
|
||||
try:
|
||||
# 判断是本地文件还是URL
|
||||
if url.startswith(('http://', 'https://')):
|
||||
# 网络资源
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型发送
|
||||
media_reply = Reply(ReplyType.FILE, url)
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, url)
|
||||
elif os.path.exists(url):
|
||||
# 本地文件
|
||||
if media_type == 'video':
|
||||
# 视频使用 FILE 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.FILE, f"file://{url}")
|
||||
media_reply.file_name = os.path.basename(url)
|
||||
else:
|
||||
# 图片使用 IMAGE_URL 类型,转换为 file:// URL
|
||||
media_reply = Reply(ReplyType.IMAGE_URL, f"file://{url}")
|
||||
else:
|
||||
logger.warning(f"[chat_channel] Media file not found or invalid URL: {url}")
|
||||
continue
|
||||
|
||||
# 发送媒体文件(添加小延迟避免频率限制)
|
||||
if i > 0:
|
||||
time.sleep(0.5)
|
||||
self._send(media_reply, context)
|
||||
logger.info(f"[chat_channel] Sent {media_type} {i+1}/{len(media_items)}: {url[:50]}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[chat_channel] Failed to send {media_type} {url}: {e}")
|
||||
else:
|
||||
# 没有媒体文件,正常发送文本
|
||||
self._send(reply, context)
|
||||
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt = 0):
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt=0):
|
||||
try:
|
||||
self.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error('[WX] sendMsg error: {}'.format(str(e)))
|
||||
logger.error("[chat_channel] sendMsg error: {}".format(str(e)))
|
||||
if isinstance(e, NotImplementedError):
|
||||
return
|
||||
logger.exception(e)
|
||||
if retry_cnt < 2:
|
||||
time.sleep(3+3*retry_cnt)
|
||||
self._send(reply, context, retry_cnt+1)
|
||||
time.sleep(3 + 3 * retry_cnt)
|
||||
self._send(reply, context, retry_cnt + 1)
|
||||
|
||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
|
||||
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
|
||||
logger.debug("Worker return success, session_id = {}".format(session_id))
|
||||
|
||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.exception("Worker return exception: {}".format(exception))
|
||||
|
||||
def _thread_pool_callback(self, session_id, **kwargs):
|
||||
def func(worker:Future):
|
||||
def func(worker: Future):
|
||||
try:
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
self._fail_callback(session_id, exception = worker_exception, **kwargs)
|
||||
self._fail_callback(session_id, exception=worker_exception, **kwargs)
|
||||
else:
|
||||
self._success_callback(session_id, **kwargs)
|
||||
except CancelledError as e:
|
||||
logger.info("Worker cancelled, session_id = {}".format(session_id))
|
||||
except Exception as e:
|
||||
logger.exception("Worker raise exception: {}".format(e))
|
||||
with self.lock:
|
||||
self.sessions[session_id][1].release()
|
||||
|
||||
return func
|
||||
|
||||
def produce(self, context: Context):
|
||||
session_id = context['session_id']
|
||||
session_id = context["session_id"]
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = [Dequeue(), threading.BoundedSemaphore(conf().get("concurrency_in_session", 4))]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
self.sessions[session_id] = [
|
||||
Dequeue(),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
|
||||
]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
else:
|
||||
self.sessions[session_id][0].put(context)
|
||||
|
||||
@@ -265,46 +442,49 @@ class ChatChannel(Channel):
|
||||
while True:
|
||||
with self.lock:
|
||||
session_ids = list(self.sessions.keys())
|
||||
for session_id in session_ids:
|
||||
for session_id in session_ids:
|
||||
with self.lock:
|
||||
context_queue, semaphore = self.sessions[session_id]
|
||||
if semaphore.acquire(blocking = False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[WX] consume context: {}".format(context))
|
||||
future:Future = self.handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context = context))
|
||||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[chat_channel] consume context: {}".format(context))
|
||||
future: Future = handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||
with self.lock:
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
self.futures[session_id].append(future)
|
||||
elif semaphore._initial_value == semaphore._value+1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
with self.lock:
|
||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
||||
del self.sessions[session_id]
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.2)
|
||||
|
||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
||||
def cancel_session(self, session_id):
|
||||
def cancel_session(self, session_id):
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt>0:
|
||||
if cnt > 0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
|
||||
def cancel_all_session(self):
|
||||
with self.lock:
|
||||
for session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt>0:
|
||||
if cnt > 0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
if not prefix_list:
|
||||
@@ -314,6 +494,7 @@ def check_prefix(content, prefix_list):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
|
||||
def check_contain(content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
|
||||
"""
|
||||
"""
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
||||
|
||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
||||
@@ -20,45 +19,47 @@ other_user_id: 对方的id,如果你是发送者,那这个就是接收者id
|
||||
other_user_nickname: 同上
|
||||
|
||||
is_group: 是否是群消息 (群聊必填)
|
||||
is_at: 是否被at
|
||||
is_at: 是否被at
|
||||
|
||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
||||
actual_user_id: 实际发送者id (群聊必填)
|
||||
actual_user_nickname:实际发送者昵称
|
||||
|
||||
|
||||
|
||||
self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
|
||||
|
||||
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
|
||||
_prepared: 是否已经调用过准备函数
|
||||
_rawmsg: 原始消息对象
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ChatMessage(object):
|
||||
msg_id = None
|
||||
create_time = None
|
||||
|
||||
|
||||
ctype = None
|
||||
content = None
|
||||
|
||||
|
||||
from_user_id = None
|
||||
from_user_nickname = None
|
||||
to_user_id = None
|
||||
to_user_nickname = None
|
||||
other_user_id = None
|
||||
other_user_nickname = None
|
||||
|
||||
my_msg = False
|
||||
self_display_name = None
|
||||
|
||||
is_group = False
|
||||
is_at = False
|
||||
actual_user_id = None
|
||||
actual_user_nickname = None
|
||||
at_list = None
|
||||
|
||||
_prepare_fn = None
|
||||
_prepared = False
|
||||
_rawmsg = None
|
||||
|
||||
|
||||
def __init__(self,_rawmsg):
|
||||
def __init__(self, _rawmsg):
|
||||
self._rawmsg = _rawmsg
|
||||
|
||||
def prepare(self):
|
||||
@@ -67,7 +68,7 @@ class ChatMessage(object):
|
||||
self._prepare_fn()
|
||||
|
||||
def __str__(self):
|
||||
return 'ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}'.format(
|
||||
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}, at_list={}".format(
|
||||
self.msg_id,
|
||||
self.create_time,
|
||||
self.ctype,
|
||||
@@ -82,4 +83,5 @@ class ChatMessage(object):
|
||||
self.is_at,
|
||||
self.actual_user_id,
|
||||
self.actual_user_nickname,
|
||||
)
|
||||
self.at_list
|
||||
)
|
||||
|
||||
895
channel/dingtalk/dingtalk_channel.py
Normal file
895
channel/dingtalk/dingtalk_channel.py
Normal file
@@ -0,0 +1,895 @@
|
||||
"""
|
||||
钉钉通道接入
|
||||
|
||||
@author huiwen
|
||||
@Date 2023/11/28
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
# -*- coding=utf-8 -*-
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
|
||||
import dingtalk_stream
|
||||
from dingtalk_stream import AckMessage
|
||||
from dingtalk_stream.card_replier import AICardReplier
|
||||
from dingtalk_stream.card_replier import AICardStatus
|
||||
from dingtalk_stream.card_replier import CardReplier
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel
|
||||
from common.utils import expand_path
|
||||
from channel.dingtalk.dingtalk_message import DingTalkMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.time_check import time_checker
|
||||
from config import conf
|
||||
|
||||
|
||||
class CustomAICardReplier(CardReplier):
|
||||
def __init__(self, dingtalk_client, incoming_message):
|
||||
super(AICardReplier, self).__init__(dingtalk_client, incoming_message)
|
||||
|
||||
def start(
|
||||
self,
|
||||
card_template_id: str,
|
||||
card_data: dict,
|
||||
recipients: list = None,
|
||||
support_forward: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
AI卡片的创建接口
|
||||
:param support_forward:
|
||||
:param recipients:
|
||||
:param card_template_id:
|
||||
:param card_data:
|
||||
:return:
|
||||
"""
|
||||
card_data_with_status = copy.deepcopy(card_data)
|
||||
card_data_with_status["flowStatus"] = AICardStatus.PROCESSING
|
||||
return self.create_and_send_card(
|
||||
card_template_id,
|
||||
card_data_with_status,
|
||||
at_sender=True,
|
||||
at_all=False,
|
||||
recipients=recipients,
|
||||
support_forward=support_forward,
|
||||
)
|
||||
|
||||
|
||||
# 对 AICardReplier 进行猴子补丁
|
||||
AICardReplier.start = CustomAICardReplier.start
|
||||
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: DingTalkMessage):
|
||||
msgId = cmsg.msg_id
|
||||
if msgId in self.receivedMsgs:
|
||||
logger.info("DingTalk message {} already received, ignore".format(msgId))
|
||||
return
|
||||
self.receivedMsgs[msgId] = True
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[DingTalk] History message {} skipped".format(msgId))
|
||||
return
|
||||
if cmsg.my_msg and not cmsg.is_group:
|
||||
logger.debug("[DingTalk] My message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@singleton
|
||||
class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
|
||||
def setup_logger(self):
|
||||
# Suppress verbose logs from dingtalk_stream SDK
|
||||
logging.getLogger("dingtalk_stream").setLevel(logging.WARNING)
|
||||
return logging.getLogger("DingTalk")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
super(dingtalk_stream.ChatbotHandler, self).__init__()
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
self._stream_client = None
|
||||
self._running = False
|
||||
self._event_loop = None
|
||||
logger.debug("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
# 单聊无需前缀
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
# Access token cache
|
||||
self._access_token = None
|
||||
self._access_token_expires_at = 0
|
||||
# Robot code cache (extracted from incoming messages)
|
||||
self._robot_code = None
|
||||
|
||||
def startup(self):
|
||||
import asyncio
|
||||
self.dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
self.dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
self._running = True
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
self._stream_client = client
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
logger.info("[DingTalk] ✅ Stream client initialized, ready to receive messages")
|
||||
_first_connect = True
|
||||
while self._running:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
self._event_loop = loop
|
||||
try:
|
||||
if not _first_connect:
|
||||
logger.info("[DingTalk] Reconnecting...")
|
||||
_first_connect = False
|
||||
loop.run_until_complete(client.start())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("[DingTalk] Startup loop received stop signal, exiting")
|
||||
break
|
||||
except Exception as e:
|
||||
if not self._running:
|
||||
break
|
||||
logger.warning(f"[DingTalk] Stream connection error: {e}, reconnecting in 3s...")
|
||||
time.sleep(3)
|
||||
finally:
|
||||
self._event_loop = None
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[DingTalk] Startup loop exited")
|
||||
|
||||
def stop(self):
|
||||
import asyncio
|
||||
logger.info("[DingTalk] stop() called, setting _running=False")
|
||||
self._running = False
|
||||
loop = self._event_loop
|
||||
if loop and not loop.is_closed():
|
||||
try:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
logger.info("[DingTalk] Sent stop signal to event loop")
|
||||
except Exception as e:
|
||||
logger.warning(f"[DingTalk] Error stopping event loop: {e}")
|
||||
self._stream_client = None
|
||||
logger.info("[DingTalk] stop() completed")
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
获取企业内部应用的 access_token
|
||||
文档: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 如果 token 还没过期,直接返回缓存的 token
|
||||
if self._access_token and current_time < self._access_token_expires_at:
|
||||
return self._access_token
|
||||
|
||||
# 获取新的 access_token
|
||||
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"appKey": self.dingtalk_client_id,
|
||||
"appSecret": self.dingtalk_client_secret
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200 and "accessToken" in result:
|
||||
self._access_token = result["accessToken"]
|
||||
# Token 有效期为 2 小时,提前 5 分钟刷新
|
||||
self._access_token_expires_at = current_time + result.get("expireIn", 7200) - 300
|
||||
logger.info("[DingTalk] Access token refreshed successfully")
|
||||
return self._access_token
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to get access token: {result}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error getting access token: {e}")
|
||||
return None
|
||||
|
||||
def send_single_message(self, user_id: str, content: str, robot_code: str) -> bool:
|
||||
"""
|
||||
Send message to single user (private chat)
|
||||
API: https://open.dingtalk.com/document/orgapp/chatbots-send-one-on-one-chat-messages-in-batches
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Failed to send single message: Access token not available.")
|
||||
return False
|
||||
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] Cannot send single message: robot_code is required")
|
||||
return False
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"content": content}),
|
||||
"msgKey": "sampleText",
|
||||
"userIds": [user_id],
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
logger.info(f"[DingTalk] Sending single message to user {user_id} with robot_code {robot_code}")
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200 and result.get("processQueryKey"):
|
||||
logger.info(f"[DingTalk] Single message sent successfully to {user_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to send single message: {result}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error sending single message: {e}")
|
||||
return False
|
||||
|
||||
def send_group_message(self, conversation_id: str, content: str, robot_code: str = None):
|
||||
"""
|
||||
主动发送群消息
|
||||
文档: https://open.dingtalk.com/document/orgapp/the-robot-sends-a-group-message
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID (openConversationId)
|
||||
content: 消息内容
|
||||
robot_code: 机器人编码,默认使用 dingtalk_client_id
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot send group message: no access token")
|
||||
return False
|
||||
|
||||
# Validate robot_code
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] Cannot send group message: robot_code is required")
|
||||
return False
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"content": content}),
|
||||
"msgKey": "sampleText",
|
||||
"openConversationId": conversation_id,
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(f"[DingTalk] Group message sent successfully to {conversation_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to send group message: {result}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error sending group message: {e}")
|
||||
return False
|
||||
|
||||
def upload_media(self, file_path: str, media_type: str = "image") -> str:
|
||||
"""
|
||||
上传媒体文件到钉钉
|
||||
|
||||
Args:
|
||||
file_path: 本地文件路径或URL
|
||||
media_type: 媒体类型 (image, video, voice, file)
|
||||
|
||||
Returns:
|
||||
media_id,如果上传失败返回 None
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot upload media: no access token")
|
||||
return None
|
||||
|
||||
# 处理 file:// URL
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
|
||||
# 如果是 HTTP URL,先下载
|
||||
if file_path.startswith("http://") or file_path.startswith("https://"):
|
||||
try:
|
||||
import uuid
|
||||
response = requests.get(file_path, timeout=(5, 60))
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[DingTalk] Failed to download file from URL: {file_path}")
|
||||
return None
|
||||
|
||||
# 保存到临时文件
|
||||
file_name = os.path.basename(file_path) or f"media_{uuid.uuid4()}"
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
temp_file = os.path.join(tmp_dir, file_name)
|
||||
|
||||
with open(temp_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
file_path = temp_file
|
||||
logger.info(f"[DingTalk] Downloaded file to {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error downloading file: {e}")
|
||||
return None
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"[DingTalk] File not found: {file_path}")
|
||||
return None
|
||||
|
||||
# 上传到钉钉
|
||||
# 钉钉上传媒体文件 API: https://open.dingtalk.com/document/orgapp/upload-media-files
|
||||
url = "https://oapi.dingtalk.com/media/upload"
|
||||
params = {
|
||||
"access_token": access_token,
|
||||
"type": media_type
|
||||
}
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"media": (os.path.basename(file_path), f)}
|
||||
response = requests.post(url, params=params, files=files, timeout=(5, 60))
|
||||
result = response.json()
|
||||
|
||||
if result.get("errcode") == 0:
|
||||
media_id = result.get("media_id")
|
||||
logger.info(f"[DingTalk] Media uploaded successfully, media_id={media_id}")
|
||||
return media_id
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to upload media: {result}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error uploading media: {e}")
|
||||
return None
|
||||
|
||||
def send_image_with_media_id(self, access_token: str, media_id: str, incoming_message, is_group: bool) -> bool:
|
||||
"""
|
||||
发送图片消息(使用 media_id)
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
media_id: 媒体ID
|
||||
incoming_message: 钉钉消息对象
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
msg_param = {
|
||||
"photoURL": media_id # 钉钉图片消息使用 photoURL 字段
|
||||
}
|
||||
|
||||
body = {
|
||||
"robotCode": incoming_message.robot_code,
|
||||
"msgKey": "sampleImageMsg",
|
||||
"msgParam": json.dumps(msg_param),
|
||||
}
|
||||
|
||||
if is_group:
|
||||
# 群聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
body["openConversationId"] = incoming_message.conversation_id
|
||||
else:
|
||||
# 单聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
body["userIds"] = [incoming_message.sender_staff_id]
|
||||
|
||||
try:
|
||||
response = requests.post(url=url, headers=headers, json=body, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
logger.info(f"[DingTalk] Image send result: {response.text}")
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Send image error: {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Send image exception: {e}")
|
||||
return False
|
||||
|
||||
def send_image_message(self, receiver: str, media_id: str, is_group: bool, robot_code: str) -> bool:
|
||||
"""
|
||||
发送图片消息
|
||||
|
||||
Args:
|
||||
receiver: 接收者ID (user_id 或 conversation_id)
|
||||
media_id: 媒体ID
|
||||
is_group: 是否为群聊
|
||||
robot_code: 机器人编码
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot send image: no access token")
|
||||
return False
|
||||
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] Cannot send image: robot_code is required")
|
||||
return False
|
||||
|
||||
if is_group:
|
||||
# 发送群聊图片
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"mediaId": media_id}),
|
||||
"msgKey": "sampleImageMsg",
|
||||
"openConversationId": receiver,
|
||||
"robotCode": robot_code
|
||||
}
|
||||
else:
|
||||
# 发送单聊图片
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"msgParam": json.dumps({"mediaId": media_id}),
|
||||
"msgKey": "sampleImageMsg",
|
||||
"userIds": [receiver],
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
logger.info(f"[DingTalk] Image message sent successfully")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to send image message: {result}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Error sending image message: {e}")
|
||||
return False
|
||||
|
||||
def get_image_download_url(self, download_code: str) -> str:
|
||||
"""
|
||||
获取图片下载地址
|
||||
返回一个特殊的 URL 格式:dingtalk://download/{robot_code}:{download_code}
|
||||
后续会在 download_image_file 中使用新版 API 下载
|
||||
"""
|
||||
# 获取 robot_code
|
||||
if not hasattr(self, '_robot_code_cache'):
|
||||
self._robot_code_cache = None
|
||||
|
||||
robot_code = self._robot_code_cache
|
||||
|
||||
if not robot_code:
|
||||
logger.error("[DingTalk] robot_code not available for image download")
|
||||
return None
|
||||
|
||||
# 返回一个特殊的 URL,包含 robot_code 和 download_code
|
||||
logger.info(f"[DingTalk] Successfully got image download URL for code: {download_code}")
|
||||
return f"dingtalk://download/{robot_code}:{download_code}"
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
|
||||
# 缓存 robot_code,用于后续图片下载
|
||||
if hasattr(incoming_message, 'robot_code'):
|
||||
self._robot_code_cache = incoming_message.robot_code
|
||||
|
||||
# Filter out stale messages from before channel startup (offline backlog)
|
||||
create_at = getattr(incoming_message, 'create_at', None)
|
||||
if create_at:
|
||||
msg_age_s = time.time() - int(create_at) / 1000
|
||||
if msg_age_s > 60:
|
||||
logger.warning(f"[DingTalk] stale msg filtered (age={msg_age_s:.0f}s), "
|
||||
f"msg_id={getattr(incoming_message, 'message_id', 'N/A')}")
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
|
||||
image_download_handler = self
|
||||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
|
||||
|
||||
if dingtalk_msg.is_group:
|
||||
self.handle_group(dingtalk_msg)
|
||||
else:
|
||||
self.handle_single(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] process error: {e}", exc_info=True)
|
||||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: DingTalkMessage):
|
||||
# 处理单聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE_CREATE:
|
||||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
|
||||
# 处理文件缓存逻辑
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
# 单聊的 session_id 就是 sender_id
|
||||
session_id = cmsg.from_user_id
|
||||
|
||||
# 如果是单张图片消息,缓存起来
|
||||
if cmsg.ctype == ContextType.IMAGE:
|
||||
if hasattr(cmsg, 'image_path') and cmsg.image_path:
|
||||
file_cache.add(session_id, cmsg.image_path, file_type='image')
|
||||
logger.info(f"[DingTalk] Image cached for session {session_id}, waiting for user query...")
|
||||
# 单张图片不直接处理,等待用户提问
|
||||
return
|
||||
|
||||
# 如果是文本消息,检查是否有缓存的文件
|
||||
if cmsg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
# 将缓存的文件附加到文本消息中
|
||||
file_refs = []
|
||||
for file_info in cached_files:
|
||||
file_path = file_info['path']
|
||||
file_type = file_info['type']
|
||||
if file_type == 'image':
|
||||
file_refs.append(f"[图片: {file_path}]")
|
||||
elif file_type == 'video':
|
||||
file_refs.append(f"[视频: {file_path}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {file_path}]")
|
||||
|
||||
cmsg.content = cmsg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[DingTalk] Attached {len(cached_files)} cached file(s) to user query")
|
||||
# 清除缓存
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: DingTalkMessage):
|
||||
# 处理群聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE_CREATE:
|
||||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
|
||||
# 处理文件缓存逻辑
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
# 群聊的 session_id
|
||||
if conf().get("group_shared_session", True):
|
||||
session_id = cmsg.other_user_id # conversation_id
|
||||
else:
|
||||
session_id = cmsg.from_user_id + "_" + cmsg.other_user_id
|
||||
|
||||
# 如果是单张图片消息,缓存起来
|
||||
if cmsg.ctype == ContextType.IMAGE:
|
||||
if hasattr(cmsg, 'image_path') and cmsg.image_path:
|
||||
file_cache.add(session_id, cmsg.image_path, file_type='image')
|
||||
logger.info(f"[DingTalk] Image cached for session {session_id}, waiting for user query...")
|
||||
# 单张图片不直接处理,等待用户提问
|
||||
return
|
||||
|
||||
# 如果是文本消息,检查是否有缓存的文件
|
||||
if cmsg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
# 将缓存的文件附加到文本消息中
|
||||
file_refs = []
|
||||
for file_info in cached_files:
|
||||
file_path = file_info['path']
|
||||
file_type = file_info['type']
|
||||
if file_type == 'image':
|
||||
file_refs.append(f"[图片: {file_path}]")
|
||||
elif file_type == 'video':
|
||||
file_refs.append(f"[视频: {file_path}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {file_path}]")
|
||||
|
||||
cmsg.content = cmsg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[DingTalk] Attached {len(cached_files)} cached file(s) to user query")
|
||||
# 清除缓存
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
context['no_need_at'] = True
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
logger.debug(f"[DingTalk] send() called with reply.type={reply.type}, content_length={len(str(reply.content))}")
|
||||
receiver = context["receiver"]
|
||||
|
||||
# Check if msg exists (for scheduled tasks, msg might be None)
|
||||
msg = context.kwargs.get('msg')
|
||||
if msg is None:
|
||||
# 定时任务场景:使用主动发送 API
|
||||
is_group = context.get("isgroup", False)
|
||||
logger.info(f"[DingTalk] Sending scheduled task message to {receiver} (is_group={is_group})")
|
||||
|
||||
# 使用缓存的 robot_code 或配置的值
|
||||
robot_code = self._robot_code or conf().get("dingtalk_robot_code")
|
||||
logger.info(f"[DingTalk] Using robot_code: {robot_code}, cached: {self._robot_code}, config: {conf().get('dingtalk_robot_code')}")
|
||||
|
||||
if not robot_code:
|
||||
logger.error(f"[DingTalk] Cannot send scheduled task: robot_code not available. Please send at least one message to the bot first, or configure dingtalk_robot_code in config.json")
|
||||
return
|
||||
|
||||
# 根据是否群聊选择不同的 API
|
||||
if is_group:
|
||||
success = self.send_group_message(receiver, reply.content, robot_code)
|
||||
else:
|
||||
# 单聊场景:尝试从 context 中获取 dingtalk_sender_staff_id
|
||||
sender_staff_id = context.get("dingtalk_sender_staff_id")
|
||||
if not sender_staff_id:
|
||||
logger.error(f"[DingTalk] Cannot send single chat scheduled message: sender_staff_id not available in context")
|
||||
return
|
||||
|
||||
logger.info(f"[DingTalk] Sending single message to staff_id: {sender_staff_id}")
|
||||
success = self.send_single_message(sender_staff_id, reply.content, robot_code)
|
||||
|
||||
if not success:
|
||||
logger.error(f"[DingTalk] Failed to send scheduled task message")
|
||||
return
|
||||
|
||||
# 从正常消息中提取并缓存 robot_code
|
||||
if hasattr(msg, 'robot_code'):
|
||||
robot_code = msg.robot_code
|
||||
if robot_code and robot_code != self._robot_code:
|
||||
self._robot_code = robot_code
|
||||
logger.debug(f"[DingTalk] Cached robot_code: {robot_code}")
|
||||
|
||||
isgroup = msg.is_group
|
||||
incoming_message = msg.incoming_message
|
||||
robot_code = self._robot_code or conf().get("dingtalk_robot_code")
|
||||
|
||||
# 处理图片和视频发送
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
logger.info(f"[DingTalk] Sending image: {reply.content}")
|
||||
|
||||
# 如果有附加的文本内容,先发送文本
|
||||
if hasattr(reply, 'text_content') and reply.text_content:
|
||||
self.reply_text(reply.text_content, incoming_message)
|
||||
import time
|
||||
time.sleep(0.3) # 短暂延迟,确保文本先到达
|
||||
|
||||
media_id = self.upload_media(reply.content, media_type="image")
|
||||
if media_id:
|
||||
# 使用主动发送 API 发送图片
|
||||
access_token = self.get_access_token()
|
||||
if access_token:
|
||||
success = self.send_image_with_media_id(
|
||||
access_token,
|
||||
media_id,
|
||||
incoming_message,
|
||||
isgroup
|
||||
)
|
||||
if not success:
|
||||
logger.error("[DingTalk] Failed to send image message")
|
||||
self.reply_text("抱歉,图片发送失败", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Cannot get access token")
|
||||
self.reply_text("抱歉,图片发送失败(无法获取token)", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Failed to upload image")
|
||||
self.reply_text("抱歉,图片上传失败", incoming_message)
|
||||
return
|
||||
|
||||
elif reply.type == ReplyType.FILE:
|
||||
# 如果有附加的文本内容,先发送文本
|
||||
if hasattr(reply, 'text_content') and reply.text_content:
|
||||
self.reply_text(reply.text_content, incoming_message)
|
||||
import time
|
||||
time.sleep(0.3) # 短暂延迟,确保文本先到达
|
||||
|
||||
# 判断是否为视频文件
|
||||
file_path = reply.content
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
|
||||
is_video = file_path.lower().endswith(('.mp4', '.avi', '.mov', '.wmv', '.flv'))
|
||||
|
||||
access_token = self.get_access_token()
|
||||
if not access_token:
|
||||
logger.error("[DingTalk] Cannot get access token")
|
||||
self.reply_text("抱歉,文件发送失败(无法获取token)", incoming_message)
|
||||
return
|
||||
|
||||
if is_video:
|
||||
logger.info(f"[DingTalk] Sending video: {reply.content}")
|
||||
media_id = self.upload_media(reply.content, media_type="video")
|
||||
if media_id:
|
||||
# 发送视频消息
|
||||
msg_param = {
|
||||
"duration": "30", # TODO: 获取实际视频时长
|
||||
"videoMediaId": media_id,
|
||||
"videoType": "mp4",
|
||||
"height": "400",
|
||||
"width": "600",
|
||||
}
|
||||
success = self._send_file_message(
|
||||
access_token,
|
||||
incoming_message,
|
||||
"sampleVideo",
|
||||
msg_param,
|
||||
isgroup
|
||||
)
|
||||
if not success:
|
||||
self.reply_text("抱歉,视频发送失败", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Failed to upload video")
|
||||
self.reply_text("抱歉,视频上传失败", incoming_message)
|
||||
else:
|
||||
# 其他文件类型
|
||||
logger.info(f"[DingTalk] Sending file: {reply.content}")
|
||||
media_id = self.upload_media(reply.content, media_type="file")
|
||||
if media_id:
|
||||
file_name = os.path.basename(file_path)
|
||||
file_base, file_extension = os.path.splitext(file_name)
|
||||
msg_param = {
|
||||
"mediaId": media_id,
|
||||
"fileName": file_name,
|
||||
"fileType": file_extension[1:] if file_extension else "file"
|
||||
}
|
||||
success = self._send_file_message(
|
||||
access_token,
|
||||
incoming_message,
|
||||
"sampleFile",
|
||||
msg_param,
|
||||
isgroup
|
||||
)
|
||||
if not success:
|
||||
self.reply_text("抱歉,文件发送失败", incoming_message)
|
||||
else:
|
||||
logger.error("[DingTalk] Failed to upload file")
|
||||
self.reply_text("抱歉,文件上传失败", incoming_message)
|
||||
return
|
||||
|
||||
# 处理文本消息
|
||||
elif reply.type == ReplyType.TEXT:
|
||||
logger.info(f"[DingTalk] Sending text message, length={len(reply.content)}")
|
||||
if conf().get("dingtalk_card_enabled"):
|
||||
logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
def reply_with_text():
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
def reply_with_at_text():
|
||||
self.reply_text("📢 您有一条新的消息,请查看。", incoming_message)
|
||||
def reply_with_ai_markdown():
|
||||
button_list, markdown_content = self.generate_button_markdown_content(context, reply)
|
||||
self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI生成", "",[incoming_message.sender_staff_id])
|
||||
|
||||
if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]:
|
||||
if isgroup:
|
||||
reply_with_ai_markdown()
|
||||
reply_with_at_text()
|
||||
else:
|
||||
reply_with_ai_markdown()
|
||||
else:
|
||||
# 暂不支持其它类型消息回复
|
||||
reply_with_text()
|
||||
else:
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
return
|
||||
|
||||
def _send_file_message(self, access_token: str, incoming_message, msg_key: str, msg_param: dict, is_group: bool) -> bool:
|
||||
"""
|
||||
发送文件/视频消息的通用方法
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
incoming_message: 钉钉消息对象
|
||||
msg_key: 消息类型 (sampleFile, sampleVideo, sampleAudio)
|
||||
msg_param: 消息参数
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
body = {
|
||||
"robotCode": incoming_message.robot_code,
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param),
|
||||
}
|
||||
|
||||
if is_group:
|
||||
# 群聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
body["openConversationId"] = incoming_message.conversation_id
|
||||
else:
|
||||
# 单聊
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
body["userIds"] = [incoming_message.sender_staff_id]
|
||||
|
||||
try:
|
||||
response = requests.post(url=url, headers=headers, json=body, timeout=10)
|
||||
result = response.json()
|
||||
|
||||
logger.info(f"[DingTalk] File send result: {response.text}")
|
||||
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
logger.error(f"[DingTalk] Send file error: {response.text}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Send file exception: {e}")
|
||||
return False
|
||||
|
||||
def generate_button_markdown_content(self, context, reply):
|
||||
image_url = context.kwargs.get("image_url")
|
||||
promptEn = context.kwargs.get("promptEn")
|
||||
reply_text = reply.content
|
||||
button_list = []
|
||||
markdown_content = f"""
|
||||
{reply.content}
|
||||
"""
|
||||
if image_url is not None and promptEn is not None:
|
||||
button_list = [
|
||||
{"text": "查看原图", "url": image_url, "iosUrl": image_url, "color": "blue"}
|
||||
]
|
||||
markdown_content = f"""
|
||||
{promptEn}
|
||||
|
||||

|
||||
|
||||
{reply_text}
|
||||
|
||||
"""
|
||||
logger.debug(f"[Dingtalk] generate_button_markdown_content, button_list={button_list} , markdown_content={markdown_content}")
|
||||
|
||||
return button_list, markdown_content
|
||||
244
channel/dingtalk/dingtalk_message.py
Normal file
244
channel/dingtalk/dingtalk_message.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import requests
|
||||
from dingtalk_stream import ChatbotMessage
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
# -*- coding=utf-8 -*-
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
class DingTalkMessage(ChatMessage):
|
||||
def __init__(self, event: ChatbotMessage, image_download_handler):
|
||||
super().__init__(event)
|
||||
self.image_download_handler = image_download_handler
|
||||
self.msg_id = event.message_id
|
||||
self.message_type = event.message_type
|
||||
self.incoming_message = event
|
||||
self.sender_staff_id = event.sender_staff_id
|
||||
self.other_user_id = event.conversation_id
|
||||
self.create_time = event.create_at
|
||||
self.image_content = event.image_content
|
||||
self.rich_text_content = event.rich_text_content
|
||||
self.robot_code = event.robot_code # 机器人编码
|
||||
if event.conversation_type == "1":
|
||||
self.is_group = False
|
||||
else:
|
||||
self.is_group = True
|
||||
|
||||
if self.message_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
self.content = event.text.content.strip()
|
||||
elif self.message_type == "audio":
|
||||
# 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理
|
||||
self.content = event.extensions['content']['recognition'].strip()
|
||||
self.ctype = ContextType.TEXT
|
||||
elif (self.message_type == 'picture') or (self.message_type == 'richText'):
|
||||
# 钉钉图片类型或富文本类型消息处理
|
||||
image_list = event.get_image_list()
|
||||
|
||||
if self.message_type == 'picture' and len(image_list) > 0:
|
||||
# 单张图片消息:下载到工作空间,用于文件缓存
|
||||
self.ctype = ContextType.IMAGE
|
||||
download_code = image_list[0]
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
|
||||
# 下载到工作空间 tmp 目录
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
image_path = download_image_file(download_url, tmp_dir)
|
||||
if image_path:
|
||||
self.content = image_path
|
||||
self.image_path = image_path # 保存图片路径用于缓存
|
||||
logger.info(f"[DingTalk] Downloaded single image to {image_path}")
|
||||
else:
|
||||
self.content = "[图片下载失败]"
|
||||
self.image_path = None
|
||||
|
||||
elif self.message_type == 'richText' and len(image_list) > 0:
|
||||
# 富文本消息:下载所有图片并附加到文本中
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
# 下载到工作空间 tmp 目录
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
# 提取富文本中的文本内容
|
||||
text_content = ""
|
||||
if self.rich_text_content:
|
||||
# rich_text_content 是一个 RichTextContent 对象,需要从中提取文本
|
||||
text_list = event.get_text_list()
|
||||
if text_list:
|
||||
text_content = "".join(text_list).strip()
|
||||
|
||||
# 下载所有图片
|
||||
image_paths = []
|
||||
for download_code in image_list:
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
image_path = download_image_file(download_url, tmp_dir)
|
||||
if image_path:
|
||||
image_paths.append(image_path)
|
||||
|
||||
# 构建消息内容:文本 + 图片路径
|
||||
content_parts = []
|
||||
if text_content:
|
||||
content_parts.append(text_content)
|
||||
for img_path in image_paths:
|
||||
content_parts.append(f"[图片: {img_path}]")
|
||||
|
||||
self.content = "\n".join(content_parts) if content_parts else "[富文本消息]"
|
||||
logger.info(f"[DingTalk] Received richText with {len(image_paths)} image(s): {self.content}")
|
||||
else:
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = "[未找到图片]"
|
||||
logger.debug(f"[DingTalk] messageType: {self.message_type}, imageList isEmpty")
|
||||
|
||||
if self.is_group:
|
||||
self.from_user_id = event.conversation_id
|
||||
self.actual_user_id = event.sender_id
|
||||
self.is_at = True
|
||||
else:
|
||||
self.from_user_id = event.sender_id
|
||||
self.actual_user_id = event.sender_id
|
||||
self.to_user_id = event.chatbot_user_id
|
||||
self.other_user_nickname = event.conversation_title
|
||||
|
||||
|
||||
def download_image_file(image_url, temp_dir):
|
||||
"""
|
||||
下载图片文件
|
||||
支持两种方式:
|
||||
1. 普通 HTTP(S) URL
|
||||
2. 钉钉 downloadCode: dingtalk://download/{download_code}
|
||||
"""
|
||||
# 检查临时目录是否存在,如果不存在则创建
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
|
||||
# 处理钉钉 downloadCode
|
||||
if image_url.startswith("dingtalk://download/"):
|
||||
download_code = image_url.replace("dingtalk://download/", "")
|
||||
logger.info(f"[DingTalk] Downloading image with downloadCode: {download_code[:20]}...")
|
||||
|
||||
# 需要从外部传入 access_token,这里先用一个临时方案
|
||||
# 从 config 获取 dingtalk_client_id 和 dingtalk_client_secret
|
||||
from config import conf
|
||||
client_id = conf().get("dingtalk_client_id")
|
||||
client_secret = conf().get("dingtalk_client_secret")
|
||||
|
||||
if not client_id or not client_secret:
|
||||
logger.error("[DingTalk] Missing dingtalk_client_id or dingtalk_client_secret")
|
||||
return None
|
||||
|
||||
# 解析 robot_code 和 download_code
|
||||
parts = download_code.split(":", 1)
|
||||
if len(parts) != 2:
|
||||
logger.error(f"[DingTalk] Invalid download_code format (expected robot_code:download_code): {download_code[:50]}")
|
||||
return None
|
||||
|
||||
robot_code, actual_download_code = parts
|
||||
|
||||
# 获取 access_token(使用新版 API)
|
||||
token_url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
token_headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
token_body = {
|
||||
"appKey": client_id,
|
||||
"appSecret": client_secret
|
||||
}
|
||||
|
||||
try:
|
||||
token_response = requests.post(token_url, json=token_body, headers=token_headers, timeout=10)
|
||||
|
||||
if token_response.status_code == 200:
|
||||
token_data = token_response.json()
|
||||
access_token = token_data.get("accessToken")
|
||||
|
||||
if not access_token:
|
||||
logger.error(f"[DingTalk] Failed to get access token: {token_data}")
|
||||
return None
|
||||
|
||||
# 获取下载 URL(使用新版 API)
|
||||
download_api_url = "https://api.dingtalk.com/v1.0/robot/messageFiles/download"
|
||||
download_headers = {
|
||||
"x-acs-dingtalk-access-token": access_token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
download_body = {
|
||||
"downloadCode": actual_download_code,
|
||||
"robotCode": robot_code
|
||||
}
|
||||
|
||||
download_response = requests.post(download_api_url, json=download_body, headers=download_headers, timeout=10)
|
||||
|
||||
if download_response.status_code == 200:
|
||||
download_data = download_response.json()
|
||||
download_url = download_data.get("downloadUrl")
|
||||
|
||||
if not download_url:
|
||||
logger.error(f"[DingTalk] No downloadUrl in response: {download_data}")
|
||||
return None
|
||||
|
||||
# 从 downloadUrl 下载实际图片
|
||||
image_response = requests.get(download_url, stream=True, timeout=60)
|
||||
|
||||
if image_response.status_code == 200:
|
||||
# 生成文件名(使用 download_code 的 hash,避免特殊字符)
|
||||
import hashlib
|
||||
file_hash = hashlib.md5(actual_download_code.encode()).hexdigest()[:16]
|
||||
file_name = f"{file_hash}.png"
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(image_response.content)
|
||||
|
||||
logger.info(f"[DingTalk] Image downloaded successfully: {file_path}")
|
||||
return file_path
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to download image from URL: {image_response.status_code}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to get download URL: {download_response.status_code}, {download_response.text}")
|
||||
return None
|
||||
else:
|
||||
logger.error(f"[DingTalk] Failed to get access token: {token_response.status_code}, {token_response.text}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[DingTalk] Exception downloading image: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
# 普通 HTTP(S) URL
|
||||
else:
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36'
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(image_url, headers=headers, stream=True, timeout=60 * 5)
|
||||
if response.status_code == 200:
|
||||
# 生成文件名
|
||||
file_name = image_url.split("/")[-1].split("?")[0]
|
||||
|
||||
# 将文件保存到临时目录
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
return file_path
|
||||
else:
|
||||
logger.info(f"[Dingtalk] Failed to download image file, {response.content}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Dingtalk] Exception downloading image: {e}")
|
||||
return None
|
||||
184
channel/feishu/README.md
Normal file
184
channel/feishu/README.md
Normal file
@@ -0,0 +1,184 @@
|
||||
# 飞书Channel使用说明
|
||||
|
||||
飞书Channel支持两种事件接收模式,可以根据部署环境灵活选择。
|
||||
|
||||
## 模式对比
|
||||
|
||||
| 模式 | 适用场景 | 优点 | 缺点 |
|
||||
|------|---------|------|------|
|
||||
| **webhook** | 生产环境 | 稳定可靠,官方推荐 | 需要公网IP或域名 |
|
||||
| **websocket** | 本地开发 | 无需公网IP,开发便捷 | 需要额外依赖 |
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 基础配置
|
||||
|
||||
在 `config.json` 中添加以下配置:
|
||||
|
||||
```json
|
||||
{
|
||||
"channel_type": "feishu",
|
||||
"feishu_app_id": "cli_xxxxx",
|
||||
"feishu_app_secret": "your_app_secret",
|
||||
"feishu_token": "your_verification_token",
|
||||
"feishu_bot_name": "你的机器人名称",
|
||||
"feishu_event_mode": "webhook",
|
||||
"feishu_port": 9891
|
||||
}
|
||||
```
|
||||
|
||||
### 配置项说明
|
||||
|
||||
- `feishu_app_id`: 飞书应用的App ID
|
||||
- `feishu_app_secret`: 飞书应用的App Secret
|
||||
- `feishu_token`: 事件订阅的Verification Token
|
||||
- `feishu_bot_name`: 机器人名称(用于群聊@判断)
|
||||
- `feishu_event_mode`: 事件接收模式,可选值:
|
||||
- `"websocket"`: 长连接模式(默认)
|
||||
- `"webhook"`: HTTP服务器模式
|
||||
- `feishu_port`: webhook模式下的HTTP服务端口(默认9891)
|
||||
|
||||
## 模式一: Webhook模式(推荐生产环境)
|
||||
|
||||
### 1. 配置
|
||||
|
||||
```json
|
||||
{
|
||||
"feishu_event_mode": "webhook",
|
||||
"feishu_port": 9891
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 启动服务
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
```
|
||||
|
||||
服务将在 `http://0.0.0.0:9891` 启动。
|
||||
|
||||
### 3. 配置飞书应用
|
||||
|
||||
1. 登录[飞书开放平台](https://open.feishu.cn/)
|
||||
2. 进入应用详情 -> 事件订阅
|
||||
3. 选择 **将事件发送至开发者服务器**
|
||||
4. 填写请求地址: `http://your-domain:9891/`
|
||||
5. 添加事件: `im.message.receive_v1` (接收消息v2.0)
|
||||
6. 保存配置
|
||||
|
||||
### 4. 注意事项
|
||||
|
||||
- 需要有公网IP或域名
|
||||
- 确保防火墙开放对应端口
|
||||
- 建议使用HTTPS(需要配置反向代理)
|
||||
|
||||
## 模式二: WebSocket模式(推荐本地开发)
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install lark-oapi
|
||||
```
|
||||
|
||||
### 2. 配置
|
||||
|
||||
```json
|
||||
{
|
||||
"feishu_event_mode": "websocket"
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 启动服务
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
```
|
||||
|
||||
程序将自动建立与飞书开放平台的长连接。
|
||||
|
||||
### 4. 配置飞书应用
|
||||
|
||||
1. 登录[飞书开放平台](https://open.feishu.cn/)
|
||||
2. 进入应用详情 -> 事件订阅
|
||||
3. 选择 **使用长连接接收事件**
|
||||
4. 添加事件: `im.message.receive_v1` (接收消息v2.0)
|
||||
5. 保存配置
|
||||
|
||||
### 5. 注意事项
|
||||
|
||||
- 无需公网IP
|
||||
- 需要能访问公网(建立WebSocket连接)
|
||||
- 每个应用最多50个连接
|
||||
- 集群模式下消息随机分发到一个客户端
|
||||
|
||||
## 平滑迁移
|
||||
|
||||
从webhook模式切换到websocket模式(或反向切换):
|
||||
|
||||
1. 修改 `config.json` 中的 `feishu_event_mode`
|
||||
2. 如果切换到websocket模式,安装 `lark-oapi` 依赖
|
||||
3. 重启服务
|
||||
4. 在飞书开放平台修改事件订阅方式
|
||||
|
||||
**重要**: 同一时间只能使用一种模式,否则会导致消息重复接收。
|
||||
|
||||
## 消息去重机制
|
||||
|
||||
两种模式都使用相同的消息去重机制:
|
||||
|
||||
- 使用 `ExpiredDict` 存储已处理的消息ID
|
||||
- 过期时间: 7.1小时
|
||||
- 确保消息不会重复处理
|
||||
|
||||
## 故障排查
|
||||
|
||||
### WebSocket模式连接失败
|
||||
|
||||
```
|
||||
[FeiShu] lark_oapi not installed
|
||||
```
|
||||
|
||||
**解决**: 安装依赖 `pip install lark-oapi`
|
||||
|
||||
### SSL证书验证失败
|
||||
|
||||
```
|
||||
[Lark][ERROR] connect failed, err:[SSL:CERTIFICATE_VERIFY_FAILED] certificate verify failed: self signed certificate in certificate chain
|
||||
```
|
||||
|
||||
**原因**: 网络环境中存在自签名证书或SSL中间人代理(如企业代理、VPN等)
|
||||
|
||||
**解决**: 程序会自动检测SSL证书验证失败,并自动重试禁用证书验证的连接。无需手动配置。
|
||||
|
||||
当遇到证书错误时,日志会显示:
|
||||
```
|
||||
[FeiShu] SSL certificate verification disabled due to certificate error. This may happen when using corporate proxy or self-signed certificates.
|
||||
```
|
||||
|
||||
这是正常现象,程序会自动处理并继续运行。
|
||||
|
||||
### Webhook模式端口被占用
|
||||
|
||||
```
|
||||
Address already in use
|
||||
```
|
||||
|
||||
**解决**: 修改 `feishu_port` 配置或关闭占用端口的进程
|
||||
|
||||
### 收不到消息
|
||||
|
||||
1. 检查飞书应用的事件订阅配置
|
||||
2. 确认已添加 `im.message.receive_v1` 事件
|
||||
3. 检查应用权限: 需要 `im:message` 权限
|
||||
4. 查看日志中的错误信息
|
||||
|
||||
## 开发建议
|
||||
|
||||
- **本地开发**: 使用websocket模式,快速迭代
|
||||
- **测试环境**: 可以使用webhook模式 + 内网穿透工具(如ngrok)
|
||||
- **生产环境**: 使用webhook模式,配置正式域名和HTTPS
|
||||
|
||||
## 参考文档
|
||||
|
||||
- [飞书开放平台 - 事件订阅](https://open.feishu.cn/document/ukTMukTMukTM/uUTNz4SN1MjL1UzM)
|
||||
- [飞书SDK - Python](https://github.com/larksuite/oapi-sdk-python)
|
||||
815
channel/feishu/feishu_channel.py
Normal file
815
channel/feishu/feishu_channel.py
Normal file
@@ -0,0 +1,815 @@
|
||||
"""
|
||||
飞书通道接入
|
||||
|
||||
支持两种事件接收模式:
|
||||
1. webhook模式: 通过HTTP服务器接收事件(需要公网IP)
|
||||
2. websocket模式: 通过长连接接收事件(本地开发友好)
|
||||
|
||||
通过配置项 feishu_event_mode 选择模式: "webhook" 或 "websocket"
|
||||
|
||||
@author Saboteur7
|
||||
@Date 2023/11/19
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import threading
|
||||
# -*- coding=utf-8 -*-
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
import web
|
||||
|
||||
from bridge.context import Context
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.feishu.feishu_message import FeishuMessage
|
||||
from common import utils
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
# Suppress verbose logs from Lark SDK
|
||||
logging.getLogger("Lark").setLevel(logging.WARNING)
|
||||
|
||||
URL_VERIFICATION = "url_verification"
|
||||
|
||||
# 尝试导入飞书SDK,如果未安装则websocket模式不可用
|
||||
try:
|
||||
import lark_oapi as lark
|
||||
|
||||
LARK_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
LARK_SDK_AVAILABLE = False
|
||||
logger.warning(
|
||||
"[FeiShu] lark_oapi not installed, websocket mode is not available. Install with: pip install lark-oapi")
|
||||
|
||||
|
||||
@singleton
|
||||
class FeiShuChanel(ChatChannel):
|
||||
feishu_app_id = conf().get('feishu_app_id')
|
||||
feishu_app_secret = conf().get('feishu_app_secret')
|
||||
feishu_token = conf().get('feishu_token')
|
||||
feishu_event_mode = conf().get('feishu_event_mode', 'websocket') # webhook 或 websocket
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
self._http_server = None
|
||||
self._ws_client = None
|
||||
self._ws_thread = None
|
||||
logger.debug("[FeiShu] app_id={}, app_secret={}, verification_token={}, event_mode={}".format(
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token, self.feishu_event_mode))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
# 验证配置
|
||||
if self.feishu_event_mode == 'websocket' and not LARK_SDK_AVAILABLE:
|
||||
logger.error("[FeiShu] websocket mode requires lark_oapi. Please install: pip install lark-oapi")
|
||||
raise Exception("lark_oapi not installed")
|
||||
|
||||
def startup(self):
|
||||
self.feishu_app_id = conf().get('feishu_app_id')
|
||||
self.feishu_app_secret = conf().get('feishu_app_secret')
|
||||
self.feishu_token = conf().get('feishu_token')
|
||||
self.feishu_event_mode = conf().get('feishu_event_mode', 'websocket')
|
||||
if self.feishu_event_mode == 'websocket':
|
||||
self._startup_websocket()
|
||||
else:
|
||||
self._startup_webhook()
|
||||
|
||||
def stop(self):
|
||||
import ctypes
|
||||
logger.info("[FeiShu] stop() called")
|
||||
ws_client = self._ws_client
|
||||
self._ws_client = None
|
||||
ws_thread = self._ws_thread
|
||||
self._ws_thread = None
|
||||
# Interrupt the ws thread first so its blocking start() unblocks
|
||||
if ws_thread and ws_thread.is_alive():
|
||||
try:
|
||||
tid = ws_thread.ident
|
||||
if tid:
|
||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
||||
ctypes.c_ulong(tid), ctypes.py_object(SystemExit)
|
||||
)
|
||||
if res == 1:
|
||||
logger.info("[FeiShu] Interrupted ws thread via ctypes")
|
||||
elif res > 1:
|
||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(tid), None)
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Error interrupting ws thread: {e}")
|
||||
# lark.ws.Client has no stop() method; thread interruption above is sufficient
|
||||
if self._http_server:
|
||||
try:
|
||||
self._http_server.stop()
|
||||
logger.info("[FeiShu] HTTP server stopped")
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Error stopping HTTP server: {e}")
|
||||
self._http_server = None
|
||||
logger.info("[FeiShu] stop() completed")
|
||||
|
||||
def _startup_webhook(self):
|
||||
"""启动HTTP服务器接收事件(webhook模式)"""
|
||||
logger.debug("[FeiShu] Starting in webhook mode...")
|
||||
urls = (
|
||||
'/', 'channel.feishu.feishu_channel.FeishuController'
|
||||
)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("feishu_port", 9891)
|
||||
func = web.httpserver.StaticMiddleware(app.wsgifunc())
|
||||
func = web.httpserver.LogMiddleware(func)
|
||||
server = web.httpserver.WSGIServer(("0.0.0.0", port), func)
|
||||
self._http_server = server
|
||||
try:
|
||||
server.start()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
server.stop()
|
||||
|
||||
def _startup_websocket(self):
|
||||
"""启动长连接接收事件(websocket模式)"""
|
||||
logger.debug("[FeiShu] Starting in websocket mode...")
|
||||
|
||||
# 创建事件处理器
|
||||
def handle_message_event(data: lark.im.v1.P2ImMessageReceiveV1) -> None:
|
||||
"""处理接收消息事件 v2.0"""
|
||||
try:
|
||||
logger.debug(f"[FeiShu] websocket receive event: {lark.JSON.marshal(data, indent=2)}")
|
||||
|
||||
# 转换为标准的event格式
|
||||
event_dict = json.loads(lark.JSON.marshal(data))
|
||||
event = event_dict.get("event", {})
|
||||
|
||||
# 处理消息
|
||||
self._handle_message_event(event)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] websocket handle message error: {e}", exc_info=True)
|
||||
|
||||
# 构建事件分发器
|
||||
event_handler = lark.EventDispatcherHandler.builder("", "") \
|
||||
.register_p2_im_message_receive_v1(handle_message_event) \
|
||||
.build()
|
||||
|
||||
def start_client_with_retry():
|
||||
"""Run ws client in this thread with its own event loop to avoid conflicts."""
|
||||
import asyncio
|
||||
import ssl as ssl_module
|
||||
original_create_default_context = ssl_module.create_default_context
|
||||
|
||||
def create_unverified_context(*args, **kwargs):
|
||||
context = original_create_default_context(*args, **kwargs)
|
||||
context.check_hostname = False
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
return context
|
||||
|
||||
# Give this thread its own event loop so lark SDK can call run_until_complete
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
for attempt in range(2):
|
||||
try:
|
||||
if attempt == 1:
|
||||
logger.warning("[FeiShu] Retrying with SSL verification disabled...")
|
||||
ssl_module.create_default_context = create_unverified_context
|
||||
ssl_module._create_unverified_context = create_unverified_context
|
||||
|
||||
ws_client = lark.ws.Client(
|
||||
self.feishu_app_id,
|
||||
self.feishu_app_secret,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.WARNING
|
||||
)
|
||||
self._ws_client = ws_client
|
||||
logger.debug("[FeiShu] Websocket client starting...")
|
||||
ws_client.start()
|
||||
break
|
||||
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
logger.info("[FeiShu] Websocket thread received stop signal")
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
is_ssl_error = ("CERTIFICATE_VERIFY_FAILED" in error_msg
|
||||
or "certificate verify failed" in error_msg.lower())
|
||||
if is_ssl_error and attempt == 0:
|
||||
logger.warning(f"[FeiShu] SSL error: {error_msg}, retrying...")
|
||||
continue
|
||||
logger.error(f"[FeiShu] Websocket client error: {e}", exc_info=True)
|
||||
ssl_module.create_default_context = original_create_default_context
|
||||
break
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("[FeiShu] Websocket thread exited")
|
||||
|
||||
ws_thread = threading.Thread(target=start_client_with_retry, daemon=True)
|
||||
self._ws_thread = ws_thread
|
||||
ws_thread.start()
|
||||
logger.info("[FeiShu] ✅ Websocket thread started, ready to receive messages")
|
||||
ws_thread.join()
|
||||
|
||||
def _handle_message_event(self, event: dict):
|
||||
"""
|
||||
处理消息事件的核心逻辑
|
||||
webhook和websocket模式共用此方法
|
||||
"""
|
||||
if not event.get("message") or not event.get("sender"):
|
||||
logger.warning(f"[FeiShu] invalid message, event={event}")
|
||||
return
|
||||
|
||||
msg = event.get("message")
|
||||
|
||||
# 幂等判断
|
||||
msg_id = msg.get("message_id")
|
||||
if self.receivedMsgs.get(msg_id):
|
||||
logger.warning(f"[FeiShu] repeat msg filtered, msg_id={msg_id}")
|
||||
return
|
||||
self.receivedMsgs[msg_id] = True
|
||||
|
||||
# Filter out stale messages from before channel startup (offline backlog)
|
||||
import time as _time
|
||||
create_time_ms = msg.get("create_time")
|
||||
if create_time_ms:
|
||||
msg_age_s = _time.time() - int(create_time_ms) / 1000
|
||||
if msg_age_s > 60:
|
||||
logger.warning(f"[FeiShu] stale msg filtered (age={msg_age_s:.0f}s), msg_id={msg_id}")
|
||||
return
|
||||
|
||||
is_group = False
|
||||
chat_type = msg.get("chat_type")
|
||||
|
||||
if chat_type == "group":
|
||||
if not msg.get("mentions") and msg.get("message_type") == "text":
|
||||
# 群聊中未@不响应
|
||||
return
|
||||
if msg.get("mentions") and msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get(
|
||||
"message_type") == "text":
|
||||
# 不是@机器人,不响应
|
||||
return
|
||||
# 群聊
|
||||
is_group = True
|
||||
receive_id_type = "chat_id"
|
||||
elif chat_type == "p2p":
|
||||
receive_id_type = "open_id"
|
||||
else:
|
||||
logger.warning("[FeiShu] message ignore")
|
||||
return
|
||||
|
||||
# 构造飞书消息对象
|
||||
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=self.fetch_access_token())
|
||||
if not feishu_msg:
|
||||
return
|
||||
|
||||
# 处理文件缓存逻辑
|
||||
from channel.file_cache import get_file_cache
|
||||
file_cache = get_file_cache()
|
||||
|
||||
# 获取 session_id(用于缓存关联)
|
||||
if is_group:
|
||||
if conf().get("group_shared_session", True):
|
||||
session_id = msg.get("chat_id") # 群共享会话
|
||||
else:
|
||||
session_id = feishu_msg.from_user_id + "_" + msg.get("chat_id")
|
||||
else:
|
||||
session_id = feishu_msg.from_user_id
|
||||
|
||||
# 如果是单张图片消息,缓存起来
|
||||
if feishu_msg.ctype == ContextType.IMAGE:
|
||||
if hasattr(feishu_msg, 'image_path') and feishu_msg.image_path:
|
||||
file_cache.add(session_id, feishu_msg.image_path, file_type='image')
|
||||
logger.info(f"[FeiShu] Image cached for session {session_id}, waiting for user query...")
|
||||
# 单张图片不直接处理,等待用户提问
|
||||
return
|
||||
|
||||
# 如果是文本消息,检查是否有缓存的文件
|
||||
if feishu_msg.ctype == ContextType.TEXT:
|
||||
cached_files = file_cache.get(session_id)
|
||||
if cached_files:
|
||||
# 将缓存的文件附加到文本消息中
|
||||
file_refs = []
|
||||
for file_info in cached_files:
|
||||
file_path = file_info['path']
|
||||
file_type = file_info['type']
|
||||
if file_type == 'image':
|
||||
file_refs.append(f"[图片: {file_path}]")
|
||||
elif file_type == 'video':
|
||||
file_refs.append(f"[视频: {file_path}]")
|
||||
else:
|
||||
file_refs.append(f"[文件: {file_path}]")
|
||||
|
||||
feishu_msg.content = feishu_msg.content + "\n" + "\n".join(file_refs)
|
||||
logger.info(f"[FeiShu] Attached {len(cached_files)} cached file(s) to user query")
|
||||
# 清除缓存
|
||||
file_cache.clear(session_id)
|
||||
|
||||
context = self._compose_context(
|
||||
feishu_msg.ctype,
|
||||
feishu_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=feishu_msg,
|
||||
receive_id_type=receive_id_type,
|
||||
no_need_at=True
|
||||
)
|
||||
if context:
|
||||
self.produce(context)
|
||||
logger.debug(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context["isgroup"]
|
||||
if msg:
|
||||
access_token = msg.access_token
|
||||
else:
|
||||
access_token = self.fetch_access_token()
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
msg_type = "text"
|
||||
logger.debug(f"[FeiShu] sending reply, type={context.type}, content={reply.content[:100]}...")
|
||||
reply_content = reply.content
|
||||
content_key = "text"
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
# 图片上传
|
||||
reply_content = self._upload_image_url(reply.content, access_token)
|
||||
if not reply_content:
|
||||
logger.warning("[FeiShu] upload image failed")
|
||||
return
|
||||
msg_type = "image"
|
||||
content_key = "image_key"
|
||||
elif reply.type == ReplyType.FILE:
|
||||
# 如果有附加的文本内容,先发送文本
|
||||
if hasattr(reply, 'text_content') and reply.text_content:
|
||||
logger.info(f"[FeiShu] Sending text before file: {reply.text_content[:50]}...")
|
||||
text_reply = Reply(ReplyType.TEXT, reply.text_content)
|
||||
self._send(text_reply, context)
|
||||
import time
|
||||
time.sleep(0.3) # 短暂延迟,确保文本先到达
|
||||
|
||||
# 判断是否为视频文件
|
||||
file_path = reply.content
|
||||
if file_path.startswith("file://"):
|
||||
file_path = file_path[7:]
|
||||
|
||||
is_video = file_path.lower().endswith(('.mp4', '.avi', '.mov', '.wmv', '.flv'))
|
||||
|
||||
if is_video:
|
||||
# 视频上传(包含duration信息)
|
||||
upload_data = self._upload_video_url(reply.content, access_token)
|
||||
if not upload_data or not upload_data.get('file_key'):
|
||||
logger.warning("[FeiShu] upload video failed")
|
||||
return
|
||||
|
||||
# 视频使用 media 类型(根据官方文档)
|
||||
# 错误码 230055 说明:上传 mp4 时必须使用 msg_type="media"
|
||||
msg_type = "media"
|
||||
reply_content = upload_data # 完整的上传响应数据(包含file_key和duration)
|
||||
logger.info(
|
||||
f"[FeiShu] Sending video: file_key={upload_data.get('file_key')}, duration={upload_data.get('duration')}ms")
|
||||
content_key = None # 直接序列化整个对象
|
||||
else:
|
||||
# 其他文件使用 file 类型
|
||||
file_key = self._upload_file_url(reply.content, access_token)
|
||||
if not file_key:
|
||||
logger.warning("[FeiShu] upload file failed")
|
||||
return
|
||||
reply_content = file_key
|
||||
msg_type = "file"
|
||||
content_key = "file_key"
|
||||
|
||||
# Check if we can reply to an existing message (need msg_id)
|
||||
can_reply = is_group and msg and hasattr(msg, 'msg_id') and msg.msg_id
|
||||
|
||||
# Build content JSON
|
||||
content_json = json.dumps(reply_content) if content_key is None else json.dumps({content_key: reply_content})
|
||||
logger.debug(f"[FeiShu] Sending message: msg_type={msg_type}, content={content_json[:200]}")
|
||||
|
||||
if can_reply:
|
||||
# 群聊中回复已有消息
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply"
|
||||
data = {
|
||||
"msg_type": msg_type,
|
||||
"content": content_json
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
|
||||
else:
|
||||
# 发送新消息(私聊或群聊中无msg_id的情况,如定时任务)
|
||||
url = "https://open.feishu.cn/open-apis/im/v1/messages"
|
||||
params = {"receive_id_type": context.get("receive_id_type") or "open_id"}
|
||||
data = {
|
||||
"receive_id": context.get("receiver"),
|
||||
"msg_type": msg_type,
|
||||
"content": content_json
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10))
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
logger.info(f"[FeiShu] send message success")
|
||||
else:
|
||||
logger.error(f"[FeiShu] send message failed, code={res.get('code')}, msg={res.get('msg')}")
|
||||
|
||||
def fetch_access_token(self) -> str:
|
||||
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/"
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
req_body = {
|
||||
"app_id": self.feishu_app_id,
|
||||
"app_secret": self.feishu_app_secret
|
||||
}
|
||||
data = bytes(json.dumps(req_body), encoding='utf8')
|
||||
response = requests.post(url=url, data=data, headers=headers)
|
||||
if response.status_code == 200:
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
logger.error(f"[FeiShu] get tenant_access_token error, code={res.get('code')}, msg={res.get('msg')}")
|
||||
return ""
|
||||
else:
|
||||
return res.get("tenant_access_token")
|
||||
else:
|
||||
logger.error(f"[FeiShu] fetch token error, res={response}")
|
||||
|
||||
def _upload_image_url(self, img_url, access_token):
|
||||
logger.debug(f"[FeiShu] start process image, img_url={img_url}")
|
||||
|
||||
# Check if it's a local file path (file:// protocol)
|
||||
if img_url.startswith("file://"):
|
||||
local_path = img_url[7:] # Remove "file://" prefix
|
||||
logger.info(f"[FeiShu] uploading local file: {local_path}")
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[FeiShu] local file not found: {local_path}")
|
||||
return None
|
||||
|
||||
# Upload directly from local file
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
|
||||
data = {'image_type': 'message'}
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
with open(local_path, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
if response_data.get("code") == 0:
|
||||
return response_data.get("data").get("image_key")
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload failed: {response_data}")
|
||||
return None
|
||||
|
||||
# Original logic for HTTP URLs
|
||||
response = requests.get(img_url)
|
||||
suffix = utils.get_path_suffix(img_url)
|
||||
temp_name = str(uuid.uuid4()) + "." + suffix
|
||||
if response.status_code == 200:
|
||||
# 将图片内容保存为临时文件
|
||||
with open(temp_name, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
# upload
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
|
||||
data = {
|
||||
'image_type': 'message'
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
}
|
||||
with open(temp_name, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
os.remove(temp_name)
|
||||
return upload_response.json().get("data").get("image_key")
|
||||
|
||||
def _get_video_duration(self, file_path: str) -> int:
|
||||
"""
|
||||
获取视频时长(毫秒)
|
||||
|
||||
Args:
|
||||
file_path: 视频文件路径
|
||||
|
||||
Returns:
|
||||
视频时长(毫秒),如果获取失败返回0
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
# 使用 ffprobe 获取视频时长
|
||||
cmd = [
|
||||
'ffprobe',
|
||||
'-v', 'error',
|
||||
'-show_entries', 'format=duration',
|
||||
'-of', 'default=noprint_wrappers=1:nokey=1',
|
||||
file_path
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
if result.returncode == 0:
|
||||
duration_seconds = float(result.stdout.strip())
|
||||
duration_ms = int(duration_seconds * 1000)
|
||||
logger.info(f"[FeiShu] Video duration: {duration_seconds:.2f}s ({duration_ms}ms)")
|
||||
return duration_ms
|
||||
else:
|
||||
logger.warning(f"[FeiShu] Failed to get video duration via ffprobe: {result.stderr}")
|
||||
return 0
|
||||
except FileNotFoundError:
|
||||
logger.warning("[FeiShu] ffprobe not found, video duration will be 0. Install ffmpeg to fix this.")
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Failed to get video duration: {e}")
|
||||
return 0
|
||||
|
||||
def _upload_video_url(self, video_url, access_token):
|
||||
"""
|
||||
Upload video to Feishu and return video info (file_key and duration)
|
||||
Supports:
|
||||
- file:// URLs for local files
|
||||
- http(s):// URLs (download then upload)
|
||||
|
||||
Returns:
|
||||
dict with 'file_key' and 'duration' (milliseconds), or None if failed
|
||||
"""
|
||||
local_path = None
|
||||
temp_file = None
|
||||
|
||||
try:
|
||||
# For file:// URLs (local files), upload directly
|
||||
if video_url.startswith("file://"):
|
||||
local_path = video_url[7:] # Remove file:// prefix
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[FeiShu] local video file not found: {local_path}")
|
||||
return None
|
||||
else:
|
||||
# For HTTP URLs, download first
|
||||
logger.info(f"[FeiShu] Downloading video from URL: {video_url}")
|
||||
response = requests.get(video_url, timeout=(5, 60))
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[FeiShu] download video failed, status={response.status_code}")
|
||||
return None
|
||||
|
||||
# Save to temp file
|
||||
import uuid
|
||||
file_name = os.path.basename(video_url) or "video.mp4"
|
||||
temp_file = str(uuid.uuid4()) + "_" + file_name
|
||||
|
||||
with open(temp_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
logger.info(f"[FeiShu] Video downloaded, size={len(response.content)} bytes")
|
||||
local_path = temp_file
|
||||
|
||||
# Get video duration
|
||||
duration = self._get_video_duration(local_path)
|
||||
|
||||
# Upload to Feishu
|
||||
file_name = os.path.basename(local_path)
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
file_type_map = {'.mp4': 'mp4'}
|
||||
file_type = file_type_map.get(file_ext, 'mp4')
|
||||
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
|
||||
data = {
|
||||
'file_type': file_type,
|
||||
'file_name': file_name
|
||||
}
|
||||
# Add duration only if available (required for video/audio)
|
||||
if duration:
|
||||
data['duration'] = duration # Must be int, not string
|
||||
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
logger.info(f"[FeiShu] Uploading video: file_name={file_name}, duration={duration}ms")
|
||||
|
||||
with open(local_path, "rb") as file:
|
||||
upload_response = requests.post(
|
||||
upload_url,
|
||||
files={"file": file},
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=(5, 60)
|
||||
)
|
||||
logger.info(
|
||||
f"[FeiShu] upload video response, status={upload_response.status_code}, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
if response_data.get("code") == 0:
|
||||
# Add duration to the response data (API doesn't return it)
|
||||
upload_data = response_data.get("data")
|
||||
upload_data['duration'] = duration # Add our calculated duration
|
||||
logger.info(
|
||||
f"[FeiShu] Upload complete: file_key={upload_data.get('file_key')}, duration={duration}ms")
|
||||
return upload_data
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload video failed: {response_data}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] upload video exception: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if temp_file and os.path.exists(temp_file):
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"[FeiShu] Failed to remove temp file {temp_file}: {e}")
|
||||
|
||||
def _upload_file_url(self, file_url, access_token):
|
||||
"""
|
||||
Upload file to Feishu
|
||||
Supports both local files (file://) and HTTP URLs
|
||||
"""
|
||||
logger.debug(f"[FeiShu] start process file, file_url={file_url}")
|
||||
|
||||
# Check if it's a local file path (file:// protocol)
|
||||
if file_url.startswith("file://"):
|
||||
local_path = file_url[7:] # Remove "file://" prefix
|
||||
logger.info(f"[FeiShu] uploading local file: {local_path}")
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"[FeiShu] local file not found: {local_path}")
|
||||
return None
|
||||
|
||||
# Get file info
|
||||
file_name = os.path.basename(local_path)
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
|
||||
# Determine file type for Feishu API
|
||||
# Feishu supports: opus, mp4, pdf, doc, xls, ppt, stream (other types)
|
||||
file_type_map = {
|
||||
'.opus': 'opus',
|
||||
'.mp4': 'mp4',
|
||||
'.pdf': 'pdf',
|
||||
'.doc': 'doc', '.docx': 'doc',
|
||||
'.xls': 'xls', '.xlsx': 'xls',
|
||||
'.ppt': 'ppt', '.pptx': 'ppt',
|
||||
}
|
||||
file_type = file_type_map.get(file_ext, 'stream') # Default to stream for other types
|
||||
|
||||
# Upload file to Feishu
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
|
||||
data = {'file_type': file_type, 'file_name': file_name}
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
try:
|
||||
with open(local_path, "rb") as file:
|
||||
upload_response = requests.post(
|
||||
upload_url,
|
||||
files={"file": file},
|
||||
data=data,
|
||||
headers=headers,
|
||||
timeout=(5, 30) # 5s connect, 30s read timeout
|
||||
)
|
||||
logger.info(
|
||||
f"[FeiShu] upload file response, status={upload_response.status_code}, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
if response_data.get("code") == 0:
|
||||
return response_data.get("data").get("file_key")
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload file failed: {response_data}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] upload file exception: {e}")
|
||||
return None
|
||||
|
||||
# For HTTP URLs, download first then upload
|
||||
try:
|
||||
response = requests.get(file_url, timeout=(5, 30))
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[FeiShu] download file failed, status={response.status_code}")
|
||||
return None
|
||||
|
||||
# Save to temp file
|
||||
import uuid
|
||||
file_name = os.path.basename(file_url)
|
||||
temp_name = str(uuid.uuid4()) + "_" + file_name
|
||||
|
||||
with open(temp_name, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
# Upload
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
file_type_map = {
|
||||
'.opus': 'opus', '.mp4': 'mp4', '.pdf': 'pdf',
|
||||
'.doc': 'doc', '.docx': 'doc',
|
||||
'.xls': 'xls', '.xlsx': 'xls',
|
||||
'.ppt': 'ppt', '.pptx': 'ppt',
|
||||
}
|
||||
file_type = file_type_map.get(file_ext, 'stream')
|
||||
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/files"
|
||||
data = {'file_type': file_type, 'file_name': file_name}
|
||||
headers = {'Authorization': f'Bearer {access_token}'}
|
||||
|
||||
with open(temp_name, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"file": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
|
||||
response_data = upload_response.json()
|
||||
os.remove(temp_name) # Clean up temp file
|
||||
|
||||
if response_data.get("code") == 0:
|
||||
return response_data.get("data").get("file_key")
|
||||
else:
|
||||
logger.error(f"[FeiShu] upload file failed: {response_data}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[FeiShu] upload file from URL exception: {e}")
|
||||
return None
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "channel_type" not in context:
|
||||
context["channel_type"] = self.channel_type
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
|
||||
# Set session_id based on chat type
|
||||
if cmsg.is_group:
|
||||
# Group chat: check if group_shared_session is enabled
|
||||
if conf().get("group_shared_session", True):
|
||||
# All users in the group share the same session context
|
||||
context["session_id"] = cmsg.other_user_id # group_id
|
||||
else:
|
||||
# Each user has their own session within the group
|
||||
# This ensures:
|
||||
# - Same user in different groups have separate conversation histories
|
||||
# - Same user in private chat and group chat have separate histories
|
||||
context["session_id"] = f"{cmsg.from_user_id}:{cmsg.other_user_id}"
|
||||
else:
|
||||
# Private chat: use user_id only
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# 1.文本请求
|
||||
# 图片生成处理
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
elif context.type == ContextType.VOICE:
|
||||
# 2.语音请求
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class FeishuController:
|
||||
"""
|
||||
HTTP服务器控制器,用于webhook模式
|
||||
"""
|
||||
# 类常量
|
||||
FAILED_MSG = '{"success": false}'
|
||||
SUCCESS_MSG = '{"success": true}'
|
||||
MESSAGE_RECEIVE_TYPE = "im.message.receive_v1"
|
||||
|
||||
def GET(self):
|
||||
return "Feishu service start success!"
|
||||
|
||||
def POST(self):
|
||||
try:
|
||||
channel = FeiShuChanel()
|
||||
|
||||
request = json.loads(web.data().decode("utf-8"))
|
||||
logger.debug(f"[FeiShu] receive request: {request}")
|
||||
|
||||
# 1.事件订阅回调验证
|
||||
if request.get("type") == URL_VERIFICATION:
|
||||
varify_res = {"challenge": request.get("challenge")}
|
||||
return json.dumps(varify_res)
|
||||
|
||||
# 2.消息接收处理
|
||||
# token 校验
|
||||
header = request.get("header")
|
||||
if not header or header.get("token") != channel.feishu_token:
|
||||
return self.FAILED_MSG
|
||||
|
||||
# 处理消息事件
|
||||
event = request.get("event")
|
||||
if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event:
|
||||
channel._handle_message_event(event)
|
||||
|
||||
return self.SUCCESS_MSG
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.FAILED_MSG
|
||||
179
channel/feishu/feishu_message.py
Normal file
179
channel/feishu/feishu_message.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
from common.utils import expand_path
|
||||
from config import conf
|
||||
|
||||
|
||||
class FeishuMessage(ChatMessage):
|
||||
def __init__(self, event: dict, is_group=False, access_token=None):
|
||||
super().__init__(event)
|
||||
msg = event.get("message")
|
||||
sender = event.get("sender")
|
||||
self.access_token = access_token
|
||||
self.msg_id = msg.get("message_id")
|
||||
self.create_time = msg.get("create_time")
|
||||
self.is_group = is_group
|
||||
msg_type = msg.get("message_type")
|
||||
|
||||
if msg_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
content = json.loads(msg.get('content'))
|
||||
self.content = content.get("text").strip()
|
||||
elif msg_type == "image":
|
||||
# 单张图片消息:下载并缓存,等待用户提问时一起发送
|
||||
self.ctype = ContextType.IMAGE
|
||||
content = json.loads(msg.get("content"))
|
||||
image_key = content.get("image_key")
|
||||
|
||||
# 下载图片到工作空间临时目录
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
image_path = os.path.join(tmp_dir, f"{image_key}.png")
|
||||
|
||||
# 下载图片
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.get('message_id')}/resources/{image_key}"
|
||||
headers = {"Authorization": "Bearer " + access_token}
|
||||
params = {"type": "image"}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"[FeiShu] Downloaded single image, key={image_key}, path={image_path}")
|
||||
self.content = image_path
|
||||
self.image_path = image_path # 保存图片路径
|
||||
else:
|
||||
logger.error(f"[FeiShu] Failed to download single image, key={image_key}, status={response.status_code}")
|
||||
self.content = f"[图片下载失败: {image_key}]"
|
||||
self.image_path = None
|
||||
elif msg_type == "post":
|
||||
# 富文本消息,可能包含图片、文本等多种元素
|
||||
content = json.loads(msg.get("content"))
|
||||
|
||||
# 飞书富文本消息结构:content 直接包含 title 和 content 数组
|
||||
# 不是嵌套在 post 字段下
|
||||
title = content.get("title", "")
|
||||
content_list = content.get("content", [])
|
||||
|
||||
logger.info(f"[FeiShu] Post message - title: '{title}', content_list length: {len(content_list)}")
|
||||
|
||||
# 收集所有图片和文本
|
||||
image_keys = []
|
||||
text_parts = []
|
||||
|
||||
if title:
|
||||
text_parts.append(title)
|
||||
|
||||
for block in content_list:
|
||||
logger.debug(f"[FeiShu] Processing block: {block}")
|
||||
# block 本身就是元素列表
|
||||
if not isinstance(block, list):
|
||||
continue
|
||||
|
||||
for element in block:
|
||||
element_tag = element.get("tag")
|
||||
logger.debug(f"[FeiShu] Element tag: {element_tag}, element: {element}")
|
||||
if element_tag == "img":
|
||||
# 找到图片元素
|
||||
image_key = element.get("image_key")
|
||||
if image_key:
|
||||
image_keys.append(image_key)
|
||||
elif element_tag == "text":
|
||||
# 文本元素
|
||||
text_content = element.get("text", "")
|
||||
if text_content:
|
||||
text_parts.append(text_content)
|
||||
|
||||
logger.info(f"[FeiShu] Parsed - images: {len(image_keys)}, text_parts: {text_parts}")
|
||||
|
||||
# 富文本消息统一作为文本消息处理
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
if image_keys:
|
||||
# 如果包含图片,下载并在文本中引用本地路径
|
||||
workspace_root = expand_path(conf().get("agent_workspace", "~/cow"))
|
||||
tmp_dir = os.path.join(workspace_root, "tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
# 保存图片路径映射
|
||||
self.image_paths = {}
|
||||
for image_key in image_keys:
|
||||
image_path = os.path.join(tmp_dir, f"{image_key}.png")
|
||||
self.image_paths[image_key] = image_path
|
||||
|
||||
def _download_images():
|
||||
for image_key, image_path in self.image_paths.items():
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{image_key}"
|
||||
headers = {"Authorization": "Bearer " + access_token}
|
||||
params = {"type": "image"}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
if response.status_code == 200:
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"[FeiShu] Image downloaded from post message, key={image_key}, path={image_path}")
|
||||
else:
|
||||
logger.error(f"[FeiShu] Failed to download image from post, key={image_key}, status={response.status_code}")
|
||||
|
||||
# 立即下载图片,不使用延迟下载
|
||||
# 因为 TEXT 类型消息不会调用 prepare()
|
||||
_download_images()
|
||||
|
||||
# 构建消息内容:文本 + 图片路径
|
||||
content_parts = []
|
||||
if text_parts:
|
||||
content_parts.append("\n".join(text_parts).strip())
|
||||
for image_key, image_path in self.image_paths.items():
|
||||
content_parts.append(f"[图片: {image_path}]")
|
||||
|
||||
self.content = "\n".join(content_parts)
|
||||
logger.info(f"[FeiShu] Received post message with {len(image_keys)} image(s) and text: {self.content}")
|
||||
else:
|
||||
# 纯文本富文本消息
|
||||
self.content = "\n".join(text_parts).strip() if text_parts else "[富文本消息]"
|
||||
logger.info(f"[FeiShu] Received post message (text only): {self.content}")
|
||||
elif msg_type == "file":
|
||||
self.ctype = ContextType.FILE
|
||||
content = json.loads(msg.get("content"))
|
||||
file_key = content.get("file_key")
|
||||
file_name = content.get("file_name")
|
||||
|
||||
self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name)
|
||||
|
||||
def _download_file():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}"
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
}
|
||||
params = {
|
||||
"type": "file"
|
||||
}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
|
||||
self._prepare_fn = _download_file
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))
|
||||
|
||||
self.from_user_id = sender.get("sender_id").get("open_id")
|
||||
self.to_user_id = event.get("app_id")
|
||||
if is_group:
|
||||
# 群聊
|
||||
self.other_user_id = msg.get("chat_id")
|
||||
self.actual_user_id = self.from_user_id
|
||||
self.content = self.content.replace("@_user_1", "").strip()
|
||||
self.actual_user_nickname = ""
|
||||
else:
|
||||
# 私聊
|
||||
self.other_user_id = self.from_user_id
|
||||
self.actual_user_id = self.from_user_id
|
||||
100
channel/file_cache.py
Normal file
100
channel/file_cache.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
文件缓存管理器
|
||||
用于缓存单独发送的文件消息(图片、视频、文档等),在用户提问时自动附加
|
||||
"""
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileCache:
|
||||
"""文件缓存管理器,按 session_id 缓存文件,TTL=2分钟"""
|
||||
|
||||
def __init__(self, ttl=120):
|
||||
"""
|
||||
Args:
|
||||
ttl: 缓存过期时间(秒),默认2分钟
|
||||
"""
|
||||
self.cache = {}
|
||||
self.ttl = ttl
|
||||
|
||||
def add(self, session_id: str, file_path: str, file_type: str = "image"):
|
||||
"""
|
||||
添加文件到缓存
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
file_path: 文件本地路径
|
||||
file_type: 文件类型(image, video, file 等)
|
||||
"""
|
||||
if session_id not in self.cache:
|
||||
self.cache[session_id] = {
|
||||
'files': [],
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
# 添加文件(去重)
|
||||
file_info = {'path': file_path, 'type': file_type}
|
||||
if file_info not in self.cache[session_id]['files']:
|
||||
self.cache[session_id]['files'].append(file_info)
|
||||
logger.info(f"[FileCache] Added {file_type} to cache for session {session_id}: {file_path}")
|
||||
|
||||
def get(self, session_id: str) -> list:
|
||||
"""
|
||||
获取缓存的文件列表
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
文件信息列表 [{'path': '...', 'type': 'image'}, ...],如果没有或已过期返回空列表
|
||||
"""
|
||||
if session_id not in self.cache:
|
||||
return []
|
||||
|
||||
item = self.cache[session_id]
|
||||
|
||||
# 检查是否过期
|
||||
if time.time() - item['timestamp'] > self.ttl:
|
||||
logger.info(f"[FileCache] Cache expired for session {session_id}, clearing...")
|
||||
del self.cache[session_id]
|
||||
return []
|
||||
|
||||
return item['files']
|
||||
|
||||
def clear(self, session_id: str):
|
||||
"""
|
||||
清除指定会话的缓存
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
"""
|
||||
if session_id in self.cache:
|
||||
logger.info(f"[FileCache] Cleared cache for session {session_id}")
|
||||
del self.cache[session_id]
|
||||
|
||||
def cleanup_expired(self):
|
||||
"""清理所有过期的缓存"""
|
||||
current_time = time.time()
|
||||
expired_sessions = []
|
||||
|
||||
for session_id, item in self.cache.items():
|
||||
if current_time - item['timestamp'] > self.ttl:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del self.cache[session_id]
|
||||
logger.debug(f"[FileCache] Cleaned up expired cache for session {session_id}")
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"[FileCache] Cleaned up {len(expired_sessions)} expired cache(s)")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_file_cache = FileCache()
|
||||
|
||||
|
||||
def get_file_cache() -> FileCache:
|
||||
"""获取全局文件缓存实例"""
|
||||
return _file_cache
|
||||
@@ -1,14 +1,23 @@
|
||||
import sys
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.chat_message import ChatMessage
|
||||
import sys
|
||||
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class TerminalMessage(ChatMessage):
|
||||
def __init__(self, msg_id, content, ctype = ContextType.TEXT, from_user_id = "User", to_user_id = "Chatgpt", other_user_id = "Chatgpt"):
|
||||
def __init__(
|
||||
self,
|
||||
msg_id,
|
||||
content,
|
||||
ctype=ContextType.TEXT,
|
||||
from_user_id="User",
|
||||
to_user_id="Chatgpt",
|
||||
other_user_id="Chatgpt",
|
||||
):
|
||||
self.msg_id = msg_id
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
@@ -16,6 +25,7 @@ class TerminalMessage(ChatMessage):
|
||||
self.to_user_id = to_user_id
|
||||
self.other_user_id = other_user_id
|
||||
|
||||
|
||||
class TerminalChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
||||
|
||||
@@ -23,14 +33,18 @@ class TerminalChannel(ChatChannel):
|
||||
print("\nBot:")
|
||||
if reply.type == ReplyType.IMAGE:
|
||||
from PIL import Image
|
||||
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print("<IMAGE>")
|
||||
img.show()
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
import requests,io
|
||||
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
@@ -59,11 +73,12 @@ class TerminalChannel(ChatChannel):
|
||||
print("\nExiting...")
|
||||
sys.exit()
|
||||
msg_id += 1
|
||||
trigger_prefixs = conf().get("single_chat_prefix",[""])
|
||||
trigger_prefixs = conf().get("single_chat_prefix", [""])
|
||||
if check_prefix(prompt, trigger_prefixs) is None:
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg = TerminalMessage(msg_id, prompt))
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
|
||||
context["isgroup"] = False
|
||||
if context:
|
||||
self.produce(context)
|
||||
else:
|
||||
|
||||
10
channel/web/README.md
Normal file
10
channel/web/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# Web Channel
|
||||
|
||||
提供了一个默认的AI对话页面,可展示文本、图片等消息交互,支持markdown语法渲染,兼容插件执行。
|
||||
|
||||
# 使用说明
|
||||
|
||||
- 在 `config.json` 配置文件中的 `channel_type` 字段填入 `web`
|
||||
- 程序运行后将监听9899端口,浏览器访问 http://localhost:9899/chat 即可使用
|
||||
- 监听端口可以在配置文件 `web_port` 中自定义
|
||||
- 对于Docker运行方式,如果需要外部访问,需要在 `docker-compose.yml` 中通过 ports配置将端口监听映射到宿主机
|
||||
641
channel/web/chat.html
Normal file
641
channel/web/chat.html
Normal file
@@ -0,0 +1,641 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh" class="">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>CowAgent Console</title>
|
||||
<link rel="icon" href="assets/favicon.ico" type="image/x-icon">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/markdown-it@13.0.1/dist/markdown-it.min.js"></script>
|
||||
<link id="hljs-light" rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github.min.css">
|
||||
<link id="hljs-dark" rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/github-dark.min.css" disabled>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/python.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/javascript.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/java.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/go.min.js"></script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/languages/bash.min.js"></script>
|
||||
<script>
|
||||
tailwind.config = {
|
||||
darkMode: 'class',
|
||||
theme: {
|
||||
extend: {
|
||||
fontFamily: {
|
||||
sans: ['Inter', 'system-ui', '-apple-system', 'sans-serif'],
|
||||
mono: ['"JetBrains Mono"', '"Fira Code"', 'Consolas', 'monospace'],
|
||||
},
|
||||
colors: {
|
||||
primary: {
|
||||
50: '#EDFDF3', 100: '#D4FAE2', 200: '#ABF4C7', 300: '#74E9A4',
|
||||
400: '#4ABE6E', 500: '#35A85B', 600: '#228547', 700: '#1C6B3B',
|
||||
800: '#1A5532', 900: '#16462A',
|
||||
}
|
||||
},
|
||||
animation: {
|
||||
'pulse-dot': 'pulseDot 1.4s infinite ease-in-out both',
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
<link rel="stylesheet" href="assets/css/console.css">
|
||||
<!-- Apply theme/lang before first paint to avoid flash of unstyled content.
|
||||
This runs synchronously in <head> so the correct class is on <html>
|
||||
before any CSS or body rendering occurs. -->
|
||||
<script>
|
||||
(function() {
|
||||
var theme = localStorage.getItem('cow_theme') || 'dark';
|
||||
if (theme === 'dark') document.documentElement.classList.add('dark');
|
||||
})();
|
||||
</script>
|
||||
</head>
|
||||
<body class="h-screen overflow-hidden bg-gray-50 dark:bg-[#111111] text-slate-800 dark:text-slate-200 font-sans">
|
||||
<div id="app" class="flex h-screen">
|
||||
|
||||
<!-- ================================================================ -->
|
||||
<!-- SIDEBAR -->
|
||||
<!-- ================================================================ -->
|
||||
<aside id="sidebar" class="fixed inset-y-0 left-0 z-50 w-64 bg-[#0A0A0A] text-neutral-400 flex flex-col
|
||||
transform -translate-x-full lg:relative lg:translate-x-0
|
||||
transition-transform duration-300 ease-in-out">
|
||||
<!-- Logo -->
|
||||
<div class="flex items-center gap-3 px-5 h-14 border-b border-white/10 flex-shrink-0">
|
||||
<img src="assets/logo.jpg" alt="CowAgent" class="w-8 h-8 rounded-lg flex-shrink-0">
|
||||
<div class="flex flex-col min-w-0">
|
||||
<span class="text-white font-semibold text-sm truncate">CowAgent</span>
|
||||
<span class="text-neutral-500 text-xs" data-i18n="console">Console</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Navigation -->
|
||||
<nav class="flex-1 overflow-y-auto py-4 px-3 space-y-1">
|
||||
<!-- Chat Group -->
|
||||
<div class="menu-group open" data-group="chat">
|
||||
<button class="w-full flex items-center gap-2 px-3 py-2 text-xs font-semibold uppercase tracking-wider text-neutral-500 hover:text-neutral-300 cursor-pointer transition-colors duration-150">
|
||||
<i class="fas fa-chevron-right text-[10px] chevron"></i>
|
||||
<span data-i18n="nav_chat">Chat</span>
|
||||
</button>
|
||||
<div class="menu-group-items pl-2">
|
||||
<a class="sidebar-item active flex items-center gap-3 px-3 py-2 rounded-lg cursor-pointer transition-all duration-150 hover:bg-white/5 hover:text-neutral-200 text-[14px]"
|
||||
data-view="chat">
|
||||
<i class="fas fa-message item-icon text-xs w-5 text-center"></i>
|
||||
<span data-i18n="menu_chat">Chat</span>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Management Group -->
|
||||
<div class="menu-group open" data-group="manage">
|
||||
<button class="w-full flex items-center gap-2 px-3 py-2 text-xs font-semibold uppercase tracking-wider text-neutral-500 hover:text-neutral-300 cursor-pointer transition-colors duration-150">
|
||||
<i class="fas fa-chevron-right text-[10px] chevron"></i>
|
||||
<span data-i18n="nav_manage">Management</span>
|
||||
</button>
|
||||
<div class="menu-group-items pl-2">
|
||||
<a class="sidebar-item flex items-center gap-3 px-3 py-2 rounded-lg cursor-pointer transition-all duration-150 hover:bg-white/5 hover:text-neutral-200 text-[14px]"
|
||||
data-view="config">
|
||||
<i class="fas fa-sliders item-icon text-xs w-5 text-center"></i>
|
||||
<span data-i18n="menu_config">Config</span>
|
||||
</a>
|
||||
<a class="sidebar-item flex items-center gap-3 px-3 py-2 rounded-lg cursor-pointer transition-all duration-150 hover:bg-white/5 hover:text-neutral-200 text-[14px]"
|
||||
data-view="skills">
|
||||
<i class="fas fa-bolt item-icon text-xs w-5 text-center"></i>
|
||||
<span data-i18n="menu_skills">Skills</span>
|
||||
</a>
|
||||
<a class="sidebar-item flex items-center gap-3 px-3 py-2 rounded-lg cursor-pointer transition-all duration-150 hover:bg-white/5 hover:text-neutral-200 text-[14px]"
|
||||
data-view="memory">
|
||||
<i class="fas fa-brain item-icon text-xs w-5 text-center"></i>
|
||||
<span data-i18n="menu_memory">Memory</span>
|
||||
</a>
|
||||
<a class="sidebar-item flex items-center gap-3 px-3 py-2 rounded-lg cursor-pointer transition-all duration-150 hover:bg-white/5 hover:text-neutral-200 text-[14px]"
|
||||
data-view="channels">
|
||||
<i class="fas fa-tower-broadcast item-icon text-xs w-5 text-center"></i>
|
||||
<span data-i18n="menu_channels">Channels</span>
|
||||
</a>
|
||||
<a class="sidebar-item flex items-center gap-3 px-3 py-2 rounded-lg cursor-pointer transition-all duration-150 hover:bg-white/5 hover:text-neutral-200 text-[14px]"
|
||||
data-view="tasks">
|
||||
<i class="fas fa-clock item-icon text-xs w-5 text-center"></i>
|
||||
<span data-i18n="menu_tasks">Tasks</span>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Monitor Group -->
|
||||
<div class="menu-group open" data-group="monitor">
|
||||
<button class="w-full flex items-center gap-2 px-3 py-2 text-xs font-semibold uppercase tracking-wider text-neutral-500 hover:text-neutral-300 cursor-pointer transition-colors duration-150">
|
||||
<i class="fas fa-chevron-right text-[10px] chevron"></i>
|
||||
<span data-i18n="nav_monitor">Monitor</span>
|
||||
</button>
|
||||
<div class="menu-group-items pl-2">
|
||||
<a class="sidebar-item flex items-center gap-3 px-3 py-2 rounded-lg cursor-pointer transition-all duration-150 hover:bg-white/5 hover:text-neutral-200 text-[14px]"
|
||||
data-view="logs">
|
||||
<i class="fas fa-terminal item-icon text-xs w-5 text-center"></i>
|
||||
<span data-i18n="menu_logs">Logs</span>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
<!-- Sidebar Footer -->
|
||||
<div class="px-4 py-3 border-t border-white/10 flex-shrink-0">
|
||||
<div class="flex items-center gap-2 text-xs text-neutral-600">
|
||||
<i class="fas fa-circle text-[6px] text-primary-400"></i>
|
||||
<a id="sidebar-version"
|
||||
href="https://github.com/zhayujie/chatgpt-on-wechat/releases"
|
||||
target="_blank" rel="noopener noreferrer"
|
||||
class="hover:text-primary-400 transition-colors duration-150 cursor-pointer"></a>
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
<!-- Mobile Overlay -->
|
||||
<div id="sidebar-overlay" class="fixed inset-0 bg-black/50 z-40 hidden lg:hidden cursor-pointer" onclick="toggleSidebar()"></div>
|
||||
|
||||
<!-- ================================================================ -->
|
||||
<!-- MAIN CONTENT -->
|
||||
<!-- ================================================================ -->
|
||||
<div id="main-content" class="flex-1 flex flex-col min-w-0 h-screen">
|
||||
<!-- Top Header -->
|
||||
<header class="h-14 flex items-center gap-3 px-4 border-b border-slate-200 dark:border-white/10 bg-white dark:bg-[#1A1A1A] flex-shrink-0 z-10">
|
||||
<!-- Mobile menu toggle -->
|
||||
<button id="menu-toggle" class="lg:hidden p-2 rounded-lg hover:bg-slate-100 dark:hover:bg-white/10 cursor-pointer transition-colors duration-150"
|
||||
onclick="toggleSidebar()">
|
||||
<i class="fas fa-bars text-slate-600 dark:text-slate-300"></i>
|
||||
</button>
|
||||
|
||||
<!-- Breadcrumb -->
|
||||
<div class="flex items-center gap-2 text-sm min-w-0">
|
||||
<span id="breadcrumb-group" class="text-slate-400 dark:text-slate-500 truncate" data-i18n="nav_chat">Chat</span>
|
||||
<i class="fas fa-chevron-right text-[10px] text-slate-300 dark:text-slate-600"></i>
|
||||
<span id="breadcrumb-page" class="font-medium text-slate-700 dark:text-slate-200 truncate" data-i18n="menu_chat">Chat</span>
|
||||
</div>
|
||||
|
||||
<div class="flex-1"></div>
|
||||
|
||||
<!-- Language Toggle -->
|
||||
<button id="lang-toggle" class="flex items-center gap-1.5 px-3 py-1.5 rounded-lg text-sm font-medium
|
||||
text-slate-500 dark:text-slate-400 hover:bg-slate-100 dark:hover:bg-white/10
|
||||
cursor-pointer transition-colors duration-150"
|
||||
onclick="toggleLanguage()">
|
||||
<i class="fas fa-globe text-xs"></i>
|
||||
<span id="lang-label">EN</span>
|
||||
</button>
|
||||
|
||||
<!-- Theme Toggle -->
|
||||
<button id="theme-toggle" class="p-2 rounded-lg text-slate-500 dark:text-slate-400
|
||||
hover:bg-slate-100 dark:hover:bg-white/10
|
||||
cursor-pointer transition-colors duration-150"
|
||||
onclick="toggleTheme()">
|
||||
<i id="theme-icon" class="fas fa-moon"></i>
|
||||
</button>
|
||||
|
||||
<!-- GitHub Link -->
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat" target="_blank" rel="noopener noreferrer"
|
||||
class="p-2 rounded-lg text-slate-500 dark:text-slate-400 hover:bg-slate-100 dark:hover:bg-white/10
|
||||
cursor-pointer transition-colors duration-150">
|
||||
<i class="fab fa-github text-lg"></i>
|
||||
</a>
|
||||
</header>
|
||||
|
||||
<!-- Content Area -->
|
||||
<div id="content-area" class="flex-1 overflow-hidden">
|
||||
|
||||
<!-- ====================================================== -->
|
||||
<!-- VIEW: Chat -->
|
||||
<!-- ====================================================== -->
|
||||
<div id="view-chat" class="view active">
|
||||
<!-- Messages -->
|
||||
<div id="chat-messages" class="flex-1 overflow-y-auto">
|
||||
<!-- Welcome Screen -->
|
||||
<div id="welcome-screen" class="flex flex-col items-center justify-center h-full px-6 py-12">
|
||||
<img src="assets/logo.jpg" alt="CowAgent" class="w-16 h-16 rounded-2xl mb-6 shadow-lg shadow-primary-500/20">
|
||||
<h1 id="welcome-title" class="text-2xl font-bold text-slate-800 dark:text-slate-100 mb-3">CowAgent</h1>
|
||||
<p id="welcome-subtitle" class="text-slate-500 dark:text-slate-400 text-center max-w-lg mb-10 leading-relaxed"
|
||||
data-i18n-html="welcome_subtitle">I can help you answer questions, manage your computer, create and execute skills,<br>and keep growing through long-term memory.</p>
|
||||
|
||||
<div class="grid grid-cols-1 sm:grid-cols-3 gap-4 w-full max-w-2xl">
|
||||
<div class="example-card group bg-white dark:bg-[#1A1A1A] border border-slate-200 dark:border-white/10 rounded-xl p-4
|
||||
cursor-pointer hover:border-primary-300 dark:hover:border-primary-600 hover:shadow-md transition-all duration-200">
|
||||
<div class="flex items-center gap-2 mb-2">
|
||||
<div class="w-7 h-7 rounded-lg bg-blue-50 dark:bg-blue-900/30 flex items-center justify-center">
|
||||
<i class="fas fa-folder-open text-blue-500 text-xs"></i>
|
||||
</div>
|
||||
<span class="font-medium text-sm text-slate-700 dark:text-slate-200" data-i18n="example_sys_title">System</span>
|
||||
</div>
|
||||
<p class="text-sm text-slate-500 dark:text-slate-400 leading-relaxed" data-i18n="example_sys_text">Show me the files in the workspace</p>
|
||||
</div>
|
||||
<div class="example-card group bg-white dark:bg-[#1A1A1A] border border-slate-200 dark:border-white/10 rounded-xl p-4
|
||||
cursor-pointer hover:border-primary-300 dark:hover:border-primary-600 hover:shadow-md transition-all duration-200">
|
||||
<div class="flex items-center gap-2 mb-2">
|
||||
<div class="w-7 h-7 rounded-lg bg-amber-50 dark:bg-amber-900/30 flex items-center justify-center">
|
||||
<i class="fas fa-clock text-amber-500 text-xs"></i>
|
||||
</div>
|
||||
<span class="font-medium text-sm text-slate-700 dark:text-slate-200" data-i18n="example_task_title">Smart Task</span>
|
||||
</div>
|
||||
<p class="text-sm text-slate-500 dark:text-slate-400 leading-relaxed" data-i18n="example_task_text">Remind me to check the server in 5 minutes</p>
|
||||
</div>
|
||||
<div class="example-card group bg-white dark:bg-[#1A1A1A] border border-slate-200 dark:border-white/10 rounded-xl p-4
|
||||
cursor-pointer hover:border-primary-300 dark:hover:border-primary-600 hover:shadow-md transition-all duration-200">
|
||||
<div class="flex items-center gap-2 mb-2">
|
||||
<div class="w-7 h-7 rounded-lg bg-emerald-50 dark:bg-emerald-900/30 flex items-center justify-center">
|
||||
<i class="fas fa-code text-emerald-500 text-xs"></i>
|
||||
</div>
|
||||
<span class="font-medium text-sm text-slate-700 dark:text-slate-200" data-i18n="example_code_title">Coding</span>
|
||||
</div>
|
||||
<p class="text-sm text-slate-500 dark:text-slate-400 leading-relaxed" data-i18n="example_code_text">Write a Python web scraper script</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Chat Input -->
|
||||
<div class="flex-shrink-0 border-t border-slate-200 dark:border-white/10 bg-white dark:bg-[#1A1A1A] px-4 py-3">
|
||||
<div class="max-w-3xl mx-auto flex items-center gap-2">
|
||||
<button id="new-chat-btn" class="flex-shrink-0 w-10 h-10 flex items-center justify-center rounded-lg
|
||||
text-slate-400 hover:text-primary-500 hover:bg-primary-50 dark:hover:bg-primary-900/20
|
||||
cursor-pointer transition-colors duration-150" title="New Chat"
|
||||
onclick="newChat()">
|
||||
<i class="fas fa-plus text-base"></i>
|
||||
</button>
|
||||
<textarea id="chat-input"
|
||||
class="flex-1 min-w-0 px-4 py-[10px] rounded-xl border border-slate-200 dark:border-slate-600
|
||||
bg-slate-50 dark:bg-white/5 text-slate-800 dark:text-slate-100
|
||||
placeholder:text-slate-400 dark:placeholder:text-slate-500
|
||||
focus:outline-none focus:ring-0 focus:border-primary-600
|
||||
text-sm leading-relaxed"
|
||||
rows="1"
|
||||
data-i18n-placeholder="input_placeholder"
|
||||
placeholder="Type a message..."></textarea>
|
||||
<button id="send-btn"
|
||||
class="flex-shrink-0 w-10 h-10 flex items-center justify-center rounded-lg
|
||||
bg-primary-400 text-white hover:bg-primary-500
|
||||
disabled:bg-slate-300 dark:disabled:bg-slate-600
|
||||
disabled:cursor-not-allowed cursor-pointer transition-colors duration-150"
|
||||
disabled onclick="sendMessage()">
|
||||
<i class="fas fa-paper-plane text-sm"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ====================================================== -->
|
||||
<!-- VIEW: Config -->
|
||||
<!-- ====================================================== -->
|
||||
<div id="view-config" class="view">
|
||||
<div class="flex-1 overflow-y-auto p-6">
|
||||
<div class="max-w-4xl mx-auto">
|
||||
<div class="flex items-center justify-between mb-6">
|
||||
<div>
|
||||
<h2 class="text-xl font-bold text-slate-800 dark:text-slate-100" data-i18n="config_title">Configuration</h2>
|
||||
<p class="text-sm text-slate-500 dark:text-slate-400 mt-1" data-i18n="config_desc">Manage model and agent settings</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="grid gap-6">
|
||||
|
||||
<!-- Model Config Card -->
|
||||
<div class="bg-white dark:bg-[#1A1A1A] rounded-xl border border-slate-200 dark:border-white/10 p-6">
|
||||
<div class="flex items-center gap-3 mb-5">
|
||||
<div class="w-9 h-9 rounded-lg bg-primary-50 dark:bg-primary-900/30 flex items-center justify-center">
|
||||
<i class="fas fa-microchip text-primary-500 text-sm"></i>
|
||||
</div>
|
||||
<h3 class="font-semibold text-slate-800 dark:text-slate-100" data-i18n="config_model">Model Configuration</h3>
|
||||
</div>
|
||||
<div class="space-y-5">
|
||||
<!-- Provider -->
|
||||
<div>
|
||||
<label class="block text-sm font-medium text-slate-600 dark:text-slate-400 mb-1.5" data-i18n="config_provider">Provider</label>
|
||||
<div id="cfg-provider" class="cfg-dropdown" tabindex="0">
|
||||
<div class="cfg-dropdown-selected">
|
||||
<span class="cfg-dropdown-text">--</span>
|
||||
<i class="fas fa-chevron-down cfg-dropdown-arrow"></i>
|
||||
</div>
|
||||
<div class="cfg-dropdown-menu"></div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- Model -->
|
||||
<div>
|
||||
<label class="block text-sm font-medium text-slate-600 dark:text-slate-400 mb-1.5" data-i18n="config_model_name">Model</label>
|
||||
<div id="cfg-model-select" class="cfg-dropdown" tabindex="0">
|
||||
<div class="cfg-dropdown-selected">
|
||||
<span class="cfg-dropdown-text">--</span>
|
||||
<i class="fas fa-chevron-down cfg-dropdown-arrow"></i>
|
||||
</div>
|
||||
<div class="cfg-dropdown-menu"></div>
|
||||
</div>
|
||||
<div id="cfg-model-custom-wrap" class="mt-2 hidden">
|
||||
<input id="cfg-model-custom" type="text"
|
||||
class="w-full px-3 py-2 rounded-lg border border-slate-200 dark:border-slate-600
|
||||
bg-slate-50 dark:bg-white/5 text-sm text-slate-800 dark:text-slate-100
|
||||
focus:outline-none focus:border-primary-500 font-mono transition-colors"
|
||||
data-i18n-placeholder="config_custom_model_hint" placeholder="Enter custom model name">
|
||||
</div>
|
||||
</div>
|
||||
<!-- API Key -->
|
||||
<div id="cfg-api-key-wrap">
|
||||
<label class="block text-sm font-medium text-slate-600 dark:text-slate-400 mb-1.5">API Key</label>
|
||||
<div class="relative">
|
||||
<input id="cfg-api-key" type="text" autocomplete="off" data-1p-ignore data-lpignore="true"
|
||||
class="w-full px-3 py-2 pr-10 rounded-lg border border-slate-200 dark:border-slate-600
|
||||
bg-slate-50 dark:bg-white/5 text-sm text-slate-800 dark:text-slate-100
|
||||
focus:outline-none focus:border-primary-500 font-mono transition-colors cfg-key-masked"
|
||||
placeholder="sk-...">
|
||||
<button type="button" id="cfg-api-key-toggle"
|
||||
class="absolute right-2.5 top-1/2 -translate-y-1/2 text-slate-400 hover:text-slate-600
|
||||
dark:hover:text-slate-300 cursor-pointer transition-colors p-1"
|
||||
onclick="toggleApiKeyVisibility()">
|
||||
<i class="fas fa-eye text-xs"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<!-- API Base -->
|
||||
<div id="cfg-api-base-wrap" class="hidden">
|
||||
<label class="block text-sm font-medium text-slate-600 dark:text-slate-400 mb-1.5">API Base</label>
|
||||
<input id="cfg-api-base" type="text"
|
||||
class="w-full px-3 py-2 rounded-lg border border-slate-200 dark:border-slate-600
|
||||
bg-slate-50 dark:bg-white/5 text-sm text-slate-800 dark:text-slate-100
|
||||
focus:outline-none focus:border-primary-500 font-mono transition-colors"
|
||||
placeholder="https://...">
|
||||
</div>
|
||||
<!-- Save Model Button -->
|
||||
<div class="flex items-center justify-end gap-3 pt-1">
|
||||
<span id="cfg-model-status" class="text-xs text-primary-500 opacity-0 transition-opacity duration-300"></span>
|
||||
<button id="cfg-model-save"
|
||||
class="px-4 py-2 rounded-lg bg-primary-500 hover:bg-primary-600 text-white text-sm font-medium
|
||||
cursor-pointer transition-colors duration-150 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
onclick="saveModelConfig()" data-i18n="config_save">Save</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Agent Config Card -->
|
||||
<div class="bg-white dark:bg-[#1A1A1A] rounded-xl border border-slate-200 dark:border-white/10 p-6">
|
||||
<div class="flex items-center gap-3 mb-5">
|
||||
<div class="w-9 h-9 rounded-lg bg-emerald-50 dark:bg-emerald-900/30 flex items-center justify-center">
|
||||
<i class="fas fa-robot text-emerald-500 text-sm"></i>
|
||||
</div>
|
||||
<h3 class="font-semibold text-slate-800 dark:text-slate-100" data-i18n="config_agent">Agent Configuration</h3>
|
||||
</div>
|
||||
<div class="space-y-4">
|
||||
<div>
|
||||
<label class="block text-sm font-medium text-slate-600 dark:text-slate-400 mb-1.5" data-i18n="config_max_tokens">Max Context Tokens</label>
|
||||
<input id="cfg-max-tokens" type="number" min="1000" max="200000" step="1000"
|
||||
class="w-full px-3 py-2 rounded-lg border border-slate-200 dark:border-slate-600
|
||||
bg-slate-50 dark:bg-white/5 text-sm text-slate-800 dark:text-slate-100
|
||||
focus:outline-none focus:border-primary-500 font-mono transition-colors">
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-sm font-medium text-slate-600 dark:text-slate-400 mb-1.5" data-i18n="config_max_turns">Max Context Turns</label>
|
||||
<input id="cfg-max-turns" type="number" min="1" max="100" step="1"
|
||||
class="w-full px-3 py-2 rounded-lg border border-slate-200 dark:border-slate-600
|
||||
bg-slate-50 dark:bg-white/5 text-sm text-slate-800 dark:text-slate-100
|
||||
focus:outline-none focus:border-primary-500 font-mono transition-colors">
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-sm font-medium text-slate-600 dark:text-slate-400 mb-1.5" data-i18n="config_max_steps">Max Steps</label>
|
||||
<input id="cfg-max-steps" type="number" min="1" max="50" step="1"
|
||||
class="w-full px-3 py-2 rounded-lg border border-slate-200 dark:border-slate-600
|
||||
bg-slate-50 dark:bg-white/5 text-sm text-slate-800 dark:text-slate-100
|
||||
focus:outline-none focus:border-primary-500 font-mono transition-colors">
|
||||
</div>
|
||||
<div class="flex items-center justify-end gap-3 pt-1">
|
||||
<span id="cfg-agent-status" class="text-xs text-primary-500 opacity-0 transition-opacity duration-300"></span>
|
||||
<button id="cfg-agent-save"
|
||||
class="px-4 py-2 rounded-lg bg-primary-500 hover:bg-primary-600 text-white text-sm font-medium
|
||||
cursor-pointer transition-colors duration-150 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
onclick="saveAgentConfig()" data-i18n="config_save">Save</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ====================================================== -->
|
||||
<!-- VIEW: Skills -->
|
||||
<!-- ====================================================== -->
|
||||
<div id="view-skills" class="view">
|
||||
<div class="flex-1 overflow-y-auto p-6">
|
||||
<div class="max-w-4xl mx-auto">
|
||||
<div class="flex items-center justify-between mb-6">
|
||||
<div>
|
||||
<h2 class="text-xl font-bold text-slate-800 dark:text-slate-100" data-i18n="skills_title">Skills</h2>
|
||||
<p class="text-sm text-slate-500 dark:text-slate-400 mt-1" data-i18n="skills_desc">View, enable, or disable agent skills</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Built-in Tools Section -->
|
||||
<div class="mb-8">
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<span class="text-xs font-semibold uppercase tracking-wider text-slate-400 dark:text-slate-500" data-i18n="tools_section_title">Built-in Tools</span>
|
||||
<span id="tools-count-badge" class="hidden px-2 py-0.5 rounded-full text-xs bg-slate-100 dark:bg-white/10 text-slate-500 dark:text-slate-400"></span>
|
||||
</div>
|
||||
<div id="tools-empty" class="flex items-center gap-2 py-4 text-slate-400 dark:text-slate-500 text-sm">
|
||||
<i class="fas fa-spinner fa-spin text-xs"></i>
|
||||
<span data-i18n="tools_loading">Loading tools...</span>
|
||||
</div>
|
||||
<div id="tools-list" class="grid gap-3 sm:grid-cols-2 hidden"></div>
|
||||
</div>
|
||||
|
||||
<!-- Skills Section -->
|
||||
<div>
|
||||
<div class="flex items-center gap-2 mb-3">
|
||||
<span class="text-xs font-semibold uppercase tracking-wider text-slate-400 dark:text-slate-500" data-i18n="skills_section_title">Skills</span>
|
||||
<span id="skills-count-badge" class="hidden px-2 py-0.5 rounded-full text-xs bg-slate-100 dark:bg-white/10 text-slate-500 dark:text-slate-400"></span>
|
||||
</div>
|
||||
<div id="skills-empty" class="flex flex-col items-center justify-center py-12">
|
||||
<div class="w-14 h-14 rounded-2xl bg-amber-50 dark:bg-amber-900/20 flex items-center justify-center mb-3">
|
||||
<i class="fas fa-bolt text-amber-400 text-lg"></i>
|
||||
</div>
|
||||
<p class="text-slate-500 dark:text-slate-400 font-medium" data-i18n="skills_loading">Loading skills...</p>
|
||||
<p class="text-sm text-slate-400 dark:text-slate-500 mt-1" data-i18n="skills_loading_desc">Skills will be displayed here after loading</p>
|
||||
</div>
|
||||
<div id="skills-list" class="grid gap-4 sm:grid-cols-2"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ====================================================== -->
|
||||
<!-- VIEW: Memory -->
|
||||
<!-- ====================================================== -->
|
||||
<div id="view-memory" class="view">
|
||||
<div class="flex-1 overflow-y-auto p-6">
|
||||
<div class="max-w-4xl mx-auto">
|
||||
|
||||
<!-- Panel: list -->
|
||||
<div id="memory-panel-list">
|
||||
<div class="flex items-center justify-between mb-6">
|
||||
<div>
|
||||
<h2 class="text-xl font-bold text-slate-800 dark:text-slate-100" data-i18n="memory_title">Memory</h2>
|
||||
<p class="text-sm text-slate-500 dark:text-slate-400 mt-1" data-i18n="memory_desc">View agent memory files and contents</p>
|
||||
</div>
|
||||
</div>
|
||||
<div id="memory-empty" class="flex flex-col items-center justify-center py-20">
|
||||
<div class="w-16 h-16 rounded-2xl bg-purple-50 dark:bg-purple-900/20 flex items-center justify-center mb-4">
|
||||
<i class="fas fa-brain text-purple-400 text-xl"></i>
|
||||
</div>
|
||||
<p class="text-slate-500 dark:text-slate-400 font-medium" data-i18n="memory_loading">Loading memory files...</p>
|
||||
<p class="text-sm text-slate-400 dark:text-slate-500 mt-1" data-i18n="memory_loading_desc">Memory files will be displayed here</p>
|
||||
</div>
|
||||
<div id="memory-list" class="hidden">
|
||||
<div class="bg-white dark:bg-[#1A1A1A] rounded-xl border border-slate-200 dark:border-white/10 overflow-hidden">
|
||||
<table class="w-full">
|
||||
<thead>
|
||||
<tr class="border-b border-slate-200 dark:border-white/10">
|
||||
<th class="text-left px-4 py-3 text-xs font-semibold uppercase tracking-wider text-slate-500 dark:text-slate-400" data-i18n="memory_col_name">Filename</th>
|
||||
<th class="text-left px-4 py-3 text-xs font-semibold uppercase tracking-wider text-slate-500 dark:text-slate-400" data-i18n="memory_col_type">Type</th>
|
||||
<th class="text-left px-4 py-3 text-xs font-semibold uppercase tracking-wider text-slate-500 dark:text-slate-400" data-i18n="memory_col_size">Size</th>
|
||||
<th class="text-left px-4 py-3 text-xs font-semibold uppercase tracking-wider text-slate-500 dark:text-slate-400" data-i18n="memory_col_updated">Updated</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody id="memory-table-body"></tbody>
|
||||
</table>
|
||||
</div>
|
||||
<div id="memory-pagination" class="flex items-center justify-between mt-4 text-sm text-slate-500 dark:text-slate-400"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Panel: file viewer (replaces list) -->
|
||||
<div id="memory-panel-viewer" class="hidden">
|
||||
<div class="flex items-center gap-3 mb-6">
|
||||
<button onclick="closeMemoryViewer()"
|
||||
class="flex items-center gap-1.5 px-3 py-1.5 rounded-lg text-sm
|
||||
text-slate-500 dark:text-slate-400 hover:bg-slate-100 dark:hover:bg-white/10
|
||||
border border-slate-200 dark:border-white/10 transition-colors cursor-pointer">
|
||||
<i class="fas fa-arrow-left text-xs"></i>
|
||||
<span data-i18n="memory_back">Back</span>
|
||||
</button>
|
||||
<h2 id="memory-viewer-title"
|
||||
class="text-base font-semibold text-slate-800 dark:text-slate-100 font-mono truncate"></h2>
|
||||
</div>
|
||||
<div class="bg-white dark:bg-[#1A1A1A] rounded-xl border border-slate-200 dark:border-white/10 overflow-hidden">
|
||||
<div id="memory-viewer-content"
|
||||
class="p-5 overflow-y-auto text-sm msg-content text-slate-700 dark:text-slate-200"
|
||||
style="max-height: calc(100vh - 220px)"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ====================================================== -->
|
||||
<!-- VIEW: Channels -->
|
||||
<!-- ====================================================== -->
|
||||
<div id="view-channels" class="view">
|
||||
<div class="flex-1 overflow-y-auto p-6">
|
||||
<div class="max-w-4xl mx-auto">
|
||||
<div class="flex items-center justify-between mb-6">
|
||||
<div>
|
||||
<h2 class="text-xl font-bold text-slate-800 dark:text-slate-100" data-i18n="channels_title">Channels</h2>
|
||||
<p class="text-sm text-slate-500 dark:text-slate-400 mt-1" data-i18n="channels_desc">View and manage messaging channels</p>
|
||||
</div>
|
||||
<button id="add-channel-btn" onclick="openAddChannelPanel()"
|
||||
class="flex items-center gap-2 px-4 py-2 rounded-lg bg-primary-500 hover:bg-primary-600
|
||||
text-white text-sm font-medium cursor-pointer transition-colors duration-150">
|
||||
<i class="fas fa-plus text-xs"></i>
|
||||
<span data-i18n="channels_add">Connect</span>
|
||||
</button>
|
||||
</div>
|
||||
<div id="channels-content" class="grid gap-4"></div>
|
||||
<div id="channels-add-panel" class="hidden mt-4"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ====================================================== -->
|
||||
<!-- VIEW: Tasks -->
|
||||
<!-- ====================================================== -->
|
||||
<div id="view-tasks" class="view">
|
||||
<div class="flex-1 overflow-y-auto p-6">
|
||||
<div class="max-w-4xl mx-auto">
|
||||
<div class="flex items-center justify-between mb-6">
|
||||
<div>
|
||||
<h2 class="text-xl font-bold text-slate-800 dark:text-slate-100" data-i18n="tasks_title">Scheduled Tasks</h2>
|
||||
<p class="text-sm text-slate-500 dark:text-slate-400 mt-1" data-i18n="tasks_desc">View and manage scheduled tasks</p>
|
||||
</div>
|
||||
</div>
|
||||
<div id="tasks-empty" class="flex flex-col items-center justify-center py-20">
|
||||
<div class="w-16 h-16 rounded-2xl bg-rose-50 dark:bg-rose-900/20 flex items-center justify-center mb-4">
|
||||
<i class="fas fa-clock text-rose-400 text-xl"></i>
|
||||
</div>
|
||||
<p class="text-slate-500 dark:text-slate-400 font-medium">Loading...</p>
|
||||
</div>
|
||||
<div id="tasks-list" class="grid gap-4 hidden"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- ====================================================== -->
|
||||
<!-- VIEW: Logs -->
|
||||
<!-- ====================================================== -->
|
||||
<div id="view-logs" class="view">
|
||||
<div class="flex-1 overflow-y-auto p-6">
|
||||
<div class="max-w-5xl mx-auto">
|
||||
<div class="flex items-center justify-between mb-6">
|
||||
<div>
|
||||
<h2 class="text-xl font-bold text-slate-800 dark:text-slate-100" data-i18n="logs_title">Logs</h2>
|
||||
<p class="text-sm text-slate-500 dark:text-slate-400 mt-1" data-i18n="logs_desc">Real-time log output (run.log)</p>
|
||||
</div>
|
||||
</div>
|
||||
<!-- Log Terminal -->
|
||||
<div class="bg-slate-900 rounded-xl border border-slate-700 overflow-hidden shadow-lg">
|
||||
<div class="flex items-center gap-2 px-4 py-2.5 bg-slate-800 border-b border-slate-700">
|
||||
<div class="flex gap-1.5">
|
||||
<span class="w-3 h-3 rounded-full bg-red-500/80"></span>
|
||||
<span class="w-3 h-3 rounded-full bg-amber-500/80"></span>
|
||||
<span class="w-3 h-3 rounded-full bg-emerald-500/80"></span>
|
||||
</div>
|
||||
<span class="text-xs text-slate-400 ml-2 font-mono">run.log</span>
|
||||
<div class="flex-1"></div>
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="w-2 h-2 rounded-full bg-emerald-500 animate-pulse"></span>
|
||||
<span class="text-xs text-slate-500" data-i18n="logs_live">Live</span>
|
||||
</div>
|
||||
</div>
|
||||
<div id="log-output" class="p-4 overflow-y-auto font-mono text-xs leading-relaxed text-slate-300 whitespace-pre-wrap break-all" style="height: calc(100vh - 272px)">
|
||||
<p class="text-slate-500" data-i18n="logs_coming_msg">Log streaming will be available here. Connects to run.log for real-time output similar to tail -f.</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div><!-- /content-area -->
|
||||
</div><!-- /main-content -->
|
||||
</div><!-- /app -->
|
||||
|
||||
<!-- Confirm Dialog -->
|
||||
<div id="confirm-dialog-overlay" class="fixed inset-0 bg-black/50 z-[100] hidden flex items-center justify-center">
|
||||
<div class="bg-white dark:bg-[#1A1A1A] rounded-2xl border border-slate-200 dark:border-white/10 shadow-xl
|
||||
w-full max-w-sm mx-4 overflow-hidden">
|
||||
<div class="p-6">
|
||||
<div class="flex items-center gap-3 mb-3">
|
||||
<div class="w-10 h-10 rounded-xl bg-red-50 dark:bg-red-900/20 flex items-center justify-center flex-shrink-0">
|
||||
<i class="fas fa-triangle-exclamation text-red-500"></i>
|
||||
</div>
|
||||
<h3 id="confirm-dialog-title" class="font-semibold text-slate-800 dark:text-slate-100 text-base"></h3>
|
||||
</div>
|
||||
<p id="confirm-dialog-message" class="text-sm text-slate-500 dark:text-slate-400 leading-relaxed ml-[52px]"></p>
|
||||
</div>
|
||||
<div class="flex items-center justify-end gap-3 px-6 py-4 border-t border-slate-100 dark:border-white/5">
|
||||
<button id="confirm-dialog-cancel"
|
||||
class="px-4 py-2 rounded-lg border border-slate-200 dark:border-white/10
|
||||
text-slate-600 dark:text-slate-300 text-sm font-medium
|
||||
hover:bg-slate-50 dark:hover:bg-white/5
|
||||
cursor-pointer transition-colors duration-150"></button>
|
||||
<button id="confirm-dialog-ok"
|
||||
class="px-4 py-2 rounded-lg bg-red-500 hover:bg-red-600 text-white text-sm font-medium
|
||||
cursor-pointer transition-colors duration-150"></button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="assets/js/console.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
2
channel/web/static/axios.min.js
vendored
Normal file
2
channel/web/static/axios.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
354
channel/web/static/css/console.css
Normal file
354
channel/web/static/css/console.css
Normal file
@@ -0,0 +1,354 @@
|
||||
/* =====================================================================
|
||||
CowAgent Console Styles
|
||||
===================================================================== */
|
||||
|
||||
/* Animations */
|
||||
@keyframes pulseDot {
|
||||
0%, 80%, 100% { transform: scale(0.6); opacity: 0.4; }
|
||||
40% { transform: scale(1); opacity: 1; }
|
||||
}
|
||||
|
||||
/* Scrollbar */
|
||||
* { scrollbar-width: thin; scrollbar-color: #94a3b8 transparent; }
|
||||
::-webkit-scrollbar { width: 6px; height: 6px; }
|
||||
::-webkit-scrollbar-track { background: transparent; }
|
||||
::-webkit-scrollbar-thumb { background: #94a3b8; border-radius: 3px; }
|
||||
::-webkit-scrollbar-thumb:hover { background: #64748b; }
|
||||
.dark ::-webkit-scrollbar-thumb { background: #475569; }
|
||||
.dark ::-webkit-scrollbar-thumb:hover { background: #64748b; }
|
||||
|
||||
/* Sidebar */
|
||||
.sidebar-item.active {
|
||||
background: rgba(255, 255, 255, 0.08);
|
||||
color: #FFFFFF;
|
||||
}
|
||||
.sidebar-item.active .item-icon { color: #4ABE6E; }
|
||||
|
||||
/* Menu Groups */
|
||||
.menu-group-items { max-height: 0; overflow: hidden; transition: max-height 0.25s ease-out; }
|
||||
.menu-group.open .menu-group-items { max-height: 500px; transition: max-height 0.35s ease-in; }
|
||||
.menu-group .chevron { transition: transform 0.25s ease; }
|
||||
.menu-group.open .chevron { transform: rotate(90deg); }
|
||||
|
||||
/* View Switching */
|
||||
.view { display: none; height: 100%; }
|
||||
.view.active { display: flex; flex-direction: column; }
|
||||
|
||||
/* Markdown Content */
|
||||
.msg-content p { margin: 0.5em 0; line-height: 1.7; }
|
||||
.msg-content p:first-child { margin-top: 0; }
|
||||
.msg-content p:last-child { margin-bottom: 0; }
|
||||
.msg-content h1, .msg-content h2, .msg-content h3,
|
||||
.msg-content h4, .msg-content h5, .msg-content h6 {
|
||||
margin-top: 1.2em; margin-bottom: 0.6em; font-weight: 600; line-height: 1.3;
|
||||
}
|
||||
.msg-content h1 { font-size: 1.4em; }
|
||||
.msg-content h2 { font-size: 1.25em; }
|
||||
.msg-content h3 { font-size: 1.1em; }
|
||||
.msg-content ul, .msg-content ol { margin: 0.5em 0; padding-left: 1.8em; }
|
||||
.msg-content li { margin: 0.25em 0; }
|
||||
.msg-content pre {
|
||||
border-radius: 8px; overflow-x: auto; margin: 0.8em 0;
|
||||
background: #f1f5f9; padding: 1em;
|
||||
}
|
||||
.dark .msg-content pre { background: #111111; }
|
||||
.msg-content code {
|
||||
font-family: 'JetBrains Mono', 'Fira Code', Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
}
|
||||
.msg-content :not(pre) > code {
|
||||
background: rgba(74, 190, 110, 0.1); color: #1C6B3B;
|
||||
padding: 2px 6px; border-radius: 4px;
|
||||
}
|
||||
.dark .msg-content :not(pre) > code {
|
||||
background: rgba(74, 190, 110, 0.15); color: #74E9A4;
|
||||
}
|
||||
.msg-content pre code { background: transparent; padding: 0; color: inherit; }
|
||||
.msg-content blockquote {
|
||||
border-left: 3px solid #4ABE6E; padding: 0.5em 1em;
|
||||
margin: 0.8em 0; background: rgba(74, 190, 110, 0.05); border-radius: 0 6px 6px 0;
|
||||
}
|
||||
.dark .msg-content blockquote { background: rgba(74, 190, 110, 0.08); }
|
||||
.msg-content table { border-collapse: collapse; width: 100%; margin: 0.8em 0; }
|
||||
.msg-content th, .msg-content td {
|
||||
border: 1px solid #e2e8f0; padding: 8px 12px; text-align: left;
|
||||
}
|
||||
.dark .msg-content th, .dark .msg-content td { border-color: rgba(255,255,255,0.1); }
|
||||
.msg-content th { background: #f1f5f9; font-weight: 600; }
|
||||
.dark .msg-content th { background: #111111; }
|
||||
.msg-content img { max-width: 100%; height: auto; border-radius: 8px; margin: 0.5em 0; }
|
||||
.msg-content a { color: #35A85B; text-decoration: underline; }
|
||||
.msg-content a:hover { color: #228547; }
|
||||
.msg-content hr { border: none; height: 1px; background: #e2e8f0; margin: 1.2em 0; }
|
||||
.dark .msg-content hr { background: rgba(255,255,255,0.1); }
|
||||
|
||||
/* SSE Streaming cursor */
|
||||
@keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
|
||||
.sse-streaming::after {
|
||||
content: '▋';
|
||||
display: inline-block;
|
||||
margin-left: 2px;
|
||||
color: #4ABE6E;
|
||||
animation: blink 0.9s step-end infinite;
|
||||
font-size: 0.85em;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
/* Agent steps (thinking summaries + tool indicators) */
|
||||
.agent-steps:empty { display: none; }
|
||||
.agent-steps:not(:empty) {
|
||||
margin-bottom: 0.625rem;
|
||||
padding-bottom: 0.5rem;
|
||||
border-bottom: 1px dashed rgba(0, 0, 0, 0.08);
|
||||
}
|
||||
.dark .agent-steps:not(:empty) { border-bottom-color: rgba(255, 255, 255, 0.08); }
|
||||
|
||||
.agent-step {
|
||||
font-size: 0.75rem;
|
||||
line-height: 1.4;
|
||||
color: #94a3b8;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
.agent-step:last-child { margin-bottom: 0; }
|
||||
|
||||
/* Thinking step - collapsible */
|
||||
.agent-thinking-step .thinking-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.375rem;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
}
|
||||
.agent-thinking-step .thinking-header.no-toggle { cursor: default; }
|
||||
.agent-thinking-step .thinking-header:not(.no-toggle):hover { color: #64748b; }
|
||||
.dark .agent-thinking-step .thinking-header:not(.no-toggle):hover { color: #cbd5e1; }
|
||||
.agent-thinking-step .thinking-header i:first-child { font-size: 0.625rem; margin-top: 1px; }
|
||||
.agent-thinking-step .thinking-chevron {
|
||||
font-size: 0.5rem;
|
||||
margin-left: auto;
|
||||
transition: transform 0.2s ease;
|
||||
opacity: 0.5;
|
||||
}
|
||||
.agent-thinking-step.expanded .thinking-chevron { transform: rotate(90deg); }
|
||||
.agent-thinking-step .thinking-full {
|
||||
display: none;
|
||||
margin-top: 0.375rem;
|
||||
margin-left: 1rem;
|
||||
padding: 0.5rem;
|
||||
background: rgba(0, 0, 0, 0.02);
|
||||
border-radius: 6px;
|
||||
border: 1px solid rgba(0, 0, 0, 0.04);
|
||||
font-size: 0.75rem;
|
||||
line-height: 1.5;
|
||||
color: #94a3b8;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.dark .agent-thinking-step .thinking-full {
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
border-color: rgba(255, 255, 255, 0.04);
|
||||
}
|
||||
.agent-thinking-step.expanded .thinking-full { display: block; }
|
||||
.agent-thinking-step .thinking-full p { margin: 0.25em 0; }
|
||||
.agent-thinking-step .thinking-full p:first-child { margin-top: 0; }
|
||||
.agent-thinking-step .thinking-full p:last-child { margin-bottom: 0; }
|
||||
|
||||
/* Tool step - collapsible */
|
||||
.agent-tool-step .tool-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.375rem;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
padding: 1px 0;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.agent-tool-step .tool-header:hover { color: #64748b; }
|
||||
.dark .agent-tool-step .tool-header:hover { color: #cbd5e1; }
|
||||
.agent-tool-step .tool-icon { font-size: 0.625rem; }
|
||||
.agent-tool-step .tool-chevron {
|
||||
font-size: 0.5rem;
|
||||
margin-left: auto;
|
||||
transition: transform 0.2s ease;
|
||||
opacity: 0.5;
|
||||
}
|
||||
.agent-tool-step.expanded .tool-chevron { transform: rotate(90deg); }
|
||||
.agent-tool-step .tool-time {
|
||||
font-size: 0.65rem;
|
||||
opacity: 0.6;
|
||||
margin-left: 0.25rem;
|
||||
}
|
||||
|
||||
/* Tool detail panel */
|
||||
.agent-tool-step .tool-detail {
|
||||
display: none;
|
||||
margin-top: 0.375rem;
|
||||
margin-left: 1rem;
|
||||
padding: 0.5rem;
|
||||
background: rgba(0, 0, 0, 0.02);
|
||||
border-radius: 6px;
|
||||
border: 1px solid rgba(0, 0, 0, 0.04);
|
||||
}
|
||||
.dark .agent-tool-step .tool-detail {
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
border-color: rgba(255, 255, 255, 0.04);
|
||||
}
|
||||
.agent-tool-step.expanded .tool-detail { display: block; }
|
||||
.tool-detail-section { margin-bottom: 0.375rem; }
|
||||
.tool-detail-section:last-child { margin-bottom: 0; }
|
||||
.tool-detail-label {
|
||||
font-size: 0.625rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
opacity: 0.6;
|
||||
margin-bottom: 0.125rem;
|
||||
}
|
||||
.tool-detail-content {
|
||||
font-family: 'JetBrains Mono', 'Fira Code', Consolas, monospace;
|
||||
font-size: 0.7rem;
|
||||
line-height: 1.5;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
margin: 0;
|
||||
padding: 0.25rem 0;
|
||||
background: transparent;
|
||||
color: inherit;
|
||||
}
|
||||
.tool-error-text { color: #f87171; }
|
||||
|
||||
/* Tool failed state */
|
||||
.agent-tool-step.tool-failed .tool-name { color: #f87171; }
|
||||
|
||||
/* Config form controls */
|
||||
#view-config input[type="text"],
|
||||
#view-config input[type="number"],
|
||||
#view-config input[type="password"] {
|
||||
height: 40px;
|
||||
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
#view-config input:focus {
|
||||
border-color: #4ABE6E;
|
||||
box-shadow: 0 0 0 3px rgba(74, 190, 110, 0.12);
|
||||
}
|
||||
#view-config input[type="text"]:hover,
|
||||
#view-config input[type="number"]:hover,
|
||||
#view-config input[type="password"]:hover {
|
||||
border-color: #94a3b8;
|
||||
}
|
||||
.dark #view-config input[type="text"]:hover,
|
||||
.dark #view-config input[type="number"]:hover,
|
||||
.dark #view-config input[type="password"]:hover {
|
||||
border-color: #64748b;
|
||||
}
|
||||
|
||||
/* Custom dropdown */
|
||||
.cfg-dropdown {
|
||||
position: relative;
|
||||
outline: none;
|
||||
}
|
||||
.cfg-dropdown-selected {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
height: 40px;
|
||||
padding: 0 0.75rem;
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid #e2e8f0;
|
||||
background: #f8fafc;
|
||||
font-size: 0.875rem;
|
||||
color: #1e293b;
|
||||
cursor: pointer;
|
||||
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
||||
user-select: none;
|
||||
}
|
||||
.dark .cfg-dropdown-selected {
|
||||
border-color: #475569;
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
color: #f1f5f9;
|
||||
}
|
||||
.cfg-dropdown-selected:hover { border-color: #94a3b8; }
|
||||
.dark .cfg-dropdown-selected:hover { border-color: #64748b; }
|
||||
.cfg-dropdown.open .cfg-dropdown-selected,
|
||||
.cfg-dropdown:focus .cfg-dropdown-selected {
|
||||
border-color: #4ABE6E;
|
||||
box-shadow: 0 0 0 3px rgba(74, 190, 110, 0.12);
|
||||
}
|
||||
.cfg-dropdown-arrow {
|
||||
font-size: 0.625rem;
|
||||
color: #94a3b8;
|
||||
transition: transform 0.2s ease;
|
||||
flex-shrink: 0;
|
||||
margin-left: 0.5rem;
|
||||
}
|
||||
.cfg-dropdown.open .cfg-dropdown-arrow { transform: rotate(180deg); }
|
||||
.cfg-dropdown-menu {
|
||||
display: none;
|
||||
position: absolute;
|
||||
top: calc(100% + 4px);
|
||||
left: 0;
|
||||
right: 0;
|
||||
z-index: 50;
|
||||
max-height: 240px;
|
||||
overflow-y: auto;
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid #e2e8f0;
|
||||
background: #ffffff;
|
||||
box-shadow: 0 10px 25px -5px rgba(0, 0, 0, 0.1), 0 4px 10px -5px rgba(0, 0, 0, 0.04);
|
||||
padding: 4px;
|
||||
}
|
||||
.dark .cfg-dropdown-menu {
|
||||
border-color: #334155;
|
||||
background: #1e1e1e;
|
||||
box-shadow: 0 10px 25px -5px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
.cfg-dropdown.open .cfg-dropdown-menu { display: block; }
|
||||
.cfg-dropdown-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 8px 10px;
|
||||
border-radius: 6px;
|
||||
font-size: 0.875rem;
|
||||
color: #334155;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s ease;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
.dark .cfg-dropdown-item { color: #cbd5e1; }
|
||||
.cfg-dropdown-item:hover { background: #f1f5f9; }
|
||||
.dark .cfg-dropdown-item:hover { background: rgba(255, 255, 255, 0.08); }
|
||||
.cfg-dropdown-item.active {
|
||||
background: rgba(74, 190, 110, 0.1);
|
||||
color: #228547;
|
||||
font-weight: 500;
|
||||
}
|
||||
.dark .cfg-dropdown-item.active {
|
||||
background: rgba(74, 190, 110, 0.15);
|
||||
color: #74E9A4;
|
||||
}
|
||||
|
||||
/* API Key masking via CSS (avoids browser password prompts) */
|
||||
.cfg-key-masked {
|
||||
-webkit-text-security: disc;
|
||||
text-security: disc;
|
||||
}
|
||||
|
||||
/* Chat Input */
|
||||
#chat-input {
|
||||
resize: none; height: 42px; max-height: 180px;
|
||||
overflow-y: hidden;
|
||||
transition: border-color 0.2s ease;
|
||||
}
|
||||
|
||||
/* Placeholder Cards */
|
||||
.placeholder-card {
|
||||
transition: transform 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
.placeholder-card:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 8px 25px -5px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
BIN
channel/web/static/favicon.ico
Normal file
BIN
channel/web/static/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.2 KiB |
BIN
channel/web/static/github.png
Normal file
BIN
channel/web/static/github.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.4 KiB |
1868
channel/web/static/js/console.js
Normal file
1868
channel/web/static/js/console.js
Normal file
File diff suppressed because it is too large
Load Diff
BIN
channel/web/static/logo.jpg
Normal file
BIN
channel/web/static/logo.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
1131
channel/web/web_channel.py
Normal file
1131
channel/web/web_channel.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user