mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 18:17:11 +08:00
Compare commits
890 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de2c031797 | ||
|
|
3aa571aa1b | ||
|
|
3e4969efe6 | ||
|
|
446e94df76 | ||
|
|
5b26066a4c | ||
|
|
8a80de5c3f | ||
|
|
52a490c87e | ||
|
|
29490741fd | ||
|
|
f0e416455f | ||
|
|
f7a2c97943 | ||
|
|
993853757b | ||
|
|
a3abfb987d | ||
|
|
2711fa1b1b | ||
|
|
1f7afaba07 | ||
|
|
e02c8bff81 | ||
|
|
22391ba1a5 | ||
|
|
a05781ec19 | ||
|
|
f898ed6a2a | ||
|
|
e6d0a15b54 | ||
|
|
49cff026e2 | ||
|
|
08f0023cfd | ||
|
|
e311466ee6 | ||
|
|
56789e68d7 | ||
|
|
87525bb383 | ||
|
|
bb2880191a | ||
|
|
4f1acf26d6 | ||
|
|
fc2d6b21ac | ||
|
|
b9e84fefbd | ||
|
|
91f5ffb2d9 | ||
|
|
70ff2341cb | ||
|
|
74eed93497 | ||
|
|
d02e26c014 | ||
|
|
523cade7c3 | ||
|
|
e22c183ca9 | ||
|
|
3afd99da30 | ||
|
|
f44979f983 | ||
|
|
095f9cc108 | ||
|
|
1089076fce | ||
|
|
cad3b691a9 | ||
|
|
bac21426d3 | ||
|
|
c4a35314cd | ||
|
|
7090722565 | ||
|
|
6d972c7c18 | ||
|
|
6961a88feb | ||
|
|
c41ec13984 | ||
|
|
ca8e06e562 | ||
|
|
200cd33a8e | ||
|
|
1da7991c65 | ||
|
|
fdfb7e369a | ||
|
|
c2b01cc957 | ||
|
|
5de8e94bb4 | ||
|
|
7a2c15d912 | ||
|
|
70344dd214 | ||
|
|
405372d1a7 | ||
|
|
b8c5174da5 | ||
|
|
1f6f9103d9 | ||
|
|
6431487c7a | ||
|
|
8b2d1189db | ||
|
|
b777f27cb7 | ||
|
|
b31c3b124a | ||
|
|
fa1e965fba | ||
|
|
91dc8b4d58 | ||
|
|
6d16ea8830 | ||
|
|
7db4253264 | ||
|
|
4d2b7d9bf9 | ||
|
|
8f6f4acb88 | ||
|
|
f20d84cb37 | ||
|
|
afbdf1d5d5 | ||
|
|
bc8364d594 | ||
|
|
c8d388f70f | ||
|
|
be13cc3194 | ||
|
|
a46320e744 | ||
|
|
071709d263 | ||
|
|
93a32ae5ff | ||
|
|
eee96f226f | ||
|
|
e19a8b479c | ||
|
|
9ef459112e | ||
|
|
e96474bd5c | ||
|
|
6fed719e09 | ||
|
|
99aac76618 | ||
|
|
599f458201 | ||
|
|
2f8099059c | ||
|
|
e24f177832 | ||
|
|
48cc143e88 | ||
|
|
b09b46c045 | ||
|
|
2c6583cc9c | ||
|
|
e381d1bfb8 | ||
|
|
eac619d54f | ||
|
|
a6ef3bc0ce | ||
|
|
118122c541 | ||
|
|
bfdf33ac09 | ||
|
|
fa3370df5b | ||
|
|
f1e51672c5 | ||
|
|
91f97b2728 | ||
|
|
2c542e03fe | ||
|
|
71a11b4267 | ||
|
|
ea642757db | ||
|
|
fb72b601aa | ||
|
|
27e507e744 | ||
|
|
4db19f816f | ||
|
|
096d5776d1 | ||
|
|
3d799eb4d9 | ||
|
|
e4ac3afa4d | ||
|
|
d38e4eed5b | ||
|
|
97787fac91 | ||
|
|
b494ee2f1c | ||
|
|
31ac80a074 | ||
|
|
c8896450f6 | ||
|
|
c662fa4c63 | ||
|
|
db2ee802ca | ||
|
|
d40e915e2b | ||
|
|
c0616e7efa | ||
|
|
01660597e3 | ||
|
|
c5b549f450 | ||
|
|
802d8457bb | ||
|
|
c3a3df67b0 | ||
|
|
5798aeb3cd | ||
|
|
cc81dd9172 | ||
|
|
44fdadda08 | ||
|
|
66a014150b | ||
|
|
1da596639f | ||
|
|
76614ae9e5 | ||
|
|
6ddddffc0f | ||
|
|
dd95f849d4 | ||
|
|
22c7f8fe9e | ||
|
|
3d47be1f49 | ||
|
|
5e399c46b1 | ||
|
|
38e1db7a37 | ||
|
|
8309f7cdbe | ||
|
|
b8cc62ae95 | ||
|
|
c0eb433fa2 | ||
|
|
7f857d66f6 | ||
|
|
93b14d38f4 | ||
|
|
21825faab0 | ||
|
|
1fafd39298 | ||
|
|
23b750fc4f | ||
|
|
90581c840d | ||
|
|
cac7a6228a | ||
|
|
674fbc3f69 | ||
|
|
9577bf1cc7 | ||
|
|
654ebe93e7 | ||
|
|
ecb1b3c491 | ||
|
|
c3d1711edc | ||
|
|
c12c7f10f0 | ||
|
|
f71820bf4e | ||
|
|
748c53c774 | ||
|
|
b290a71bfb | ||
|
|
3204c51eca | ||
|
|
2c4b8a44dc | ||
|
|
943aa05eaa | ||
|
|
d0fd36e7e1 | ||
|
|
f45ff5fd0a | ||
|
|
c22c7102d5 | ||
|
|
11ecfd1b41 | ||
|
|
798e30e5ac | ||
|
|
15e0702329 | ||
|
|
a2bc22c37d | ||
|
|
8093fcc64c | ||
|
|
800419e7cc | ||
|
|
a241dc6785 | ||
|
|
805bea0d5f | ||
|
|
9d394adf24 | ||
|
|
2074f27aff | ||
|
|
283ad48b86 | ||
|
|
07e10a7943 | ||
|
|
2812a5026c | ||
|
|
3a20461abf | ||
|
|
64ae3d1e21 | ||
|
|
a25d7ea65b | ||
|
|
74ebbdd761 | ||
|
|
a0427b569e | ||
|
|
5346dfdd8b | ||
|
|
3ee4147285 | ||
|
|
c41e486bfc | ||
|
|
eda3ba92fd | ||
|
|
40255290b0 | ||
|
|
af5bc73dc0 | ||
|
|
0247cd4c45 | ||
|
|
916762cc8c | ||
|
|
d6fdf8ca2a | ||
|
|
95708489c9 | ||
|
|
ced0fa4608 | ||
|
|
7e0fbd600f | ||
|
|
f33e4e0323 | ||
|
|
d0fd78497d | ||
|
|
8045019603 | ||
|
|
7d92b9435e | ||
|
|
1e0822703a | ||
|
|
0403ff88ef | ||
|
|
78376d591b | ||
|
|
8e23d0df20 | ||
|
|
9e281d20ab | ||
|
|
644bd4a106 | ||
|
|
7729e66a96 | ||
|
|
d67d6b7948 | ||
|
|
4c4a46bfbe | ||
|
|
4536f9c177 | ||
|
|
977d3bc02e | ||
|
|
eae95dfef5 | ||
|
|
b67d4460ca | ||
|
|
3dea8311b1 | ||
|
|
11f6e98874 | ||
|
|
2609e595f4 | ||
|
|
ac6e41abc8 | ||
|
|
9c17e16d0a | ||
|
|
55e9064307 | ||
|
|
91cabd7d49 | ||
|
|
7456950530 | ||
|
|
8fcdda625d | ||
|
|
40a10ee926 | ||
|
|
c3f7e2645c | ||
|
|
b264af1892 | ||
|
|
43e93e8e22 | ||
|
|
d6c4789688 | ||
|
|
cb31ee6f01 | ||
|
|
f7b694ac56 | ||
|
|
eb809055d4 | ||
|
|
78d9be82b2 | ||
|
|
76a95c0226 | ||
|
|
d3ab8fb04a | ||
|
|
f7a0b63a00 | ||
|
|
a21dd97786 | ||
|
|
04943c0bfa | ||
|
|
203d4d8bfb | ||
|
|
c049a619dc | ||
|
|
cc1b14b607 | ||
|
|
e04a12a8f4 | ||
|
|
a2c82bc583 | ||
|
|
b4dc382f7c | ||
|
|
eca1892e2a | ||
|
|
23a237074e | ||
|
|
219e9eca4f | ||
|
|
413e09fb9e | ||
|
|
3514c37e4c | ||
|
|
95260e303c | ||
|
|
0cef34bdfa | ||
|
|
9838979bbd | ||
|
|
c8910b8e14 | ||
|
|
207fa1d019 | ||
|
|
be0bb591e7 | ||
|
|
bfacdb9c3b | ||
|
|
ae4077ed6c | ||
|
|
6eb3c90e18 | ||
|
|
8c2a53a504 | ||
|
|
74db1e0308 | ||
|
|
b9dfdcef3d | ||
|
|
9d4afeac31 | ||
|
|
14ae2f169a | ||
|
|
55df19142f | ||
|
|
40fd545b2c | ||
|
|
95fb07343e | ||
|
|
4d87906559 | ||
|
|
6b30dced43 | ||
|
|
293a03b7c8 | ||
|
|
c010549f17 | ||
|
|
cc0be22026 | ||
|
|
e5ba26febe | ||
|
|
36f9680eec | ||
|
|
f4f5be5b08 | ||
|
|
d89b056886 | ||
|
|
65424c7db9 | ||
|
|
32a8a847fc | ||
|
|
88fb3dbf60 | ||
|
|
f6bee3aa58 | ||
|
|
5f19f37dcb | ||
|
|
dd36d8ce9e | ||
|
|
865e4b5349 | ||
|
|
e70564752b | ||
|
|
6e0d2f9437 | ||
|
|
291f936097 | ||
|
|
0b2ce48586 | ||
|
|
da87fd9e20 | ||
|
|
d4da4d2575 | ||
|
|
bad20ff483 | ||
|
|
21ad51ffbf | ||
|
|
697c6d5fbe | ||
|
|
293c659053 | ||
|
|
a12507abbd | ||
|
|
4e675b84fb | ||
|
|
c1022feab8 | ||
|
|
ddcfcf21fe | ||
|
|
86a58c3d80 | ||
|
|
abf9a9048d | ||
|
|
b1030a527a | ||
|
|
8d07ba6332 | ||
|
|
4ce37f84e4 | ||
|
|
061d8a3a5f | ||
|
|
374cd5dbb8 | ||
|
|
5ad53c2b9c | ||
|
|
a2ec1a063d | ||
|
|
e431dbe2df | ||
|
|
7218463f9e | ||
|
|
aeb09a95b0 | ||
|
|
0c8f292e12 | ||
|
|
f001ac6903 | ||
|
|
db8e506de0 | ||
|
|
099f859dd4 | ||
|
|
b7684c1c2b | ||
|
|
058c167f79 | ||
|
|
49446d4872 | ||
|
|
ced560e1e1 | ||
|
|
339102c3cd | ||
|
|
6331350239 | ||
|
|
34e06fcbf8 | ||
|
|
70aac312ff | ||
|
|
5e00704152 | ||
|
|
1a9edb6907 | ||
|
|
0c18c3a6dd | ||
|
|
847bb51ce4 | ||
|
|
fa60a5dc63 | ||
|
|
aaed3f9839 | ||
|
|
21b956b983 | ||
|
|
792e940279 | ||
|
|
c2477b26c0 | ||
|
|
4b27de809b | ||
|
|
572932d8e8 | ||
|
|
270dd778d9 | ||
|
|
dd04287b0a | ||
|
|
36ac6d005a | ||
|
|
701daedf49 | ||
|
|
238f05f453 | ||
|
|
dd082bd212 | ||
|
|
cfd2f27b0b | ||
|
|
a2160d135e | ||
|
|
16d7836369 | ||
|
|
f3de4dcc5f | ||
|
|
e34523028f | ||
|
|
efe2fbacd6 | ||
|
|
2fa1df29be | ||
|
|
f72cd13fba | ||
|
|
5b552dffbf | ||
|
|
a0ae2d13dc | ||
|
|
f7262a0a3a | ||
|
|
9736f121eb | ||
|
|
7c8fb7eacc | ||
|
|
b45eea5908 | ||
|
|
6babf4ee6c | ||
|
|
576526d4ee | ||
|
|
c03e31b7be | ||
|
|
a1aa925019 | ||
|
|
a5a234ed97 | ||
|
|
5b5dbcd78b | ||
|
|
bd1c6361d3 | ||
|
|
1fc1febf03 | ||
|
|
55cc35efa9 | ||
|
|
5ba8fdc5e7 | ||
|
|
6ea295e227 | ||
|
|
5010c76ef7 | ||
|
|
79c7f0c29f | ||
|
|
2b3e643786 | ||
|
|
90cdff327c | ||
|
|
55c116e727 | ||
|
|
3dd83aa6b7 | ||
|
|
a74aa12641 | ||
|
|
151e8c69f9 | ||
|
|
d8bfa77705 | ||
|
|
6bd286e8d5 | ||
|
|
905532b681 | ||
|
|
04d5c1ab01 | ||
|
|
28be141dc7 | ||
|
|
652b786baf | ||
|
|
ba6c671051 | ||
|
|
ca25d0433f | ||
|
|
5338106dfa | ||
|
|
854d613a81 | ||
|
|
b6b76be4f6 | ||
|
|
03d94fcfa0 | ||
|
|
b2c5f0d455 | ||
|
|
54f60dd38c | ||
|
|
42f181aca2 | ||
|
|
9c3a27894f | ||
|
|
f7cd348912 | ||
|
|
aeaeb75d3b | ||
|
|
96542b532e | ||
|
|
139295fe0d | ||
|
|
13217b2ce2 | ||
|
|
5cc8b56a7c | ||
|
|
e23e01c95e | ||
|
|
bca8ba12c7 | ||
|
|
3c44bdbe1c | ||
|
|
db93ed025b | ||
|
|
4209e108d0 | ||
|
|
14cbf011af | ||
|
|
03a41ec199 | ||
|
|
125fe2a026 | ||
|
|
ac4adac29e | ||
|
|
ac449d078e | ||
|
|
79be4530d4 | ||
|
|
85ce52d70c | ||
|
|
7ab56b9076 | ||
|
|
dedf976375 | ||
|
|
89f438208a | ||
|
|
ffbc5080ae | ||
|
|
4167f13bac | ||
|
|
6ba0baabb0 | ||
|
|
081003df47 | ||
|
|
559194ffb2 | ||
|
|
97a26d4a46 | ||
|
|
503c6c9b7e | ||
|
|
9a1e10deff | ||
|
|
054f927c05 | ||
|
|
22210747d0 | ||
|
|
53b2deb72c | ||
|
|
6fc158e7d6 | ||
|
|
a23a65c731 | ||
|
|
7dc7105ee2 | ||
|
|
bac70108b2 | ||
|
|
297404b21e | ||
|
|
33a7f8b558 | ||
|
|
4a670b7df7 | ||
|
|
79e4af315e | ||
|
|
c6e31b2fdc | ||
|
|
91dc44df53 | ||
|
|
7e57f8f157 | ||
|
|
15f6b7c6d3 | ||
|
|
b213ba541d | ||
|
|
7c6ed9944e | ||
|
|
a5a825e439 | ||
|
|
a4ab547f77 | ||
|
|
76ed763abe | ||
|
|
b9e3125610 | ||
|
|
8d9d5b7b6f | ||
|
|
187601da1e | ||
|
|
cc3a0fc367 | ||
|
|
44cc4165d1 | ||
|
|
f98b43514e | ||
|
|
3c9b1a14e9 | ||
|
|
827e8eddf8 | ||
|
|
7bc27d6167 | ||
|
|
ba06edd63a | ||
|
|
cacf553a5b | ||
|
|
d89091a8ea | ||
|
|
01a56e1155 | ||
|
|
a64d7c42b1 | ||
|
|
36b6cc58bf | ||
|
|
5ac8a257e7 | ||
|
|
74119d0372 | ||
|
|
4e162c73e5 | ||
|
|
5ff753a492 | ||
|
|
89400630c0 | ||
|
|
3899c0cfe3 | ||
|
|
a086f1989f | ||
|
|
1171b04e93 | ||
|
|
c55d81825a | ||
|
|
2dcd026e9f | ||
|
|
cdf8609d24 | ||
|
|
36580c5f7f | ||
|
|
1cff2521f4 | ||
|
|
db4998a56b | ||
|
|
acbd506568 | ||
|
|
0cf8e3be73 | ||
|
|
2473334dfc | ||
|
|
1ff72d1d37 | ||
|
|
241fad5524 | ||
|
|
1b48cea50a | ||
|
|
88bf345b91 | ||
|
|
ab4ff3d1a3 | ||
|
|
3502e0d643 | ||
|
|
995894d3aa | ||
|
|
4da8714124 | ||
|
|
6b247ae880 | ||
|
|
176941ea3b | ||
|
|
5176b56d3b | ||
|
|
8abf18ab25 | ||
|
|
395edbd9f4 | ||
|
|
2386eb8fc2 | ||
|
|
68208f82a0 | ||
|
|
ca916b7ce5 | ||
|
|
01e02934da | ||
|
|
c81a79f7b9 | ||
|
|
1133648bf6 | ||
|
|
e05bc541d7 | ||
|
|
d689d20482 | ||
|
|
39dd99b272 | ||
|
|
cda21acb43 | ||
|
|
9bd7d09f20 | ||
|
|
b22994c2d2 | ||
|
|
e027286b6d | ||
|
|
d6e16995e0 | ||
|
|
782bff3a51 | ||
|
|
de26dc0597 | ||
|
|
233b24ab0f | ||
|
|
2f9e5b1219 | ||
|
|
dd36b8b150 | ||
|
|
f81ac31fe1 | ||
|
|
24b63bc5bd | ||
|
|
1817a972c6 | ||
|
|
74a253f521 | ||
|
|
41762a1c57 | ||
|
|
a786fa4b75 | ||
|
|
e4c7602c0c | ||
|
|
e0d2e34980 | ||
|
|
9ef8e1be3f | ||
|
|
aae9b64833 | ||
|
|
4bab4299f2 | ||
|
|
954e55f4b4 | ||
|
|
2361e3c28c | ||
|
|
8224c2fc16 | ||
|
|
8aac86f0a9 | ||
|
|
6384e9310b | ||
|
|
7a9205dfba | ||
|
|
94b47a56f4 | ||
|
|
709b5be634 | ||
|
|
f970b2c168 | ||
|
|
973acb37ed | ||
|
|
1c9020a565 | ||
|
|
c5f1d0042c | ||
|
|
fa706e8b1d | ||
|
|
12c170f227 | ||
|
|
db27dfe227 | ||
|
|
2db4673392 | ||
|
|
38619db629 | ||
|
|
930fd436ea | ||
|
|
98b8ff2fc8 | ||
|
|
d0662683f9 | ||
|
|
957f2574a9 | ||
|
|
109b362ebd | ||
|
|
ff3fdfa738 | ||
|
|
e2636ed54a | ||
|
|
dbe2f17e1a | ||
|
|
4dc535673f | ||
|
|
f414b6408e | ||
|
|
3aa2e6a04d | ||
|
|
1963ff273f | ||
|
|
bb737a71d5 | ||
|
|
a582a46ce9 | ||
|
|
abf80a3266 | ||
|
|
d768f5c66d | ||
|
|
b25e843351 | ||
|
|
419a3e518e | ||
|
|
d1b867a7c0 | ||
|
|
c34d70b3cb | ||
|
|
a33df9312f | ||
|
|
ebf8db0b37 | ||
|
|
e539ae3b69 | ||
|
|
4c5e8850aa | ||
|
|
94c0af3037 | ||
|
|
165182c68f | ||
|
|
65b9542599 | ||
|
|
d01d1f8830 | ||
|
|
ad3e9f3d42 | ||
|
|
4589974095 | ||
|
|
ed4553ddf8 | ||
|
|
ff97ae73f1 | ||
|
|
f96b4d2781 | ||
|
|
ce32cfffdb | ||
|
|
f66df8531e | ||
|
|
dfe1c23e76 | ||
|
|
07fd81919f | ||
|
|
210042bb81 | ||
|
|
12dc7427e9 | ||
|
|
b476085110 | ||
|
|
776cdaf63c | ||
|
|
69b6855745 | ||
|
|
3590babd8b | ||
|
|
c29d391c1d | ||
|
|
50e44dbb2a | ||
|
|
34277a3940 | ||
|
|
f1a00d58ca | ||
|
|
d1a5f17ae8 | ||
|
|
4dbc54fa15 | ||
|
|
1d4ff796d7 | ||
|
|
44cb54a9ea | ||
|
|
6409f49609 | ||
|
|
9ee0ea88b5 | ||
|
|
a3819d8673 | ||
|
|
2d7dd71a3d | ||
|
|
0e8195ae61 | ||
|
|
3e92d07618 | ||
|
|
e59597280d | ||
|
|
f2e3d69d8a | ||
|
|
9d2cb75c84 | ||
|
|
f971505c4a | ||
|
|
2133c1d6af | ||
|
|
0bf06ddfd3 | ||
|
|
024a50d642 | ||
|
|
e4eebd64d1 | ||
|
|
c9055989e9 | ||
|
|
4f1ed197ce | ||
|
|
3e710aa2a1 | ||
|
|
b6226a45bb | ||
|
|
3001ba9266 | ||
|
|
b0a401a1ed | ||
|
|
6b4dc37428 | ||
|
|
8528c9b262 | ||
|
|
7222a5c2f4 | ||
|
|
59050001ef | ||
|
|
2ba8f18724 | ||
|
|
fb22e01b89 | ||
|
|
76a81d5360 | ||
|
|
3314b05648 | ||
|
|
45b89218de | ||
|
|
beb7bda243 | ||
|
|
bef2896f50 | ||
|
|
9fea949b25 | ||
|
|
be258e5b05 | ||
|
|
008178d737 | ||
|
|
527d5e1dbc | ||
|
|
9b47e2d6f9 | ||
|
|
8781b1e976 | ||
|
|
38c653d8d8 | ||
|
|
74e48bb137 | ||
|
|
c3aaa1f735 | ||
|
|
bead2aa228 | ||
|
|
dc52ab8aa9 | ||
|
|
20b71f206b | ||
|
|
73c87d5959 | ||
|
|
c6601aaeed | ||
|
|
6e14fce1fe | ||
|
|
be5a62f1b8 | ||
|
|
1fa8cefaea | ||
|
|
d7c251ac83 | ||
|
|
d03229a183 | ||
|
|
243482e829 | ||
|
|
79d10be8a0 | ||
|
|
dca5c058e0 | ||
|
|
9163ce71fd | ||
|
|
2ec5374765 | ||
|
|
d6a4b35cd3 | ||
|
|
8205d2552c | ||
|
|
9a99caeb9d | ||
|
|
1e09bd0e76 | ||
|
|
cae12eb187 | ||
|
|
8bb36e0eb6 | ||
|
|
d183204caa | ||
|
|
4a22ae6b61 | ||
|
|
a52f54d988 | ||
|
|
618c94edb8 | ||
|
|
eaf4e9174f | ||
|
|
4af2c7f3d7 | ||
|
|
361f599df0 | ||
|
|
ffe4ea5e4c | ||
|
|
9461e3e01a | ||
|
|
7c85c6f742 | ||
|
|
b5df6faadf | ||
|
|
7cefe2d825 | ||
|
|
350633b69b | ||
|
|
1cd6a71ce0 | ||
|
|
3a08b002a0 | ||
|
|
665001732b | ||
|
|
cca49da730 | ||
|
|
f6d370ad29 | ||
|
|
c9131b333b | ||
|
|
e44161bf42 | ||
|
|
a26189fb25 | ||
|
|
89dd8a1db6 | ||
|
|
650e0b4ad4 | ||
|
|
c60f0517fb | ||
|
|
0f8dc91a8b | ||
|
|
b58feb5d8e | ||
|
|
71c8043699 | ||
|
|
40264bc9cb | ||
|
|
a7772316f9 | ||
|
|
34209021c8 | ||
|
|
3e9e8d442a | ||
|
|
d2bf90c6c7 | ||
|
|
1e58c1ad2b | ||
|
|
8cea022ec5 | ||
|
|
f32f8aa08e | ||
|
|
3ea8781381 | ||
|
|
ab83dacb76 | ||
|
|
4cbf46fd4d | ||
|
|
0a7d6e4577 | ||
|
|
df4c1f0401 | ||
|
|
9a86a67984 | ||
|
|
a0cbe9c3e2 | ||
|
|
a83e5a9b65 | ||
|
|
de33911460 | ||
|
|
0be56e5b25 | ||
|
|
abcbb34b1c | ||
|
|
6a13dd04a3 | ||
|
|
f2e29f3f2e | ||
|
|
68361cddd2 | ||
|
|
6404332adc | ||
|
|
e060b6fea2 | ||
|
|
e8aae27ee9 | ||
|
|
2f732e5493 | ||
|
|
65f20ff2c1 | ||
|
|
8f72e8c3e6 | ||
|
|
3b8972ce1f | ||
|
|
fc5d3e4e9c | ||
|
|
29fbf69945 | ||
|
|
583440b82b | ||
|
|
720de9d73f | ||
|
|
78332d882b | ||
|
|
2dfbc840b3 | ||
|
|
0b4bf15163 | ||
|
|
2989249e4b | ||
|
|
9cef559a05 | ||
|
|
47fe16c92a | ||
|
|
36b5c821ff | ||
|
|
82ec440b45 | ||
|
|
88f4a45cae | ||
|
|
7fb4f72b84 | ||
|
|
d4fc322101 | ||
|
|
8fa3da9ca5 | ||
|
|
68ef5aa3ae | ||
|
|
28bd917c9f | ||
|
|
0eb1b94300 | ||
|
|
15e6cf850b | ||
|
|
ee91c86a29 | ||
|
|
48c08f4aad | ||
|
|
fceabb8e67 | ||
|
|
fcfafb05f1 | ||
|
|
f1e8344beb | ||
|
|
f687b2b6f4 | ||
|
|
8ee7a48151 | ||
|
|
89e8f385b4 | ||
|
|
bf4ae9a051 | ||
|
|
6bd1242d43 | ||
|
|
8779eab36b | ||
|
|
3174b1158c | ||
|
|
18740093d1 | ||
|
|
8c7d1d4010 | ||
|
|
8c48a27e1a | ||
|
|
4278d2b8ef | ||
|
|
3a3affd3ec | ||
|
|
45d72b8b9b | ||
|
|
03b908c079 | ||
|
|
d35d01f980 | ||
|
|
9c208ffa2c | ||
|
|
bea4416f12 | ||
|
|
2ea8b4ef73 | ||
|
|
e6946ef989 | ||
|
|
9aeb60f66d | ||
|
|
d687f9329e | ||
|
|
3207258fd9 | ||
|
|
d8b75206fe | ||
|
|
88e8dd5162 | ||
|
|
c9306633b2 | ||
|
|
c50d1cc99d | ||
|
|
9a20c1cb02 | ||
|
|
176f77ba5b | ||
|
|
484de6237b | ||
|
|
898aa30b1d | ||
|
|
8b73a74609 | ||
|
|
3c6d42b22e | ||
|
|
40563c1e96 | ||
|
|
cb0c86ec1c | ||
|
|
614f3b1ea4 | ||
|
|
938e3b5cf2 | ||
|
|
5fe8d9a855 | ||
|
|
8193ecf5f6 | ||
|
|
1dff630257 | ||
|
|
eaac3e3579 | ||
|
|
d3758968d0 | ||
|
|
020f9a8d98 | ||
|
|
9d8ae80548 | ||
|
|
7e7484a27d | ||
|
|
0adf8d6e5d | ||
|
|
1a981ea970 | ||
|
|
5bd9f50818 | ||
|
|
44f6892cb7 | ||
|
|
fdf6b0dc6b | ||
|
|
a7914279a9 | ||
|
|
2cf71dd6f2 | ||
|
|
62e3baba20 | ||
|
|
e00c99c1d7 | ||
|
|
31d5b95611 | ||
|
|
cc881adda6 | ||
|
|
78d4c58b70 | ||
|
|
eca369532d | ||
|
|
9520d94b13 | ||
|
|
f973bc3fe2 | ||
|
|
94004b095b | ||
|
|
f652d592bd | ||
|
|
186e18fe94 | ||
|
|
28eb67bc24 | ||
|
|
6c7e4aaf37 | ||
|
|
709a1317ef | ||
|
|
371e38cfa6 | ||
|
|
5a221848e9 | ||
|
|
6901c5ba56 | ||
|
|
21a3b0d9a1 | ||
|
|
29422edcc9 | ||
|
|
2da1c18b71 | ||
|
|
be592cc290 | ||
|
|
ce8635dd99 | ||
|
|
76783f0ad3 | ||
|
|
441228e200 | ||
|
|
45a131aa0d | ||
|
|
a7900d4b2c | ||
|
|
a4b1d7446a | ||
|
|
7458a6298f | ||
|
|
b0f54bb8b7 | ||
|
|
acddadc406 | ||
|
|
761fb20dd9 | ||
|
|
b74274b96b | ||
|
|
7835379f8f | ||
|
|
49ba278316 | ||
|
|
388058467c | ||
|
|
cf25bd7869 | ||
|
|
02a95345aa | ||
|
|
6076e2ed0a | ||
|
|
cec674cb47 | ||
|
|
c5a90823fa | ||
|
|
18d82bc1f0 | ||
|
|
a68af990ea | ||
|
|
e71c600d10 | ||
|
|
d7f1f7182c | ||
|
|
dfb2e460b4 | ||
|
|
5badef8ba9 | ||
|
|
18aa5ce75c | ||
|
|
1545a9f262 | ||
|
|
47cc65a787 | ||
|
|
cda9d5873d | ||
|
|
02cd553990 | ||
|
|
71d288f550 | ||
|
|
87df588c80 | ||
|
|
4ad2997717 | ||
|
|
50a03e7c15 | ||
|
|
4f3d12129c | ||
|
|
37a95980d4 | ||
|
|
f49806558e | ||
|
|
8da362d6fe | ||
|
|
bf02a59aec | ||
|
|
461777cad3 | ||
|
|
0597ba20d2 | ||
|
|
0b5fd27cd8 | ||
|
|
f5f8033d4d | ||
|
|
a5f7dec011 | ||
|
|
d9ef5a6612 | ||
|
|
66a81cd47c | ||
|
|
81edd13470 | ||
|
|
7a94745b8a | ||
|
|
06b02f5df8 | ||
|
|
83136e3142 | ||
|
|
950a9f2ee0 | ||
|
|
a26c10fee8 | ||
|
|
4bcd76fe93 | ||
|
|
90ccb091ca | ||
|
|
62df27eaa1 | ||
|
|
349115b948 | ||
|
|
4fd7e4be67 | ||
|
|
947e892916 | ||
|
|
d62b7d1a99 | ||
|
|
432b39a9c4 | ||
|
|
26540bfb63 | ||
|
|
fd64f88a7e | ||
|
|
72994bc9ef | ||
|
|
7e1138af50 | ||
|
|
72dbddb7f7 | ||
|
|
10dba50843 | ||
|
|
d6af1b5827 | ||
|
|
6c362a9b4b | ||
|
|
9a0584d649 | ||
|
|
5ab5211c95 | ||
|
|
f644682be7 | ||
|
|
ffad8e4d26 | ||
|
|
8f07e6304a | ||
|
|
834c03359f | ||
|
|
3e2c68ba49 | ||
|
|
2a21941b68 | ||
|
|
e78886fb35 | ||
|
|
80bf6a0c7a | ||
|
|
48e066b677 | ||
|
|
dcb9d7fc2a | ||
|
|
279f0f0234 | ||
|
|
b3c8a7d8de | ||
|
|
1baf1a79e5 | ||
|
|
35160e717e | ||
|
|
a12f2d8fbd | ||
|
|
6b7c17374b | ||
|
|
9b3585e795 | ||
|
|
74f383a7d4 | ||
|
|
820fbeed18 | ||
|
|
f76e8d9a77 | ||
|
|
5b85e60d5d | ||
|
|
24de670c2c | ||
|
|
42aca71763 | ||
|
|
9b4ef85174 | ||
|
|
9b389ffc33 | ||
|
|
b3cb81aa52 | ||
|
|
61865bc408 | ||
|
|
c2ea6214a9 | ||
|
|
b6684fe7a3 | ||
|
|
b50ebc05a0 | ||
|
|
dbb0648c39 | ||
|
|
5fc0987cc3 | ||
|
|
7c4037147c | ||
|
|
f76cb1231e | ||
|
|
6701d8c5e6 | ||
|
|
ff3d143185 | ||
|
|
ea95ab9062 | ||
|
|
38c901a1c5 | ||
|
|
0c9753b7cd | ||
|
|
721b36c7f7 | ||
|
|
f8e0716474 | ||
|
|
3d428ee844 | ||
|
|
a3be1fcd8f |
13
.flake8
Normal file
13
.flake8
Normal file
@@ -0,0 +1,13 @@
|
||||
[flake8]
|
||||
max-line-length = 176
|
||||
select = E303,W293,W291,W292,E305,E231,E302
|
||||
exclude =
|
||||
.tox,
|
||||
__pycache__,
|
||||
*.pyc,
|
||||
.env
|
||||
venv/*
|
||||
.venv/*
|
||||
reports/*
|
||||
dist/*
|
||||
lib/*
|
||||
28
.github/ISSUE_TEMPLATE.md
vendored
28
.github/ISSUE_TEMPLATE.md
vendored
@@ -1,28 +0,0 @@
|
||||
### 前置确认
|
||||
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间,依赖已安装
|
||||
3. 在已有 issue 中未搜索到类似问题
|
||||
4. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
|
||||
|
||||
### 问题描述
|
||||
|
||||
> 简要说明、截图、复现步骤等,也可以是需求或想法
|
||||
|
||||
|
||||
|
||||
|
||||
### 终端日志 (如有报错)
|
||||
|
||||
```
|
||||
[在此处粘贴终端日志]
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 环境
|
||||
|
||||
- 操作系统类型 (Mac/Windows/Linux):
|
||||
- Python版本 ( 执行 `python3 -V` ):
|
||||
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`):
|
||||
133
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
Normal file
133
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
name: Bug report 🐛
|
||||
description: 项目运行中遇到的Bug或问题。
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
### ⚠️ 前置确认
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
||||
3. `git pull` 拉取最新代码
|
||||
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
||||
5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
|
||||
6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 前置确认
|
||||
options:
|
||||
- label: 我确认我运行的是最新版本的代码,并且安装了所需的依赖,在[FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs)中也未找到类似问题。
|
||||
required: true
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: ⚠️ 搜索issues中是否已存在类似问题
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索你的问题
|
||||
或相关日志的关键词来查找是否存在类似问题。
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有跟我遇到的问题相关的issue
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写你对你所遇到问题的简略总结,这将帮助其他人更好的找到相似问题,谢谢❤️。
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 操作系统类型?
|
||||
description: >
|
||||
请选择你运行程序的操作系统类型。
|
||||
options:
|
||||
- Windows
|
||||
- Linux
|
||||
- MacOS
|
||||
- Docker
|
||||
- Railway
|
||||
- Windows Subsystem for Linux (WSL)
|
||||
- Other (请在问题中说明)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的python版本是?
|
||||
description: |
|
||||
请选择你运行程序的`python`版本。
|
||||
注意:在`python 3.7`中,有部分可选依赖无法安装。
|
||||
经过长时间的观察,我们认为`python 3.8`是兼容性最好的版本。
|
||||
`python 3.7`~`python 3.10`以外版本的issue,将视情况直接关闭。
|
||||
options:
|
||||
- python 3.7
|
||||
- python 3.8
|
||||
- python 3.9
|
||||
- python 3.10
|
||||
- other
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 使用的chatgpt-on-wechat版本是?
|
||||
description: |
|
||||
请确保你使用的是 [releases](https://github.com/zhayujie/chatgpt-on-wechat/releases) 中的最新版本。
|
||||
如果你使用git, 请使用`git branch`命令来查看分支。
|
||||
options:
|
||||
- Latest Release
|
||||
- Master (branch)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的`channel`类型是?
|
||||
description: |
|
||||
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
|
||||
options:
|
||||
- wx(个人微信, itchat)
|
||||
- wxy(个人微信, wechaty)
|
||||
- wechatmp(公众号, 订阅号)
|
||||
- wechatmp_service(公众号, 服务号)
|
||||
- terminal
|
||||
- other
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 复现步骤 🕹
|
||||
description: |
|
||||
**⚠️ 不能复现将会关闭issue.**
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 问题描述 😯
|
||||
description: 详细描述出现的问题,或提供有关截图。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 终端日志 📒
|
||||
description: |
|
||||
在此处粘贴终端日志,可在主目录下`run.log`文件中找到,这会帮助我们更好的分析问题,注意隐去你的API key。
|
||||
如果在配置文件中加入`"debug": true`,打印出的日志会更有帮助。
|
||||
|
||||
<details>
|
||||
<summary><i>示例</i></summary>
|
||||
```log
|
||||
[DEBUG][2023-04-16 00:23:22][plugin_manager.py:157] - Plugin SUMMARY triggered by event Event.ON_HANDLE_CONTEXT
|
||||
[DEBUG][2023-04-16 00:23:22][main.py:221] - [Summary] on_handle_context. content: $总结前100条消息
|
||||
[DEBUG][2023-04-16 00:23:24][main.py:240] - [Summary] limit: 100, duration: -1 seconds
|
||||
[ERROR][2023-04-16 00:23:24][chat_channel.py:244] - Worker return exception: name 'start_date' is not defined
|
||||
Traceback (most recent call last):
|
||||
File "C:\ProgramData\Anaconda3\lib\concurrent\futures\thread.py", line 57, in run
|
||||
result = self.fn(*self.args, **self.kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 132, in _handle
|
||||
reply = self._generate_reply(context)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 142, in _generate_reply
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
||||
File "D:\project\chatgpt-on-wechat\plugins\plugin_manager.py", line 159, in emit_event
|
||||
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 255, in on_handle_context
|
||||
records = self._get_records(session_id, start_time, limit)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 96, in _get_records
|
||||
c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", (session_id, start_date, limit))
|
||||
NameError: name 'start_date' is not defined
|
||||
[INFO][2023-04-16 00:23:36][app.py:14] - signal 2 received, exiting...
|
||||
```
|
||||
</details>
|
||||
value: |
|
||||
```log
|
||||
<此处粘贴终端日志>
|
||||
```
|
||||
28
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
Normal file
28
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Feature request 🚀
|
||||
description: 提出你对项目的新想法或建议。
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写简略总结,谢谢❤️。
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: ⚠️ 搜索是否存在类似issue
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索关键词查找是否存在相似issue。
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有发现相似issue
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 总结
|
||||
description: 描述feature的功能。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 举例
|
||||
description: 提供聊天示例,草图或相关网址。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 动机
|
||||
description: 描述你提出该feature的动机,比如没有这项feature对你的使用造成了怎样的影响。 请提供更详细的场景描述,这可能会帮助我们发现并提出更好的解决方案。
|
||||
72
.github/workflows/deploy-image-arm.yml
vendored
Normal file
72
.github/workflows/deploy-image-arm.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
# GitHub recommends pinning actions to a commit SHA.
|
||||
# To get a newer version, you will need to update the SHA.
|
||||
# You can also reference a tag or branch, but the action may change without warning.
|
||||
|
||||
name: Create and publish a Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ['master']
|
||||
create:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Available platforms
|
||||
run: echo ${{ steps.buildx.outputs.platforms }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
platforms: linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}-arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
with:
|
||||
package-name: 'chatgpt-on-wechat'
|
||||
package-type: 'container'
|
||||
min-versions-to-keep: 10
|
||||
delete-only-untagged-versions: 'true'
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
68
.github/workflows/deploy-image.yml
vendored
Normal file
68
.github/workflows/deploy-image.yml
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
# GitHub recommends pinning actions to a commit SHA.
|
||||
# To get a newer version, you will need to update the SHA.
|
||||
# You can also reference a tag or branch, but the action may change without warning.
|
||||
|
||||
name: Create and publish a Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ['master']
|
||||
create:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
with:
|
||||
package-name: 'chatgpt-on-wechat'
|
||||
package-type: 'container'
|
||||
min-versions-to-keep: 10
|
||||
delete-only-untagged-versions: 'true'
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
21
.gitignore
vendored
21
.gitignore
vendored
@@ -1,5 +1,8 @@
|
||||
.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
.venv
|
||||
.vs
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
@@ -10,3 +13,21 @@ nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
user_datas.pkl
|
||||
chatgpt_tool_hub/
|
||||
plugins/**/
|
||||
!plugins/bdunit
|
||||
!plugins/dungeon
|
||||
!plugins/finish
|
||||
!plugins/godcmd
|
||||
!plugins/tool
|
||||
!plugins/banwords
|
||||
!plugins/banwords/**/
|
||||
plugins/banwords/__pycache__
|
||||
plugins/banwords/lib/__pycache__
|
||||
!plugins/hello
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
client_config.json
|
||||
|
||||
30
.pre-commit-config.yaml
Normal file
30
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: fix-byte-order-marker
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
||||
- id: pretty-format-json
|
||||
types: [text]
|
||||
files: \.json(.template)?$
|
||||
args: [ --autofix , --no-ensure-ascii, --indent=2, --no-sort-keys]
|
||||
- id: trailing-whitespace
|
||||
exclude: '(\/|^)lib\/'
|
||||
args: [ --markdown-linebreak-ext=md ]
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
exclude: '(\/|^)lib\/'
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
exclude: '(\/|^)lib\/'
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: '(\/|^)lib\/'
|
||||
3
Dockerfile
Normal file
3
Dockerfile
Normal file
@@ -0,0 +1,3 @@
|
||||
FROM ghcr.io/zhayujie/chatgpt-on-wechat:latest
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
253
README.md
253
README.md
@@ -1,67 +1,101 @@
|
||||
# 简介
|
||||
|
||||
> ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。
|
||||
> chatgpt-on-wechat(简称CoW)项目是基于大模型的智能对话机器人,支持微信公众号、企业微信应用、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/Gemini/LinkAI/ChatGLM/KIMI/文心一言/讯飞星火/通义千问/LinkAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
|
||||
|
||||
|
||||
基于ChatGPT的微信聊天机器人,通过 [ChatGPT](https://github.com/openai/openai-python) 接口生成对话内容,使用 [itchat](https://github.com/littlecodersh/ItChat) 实现微信消息的接收和自动回复。已实现的特性如下:
|
||||
最新版本支持的功能如下:
|
||||
|
||||
- [x] **文本对话:** 接收私聊及群组中的微信消息,使用ChatGPT生成回复内容,完成自动回复
|
||||
- [x] **规则定制化:** 支持私聊中按指定规则触发自动回复,支持对群组设置自动回复白名单
|
||||
- [x] **多账号:** 支持多微信账号同时运行
|
||||
- [x] **图片生成:** 支持根据描述生成图片,并自动发送至个人聊天或群聊
|
||||
- [x] **上下文记忆**:支持多轮对话记忆,且为每个好友维护独立的上下会话
|
||||
- [x] **语音识别:** 支持接收和处理语音消息,通过文字或语音回复
|
||||
- ✅ **多端部署:** 有多种部署方式可选择且功能完备,目前已支持微信公众号、企业微信应用、飞书、钉钉等部署方式
|
||||
- ✅ **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4o-mini, GPT-4o, GPT-4, Claude-3.5, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM-4,Kimi(月之暗面), MiniMax
|
||||
- ✅ **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
|
||||
- ✅ **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型
|
||||
- ✅ **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件
|
||||
- ✅ **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、智能客服、私域助手使用,基于 [LinkAI](https://link-ai.tech) 实现
|
||||
|
||||
## 声明
|
||||
|
||||
# 更新日志
|
||||
1. 本项目遵循 [MIT开源协议](/LICENSE),仅用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为
|
||||
2. 境内使用该项目时,请使用国内厂商的大模型服务,并进行必要的内容安全审核及过滤
|
||||
3. 本项目主要接入协同办公平台,推荐使用公众号、企微自建应用、钉钉、飞书等接入通道,其他通道为历史产物已不维护
|
||||
4. 任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任
|
||||
|
||||
>**2023.03.25:** 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 [#578](https://github.com/zhayujie/chatgpt-on-wechat/issues/578)。(contributed by [@lanvent](https://github.com/lanvent) in [#565](https://github.com/zhayujie/chatgpt-on-wechat/pull/565))
|
||||
## 演示
|
||||
|
||||
>**2023.03.09:** 基于 `whisper API` 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
|
||||
DEMO视频:https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
>**2023.03.02:** 接入[ChatGPT API](https://platform.openai.com/docs/guides/chat) (gpt-3.5-turbo),默认使用该模型进行对话,需升级openai依赖 (`pip3 install --upgrade openai`)。网络问题参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
## 社区
|
||||
|
||||
>**2023.02.09:** 扫码登录存在封号风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
|
||||
添加小助手微信加入开源项目交流群:
|
||||
|
||||
>**2023.02.05:** 在openai官方接口方案中 (GPT-3模型) 实现上下文对话
|
||||
<img width="160" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
|
||||
>**2022.12.18:** 支持根据描述生成图片并发送,openai版本需大于0.25.0
|
||||
<br>
|
||||
|
||||
>**2022.12.17:** 原来的方案是从 [ChatGPT页面](https://chat.openai.com/chat) 获取session_token,使用 [revChatGPT](https://github.com/acheong08/ChatGPT) 直接访问web接口,但随着ChatGPT接入Cloudflare人机验证,这一方案难以在服务器顺利运行。 所以目前使用的方案是调用 OpenAI 官方提供的 [API](https://beta.openai.com/docs/api-reference/introduction),回复质量上基本接近于ChatGPT的内容,劣势是暂不支持有上下文记忆的对话,优势是稳定性和响应速度较好。
|
||||
# 企业服务
|
||||
|
||||
# 使用效果
|
||||
<a href="https://link-ai.tech" target="_blank"><img width="800" src="https://cdn.link-ai.tech/image/link-ai-intro.jpg"></a>
|
||||
|
||||
### 个人聊天
|
||||
> [LinkAI](https://link-ai.tech/) 是面向企业和开发者的一站式AI应用平台,聚合多模态大模型、知识库、Agent 插件、工作流等能力,支持一键接入主流平台并进行管理,支持SaaS、私有化部署多种模式。
|
||||
>
|
||||
> LinkAI 目前 已在私域运营、智能客服、企业效率助手等场景积累了丰富的 AI 解决方案, 在电商、文教、健康、新消费、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。
|
||||
|
||||

|
||||
**企业服务和产品咨询** 可联系产品顾问:
|
||||
|
||||
### 群组聊天
|
||||
<img width="160" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/github-product-consult.png">
|
||||
|
||||

|
||||
<br>
|
||||
|
||||
### 图片生成
|
||||
# 🏷 更新日志
|
||||
|
||||

|
||||
>**2024.07.19:** [1.6.9版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.9) 新增 gpt-4o-mini 模型、阿里语音识别、企微应用渠道路由优化
|
||||
|
||||
>**2024.07.05:** [1.6.8版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.8) 和 [1.6.7版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.7),Claude3.5, Gemini 1.5 Pro, MiniMax模型、工作流图片输入、模型列表完善
|
||||
|
||||
# 快速开始
|
||||
>**2024.06.04:** [1.6.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.6) 和 [1.6.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.5),gpt-4o模型、钉钉流式卡片、讯飞语音识别/合成
|
||||
|
||||
## 准备
|
||||
>**2024.04.26:** [1.6.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.0),新增 Kimi 接入、gpt-4-turbo版本升级、文件总结和语音识别问题修复
|
||||
|
||||
### 1. OpenAI账号注册
|
||||
>**2024.03.26:** [1.5.8版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.8) 和 [1.5.7版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.7),新增 GLM-4、Claude-3 模型,edge-tts 语音支持
|
||||
|
||||
前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,参考这篇 [教程](https://www.pythonthree.com/register-openai-chatgpt/) 可以通过虚拟手机号来接收验证码。创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。
|
||||
>**2024.01.26:** [1.5.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.6) 和 [1.5.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.5),钉钉接入,tool插件升级,4-turbo模型更新
|
||||
|
||||
> 项目中使用的对话模型是 davinci,计费方式是约每 750 字 (包含请求和回复) 消耗 $0.02,图片生成是每张消耗 $0.016,账号创建有免费的 $18 额度 (更新3.25: 最新注册的已经无免费额度了),使用完可以更换邮箱重新注册。
|
||||
>**2023.11.11:** [1.5.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.3) 和 [1.5.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.4),新增通义千问模型、Google Gemini
|
||||
|
||||
#### 1.1 ChapGPT service On Azure
|
||||
一种替换以上的方法是使用Azure推出的[ChatGPT service](https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/)。它host在公有云Azure上,因此不需要VPN就可以直接访问。不过目前仍然处于preview阶段。新用户可以通过Try Azure for free来薅一段时间的羊毛
|
||||
>**2023.11.10:** [1.5.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.2),新增飞书通道、图像识别对话、黑名单配置
|
||||
|
||||
>**2023.11.10:** [1.5.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.0),新增 `gpt-4-turbo`, `dall-e-3`, `tts` 模型接入,完善图像理解&生成、语音识别&生成的多模态能力
|
||||
|
||||
>**2023.10.16:** 支持通过意图识别使用LinkAI联网搜索、数学计算、网页访问等插件,参考[插件文档](https://docs.link-ai.tech/platform/plugins)
|
||||
|
||||
>**2023.09.26:** 插件增加 文件/文章链接 一键总结和对话的功能,使用参考:[插件说明](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai#3%E6%96%87%E6%A1%A3%E6%80%BB%E7%BB%93%E5%AF%B9%E8%AF%9D%E5%8A%9F%E8%83%BD)
|
||||
|
||||
>**2023.08.08:** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图
|
||||
|
||||
>**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
更早更新日志查看: [归档日志](/docs/version/old-version.md)
|
||||
|
||||
<br>
|
||||
|
||||
# 🚀 快速开始
|
||||
|
||||
快速开始详细文档:[项目搭建文档](https://docs.link-ai.tech/cow/quick-start)
|
||||
|
||||
## 一、准备
|
||||
|
||||
### 1. 账号注册
|
||||
|
||||
项目默认使用OpenAI接口,需前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。接口需要海外网络访问及绑定信用卡支付。
|
||||
|
||||
> 默认对话模型是 openai 的 gpt-3.5-turbo,计费方式是约每 1000tokens (约750个英文单词 或 500汉字,包含请求和回复) 消耗 $0.002,图片生成是Dell E模型,每张消耗 $0.016。
|
||||
|
||||
项目同时也支持使用 LinkAI 接口,无需代理,可使用 Kimi、文心、讯飞、GPT-3.5、GPT-4o 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结、工作流等能力。修改配置即可一键使用,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
### 2.运行环境
|
||||
|
||||
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
|
||||
> 建议Python版本在 3.7.1~3.9.X 之间,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
|
||||
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
|
||||
> 建议Python版本在 3.7.1~3.9.X 之间,推荐3.8版本,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
|
||||
|
||||
> 注意:Docker 或 Railway 部署无需安装python环境和下载源码,可直接快进到下一节。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
|
||||
@@ -70,20 +104,22 @@ git clone https://github.com/zhayujie/chatgpt-on-wechat
|
||||
cd chatgpt-on-wechat/
|
||||
```
|
||||
|
||||
注: 如遇到网络问题可选择国内镜像 https://gitee.com/zhayujie/chatgpt-on-wechat
|
||||
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
> 能够使用`itchat`创建机器人,并具有文字交流功能所需的最小依赖集合。
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
**(3) 拓展依赖 (可选,建议安装):**
|
||||
|
||||
```bash
|
||||
pip3 install itchat-uos==1.5.0.dev0
|
||||
pip3 install --upgrade openai
|
||||
pip3 install -r requirements-optional.txt
|
||||
```
|
||||
注:`itchat-uos`使用指定版本1.5.0.dev0,`openai`使用最新版本,需高于0.27.0。
|
||||
> 如果某项依赖安装失败可注释掉对应的行再继续
|
||||
|
||||
**(3) 拓展依赖 (可选):**
|
||||
|
||||
语音识别及语音回复相关依赖:[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。
|
||||
|
||||
|
||||
## 配置
|
||||
## 二、配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
|
||||
@@ -91,24 +127,30 @@ pip3 install --upgrade openai
|
||||
cp config-template.json config.json
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改:
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证JSON格式的完整):
|
||||
|
||||
```bash
|
||||
# config.json文件内容示例
|
||||
{
|
||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
||||
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
|
||||
{
|
||||
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-4, gpt-4-turbo, wenxin, xunfei, glm-4, claude-3-haiku, moonshot
|
||||
"open_ai_api_key": "YOUR API KEY", # 如果使用openAI模型则填入上面创建的 OpenAI API KEY
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述,
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"voice_reply_voice": false, # 是否使用语音回复语音
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
# 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
|
||||
"linkai_api_key": "", # LinkAI Api Key
|
||||
"linkai_app_code": "" # LinkAI 应用或工作流code
|
||||
}
|
||||
```
|
||||
**配置说明:**
|
||||
@@ -127,72 +169,131 @@ pip3 install --upgrade openai
|
||||
|
||||
**3.语音识别**
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,目前只支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音,但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊)
|
||||
|
||||
**4.其他配置**
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放)
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `gpt-4o-mini`, `gpt-4o`, `gpt-4`, `wenxin` , `claude` , `gemini`, `glm-4`, `xunfei`, `moonshot`等,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。
|
||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档,在[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中检查哪些参数在本项目中是可配置的。
|
||||
+ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
|
||||
+ `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
|
||||
+ `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
|
||||
+ `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。
|
||||
+ `hot_reload`: 程序退出后,暂存等于状态,默认关闭。
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
|
||||
**5.LinkAI配置 (可选)**
|
||||
|
||||
## 运行
|
||||
+ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用或工作流的code,选填
|
||||
|
||||
**本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
|
||||
## 三、运行
|
||||
|
||||
### 1.本地运行
|
||||
|
||||
如果是开发机 **本地运行**,直接在项目根目录下执行:
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
|
||||
终端输出二维码后,进行扫码登录,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的账号需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
使用nohup命令在后台运行程序:
|
||||
|
||||
```bash
|
||||
touch nohup.out # 首次运行需要新建日志文件
|
||||
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
|
||||
```
|
||||
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。
|
||||
|
||||
> **注意:** 如果 扫码后手机提示登录验证需要等待5s,而终端的二维码再次刷新并提示 `Log in time out, reloading QR code`,此时需参考此 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/8) 修改一行代码即可解决。
|
||||
> **多账号支持:** 将项目复制多份,分别启动程序,用不同账号扫码登录即可实现同时运行。
|
||||
|
||||
> **多账号支持:** 将 项目复制多份,分别启动程序,用不同账号扫码登录即可实现同时运行。
|
||||
|
||||
> **特殊指令:** 用户向机器人发送 **#清除记忆** 即可清空该用户的上下文记忆。
|
||||
> **特殊指令:** 用户向机器人发送 **#reset** 即可清空该用户的上下文记忆。
|
||||
|
||||
|
||||
### 3.Docker部署
|
||||
|
||||
参考文档 [Docker部署](https://github.com/limccn/chatgpt-on-wechat/wiki/Docker%E9%83%A8%E7%BD%B2) (Contributed by [limccn](https://github.com/limccn))。
|
||||
> 使用docker部署无需下载源码和安装依赖,只需要获取 docker-compose.yml 配置文件并启动容器即可。
|
||||
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功的表现是执行 `docker -v` 和 `docker-compose version` (或 docker compose version) 可以查看到版本号,可前往 [docker官网](https://docs.docker.com/engine/install/) 进行下载。
|
||||
|
||||
**(1) 下载 docker-compose.yml 文件**
|
||||
|
||||
```bash
|
||||
wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml
|
||||
```
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 修改所需配置,如 `OPEN_AI_API_KEY` 和 `GROUP_NAME_WHITE_LIST` 等。
|
||||
|
||||
**(2) 启动容器**
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
|
||||
```bash
|
||||
sudo docker compose up -d
|
||||
```
|
||||
|
||||
运行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。
|
||||
|
||||
注意:
|
||||
|
||||
- 如果 `docker-compose` 是 1.X 版本 则需要执行 `sudo docker-compose up -d` 来启动容器
|
||||
- 该命令会自动去 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取 latest 版本的镜像,latest 镜像会在每次项目 release 新的版本时生成
|
||||
|
||||
最后运行以下命令可查看容器运行日志,扫描日志中的二维码即可完成登录:
|
||||
|
||||
```bash
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
**(3) 插件使用**
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
|
||||
```
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
|
||||
### 4. Railway部署
|
||||
[Use with Railway](#use-with-railway)(PaaS, Free, Stable, ✅Recommended)
|
||||
> Railway offers $5 (500 hours) of runtime per month
|
||||
1. Click the [Railway](https://railway.app/) button to go to the Railway homepage
|
||||
2. Click the `Start New Project` button.
|
||||
3. Click the `Deploy from Github repo` button.
|
||||
4. Choose your repo (you can fork this repo firstly)
|
||||
5. Set environment variable to override settings in config-template.json, such as: model, open_ai_api_base, open_ai_api_key, use_azure_chatgpt etc.
|
||||
|
||||
## 常见问题
|
||||
> Railway 每月提供5刀和最多500小时的免费额度。 (07.11更新: 目前大部分账号已无法免费部署)
|
||||
|
||||
1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
2. 点击 `Deploy Now` 按钮。
|
||||
3. 设置环境变量来重载程序运行的参数,例如`open_ai_api_key`, `character_desc`。
|
||||
|
||||
**一键部署:**
|
||||
|
||||
[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
|
||||
<br>
|
||||
|
||||
# 🔎 常见问题
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (语料持续完善中,回复仅供参考)
|
||||
|
||||
## 联系
|
||||
# 🛠️ 开发
|
||||
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题优先查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索,若无相似问题可创建Issue,或加微信 eijuyahz 交流。
|
||||
欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py) 实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
|
||||
|
||||
|
||||
# ✉ 联系
|
||||
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品顾问](https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg)咨询。
|
||||
|
||||
# 🌟 贡献者
|
||||
|
||||

|
||||
|
||||
68
app.py
68
app.py
@@ -1,27 +1,71 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import config
|
||||
from channel import channel_factory
|
||||
from common.log import logger
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
from channel import channel_factory
|
||||
from common import const
|
||||
from config import load_config
|
||||
from plugins import *
|
||||
import threading
|
||||
|
||||
|
||||
def sigterm_handler_wrap(_signo):
|
||||
old_handler = signal.getsignal(_signo)
|
||||
|
||||
def func(_signo, _stack_frame):
|
||||
logger.info("signal {} received, exiting...".format(_signo))
|
||||
conf().save_user_datas()
|
||||
if callable(old_handler): # check old_handler
|
||||
return old_handler(_signo, _stack_frame)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def start_channel(channel_name: str):
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework",
|
||||
const.FEISHU, const.DINGTALK]:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
try:
|
||||
from common import linkai_client
|
||||
threading.Thread(target=linkai_client.start, args=(channel,)).start()
|
||||
except Exception as e:
|
||||
pass
|
||||
channel.startup()
|
||||
|
||||
|
||||
def run():
|
||||
try:
|
||||
# load config
|
||||
config.load_config()
|
||||
load_config()
|
||||
# ctrl + c
|
||||
sigterm_handler_wrap(signal.SIGINT)
|
||||
# kill signal
|
||||
sigterm_handler_wrap(signal.SIGTERM)
|
||||
|
||||
# create channel
|
||||
channel_name='wx'
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name=='wx':
|
||||
PluginManager().load_plugins()
|
||||
channel_name = conf().get("channel_type", "wx")
|
||||
|
||||
# startup channel
|
||||
channel.startup()
|
||||
if "--cmd" in sys.argv:
|
||||
channel_name = "terminal"
|
||||
|
||||
if channel_name == "wxy":
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
|
||||
start_channel(channel_name)
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
|
||||
214
bot/ali/ali_qwen_bot.py
Normal file
214
bot/ali/ali_qwen_bot.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import broadscope_bailian
|
||||
from broadscope_bailian import ChatQaMessage
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.ali.ali_qwen_session import AliQwenSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common import const
|
||||
from config import conf, load_config
|
||||
|
||||
class AliQwenBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api_key_expired_time = self.set_api_key()
|
||||
self.sessions = SessionManager(AliQwenSession, model=conf().get("model", const.QWEN))
|
||||
|
||||
def api_key_client(self):
|
||||
return broadscope_bailian.AccessTokenClient(access_key_id=self.access_key_id(), access_key_secret=self.access_key_secret())
|
||||
|
||||
def access_key_id(self):
|
||||
return conf().get("qwen_access_key_id")
|
||||
|
||||
def access_key_secret(self):
|
||||
return conf().get("qwen_access_key_secret")
|
||||
|
||||
def agent_key(self):
|
||||
return conf().get("qwen_agent_key")
|
||||
|
||||
def app_id(self):
|
||||
return conf().get("qwen_app_id")
|
||||
|
||||
def node_id(self):
|
||||
return conf().get("qwen_node_id", "")
|
||||
|
||||
def temperature(self):
|
||||
return conf().get("temperature", 0.2 )
|
||||
|
||||
def top_p(self):
|
||||
return conf().get("top_p", 1)
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[QWEN] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[QWEN] session query={}".format(session.messages))
|
||||
|
||||
reply_content = self.reply_text(session)
|
||||
logger.debug(
|
||||
"[QWEN] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[QWEN] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: AliQwenSession, retry_count=0) -> dict:
|
||||
"""
|
||||
call bailian's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
prompt, history = self.convert_messages_format(session.messages)
|
||||
self.update_api_key_if_expired()
|
||||
# NOTE 阿里百炼的call()函数未提供temperature参数,考虑到temperature和top_p参数作用相同,取两者较小的值作为top_p参数传入,详情见文档 https://help.aliyun.com/document_detail/2587502.htm
|
||||
response = broadscope_bailian.Completions().call(app_id=self.app_id(), prompt=prompt, history=history, top_p=min(self.temperature(), self.top_p()))
|
||||
completion_content = self.get_completion_content(response, self.node_id())
|
||||
completion_tokens, total_tokens = self.calc_tokens(session.messages, completion_content)
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": completion_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[QWEN] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[QWEN] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIError):
|
||||
logger.warn("[QWEN] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[QWEN] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.exception("[QWEN] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[QWEN] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
def set_api_key(self):
|
||||
api_key, expired_time = self.api_key_client().create_token(agent_key=self.agent_key())
|
||||
broadscope_bailian.api_key = api_key
|
||||
return expired_time
|
||||
|
||||
def update_api_key_if_expired(self):
|
||||
if time.time() > self.api_key_expired_time:
|
||||
self.api_key_expired_time = self.set_api_key()
|
||||
|
||||
def convert_messages_format(self, messages) -> Tuple[str, List[ChatQaMessage]]:
|
||||
history = []
|
||||
user_content = ''
|
||||
assistant_content = ''
|
||||
system_content = ''
|
||||
for message in messages:
|
||||
role = message.get('role')
|
||||
if role == 'user':
|
||||
user_content += message.get('content')
|
||||
elif role == 'assistant':
|
||||
assistant_content = message.get('content')
|
||||
history.append(ChatQaMessage(user_content, assistant_content))
|
||||
user_content = ''
|
||||
assistant_content = ''
|
||||
elif role =='system':
|
||||
system_content += message.get('content')
|
||||
if user_content == '':
|
||||
raise Exception('no user message')
|
||||
if system_content != '':
|
||||
# NOTE 模拟系统消息,测试发现人格描述以"你需要扮演ChatGPT"开头能够起作用,而以"你是ChatGPT"开头模型会直接否认
|
||||
system_qa = ChatQaMessage(system_content, '好的,我会严格按照你的设定回答问题')
|
||||
history.insert(0, system_qa)
|
||||
logger.debug("[QWEN] converted qa messages: {}".format([item.to_dict() for item in history]))
|
||||
logger.debug("[QWEN] user content as prompt: {}".format(user_content))
|
||||
return user_content, history
|
||||
|
||||
def get_completion_content(self, response, node_id):
|
||||
if not response['Success']:
|
||||
return f"[ERROR]\n{response['Code']}:{response['Message']}"
|
||||
text = response['Data']['Text']
|
||||
if node_id == '':
|
||||
return text
|
||||
# TODO: 当使用流程编排创建大模型应用时,响应结构如下,最终结果在['finalResult'][node_id]['response']['text']中,暂时先这么写
|
||||
# {
|
||||
# 'Success': True,
|
||||
# 'Code': None,
|
||||
# 'Message': None,
|
||||
# 'Data': {
|
||||
# 'ResponseId': '9822f38dbacf4c9b8daf5ca03a2daf15',
|
||||
# 'SessionId': 'session_id',
|
||||
# 'Text': '{"finalResult":{"LLM_T7islK":{"params":{"modelId":"qwen-plus-v1","prompt":"${systemVars.query}${bizVars.Text}"},"response":{"text":"作为一个AI语言模型,我没有年龄,因为我没有生日。\n我只是一个程序,没有生命和身体。"}}}}',
|
||||
# 'Thoughts': [],
|
||||
# 'Debug': {},
|
||||
# 'DocReferences': []
|
||||
# },
|
||||
# 'RequestId': '8e11d31551ce4c3f83f49e6e0dd998b0',
|
||||
# 'Failed': None
|
||||
# }
|
||||
text_dict = json.loads(text)
|
||||
completion_content = text_dict['finalResult'][node_id]['response']['text']
|
||||
return completion_content
|
||||
|
||||
def calc_tokens(self, messages, completion_content):
|
||||
completion_tokens = len(completion_content)
|
||||
prompt_tokens = 0
|
||||
for message in messages:
|
||||
prompt_tokens += len(message["content"])
|
||||
return completion_tokens, prompt_tokens + completion_tokens
|
||||
62
bot/ali/ali_qwen_session.py
Normal file
62
bot/ali/ali_qwen_session.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
e.g.
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
class AliQwenSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="qianwen"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
# 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词"
|
||||
# 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html
|
||||
# 目前根据字符串长度粗略估计token数,不影响正常使用
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -1,6 +1,7 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests
|
||||
|
||||
from bot.bot import Bot
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
@@ -9,20 +10,27 @@ from bridge.reply import Reply, ReplyType
|
||||
class BaiduUnitBot(Bot):
|
||||
def reply(self, query, context=None):
|
||||
token = self.get_token()
|
||||
url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token
|
||||
post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}"
|
||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
|
||||
post_data = (
|
||||
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
|
||||
+ query
|
||||
+ '", "hyper_params": {"chat_custom_bot_profile": 1}}}'
|
||||
)
|
||||
print(post_data)
|
||||
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
response = requests.post(url, data=post_data.encode(), headers=headers)
|
||||
if response:
|
||||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
|
||||
reply = Reply(
|
||||
ReplyType.TEXT,
|
||||
response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1],
|
||||
)
|
||||
return reply
|
||||
|
||||
def get_token(self):
|
||||
access_key = 'YOUR_ACCESS_KEY'
|
||||
secret_key = 'YOUR_SECRET_KEY'
|
||||
host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key
|
||||
access_key = "YOUR_ACCESS_KEY"
|
||||
secret_key = "YOUR_SECRET_KEY"
|
||||
host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
|
||||
response = requests.get(host)
|
||||
if response:
|
||||
print(response.json())
|
||||
return response.json()['access_token']
|
||||
return response.json()["access_token"]
|
||||
|
||||
115
bot/baidu/baidu_wenxin.py
Normal file
115
bot/baidu/baidu_wenxin.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests
|
||||
import json
|
||||
from common import const
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
BAIDU_API_KEY = conf().get("baidu_wenxin_api_key")
|
||||
BAIDU_SECRET_KEY = conf().get("baidu_wenxin_secret_key")
|
||||
|
||||
class BaiduWenxinBot(Bot):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
wenxin_model = conf().get("baidu_wenxin_model")
|
||||
if wenxin_model is not None:
|
||||
wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant"
|
||||
else:
|
||||
if conf().get("model") and conf().get("model") == const.WEN_XIN:
|
||||
wenxin_model = "completions"
|
||||
elif conf().get("model") and conf().get("model") == const.WEN_XIN_4:
|
||||
wenxin_model = "completions_pro"
|
||||
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=wenxin_model)
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[BAIDU] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
if query == "#清除记忆":
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
result = self.reply_text(session)
|
||||
total_tokens, completion_tokens, reply_content = (
|
||||
result["total_tokens"],
|
||||
result["completion_tokens"],
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[BAIDU] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
||||
try:
|
||||
logger.info("[BAIDU] model={}".format(session.model))
|
||||
access_token = self.get_access_token()
|
||||
if access_token == 'None':
|
||||
logger.warn("[BAIDU] access token 获取失败")
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": 0,
|
||||
}
|
||||
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + session.model + "?access_token=" + access_token
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
payload = {'messages': session.messages}
|
||||
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
|
||||
response_text = json.loads(response.text)
|
||||
logger.info(f"[BAIDU] response text={response_text}")
|
||||
res_content = response_text["result"]
|
||||
total_tokens = response_text["usage"]["total_tokens"]
|
||||
completion_tokens = response_text["usage"]["completion_tokens"]
|
||||
logger.info("[BAIDU] reply={}".format(res_content))
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
logger.warn("[BAIDU] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
result = {"total_tokens": 0, "completion_tokens": 0, "content": "出错了: {}".format(e)}
|
||||
return result
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
使用 AK,SK 生成鉴权签名(Access Token)
|
||||
:return: access_token,或是None(如果错误)
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {"grant_type": "client_credentials", "client_id": BAIDU_API_KEY, "client_secret": BAIDU_SECRET_KEY}
|
||||
return str(requests.post(url, params=params).json().get("access_token"))
|
||||
53
bot/baidu/baidu_wenxin_session.py
Normal file
53
bot/baidu/baidu_wenxin_session.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
e.g. [
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class BaiduWenxinSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# 百度文心不支持system prompt
|
||||
# self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) >= 2:
|
||||
self.messages.pop(0)
|
||||
self.messages.pop(0)
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
# 官方token计算规则暂不明确: "大约为 token数为 "中文字 + 其他语种单词数 x 1.3"
|
||||
# 这里先直接根据字数粗略估算吧,暂不影响正常使用,仅在判断是否丢弃历史会话的时候会有偏差
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -8,7 +8,7 @@ from bridge.reply import Reply
|
||||
|
||||
|
||||
class Bot(object):
|
||||
def reply(self, query, context : Context =None) -> Reply:
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
"""
|
||||
bot auto-reply content
|
||||
:param req: received message
|
||||
|
||||
@@ -6,14 +6,16 @@ from common import const
|
||||
|
||||
def create_bot(bot_type):
|
||||
"""
|
||||
create a channel instance
|
||||
:param channel_type: channel type code
|
||||
:return: channel instance
|
||||
create a bot_type instance
|
||||
:param bot_type: bot type code
|
||||
:return: bot instance
|
||||
"""
|
||||
if bot_type == const.BAIDU:
|
||||
# Baidu Unit对话接口
|
||||
from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
return BaiduUnitBot()
|
||||
# 替换Baidu Unit为Baidu文心千帆对话接口
|
||||
# from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
# return BaiduUnitBot()
|
||||
from bot.baidu.baidu_wenxin import BaiduWenxinBot
|
||||
return BaiduWenxinBot()
|
||||
|
||||
elif bot_type == const.CHATGPT:
|
||||
# ChatGPT 网页端web接口
|
||||
@@ -29,4 +31,42 @@ def create_bot(bot_type):
|
||||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
||||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
||||
return AzureChatGPTBot()
|
||||
|
||||
elif bot_type == const.XUNFEI:
|
||||
from bot.xunfei.xunfei_spark_bot import XunFeiBot
|
||||
return XunFeiBot()
|
||||
|
||||
elif bot_type == const.LINKAI:
|
||||
from bot.linkai.link_ai_bot import LinkAIBot
|
||||
return LinkAIBot()
|
||||
|
||||
elif bot_type == const.CLAUDEAI:
|
||||
from bot.claude.claude_ai_bot import ClaudeAIBot
|
||||
return ClaudeAIBot()
|
||||
elif bot_type == const.CLAUDEAPI:
|
||||
from bot.claudeapi.claude_api_bot import ClaudeAPIBot
|
||||
return ClaudeAPIBot()
|
||||
elif bot_type == const.QWEN:
|
||||
from bot.ali.ali_qwen_bot import AliQwenBot
|
||||
return AliQwenBot()
|
||||
elif bot_type == const.QWEN_DASHSCOPE:
|
||||
from bot.dashscope.dashscope_bot import DashscopeBot
|
||||
return DashscopeBot()
|
||||
elif bot_type == const.GEMINI:
|
||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
||||
return GoogleGeminiBot()
|
||||
|
||||
elif bot_type == const.ZHIPU_AI:
|
||||
from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
|
||||
return ZHIPUAIBot()
|
||||
|
||||
elif bot_type == const.MOONSHOT:
|
||||
from bot.moonshot.moonshot_bot import MoonshotBot
|
||||
return MoonshotBot()
|
||||
|
||||
elif bot_type == const.MiniMax:
|
||||
from bot.minimax.minimax_bot import MinimaxBot
|
||||
return MinimaxBot()
|
||||
|
||||
|
||||
raise RuntimeError
|
||||
|
||||
@@ -1,67 +1,97 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import requests
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf, load_config
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from common.expired_dict import ExpiredDict
|
||||
import openai
|
||||
import time
|
||||
from config import conf, load_config
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot):
|
||||
class ChatGPTBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
proxy = conf().get('proxy')
|
||||
self.sessions = SessionManager()
|
||||
super().__init__()
|
||||
# set the default api_key
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("open_ai_api_base"):
|
||||
openai.api_base = conf().get("open_ai_api_base")
|
||||
proxy = conf().get("proxy")
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
if conf().get('rate_limit_chatgpt'):
|
||||
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
|
||||
if conf().get('rate_limit_dalle'):
|
||||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
|
||||
if conf().get("rate_limit_chatgpt"):
|
||||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
||||
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
logger.info("[CHATGPT] query={}".format(query))
|
||||
|
||||
session_id = context['session_id']
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
elif query == '#更新配置':
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, '配置已更新')
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.build_session_query(query, session_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(session))
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
||||
|
||||
api_key = context.get("openai_api_key")
|
||||
model = context.get("gpt_model")
|
||||
new_args = None
|
||||
if model:
|
||||
new_args = self.args.copy()
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, session_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}".format(session, session_id, reply_content["content"]))
|
||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
reply_content = self.reply_text(session, api_key, args=new_args)
|
||||
logger.debug(
|
||||
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.save_session(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
logger.debug("[OPEN_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
@@ -73,166 +103,116 @@ class ChatGPTBot(Bot):
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def compose_args(self):
|
||||
return {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p":1,
|
||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
}
|
||||
|
||||
def reply_text(self, session, session_id, retry_count=0) -> dict:
|
||||
'''
|
||||
def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
'''
|
||||
"""
|
||||
try:
|
||||
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
|
||||
return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=session, **self.compose_args()
|
||||
)
|
||||
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
if args is None:
|
||||
args = self.args
|
||||
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
||||
# logger.debug("[CHATGPT] response={}".format(response))
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]['message']['content']}
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(session, session_id, retry_count+1)
|
||||
else:
|
||||
return {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
except openai.error.APIConnectionError as e:
|
||||
# api connection exception
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] APIConnection failed")
|
||||
return {"completion_tokens": 0, "content": "我连接不到你的网络"}
|
||||
except openai.error.Timeout as e:
|
||||
logger.warn(e)
|
||||
logger.warn("[OPEN_AI] Timeout")
|
||||
return {"completion_tokens": 0, "content": "我没有收到你的消息"}
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]["message"]["content"],
|
||||
}
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
logger.exception(e)
|
||||
self.sessions.clear_session(session_id)
|
||||
return {"completion_tokens": 0, "content": "请再问我一次吧"}
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIError):
|
||||
logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
||||
result["content"] = "我连接不到你的网络"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
else:
|
||||
logger.exception("[CHATGPT] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
prompt=query, #图片描述
|
||||
n=1, #每次生成图片的数量
|
||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.create_img(query, retry_count+1)
|
||||
if need_retry:
|
||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, api_key, args, retry_count + 1)
|
||||
else:
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, str(e)
|
||||
return result
|
||||
|
||||
|
||||
class AzureChatGPTBot(ChatGPTBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_type = "azure"
|
||||
openai.api_version = "2023-03-15-preview"
|
||||
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
|
||||
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
||||
|
||||
def compose_args(self):
|
||||
args = super().compose_args()
|
||||
args["engine"] = args["model"]
|
||||
del(args["model"])
|
||||
return args
|
||||
|
||||
|
||||
class SessionManager(object):
|
||||
def __init__(self):
|
||||
if conf().get('expires_in_seconds'):
|
||||
sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
text_to_image_model = conf().get("text_to_image")
|
||||
if text_to_image_model == "dall-e-2":
|
||||
api_version = "2023-06-01-preview"
|
||||
endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base")
|
||||
# 检查endpoint是否以/结尾
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint = endpoint + "/"
|
||||
url = "{}openai/images/generations:submit?api-version={}".format(endpoint, api_version)
|
||||
api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key")
|
||||
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
||||
try:
|
||||
body = {"prompt": query, "size": conf().get("image_create_size", "256x256"),"n": 1}
|
||||
submission = requests.post(url, headers=headers, json=body)
|
||||
operation_location = submission.headers['operation-location']
|
||||
status = ""
|
||||
while (status != "succeeded"):
|
||||
if retry_count > 3:
|
||||
return False, "图片生成失败"
|
||||
response = requests.get(operation_location, headers=headers)
|
||||
status = response.json()['status']
|
||||
retry_count += 1
|
||||
image_url = response.json()['result']['data'][0]['url']
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.error("create image error: {}".format(e))
|
||||
return False, "图片生成失败"
|
||||
elif text_to_image_model == "dall-e-3":
|
||||
api_version = conf().get("azure_api_version", "2024-02-15-preview")
|
||||
endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base")
|
||||
# 检查endpoint是否以/结尾
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint = endpoint + "/"
|
||||
url = "{}openai/deployments/{}/images/generations?api-version={}".format(endpoint, conf().get("azure_openai_dalle_deployment_id","text_to_image"),api_version)
|
||||
api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key")
|
||||
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
||||
try:
|
||||
body = {"prompt": query, "size": conf().get("image_create_size", "1024x1024"), "quality": conf().get("dalle3_image_quality", "standard")}
|
||||
submission = requests.post(url, headers=headers, json=body)
|
||||
image_url = submission.json()['data'][0]['url']
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.error("create image error: {}".format(e))
|
||||
return False, "图片生成失败"
|
||||
else:
|
||||
sessions = dict()
|
||||
self.sessions = sessions
|
||||
|
||||
def build_session(self, session_id, system_prompt=None):
|
||||
session = self.sessions.get(session_id, [])
|
||||
if len(session) == 0:
|
||||
if system_prompt is None:
|
||||
system_prompt = conf().get("character_desc", "")
|
||||
system_item = {'role': 'system', 'content': system_prompt}
|
||||
session.append(system_item)
|
||||
self.sessions[session_id] = session
|
||||
return session
|
||||
|
||||
def build_session_query(self, query, session_id):
|
||||
'''
|
||||
build query with conversation history
|
||||
e.g. [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
:param query: query content
|
||||
:param session_id: session id
|
||||
:return: query content with conversaction
|
||||
'''
|
||||
session = self.build_session(session_id)
|
||||
user_item = {'role': 'user', 'content': query}
|
||||
session.append(user_item)
|
||||
return session
|
||||
|
||||
def save_session(self, answer, session_id, total_tokens):
|
||||
max_tokens = conf().get("conversation_max_tokens")
|
||||
if not max_tokens:
|
||||
# default 3000
|
||||
max_tokens = 1000
|
||||
max_tokens = int(max_tokens)
|
||||
|
||||
session = self.sessions.get(session_id)
|
||||
if session:
|
||||
# append conversation
|
||||
gpt_item = {'role': 'assistant', 'content': answer}
|
||||
session.append(gpt_item)
|
||||
|
||||
# discard exceed limit conversation
|
||||
self.discard_exceed_conversation(session, max_tokens, total_tokens)
|
||||
|
||||
def discard_exceed_conversation(self, session, max_tokens, total_tokens):
|
||||
dec_tokens = int(total_tokens)
|
||||
# logger.info("prompt tokens used={},max_tokens={}".format(used_tokens,max_tokens))
|
||||
while dec_tokens > max_tokens:
|
||||
# pop first conversation
|
||||
if len(session) > 3:
|
||||
session.pop(1)
|
||||
session.pop(1)
|
||||
else:
|
||||
break
|
||||
dec_tokens = dec_tokens - max_tokens
|
||||
|
||||
def clear_session(self, session_id):
|
||||
self.sessions[session_id] = []
|
||||
|
||||
def clear_all_session(self):
|
||||
self.sessions.clear()
|
||||
return False, "图片生成失败,未配置text_to_image参数"
|
||||
|
||||
104
bot/chatgpt/chat_gpt_session.py
Normal file
104
bot/chatgpt/chat_gpt_session.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
from common import const
|
||||
|
||||
"""
|
||||
e.g. [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class ChatGPTSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
|
||||
if model in ["wenxin", "xunfei", const.GEMINI]:
|
||||
return num_tokens_by_character(messages)
|
||||
|
||||
import tiktoken
|
||||
|
||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot", const.LINKAI_35]:
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW, const.GPT4_TURBO_01_25,
|
||||
const.GPT_4o, const.GPT_4o_MINI, const.LINKAI_4o, const.LINKAI_4_TURBO]:
|
||||
return num_tokens_from_messages(messages, model="gpt-4")
|
||||
elif model.startswith("claude-3"):
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.debug("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model == "gpt-3.5-turbo":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif model == "gpt-4":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
|
||||
def num_tokens_by_character(messages):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
222
bot/claude/claude_ai_bot.py
Normal file
222
bot/claude/claude_ai_bot.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from curl_cffi import requests
|
||||
from bot.bot import Bot
|
||||
from bot.claude.claude_ai_session import ClaudeAiSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class ClaudeAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.claude_api_cookie = conf().get("claude_api_cookie")
|
||||
self.proxy = conf().get("proxy")
|
||||
self.con_uuid_dic = {}
|
||||
if self.proxy:
|
||||
self.proxies = {
|
||||
"http": self.proxy,
|
||||
"https": self.proxy
|
||||
}
|
||||
else:
|
||||
self.proxies = None
|
||||
self.error = ""
|
||||
self.org_uuid = self.get_organization_id()
|
||||
|
||||
def generate_uuid(self):
|
||||
random_uuid = uuid.uuid4()
|
||||
random_uuid_str = str(random_uuid)
|
||||
formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
|
||||
return formatted_uuid
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def get_organization_id(self):
|
||||
url = "https://claude.ai/api/organizations"
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}'
|
||||
}
|
||||
try:
|
||||
response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
|
||||
res = json.loads(response.text)
|
||||
uuid = res[0]['uuid']
|
||||
except:
|
||||
if "App unavailable" in response.text:
|
||||
logger.error("IP error: The IP is not allowed to be used on Claude")
|
||||
self.error = "ip所在地区不被claude支持"
|
||||
elif "Invalid authorization" in response.text:
|
||||
logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
|
||||
self.error = "无法通过claude身份验证,请检查cookie"
|
||||
return None
|
||||
return uuid
|
||||
|
||||
def conversation_share_check(self,session_id):
|
||||
if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
|
||||
con_uuid = conf().get("claude_uuid")
|
||||
return con_uuid
|
||||
if session_id not in self.con_uuid_dic:
|
||||
self.con_uuid_dic[session_id] = self.generate_uuid()
|
||||
self.create_new_chat(self.con_uuid_dic[session_id])
|
||||
return self.con_uuid_dic[session_id]
|
||||
|
||||
def check_cookie(self):
|
||||
flag = self.get_organization_id()
|
||||
return flag
|
||||
|
||||
def create_new_chat(self, con_uuid):
|
||||
"""
|
||||
新建claude对话实体
|
||||
:param con_uuid: 对话id
|
||||
:return:
|
||||
"""
|
||||
url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
|
||||
payload = json.dumps({"uuid": con_uuid, "name": ""})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': self.claude_api_cookie,
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
|
||||
# Returns JSON of the newly created conversation information
|
||||
return response.json()
|
||||
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[CLAUDEAI] failed after maximum number of retry times")
|
||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
||||
|
||||
try:
|
||||
session_id = context["session_id"]
|
||||
if self.org_uuid is None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
con_uuid = self.conversation_share_check(session_id)
|
||||
|
||||
model = conf().get("model") or "gpt-3.5-turbo"
|
||||
# remove system message
|
||||
if session.messages[0].get("role") == "system":
|
||||
if model == "wenxin" or model == "claude":
|
||||
session.messages.pop(0)
|
||||
logger.info(f"[CLAUDEAI] query={query}")
|
||||
|
||||
# do http request
|
||||
base_url = "https://claude.ai"
|
||||
payload = json.dumps({
|
||||
"completion": {
|
||||
"prompt": f"{query}",
|
||||
"timezone": "Asia/Kolkata",
|
||||
"model": "claude-2"
|
||||
},
|
||||
"organization_uuid": f"{self.org_uuid}",
|
||||
"conversation_uuid": f"{con_uuid}",
|
||||
"text": f"{query}",
|
||||
"attachments": []
|
||||
})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept': 'text/event-stream, text/event-stream',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
|
||||
res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
|
||||
if res.status_code == 200 or "pemission" in res.text:
|
||||
# execute success
|
||||
decoded_data = res.content.decode("utf-8")
|
||||
decoded_data = re.sub('\n+', '\n', decoded_data).strip()
|
||||
data_strings = decoded_data.split('\n')
|
||||
completions = []
|
||||
for data_string in data_strings:
|
||||
json_str = data_string[6:].strip()
|
||||
data = json.loads(json_str)
|
||||
if 'completion' in data:
|
||||
completions.append(data['completion'])
|
||||
|
||||
reply_content = ''.join(completions)
|
||||
|
||||
if "rate limi" in reply_content:
|
||||
logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
|
||||
return Reply(ReplyType.ERROR, "对话达到系统速率限制,与cladue同步,请进入官网查看解除限制时间")
|
||||
logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
|
||||
self.sessions.session_reply(reply_content, session_id, 100)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
else:
|
||||
flag = self.check_cookie()
|
||||
if flag == None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
9
bot/claude/claude_ai_session.py
Normal file
9
bot/claude/claude_ai_session.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from bot.session_manager import Session
|
||||
|
||||
|
||||
class ClaudeAiSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="claude"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# claude逆向不支持role prompt
|
||||
# self.reset()
|
||||
135
bot/claudeapi/claude_api_bot.py
Normal file
135
bot/claudeapi/claude_api_bot.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import anthropic
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
user_session = dict()
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.claudeClient = anthropic.Anthropic(
|
||||
api_key=conf().get("claude_api_key")
|
||||
)
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("open_ai_api_base"):
|
||||
openai.api_base = conf().get("open_ai_api_base")
|
||||
proxy = conf().get("proxy")
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "text-davinci-003")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[CLAUDE_API] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
if query == "#清除记忆":
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
result = self.reply_text(session)
|
||||
logger.info(result)
|
||||
total_tokens, completion_tokens, reply_content = (
|
||||
result["total_tokens"],
|
||||
result["completion_tokens"],
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[CLAUDE_API] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: ChatGPTSession, retry_count=0):
|
||||
try:
|
||||
actual_model = self._model_mapping(conf().get("model"))
|
||||
response = self.claudeClient.messages.create(
|
||||
model=actual_model,
|
||||
max_tokens=1024,
|
||||
# system=conf().get("system"),
|
||||
messages=GoogleGeminiBot.filter_messages(session.messages)
|
||||
)
|
||||
# response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response.usage.input_tokens+response.usage.output_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
logger.info("[CLAUDE_API] reply={}".format(res_content))
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CLAUDE_API] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.warn("[CLAUDE_API] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
def _model_mapping(self, model) -> str:
|
||||
if model == "claude-3-opus":
|
||||
return "claude-3-opus-20240229"
|
||||
elif model == "claude-3-sonnet":
|
||||
return "claude-3-sonnet-20240229"
|
||||
elif model == "claude-3-haiku":
|
||||
return "claude-3-haiku-20240307"
|
||||
elif model == "claude-3.5-sonnet":
|
||||
return "claude-3-5-sonnet-20240620"
|
||||
return model
|
||||
117
bot/dashscope/dashscope_bot.py
Normal file
117
bot/dashscope/dashscope_bot.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from .dashscope_session import DashscopeSession
|
||||
import os
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
|
||||
|
||||
|
||||
dashscope_models = {
|
||||
"qwen-turbo": dashscope.Generation.Models.qwen_turbo,
|
||||
"qwen-plus": dashscope.Generation.Models.qwen_plus,
|
||||
"qwen-max": dashscope.Generation.Models.qwen_max,
|
||||
"qwen-bailian-v1": dashscope.Generation.Models.bailian_v1
|
||||
}
|
||||
# ZhipuAI对话模型API
|
||||
class DashscopeBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(DashscopeSession, model=conf().get("model") or "qwen-plus")
|
||||
self.model_name = conf().get("model") or "qwen-plus"
|
||||
self.api_key = conf().get("dashscope_api_key")
|
||||
os.environ["DASHSCOPE_API_KEY"] = self.api_key
|
||||
self.client = dashscope.Generation
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[DASHSCOPE] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[DASHSCOPE] session query={}".format(session.messages))
|
||||
|
||||
reply_content = self.reply_text(session)
|
||||
logger.debug(
|
||||
"[DASHSCOPE] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[DASHSCOPE] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: DashscopeSession, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
dashscope.api_key = self.api_key
|
||||
response = self.client.call(
|
||||
dashscope_models[self.model_name],
|
||||
messages=session.messages,
|
||||
result_format="message"
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
content = response.output.choices[0]["message"]["content"]
|
||||
return {
|
||||
"total_tokens": response.usage["total_tokens"],
|
||||
"completion_tokens": response.usage["output_tokens"],
|
||||
"content": content,
|
||||
}
|
||||
else:
|
||||
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||||
response.request_id, response.status_code,
|
||||
response.code, response.message
|
||||
))
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
51
bot/dashscope/dashscope_session.py
Normal file
51
bot/dashscope/dashscope_session.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class DashscopeSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="qwen-turbo"):
|
||||
super().__init__(session_id)
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages):
|
||||
# 只是大概,具体计算规则:https://help.aliyun.com/zh/dashscope/developer-reference/token-api?spm=a2c4g.11186623.0.0.4d8b12b0BkP3K9
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
81
bot/gemini/google_gemini_bot.py
Normal file
81
bot/gemini/google_gemini_bot.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Google gemini bot
|
||||
|
||||
@author zhayujie
|
||||
@Date 2023/12/15
|
||||
"""
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
import google.generativeai as genai
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class GoogleGeminiBot(Bot):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api_key = conf().get("gemini_api_key")
|
||||
# 复用文心的token计算方式
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.model = conf().get("model") or "gemini-pro"
|
||||
if self.model == "gemini":
|
||||
self.model = "gemini-pro"
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
try:
|
||||
if context.type != ContextType.TEXT:
|
||||
logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
|
||||
return Reply(ReplyType.TEXT, None)
|
||||
logger.info(f"[Gemini] query={query}")
|
||||
session_id = context["session_id"]
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
|
||||
genai.configure(api_key=self.api_key)
|
||||
model = genai.GenerativeModel(self.model)
|
||||
response = model.generate_content(gemini_messages)
|
||||
reply_text = response.text
|
||||
self.sessions.session_reply(reply_text, session_id)
|
||||
logger.info(f"[Gemini] reply={reply_text}")
|
||||
return Reply(ReplyType.TEXT, reply_text)
|
||||
except Exception as e:
|
||||
logger.error("[Gemini] fetch reply error, may contain unsafe content")
|
||||
logger.error(e)
|
||||
return Reply(ReplyType.ERROR, "invoke [Gemini] api failed!")
|
||||
|
||||
def _convert_to_gemini_messages(self, messages: list):
|
||||
res = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
role = "user"
|
||||
elif msg.get("role") == "assistant":
|
||||
role = "model"
|
||||
else:
|
||||
continue
|
||||
res.append({
|
||||
"role": role,
|
||||
"parts": [{"text": msg.get("content")}]
|
||||
})
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def filter_messages(messages: list):
|
||||
res = []
|
||||
turn = "user"
|
||||
if not messages:
|
||||
return res
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
message = messages[i]
|
||||
if message.get("role") != turn:
|
||||
continue
|
||||
res.insert(0, message)
|
||||
if turn == "user":
|
||||
turn = "assistant"
|
||||
elif turn == "assistant":
|
||||
turn = "user"
|
||||
return res
|
||||
475
bot/linkai/link_ai_bot.py
Normal file
475
bot/linkai/link_ai_bot.py
Normal file
@@ -0,0 +1,475 @@
|
||||
# access LinkAI knowledge base platform
|
||||
# docs: https://link-ai.tech/platform/link-app/wechat
|
||||
|
||||
import re
|
||||
import time
|
||||
import requests
|
||||
import config
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, pconf
|
||||
import threading
|
||||
from common import memory, utils
|
||||
import base64
|
||||
import os
|
||||
|
||||
class LinkAIBot(Bot):
|
||||
# authentication failed
|
||||
AUTH_FAILED_CODE = 401
|
||||
NO_QUOTA_CODE = 406
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {}
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
if not conf().get("text_to_image"):
|
||||
logger.warn("[LinkAI] text_to_image is not enabled, ignore the IMAGE_CREATE request")
|
||||
return Reply(ReplyType.TEXT, "")
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count > 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[LINKAI] failed after maximum number of retry times")
|
||||
return Reply(ReplyType.TEXT, "请再问我一次吧")
|
||||
|
||||
try:
|
||||
# load config
|
||||
if context.get("generate_breaked_by"):
|
||||
logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
|
||||
app_code = None
|
||||
else:
|
||||
plugin_app_code = self._find_group_mapping_code(context)
|
||||
app_code = context.kwargs.get("app_code") or plugin_app_code or conf().get("linkai_app_code")
|
||||
linkai_api_key = conf().get("linkai_api_key")
|
||||
|
||||
session_id = context["session_id"]
|
||||
session_message = self.sessions.session_msg_query(query, session_id)
|
||||
logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")
|
||||
|
||||
# image process
|
||||
img_cache = memory.USER_IMAGE_CACHE.get(session_id)
|
||||
if img_cache:
|
||||
messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
|
||||
if messages:
|
||||
session_message = messages
|
||||
|
||||
model = conf().get("model")
|
||||
# remove system message
|
||||
if session_message[0].get("role") == "system":
|
||||
if app_code or model == "wenxin":
|
||||
session_message.pop(0)
|
||||
body = {
|
||||
"app_code": app_code,
|
||||
"messages": session_message,
|
||||
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"temperature": conf().get("temperature"),
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"session_id": session_id,
|
||||
"sender_id": session_id,
|
||||
"channel_type": conf().get("channel_type", "wx")
|
||||
}
|
||||
try:
|
||||
from linkai import LinkAIClient
|
||||
client_id = LinkAIClient.fetch_client_id()
|
||||
if client_id:
|
||||
body["client_id"] = client_id
|
||||
# start: client info deliver
|
||||
if context.kwargs.get("msg"):
|
||||
body["session_id"] = context.kwargs.get("msg").from_user_id
|
||||
if context.kwargs.get("msg").is_group:
|
||||
body["is_group"] = True
|
||||
body["group_name"] = context.kwargs.get("msg").from_user_nickname
|
||||
body["sender_name"] = context.kwargs.get("msg").actual_user_nickname
|
||||
else:
|
||||
if body.get("channel_type") in ["wechatcom_app"]:
|
||||
body["sender_name"] = context.kwargs.get("msg").from_user_id
|
||||
else:
|
||||
body["sender_name"] = context.kwargs.get("msg").from_user_nickname
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
file_id = context.kwargs.get("file_id")
|
||||
if file_id:
|
||||
body["file_id"] = file_id
|
||||
logger.info(f"[LINKAI] query={query}, app_code={app_code}, model={body.get('model')}, file_id={file_id}")
|
||||
headers = {"Authorization": "Bearer " + linkai_api_key}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
# execute success
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
res_code = response.get('code')
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}, res_code={res_code}")
|
||||
if res_code == 429:
|
||||
logger.warn(f"[LINKAI] 用户访问超出限流配置,sender_id={body.get('sender_id')}")
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
|
||||
agent_suffix = self._fetch_agent_suffix(response)
|
||||
if agent_suffix:
|
||||
reply_content += agent_suffix
|
||||
if not agent_suffix:
|
||||
knowledge_suffix = self._fetch_knowledge_search_suffix(response)
|
||||
if knowledge_suffix:
|
||||
reply_content += knowledge_suffix
|
||||
# image process
|
||||
if response["choices"][0].get("img_urls"):
|
||||
thread = threading.Thread(target=self._send_image, args=(context.get("channel"), context, response["choices"][0].get("img_urls")))
|
||||
thread.start()
|
||||
if response["choices"][0].get("text_content"):
|
||||
reply_content = response["choices"][0].get("text_content")
|
||||
reply_content = self._process_url(reply_content)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
error_reply = "提问太快啦,请休息一下再问我吧"
|
||||
if res.status_code == 409:
|
||||
error_reply = "这个问题我还没有学会,请问我其它问题吧"
|
||||
return Reply(ReplyType.TEXT, error_reply)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
|
||||
try:
|
||||
enable_image_input = False
|
||||
app_info = self._fetch_app_info(app_code)
|
||||
if not app_info:
|
||||
logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
|
||||
return None
|
||||
plugins = app_info.get("data").get("plugins")
|
||||
for plugin in plugins:
|
||||
if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
|
||||
enable_image_input = True
|
||||
if not enable_image_input:
|
||||
return
|
||||
msg = img_cache.get("msg")
|
||||
path = img_cache.get("path")
|
||||
msg.prepare()
|
||||
logger.info(f"[LinkAI] query with images, path={path}")
|
||||
messages = self._build_vision_msg(query, path)
|
||||
memory.USER_IMAGE_CACHE[session_id] = None
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _find_group_mapping_code(self, context):
|
||||
try:
|
||||
if context.kwargs.get("isgroup"):
|
||||
group_name = context.kwargs.get("msg").from_user_nickname
|
||||
if config.plugin_config and config.plugin_config.get("linkai"):
|
||||
linkai_config = config.plugin_config.get("linkai")
|
||||
group_mapping = linkai_config.get("group_app_map")
|
||||
if group_mapping and group_name:
|
||||
return group_mapping.get(group_name)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return None
|
||||
|
||||
def _build_vision_msg(self, query: str, path: str):
|
||||
try:
|
||||
suffix = utils.get_path_suffix(path)
|
||||
with open(path, "rb") as file:
|
||||
base64_str = base64.b64encode(file.read()).decode('utf-8')
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": query
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/{suffix};base64,{base64_str}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}]
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[LINKAI] failed after maximum number of retry times")
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": "请再问我一次吧"
|
||||
}
|
||||
|
||||
try:
|
||||
body = {
|
||||
"app_code": app_code,
|
||||
"messages": session.messages,
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"temperature": conf().get("temperature"),
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
}
|
||||
if self.args.get("max_tokens"):
|
||||
body["max_tokens"] = self.args.get("max_tokens")
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
# execute success
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": reply_content,
|
||||
}
|
||||
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self.reply_text(session, app_code, retry_count + 1)
|
||||
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": "提问太快啦,请休息一下再问我吧"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self.reply_text(session, app_code, retry_count + 1)
|
||||
|
||||
def _fetch_app_info(self, app_code: str):
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
params = {"app_code": app_code}
|
||||
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
|
||||
if res.status_code == 200:
|
||||
return res.json()
|
||||
else:
|
||||
logger.warning(f"[LinkAI] find app info exception, res={res}")
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
try:
|
||||
logger.info("[LinkImage] image_query={}".format(query))
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {conf().get('linkai_api_key')}"
|
||||
}
|
||||
data = {
|
||||
"prompt": query,
|
||||
"n": 1,
|
||||
"model": conf().get("text_to_image") or "dall-e-2",
|
||||
"response_format": "url",
|
||||
"img_proxy": conf().get("image_proxy")
|
||||
}
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/images/generations"
|
||||
res = requests.post(url, headers=headers, json=data, timeout=(5, 90))
|
||||
t2 = time.time()
|
||||
image_url = res.json()["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(format(e))
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
|
||||
|
||||
def _fetch_knowledge_search_suffix(self, response) -> str:
|
||||
try:
|
||||
if response.get("knowledge_base"):
|
||||
search_hit = response.get("knowledge_base").get("search_hit")
|
||||
first_similarity = response.get("knowledge_base").get("first_similarity")
|
||||
logger.info(f"[LINKAI] knowledge base, search_hit={search_hit}, first_similarity={first_similarity}")
|
||||
plugin_config = pconf("linkai")
|
||||
if plugin_config and plugin_config.get("knowledge_base") and plugin_config.get("knowledge_base").get("search_miss_text_enabled"):
|
||||
search_miss_similarity = plugin_config.get("knowledge_base").get("search_miss_similarity")
|
||||
search_miss_text = plugin_config.get("knowledge_base").get("search_miss_suffix")
|
||||
if not search_hit:
|
||||
return search_miss_text
|
||||
if search_miss_similarity and float(search_miss_similarity) > first_similarity:
|
||||
return search_miss_text
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
def _fetch_agent_suffix(self, response):
|
||||
try:
|
||||
plugin_list = []
|
||||
logger.debug(f"[LinkAgent] res={response}")
|
||||
if response.get("agent") and response.get("agent").get("chain") and response.get("agent").get("need_show_plugin"):
|
||||
chain = response.get("agent").get("chain")
|
||||
suffix = "\n\n- - - - - - - - - - - -"
|
||||
i = 0
|
||||
for turn in chain:
|
||||
plugin_name = turn.get('plugin_name')
|
||||
suffix += "\n"
|
||||
need_show_thought = response.get("agent").get("need_show_thought")
|
||||
if turn.get("thought") and plugin_name and need_show_thought:
|
||||
suffix += f"{turn.get('thought')}\n"
|
||||
if plugin_name:
|
||||
plugin_list.append(turn.get('plugin_name'))
|
||||
if turn.get('plugin_icon'):
|
||||
suffix += f"{turn.get('plugin_icon')} "
|
||||
suffix += f"{turn.get('plugin_name')}"
|
||||
if turn.get('plugin_input'):
|
||||
suffix += f":{turn.get('plugin_input')}"
|
||||
if i < len(chain) - 1:
|
||||
suffix += "\n"
|
||||
i += 1
|
||||
logger.info(f"[LinkAgent] use plugins: {plugin_list}")
|
||||
return suffix
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _process_url(self, text):
|
||||
try:
|
||||
url_pattern = re.compile(r'\[(.*?)\]\((http[s]?://.*?)\)')
|
||||
def replace_markdown_url(match):
|
||||
return f"{match.group(2)}"
|
||||
return url_pattern.sub(replace_markdown_url, text)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def _send_image(self, channel, context, image_urls):
|
||||
if not image_urls:
|
||||
return
|
||||
max_send_num = conf().get("max_media_send_count")
|
||||
send_interval = conf().get("media_send_interval")
|
||||
file_type = (".pdf", ".doc", ".docx", ".csv", ".xls", ".xlsx", ".txt", ".rtf", ".ppt", ".pptx")
|
||||
try:
|
||||
i = 0
|
||||
for url in image_urls:
|
||||
if max_send_num and i >= max_send_num:
|
||||
continue
|
||||
i += 1
|
||||
if url.endswith(".mp4"):
|
||||
reply_type = ReplyType.VIDEO_URL
|
||||
elif url.endswith(file_type):
|
||||
reply_type = ReplyType.FILE
|
||||
url = _download_file(url)
|
||||
if not url:
|
||||
continue
|
||||
else:
|
||||
reply_type = ReplyType.IMAGE_URL
|
||||
reply = Reply(reply_type, url)
|
||||
channel.send(reply, context)
|
||||
if send_interval:
|
||||
time.sleep(send_interval)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
|
||||
def _download_file(url: str):
|
||||
try:
|
||||
file_path = "tmp"
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path)
|
||||
file_name = url.split("/")[-1] # 获取文件名
|
||||
file_path = os.path.join(file_path, file_name)
|
||||
response = requests.get(url)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
return file_path
|
||||
except Exception as e:
|
||||
logger.warn(e)
|
||||
|
||||
|
||||
class LinkAISessionManager(SessionManager):
|
||||
def session_msg_query(self, query, session_id):
|
||||
session = self.build_session(session_id)
|
||||
messages = session.messages + [{"role": "user", "content": query}]
|
||||
return messages
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens=None, query=None):
|
||||
session = self.build_session(session_id)
|
||||
if query:
|
||||
session.add_query(query)
|
||||
session.add_reply(reply)
|
||||
try:
|
||||
max_tokens = conf().get("conversation_max_tokens", 2500)
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.debug(f"[LinkAI] chat history, before tokens={total_tokens}, now tokens={tokens_cnt}")
|
||||
except Exception as e:
|
||||
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
|
||||
class LinkAISession(ChatGPTSession):
|
||||
def calc_tokens(self):
|
||||
if not self.messages:
|
||||
return 0
|
||||
return len(str(self.messages))
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
cur_tokens = self.calc_tokens()
|
||||
if cur_tokens > max_tokens:
|
||||
for i in range(0, len(self.messages)):
|
||||
if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
|
||||
self.messages.pop(i)
|
||||
self.messages.pop(i - 1)
|
||||
return self.calc_tokens()
|
||||
return cur_tokens
|
||||
151
bot/minimax/minimax_bot.py
Normal file
151
bot/minimax/minimax_bot.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.minimax.minimax_session import MinimaxSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
import requests
|
||||
from common import const
|
||||
|
||||
|
||||
# ZhipuAI对话模型API
|
||||
class MinimaxBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.args = {
|
||||
"model": conf().get("model") or "abab6.5", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
|
||||
"top_p": conf().get("top_p", 0.95), # 使用默认值
|
||||
}
|
||||
self.api_key = conf().get("Minimax_api_key")
|
||||
self.group_id = conf().get("Minimax_group_id")
|
||||
self.base_url = conf().get("Minimax_base_url", f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={self.group_id}")
|
||||
# tokens_to_generate/bot_setting/reply_constraints可自行修改
|
||||
self.request_body = {
|
||||
"model": self.args["model"],
|
||||
"tokens_to_generate": 2048,
|
||||
"reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"},
|
||||
"messages": [],
|
||||
"bot_setting": [
|
||||
{
|
||||
"bot_name": "MM智能助理",
|
||||
"content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。",
|
||||
}
|
||||
],
|
||||
}
|
||||
self.sessions = SessionManager(MinimaxSession, model=const.MiniMax)
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
# acquire reply content
|
||||
logger.info("[Minimax_AI] query={}".format(query))
|
||||
if context.type == ContextType.TEXT:
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[Minimax_AI] session query={}".format(session))
|
||||
|
||||
model = context.get("Minimax_model")
|
||||
new_args = self.args.copy()
|
||||
if model:
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, args=new_args)
|
||||
logger.debug(
|
||||
"[Minimax_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[Minimax_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: MinimaxSession, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + self.api_key}
|
||||
self.request_body["messages"].extend(session.messages)
|
||||
logger.info("[Minimax_AI] request_body={}".format(self.request_body))
|
||||
# logger.info("[Minimax_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
res = requests.post(self.base_url, headers=headers, json=self.request_body)
|
||||
|
||||
# self.request_body["messages"].extend(response.json()["choices"][0]["messages"])
|
||||
if res.status_code == 200:
|
||||
response = res.json()
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["total_tokens"],
|
||||
"content": response["reply"],
|
||||
}
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[Minimax_AI] chat failed, status_code={res.status_code}, " f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
need_retry = False
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
logger.warn(f"[Minimax_AI] do retry, times={retry_count}")
|
||||
need_retry = retry_count < 2
|
||||
elif res.status_code == 401:
|
||||
result["content"] = "授权失败,请检查API Key是否正确"
|
||||
elif res.status_code == 429:
|
||||
result["content"] = "请求过于频繁,请稍后再试"
|
||||
need_retry = retry_count < 2
|
||||
else:
|
||||
need_retry = False
|
||||
|
||||
if need_retry:
|
||||
time.sleep(3)
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
72
bot/minimax/minimax_session.py
Normal file
72
bot/minimax/minimax_session.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
e.g.
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class MinimaxSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="minimax"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# self.reset()
|
||||
|
||||
def add_query(self, query):
|
||||
user_item = {"sender_type": "USER", "sender_name": self.session_id, "text": query}
|
||||
self.messages.append(user_item)
|
||||
|
||||
def add_reply(self, reply):
|
||||
assistant_item = {"sender_type": "BOT", "sender_name": "MM智能助理", "text": reply}
|
||||
self.messages.append(assistant_item)
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "BOT":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "USER":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
# 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词"
|
||||
# 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html
|
||||
# 目前根据字符串长度粗略估计token数,不影响正常使用
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["text"])
|
||||
return tokens
|
||||
143
bot/moonshot/moonshot_bot.py
Normal file
143
bot/moonshot/moonshot_bot.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from .moonshot_session import MoonshotSession
|
||||
import requests
|
||||
|
||||
|
||||
# ZhipuAI对话模型API
|
||||
class MoonshotBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(MoonshotSession, model=conf().get("model") or "moonshot-v1-128k")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "moonshot-v1-128k", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
|
||||
"top_p": conf().get("top_p", 1.0), # 使用默认值
|
||||
}
|
||||
self.api_key = conf().get("moonshot_api_key")
|
||||
self.base_url = conf().get("moonshot_base_url", "https://api.moonshot.cn/v1/chat/completions")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[MOONSHOT_AI] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[MOONSHOT_AI] session query={}".format(session.messages))
|
||||
|
||||
model = context.get("moonshot_model")
|
||||
new_args = self.args.copy()
|
||||
if model:
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, args=new_args)
|
||||
logger.debug(
|
||||
"[MOONSHOT_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[MOONSHOT_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: MoonshotSession, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key
|
||||
}
|
||||
body = args
|
||||
body["messages"] = session.messages
|
||||
# logger.debug("[MOONSHOT_AI] response={}".format(response))
|
||||
# logger.info("[MOONSHOT_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
res = requests.post(
|
||||
self.base_url,
|
||||
headers=headers,
|
||||
json=body
|
||||
)
|
||||
if res.status_code == 200:
|
||||
response = res.json()
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response["choices"][0]["message"]["content"]
|
||||
}
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[MOONSHOT_AI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
need_retry = False
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
logger.warn(f"[MOONSHOT_AI] do retry, times={retry_count}")
|
||||
need_retry = retry_count < 2
|
||||
elif res.status_code == 401:
|
||||
result["content"] = "授权失败,请检查API Key是否正确"
|
||||
elif res.status_code == 429:
|
||||
result["content"] = "请求过于频繁,请稍后再试"
|
||||
need_retry = retry_count < 2
|
||||
else:
|
||||
need_retry = False
|
||||
|
||||
if need_retry:
|
||||
time.sleep(3)
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
51
bot/moonshot/moonshot_session.py
Normal file
51
bot/moonshot/moonshot_session.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class MoonshotSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="moonshot-v1-128k"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -1,175 +1,122 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.openai.open_ai_session import OpenAISession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import openai
|
||||
import time
|
||||
from config import conf
|
||||
|
||||
user_session = dict()
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class OpenAIBot(Bot):
|
||||
class OpenAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
proxy = conf().get('proxy')
|
||||
super().__init__()
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("open_ai_api_base"):
|
||||
openai.api_base = conf().get("open_ai_api_base")
|
||||
proxy = conf().get("proxy")
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
"max_tokens": 1200, # 回复最大的字符数
|
||||
"top_p": 1,
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
"stop": ["\n\n\n"],
|
||||
}
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
from_user_id = context['session_id']
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
if query == '#清除记忆':
|
||||
Session.clear_session(from_user_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
Session.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
if query == "#清除记忆":
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
else:
|
||||
new_query = Session.build_session_query(query, from_user_id)
|
||||
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
result = self.reply_text(session)
|
||||
total_tokens, completion_tokens, reply_content = (
|
||||
result["total_tokens"],
|
||||
result["completion_tokens"],
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
reply_content = self.reply_text(new_query, from_user_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, user={}, reply_cont={}".format(new_query, from_user_id, reply_content))
|
||||
if reply_content and query:
|
||||
Session.save_session(query, reply_content, from_user_id)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
return self.create_img(query, 0)
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, query, user_id, retry_count=0):
|
||||
def reply_text(self, session: OpenAISession, retry_count=0):
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
model= conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||
prompt=query,
|
||||
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
max_tokens=1200, # 回复最大的字符数
|
||||
top_p=1,
|
||||
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
stop=["\n\n\n"]
|
||||
)
|
||||
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
||||
response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
completion_tokens = response["usage"]["completion_tokens"]
|
||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||
return res_content
|
||||
except openai.error.RateLimitError as e:
|
||||
# rate limit exception
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, user_id, retry_count+1)
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except Exception as e:
|
||||
# unknown exception
|
||||
logger.exception(e)
|
||||
Session.clear_session(user_id)
|
||||
return "请再问我一次吧"
|
||||
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
prompt=query, #图片描述
|
||||
n=1, #每次生成图片的数量
|
||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, retry_count+1)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
else:
|
||||
return "提问太快啦,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return None
|
||||
logger.warn("[OPEN_AI] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
|
||||
class Session(object):
|
||||
@staticmethod
|
||||
def build_session_query(query, user_id):
|
||||
'''
|
||||
build query with conversation history
|
||||
e.g. Q: xxx
|
||||
A: xxx
|
||||
Q: xxx
|
||||
:param query: query content
|
||||
:param user_id: from user id
|
||||
:return: query content with conversaction
|
||||
'''
|
||||
prompt = conf().get("character_desc", "")
|
||||
if prompt:
|
||||
prompt += "<|endoftext|>\n\n\n"
|
||||
session = user_session.get(user_id, None)
|
||||
if session:
|
||||
for conversation in session:
|
||||
prompt += "Q: " + conversation["question"] + "\n\n\nA: " + conversation["answer"] + "<|endoftext|>\n"
|
||||
prompt += "Q: " + query + "\nA: "
|
||||
return prompt
|
||||
else:
|
||||
return prompt + "Q: " + query + "\nA: "
|
||||
|
||||
@staticmethod
|
||||
def save_session(query, answer, user_id):
|
||||
max_tokens = conf().get("conversation_max_tokens")
|
||||
if not max_tokens:
|
||||
# default 3000
|
||||
max_tokens = 1000
|
||||
conversation = dict()
|
||||
conversation["question"] = query
|
||||
conversation["answer"] = answer
|
||||
session = user_session.get(user_id)
|
||||
logger.debug(conversation)
|
||||
logger.debug(session)
|
||||
if session:
|
||||
# append conversation
|
||||
session.append(conversation)
|
||||
else:
|
||||
# create session
|
||||
queue = list()
|
||||
queue.append(conversation)
|
||||
user_session[user_id] = queue
|
||||
|
||||
# discard exceed limit conversation
|
||||
Session.discard_exceed_conversation(user_session[user_id], max_tokens)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def discard_exceed_conversation(session, max_tokens):
|
||||
count = 0
|
||||
count_list = list()
|
||||
for i in range(len(session)-1, -1, -1):
|
||||
# count tokens of conversation list
|
||||
history_conv = session[i]
|
||||
count += len(history_conv["question"]) + len(history_conv["answer"])
|
||||
count_list.append(count)
|
||||
|
||||
for c in count_list:
|
||||
if c > max_tokens:
|
||||
# pop first conversation
|
||||
session.pop(0)
|
||||
|
||||
@staticmethod
|
||||
def clear_session(user_id):
|
||||
user_session[user_id] = []
|
||||
|
||||
@staticmethod
|
||||
def clear_all_session():
|
||||
user_session.clear()
|
||||
if need_retry:
|
||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
43
bot/openai/open_ai_image.py
Normal file
43
bot/openai/open_ai_image.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from config import conf
|
||||
|
||||
|
||||
# OPENAI提供的画图接口
|
||||
class OpenAIImage(object):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("rate_limit_dalle"):
|
||||
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None, api_base=None):
|
||||
try:
|
||||
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
api_key=api_key,
|
||||
prompt=query, # 图片描述
|
||||
n=1, # 每次生成图片的数量
|
||||
model=conf().get("text_to_image") or "dall-e-2",
|
||||
# size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
|
||||
return self.create_img(query, retry_count + 1)
|
||||
else:
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
73
bot/openai/open_ai_session.py
Normal file
73
bot/openai/open_ai_session.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class OpenAISession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def __str__(self):
|
||||
# 构造对话模型的输入
|
||||
"""
|
||||
e.g. Q: xxx
|
||||
A: xxx
|
||||
Q: xxx
|
||||
"""
|
||||
prompt = ""
|
||||
for item in self.messages:
|
||||
if item["role"] == "system":
|
||||
prompt += item["content"] + "<|endoftext|>\n\n\n"
|
||||
elif item["role"] == "user":
|
||||
prompt += "Q: " + item["content"] + "\n"
|
||||
elif item["role"] == "assistant":
|
||||
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
|
||||
|
||||
if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
|
||||
prompt += "A: "
|
||||
return prompt
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 1:
|
||||
self.messages.pop(0)
|
||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
|
||||
self.messages.pop(0)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = len(str(self))
|
||||
break
|
||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
|
||||
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = len(str(self))
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_string(str(self), self.model)
|
||||
|
||||
|
||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
def num_tokens_from_string(string: str, model: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
import tiktoken
|
||||
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
num_tokens = len(encoding.encode(string, disallowed_special=()))
|
||||
return num_tokens
|
||||
91
bot/session_manager.py
Normal file
91
bot/session_manager.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class Session(object):
|
||||
def __init__(self, session_id, system_prompt=None):
|
||||
self.session_id = session_id
|
||||
self.messages = []
|
||||
if system_prompt is None:
|
||||
self.system_prompt = conf().get("character_desc", "")
|
||||
else:
|
||||
self.system_prompt = system_prompt
|
||||
|
||||
# 重置会话
|
||||
def reset(self):
|
||||
system_item = {"role": "system", "content": self.system_prompt}
|
||||
self.messages = [system_item]
|
||||
|
||||
def set_system_prompt(self, system_prompt):
|
||||
self.system_prompt = system_prompt
|
||||
self.reset()
|
||||
|
||||
def add_query(self, query):
|
||||
user_item = {"role": "user", "content": query}
|
||||
self.messages.append(user_item)
|
||||
|
||||
def add_reply(self, reply):
|
||||
assistant_item = {"role": "assistant", "content": reply}
|
||||
self.messages.append(assistant_item)
|
||||
|
||||
def discard_exceeding(self, max_tokens=None, cur_tokens=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def calc_tokens(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SessionManager(object):
|
||||
def __init__(self, sessioncls, **session_args):
|
||||
if conf().get("expires_in_seconds"):
|
||||
sessions = ExpiredDict(conf().get("expires_in_seconds"))
|
||||
else:
|
||||
sessions = dict()
|
||||
self.sessions = sessions
|
||||
self.sessioncls = sessioncls
|
||||
self.session_args = session_args
|
||||
|
||||
def build_session(self, session_id, system_prompt=None):
|
||||
"""
|
||||
如果session_id不在sessions中,创建一个新的session并添加到sessions中
|
||||
如果system_prompt不会空,会更新session的system_prompt并重置session
|
||||
"""
|
||||
if session_id is None:
|
||||
return self.sessioncls(session_id, system_prompt, **self.session_args)
|
||||
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
|
||||
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
|
||||
self.sessions[session_id].set_system_prompt(system_prompt)
|
||||
session = self.sessions[session_id]
|
||||
return session
|
||||
|
||||
def session_query(self, query, session_id):
|
||||
session = self.build_session(session_id)
|
||||
session.add_query(query)
|
||||
try:
|
||||
max_tokens = conf().get("conversation_max_tokens", 1000)
|
||||
total_tokens = session.discard_exceeding(max_tokens, None)
|
||||
logger.debug("prompt tokens used={}".format(total_tokens))
|
||||
except Exception as e:
|
||||
logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens=None):
|
||||
session = self.build_session(session_id)
|
||||
session.add_reply(reply)
|
||||
try:
|
||||
max_tokens = conf().get("conversation_max_tokens", 1000)
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
|
||||
except Exception as e:
|
||||
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def clear_session(self, session_id):
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
|
||||
def clear_all_session(self):
|
||||
self.sessions.clear()
|
||||
268
bot/xunfei/xunfei_spark_bot.py
Normal file
268
bot/xunfei/xunfei_spark_bot.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests, json
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from common import const
|
||||
import time
|
||||
import _thread as thread
|
||||
import datetime
|
||||
from datetime import datetime
|
||||
from wsgiref.handlers import format_date_time
|
||||
from urllib.parse import urlencode
|
||||
import base64
|
||||
import ssl
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from time import mktime
|
||||
from urllib.parse import urlparse
|
||||
import websocket
|
||||
import queue
|
||||
import threading
|
||||
import random
|
||||
|
||||
# 消息队列 map
|
||||
queue_map = dict()
|
||||
|
||||
# 响应队列 map
|
||||
reply_map = dict()
|
||||
|
||||
|
||||
class XunFeiBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app_id = conf().get("xunfei_app_id")
|
||||
self.api_key = conf().get("xunfei_api_key")
|
||||
self.api_secret = conf().get("xunfei_api_secret")
|
||||
# 默认使用v2.0版本: "generalv2"
|
||||
# v1.5版本为 "general"
|
||||
# v3.0版本为: "generalv3"
|
||||
self.domain = "generalv3"
|
||||
# 默认使用v2.0版本: "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
# v1.5版本为: "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
# v3.0版本为: "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
# v3.5版本为: "wss://spark-api.xf-yun.com/v3.5/chat"
|
||||
self.spark_url = "wss://spark-api.xf-yun.com/v3.5/chat"
|
||||
self.host = urlparse(self.spark_url).netloc
|
||||
self.path = urlparse(self.spark_url).path
|
||||
# 和wenxin使用相同的session机制
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[XunFei] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
request_id = self.gen_request_id(session_id)
|
||||
reply_map[request_id] = ""
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
threading.Thread(target=self.create_web_socket,
|
||||
args=(session.messages, request_id)).start()
|
||||
depth = 0
|
||||
time.sleep(0.1)
|
||||
t1 = time.time()
|
||||
usage = {}
|
||||
while depth <= 300:
|
||||
try:
|
||||
data_queue = queue_map.get(request_id)
|
||||
if not data_queue:
|
||||
depth += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
data_item = data_queue.get(block=True, timeout=0.1)
|
||||
if data_item.is_end:
|
||||
# 请求结束
|
||||
del queue_map[request_id]
|
||||
if data_item.reply:
|
||||
reply_map[request_id] += data_item.reply
|
||||
usage = data_item.usage
|
||||
break
|
||||
|
||||
reply_map[request_id] += data_item.reply
|
||||
depth += 1
|
||||
except Exception as e:
|
||||
depth += 1
|
||||
continue
|
||||
t2 = time.time()
|
||||
logger.info(
|
||||
f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}"
|
||||
)
|
||||
self.sessions.session_reply(reply_map[request_id], session_id,
|
||||
usage.get("total_tokens"))
|
||||
reply = Reply(ReplyType.TEXT, reply_map[request_id])
|
||||
del reply_map[request_id]
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR,
|
||||
"Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def create_web_socket(self, prompt, session_id, temperature=0.5):
|
||||
logger.info(f"[XunFei] start connect, prompt={prompt}")
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = self.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl,
|
||||
on_message=on_message,
|
||||
on_error=on_error,
|
||||
on_close=on_close,
|
||||
on_open=on_open)
|
||||
data_queue = queue.Queue(1000)
|
||||
queue_map[session_id] = data_queue
|
||||
ws.appid = self.app_id
|
||||
ws.question = prompt
|
||||
ws.domain = self.domain
|
||||
ws.session_id = session_id
|
||||
ws.temperature = temperature
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def gen_request_id(self, session_id: str):
|
||||
return session_id + "_" + str(int(time.time())) + "" + str(
|
||||
random.randint(0, 100))
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.api_secret.encode('utf-8'),
|
||||
signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(
|
||||
encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
|
||||
f'signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(
|
||||
authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.spark_url + '?' + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
def gen_params(self, appid, domain, question):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class ReplyItem:
|
||||
def __init__(self, reply, usage=None, is_end=False):
|
||||
self.is_end = is_end
|
||||
self.reply = reply
|
||||
self.usage = usage
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
logger.error(f"[XunFei] error: {str(error)}")
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws, one, two):
|
||||
data_queue = queue_map.get(ws.session_id)
|
||||
data_queue.put("END")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
|
||||
thread.start_new_thread(run, (ws, ))
|
||||
|
||||
|
||||
def run(ws, *args):
|
||||
data = json.dumps(
|
||||
gen_params(appid=ws.appid,
|
||||
domain=ws.domain,
|
||||
question=ws.question,
|
||||
temperature=ws.temperature))
|
||||
ws.send(data)
|
||||
|
||||
|
||||
# Websocket 操作
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
logger.error(f'请求错误: {code}, {data}')
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
data_queue = queue_map.get(ws.session_id)
|
||||
if not data_queue:
|
||||
logger.error(
|
||||
f"[XunFei] can't find data queue, session_id={ws.session_id}")
|
||||
return
|
||||
reply_item = ReplyItem(content)
|
||||
if status == 2:
|
||||
usage = data["payload"].get("usage")
|
||||
reply_item = ReplyItem(content, usage)
|
||||
reply_item.is_end = True
|
||||
ws.close()
|
||||
data_queue.put(reply_item)
|
||||
|
||||
|
||||
def gen_params(appid, domain, question, temperature=0.5):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"temperature": temperature,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
29
bot/zhipuai/zhipu_ai_image.py
Normal file
29
bot/zhipuai/zhipu_ai_image.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
# ZhipuAI提供的画图接口
|
||||
|
||||
class ZhipuAIImage(object):
|
||||
def __init__(self):
|
||||
from zhipuai import ZhipuAI
|
||||
self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None, api_base=None):
|
||||
try:
|
||||
if conf().get("rate_limit_dalle"):
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[ZHIPU_AI] image_query={}".format(query))
|
||||
response = self.client.images.generations(
|
||||
prompt=query,
|
||||
n=1, # 每次生成图片的数量
|
||||
model=conf().get("text_to_image") or "cogview-3",
|
||||
size=conf().get("image_create_size", "1024x1024"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
quality="standard",
|
||||
)
|
||||
image_url = response.data[0].url
|
||||
logger.info("[ZHIPU_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
53
bot/zhipuai/zhipu_ai_session.py
Normal file
53
bot/zhipuai/zhipu_ai_session.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ZhipuAISession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="glm-4"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
if not system_prompt:
|
||||
logger.warn("[ZhiPu] `character_desc` can not be empty")
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
149
bot/zhipuai/zhipuai_bot.py
Normal file
149
bot/zhipuai/zhipuai_bot.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.zhipuai.zhipu_ai_session import ZhipuAISession
|
||||
from bot.zhipuai.zhipu_ai_image import ZhipuAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
|
||||
# ZhipuAI对话模型API
|
||||
class ZHIPUAIBot(Bot, ZhipuAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ZhipuAISession, model=conf().get("model") or "ZHIPU_AI")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "glm-4", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在(0,1)之间(智谱AI 的温度不能取 0 或者 1)
|
||||
"top_p": conf().get("top_p", 0.7), # 值在(0,1)之间(智谱AI 的 top_p 不能取 0 或者 1)
|
||||
}
|
||||
self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[ZHIPU_AI] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[ZHIPU_AI] session query={}".format(session.messages))
|
||||
|
||||
api_key = context.get("openai_api_key") or openai.api_key
|
||||
model = context.get("gpt_model")
|
||||
new_args = None
|
||||
if model:
|
||||
new_args = self.args.copy()
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, api_key, args=new_args)
|
||||
logger.debug(
|
||||
"[ZHIPU_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[ZHIPU_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: ZhipuAISession, api_key=None, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
# if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
||||
# raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
if args is None:
|
||||
args = self.args
|
||||
# response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
||||
response = self.client.chat.completions.create(messages=session.messages, **args)
|
||||
# logger.debug("[ZHIPU_AI] response={}".format(response))
|
||||
# logger.info("[ZHIPU_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
|
||||
return {
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"content": response.choices[0].message.content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[ZHIPU_AI] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[ZHIPU_AI] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIError):
|
||||
logger.warn("[ZHIPU_AI] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[ZHIPU_AI] APIConnectionError: {}".format(e))
|
||||
result["content"] = "我连接不到你的网络"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
else:
|
||||
logger.exception("[ZHIPU_AI] Exception: {}".format(e), e)
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[ZHIPU_AI] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, api_key, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
@@ -1,50 +1,103 @@
|
||||
from bot.bot_factory import create_bot
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from common.log import logger
|
||||
from bot import bot_factory
|
||||
from common.singleton import singleton
|
||||
from voice import voice_factory
|
||||
from config import conf
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from translate.factory import create_translator
|
||||
from voice.factory import create_voice
|
||||
|
||||
|
||||
@singleton
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
self.btype={
|
||||
self.btype = {
|
||||
"chat": const.CHATGPT,
|
||||
"voice_to_text": conf().get("voice_to_text", "openai"),
|
||||
"text_to_voice": conf().get("text_to_voice", "baidu")
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
}
|
||||
model_type = conf().get("model")
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype['chat'] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt"):
|
||||
self.btype['chat'] = const.CHATGPTONAZURE
|
||||
self.bots={}
|
||||
# 这边取配置的模型
|
||||
bot_type = conf().get("bot_type")
|
||||
if bot_type:
|
||||
self.btype["chat"] = bot_type
|
||||
else:
|
||||
model_type = conf().get("model") or const.GPT35
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype["chat"] = const.CHATGPTONAZURE
|
||||
if model_type in ["wenxin", "wenxin-4"]:
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and model_type.startswith("gemini"):
|
||||
self.btype["chat"] = const.GEMINI
|
||||
if model_type in [const.ZHIPU_AI]:
|
||||
self.btype["chat"] = const.ZHIPU_AI
|
||||
if model_type and model_type.startswith("claude-3"):
|
||||
self.btype["chat"] = const.CLAUDEAPI
|
||||
|
||||
def get_bot(self,typename):
|
||||
if model_type in ["claude"]:
|
||||
self.btype["chat"] = const.CLAUDEAI
|
||||
|
||||
if model_type in ["moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
|
||||
if model_type in ["abab6.5-chat"]:
|
||||
self.btype["chat"] = const.MiniMax
|
||||
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
self.btype["chat"] = const.LINKAI
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
self.btype["voice_to_text"] = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
self.btype["text_to_voice"] = const.LINKAI
|
||||
|
||||
self.bots = {}
|
||||
self.chat_bots = {}
|
||||
|
||||
# 模型对应的接口
|
||||
def get_bot(self, typename):
|
||||
if self.bots.get(typename) is None:
|
||||
logger.info("create bot {} for {}".format(self.btype[typename],typename))
|
||||
logger.info("create bot {} for {}".format(self.btype[typename], typename))
|
||||
if typename == "text_to_voice":
|
||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||
self.bots[typename] = create_voice(self.btype[typename])
|
||||
elif typename == "voice_to_text":
|
||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||
self.bots[typename] = create_voice(self.btype[typename])
|
||||
elif typename == "chat":
|
||||
self.bots[typename] = bot_factory.create_bot(self.btype[typename])
|
||||
self.bots[typename] = create_bot(self.btype[typename])
|
||||
elif typename == "translate":
|
||||
self.bots[typename] = create_translator(self.btype[typename])
|
||||
return self.bots[typename]
|
||||
|
||||
def get_bot_type(self,typename):
|
||||
|
||||
def get_bot_type(self, typename):
|
||||
return self.btype[typename]
|
||||
|
||||
|
||||
def fetch_reply_content(self, query, context : Context) -> Reply:
|
||||
def fetch_reply_content(self, query, context: Context) -> Reply:
|
||||
return self.get_bot("chat").reply(query, context)
|
||||
|
||||
|
||||
def fetch_voice_to_text(self, voiceFile) -> Reply:
|
||||
return self.get_bot("voice_to_text").voiceToText(voiceFile)
|
||||
|
||||
def fetch_text_to_voice(self, text) -> Reply:
|
||||
return self.get_bot("text_to_voice").textToVoice(text)
|
||||
|
||||
def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
|
||||
return self.get_bot("translate").translate(text, from_lang, to_lang)
|
||||
|
||||
def find_chat_bot(self, bot_type: str):
|
||||
if self.chat_bots.get(bot_type) is None:
|
||||
self.chat_bots[bot_type] = create_bot(bot_type)
|
||||
return self.chat_bots.get(bot_type)
|
||||
|
||||
def reset_bot(self):
|
||||
"""
|
||||
重置bot路由
|
||||
"""
|
||||
self.__init__()
|
||||
|
||||
@@ -2,41 +2,70 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class ContextType (Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE_CREATE = 3 # 创建图片命令
|
||||
|
||||
|
||||
class ContextType(Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE = 3 # 图片消息
|
||||
FILE = 4 # 文件信息
|
||||
VIDEO = 5 # 视频信息
|
||||
SHARING = 6 # 分享信息
|
||||
|
||||
IMAGE_CREATE = 10 # 创建图片命令
|
||||
ACCEPT_FRIEND = 19 # 同意好友请求
|
||||
JOIN_GROUP = 20 # 加入群聊
|
||||
PATPAT = 21 # 拍了拍
|
||||
FUNCTION = 22 # 函数调用
|
||||
EXIT_GROUP = 23 #退出
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()):
|
||||
def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
|
||||
self.type = type
|
||||
self.content = content
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __contains__(self, key):
|
||||
if key == "type":
|
||||
return self.type is not None
|
||||
elif key == "content":
|
||||
return self.content is not None
|
||||
else:
|
||||
return key in self.kwargs
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
return self.type
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
return self.content
|
||||
else:
|
||||
return self.kwargs[key]
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
self.type = value
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
self.content = value
|
||||
else:
|
||||
self.kwargs[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
self.type = None
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
self.content = None
|
||||
else:
|
||||
del self.kwargs[key]
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
||||
|
||||
@@ -1,22 +1,31 @@
|
||||
|
||||
# encoding:utf-8
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ReplyType(Enum):
|
||||
TEXT = 1 # 文本
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
|
||||
TEXT = 1 # 文本
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
VIDEO_URL = 5 # 视频URL
|
||||
FILE = 6 # 文件
|
||||
CARD = 7 # 微信名片,仅支持ntchat
|
||||
INVITE_ROOM = 8 # 邀请好友进群
|
||||
INFO = 9
|
||||
ERROR = 10
|
||||
TEXT_ = 11 # 强制文本
|
||||
VIDEO = 12
|
||||
MINIAPP = 13 # 小程序
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Reply:
|
||||
def __init__(self, type : ReplyType = None , content = None):
|
||||
def __init__(self, type: ReplyType = None, content=None):
|
||||
self.type = type
|
||||
self.content = content
|
||||
|
||||
def __str__(self):
|
||||
return "Reply(type={}, content={})".format(self.type, self.content)
|
||||
return "Reply(type={}, content={})".format(self.type, self.content)
|
||||
|
||||
@@ -4,9 +4,13 @@ Message sending channel abstract class
|
||||
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from bridge.reply import *
|
||||
|
||||
|
||||
class Channel(object):
|
||||
channel_type = ""
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
init channel
|
||||
@@ -20,20 +24,21 @@ class Channel(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def send(self, msg, receiver):
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""
|
||||
send message to user
|
||||
:param msg: message content
|
||||
:param receiver: receiver channel account
|
||||
:return:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_reply_content(self, query, context : Context=None) -> Reply:
|
||||
def build_reply_content(self, query, context: Context = None) -> Reply:
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
|
||||
def build_voice_to_text(self, voice_file) -> Reply:
|
||||
return Bridge().fetch_voice_to_text(voice_file)
|
||||
|
||||
|
||||
def build_text_to_voice(self, text) -> Reply:
|
||||
return Bridge().fetch_text_to_voice(text)
|
||||
|
||||
@@ -1,20 +1,45 @@
|
||||
"""
|
||||
channel factory
|
||||
"""
|
||||
from common import const
|
||||
from .channel import Channel
|
||||
|
||||
def create_channel(channel_type):
|
||||
|
||||
def create_channel(channel_type) -> Channel:
|
||||
"""
|
||||
create a channel instance
|
||||
:param channel_type: channel type code
|
||||
:return: channel instance
|
||||
"""
|
||||
if channel_type == 'wx':
|
||||
ch = Channel()
|
||||
if channel_type == "wx":
|
||||
from channel.wechat.wechat_channel import WechatChannel
|
||||
return WechatChannel()
|
||||
elif channel_type == 'wxy':
|
||||
ch = WechatChannel()
|
||||
elif channel_type == "wxy":
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
return WechatyChannel()
|
||||
elif channel_type == 'terminal':
|
||||
ch = WechatyChannel()
|
||||
elif channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
return TerminalChannel()
|
||||
raise RuntimeError
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == "wechatmp":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
ch = WechatMPChannel(passive_reply=True)
|
||||
elif channel_type == "wechatmp_service":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
ch = WechatMPChannel(passive_reply=False)
|
||||
elif channel_type == "wechatcom_app":
|
||||
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
|
||||
ch = WechatComAppChannel()
|
||||
elif channel_type == "wework":
|
||||
from channel.wework.wework_channel import WeworkChannel
|
||||
ch = WeworkChannel()
|
||||
elif channel_type == const.FEISHU:
|
||||
from channel.feishu.feishu_channel import FeiShuChanel
|
||||
ch = FeiShuChanel()
|
||||
elif channel_type == const.DINGTALK:
|
||||
from channel.dingtalk.dingtalk_channel import DingTalkChanel
|
||||
ch = DingTalkChanel()
|
||||
else:
|
||||
raise RuntimeError
|
||||
ch.channel_type = channel_type
|
||||
return ch
|
||||
|
||||
396
channel/chat_channel.py
Normal file
396
channel/chat_channel.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.channel import Channel
|
||||
from common.dequeue import Dequeue
|
||||
from common import memory
|
||||
from plugins import *
|
||||
|
||||
try:
|
||||
from voice.audio_convert import any_to_wav
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
|
||||
|
||||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
|
||||
def __init__(self):
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
|
||||
# 根据消息构造context,消息内容相关的触发项写在这里
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
# context首次传入时,origin_ctype是None,
|
||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
# context首次传入时,receiver是None,根据类型设置receiver
|
||||
first_in = "receiver" not in context
|
||||
# 群名匹配过程,设置session_id和receiver
|
||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
||||
config = conf()
|
||||
cmsg = context["msg"]
|
||||
user_data = conf().get_user_data(cmsg.from_user_id)
|
||||
context["openai_api_key"] = user_data.get("openai_api_key")
|
||||
context["gpt_model"] = user_data.get("gpt_model")
|
||||
if context.get("isgroup", False):
|
||||
group_name = cmsg.other_user_nickname
|
||||
group_id = cmsg.other_user_id
|
||||
|
||||
group_name_white_list = config.get("group_name_white_list", [])
|
||||
group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
|
||||
if any(
|
||||
[
|
||||
group_name in group_name_white_list,
|
||||
"ALL_GROUP" in group_name_white_list,
|
||||
check_contain(group_name, group_name_keyword_white_list),
|
||||
]
|
||||
):
|
||||
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
|
||||
session_id = cmsg.actual_user_id
|
||||
if any(
|
||||
[
|
||||
group_name in group_chat_in_one_session,
|
||||
"ALL_GROUP" in group_chat_in_one_session,
|
||||
]
|
||||
):
|
||||
session_id = group_id
|
||||
else:
|
||||
logger.debug(f"No need reply, groupName not in whitelist, group_name={group_name}")
|
||||
return None
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = group_id
|
||||
else:
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
|
||||
context = e_context["context"]
|
||||
if e_context.is_pass() or context is None:
|
||||
return context
|
||||
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
|
||||
logger.debug("[chat_channel]self message skipped")
|
||||
return None
|
||||
|
||||
# 消息内容匹配过程,并处理content
|
||||
if ctype == ContextType.TEXT:
|
||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||
logger.debug(content)
|
||||
logger.debug("[chat_channel]reference query skipped")
|
||||
return None
|
||||
|
||||
nick_name_black_list = conf().get("nick_name_black_list", [])
|
||||
if context.get("isgroup", False): # 群聊
|
||||
# 校验关键字
|
||||
match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
|
||||
match_contain = check_contain(content, conf().get("group_chat_keyword"))
|
||||
flag = False
|
||||
if context["msg"].to_user_id != context["msg"].actual_user_id:
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
flag = True
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
if context["msg"].is_at:
|
||||
nick_name = context["msg"].actual_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[chat_channel] Nickname {nick_name} in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
logger.info("[chat_channel]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
flag = True
|
||||
self.name = self.name if self.name is not None else "" # 部分渠道self.name可能没有赋值
|
||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
if isinstance(context["msg"].at_list, list):
|
||||
for at in context["msg"].at_list:
|
||||
pattern = f"@{re.escape(at)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", subtract_res)
|
||||
if subtract_res == content and context["msg"].self_display_name:
|
||||
# 前缀移除后没有变化,使用群昵称再次移除
|
||||
pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
content = subtract_res
|
||||
if not flag:
|
||||
if context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[chat_channel]receive group voice, but checkprefix didn't match")
|
||||
return None
|
||||
else: # 单聊
|
||||
nick_name = context["msg"].from_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[chat_channel] Nickname '{nick_name}' in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
|
||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
content = content.strip()
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix",[""]))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
return context
|
||||
|
||||
def _handle(self, context: Context):
|
||||
if context is None or not context.content:
|
||||
return
|
||||
logger.debug("[chat_channel] ready to handle context: {}".format(context))
|
||||
# reply的构建步骤
|
||||
reply = self._generate_reply(context)
|
||||
|
||||
logger.debug("[chat_channel] ready to decorate reply: {}".format(reply))
|
||||
|
||||
# reply的包装步骤
|
||||
if reply and reply.content:
|
||||
reply = self._decorate_reply(context, reply)
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
|
||||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_HANDLE_CONTEXT,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass():
|
||||
logger.debug("[chat_channel] ready to handle context: type={}, content={}".format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
context["channel"] = e_context["channel"]
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE: # 语音消息
|
||||
cmsg = context["msg"]
|
||||
cmsg.prepare()
|
||||
file_path = context.content
|
||||
wav_path = os.path.splitext(file_path)[0] + ".wav"
|
||||
try:
|
||||
any_to_wav(file_path, wav_path)
|
||||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
||||
logger.warning("[chat_channel]any to wav error, use raw path. " + str(e))
|
||||
wav_path = file_path
|
||||
# 语音识别
|
||||
reply = super().build_voice_to_text(wav_path)
|
||||
# 删除临时文件
|
||||
try:
|
||||
os.remove(file_path)
|
||||
if wav_path != file_path:
|
||||
os.remove(wav_path)
|
||||
except Exception as e:
|
||||
pass
|
||||
# logger.warning("[chat_channel]delete temp file error: " + str(e))
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
|
||||
if new_context:
|
||||
reply = self._generate_reply(new_context)
|
||||
else:
|
||||
return
|
||||
elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
|
||||
memory.USER_IMAGE_CACHE[context["session_id"]] = {
|
||||
"path": context.content,
|
||||
"msg": context.get("msg")
|
||||
}
|
||||
elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
|
||||
pass
|
||||
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
|
||||
pass
|
||||
else:
|
||||
logger.warning("[chat_channel] unknown context type: {}".format(context.type))
|
||||
return
|
||||
return reply
|
||||
|
||||
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_DECORATE_REPLY,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
desire_rtype = context.get("desire_rtype")
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||
logger.error("[chat_channel]reply type not support: " + str(reply.type))
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "不支持发送的消息类型: " + str(reply.type)
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context.get("isgroup", False):
|
||||
if not context.get("no_need_at", False):
|
||||
reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "")
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "")
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = "[" + str(reply.type) + "]\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
|
||||
pass
|
||||
else:
|
||||
logger.error("[chat_channel] unknown reply type: {}".format(reply.type))
|
||||
return
|
||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
||||
logger.warning("[chat_channel] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
|
||||
return reply
|
||||
|
||||
def _send_reply(self, context: Context, reply: Reply):
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_SEND_REPLY,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug("[chat_channel] ready to send reply: {}, context: {}".format(reply, context))
|
||||
self._send(reply, context)
|
||||
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt=0):
|
||||
try:
|
||||
self.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error("[chat_channel] sendMsg error: {}".format(str(e)))
|
||||
if isinstance(e, NotImplementedError):
|
||||
return
|
||||
logger.exception(e)
|
||||
if retry_cnt < 2:
|
||||
time.sleep(3 + 3 * retry_cnt)
|
||||
self._send(reply, context, retry_cnt + 1)
|
||||
|
||||
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
|
||||
logger.debug("Worker return success, session_id = {}".format(session_id))
|
||||
|
||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.exception("Worker return exception: {}".format(exception))
|
||||
|
||||
def _thread_pool_callback(self, session_id, **kwargs):
|
||||
def func(worker: Future):
|
||||
try:
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
self._fail_callback(session_id, exception=worker_exception, **kwargs)
|
||||
else:
|
||||
self._success_callback(session_id, **kwargs)
|
||||
except CancelledError as e:
|
||||
logger.info("Worker cancelled, session_id = {}".format(session_id))
|
||||
except Exception as e:
|
||||
logger.exception("Worker raise exception: {}".format(e))
|
||||
with self.lock:
|
||||
self.sessions[session_id][1].release()
|
||||
|
||||
return func
|
||||
|
||||
def produce(self, context: Context):
|
||||
session_id = context["session_id"]
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = [
|
||||
Dequeue(),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
|
||||
]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
else:
|
||||
self.sessions[session_id][0].put(context)
|
||||
|
||||
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
||||
def consume(self):
|
||||
while True:
|
||||
with self.lock:
|
||||
session_ids = list(self.sessions.keys())
|
||||
for session_id in session_ids:
|
||||
context_queue, semaphore = self.sessions[session_id]
|
||||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[chat_channel] consume context: {}".format(context))
|
||||
future: Future = handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
self.futures[session_id].append(future)
|
||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
||||
del self.sessions[session_id]
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.1)
|
||||
|
||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
||||
def cancel_session(self, session_id):
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt > 0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
def cancel_all_session(self):
|
||||
with self.lock:
|
||||
for session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt > 0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
if not prefix_list:
|
||||
return None
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
|
||||
def check_contain(content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
87
channel/chat_message.py
Normal file
87
channel/chat_message.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
||||
|
||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
||||
|
||||
ChatMessage
|
||||
msg_id: 消息id (必填)
|
||||
create_time: 消息创建时间
|
||||
|
||||
ctype: 消息类型 : ContextType (必填)
|
||||
content: 消息内容, 如果是声音/图片,这里是文件路径 (必填)
|
||||
|
||||
from_user_id: 发送者id (必填)
|
||||
from_user_nickname: 发送者昵称
|
||||
to_user_id: 接收者id (必填)
|
||||
to_user_nickname: 接收者昵称
|
||||
|
||||
other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id (必填)
|
||||
other_user_nickname: 同上
|
||||
|
||||
is_group: 是否是群消息 (群聊必填)
|
||||
is_at: 是否被at
|
||||
|
||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
||||
actual_user_id: 实际发送者id (群聊必填)
|
||||
actual_user_nickname:实际发送者昵称
|
||||
self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
|
||||
|
||||
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
|
||||
_prepared: 是否已经调用过准备函数
|
||||
_rawmsg: 原始消息对象
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ChatMessage(object):
|
||||
msg_id = None
|
||||
create_time = None
|
||||
|
||||
ctype = None
|
||||
content = None
|
||||
|
||||
from_user_id = None
|
||||
from_user_nickname = None
|
||||
to_user_id = None
|
||||
to_user_nickname = None
|
||||
other_user_id = None
|
||||
other_user_nickname = None
|
||||
my_msg = False
|
||||
self_display_name = None
|
||||
|
||||
is_group = False
|
||||
is_at = False
|
||||
actual_user_id = None
|
||||
actual_user_nickname = None
|
||||
at_list = None
|
||||
|
||||
_prepare_fn = None
|
||||
_prepared = False
|
||||
_rawmsg = None
|
||||
|
||||
def __init__(self, _rawmsg):
|
||||
self._rawmsg = _rawmsg
|
||||
|
||||
def prepare(self):
|
||||
if self._prepare_fn and not self._prepared:
|
||||
self._prepared = True
|
||||
self._prepare_fn()
|
||||
|
||||
def __str__(self):
|
||||
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}, at_list={}".format(
|
||||
self.msg_id,
|
||||
self.create_time,
|
||||
self.ctype,
|
||||
self.content,
|
||||
self.from_user_id,
|
||||
self.from_user_nickname,
|
||||
self.to_user_id,
|
||||
self.to_user_nickname,
|
||||
self.other_user_id,
|
||||
self.other_user_nickname,
|
||||
self.is_group,
|
||||
self.is_at,
|
||||
self.actual_user_id,
|
||||
self.actual_user_nickname,
|
||||
self.at_list
|
||||
)
|
||||
225
channel/dingtalk/dingtalk_channel.py
Normal file
225
channel/dingtalk/dingtalk_channel.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
钉钉通道接入
|
||||
|
||||
@author huiwen
|
||||
@Date 2023/11/28
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
# -*- coding=utf-8 -*-
|
||||
import logging
|
||||
import time
|
||||
|
||||
import dingtalk_stream
|
||||
from dingtalk_stream import AckMessage
|
||||
from dingtalk_stream.card_replier import AICardReplier
|
||||
from dingtalk_stream.card_replier import AICardStatus
|
||||
from dingtalk_stream.card_replier import CardReplier
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.dingtalk.dingtalk_message import DingTalkMessage
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.time_check import time_checker
|
||||
from config import conf
|
||||
|
||||
|
||||
class CustomAICardReplier(CardReplier):
|
||||
def __init__(self, dingtalk_client, incoming_message):
|
||||
super(AICardReplier, self).__init__(dingtalk_client, incoming_message)
|
||||
|
||||
def start(
|
||||
self,
|
||||
card_template_id: str,
|
||||
card_data: dict,
|
||||
recipients: list = None,
|
||||
support_forward: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
AI卡片的创建接口
|
||||
:param support_forward:
|
||||
:param recipients:
|
||||
:param card_template_id:
|
||||
:param card_data:
|
||||
:return:
|
||||
"""
|
||||
card_data_with_status = copy.deepcopy(card_data)
|
||||
card_data_with_status["flowStatus"] = AICardStatus.PROCESSING
|
||||
return self.create_and_send_card(
|
||||
card_template_id,
|
||||
card_data_with_status,
|
||||
at_sender=True,
|
||||
at_all=False,
|
||||
recipients=recipients,
|
||||
support_forward=support_forward,
|
||||
)
|
||||
|
||||
|
||||
# 对 AICardReplier 进行猴子补丁
|
||||
AICardReplier.start = CustomAICardReplier.start
|
||||
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: DingTalkMessage):
|
||||
msgId = cmsg.msg_id
|
||||
if msgId in self.receivedMsgs:
|
||||
logger.info("DingTalk message {} already received, ignore".format(msgId))
|
||||
return
|
||||
self.receivedMsgs[msgId] = True
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[DingTalk] History message {} skipped".format(msgId))
|
||||
return
|
||||
if cmsg.my_msg and not cmsg.is_group:
|
||||
logger.debug("[DingTalk] My message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@singleton
|
||||
class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
|
||||
def setup_logger(self):
|
||||
logger = logging.getLogger()
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter('%(asctime)s %(name)-8s %(levelname)-8s %(message)s [%(filename)s:%(lineno)d]'))
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
return logger
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
super(dingtalk_stream.ChatbotHandler, self).__init__()
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds"))
|
||||
logger.info("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
# 单聊无需前缀
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
def startup(self):
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
client.start_forever()
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
image_download_handler = self # 传入方法所在的类实例
|
||||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
|
||||
|
||||
if dingtalk_msg.is_group:
|
||||
self.handle_group(dingtalk_msg)
|
||||
else:
|
||||
self.handle_single(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(f"dingtalk process error={e}")
|
||||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: DingTalkMessage):
|
||||
# 处理单聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE_CREATE:
|
||||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: DingTalkMessage):
|
||||
# 处理群聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE_CREATE:
|
||||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
context['no_need_at'] = True
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
isgroup = context.kwargs['msg'].is_group
|
||||
incoming_message = context.kwargs['msg'].incoming_message
|
||||
|
||||
if conf().get("dingtalk_card_enabled"):
|
||||
logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
def reply_with_text():
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
def reply_with_at_text():
|
||||
self.reply_text("📢 您有一条新的消息,请查看。", incoming_message)
|
||||
def reply_with_ai_markdown():
|
||||
button_list, markdown_content = self.generate_button_markdown_content(context, reply)
|
||||
self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI生成", "",[incoming_message.sender_staff_id])
|
||||
|
||||
if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]:
|
||||
if isgroup:
|
||||
reply_with_ai_markdown()
|
||||
reply_with_at_text()
|
||||
else:
|
||||
reply_with_ai_markdown()
|
||||
else:
|
||||
# 暂不支持其它类型消息回复
|
||||
reply_with_text()
|
||||
else:
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
|
||||
|
||||
def generate_button_markdown_content(self, context, reply):
|
||||
image_url = context.kwargs.get("image_url")
|
||||
promptEn = context.kwargs.get("promptEn")
|
||||
reply_text = reply.content
|
||||
button_list = []
|
||||
markdown_content = f"""
|
||||
{reply.content}
|
||||
"""
|
||||
if image_url is not None and promptEn is not None:
|
||||
button_list = [
|
||||
{"text": "查看原图", "url": image_url, "iosUrl": image_url, "color": "blue"}
|
||||
]
|
||||
markdown_content = f"""
|
||||
{promptEn}
|
||||
|
||||

|
||||
|
||||
{reply_text}
|
||||
|
||||
"""
|
||||
logger.debug(f"[Dingtalk] generate_button_markdown_content, button_list={button_list} , markdown_content={markdown_content}")
|
||||
|
||||
return button_list, markdown_content
|
||||
84
channel/dingtalk/dingtalk_message.py
Normal file
84
channel/dingtalk/dingtalk_message.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
from dingtalk_stream import ChatbotMessage
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
# -*- coding=utf-8 -*-
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class DingTalkMessage(ChatMessage):
|
||||
def __init__(self, event: ChatbotMessage, image_download_handler):
|
||||
super().__init__(event)
|
||||
self.image_download_handler = image_download_handler
|
||||
self.msg_id = event.message_id
|
||||
self.message_type = event.message_type
|
||||
self.incoming_message = event
|
||||
self.sender_staff_id = event.sender_staff_id
|
||||
self.other_user_id = event.conversation_id
|
||||
self.create_time = event.create_at
|
||||
self.image_content = event.image_content
|
||||
self.rich_text_content = event.rich_text_content
|
||||
if event.conversation_type == "1":
|
||||
self.is_group = False
|
||||
else:
|
||||
self.is_group = True
|
||||
|
||||
if self.message_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
self.content = event.text.content.strip()
|
||||
elif self.message_type == "audio":
|
||||
# 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理
|
||||
self.content = event.extensions['content']['recognition'].strip()
|
||||
self.ctype = ContextType.TEXT
|
||||
elif (self.message_type == 'picture') or (self.message_type == 'richText'):
|
||||
self.ctype = ContextType.IMAGE
|
||||
# 钉钉图片类型或富文本类型消息处理
|
||||
image_list = event.get_image_list()
|
||||
if len(image_list) > 0:
|
||||
download_code = image_list[0]
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
self.content = download_image_file(download_url, TmpDir().path())
|
||||
else:
|
||||
logger.debug(f"[Dingtalk] messageType :{self.message_type} , imageList isEmpty")
|
||||
|
||||
if self.is_group:
|
||||
self.from_user_id = event.conversation_id
|
||||
self.actual_user_id = event.sender_id
|
||||
self.is_at = True
|
||||
else:
|
||||
self.from_user_id = event.sender_id
|
||||
self.actual_user_id = event.sender_id
|
||||
self.to_user_id = event.chatbot_user_id
|
||||
self.other_user_nickname = event.conversation_title
|
||||
|
||||
|
||||
def download_image_file(image_url, temp_dir):
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36'
|
||||
}
|
||||
# 设置代理
|
||||
# self.proxies
|
||||
# , proxies=self.proxies
|
||||
response = requests.get(image_url, headers=headers, stream=True, timeout=60 * 5)
|
||||
if response.status_code == 200:
|
||||
|
||||
# 生成文件名
|
||||
file_name = image_url.split("/")[-1].split("?")[0]
|
||||
|
||||
# 检查临时目录是否存在,如果不存在则创建
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
|
||||
# 将文件保存到临时目录
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
return file_path
|
||||
else:
|
||||
logger.info(f"[Dingtalk] Failed to download image file, {response.content}")
|
||||
return None
|
||||
254
channel/feishu/feishu_channel.py
Normal file
254
channel/feishu/feishu_channel.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
飞书通道接入
|
||||
|
||||
@author Saboteur7
|
||||
@Date 2023/11/19
|
||||
"""
|
||||
|
||||
# -*- coding=utf-8 -*-
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
import web
|
||||
from channel.feishu.feishu_message import FeishuMessage
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from common.expired_dict import ExpiredDict
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from common import utils
|
||||
import json
|
||||
import os
|
||||
|
||||
URL_VERIFICATION = "url_verification"
|
||||
|
||||
|
||||
@singleton
|
||||
class FeiShuChanel(ChatChannel):
|
||||
feishu_app_id = conf().get('feishu_app_id')
|
||||
feishu_app_secret = conf().get('feishu_app_secret')
|
||||
feishu_token = conf().get('feishu_token')
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
logger.info("[FeiShu] app_id={}, app_secret={} verification_token={}".format(
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
def startup(self):
|
||||
urls = (
|
||||
'/', 'channel.feishu.feishu_channel.FeishuController'
|
||||
)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("feishu_port", 9891)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context["isgroup"]
|
||||
if msg:
|
||||
access_token = msg.access_token
|
||||
else:
|
||||
access_token = self.fetch_access_token()
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
msg_type = "text"
|
||||
logger.info(f"[FeiShu] start send reply message, type={context.type}, content={reply.content}")
|
||||
reply_content = reply.content
|
||||
content_key = "text"
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
# 图片上传
|
||||
reply_content = self._upload_image_url(reply.content, access_token)
|
||||
if not reply_content:
|
||||
logger.warning("[FeiShu] upload file failed")
|
||||
return
|
||||
msg_type = "image"
|
||||
content_key = "image_key"
|
||||
if is_group:
|
||||
# 群聊中直接回复
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply"
|
||||
data = {
|
||||
"msg_type": msg_type,
|
||||
"content": json.dumps({content_key: reply_content})
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
|
||||
else:
|
||||
url = "https://open.feishu.cn/open-apis/im/v1/messages"
|
||||
params = {"receive_id_type": context.get("receive_id_type") or "open_id"}
|
||||
data = {
|
||||
"receive_id": context.get("receiver"),
|
||||
"msg_type": msg_type,
|
||||
"content": json.dumps({content_key: reply_content})
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10))
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
logger.info(f"[FeiShu] send message success")
|
||||
else:
|
||||
logger.error(f"[FeiShu] send message failed, code={res.get('code')}, msg={res.get('msg')}")
|
||||
|
||||
|
||||
def fetch_access_token(self) -> str:
|
||||
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/"
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
req_body = {
|
||||
"app_id": self.feishu_app_id,
|
||||
"app_secret": self.feishu_app_secret
|
||||
}
|
||||
data = bytes(json.dumps(req_body), encoding='utf8')
|
||||
response = requests.post(url=url, data=data, headers=headers)
|
||||
if response.status_code == 200:
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
logger.error(f"[FeiShu] get tenant_access_token error, code={res.get('code')}, msg={res.get('msg')}")
|
||||
return ""
|
||||
else:
|
||||
return res.get("tenant_access_token")
|
||||
else:
|
||||
logger.error(f"[FeiShu] fetch token error, res={response}")
|
||||
|
||||
|
||||
def _upload_image_url(self, img_url, access_token):
|
||||
logger.debug(f"[WX] start download image, img_url={img_url}")
|
||||
response = requests.get(img_url)
|
||||
suffix = utils.get_path_suffix(img_url)
|
||||
temp_name = str(uuid.uuid4()) + "." + suffix
|
||||
if response.status_code == 200:
|
||||
# 将图片内容保存为临时文件
|
||||
with open(temp_name, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
# upload
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
|
||||
data = {
|
||||
'image_type': 'message'
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
}
|
||||
with open(temp_name, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
os.remove(temp_name)
|
||||
return upload_response.json().get("data").get("image_key")
|
||||
|
||||
|
||||
|
||||
class FeishuController:
|
||||
# 类常量
|
||||
FAILED_MSG = '{"success": false}'
|
||||
SUCCESS_MSG = '{"success": true}'
|
||||
MESSAGE_RECEIVE_TYPE = "im.message.receive_v1"
|
||||
|
||||
def GET(self):
|
||||
return "Feishu service start success!"
|
||||
|
||||
def POST(self):
|
||||
try:
|
||||
channel = FeiShuChanel()
|
||||
|
||||
request = json.loads(web.data().decode("utf-8"))
|
||||
logger.debug(f"[FeiShu] receive request: {request}")
|
||||
|
||||
# 1.事件订阅回调验证
|
||||
if request.get("type") == URL_VERIFICATION:
|
||||
varify_res = {"challenge": request.get("challenge")}
|
||||
return json.dumps(varify_res)
|
||||
|
||||
# 2.消息接收处理
|
||||
# token 校验
|
||||
header = request.get("header")
|
||||
if not header or header.get("token") != channel.feishu_token:
|
||||
return self.FAILED_MSG
|
||||
|
||||
# 处理消息事件
|
||||
event = request.get("event")
|
||||
if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event:
|
||||
if not event.get("message") or not event.get("sender"):
|
||||
logger.warning(f"[FeiShu] invalid message, msg={request}")
|
||||
return self.FAILED_MSG
|
||||
msg = event.get("message")
|
||||
|
||||
# 幂等判断
|
||||
if channel.receivedMsgs.get(msg.get("message_id")):
|
||||
logger.warning(f"[FeiShu] repeat msg filtered, event_id={header.get('event_id')}")
|
||||
return self.SUCCESS_MSG
|
||||
channel.receivedMsgs[msg.get("message_id")] = True
|
||||
|
||||
is_group = False
|
||||
chat_type = msg.get("chat_type")
|
||||
if chat_type == "group":
|
||||
if not msg.get("mentions") and msg.get("message_type") == "text":
|
||||
# 群聊中未@不响应
|
||||
return self.SUCCESS_MSG
|
||||
if msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get("message_type") == "text":
|
||||
# 不是@机器人,不响应
|
||||
return self.SUCCESS_MSG
|
||||
# 群聊
|
||||
is_group = True
|
||||
receive_id_type = "chat_id"
|
||||
elif chat_type == "p2p":
|
||||
receive_id_type = "open_id"
|
||||
else:
|
||||
logger.warning("[FeiShu] message ignore")
|
||||
return self.SUCCESS_MSG
|
||||
# 构造飞书消息对象
|
||||
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=channel.fetch_access_token())
|
||||
if not feishu_msg:
|
||||
return self.SUCCESS_MSG
|
||||
|
||||
context = self._compose_context(
|
||||
feishu_msg.ctype,
|
||||
feishu_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=feishu_msg,
|
||||
receive_id_type=receive_id_type,
|
||||
no_need_at=True
|
||||
)
|
||||
if context:
|
||||
channel.produce(context)
|
||||
logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
|
||||
return self.SUCCESS_MSG
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.FAILED_MSG
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# 1.文本请求
|
||||
# 图片生成处理
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
elif context.type == ContextType.VOICE:
|
||||
# 2.语音请求
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
63
channel/feishu/feishu_message.py
Normal file
63
channel/feishu/feishu_message.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
import json
|
||||
import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
|
||||
|
||||
class FeishuMessage(ChatMessage):
|
||||
def __init__(self, event: dict, is_group=False, access_token=None):
|
||||
super().__init__(event)
|
||||
msg = event.get("message")
|
||||
sender = event.get("sender")
|
||||
self.access_token = access_token
|
||||
self.msg_id = msg.get("message_id")
|
||||
self.create_time = msg.get("create_time")
|
||||
self.is_group = is_group
|
||||
msg_type = msg.get("message_type")
|
||||
|
||||
if msg_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
content = json.loads(msg.get('content'))
|
||||
self.content = content.get("text").strip()
|
||||
elif msg_type == "file":
|
||||
self.ctype = ContextType.FILE
|
||||
content = json.loads(msg.get("content"))
|
||||
file_key = content.get("file_key")
|
||||
file_name = content.get("file_name")
|
||||
|
||||
self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name)
|
||||
|
||||
def _download_file():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}"
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
}
|
||||
params = {
|
||||
"type": "file"
|
||||
}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
|
||||
self._prepare_fn = _download_file
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))
|
||||
|
||||
self.from_user_id = sender.get("sender_id").get("open_id")
|
||||
self.to_user_id = event.get("app_id")
|
||||
if is_group:
|
||||
# 群聊
|
||||
self.other_user_id = msg.get("chat_id")
|
||||
self.actual_user_id = self.from_user_id
|
||||
self.content = self.content.replace("@_user_1", "").strip()
|
||||
self.actual_user_nickname = ""
|
||||
else:
|
||||
# 私聊
|
||||
self.other_user_id = self.from_user_id
|
||||
self.actual_user_id = self.from_user_id
|
||||
@@ -1,31 +1,93 @@
|
||||
from bridge.context import *
|
||||
from channel.channel import Channel
|
||||
import sys
|
||||
|
||||
class TerminalChannel(Channel):
|
||||
from bridge.context import *
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class TerminalMessage(ChatMessage):
|
||||
def __init__(
|
||||
self,
|
||||
msg_id,
|
||||
content,
|
||||
ctype=ContextType.TEXT,
|
||||
from_user_id="User",
|
||||
to_user_id="Chatgpt",
|
||||
other_user_id="Chatgpt",
|
||||
):
|
||||
self.msg_id = msg_id
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = to_user_id
|
||||
self.other_user_id = other_user_id
|
||||
|
||||
|
||||
class TerminalChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
print("\nBot:")
|
||||
if reply.type == ReplyType.IMAGE:
|
||||
from PIL import Image
|
||||
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print("<IMAGE>")
|
||||
img.show()
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print(img_url)
|
||||
img.show()
|
||||
else:
|
||||
print(reply.content)
|
||||
print("\nUser:", end="")
|
||||
sys.stdout.flush()
|
||||
return
|
||||
|
||||
def startup(self):
|
||||
context = Context()
|
||||
print("\nPlease input your question")
|
||||
logger.setLevel("WARN")
|
||||
print("\nPlease input your question:\nUser:", end="")
|
||||
sys.stdout.flush()
|
||||
msg_id = 0
|
||||
while True:
|
||||
try:
|
||||
prompt = self.get_input("User:\n")
|
||||
prompt = self.get_input()
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
sys.exit()
|
||||
msg_id += 1
|
||||
trigger_prefixs = conf().get("single_chat_prefix", [""])
|
||||
if check_prefix(prompt, trigger_prefixs) is None:
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
|
||||
context.type = ContextType.TEXT
|
||||
context['session_id'] = "User"
|
||||
context.content = prompt
|
||||
print("Bot:")
|
||||
sys.stdout.flush()
|
||||
res = super().build_reply_content(prompt, context).content
|
||||
print(res)
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
|
||||
context["isgroup"] = False
|
||||
if context:
|
||||
self.produce(context)
|
||||
else:
|
||||
raise Exception("context is None")
|
||||
|
||||
|
||||
def get_input(self, prompt):
|
||||
def get_input(self):
|
||||
"""
|
||||
Multi-line input function
|
||||
"""
|
||||
print(prompt, end="")
|
||||
sys.stdout.flush()
|
||||
line = input()
|
||||
return line
|
||||
|
||||
@@ -4,71 +4,154 @@
|
||||
wechat channel
|
||||
"""
|
||||
|
||||
import os
|
||||
from lib import itchat
|
||||
import json
|
||||
from lib.itchat.content import *
|
||||
from bridge.reply import *
|
||||
from bridge.context import *
|
||||
from channel.channel import Channel
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from config import conf
|
||||
from common.time_check import time_checker
|
||||
from plugins import *
|
||||
import requests
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
def thread_pool_callback(worker):
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel import chat_channel
|
||||
from channel.wechat.wechat_message import *
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.time_check import time_checker
|
||||
from config import conf, get_appdata_dir
|
||||
from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
|
||||
@itchat.msg_register(TEXT)
|
||||
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING])
|
||||
def handler_single_msg(msg):
|
||||
WechatChannel().handle_text(msg)
|
||||
try:
|
||||
cmsg = WechatMessage(msg, False)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX]single message {} skipped: {}".format(msg["MsgId"], e))
|
||||
return None
|
||||
WechatChannel().handle_single(cmsg)
|
||||
return None
|
||||
|
||||
|
||||
@itchat.msg_register(TEXT, isGroupChat=True)
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING], isGroupChat=True)
|
||||
def handler_group_msg(msg):
|
||||
WechatChannel().handle_group(msg)
|
||||
try:
|
||||
cmsg = WechatMessage(msg, True)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX]group message {} skipped: {}".format(msg["MsgId"], e))
|
||||
return None
|
||||
WechatChannel().handle_group(cmsg)
|
||||
return None
|
||||
|
||||
|
||||
@itchat.msg_register(VOICE)
|
||||
def handler_single_voice(msg):
|
||||
WechatChannel().handle_voice(msg)
|
||||
return None
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: ChatMessage):
|
||||
msgId = cmsg.msg_id
|
||||
if msgId in self.receivedMsgs:
|
||||
logger.info("Wechat message {} already received, ignore".format(msgId))
|
||||
return
|
||||
self.receivedMsgs[msgId] = True
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
if cmsg.my_msg and not cmsg.is_group:
|
||||
logger.debug("[WX]my message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class WechatChannel(Channel):
|
||||
# 可用的二维码生成接口
|
||||
# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
|
||||
# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
|
||||
def qrCallback(uuid, status, qrcode):
|
||||
# logger.debug("qrCallback: {} {}".format(uuid,status))
|
||||
if status == "0":
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(io.BytesIO(qrcode))
|
||||
_thread = threading.Thread(target=img.show, args=("QRCode",))
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
import qrcode
|
||||
|
||||
url = f"https://login.weixin.qq.com/l/{uuid}"
|
||||
|
||||
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
|
||||
qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
|
||||
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
|
||||
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
|
||||
print("You can also scan QRCode in any website below:")
|
||||
print(qr_api3)
|
||||
print(qr_api4)
|
||||
print(qr_api2)
|
||||
print(qr_api1)
|
||||
_send_qr_code([qr_api3, qr_api4, qr_api2, qr_api1])
|
||||
qr = qrcode.QRCode(border=1)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
super().__init__()
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds"))
|
||||
self.auto_login_times = 0
|
||||
|
||||
def startup(self):
|
||||
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get('hot_reload', False)
|
||||
try:
|
||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get("hot_reload", False)
|
||||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
itchat.auto_login(
|
||||
enableCmdQR=2,
|
||||
hotReload=hotReload,
|
||||
statusStorageDir=status_path,
|
||||
qrCallback=qrCallback,
|
||||
exitCallback=self.exitCallback,
|
||||
loginCallback=self.loginCallback
|
||||
)
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
self.name = itchat.instance.storageClass.nickName
|
||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# start message listener
|
||||
itchat.run()
|
||||
except Exception as e:
|
||||
if hotReload:
|
||||
logger.error("Hot reload failed, try to login without hot reload")
|
||||
itchat.logout()
|
||||
os.remove("itchat.pkl")
|
||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
|
||||
else:
|
||||
raise e
|
||||
# start message listener
|
||||
itchat.run()
|
||||
logger.exception(e)
|
||||
|
||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入handle函数中处理Context和发送回复
|
||||
def exitCallback(self):
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id and conf().get("use_linkai"):
|
||||
_send_logout()
|
||||
time.sleep(2)
|
||||
self.auto_login_times += 1
|
||||
if self.auto_login_times < 100:
|
||||
chat_channel.handler_pool._shutdown = False
|
||||
self.startup()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def loginCallback(self):
|
||||
logger.debug("Login success")
|
||||
_send_login_success()
|
||||
|
||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复
|
||||
# Context包含了消息的所有信息,包括以下属性
|
||||
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
|
||||
# content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
|
||||
@@ -76,197 +159,125 @@ class WechatChannel(Channel):
|
||||
# session_id: 会话id
|
||||
# isgroup: 是否是群聊
|
||||
# receiver: 需要回复的对象
|
||||
# msg: itchat的原始消息对象
|
||||
|
||||
def handle_voice(self, msg):
|
||||
if conf().get('speech_recognition') != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: " + msg['FileName'])
|
||||
from_user_id = msg['FromUserName']
|
||||
other_user_id = msg['User']['UserName']
|
||||
if from_user_id == other_user_id:
|
||||
context = Context(ContextType.VOICE,msg['FileName'])
|
||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
|
||||
# msg: ChatMessage消息对象
|
||||
# origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
|
||||
# desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复
|
||||
@time_checker
|
||||
def handle_text(self, msg):
|
||||
logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
|
||||
content = msg['Text']
|
||||
from_user_id = msg['FromUserName']
|
||||
to_user_id = msg['ToUserName'] # 接收人id
|
||||
other_user_id = msg['User']['UserName'] # 对手方id
|
||||
create_time = msg['CreateTime'] # 消息时间
|
||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
|
||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message skipped")
|
||||
@_check
|
||||
def handle_single(self, cmsg: ChatMessage):
|
||||
# filter system message
|
||||
if cmsg.other_user_id in ["weixin"]:
|
||||
return
|
||||
if "」\n- - - - - - - - - - - - - - -" in content:
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
elif match_prefix is None:
|
||||
return
|
||||
context = Context()
|
||||
context.kwargs = {'isgroup': False, 'msg': msg, 'receiver': other_user_id, 'session_id': other_user_id}
|
||||
|
||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, '', 1).strip()
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if conf().get("speech_recognition") != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
|
||||
context.content = content
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
def handle_group(self, msg):
|
||||
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
|
||||
group_name = msg['User'].get('NickName', None)
|
||||
group_id = msg['User'].get('UserName', None)
|
||||
create_time = msg['CreateTime'] # 消息时间
|
||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: #跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history group message skipped")
|
||||
return
|
||||
if not group_name:
|
||||
return ""
|
||||
origin_content = msg['Content']
|
||||
content = msg['Content']
|
||||
content_list = content.split(' ', 1)
|
||||
context_special_list = content.split('\u2005', 1)
|
||||
if len(context_special_list) == 2:
|
||||
content = context_special_list[1]
|
||||
elif len(content_list) == 2:
|
||||
content = content_list[1]
|
||||
if "」\n- - - - - - - - - - - - - - -" in content:
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return ""
|
||||
config = conf()
|
||||
match_prefix = (msg['IsAt'] and not config.get("group_at_off", False)) or check_prefix(origin_content, config.get('group_chat_prefix')) \
|
||||
or check_contain(origin_content, config.get('group_chat_keyword'))
|
||||
if ('ALL_GROUP' in config.get('group_name_white_list') or group_name in config.get('group_name_white_list') or check_contain(group_name, config.get('group_name_keyword_white_list'))) and match_prefix:
|
||||
context = Context()
|
||||
context.kwargs = { 'isgroup': True, 'msg': msg, 'receiver': group_id}
|
||||
|
||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, '', 1).strip()
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content
|
||||
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
if ('ALL_GROUP' in group_chat_in_one_session or
|
||||
group_name in group_chat_in_one_session or
|
||||
check_contain(group_name, group_chat_in_one_session)):
|
||||
context['session_id'] = group_id
|
||||
else:
|
||||
context['session_id'] = msg['ActualUserName']
|
||||
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
@_check
|
||||
def handle_group(self, cmsg: ChatMessage):
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if conf().get("group_speech_recognition") != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.ACCEPT_FRIEND, ContextType.EXIT_GROUP]:
|
||||
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
pass
|
||||
elif cmsg.ctype == ContextType.FILE:
|
||||
logger.debug(f"[WX]receive attachment msg, file_name={cmsg.content}")
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply : Reply, receiver):
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if reply.type == ReplyType.TEXT:
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
itchat.send_file(reply.content, toUserName=receiver)
|
||||
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
logger.debug(f"[WX] start download image, img_url={img_url}")
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
size = 0
|
||||
for block in pic_res.iter_content(1024):
|
||||
size += len(block)
|
||||
image_storage.write(block)
|
||||
logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.FILE: # 新增文件回复类型
|
||||
file_storage = reply.content
|
||||
itchat.send_file(file_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendFile, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO: # 新增视频回复类型
|
||||
video_storage = reply.content
|
||||
itchat.send_video(video_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendFile, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL: # 新增视频URL回复类型
|
||||
video_url = reply.content
|
||||
logger.debug(f"[WX] start download video, video_url={video_url}")
|
||||
video_res = requests.get(video_url, stream=True)
|
||||
video_storage = io.BytesIO()
|
||||
size = 0
|
||||
for block in video_res.iter_content(1024):
|
||||
size += len(block)
|
||||
video_storage.write(block)
|
||||
logger.info(f"[WX] download video success, size={size}, video_url={video_url}")
|
||||
video_storage.seek(0)
|
||||
itchat.send_video(video_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendVideo url={}, receiver={}".format(video_url, receiver))
|
||||
|
||||
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
|
||||
def handle(self, context):
|
||||
reply = Reply()
|
||||
def _send_login_success():
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id:
|
||||
chat_client.send_login_success()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
logger.debug('[WX] ready to handle context: {}'.format(context))
|
||||
|
||||
# reply的构建步骤
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {'channel' : self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
if not e_context.is_pass():
|
||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE:
|
||||
msg = context['msg']
|
||||
file_name = TmpDir().path() + context.content
|
||||
msg.download(file_name)
|
||||
reply = super().build_voice_to_text(file_name)
|
||||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
|
||||
context.content = reply.content # 语音转文字后,将文字内容作为新的context
|
||||
context.type = ContextType.TEXT
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
if reply.type == ReplyType.TEXT:
|
||||
if conf().get('voice_reply_voice'):
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
else:
|
||||
logger.error('[WX] unknown context type: {}'.format(context.type))
|
||||
return
|
||||
def _send_logout():
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id:
|
||||
chat_client.send_logout()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
||||
|
||||
# reply的包装步骤
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
|
||||
reply=e_context['reply']
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if context['isgroup']:
|
||||
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = str(reply.type)+":\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
||||
pass
|
||||
else:
|
||||
logger.error('[WX] unknown reply type: {}'.format(reply.type))
|
||||
return
|
||||
|
||||
# reply的发送步骤
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {'channel' : self, 'context': context, 'reply': reply}))
|
||||
reply=e_context['reply']
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
|
||||
self.send(reply, context['receiver'])
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
|
||||
def check_contain(content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
def _send_qr_code(qrcode_list: list):
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id:
|
||||
chat_client.send_qrcode(qrcode_list)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
102
channel/wechat/wechat_message.py
Normal file
102
channel/wechat/wechat_message.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import re
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
|
||||
class WechatMessage(ChatMessage):
|
||||
def __init__(self, itchat_msg, is_group=False):
|
||||
super().__init__(itchat_msg)
|
||||
self.msg_id = itchat_msg["MsgId"]
|
||||
self.create_time = itchat_msg["CreateTime"]
|
||||
self.is_group = is_group
|
||||
|
||||
if itchat_msg["Type"] == TEXT:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = itchat_msg["Text"]
|
||||
elif itchat_msg["Type"] == VOICE:
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
|
||||
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
|
||||
# 这里只能得到nickname, actual_user_id还是机器人的id
|
||||
if "加入了群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
|
||||
elif "加入群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif is_group and ("移出了群聊" in itchat_msg["Content"]):
|
||||
self.ctype = ContextType.EXIT_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
|
||||
self.ctype = ContextType.ACCEPT_FRIEND
|
||||
self.content = itchat_msg["Content"]
|
||||
elif "拍了拍我" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.PATPAT
|
||||
self.content = itchat_msg["Content"]
|
||||
if is_group:
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
else:
|
||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||
elif itchat_msg["Type"] == ATTACHMENT:
|
||||
self.ctype = ContextType.FILE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == SHARING:
|
||||
self.ctype = ContextType.SHARING
|
||||
self.content = itchat_msg.get("Url")
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
|
||||
|
||||
self.from_user_id = itchat_msg["FromUserName"]
|
||||
self.to_user_id = itchat_msg["ToUserName"]
|
||||
|
||||
user_id = itchat.instance.storageClass.userName
|
||||
nickname = itchat.instance.storageClass.nickName
|
||||
|
||||
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
|
||||
# 以下很繁琐,一句话总结:能填的都填了。
|
||||
if self.from_user_id == user_id:
|
||||
self.from_user_nickname = nickname
|
||||
if self.to_user_id == user_id:
|
||||
self.to_user_nickname = nickname
|
||||
try: # 陌生人时候, User字段可能不存在
|
||||
# my_msg 为True是表示是自己发送的消息
|
||||
self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
|
||||
itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
|
||||
self.other_user_id = itchat_msg["User"]["UserName"]
|
||||
self.other_user_nickname = itchat_msg["User"]["NickName"]
|
||||
if self.other_user_id == self.from_user_id:
|
||||
self.from_user_nickname = self.other_user_nickname
|
||||
if self.other_user_id == self.to_user_id:
|
||||
self.to_user_nickname = self.other_user_nickname
|
||||
if itchat_msg["User"].get("Self"):
|
||||
# 自身的展示名,当设置了群昵称时,该字段表示群昵称
|
||||
self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
|
||||
except KeyError as e: # 处理偶尔没有对方信息的情况
|
||||
logger.warn("[WX]get other_user_id failed: " + str(e))
|
||||
if self.from_user_id == user_id:
|
||||
self.other_user_id = self.to_user_id
|
||||
else:
|
||||
self.other_user_id = self.from_user_id
|
||||
|
||||
if self.is_group:
|
||||
self.is_at = itchat_msg["IsAt"]
|
||||
self.actual_user_id = itchat_msg["ActualUserName"]
|
||||
if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.EXIT_GROUP]:
|
||||
self.actual_user_nickname = itchat_msg["ActualNickName"]
|
||||
@@ -4,289 +4,126 @@
|
||||
wechaty channel
|
||||
Python Wechaty - https://github.com/wechaty/python-wechaty
|
||||
"""
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import requests
|
||||
import pysilk
|
||||
import wave
|
||||
from pydub import AudioSegment
|
||||
from typing import Optional, Union
|
||||
from bridge.context import Context, ContextType
|
||||
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore
|
||||
from wechaty import Wechaty, Contact
|
||||
from wechaty.user import Message, Room, MiniProgram, UrlLink
|
||||
from channel.channel import Channel
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
|
||||
from wechaty import Contact, Wechaty
|
||||
from wechaty.user import Message
|
||||
from wechaty_puppet import FileBox
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.context import Context
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wechaty_message import WechatyMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
try:
|
||||
from voice.audio_convert import any_to_sil
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
class WechatyChannel(Channel):
|
||||
|
||||
@singleton
|
||||
class WechatyChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
super().__init__()
|
||||
|
||||
def startup(self):
|
||||
config = conf()
|
||||
token = config.get("wechaty_puppet_service_token")
|
||||
os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
|
||||
asyncio.run(self.main())
|
||||
|
||||
async def main(self):
|
||||
config = conf()
|
||||
# 使用PadLocal协议 比较稳定(免费web协议 os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:8080')
|
||||
token = config.get('wechaty_puppet_service_token')
|
||||
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
|
||||
global bot
|
||||
bot = Wechaty()
|
||||
|
||||
bot.on('scan', self.on_scan)
|
||||
bot.on('login', self.on_login)
|
||||
bot.on('message', self.on_message)
|
||||
await bot.start()
|
||||
loop = asyncio.get_event_loop()
|
||||
# 将asyncio的loop传入处理线程
|
||||
self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
|
||||
self.bot = Wechaty()
|
||||
self.bot.on("login", self.on_login)
|
||||
self.bot.on("message", self.on_message)
|
||||
await self.bot.start()
|
||||
|
||||
async def on_login(self, contact: Contact):
|
||||
logger.info('[WX] login user={}'.format(contact))
|
||||
self.user_id = contact.contact_id
|
||||
self.name = contact.name
|
||||
logger.info("[WX] login user={}".format(contact))
|
||||
|
||||
async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None,
|
||||
data: Optional[str] = None):
|
||||
contact = self.Contact.load(self.contact_id)
|
||||
logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code))
|
||||
# print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}')
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver_id = context["receiver"]
|
||||
loop = asyncio.get_event_loop()
|
||||
if context["isgroup"]:
|
||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
|
||||
else:
|
||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
|
||||
msg = None
|
||||
if reply.type == ReplyType.TEXT:
|
||||
msg = reply.content
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
msg = reply.content
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
voiceLength = None
|
||||
file_path = reply.content
|
||||
sil_file = os.path.splitext(file_path)[0] + ".sil"
|
||||
voiceLength = int(any_to_sil(file_path, sil_file))
|
||||
if voiceLength >= 60000:
|
||||
voiceLength = 60000
|
||||
logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
|
||||
# 发送语音
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
|
||||
if voiceLength is not None:
|
||||
msg.metadata["voiceLength"] = voiceLength
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
try:
|
||||
os.remove(file_path)
|
||||
if sil_file != file_path:
|
||||
os.remove(sil_file)
|
||||
except Exception as e:
|
||||
pass
|
||||
logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
|
||||
async def on_message(self, msg: Message):
|
||||
"""
|
||||
listen for message event
|
||||
"""
|
||||
from_contact = msg.talker() # 获取消息的发送者
|
||||
to_contact = msg.to() # 接收人
|
||||
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
||||
from_user_id = from_contact.contact_id
|
||||
to_user_id = to_contact.contact_id # 接收人id
|
||||
# other_user_id = msg['User']['UserName'] # 对手方id
|
||||
content = msg.text()
|
||||
mention_content = await msg.mention_text() # 返回过滤掉@name后的消息
|
||||
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
|
||||
conversation: Union[Room, Contact] = from_contact if room is None else room
|
||||
|
||||
if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||
if not msg.is_self() and match_prefix is not None:
|
||||
# 好友向自己发送消息
|
||||
if match_prefix != '':
|
||||
str_list = content.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
content = str_list[1].strip()
|
||||
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
await self._do_send_img(content, from_user_id)
|
||||
else:
|
||||
await self._do_send(content, from_user_id)
|
||||
elif msg.is_self() and match_prefix:
|
||||
# 自己给好友发送消息
|
||||
str_list = content.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
content = str_list[1].strip()
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
await self._do_send_img(content, to_user_id)
|
||||
else:
|
||||
await self._do_send(content, to_user_id)
|
||||
elif room is None and msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
|
||||
if not msg.is_self(): # 接收语音消息
|
||||
# 下载语音文件
|
||||
voice_file = await msg.to_file_box()
|
||||
silk_file = TmpDir().path() + voice_file.name
|
||||
await voice_file.to_file(silk_file)
|
||||
logger.info("[WX]receive voice file: " + silk_file)
|
||||
# 将文件转成wav格式音频
|
||||
wav_file = silk_file.replace(".slk", ".wav")
|
||||
with open(silk_file, 'rb') as f:
|
||||
silk_data = f.read()
|
||||
pcm_data = pysilk.decode(silk_data)
|
||||
|
||||
with wave.open(wav_file, 'wb') as wav_data:
|
||||
wav_data.setnchannels(1)
|
||||
wav_data.setsampwidth(2)
|
||||
wav_data.setframerate(24000)
|
||||
wav_data.writeframes(pcm_data)
|
||||
if os.path.exists(wav_file):
|
||||
converter_state = "true" # 转换wav成功
|
||||
else:
|
||||
converter_state = "false" # 转换wav失败
|
||||
logger.info("[WX]receive voice converter: " + converter_state)
|
||||
# 语音识别为文本
|
||||
query = super().build_voice_to_text(wav_file).content
|
||||
# 交验关键字
|
||||
match_prefix = self.check_prefix(query, conf().get('single_chat_prefix'))
|
||||
if match_prefix is not None:
|
||||
if match_prefix != '':
|
||||
str_list = query.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
query = str_list[1].strip()
|
||||
# 返回消息
|
||||
if conf().get('voice_reply_voice'):
|
||||
await self._do_send_voice(query, from_user_id)
|
||||
else:
|
||||
await self._do_send(query, from_user_id)
|
||||
else:
|
||||
logger.info("[WX]receive voice check prefix: " + 'False')
|
||||
# 清除缓存文件
|
||||
os.remove(wav_file)
|
||||
os.remove(silk_file)
|
||||
elif room and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||
# 群组&文本消息
|
||||
room_id = room.room_id
|
||||
room_name = await room.topic()
|
||||
from_user_id = from_contact.contact_id
|
||||
from_user_name = from_contact.name
|
||||
is_at = await msg.mention_self()
|
||||
content = mention_content
|
||||
config = conf()
|
||||
match_prefix = (is_at and not config.get("group_at_off", False)) \
|
||||
or self.check_prefix(content, config.get('group_chat_prefix')) \
|
||||
or self.check_contain(content, config.get('group_chat_keyword'))
|
||||
# Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
|
||||
# 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
|
||||
prefixes = config.get('group_chat_prefix')
|
||||
for prefix in prefixes:
|
||||
if content.startswith(prefix):
|
||||
content = content.replace(prefix, '', 1).strip()
|
||||
break
|
||||
if ('ALL_GROUP' in config.get('group_name_white_list') or room_name in config.get(
|
||||
'group_name_white_list') or self.check_contain(room_name, config.get(
|
||||
'group_name_keyword_white_list'))) and match_prefix:
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
await self._do_send_group_img(content, room_id)
|
||||
else:
|
||||
await self._do_send_group(content, room_id, room_name, from_user_id, from_user_name)
|
||||
|
||||
async def send(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
|
||||
if receiver:
|
||||
contact = await bot.Contact.find(receiver)
|
||||
await contact.say(message)
|
||||
|
||||
async def send_group(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
|
||||
if receiver:
|
||||
room = await bot.Room.find(receiver)
|
||||
await room.say(message)
|
||||
|
||||
async def _do_send(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context['session_id'] = reply_user_id
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
async def _do_send_voice(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context['session_id'] = reply_user_id
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
# 转换 mp3 文件为 silk 格式
|
||||
mp3_file = super().build_text_to_voice(reply_text).content
|
||||
silk_file = mp3_file.replace(".mp3", ".silk")
|
||||
# Load the MP3 file
|
||||
audio = AudioSegment.from_file(mp3_file, format="mp3")
|
||||
# Convert to WAV format
|
||||
audio = audio.set_frame_rate(24000).set_channels(1)
|
||||
wav_data = audio.raw_data
|
||||
sample_width = audio.sample_width
|
||||
# Encode to SILK format
|
||||
silk_data = pysilk.encode(wav_data, 24000)
|
||||
# Save the silk file
|
||||
with open(silk_file, "wb") as f:
|
||||
f.write(silk_data)
|
||||
# 发送语音
|
||||
t = int(time.time())
|
||||
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
|
||||
await self.send(file_box, reply_user_id)
|
||||
# 清除缓存文件
|
||||
os.remove(mp3_file)
|
||||
os.remove(silk_file)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
async def _do_send_img(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.IMAGE_CREATE, query)
|
||||
img_url = super().build_reply_content(query, context).content
|
||||
if not img_url:
|
||||
return
|
||||
# 图片下载
|
||||
# pic_res = requests.get(img_url, stream=True)
|
||||
# image_storage = io.BytesIO()
|
||||
# for block in pic_res.iter_content(1024):
|
||||
# image_storage.write(block)
|
||||
# image_storage.seek(0)
|
||||
|
||||
# 图片发送
|
||||
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
|
||||
t = int(time.time())
|
||||
file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
|
||||
await self.send(file_box, reply_user_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name):
|
||||
if not query:
|
||||
cmsg = await WechatyMessage(msg)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX] {}".format(e))
|
||||
return
|
||||
context = Context(ContextType.TEXT, query)
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
if ('ALL_GROUP' in group_chat_in_one_session or \
|
||||
group_name in group_chat_in_one_session or \
|
||||
self.check_contain(group_name, group_chat_in_one_session)):
|
||||
context['session_id'] = str(group_id)
|
||||
else:
|
||||
context['session_id'] = str(group_id) + '-' + str(group_user_id)
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
reply_text = '@' + group_user_name + ' ' + reply_text.strip()
|
||||
await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
|
||||
|
||||
async def _do_send_group_img(self, query, reply_room_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.IMAGE_CREATE, query)
|
||||
img_url = super().build_reply_content(query, context).content
|
||||
if not img_url:
|
||||
return
|
||||
# 图片发送
|
||||
logger.info('[WX] sendImage, receiver={}'.format(reply_room_id))
|
||||
t = int(time.time())
|
||||
file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
|
||||
await self.send_group(file_box, reply_room_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def check_prefix(self, content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
def check_contain(self, content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
logger.exception("[WX] {}".format(e))
|
||||
return
|
||||
logger.debug("[WX] message:{}".format(cmsg))
|
||||
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
||||
isgroup = room is not None
|
||||
ctype = cmsg.ctype
|
||||
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
|
||||
if context:
|
||||
logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
|
||||
self.produce(context)
|
||||
|
||||
89
channel/wechat/wechaty_message.py
Normal file
89
channel/wechat/wechaty_message.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from wechaty import MessageType
|
||||
from wechaty.user import Message
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class aobject(object):
|
||||
"""Inheriting this class allows you to define an async __init__.
|
||||
|
||||
So you can create objects by doing something like `await MyClass(params)`
|
||||
"""
|
||||
|
||||
async def __new__(cls, *a, **kw):
|
||||
instance = super().__new__(cls)
|
||||
await instance.__init__(*a, **kw)
|
||||
return instance
|
||||
|
||||
async def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class WechatyMessage(ChatMessage, aobject):
|
||||
async def __init__(self, wechaty_msg: Message):
|
||||
super().__init__(wechaty_msg)
|
||||
|
||||
room = wechaty_msg.room()
|
||||
|
||||
self.msg_id = wechaty_msg.message_id
|
||||
self.create_time = wechaty_msg.payload.timestamp
|
||||
self.is_group = room is not None
|
||||
|
||||
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wechaty_msg.text()
|
||||
elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
|
||||
self.ctype = ContextType.VOICE
|
||||
voice_file = await wechaty_msg.to_file_box()
|
||||
self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
|
||||
|
||||
def func():
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
|
||||
|
||||
self._prepare_fn = func
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
|
||||
|
||||
from_contact = wechaty_msg.talker() # 获取消息的发送者
|
||||
self.from_user_id = from_contact.contact_id
|
||||
self.from_user_nickname = from_contact.name
|
||||
|
||||
# group中的from和to,wechaty跟itchat含义不一样
|
||||
# wecahty: from是消息实际发送者, to:所在群
|
||||
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
|
||||
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
|
||||
|
||||
if self.is_group:
|
||||
self.to_user_id = room.room_id
|
||||
self.to_user_nickname = await room.topic()
|
||||
else:
|
||||
to_contact = wechaty_msg.to()
|
||||
self.to_user_id = to_contact.contact_id
|
||||
self.to_user_nickname = to_contact.name
|
||||
|
||||
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
|
||||
self.other_user_id = self.to_user_id
|
||||
self.other_user_nickname = self.to_user_nickname
|
||||
else:
|
||||
self.other_user_id = self.from_user_id
|
||||
self.other_user_nickname = self.from_user_nickname
|
||||
|
||||
if self.is_group: # wechaty群聊中,实际发送用户就是from_user
|
||||
self.is_at = await wechaty_msg.mention_self()
|
||||
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
|
||||
name = wechaty_msg.wechaty.user_self().name
|
||||
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
|
||||
if re.search(pattern, self.content):
|
||||
logger.debug(f"wechaty message {self.msg_id} include at")
|
||||
self.is_at = True
|
||||
|
||||
self.actual_user_id = self.from_user_id
|
||||
self.actual_user_nickname = self.from_user_nickname
|
||||
85
channel/wechatcom/README.md
Normal file
85
channel/wechatcom/README.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# 企业微信应用号channel
|
||||
|
||||
企业微信官方提供了客服、应用等API,本channel使用的是企业微信的自建应用API的能力。
|
||||
|
||||
因为未来可能还会开发客服能力,所以本channel的类型名叫作`wechatcom_app`。
|
||||
|
||||
`wechatcom_app` channel支持插件系统和图片声音交互等能力,除了无法加入群聊,作为个人使用的私人助理已绰绰有余。
|
||||
|
||||
## 开始之前
|
||||
|
||||
- 在企业中确认自己拥有在企业内自建应用的权限。
|
||||
- 如果没有权限或者是个人用户,也可创建未认证的企业。操作方式:登录手机企业微信,选择`创建/加入企业`来创建企业,类型请选择企业,企业名称可随意填写。
|
||||
未认证的企业有100人的服务人数上限,其他功能与认证企业没有差异。
|
||||
|
||||
本channel需安装的依赖与公众号一致,需要安装`wechatpy`和`web.py`,它们包含在`requirements-optional.txt`中。
|
||||
|
||||
此外,如果你是`Linux`系统,除了`ffmpeg`还需要安装`amr`编码器,否则会出现找不到编码器的错误,无法正常使用语音功能。
|
||||
|
||||
- Ubuntu/Debian
|
||||
|
||||
```bash
|
||||
apt-get install libavcodec-extra
|
||||
```
|
||||
|
||||
- Alpine
|
||||
|
||||
需自行编译`ffmpeg`,在编译参数里加入`amr`编码器的支持
|
||||
|
||||
## 使用方法
|
||||
|
||||
1.查看企业ID
|
||||
|
||||
- 扫码登陆[企业微信后台](https://work.weixin.qq.com)
|
||||
- 选择`我的企业`,点击`企业信息`,记住该`企业ID`
|
||||
|
||||
2.创建自建应用
|
||||
|
||||
- 选择应用管理, 在自建区选创建应用来创建企业自建应用
|
||||
- 上传应用logo,填写应用名称等项
|
||||
- 创建应用后进入应用详情页面,记住`AgentId`和`Secert`
|
||||
|
||||
3.配置应用
|
||||
|
||||
- 在详情页点击`企业可信IP`的配置(没看到可以不管),填入你服务器的公网IP,如果不知道可以先不填
|
||||
- 点击`接收消息`下的启用API接收消息
|
||||
- `URL`填写格式为`http://url:port/wxcomapp`,`port`是程序监听的端口,默认是9898
|
||||
如果是未认证的企业,url可直接使用服务器的IP。如果是认证企业,需要使用备案的域名,可使用二级域名。
|
||||
- `Token`可随意填写,停留在这个页面
|
||||
- 在程序根目录`config.json`中增加配置(**去掉注释**),`wechatcomapp_aes_key`是当前页面的`wechatcomapp_aes_key`
|
||||
|
||||
```python
|
||||
"channel_type": "wechatcom_app",
|
||||
"wechatcom_corp_id": "", # 企业微信公司的corpID
|
||||
"wechatcomapp_token": "", # 企业微信app的token
|
||||
"wechatcomapp_port": 9898, # 企业微信app的服务端口, 不需要端口转发
|
||||
"wechatcomapp_secret": "", # 企业微信app的secret
|
||||
"wechatcomapp_agent_id": "", # 企业微信app的agent_id
|
||||
"wechatcomapp_aes_key": "", # 企业微信app的aes_key
|
||||
```
|
||||
|
||||
- 运行程序,在页面中点击保存,保存成功说明验证成功
|
||||
|
||||
4.连接个人微信
|
||||
|
||||
选择`我的企业`,点击`微信插件`,下面有个邀请关注的二维码。微信扫码后,即可在微信中看到对应企业,在这里你便可以和机器人沟通。
|
||||
|
||||
向机器人发送消息,如果日志里出现报错:
|
||||
|
||||
```bash
|
||||
Error code: 60020, message: "not allow to access from your ip, ...from ip: xx.xx.xx.xx"
|
||||
```
|
||||
|
||||
意思是IP不可信,需要参考上一步的`企业可信IP`配置,把这里的IP加进去。
|
||||
|
||||
~~### Railway部署方式~~(2023-06-08已失效)
|
||||
|
||||
~~公众号不能在`Railway`上部署,但企业微信应用[可以](https://railway.app/template/-FHS--?referralCode=RC3znh)!~~
|
||||
|
||||
~~填写配置后,将部署完成后的网址```**.railway.app/wxcomapp```,填写在上一步的URL中。发送信息后观察日志,把报错的IP加入到可信IP。(每次重启后都需要加入可信IP)~~
|
||||
|
||||
## 测试体验
|
||||
|
||||
AIGC开放社区中已经部署了多个可免费使用的Bot,扫描下方的二维码会自动邀请你来体验。
|
||||
|
||||
<img width="200" src="../../docs/images/aigcopen.png">
|
||||
178
channel/wechatcom/wechatcomapp_channel.py
Normal file
178
channel/wechatcom/wechatcomapp_channel.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# -*- coding=utf-8 -*-
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
import web
|
||||
from wechatpy.enterprise import create_reply, parse_message
|
||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||
from wechatpy.enterprise.exceptions import InvalidCorpIdException
|
||||
from wechatpy.exceptions import InvalidSignatureException, WeChatClientException
|
||||
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechatcom.wechatcomapp_client import WechatComAppClient
|
||||
from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import compress_imgfile, fsize, split_string_by_utf8_length
|
||||
from config import conf, subscribe_msg
|
||||
from voice.audio_convert import any_to_amr, split_audio
|
||||
|
||||
MAX_UTF8_LEN = 2048
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatComAppChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.corp_id = conf().get("wechatcom_corp_id")
|
||||
self.secret = conf().get("wechatcomapp_secret")
|
||||
self.agent_id = conf().get("wechatcomapp_agent_id")
|
||||
self.token = conf().get("wechatcomapp_token")
|
||||
self.aes_key = conf().get("wechatcomapp_aes_key")
|
||||
print(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
|
||||
logger.info(
|
||||
"[wechatcom] init: corp_id: {}, secret: {}, agent_id: {}, token: {}, aes_key: {}".format(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
|
||||
)
|
||||
self.crypto = WeChatCrypto(self.token, self.aes_key, self.corp_id)
|
||||
self.client = WechatComAppClient(self.corp_id, self.secret)
|
||||
|
||||
def startup(self):
|
||||
# start message listener
|
||||
urls = ("/wxcomapp/?", "channel.wechatcom.wechatcomapp_channel.Query")
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("wechatcomapp_port", 9898)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if reply.type in [ReplyType.TEXT, ReplyType.ERROR, ReplyType.INFO]:
|
||||
reply_text = reply.content
|
||||
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
||||
if len(texts) > 1:
|
||||
logger.info("[wechatcom] text too long, split into {} parts".format(len(texts)))
|
||||
for i, text in enumerate(texts):
|
||||
self.client.message.send_text(self.agent_id, receiver, text)
|
||||
if i != len(texts) - 1:
|
||||
time.sleep(0.5) # 休眠0.5秒,防止发送过快乱序
|
||||
logger.info("[wechatcom] Do send text to {}: {}".format(receiver, reply_text))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
try:
|
||||
media_ids = []
|
||||
file_path = reply.content
|
||||
amr_file = os.path.splitext(file_path)[0] + ".amr"
|
||||
any_to_amr(file_path, amr_file)
|
||||
duration, files = split_audio(amr_file, 60 * 1000)
|
||||
if len(files) > 1:
|
||||
logger.info("[wechatcom] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
|
||||
for path in files:
|
||||
response = self.client.media.upload("voice", open(path, "rb"))
|
||||
logger.debug("[wechatcom] upload voice response: {}".format(response))
|
||||
media_ids.append(response["media_id"])
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatcom] upload voice failed: {}".format(e))
|
||||
return
|
||||
try:
|
||||
os.remove(file_path)
|
||||
if amr_file != file_path:
|
||||
os.remove(amr_file)
|
||||
except Exception:
|
||||
pass
|
||||
for media_id in media_ids:
|
||||
self.client.message.send_voice(self.agent_id, receiver, media_id)
|
||||
time.sleep(1)
|
||||
logger.info("[wechatcom] sendVoice={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
sz = fsize(image_storage)
|
||||
if sz >= 10 * 1024 * 1024:
|
||||
logger.info("[wechatcom] image too large, ready to compress, sz={}".format(sz))
|
||||
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
||||
logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
|
||||
image_storage.seek(0)
|
||||
try:
|
||||
response = self.client.media.upload("image", image_storage)
|
||||
logger.debug("[wechatcom] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatcom] upload image failed: {}".format(e))
|
||||
return
|
||||
|
||||
self.client.message.send_image(self.agent_id, receiver, response["media_id"])
|
||||
logger.info("[wechatcom] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
sz = fsize(image_storage)
|
||||
if sz >= 10 * 1024 * 1024:
|
||||
logger.info("[wechatcom] image too large, ready to compress, sz={}".format(sz))
|
||||
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
||||
logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
|
||||
image_storage.seek(0)
|
||||
try:
|
||||
response = self.client.media.upload("image", image_storage)
|
||||
logger.debug("[wechatcom] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatcom] upload image failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_image(self.agent_id, receiver, response["media_id"])
|
||||
logger.info("[wechatcom] sendImage, receiver={}".format(receiver))
|
||||
|
||||
|
||||
class Query:
|
||||
def GET(self):
|
||||
channel = WechatComAppChannel()
|
||||
params = web.input()
|
||||
logger.info("[wechatcom] receive params: {}".format(params))
|
||||
try:
|
||||
signature = params.msg_signature
|
||||
timestamp = params.timestamp
|
||||
nonce = params.nonce
|
||||
echostr = params.echostr
|
||||
echostr = channel.crypto.check_signature(signature, timestamp, nonce, echostr)
|
||||
except InvalidSignatureException:
|
||||
raise web.Forbidden()
|
||||
return echostr
|
||||
|
||||
def POST(self):
|
||||
channel = WechatComAppChannel()
|
||||
params = web.input()
|
||||
logger.info("[wechatcom] receive params: {}".format(params))
|
||||
try:
|
||||
signature = params.msg_signature
|
||||
timestamp = params.timestamp
|
||||
nonce = params.nonce
|
||||
message = channel.crypto.decrypt_message(web.data(), signature, timestamp, nonce)
|
||||
except (InvalidSignatureException, InvalidCorpIdException):
|
||||
raise web.Forbidden()
|
||||
msg = parse_message(message)
|
||||
logger.debug("[wechatcom] receive message: {}, msg= {}".format(message, msg))
|
||||
if msg.type == "event":
|
||||
if msg.event == "subscribe":
|
||||
reply_content = subscribe_msg()
|
||||
if reply_content:
|
||||
reply = create_reply(reply_content, msg).render()
|
||||
res = channel.crypto.encrypt_message(reply, nonce, timestamp)
|
||||
return res
|
||||
else:
|
||||
try:
|
||||
wechatcom_msg = WechatComAppMessage(msg, client=channel.client)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[wechatcom] " + str(e))
|
||||
return "success"
|
||||
context = channel._compose_context(
|
||||
wechatcom_msg.ctype,
|
||||
wechatcom_msg.content,
|
||||
isgroup=False,
|
||||
msg=wechatcom_msg,
|
||||
)
|
||||
if context:
|
||||
channel.produce(context)
|
||||
return "success"
|
||||
21
channel/wechatcom/wechatcomapp_client.py
Normal file
21
channel/wechatcom/wechatcomapp_client.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
|
||||
|
||||
class WechatComAppClient(WeChatClient):
|
||||
def __init__(self, corp_id, secret, access_token=None, session=None, timeout=None, auto_retry=True):
|
||||
super(WechatComAppClient, self).__init__(corp_id, secret, access_token, session, timeout, auto_retry)
|
||||
self.fetch_access_token_lock = threading.Lock()
|
||||
|
||||
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
|
||||
with self.fetch_access_token_lock:
|
||||
access_token = self.session.get(self.access_token_key)
|
||||
if access_token:
|
||||
if not self.expires_at:
|
||||
return access_token
|
||||
timestamp = time.time()
|
||||
if self.expires_at - timestamp > 60:
|
||||
return access_token
|
||||
return super().fetch_access_token()
|
||||
52
channel/wechatcom/wechatcomapp_message.py
Normal file
52
channel/wechatcom/wechatcomapp_message.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class WechatComAppMessage(ChatMessage):
|
||||
def __init__(self, msg, client: WeChatClient, is_group=False):
|
||||
super().__init__(msg)
|
||||
self.msg_id = msg.id
|
||||
self.create_time = msg.time
|
||||
self.is_group = is_group
|
||||
|
||||
if msg.type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = msg.content
|
||||
elif msg.type == "voice":
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
|
||||
|
||||
def download_voice():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
response = client.media.download(msg.media_id)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[wechatcom] Failed to download voice file, {response.content}")
|
||||
|
||||
self._prepare_fn = download_voice
|
||||
elif msg.type == "image":
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
|
||||
|
||||
def download_image():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
response = client.media.download(msg.media_id)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[wechatcom] Failed to download image file, {response.content}")
|
||||
|
||||
self._prepare_fn = download_image
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
|
||||
|
||||
self.from_user_id = msg.source
|
||||
self.to_user_id = msg.target
|
||||
self.other_user_id = msg.source
|
||||
100
channel/wechatmp/README.md
Normal file
100
channel/wechatmp/README.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# 微信公众号channel
|
||||
|
||||
鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。
|
||||
目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制。
|
||||
|
||||
## 使用方法(订阅号,服务号类似)
|
||||
|
||||
在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。
|
||||
|
||||
此外,需要在我们的服务器上安装python的web框架web.py和wechatpy。
|
||||
以ubuntu为例(在ubuntu 22.04上测试):
|
||||
```
|
||||
pip3 install web.py
|
||||
pip3 install wechatpy
|
||||
```
|
||||
|
||||
然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。
|
||||
|
||||
然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。`URL`填写格式为`http://url/wx`,可使用IP(成功几率看脸),`Token`是你自己编的一个特定的令牌。消息加解密方式如果选择了需要加密的模式,需要在配置中填写`wechatmp_aes_key`。
|
||||
|
||||
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
|
||||
```
|
||||
"channel_type": "wechatmp", # 如果通过了微信认证,将"wechatmp"替换为"wechatmp_service",可极大的优化使用体验
|
||||
"wechatmp_token": "xxxx", # 微信公众平台的Token
|
||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||
"wechatmp_app_id": "xxxx", # 微信公众平台的appID
|
||||
"wechatmp_app_secret": "xxxx", # 微信公众平台的appsecret
|
||||
"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
|
||||
"single_chat_prefix": [""], # 推荐设置,任意对话都可以触发回复,不添加前缀
|
||||
"single_chat_reply_prefix": "", # 推荐设置,回复不设置前缀
|
||||
"plugin_trigger_prefix": "&", # 推荐设置,在手机微信客户端中,$%^等符号与中文连在一起时会自动显示一段较大的间隔,用户体验不好。请不要使用管理员指令前缀"#",这会造成未知问题。
|
||||
```
|
||||
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口:
|
||||
```
|
||||
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080
|
||||
sudo iptables-save > /etc/iptables/rules.v4
|
||||
```
|
||||
第二个方法是让python程序直接监听80端口,在配置文件中设置`"wechatmp_port": 80` ,在linux上需要使用`sudo python3 app.py`启动程序。然而这会导致一系列环境和权限问题,因此不是推荐的方法。
|
||||
|
||||
443端口同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`中需要修改相应的证书路径。
|
||||
|
||||
程序启动并监听端口后,在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
|
||||
随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。
|
||||
|
||||
之后需要在公众号开发信息下将本机IP加入到IP白名单。
|
||||
|
||||
不然在启用后,发送语音、图片等消息可能会遇到如下报错:
|
||||
```
|
||||
'errcode': 40164, 'errmsg': 'invalid ip xx.xx.xx.xx not in whitelist rid
|
||||
```
|
||||
|
||||
|
||||
## 个人微信公众号的限制
|
||||
由于人微信公众号不能通过微信认证,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。
|
||||
|
||||
另外,由于微信官方的限制,自动回复有长度限制。因此这里将ChatGPT的回答进行了拆分,以满足限制。
|
||||
|
||||
## 私有api_key
|
||||
公共api有访问频率限制(免费账号每分钟最多3次ChatGPT的API调用),这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。
|
||||
|
||||
## 语音输入
|
||||
利用微信自带的语音识别功能,提供语音输入能力。需要在公众号管理页面的“设置与开发”->“接口权限”页面开启“接收语音识别结果”。
|
||||
|
||||
## 语音回复
|
||||
请在配置文件中添加以下词条:
|
||||
```
|
||||
"voice_reply_voice": true,
|
||||
```
|
||||
这样公众号将会用语音回复语音消息,实现语音对话。
|
||||
|
||||
默认的语音合成引擎是`google`,它是免费使用的。
|
||||
|
||||
如果要选择其他的语音合成引擎,请添加以下配置项:
|
||||
```
|
||||
"text_to_voice": "pytts"
|
||||
```
|
||||
|
||||
pytts是本地的语音合成引擎。还支持baidu,azure,这些你需要自行配置相关的依赖和key。
|
||||
|
||||
如果使用pytts,在ubuntu上需要安装如下依赖:
|
||||
```
|
||||
sudo apt update
|
||||
sudo apt install espeak
|
||||
sudo apt install ffmpeg
|
||||
python3 -m pip install pyttsx3
|
||||
```
|
||||
不是很建议开启pytts语音回复,因为它是离线本地计算,算的慢会拖垮服务器,且声音不好听。
|
||||
|
||||
## 图片回复
|
||||
现在认证公众号和非认证公众号都可以实现的图片和语音回复。但是非认证公众号使用了永久素材接口,每天有1000次的调用上限(每个月有10次重置机会,程序中已设定遇到上限会自动重置),且永久素材库存也有上限。因此对于非认证公众号,我们会在回复图片或者语音消息后的10秒内从永久素材库存内删除该素材。
|
||||
|
||||
## 测试
|
||||
目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp)),感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有详尽测试。百度的接口暂未测试。[wechatmp-stable分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp-stable)是较稳定的上个版本,但也缺少最新的功能支持。
|
||||
|
||||
## TODO
|
||||
- [x] 语音输入
|
||||
- [x] 图片输入
|
||||
- [x] 使用临时素材接口提供认证公众号的图片和语音回复
|
||||
- [x] 使用永久素材接口提供未认证公众号的图片和语音回复
|
||||
- [ ] 高并发支持
|
||||
75
channel/wechatmp/active_reply.py
Normal file
75
channel/wechatmp/active_reply.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import time
|
||||
|
||||
import web
|
||||
from wechatpy import parse_message
|
||||
from wechatpy.replies import create_reply
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
from channel.wechatmp.wechatmp_message import WeChatMPMessage
|
||||
from common.log import logger
|
||||
from config import conf, subscribe_msg
|
||||
|
||||
|
||||
# This class is instantiated once per query
|
||||
class Query:
|
||||
def GET(self):
|
||||
return verify_server(web.input())
|
||||
|
||||
def POST(self):
|
||||
# Make sure to return the instance that first created, @singleton will do that.
|
||||
try:
|
||||
args = web.input()
|
||||
verify_server(args)
|
||||
channel = WechatMPChannel()
|
||||
message = web.data()
|
||||
encrypt_func = lambda x: x
|
||||
if args.get("encrypt_type") == "aes":
|
||||
logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
|
||||
if not channel.crypto:
|
||||
raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
|
||||
message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
|
||||
encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
|
||||
else:
|
||||
logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
|
||||
msg = parse_message(message)
|
||||
if msg.type in ["text", "voice", "image"]:
|
||||
wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
|
||||
from_user = wechatmp_msg.from_user_id
|
||||
content = wechatmp_msg.content
|
||||
message_id = wechatmp_msg.msg_id
|
||||
|
||||
logger.info(
|
||||
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
|
||||
web.ctx.env.get("REMOTE_ADDR"),
|
||||
web.ctx.env.get("REMOTE_PORT"),
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
)
|
||||
)
|
||||
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
|
||||
else:
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
|
||||
if context:
|
||||
channel.produce(context)
|
||||
# The reply will be sent by channel.send() in another thread
|
||||
return "success"
|
||||
elif msg.type == "event":
|
||||
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
|
||||
if msg.event in ["subscribe", "subscribe_scan"]:
|
||||
reply_text = subscribe_msg()
|
||||
if reply_text:
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
else:
|
||||
return "success"
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
27
channel/wechatmp/common.py
Normal file
27
channel/wechatmp/common.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import web
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.utils import check_signature
|
||||
|
||||
from config import conf
|
||||
|
||||
MAX_UTF8_LEN = 2048
|
||||
|
||||
|
||||
class WeChatAPIException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def verify_server(data):
|
||||
try:
|
||||
signature = data.signature
|
||||
timestamp = data.timestamp
|
||||
nonce = data.nonce
|
||||
echostr = data.get("echostr", None)
|
||||
token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写
|
||||
check_signature(token, signature, timestamp, nonce)
|
||||
return echostr
|
||||
except InvalidSignatureException:
|
||||
raise web.Forbidden("Invalid signature")
|
||||
except Exception as e:
|
||||
raise web.Forbidden(str(e))
|
||||
211
channel/wechatmp/passive_reply.py
Normal file
211
channel/wechatmp/passive_reply.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import web
|
||||
from wechatpy import parse_message
|
||||
from wechatpy.replies import ImageReply, VoiceReply, create_reply
|
||||
import textwrap
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
from channel.wechatmp.wechatmp_message import WeChatMPMessage
|
||||
from common.log import logger
|
||||
from common.utils import split_string_by_utf8_length
|
||||
from config import conf, subscribe_msg
|
||||
|
||||
|
||||
# This class is instantiated once per query
|
||||
class Query:
|
||||
def GET(self):
|
||||
return verify_server(web.input())
|
||||
|
||||
def POST(self):
|
||||
try:
|
||||
args = web.input()
|
||||
verify_server(args)
|
||||
request_time = time.time()
|
||||
channel = WechatMPChannel()
|
||||
message = web.data()
|
||||
encrypt_func = lambda x: x
|
||||
if args.get("encrypt_type") == "aes":
|
||||
logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
|
||||
if not channel.crypto:
|
||||
raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
|
||||
message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
|
||||
encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
|
||||
else:
|
||||
logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
|
||||
msg = parse_message(message)
|
||||
if msg.type in ["text", "voice", "image"]:
|
||||
wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
|
||||
from_user = wechatmp_msg.from_user_id
|
||||
content = wechatmp_msg.content
|
||||
message_id = wechatmp_msg.msg_id
|
||||
|
||||
supported = True
|
||||
if "【收到不支持的消息类型,暂无法显示】" in content:
|
||||
supported = False # not supported, used to refresh
|
||||
|
||||
# New request
|
||||
if (
|
||||
channel.cache_dict.get(from_user) is None
|
||||
and from_user not in channel.running
|
||||
or content.startswith("#")
|
||||
and message_id not in channel.request_cnt # insert the godcmd
|
||||
):
|
||||
# The first query begin
|
||||
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
|
||||
else:
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
|
||||
logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))
|
||||
|
||||
if supported and context:
|
||||
channel.running.add(from_user)
|
||||
channel.produce(context)
|
||||
else:
|
||||
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
|
||||
if trigger_prefix or not supported:
|
||||
if trigger_prefix:
|
||||
reply_text = textwrap.dedent(
|
||||
f"""\
|
||||
请输入'{trigger_prefix}'接你想说的话跟我说话。
|
||||
例如:
|
||||
{trigger_prefix}你好,很高兴见到你。"""
|
||||
)
|
||||
else:
|
||||
reply_text = textwrap.dedent(
|
||||
"""\
|
||||
你好,很高兴见到你。
|
||||
请跟我说话吧。"""
|
||||
)
|
||||
else:
|
||||
logger.error(f"[wechatmp] unknown error")
|
||||
reply_text = textwrap.dedent(
|
||||
"""\
|
||||
未知错误,请稍后再试"""
|
||||
)
|
||||
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
# Wechat official server will request 3 times (5 seconds each), with the same message_id.
|
||||
# Because the interval is 5 seconds, here assumed that do not have multithreading problems.
|
||||
request_cnt = channel.request_cnt.get(message_id, 0) + 1
|
||||
channel.request_cnt[message_id] = request_cnt
|
||||
logger.info(
|
||||
"[wechatmp] Request {} from {} {} {}:{}\n{}".format(
|
||||
request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
|
||||
)
|
||||
)
|
||||
|
||||
task_running = True
|
||||
waiting_until = request_time + 4
|
||||
while time.time() < waiting_until:
|
||||
if from_user in channel.running:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
task_running = False
|
||||
break
|
||||
|
||||
reply_text = ""
|
||||
if task_running:
|
||||
if request_cnt < 3:
|
||||
# waiting for timeout (the POST request will be closed by Wechat official server)
|
||||
time.sleep(2)
|
||||
# and do nothing, waiting for the next request
|
||||
return "success"
|
||||
else: # request_cnt == 3:
|
||||
# return timeout message
|
||||
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
# reply is ready
|
||||
channel.request_cnt.pop(message_id)
|
||||
|
||||
# no return because of bandwords or other reasons
|
||||
if from_user not in channel.cache_dict and from_user not in channel.running:
|
||||
return "success"
|
||||
|
||||
# Only one request can access to the cached data
|
||||
try:
|
||||
(reply_type, reply_content) = channel.cache_dict[from_user].pop(0)
|
||||
if not channel.cache_dict[from_user]: # If popping the message makes the list empty, delete the user entry from cache
|
||||
del channel.cache_dict[from_user]
|
||||
except IndexError:
|
||||
return "success"
|
||||
|
||||
if reply_type == "text":
|
||||
if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
|
||||
reply_text = reply_content
|
||||
else:
|
||||
continue_text = "\n【未完待续,回复任意文字以继续】"
|
||||
splits = split_string_by_utf8_length(
|
||||
reply_content,
|
||||
MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
|
||||
max_split=1,
|
||||
)
|
||||
reply_text = splits[0] + continue_text
|
||||
channel.cache_dict[from_user].append(("text", splits[1]))
|
||||
|
||||
logger.info(
|
||||
"[wechatmp] Request {} do send to {} {}: {}\n{}".format(
|
||||
request_cnt,
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
reply_text,
|
||||
)
|
||||
)
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
elif reply_type == "voice":
|
||||
media_id = reply_content
|
||||
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
|
||||
logger.info(
|
||||
"[wechatmp] Request {} do send to {} {}: {} voice media_id {}".format(
|
||||
request_cnt,
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
media_id,
|
||||
)
|
||||
)
|
||||
replyPost = VoiceReply(message=msg)
|
||||
replyPost.media_id = media_id
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
elif reply_type == "image":
|
||||
media_id = reply_content
|
||||
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
|
||||
logger.info(
|
||||
"[wechatmp] Request {} do send to {} {}: {} image media_id {}".format(
|
||||
request_cnt,
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
media_id,
|
||||
)
|
||||
)
|
||||
replyPost = ImageReply(message=msg)
|
||||
replyPost.media_id = media_id
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
elif msg.type == "event":
|
||||
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
|
||||
if msg.event in ["subscribe", "subscribe_scan"]:
|
||||
reply_text = subscribe_msg()
|
||||
if reply_text:
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
else:
|
||||
return "success"
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
304
channel/wechatmp/wechatmp_channel.py
Normal file
304
channel/wechatmp/wechatmp_channel.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import imghdr
|
||||
import io
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
import web
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy.exceptions import WeChatClientException
|
||||
from collections import defaultdict
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_client import WechatMPClient
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import split_string_by_utf8_length
|
||||
from config import conf
|
||||
from voice.audio_convert import any_to_mp3, split_audio
|
||||
|
||||
# If using SSL, uncomment the following lines, and modify the certificate path.
|
||||
# from cheroot.server import HTTPServer
|
||||
# from cheroot.ssl.builtin import BuiltinSSLAdapter
|
||||
# HTTPServer.ssl_adapter = BuiltinSSLAdapter(
|
||||
# certificate='/ssl/cert.pem',
|
||||
# private_key='/ssl/cert.key')
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatMPChannel(ChatChannel):
|
||||
def __init__(self, passive_reply=True):
|
||||
super().__init__()
|
||||
self.passive_reply = passive_reply
|
||||
self.NOT_SUPPORT_REPLYTYPE = []
|
||||
appid = conf().get("wechatmp_app_id")
|
||||
secret = conf().get("wechatmp_app_secret")
|
||||
token = conf().get("wechatmp_token")
|
||||
aes_key = conf().get("wechatmp_aes_key")
|
||||
self.client = WechatMPClient(appid, secret)
|
||||
self.crypto = None
|
||||
if aes_key:
|
||||
self.crypto = WeChatCrypto(token, aes_key, appid)
|
||||
if self.passive_reply:
|
||||
# Cache the reply to the user's first message
|
||||
self.cache_dict = defaultdict(list)
|
||||
# Record whether the current message is being processed
|
||||
self.running = set()
|
||||
# Count the request from wechat official server by message_id
|
||||
self.request_cnt = dict()
|
||||
# The permanent media need to be deleted to avoid media number limit
|
||||
self.delete_media_loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=self.start_loop, args=(self.delete_media_loop,))
|
||||
t.setDaemon(True)
|
||||
t.start()
|
||||
|
||||
def startup(self):
|
||||
if self.passive_reply:
|
||||
urls = ("/wx", "channel.wechatmp.passive_reply.Query")
|
||||
else:
|
||||
urls = ("/wx", "channel.wechatmp.active_reply.Query")
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("wechatmp_port", 8080)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
def start_loop(self, loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
async def delete_media(self, media_id):
|
||||
logger.debug("[wechatmp] permanent media {} will be deleted in 10s".format(media_id))
|
||||
await asyncio.sleep(10)
|
||||
self.client.material.delete(media_id)
|
||||
logger.info("[wechatmp] permanent media {} has been deleted".format(media_id))
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if self.passive_reply:
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
|
||||
self.cache_dict[receiver].append(("text", reply_text))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
voice_file_path = reply.content
|
||||
duration, files = split_audio(voice_file_path, 60 * 1000)
|
||||
if len(files) > 1:
|
||||
logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
|
||||
|
||||
for path in files:
|
||||
# support: <2M, <60s, mp3/wma/wav/amr
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
response = self.client.material.add("voice", f)
|
||||
logger.debug("[wechatmp] upload voice response: {}".format(response))
|
||||
f_size = os.fstat(f.fileno()).st_size
|
||||
time.sleep(1.0 + 2 * f_size / 1024 / 1024)
|
||||
# todo check media_id
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("voice", media_id))
|
||||
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.material.add("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("image", media_id))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.material.add("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("image", media_id))
|
||||
elif reply.type == ReplyType.VIDEO_URL: # 从网络下载视频
|
||||
video_url = reply.content
|
||||
video_res = requests.get(video_url, stream=True)
|
||||
video_storage = io.BytesIO()
|
||||
for block in video_res.iter_content(1024):
|
||||
video_storage.write(block)
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.material.add("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] video uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("video", media_id))
|
||||
|
||||
elif reply.type == ReplyType.VIDEO: # 从文件读取视频
|
||||
video_storage = reply.content
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.material.add("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] video uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("video", media_id))
|
||||
|
||||
else:
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
||||
if len(texts) > 1:
|
||||
logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
|
||||
for i, text in enumerate(texts):
|
||||
self.client.message.send_text(receiver, text)
|
||||
if i != len(texts) - 1:
|
||||
time.sleep(0.5) # 休眠0.5秒,防止发送过快乱序
|
||||
logger.info("[wechatmp] Do send text to {}: {}".format(receiver, reply_text))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
try:
|
||||
file_path = reply.content
|
||||
file_name = os.path.basename(file_path)
|
||||
file_type = os.path.splitext(file_name)[1]
|
||||
if file_type == ".mp3":
|
||||
file_type = "audio/mpeg"
|
||||
elif file_type == ".amr":
|
||||
file_type = "audio/amr"
|
||||
else:
|
||||
mp3_file = os.path.splitext(file_path)[0] + ".mp3"
|
||||
any_to_mp3(file_path, mp3_file)
|
||||
file_path = mp3_file
|
||||
file_name = os.path.basename(file_path)
|
||||
file_type = "audio/mpeg"
|
||||
logger.info("[wechatmp] file_name: {}, file_type: {} ".format(file_name, file_type))
|
||||
media_ids = []
|
||||
duration, files = split_audio(file_path, 60 * 1000)
|
||||
if len(files) > 1:
|
||||
logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
|
||||
for path in files:
|
||||
# support: <2M, <60s, AMR\MP3
|
||||
response = self.client.media.upload("voice", (os.path.basename(path), open(path, "rb"), file_type))
|
||||
logger.debug("[wechatcom] upload voice response: {}".format(response))
|
||||
media_ids.append(response["media_id"])
|
||||
os.remove(path)
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
||||
return
|
||||
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for media_id in media_ids:
|
||||
self.client.message.send_voice(receiver, media_id)
|
||||
time.sleep(1)
|
||||
logger.info("[wechatmp] Do send voice to {}".format(receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.media.upload("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_image(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send image to {}".format(receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.media.upload("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_image(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send image to {}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL: # 从网络下载视频
|
||||
video_url = reply.content
|
||||
video_res = requests.get(video_url, stream=True)
|
||||
video_storage = io.BytesIO()
|
||||
for block in video_res.iter_content(1024):
|
||||
video_storage.write(block)
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.media.upload("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_video(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send video to {}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO: # 从文件读取视频
|
||||
video_storage = reply.content
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.media.upload("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_video(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send video to {}".format(receiver))
|
||||
return
|
||||
|
||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
|
||||
if self.passive_reply:
|
||||
self.running.remove(session_id)
|
||||
|
||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
|
||||
if self.passive_reply:
|
||||
assert session_id not in self.cache_dict
|
||||
self.running.remove(session_id)
|
||||
49
channel/wechatmp/wechatmp_client.py
Normal file
49
channel/wechatmp/wechatmp_client.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
from wechatpy.client import WeChatClient
|
||||
from wechatpy.exceptions import APILimitedException
|
||||
|
||||
from channel.wechatmp.common import *
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class WechatMPClient(WeChatClient):
|
||||
def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
|
||||
super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
|
||||
self.fetch_access_token_lock = threading.Lock()
|
||||
self.clear_quota_lock = threading.Lock()
|
||||
self.last_clear_quota_time = -1
|
||||
|
||||
def clear_quota(self):
|
||||
return self.post("clear_quota", data={"appid": self.appid})
|
||||
|
||||
def clear_quota_v2(self):
|
||||
return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})
|
||||
|
||||
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
|
||||
with self.fetch_access_token_lock:
|
||||
access_token = self.session.get(self.access_token_key)
|
||||
if access_token:
|
||||
if not self.expires_at:
|
||||
return access_token
|
||||
timestamp = time.time()
|
||||
if self.expires_at - timestamp > 60:
|
||||
return access_token
|
||||
return super().fetch_access_token()
|
||||
|
||||
def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
|
||||
try:
|
||||
return super()._request(method, url_or_endpoint, **kwargs)
|
||||
except APILimitedException as e:
|
||||
logger.error("[wechatmp] API quata has been used up. {}".format(e))
|
||||
if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
|
||||
with self.clear_quota_lock:
|
||||
if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
|
||||
self.last_clear_quota_time = time.time()
|
||||
response = self.clear_quota_v2()
|
||||
logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
|
||||
return super()._request(method, url_or_endpoint, **kwargs)
|
||||
else:
|
||||
logger.error("[wechatmp] last clear quota time is {}, less than 60s, skip clear quota")
|
||||
raise e
|
||||
56
channel/wechatmp/wechatmp_message.py
Normal file
56
channel/wechatmp/wechatmp_message.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-#
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class WeChatMPMessage(ChatMessage):
|
||||
def __init__(self, msg, client=None):
|
||||
super().__init__(msg)
|
||||
self.msg_id = msg.id
|
||||
self.create_time = msg.time
|
||||
self.is_group = False
|
||||
|
||||
if msg.type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = msg.content
|
||||
elif msg.type == "voice":
|
||||
if msg.recognition == None:
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
|
||||
|
||||
def download_voice():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
response = client.media.download(msg.media_id)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[wechatmp] Failed to download voice file, {response.content}")
|
||||
|
||||
self._prepare_fn = download_voice
|
||||
else:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = msg.recognition
|
||||
elif msg.type == "image":
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
|
||||
|
||||
def download_image():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
response = client.media.download(msg.media_id)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[wechatmp] Failed to download image file, {response.content}")
|
||||
|
||||
self._prepare_fn = download_image
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
|
||||
|
||||
self.from_user_id = msg.source
|
||||
self.to_user_id = msg.target
|
||||
self.other_user_id = msg.source
|
||||
17
channel/wework/run.py
Normal file
17
channel/wework/run.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
import time
|
||||
os.environ['ntwork_LOG'] = "ERROR"
|
||||
import ntwork
|
||||
|
||||
wework = ntwork.WeWork()
|
||||
|
||||
|
||||
def forever():
|
||||
try:
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
ntwork.exit_()
|
||||
os._exit(0)
|
||||
|
||||
|
||||
326
channel/wework/wework_channel.py
Normal file
326
channel/wework/wework_channel.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import threading
|
||||
os.environ['ntwork_LOG'] = "ERROR"
|
||||
import ntwork
|
||||
import requests
|
||||
import uuid
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wework.wework_message import *
|
||||
from channel.wework.wework_message import WeworkMessage
|
||||
from common.singleton import singleton
|
||||
from common.log import logger
|
||||
from common.time_check import time_checker
|
||||
from common.utils import compress_imgfile, fsize
|
||||
from config import conf
|
||||
from channel.wework.run import wework
|
||||
from channel.wework import run
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_wxid_by_name(room_members, group_wxid, name):
|
||||
if group_wxid in room_members:
|
||||
for member in room_members[group_wxid]['member_list']:
|
||||
if member['room_nickname'] == name or member['username'] == name:
|
||||
return member['user_id']
|
||||
return None # 如果没有找到对应的group_wxid或name,则返回None
|
||||
|
||||
|
||||
def download_and_compress_image(url, filename, quality=30):
|
||||
# 确定保存图片的目录
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
# 如果目录不存在,则创建目录
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
# 下载图片
|
||||
pic_res = requests.get(url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
|
||||
# 检查图片大小并可能进行压缩
|
||||
sz = fsize(image_storage)
|
||||
if sz >= 10 * 1024 * 1024: # 如果图片大于 10 MB
|
||||
logger.info("[wework] image too large, ready to compress, sz={}".format(sz))
|
||||
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
||||
logger.info("[wework] image compressed, sz={}".format(fsize(image_storage)))
|
||||
|
||||
# 将内存缓冲区的指针重置到起始位置
|
||||
image_storage.seek(0)
|
||||
|
||||
# 读取并保存图片
|
||||
image = Image.open(image_storage)
|
||||
image_path = os.path.join(directory, f"{filename}.png")
|
||||
image.save(image_path, "png")
|
||||
|
||||
return image_path
|
||||
|
||||
|
||||
def download_video(url, filename):
|
||||
# 确定保存视频的目录
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
# 如果目录不存在,则创建目录
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
# 下载视频
|
||||
response = requests.get(url, stream=True)
|
||||
total_size = 0
|
||||
|
||||
video_path = os.path.join(directory, f"{filename}.mp4")
|
||||
|
||||
with open(video_path, 'wb') as f:
|
||||
for block in response.iter_content(1024):
|
||||
total_size += len(block)
|
||||
|
||||
# 如果视频的总大小超过30MB (30 * 1024 * 1024 bytes),则停止下载并返回
|
||||
if total_size > 30 * 1024 * 1024:
|
||||
logger.info("[WX] Video is larger than 30MB, skipping...")
|
||||
return None
|
||||
|
||||
f.write(block)
|
||||
|
||||
return video_path
|
||||
|
||||
|
||||
def create_message(wework_instance, message, is_group):
|
||||
logger.debug(f"正在为{'群聊' if is_group else '单聊'}创建 WeworkMessage")
|
||||
cmsg = WeworkMessage(message, wework=wework_instance, is_group=is_group)
|
||||
logger.debug(f"cmsg:{cmsg}")
|
||||
return cmsg
|
||||
|
||||
|
||||
def handle_message(cmsg, is_group):
|
||||
logger.debug(f"准备用 WeworkChannel 处理{'群聊' if is_group else '单聊'}消息")
|
||||
if is_group:
|
||||
WeworkChannel().handle_group(cmsg)
|
||||
else:
|
||||
WeworkChannel().handle_single(cmsg)
|
||||
logger.debug(f"已用 WeworkChannel 处理完{'群聊' if is_group else '单聊'}消息")
|
||||
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: ChatMessage):
|
||||
msgId = cmsg.msg_id
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if create_time is None:
|
||||
return func(self, cmsg)
|
||||
if int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@wework.msg_register(
|
||||
[ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_LINK_CARD_MSG,ntwork.MT_RECV_FILE_MSG, ntwork.MT_RECV_VOICE_MSG])
|
||||
def all_msg_handler(wework_instance: ntwork.WeWork, message):
|
||||
logger.debug(f"收到消息: {message}")
|
||||
if 'data' in message:
|
||||
# 首先查找conversation_id,如果没有找到,则查找room_conversation_id
|
||||
conversation_id = message['data'].get('conversation_id', message['data'].get('room_conversation_id'))
|
||||
if conversation_id is not None:
|
||||
is_group = "R:" in conversation_id
|
||||
try:
|
||||
cmsg = create_message(wework_instance=wework_instance, message=message, is_group=is_group)
|
||||
except NotImplementedError as e:
|
||||
logger.error(f"[WX]{message.get('MsgId', 'unknown')} 跳过: {e}")
|
||||
return None
|
||||
delay = random.randint(1, 2)
|
||||
timer = threading.Timer(delay, handle_message, args=(cmsg, is_group))
|
||||
timer.start()
|
||||
else:
|
||||
logger.debug("消息数据中无 conversation_id")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def accept_friend_with_retries(wework_instance, user_id, corp_id):
|
||||
result = wework_instance.accept_friend(user_id, corp_id)
|
||||
logger.debug(f'result:{result}')
|
||||
|
||||
|
||||
# @wework.msg_register(ntwork.MT_RECV_FRIEND_MSG)
|
||||
# def friend(wework_instance: ntwork.WeWork, message):
|
||||
# data = message["data"]
|
||||
# user_id = data["user_id"]
|
||||
# corp_id = data["corp_id"]
|
||||
# logger.info(f"接收到好友请求,消息内容:{data}")
|
||||
# delay = random.randint(1, 180)
|
||||
# threading.Timer(delay, accept_friend_with_retries, args=(wework_instance, user_id, corp_id)).start()
|
||||
#
|
||||
# return None
|
||||
|
||||
|
||||
def get_with_retry(get_func, max_retries=5, delay=5):
|
||||
retries = 0
|
||||
result = None
|
||||
while retries < max_retries:
|
||||
result = get_func()
|
||||
if result:
|
||||
break
|
||||
logger.warning(f"获取数据失败,重试第{retries + 1}次······")
|
||||
retries += 1
|
||||
time.sleep(delay) # 等待一段时间后重试
|
||||
return result
|
||||
|
||||
|
||||
@singleton
|
||||
class WeworkChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def startup(self):
|
||||
smart = conf().get("wework_smart", True)
|
||||
wework.open(smart)
|
||||
logger.info("等待登录······")
|
||||
wework.wait_login()
|
||||
login_info = wework.get_login_info()
|
||||
self.user_id = login_info['user_id']
|
||||
self.name = login_info['nickname']
|
||||
logger.info(f"登录信息:>>>user_id:{self.user_id}>>>>>>>>name:{self.name}")
|
||||
logger.info("静默延迟60s,等待客户端刷新数据,请勿进行任何操作······")
|
||||
time.sleep(60)
|
||||
contacts = get_with_retry(wework.get_external_contacts)
|
||||
rooms = get_with_retry(wework.get_rooms)
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
if not contacts or not rooms:
|
||||
logger.error("获取contacts或rooms失败,程序退出")
|
||||
ntwork.exit_()
|
||||
os.exit(0)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
# 将contacts保存到json文件中
|
||||
with open(os.path.join(directory, 'wework_contacts.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(contacts, f, ensure_ascii=False, indent=4)
|
||||
with open(os.path.join(directory, 'wework_rooms.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(rooms, f, ensure_ascii=False, indent=4)
|
||||
# 创建一个空字典来保存结果
|
||||
result = {}
|
||||
|
||||
# 遍历列表中的每个字典
|
||||
for room in rooms['room_list']:
|
||||
# 获取聊天室ID
|
||||
room_wxid = room['conversation_id']
|
||||
|
||||
# 获取聊天室成员
|
||||
room_members = wework.get_room_members(room_wxid)
|
||||
|
||||
# 将聊天室成员保存到结果字典中
|
||||
result[room_wxid] = room_members
|
||||
|
||||
# 将结果保存到json文件中
|
||||
with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=4)
|
||||
logger.info("wework程序初始化完成········")
|
||||
run.forever()
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: ChatMessage):
|
||||
if cmsg.from_user_id == cmsg.to_user_id:
|
||||
# ignore self reply
|
||||
return
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if not conf().get("speech_recognition"):
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
else:
|
||||
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: ChatMessage):
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if not conf().get("speech_recognition"):
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
|
||||
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
pass
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
logger.debug(f"context: {context}")
|
||||
receiver = context["receiver"]
|
||||
actual_user_id = context["msg"].actual_user_id
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.TEXT_:
|
||||
match = re.search(r"^@(.*?)\n", reply.content)
|
||||
logger.debug(f"match: {match}")
|
||||
if match:
|
||||
new_content = re.sub(r"^@(.*?)\n", "\n", reply.content)
|
||||
at_list = [actual_user_id]
|
||||
logger.debug(f"new_content: {new_content}")
|
||||
wework.send_room_at_msg(receiver, new_content, at_list)
|
||||
else:
|
||||
wework.send_text(receiver, reply.content)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
wework.send_text(receiver, reply.content)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
# Read data from image_storage
|
||||
data = image_storage.read()
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp:
|
||||
temp_path = temp.name
|
||||
temp.write(data)
|
||||
# Send the image
|
||||
wework.send_image(receiver, temp_path)
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
# Remove the temporary file
|
||||
os.remove(temp_path)
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
filename = str(uuid.uuid4())
|
||||
|
||||
# 调用你的函数,下载图片并保存为本地文件
|
||||
image_path = download_and_compress_image(img_url, filename)
|
||||
|
||||
wework.send_image(receiver, file_path=image_path)
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL:
|
||||
video_url = reply.content
|
||||
filename = str(uuid.uuid4())
|
||||
video_path = download_video(video_url, filename)
|
||||
|
||||
if video_path is None:
|
||||
# 如果视频太大,下载可能会被跳过,此时 video_path 将为 None
|
||||
wework.send_text(receiver, "抱歉,视频太大了!!!")
|
||||
else:
|
||||
wework.send_video(receiver, video_path)
|
||||
logger.info("[WX] sendVideo, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
current_dir = os.getcwd()
|
||||
voice_file = reply.content.split("/")[-1]
|
||||
reply.content = os.path.join(current_dir, "tmp", voice_file)
|
||||
wework.send_file(receiver, reply.content)
|
||||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
||||
227
channel/wework/wework_message.py
Normal file
227
channel/wework/wework_message.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import pilk
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from ntwork.const import send_type
|
||||
|
||||
|
||||
def get_with_retry(get_func, max_retries=5, delay=5):
|
||||
retries = 0
|
||||
result = None
|
||||
while retries < max_retries:
|
||||
result = get_func()
|
||||
if result:
|
||||
break
|
||||
logger.warning(f"获取数据失败,重试第{retries + 1}次······")
|
||||
retries += 1
|
||||
time.sleep(delay) # 等待一段时间后重试
|
||||
return result
|
||||
|
||||
|
||||
def get_room_info(wework, conversation_id):
|
||||
logger.debug(f"传入的 conversation_id: {conversation_id}")
|
||||
rooms = wework.get_rooms()
|
||||
if not rooms or 'room_list' not in rooms:
|
||||
logger.error(f"获取群聊信息失败: {rooms}")
|
||||
return None
|
||||
time.sleep(1)
|
||||
logger.debug(f"获取到的群聊信息: {rooms}")
|
||||
for room in rooms['room_list']:
|
||||
if room['conversation_id'] == conversation_id:
|
||||
return room
|
||||
return None
|
||||
|
||||
|
||||
def cdn_download(wework, message, file_name):
|
||||
data = message["data"]
|
||||
aes_key = data["cdn"]["aes_key"]
|
||||
file_size = data["cdn"]["size"]
|
||||
|
||||
# 获取当前工作目录,然后与文件名拼接得到保存路径
|
||||
current_dir = os.getcwd()
|
||||
save_path = os.path.join(current_dir, "tmp", file_name)
|
||||
|
||||
# 下载保存图片到本地
|
||||
if "url" in data["cdn"].keys() and "auth_key" in data["cdn"].keys():
|
||||
url = data["cdn"]["url"]
|
||||
auth_key = data["cdn"]["auth_key"]
|
||||
# result = wework.wx_cdn_download(url, auth_key, aes_key, file_size, save_path) # ntwork库本身接口有问题,缺失了aes_key这个参数
|
||||
"""
|
||||
下载wx类型的cdn文件,以https开头
|
||||
"""
|
||||
data = {
|
||||
'url': url,
|
||||
'auth_key': auth_key,
|
||||
'aes_key': aes_key,
|
||||
'size': file_size,
|
||||
'save_path': save_path
|
||||
}
|
||||
result = wework._WeWork__send_sync(send_type.MT_WXCDN_DOWNLOAD_MSG, data) # 直接用wx_cdn_download的接口内部实现来调用
|
||||
elif "file_id" in data["cdn"].keys():
|
||||
if message["type"] == 11042:
|
||||
file_type = 2
|
||||
elif message["type"] == 11045:
|
||||
file_type = 5
|
||||
file_id = data["cdn"]["file_id"]
|
||||
result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
|
||||
else:
|
||||
logger.error(f"something is wrong, data: {data}")
|
||||
return
|
||||
|
||||
# 输出下载结果
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
|
||||
def c2c_download_and_convert(wework, message, file_name):
|
||||
data = message["data"]
|
||||
aes_key = data["cdn"]["aes_key"]
|
||||
file_size = data["cdn"]["size"]
|
||||
file_type = 5
|
||||
file_id = data["cdn"]["file_id"]
|
||||
|
||||
current_dir = os.getcwd()
|
||||
save_path = os.path.join(current_dir, "tmp", file_name)
|
||||
result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
|
||||
logger.debug(result)
|
||||
|
||||
# 在下载完SILK文件之后,立即将其转换为WAV文件
|
||||
base_name, _ = os.path.splitext(save_path)
|
||||
wav_file = base_name + ".wav"
|
||||
pilk.silk_to_wav(save_path, wav_file, rate=24000)
|
||||
|
||||
# 删除SILK文件
|
||||
try:
|
||||
os.remove(save_path)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
class WeworkMessage(ChatMessage):
|
||||
def __init__(self, wework_msg, wework, is_group=False):
|
||||
try:
|
||||
super().__init__(wework_msg)
|
||||
self.msg_id = wework_msg['data'].get('conversation_id', wework_msg['data'].get('room_conversation_id'))
|
||||
# 使用.get()防止 'send_time' 键不存在时抛出错误
|
||||
self.create_time = wework_msg['data'].get("send_time")
|
||||
self.is_group = is_group
|
||||
self.wework = wework
|
||||
|
||||
if wework_msg["type"] == 11041: # 文本消息类型
|
||||
if any(substring in wework_msg['data']['content'] for substring in ("该消息类型暂不能展示", "不支持的消息类型")):
|
||||
return
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wework_msg['data']['content']
|
||||
elif wework_msg["type"] == 11044: # 语音消息类型,需要缓存文件
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".silk"
|
||||
base_name, _ = os.path.splitext(file_name)
|
||||
file_name_2 = base_name + ".wav"
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name_2)
|
||||
self._prepare_fn = lambda: c2c_download_and_convert(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11042: # 图片消息类型,需要下载文件
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".jpg"
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name)
|
||||
self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11045: # 文件消息
|
||||
print("文件消息")
|
||||
print(wework_msg)
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
||||
file_name = file_name + wework_msg['data']['cdn']['file_name']
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.FILE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name)
|
||||
self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11047: # 链接消息
|
||||
self.ctype = ContextType.SHARING
|
||||
self.content = wework_msg['data']['url']
|
||||
elif wework_msg["type"] == 11072: # 新成员入群通知
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
member_list = wework_msg['data']['member_list']
|
||||
self.actual_user_nickname = member_list[0]['name']
|
||||
self.actual_user_id = member_list[0]['user_id']
|
||||
self.content = f"{self.actual_user_nickname}加入了群聊!"
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
rooms = get_with_retry(wework.get_rooms)
|
||||
if not rooms:
|
||||
logger.error("更新群信息失败···")
|
||||
else:
|
||||
result = {}
|
||||
for room in rooms['room_list']:
|
||||
# 获取聊天室ID
|
||||
room_wxid = room['conversation_id']
|
||||
|
||||
# 获取聊天室成员
|
||||
room_members = wework.get_room_members(room_wxid)
|
||||
|
||||
# 将聊天室成员保存到结果字典中
|
||||
result[room_wxid] = room_members
|
||||
with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=4)
|
||||
logger.info("有新成员加入,已自动更新群成员列表缓存!")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported message type: Type:{} MsgType:{}".format(wework_msg["type"], wework_msg["MsgType"]))
|
||||
|
||||
data = wework_msg['data']
|
||||
login_info = self.wework.get_login_info()
|
||||
logger.debug(f"login_info: {login_info}")
|
||||
nickname = f"{login_info['username']}({login_info['nickname']})" if login_info['nickname'] else login_info['username']
|
||||
user_id = login_info['user_id']
|
||||
|
||||
sender_id = data.get('sender')
|
||||
conversation_id = data.get('conversation_id')
|
||||
sender_name = data.get("sender_name")
|
||||
|
||||
self.from_user_id = user_id if sender_id == user_id else conversation_id
|
||||
self.from_user_nickname = nickname if sender_id == user_id else sender_name
|
||||
self.to_user_id = user_id
|
||||
self.to_user_nickname = nickname
|
||||
self.other_user_nickname = sender_name
|
||||
self.other_user_id = conversation_id
|
||||
|
||||
if self.is_group:
|
||||
conversation_id = data.get('conversation_id') or data.get('room_conversation_id')
|
||||
self.other_user_id = conversation_id
|
||||
if conversation_id:
|
||||
room_info = get_room_info(wework=wework, conversation_id=conversation_id)
|
||||
self.other_user_nickname = room_info.get('nickname', None) if room_info else None
|
||||
self.from_user_nickname = room_info.get('nickname', None) if room_info else None
|
||||
at_list = data.get('at_list', [])
|
||||
tmp_list = []
|
||||
for at in at_list:
|
||||
tmp_list.append(at['nickname'])
|
||||
at_list = tmp_list
|
||||
logger.debug(f"at_list: {at_list}")
|
||||
logger.debug(f"nickname: {nickname}")
|
||||
self.is_at = False
|
||||
if nickname in at_list or login_info['nickname'] in at_list or login_info['username'] in at_list:
|
||||
self.is_at = True
|
||||
self.at_list = at_list
|
||||
|
||||
# 检查消息内容是否包含@用户名。处理复制粘贴的消息,这类消息可能不会触发@通知,但内容中可能包含 "@用户名"。
|
||||
content = data.get('content', '')
|
||||
name = nickname
|
||||
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
|
||||
if re.search(pattern, content):
|
||||
logger.debug(f"Wechaty message {self.msg_id} includes at")
|
||||
self.is_at = True
|
||||
|
||||
if not self.actual_user_id:
|
||||
self.actual_user_id = data.get("sender")
|
||||
self.actual_user_nickname = sender_name if self.ctype != ContextType.JOIN_GROUP else self.actual_user_nickname
|
||||
else:
|
||||
logger.error("群聊消息中没有找到 conversation_id 或 room_conversation_id")
|
||||
|
||||
logger.debug(f"WeworkMessage has been successfully instantiated with message id: {self.msg_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"在 WeworkMessage 的初始化过程中出现错误:{e}")
|
||||
raise e
|
||||
@@ -1,5 +1,73 @@
|
||||
# bot_type
|
||||
OPEN_AI = "openAI"
|
||||
CHATGPT = "chatGPT"
|
||||
BAIDU = "baidu"
|
||||
CHATGPTONAZURE = "chatGPTOnAzure"
|
||||
BAIDU = "baidu" # 百度文心一言模型
|
||||
XUNFEI = "xunfei"
|
||||
CHATGPTONAZURE = "chatGPTOnAzure"
|
||||
LINKAI = "linkai"
|
||||
CLAUDEAI = "claude" # 使用cookie的历史模型
|
||||
CLAUDEAPI= "claudeAPI" # 通过Claude api调用模型
|
||||
QWEN = "qwen" # 旧版通义模型
|
||||
QWEN_DASHSCOPE = "dashscope" # 通义新版sdk和api key
|
||||
|
||||
|
||||
GEMINI = "gemini" # gemini-1.0-pro
|
||||
ZHIPU_AI = "glm-4"
|
||||
MOONSHOT = "moonshot"
|
||||
MiniMax = "minimax"
|
||||
|
||||
|
||||
# model
|
||||
CLAUDE3 = "claude-3-opus-20240229"
|
||||
GPT35 = "gpt-3.5-turbo"
|
||||
GPT35_0125 = "gpt-3.5-turbo-0125"
|
||||
GPT35_1106 = "gpt-3.5-turbo-1106"
|
||||
|
||||
GPT_4o = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT4_TURBO_PREVIEW = "gpt-4-turbo-preview"
|
||||
GPT4_TURBO_04_09 = "gpt-4-turbo-2024-04-09"
|
||||
GPT4_TURBO_01_25 = "gpt-4-0125-preview"
|
||||
GPT4_TURBO_11_06 = "gpt-4-1106-preview"
|
||||
GPT4_VISION_PREVIEW = "gpt-4-vision-preview"
|
||||
|
||||
GPT4 = "gpt-4"
|
||||
GPT_4o_MINI = "gpt-4o-mini"
|
||||
GPT4_32k = "gpt-4-32k"
|
||||
GPT4_06_13 = "gpt-4-0613"
|
||||
GPT4_32k_06_13 = "gpt-4-32k-0613"
|
||||
|
||||
WHISPER_1 = "whisper-1"
|
||||
TTS_1 = "tts-1"
|
||||
TTS_1_HD = "tts-1-hd"
|
||||
|
||||
WEN_XIN = "wenxin"
|
||||
WEN_XIN_4 = "wenxin-4"
|
||||
|
||||
QWEN_TURBO = "qwen-turbo"
|
||||
QWEN_PLUS = "qwen-plus"
|
||||
QWEN_MAX = "qwen-max"
|
||||
|
||||
LINKAI_35 = "linkai-3.5"
|
||||
LINKAI_4_TURBO = "linkai-4-turbo"
|
||||
LINKAI_4o = "linkai-4o"
|
||||
|
||||
GEMINI_PRO = "gemini-1.0-pro"
|
||||
GEMINI_15_flash = "gemini-1.5-flash"
|
||||
GEMINI_15_PRO = "gemini-1.5-pro"
|
||||
|
||||
MODEL_LIST = [
|
||||
GPT35, GPT35_0125, GPT35_1106, "gpt-3.5-turbo-16k",
|
||||
GPT_4o, GPT_4o_MINI, GPT4_TURBO, GPT4_TURBO_PREVIEW, GPT4_TURBO_01_25, GPT4_TURBO_11_06, GPT4, GPT4_32k, GPT4_06_13, GPT4_32k_06_13,
|
||||
WEN_XIN, WEN_XIN_4,
|
||||
XUNFEI, ZHIPU_AI, MOONSHOT, MiniMax,
|
||||
GEMINI, GEMINI_PRO, GEMINI_15_flash, GEMINI_15_PRO,
|
||||
"claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3-opus-20240229", "claude-3.5-sonnet",
|
||||
"moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k",
|
||||
QWEN, QWEN_TURBO, QWEN_PLUS, QWEN_MAX,
|
||||
LINKAI_35, LINKAI_4_TURBO, LINKAI_4o
|
||||
]
|
||||
|
||||
# channel
|
||||
FEISHU = "feishu"
|
||||
DINGTALK = "dingtalk"
|
||||
|
||||
33
common/dequeue.py
Normal file
33
common/dequeue.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from queue import Full, Queue
|
||||
from time import monotonic as time
|
||||
|
||||
|
||||
# add implementation of putleft to Queue
|
||||
class Dequeue(Queue):
|
||||
def putleft(self, item, block=True, timeout=None):
|
||||
with self.not_full:
|
||||
if self.maxsize > 0:
|
||||
if not block:
|
||||
if self._qsize() >= self.maxsize:
|
||||
raise Full
|
||||
elif timeout is None:
|
||||
while self._qsize() >= self.maxsize:
|
||||
self.not_full.wait()
|
||||
elif timeout < 0:
|
||||
raise ValueError("'timeout' must be a non-negative number")
|
||||
else:
|
||||
endtime = time() + timeout
|
||||
while self._qsize() >= self.maxsize:
|
||||
remaining = endtime - time()
|
||||
if remaining <= 0.0:
|
||||
raise Full
|
||||
self.not_full.wait(remaining)
|
||||
self._putleft(item)
|
||||
self.unfinished_tasks += 1
|
||||
self.not_empty.notify()
|
||||
|
||||
def putleft_nowait(self, item):
|
||||
return self.putleft(item, block=False)
|
||||
|
||||
def _putleft(self, item):
|
||||
self.queue.appendleft(item)
|
||||
@@ -39,4 +39,4 @@ class ExpiredDict(dict):
|
||||
return [(key, self[key]) for key in self.keys()]
|
||||
|
||||
def __iter__(self):
|
||||
return self.keys().__iter__()
|
||||
return self.keys().__iter__()
|
||||
|
||||
105
common/linkai_client.py
Normal file
105
common/linkai_client.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from linkai import LinkAIClient, PushMsg
|
||||
from config import conf, pconf, plugin_config, available_setting
|
||||
from plugins import PluginManager
|
||||
import time
|
||||
|
||||
|
||||
chat_client: LinkAIClient
|
||||
|
||||
|
||||
class ChatClient(LinkAIClient):
|
||||
def __init__(self, api_key, host, channel):
|
||||
super().__init__(api_key, host)
|
||||
self.channel = channel
|
||||
self.client_type = channel.channel_type
|
||||
|
||||
def on_message(self, push_msg: PushMsg):
|
||||
session_id = push_msg.session_id
|
||||
msg_content = push_msg.msg_content
|
||||
logger.info(f"receive msg push, session_id={session_id}, msg_content={msg_content}")
|
||||
context = Context()
|
||||
context.type = ContextType.TEXT
|
||||
context["receiver"] = session_id
|
||||
context["isgroup"] = push_msg.is_group
|
||||
self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
|
||||
|
||||
def on_config(self, config: dict):
|
||||
if not self.client_id:
|
||||
return
|
||||
logger.info(f"[LinkAI] 从客户端管理加载远程配置: {config}")
|
||||
if config.get("enabled") != "Y":
|
||||
return
|
||||
|
||||
local_config = conf()
|
||||
for key in config.keys():
|
||||
if key in available_setting and config.get(key) is not None:
|
||||
local_config[key] = config.get(key)
|
||||
# 语音配置
|
||||
reply_voice_mode = config.get("reply_voice_mode")
|
||||
if reply_voice_mode:
|
||||
if reply_voice_mode == "voice_reply_voice":
|
||||
local_config["voice_reply_voice"] = True
|
||||
elif reply_voice_mode == "always_reply_voice":
|
||||
local_config["always_reply_voice"] = True
|
||||
|
||||
if config.get("admin_password"):
|
||||
if not plugin_config.get("Godcmd"):
|
||||
plugin_config["Godcmd"] = {"password": config.get("admin_password"), "admin_users": []}
|
||||
else:
|
||||
plugin_config["Godcmd"]["password"] = config.get("admin_password")
|
||||
PluginManager().instances["GODCMD"].reload()
|
||||
|
||||
if config.get("group_app_map") and pconf("linkai"):
|
||||
local_group_map = {}
|
||||
for mapping in config.get("group_app_map"):
|
||||
local_group_map[mapping.get("group_name")] = mapping.get("app_code")
|
||||
pconf("linkai")["group_app_map"] = local_group_map
|
||||
PluginManager().instances["LINKAI"].reload()
|
||||
|
||||
if config.get("text_to_image") and config.get("text_to_image") == "midjourney" and pconf("linkai"):
|
||||
if pconf("linkai")["midjourney"]:
|
||||
pconf("linkai")["midjourney"]["enabled"] = True
|
||||
pconf("linkai")["midjourney"]["use_image_create_prefix"] = True
|
||||
elif config.get("text_to_image") and config.get("text_to_image") in ["dall-e-2", "dall-e-3"]:
|
||||
if pconf("linkai")["midjourney"]:
|
||||
pconf("linkai")["midjourney"]["use_image_create_prefix"] = False
|
||||
|
||||
|
||||
def start(channel):
|
||||
global chat_client
|
||||
chat_client = ChatClient(api_key=conf().get("linkai_api_key"), host="", channel=channel)
|
||||
chat_client.config = _build_config()
|
||||
chat_client.start()
|
||||
time.sleep(1.5)
|
||||
if chat_client.client_id:
|
||||
logger.info("[LinkAI] 可前往控制台进行线上登录和配置:https://link-ai.tech/console/clients")
|
||||
|
||||
|
||||
def _build_config():
|
||||
local_conf = conf()
|
||||
config = {
|
||||
"linkai_app_code": local_conf.get("linkai_app_code"),
|
||||
"single_chat_prefix": local_conf.get("single_chat_prefix"),
|
||||
"single_chat_reply_prefix": local_conf.get("single_chat_reply_prefix"),
|
||||
"single_chat_reply_suffix": local_conf.get("single_chat_reply_suffix"),
|
||||
"group_chat_prefix": local_conf.get("group_chat_prefix"),
|
||||
"group_chat_reply_prefix": local_conf.get("group_chat_reply_prefix"),
|
||||
"group_chat_reply_suffix": local_conf.get("group_chat_reply_suffix"),
|
||||
"group_name_white_list": local_conf.get("group_name_white_list"),
|
||||
"nick_name_black_list": local_conf.get("nick_name_black_list"),
|
||||
"speech_recognition": "Y" if local_conf.get("speech_recognition") else "N",
|
||||
"text_to_image": local_conf.get("text_to_image"),
|
||||
"image_create_prefix": local_conf.get("image_create_prefix")
|
||||
}
|
||||
if local_conf.get("always_reply_voice"):
|
||||
config["reply_voice_mode"] = "always_reply_voice"
|
||||
elif local_conf.get("voice_reply_voice"):
|
||||
config["reply_voice_mode"] = "voice_reply_voice"
|
||||
if pconf("linkai"):
|
||||
config["group_app_map"] = pconf("linkai").get("group_app_map")
|
||||
if plugin_config.get("Godcmd"):
|
||||
config["admin_password"] = plugin_config.get("Godcmd").get("password")
|
||||
return config
|
||||
@@ -2,15 +2,37 @@ import logging
|
||||
import sys
|
||||
|
||||
|
||||
def _get_logger():
|
||||
log = logging.getLogger('log')
|
||||
log.setLevel(logging.INFO)
|
||||
def _reset_logger(log):
|
||||
for handler in log.handlers:
|
||||
handler.close()
|
||||
log.removeHandler(handler)
|
||||
del handler
|
||||
log.handlers.clear()
|
||||
log.propagate = False
|
||||
console_handle = logging.StreamHandler(sys.stdout)
|
||||
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
||||
console_handle.setFormatter(
|
||||
logging.Formatter(
|
||||
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
file_handle = logging.FileHandler("run.log", encoding="utf-8")
|
||||
file_handle.setFormatter(
|
||||
logging.Formatter(
|
||||
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
log.addHandler(file_handle)
|
||||
log.addHandler(console_handle)
|
||||
|
||||
|
||||
def _get_logger():
|
||||
log = logging.getLogger("log")
|
||||
_reset_logger(log)
|
||||
log.setLevel(logging.INFO)
|
||||
return log
|
||||
|
||||
|
||||
# 日志句柄
|
||||
logger = _get_logger()
|
||||
logger = _get_logger()
|
||||
|
||||
3
common/memory.py
Normal file
3
common/memory.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from common.expired_dict import ExpiredDict
|
||||
|
||||
USER_IMAGE_CACHE = ExpiredDict(60 * 3)
|
||||
36
common/package_manager.py
Normal file
36
common/package_manager.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import time
|
||||
|
||||
import pip
|
||||
from pip._internal import main as pipmain
|
||||
|
||||
from common.log import _reset_logger, logger
|
||||
|
||||
|
||||
def install(package):
|
||||
pipmain(["install", package])
|
||||
|
||||
|
||||
def install_requirements(file):
|
||||
pipmain(["install", "-r", file, "--upgrade"])
|
||||
_reset_logger(logger)
|
||||
|
||||
|
||||
def check_dulwich():
|
||||
needwait = False
|
||||
for i in range(2):
|
||||
if needwait:
|
||||
time.sleep(3)
|
||||
needwait = False
|
||||
try:
|
||||
import dulwich
|
||||
|
||||
return
|
||||
except ImportError:
|
||||
try:
|
||||
install("dulwich")
|
||||
except:
|
||||
needwait = True
|
||||
try:
|
||||
import dulwich
|
||||
except ImportError:
|
||||
raise ImportError("Unable to import dulwich")
|
||||
@@ -62,4 +62,4 @@ class SortedDict(dict):
|
||||
return iter(self.keys())
|
||||
|
||||
def __repr__(self):
|
||||
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})'
|
||||
return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})"
|
||||
|
||||
@@ -1,38 +1,42 @@
|
||||
import time,re,hashlib
|
||||
import re
|
||||
import time
|
||||
import config
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def time_checker(f):
|
||||
def _time_checker(self, *args, **kwargs):
|
||||
_config = config.conf()
|
||||
chat_time_module = _config.get("chat_time_module", False)
|
||||
|
||||
if chat_time_module:
|
||||
chat_start_time = _config.get("chat_start_time", "00:00")
|
||||
chat_stopt_time = _config.get("chat_stop_time", "24:00")
|
||||
time_regex = re.compile(r'^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$') #时间匹配,包含24:00
|
||||
chat_stop_time = _config.get("chat_stop_time", "24:00")
|
||||
|
||||
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
|
||||
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
|
||||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
|
||||
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$")
|
||||
|
||||
# 时间格式检查
|
||||
if not (starttime_format_check and stoptime_format_check and chat_time_check):
|
||||
logger.warn('时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})'.format(starttime_format_check,stoptime_format_check))
|
||||
if chat_start_time>"23:59":
|
||||
logger.error('启动时间可能存在问题,请修改!')
|
||||
|
||||
# 服务时间检查
|
||||
now_time = time.strftime("%H:%M", time.localtime())
|
||||
if chat_start_time <= now_time <= chat_stopt_time: # 服务时间内,正常返回回答
|
||||
f(self, *args, **kwargs)
|
||||
if not (time_regex.match(chat_start_time) and time_regex.match(chat_stop_time)):
|
||||
logger.warning("时间格式不正确,请在config.json中修改CHAT_START_TIME/CHAT_STOP_TIME。")
|
||||
return None
|
||||
|
||||
now_time = time.strptime(time.strftime("%H:%M"), "%H:%M")
|
||||
chat_start_time = time.strptime(chat_start_time, "%H:%M")
|
||||
chat_stop_time = time.strptime(chat_stop_time, "%H:%M")
|
||||
# 结束时间小于开始时间,跨天了
|
||||
if chat_stop_time < chat_start_time and (chat_start_time <= now_time or now_time <= chat_stop_time):
|
||||
f(self, *args, **kwargs)
|
||||
# 结束大于开始时间代表,没有跨天
|
||||
elif chat_start_time < chat_stop_time and chat_start_time <= now_time <= chat_stop_time:
|
||||
f(self, *args, **kwargs)
|
||||
else:
|
||||
if args[0]['Content'] == "#更新配置": # 不在服务时间内也可以更新配置
|
||||
# 定义匹配规则,如果以 #reconf 或者 #更新配置 结尾, 非服务时间可以修改开始/结束时间并重载配置
|
||||
pattern = re.compile(r"^.*#(?:reconf|更新配置)$")
|
||||
if args and pattern.match(args[0].content):
|
||||
f(self, *args, **kwargs)
|
||||
else:
|
||||
logger.info('非服务时间内,不接受访问')
|
||||
logger.info("非服务时间内,不接受访问")
|
||||
return None
|
||||
else:
|
||||
f(self, *args, **kwargs) # 未开启时间模块则直接回答
|
||||
return _time_checker
|
||||
|
||||
return _time_checker
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from config import conf
|
||||
|
||||
|
||||
class TmpDir(object):
|
||||
"""A temporary directory that is deleted when the object is destroyed.
|
||||
"""
|
||||
"""A temporary directory that is deleted when the object is destroyed."""
|
||||
|
||||
tmpFilePath = pathlib.Path("./tmp/")
|
||||
|
||||
tmpFilePath = pathlib.Path('./tmp/')
|
||||
|
||||
def __init__(self):
|
||||
pathExists = os.path.exists(self.tmpFilePath)
|
||||
if not pathExists and conf().get('speech_recognition') == True:
|
||||
if not pathExists:
|
||||
os.makedirs(self.tmpFilePath)
|
||||
|
||||
def path(self):
|
||||
return str(self.tmpFilePath) + '/'
|
||||
|
||||
return str(self.tmpFilePath) + "/"
|
||||
|
||||
56
common/utils.py
Normal file
56
common/utils.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import io
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def fsize(file):
|
||||
if isinstance(file, io.BytesIO):
|
||||
return file.getbuffer().nbytes
|
||||
elif isinstance(file, str):
|
||||
return os.path.getsize(file)
|
||||
elif hasattr(file, "seek") and hasattr(file, "tell"):
|
||||
pos = file.tell()
|
||||
file.seek(0, os.SEEK_END)
|
||||
size = file.tell()
|
||||
file.seek(pos)
|
||||
return size
|
||||
else:
|
||||
raise TypeError("Unsupported type")
|
||||
|
||||
|
||||
def compress_imgfile(file, max_size):
|
||||
if fsize(file) <= max_size:
|
||||
return file
|
||||
file.seek(0)
|
||||
img = Image.open(file)
|
||||
rgb_image = img.convert("RGB")
|
||||
quality = 95
|
||||
while True:
|
||||
out_buf = io.BytesIO()
|
||||
rgb_image.save(out_buf, "JPEG", quality=quality)
|
||||
if fsize(out_buf) <= max_size:
|
||||
return out_buf
|
||||
quality -= 5
|
||||
|
||||
|
||||
def split_string_by_utf8_length(string, max_length, max_split=0):
|
||||
encoded = string.encode("utf-8")
|
||||
start, end = 0, 0
|
||||
result = []
|
||||
while end < len(encoded):
|
||||
if max_split > 0 and len(result) >= max_split:
|
||||
result.append(encoded[start:].decode("utf-8"))
|
||||
break
|
||||
end = min(start + max_length, len(encoded))
|
||||
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
|
||||
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
|
||||
end -= 1
|
||||
result.append(encoded[start:end].decode("utf-8"))
|
||||
start = end
|
||||
return result
|
||||
|
||||
|
||||
def get_path_suffix(path):
|
||||
path = urlparse(path).path
|
||||
return os.path.splitext(path)[-1].lstrip('.')
|
||||
@@ -1,17 +1,37 @@
|
||||
{
|
||||
"channel_type": "wx",
|
||||
"model": "",
|
||||
"open_ai_api_key": "YOUR API KEY",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"claude_api_key": "YOUR API KEY",
|
||||
"text_to_image": "dall-e-2",
|
||||
"voice_to_text": "openai",
|
||||
"text_to_voice": "openai",
|
||||
"proxy": "",
|
||||
"use_azure_chatgpt": false,
|
||||
"single_chat_prefix": ["bot", "@bot"],
|
||||
"hot_reload": false,
|
||||
"single_chat_prefix": [
|
||||
"bot",
|
||||
"@bot"
|
||||
],
|
||||
"single_chat_reply_prefix": "[bot] ",
|
||||
"group_chat_prefix": ["@bot"],
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"],
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"],
|
||||
"image_create_prefix": ["画", "看", "找"],
|
||||
"speech_recognition": false,
|
||||
"group_chat_prefix": [
|
||||
"@bot"
|
||||
],
|
||||
"group_name_white_list": [
|
||||
"ChatGPT测试群",
|
||||
"ChatGPT测试群2"
|
||||
],
|
||||
"image_create_prefix": [
|
||||
"画"
|
||||
],
|
||||
"speech_recognition": true,
|
||||
"group_speech_recognition": false,
|
||||
"voice_reply_voice": false,
|
||||
"conversation_max_tokens": 1000,
|
||||
"conversation_max_tokens": 2500,
|
||||
"expires_in_seconds": 3600,
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
|
||||
}
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"temperature": 0.7,
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": ""
|
||||
}
|
||||
|
||||
346
config.py
346
config.py
@@ -1,75 +1,193 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import copy
|
||||
|
||||
from common.log import logger
|
||||
|
||||
# 将所有可用的配置项写在字典里, 请使用小写字母
|
||||
available_setting ={
|
||||
#openai api配置
|
||||
"open_ai_api_key": "", # openai api key
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
|
||||
"proxy": "", # openai使用的代理
|
||||
"model": "gpt-3.5-turbo", # chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||
|
||||
#Bot触发配置
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
|
||||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
|
||||
"group_at_off": False, # 是否关闭群聊时@bot的触发
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
|
||||
#chatgpt会话参数
|
||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
|
||||
#chatgpt限流配置
|
||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制
|
||||
|
||||
|
||||
#chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
|
||||
# 此处的配置值无实际意义,程序不会读取此处的配置,仅用于提示格式,请将配置加入到config.json中
|
||||
available_setting = {
|
||||
# openai api配置
|
||||
"open_ai_api_key": "", # openai api key
|
||||
# openai apibase,当use_azure_chatgpt为true时,需要设置对应的api base
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"proxy": "", # openai使用的代理
|
||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"model": "gpt-3.5-turbo", # 可选择: gpt-4o, pt-4o-mini, gpt-4-turbo, claude-3-sonnet, wenxin, moonshot, qwen-turbo, xunfei, glm-4, minimax, gemini等模型,全部可选模型详见common/const.py文件
|
||||
"bot_type": "", # 可选配置,使用兼容openai格式的三方服务时候,需填"chatGPT"。bot具体名称详见common/const.py文件列出的bot_type,如不填根据model名称判断,
|
||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||
"azure_deployment_id": "", # azure 模型部署名称
|
||||
"azure_api_version": "", # azure api版本
|
||||
# Bot触发配置
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"single_chat_reply_suffix": "", # 私聊时自动回复的后缀,\n 可以换行
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
|
||||
"group_chat_reply_suffix": "", # 群聊时自动回复的后缀,\n 可以换行
|
||||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
|
||||
"group_at_off": False, # 是否关闭群聊时@bot的触发
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"nick_name_black_list": [], # 用户昵称黑名单
|
||||
"group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎
|
||||
"trigger_by_self": False, # 是否允许机器人触发
|
||||
"text_to_image": "dall-e-2", # 图片生成模型,可选 dall-e-2, dall-e-3
|
||||
# Azure OpenAI dall-e-3 配置
|
||||
"dalle3_image_style": "vivid", # 图片生成dalle3的风格,可选有 vivid, natural
|
||||
"dalle3_image_quality": "hd", # 图片生成dalle3的质量,可选有 standard, hd
|
||||
# Azure OpenAI DALL-E API 配置, 当use_azure_chatgpt为true时,用于将文字回复的资源和Dall-E的资源分开.
|
||||
"azure_openai_dalle_api_base": "", # [可选] azure openai 用于回复图片的资源 endpoint,默认使用 open_ai_api_base
|
||||
"azure_openai_dalle_api_key": "", # [可选] azure openai 用于回复图片的资源 key,默认使用 open_ai_api_key
|
||||
"azure_openai_dalle_deployment_id":"", # [可选] azure openai 用于回复图片的资源 deployment id,默认使用 text_to_image
|
||||
"image_proxy": True, # 是否需要图片代理,国内访问LinkAI时需要
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
|
||||
"image_create_size": "256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024 (dall-e-3默认为1024x1024)
|
||||
"group_chat_exit_group": False,
|
||||
# chatgpt会话参数
|
||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||
# 人格描述
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
# chatgpt限流配置
|
||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制
|
||||
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
|
||||
#语音设置
|
||||
"speech_recognition": False, # 是否开启语音识别
|
||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai和google
|
||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu和google
|
||||
|
||||
# baidu api的配置, 使用百度语音识别和语音合成时需要
|
||||
'baidu_app_id': "",
|
||||
'baidu_api_key': "",
|
||||
'baidu_secret_key': "",
|
||||
|
||||
#服务时间限制,目前支持itchat
|
||||
"chat_time_module": False, # 是否开启服务时间限制
|
||||
"chat_start_time": "00:00", # 服务开始时间
|
||||
"chat_stop_time": "24:00", # 服务结束时间
|
||||
|
||||
"request_timeout": 180, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
|
||||
# Baidu 文心一言参数
|
||||
"baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
|
||||
"baidu_wenxin_api_key": "", # Baidu api key
|
||||
"baidu_wenxin_secret_key": "", # Baidu secret key
|
||||
# 讯飞星火API
|
||||
"xunfei_app_id": "", # 讯飞应用ID
|
||||
"xunfei_api_key": "", # 讯飞 API key
|
||||
"xunfei_api_secret": "", # 讯飞 API secret
|
||||
# claude 配置
|
||||
"claude_api_cookie": "",
|
||||
"claude_uuid": "",
|
||||
# claude api key
|
||||
"claude_api_key": "",
|
||||
# 通义千问API, 获取方式查看文档 https://help.aliyun.com/document_detail/2587494.html
|
||||
"qwen_access_key_id": "",
|
||||
"qwen_access_key_secret": "",
|
||||
"qwen_agent_key": "",
|
||||
"qwen_app_id": "",
|
||||
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
|
||||
# 阿里灵积(通义新版sdk)模型api key
|
||||
"dashscope_api_key": "",
|
||||
# Google Gemini Api Key
|
||||
"gemini_api_key": "",
|
||||
# wework的通用配置
|
||||
"wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开
|
||||
# 语音设置
|
||||
"speech_recognition": True, # 是否开启语音识别
|
||||
"group_speech_recognition": False, # 是否开启群组语音识别
|
||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
||||
"always_reply_voice": False, # 是否一直使用语音回复
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure,xunfei,ali
|
||||
"text_to_voice": "openai", # 语音合成引擎,支持openai,baidu,google,azure,xunfei,ali,pytts(offline),elevenlabs,edge(online)
|
||||
"text_to_voice_model": "tts-1",
|
||||
"tts_voice_id": "alloy",
|
||||
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
|
||||
"baidu_app_id": "",
|
||||
"baidu_api_key": "",
|
||||
"baidu_secret_key": "",
|
||||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
|
||||
"baidu_dev_pid": 1536,
|
||||
# azure 语音api配置, 使用azure语音识别和语音合成时需要
|
||||
"azure_voice_api_key": "",
|
||||
"azure_voice_region": "japaneast",
|
||||
# elevenlabs 语音api配置
|
||||
"xi_api_key": "", # 获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication
|
||||
"xi_voice_id": "", # ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam”
|
||||
# 服务时间限制,目前支持itchat
|
||||
"chat_time_module": False, # 是否开启服务时间限制
|
||||
"chat_start_time": "00:00", # 服务开始时间
|
||||
"chat_stop_time": "24:00", # 服务结束时间
|
||||
# 翻译api
|
||||
"translate": "baidu", # 翻译api,支持baidu
|
||||
# baidu翻译api的配置
|
||||
"baidu_translate_app_id": "", # 百度翻译api的appid
|
||||
"baidu_translate_app_key": "", # 百度翻译api的秘钥
|
||||
# itchat的配置
|
||||
"hot_reload": False, # 是否开启热重载
|
||||
|
||||
"hot_reload": False, # 是否开启热重载
|
||||
# wechaty的配置
|
||||
"wechaty_puppet_service_token": "", # wechaty的token
|
||||
|
||||
"wechaty_puppet_service_token": "", # wechaty的token
|
||||
# wechatmp的配置
|
||||
"wechatmp_token": "", # 微信公众平台的Token
|
||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||
"wechatmp_app_id": "", # 微信公众平台的appID
|
||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret
|
||||
"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
|
||||
# wechatcom的通用配置
|
||||
"wechatcom_corp_id": "", # 企业微信公司的corpID
|
||||
# wechatcomapp的配置
|
||||
"wechatcomapp_token": "", # 企业微信app的token
|
||||
"wechatcomapp_port": 9898, # 企业微信app的服务端口,不需要端口转发
|
||||
"wechatcomapp_secret": "", # 企业微信app的secret
|
||||
"wechatcomapp_agent_id": "", # 企业微信app的agent_id
|
||||
"wechatcomapp_aes_key": "", # 企业微信app的aes_key
|
||||
# 飞书配置
|
||||
"feishu_port": 80, # 飞书bot监听端口
|
||||
"feishu_app_id": "", # 飞书机器人应用APP Id
|
||||
"feishu_app_secret": "", # 飞书机器人APP secret
|
||||
"feishu_token": "", # 飞书 verification token
|
||||
"feishu_bot_name": "", # 飞书机器人的名字
|
||||
# 钉钉配置
|
||||
"dingtalk_client_id": "", # 钉钉机器人Client ID
|
||||
"dingtalk_client_secret": "", # 钉钉机器人Client Secret
|
||||
"dingtalk_card_enabled": False,
|
||||
|
||||
# chatgpt指令自定义触发词
|
||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令
|
||||
|
||||
|
||||
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
|
||||
# channel配置
|
||||
"channel_type": "", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app,dingtalk}
|
||||
"subscribe_msg": "", # 订阅消息, 支持: wechatmp, wechatmp_service, wechatcom_app
|
||||
"debug": False, # 是否开启debug模式,开启后会打印更多日志
|
||||
"appdata_dir": "", # 数据目录
|
||||
# 插件配置
|
||||
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
|
||||
# 是否使用全局插件配置
|
||||
"use_global_plugin_config": False,
|
||||
"max_media_send_count": 3, # 单次最大发送媒体资源的个数
|
||||
"media_send_interval": 1, # 发送图片的事件间隔,单位秒
|
||||
# 智谱AI 平台配置
|
||||
"zhipu_ai_api_key": "",
|
||||
"zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"moonshot_api_key": "",
|
||||
"moonshot_base_url": "https://api.moonshot.cn/v1/chat/completions",
|
||||
# LinkAI平台配置
|
||||
"use_linkai": False,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": "",
|
||||
"linkai_api_base": "https://api.link-ai.tech", # linkAI服务地址
|
||||
"Minimax_api_key": "",
|
||||
"Minimax_group_id": "",
|
||||
"Minimax_base_url": "",
|
||||
}
|
||||
|
||||
|
||||
class Config(dict):
|
||||
def __init__(self, d=None):
|
||||
super().__init__()
|
||||
if d is None:
|
||||
d = {}
|
||||
for k, v in d.items():
|
||||
self[k] = v
|
||||
# user_datas: 用户数据,key为用户名,value为用户数据,也是dict
|
||||
self.user_datas = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in available_setting:
|
||||
raise Exception("key {} not in available_setting".format(key))
|
||||
@@ -81,24 +199,75 @@ class Config(dict):
|
||||
return super().__setitem__(key, value)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try :
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError as e:
|
||||
return default
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# Make sure to return a dictionary to ensure atomic
|
||||
def get_user_data(self, user) -> dict:
|
||||
if self.user_datas.get(user) is None:
|
||||
self.user_datas[user] = {}
|
||||
return self.user_datas[user]
|
||||
|
||||
def load_user_datas(self):
|
||||
try:
|
||||
with open(os.path.join(get_appdata_dir(), "user_datas.pkl"), "rb") as f:
|
||||
self.user_datas = pickle.load(f)
|
||||
logger.info("[Config] User datas loaded.")
|
||||
except FileNotFoundError as e:
|
||||
logger.info("[Config] User datas file not found, ignore.")
|
||||
except Exception as e:
|
||||
logger.info("[Config] User datas error: {}".format(e))
|
||||
self.user_datas = {}
|
||||
|
||||
def save_user_datas(self):
|
||||
try:
|
||||
with open(os.path.join(get_appdata_dir(), "user_datas.pkl"), "wb") as f:
|
||||
pickle.dump(self.user_datas, f)
|
||||
logger.info("[Config] User datas saved.")
|
||||
except Exception as e:
|
||||
logger.info("[Config] User datas error: {}".format(e))
|
||||
|
||||
|
||||
config = Config()
|
||||
|
||||
|
||||
def drag_sensitive(config):
|
||||
try:
|
||||
if isinstance(config, str):
|
||||
conf_dict: dict = json.loads(config)
|
||||
conf_dict_copy = copy.deepcopy(conf_dict)
|
||||
for key in conf_dict_copy:
|
||||
if "key" in key or "secret" in key:
|
||||
if isinstance(conf_dict_copy[key], str):
|
||||
conf_dict_copy[key] = conf_dict_copy[key][0:3] + "*" * 5 + conf_dict_copy[key][-3:]
|
||||
return json.dumps(conf_dict_copy, indent=4)
|
||||
|
||||
elif isinstance(config, dict):
|
||||
config_copy = copy.deepcopy(config)
|
||||
for key in config:
|
||||
if "key" in key or "secret" in key:
|
||||
if isinstance(config_copy[key], str):
|
||||
config_copy[key] = config_copy[key][0:3] + "*" * 5 + config_copy[key][-3:]
|
||||
return config_copy
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return config
|
||||
return config
|
||||
|
||||
|
||||
def load_config():
|
||||
global config
|
||||
config_path = "./config.json"
|
||||
if not os.path.exists(config_path):
|
||||
logger.info('配置文件不存在,将使用config-template.json模板')
|
||||
logger.info("配置文件不存在,将使用config-template.json模板")
|
||||
config_path = "./config-template.json"
|
||||
|
||||
config_str = read_file(config_path)
|
||||
logger.debug("[INIT] config str: {}".format(config_str))
|
||||
logger.debug("[INIT] config str: {}".format(drag_sensitive(config_str)))
|
||||
|
||||
# 将json字符串反序列化为dict类型
|
||||
config = Config(json.loads(config_str))
|
||||
@@ -112,20 +281,71 @@ def load_config():
|
||||
try:
|
||||
config[name] = eval(value)
|
||||
except:
|
||||
config[name] = value
|
||||
if value == "false":
|
||||
config[name] = False
|
||||
elif value == "true":
|
||||
config[name] = True
|
||||
else:
|
||||
config[name] = value
|
||||
|
||||
logger.info("[INIT] load config: {}".format(config))
|
||||
if config.get("debug", False):
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.debug("[INIT] set log level to DEBUG")
|
||||
|
||||
logger.info("[INIT] load config: {}".format(drag_sensitive(config)))
|
||||
|
||||
config.load_user_datas()
|
||||
|
||||
|
||||
def get_root():
|
||||
return os.path.dirname(os.path.abspath( __file__ ))
|
||||
return os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def read_file(path):
|
||||
with open(path, mode='r', encoding='utf-8') as f:
|
||||
with open(path, mode="r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def conf():
|
||||
return config
|
||||
|
||||
|
||||
def get_appdata_dir():
|
||||
data_path = os.path.join(get_root(), conf().get("appdata_dir", ""))
|
||||
if not os.path.exists(data_path):
|
||||
logger.info("[INIT] data path not exists, create it: {}".format(data_path))
|
||||
os.makedirs(data_path)
|
||||
return data_path
|
||||
|
||||
|
||||
def subscribe_msg():
|
||||
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
|
||||
msg = conf().get("subscribe_msg", "")
|
||||
return msg.format(trigger_prefix=trigger_prefix)
|
||||
|
||||
|
||||
# global plugin config
|
||||
plugin_config = {}
|
||||
|
||||
|
||||
def write_plugin_config(pconf: dict):
|
||||
"""
|
||||
写入插件全局配置
|
||||
:param pconf: 全量插件配置
|
||||
"""
|
||||
global plugin_config
|
||||
for k in pconf:
|
||||
plugin_config[k.lower()] = pconf[k]
|
||||
|
||||
|
||||
def pconf(plugin_name: str) -> dict:
|
||||
"""
|
||||
根据插件名称获取配置
|
||||
:param plugin_name: 插件名称
|
||||
:return: 该插件的配置项
|
||||
"""
|
||||
return plugin_config.get(plugin_name.lower())
|
||||
|
||||
|
||||
# 全局配置,用于存放全局生效的状态
|
||||
global_config = {"admin_users": []}
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
FROM python:3.7.9-alpine
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
curl \
|
||||
wget \
|
||||
&& export BUILD_GITHUB_TAG=${CHATGPT_ON_WECHAT_VER:-`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`} \
|
||||
&& wget -t 3 -T 30 -nv -O chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/archive/refs/tags/${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& tar -xzf chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& mv chatgpt-on-wechat-${BUILD_GITHUB_TAG} ${BUILD_PREFIX} \
|
||||
&& rm chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache \
|
||||
itchat-uos==1.5.0.dev0 \
|
||||
openai \
|
||||
&& apk del curl wget
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \
|
||||
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,41 +0,0 @@
|
||||
FROM python:3.7.9
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
wget \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& export BUILD_GITHUB_TAG=${CHATGPT_ON_WECHAT_VER:-`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`} \
|
||||
&& wget -t 3 -T 30 -nv -O chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/archive/refs/tags/${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& tar -xzf chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& mv chatgpt-on-wechat-${BUILD_GITHUB_TAG} ${BUILD_PREFIX} \
|
||||
&& rm chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache \
|
||||
itchat-uos==1.5.0.dev0 \
|
||||
openai
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& groupadd -r noroot \
|
||||
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
|
||||
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,33 +1,35 @@
|
||||
FROM python:3.7.9-alpine
|
||||
FROM python:3.10-slim-bullseye
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
RUN echo /etc/apt/sources.list
|
||||
# RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
COPY chatgpt-on-wechat.tar.gz ./chatgpt-on-wechat.tar.gz
|
||||
ADD . ${BUILD_PREFIX}
|
||||
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
&& tar -xf chatgpt-on-wechat.tar.gz \
|
||||
&& mv chatgpt-on-wechat ${BUILD_PREFIX} \
|
||||
RUN apt-get update \
|
||||
&&apt-get install -y --no-install-recommends bash ffmpeg espeak libavcodec-extra\
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
||||
&& cp config-template.json config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache \
|
||||
itchat-uos==1.5.0.dev0 \
|
||||
openai
|
||||
&& pip install --no-cache -r requirements.txt \
|
||||
&& pip install --no-cache -r requirements-optional.txt \
|
||||
&& pip install azure-cognitiveservices-speech
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
ADD docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \
|
||||
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
||||
&& mkdir -p /home/noroot \
|
||||
&& groupadd -r noroot \
|
||||
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
|
||||
&& chown -R noroot:noroot /home/noroot ${BUILD_PREFIX} /usr/local/lib
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# fetch latest release tag
|
||||
CHATGPT_ON_WECHAT_TAG=`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`
|
||||
|
||||
# build image
|
||||
docker build -f Dockerfile.alpine \
|
||||
--build-arg CHATGPT_ON_WECHAT_VER=$CHATGPT_ON_WECHAT_TAG \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
|
||||
# tag image
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# fetch latest release tag
|
||||
CHATGPT_ON_WECHAT_TAG=`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`
|
||||
|
||||
# build image
|
||||
docker build -f Dockerfile.debian \
|
||||
--build-arg CHATGPT_ON_WECHAT_VER=$CHATGPT_ON_WECHAT_TAG \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
|
||||
# tag image
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian
|
||||
@@ -1,8 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
# move chatgpt-on-wechat
|
||||
tar -zcf chatgpt-on-wechat.tar.gz --exclude=../../chatgpt-on-wechat/docker ../../chatgpt-on-wechat
|
||||
unset KUBECONFIG
|
||||
|
||||
# build image
|
||||
docker build -f Dockerfile.latest \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
cd .. && docker build -f docker/Dockerfile.latest \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$(date +%y%m%d)
|
||||
@@ -1,23 +0,0 @@
|
||||
FROM zhayujie/chatgpt-on-wechat:alpine
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
USER root
|
||||
|
||||
RUN apk add --no-cache \
|
||||
ffmpeg \
|
||||
espeak \
|
||||
&& pip install --no-cache \
|
||||
baidu-aip \
|
||||
chardet \
|
||||
SpeechRecognition
|
||||
|
||||
# replace entrypoint
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,24 +0,0 @@
|
||||
FROM zhayujie/chatgpt-on-wechat:debian
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
USER root
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
espeak \
|
||||
&& pip install --no-cache \
|
||||
baidu-aip \
|
||||
chardet \
|
||||
SpeechRecognition
|
||||
|
||||
# replace entrypoint
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,24 +0,0 @@
|
||||
version: '2.0'
|
||||
services:
|
||||
chatgpt-on-wechat:
|
||||
build:
|
||||
context: ./
|
||||
dockerfile: Dockerfile.alpine
|
||||
image: zhayujie/chatgpt-on-wechat-voice-reply
|
||||
container_name: chatgpt-on-wechat-voice-reply
|
||||
environment:
|
||||
OPEN_AI_API_KEY: 'YOUR API KEY'
|
||||
OPEN_AI_PROXY: ''
|
||||
SINGLE_CHAT_PREFIX: '["bot", "@bot"]'
|
||||
SINGLE_CHAT_REPLY_PREFIX: '"[bot] "'
|
||||
GROUP_CHAT_PREFIX: '["@bot"]'
|
||||
GROUP_NAME_WHITE_LIST: '["ChatGPT测试群", "ChatGPT测试群2"]'
|
||||
IMAGE_CREATE_PREFIX: '["画", "看", "找"]'
|
||||
CONVERSATION_MAX_TOKENS: 1000
|
||||
SPEECH_RECOGNITION: 'true'
|
||||
CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
|
||||
EXPIRES_IN_SECONDS: 3600
|
||||
VOICE_REPLY_VOICE: 'true'
|
||||
BAIDU_APP_ID: 'YOUR BAIDU APP ID'
|
||||
BAIDU_API_KEY: 'YOUR BAIDU API KEY'
|
||||
BAIDU_SECRET_KEY: 'YOUR BAIDU SERVICE KEY'
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# build prefix
|
||||
CHATGPT_ON_WECHAT_PREFIX=${CHATGPT_ON_WECHAT_PREFIX:-""}
|
||||
# path to config.json
|
||||
CHATGPT_ON_WECHAT_CONFIG_PATH=${CHATGPT_ON_WECHAT_CONFIG_PATH:-""}
|
||||
# execution command line
|
||||
CHATGPT_ON_WECHAT_EXEC=${CHATGPT_ON_WECHAT_EXEC:-""}
|
||||
|
||||
OPEN_AI_API_KEY=${OPEN_AI_API_KEY:-""}
|
||||
OPEN_AI_PROXY=${OPEN_AI_PROXY:-""}
|
||||
SINGLE_CHAT_PREFIX=${SINGLE_CHAT_PREFIX:-""}
|
||||
SINGLE_CHAT_REPLY_PREFIX=${SINGLE_CHAT_REPLY_PREFIX:-""}
|
||||
GROUP_CHAT_PREFIX=${GROUP_CHAT_PREFIX:-""}
|
||||
GROUP_NAME_WHITE_LIST=${GROUP_NAME_WHITE_LIST:-""}
|
||||
IMAGE_CREATE_PREFIX=${IMAGE_CREATE_PREFIX:-""}
|
||||
CONVERSATION_MAX_TOKENS=${CONVERSATION_MAX_TOKENS:-""}
|
||||
SPEECH_RECOGNITION=${SPEECH_RECOGNITION:-""}
|
||||
CHARACTER_DESC=${CHARACTER_DESC:-""}
|
||||
EXPIRES_IN_SECONDS=${EXPIRES_IN_SECONDS:-""}
|
||||
|
||||
VOICE_REPLY_VOICE=${VOICE_REPLY_VOICE:-""}
|
||||
BAIDU_APP_ID=${BAIDU_APP_ID:-""}
|
||||
BAIDU_API_KEY=${BAIDU_API_KEY:-""}
|
||||
BAIDU_SECRET_KEY=${BAIDU_SECRET_KEY:-""}
|
||||
|
||||
# CHATGPT_ON_WECHAT_PREFIX is empty, use /app
|
||||
if [ "$CHATGPT_ON_WECHAT_PREFIX" == "" ] ; then
|
||||
CHATGPT_ON_WECHAT_PREFIX=/app
|
||||
fi
|
||||
|
||||
# CHATGPT_ON_WECHAT_CONFIG_PATH is empty, use '/app/config.json'
|
||||
if [ "$CHATGPT_ON_WECHAT_CONFIG_PATH" == "" ] ; then
|
||||
CHATGPT_ON_WECHAT_CONFIG_PATH=$CHATGPT_ON_WECHAT_PREFIX/config.json
|
||||
fi
|
||||
|
||||
# CHATGPT_ON_WECHAT_EXEC is empty, use ‘python app.py’
|
||||
if [ "$CHATGPT_ON_WECHAT_EXEC" == "" ] ; then
|
||||
CHATGPT_ON_WECHAT_EXEC="python app.py"
|
||||
fi
|
||||
|
||||
# modify content in config.json
|
||||
if [ "$OPEN_AI_API_KEY" != "" ] ; then
|
||||
sed -i "s/\"open_ai_api_key\".*,$/\"open_ai_api_key\": \"$OPEN_AI_API_KEY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
else
|
||||
echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m"
|
||||
fi
|
||||
|
||||
# use http_proxy as default
|
||||
if [ "$HTTP_PROXY" != "" ] ; then
|
||||
sed -i "s/\"proxy\".*,$/\"proxy\": \"$HTTP_PROXY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$OPEN_AI_PROXY" != "" ] ; then
|
||||
sed -i "s/\"proxy\".*,$/\"proxy\": \"$OPEN_AI_PROXY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$SINGLE_CHAT_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"single_chat_prefix\".*,$/\"single_chat_prefix\": $SINGLE_CHAT_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$SINGLE_CHAT_REPLY_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"single_chat_reply_prefix\".*,$/\"single_chat_reply_prefix\": $SINGLE_CHAT_REPLY_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$GROUP_CHAT_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"group_chat_prefix\".*,$/\"group_chat_prefix\": $GROUP_CHAT_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$GROUP_NAME_WHITE_LIST" != "" ] ; then
|
||||
sed -i "s/\"group_name_white_list\".*,$/\"group_name_white_list\": $GROUP_NAME_WHITE_LIST,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$IMAGE_CREATE_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"image_create_prefix\".*,$/\"image_create_prefix\": $IMAGE_CREATE_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$CONVERSATION_MAX_TOKENS" != "" ] ; then
|
||||
sed -i "s/\"conversation_max_tokens\".*,$/\"conversation_max_tokens\": $CONVERSATION_MAX_TOKENS,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$SPEECH_RECOGNITION" != "" ] ; then
|
||||
sed -i "s/\"speech_recognition\".*,$/\"speech_recognition\": $SPEECH_RECOGNITION,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$CHARACTER_DESC" != "" ] ; then
|
||||
sed -i "s/\"character_desc\".*,$/\"character_desc\": \"$CHARACTER_DESC\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$EXPIRES_IN_SECONDS" != "" ] ; then
|
||||
sed -i "s/\"expires_in_seconds\".*$/\"expires_in_seconds\": $EXPIRES_IN_SECONDS/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
# append
|
||||
if [ "$BAIDU_SECRET_KEY" != "" ] ; then
|
||||
sed -i "1a \ \ \"baidu_secret_key\": \"$BAIDU_SECRET_KEY\"," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$BAIDU_API_KEY" != "" ] ; then
|
||||
sed -i "1a \ \ \"baidu_api_key\": \"$BAIDU_API_KEY\"," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$BAIDU_APP_ID" != "" ] ; then
|
||||
sed -i "1a \ \ \"baidu_app_id\": \"$BAIDU_APP_ID\"," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$VOICE_REPLY_VOICE" != "" ] ; then
|
||||
sed -i "1a \ \ \"voice_reply_voice\": $VOICE_REPLY_VOICE," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
# go to prefix dir
|
||||
cd $CHATGPT_ON_WECHAT_PREFIX
|
||||
# excute
|
||||
$CHATGPT_ON_WECHAT_EXEC
|
||||
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
version: '2.0'
|
||||
services:
|
||||
chatgpt-on-wechat:
|
||||
build:
|
||||
context: ./
|
||||
dockerfile: Dockerfile.alpine
|
||||
image: zhayujie/chatgpt-on-wechat
|
||||
container_name: sample-chatgpt-on-wechat
|
||||
container_name: chatgpt-on-wechat
|
||||
security_opt:
|
||||
- seccomp:unconfined
|
||||
environment:
|
||||
TZ: 'Asia/Shanghai'
|
||||
OPEN_AI_API_KEY: 'YOUR API KEY'
|
||||
OPEN_AI_PROXY: ''
|
||||
MODEL: 'gpt-3.5-turbo'
|
||||
PROXY: ''
|
||||
SINGLE_CHAT_PREFIX: '["bot", "@bot"]'
|
||||
SINGLE_CHAT_REPLY_PREFIX: '"[bot] "'
|
||||
GROUP_CHAT_PREFIX: '["@bot"]'
|
||||
GROUP_NAME_WHITE_LIST: '["ChatGPT测试群", "ChatGPT测试群2"]'
|
||||
IMAGE_CREATE_PREFIX: '["画", "看", "找"]'
|
||||
CONVERSATION_MAX_TOKENS: 1000
|
||||
SPEECH_RECOGNITION: "False"
|
||||
SPEECH_RECOGNITION: 'False'
|
||||
CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
|
||||
EXPIRES_IN_SECONDS: 3600
|
||||
EXPIRES_IN_SECONDS: 3600
|
||||
USE_GLOBAL_PLUGIN_CONFIG: 'True'
|
||||
USE_LINKAI: 'False'
|
||||
LINKAI_API_KEY: ''
|
||||
LINKAI_APP_CODE: ''
|
||||
@@ -10,17 +10,17 @@ CHATGPT_ON_WECHAT_EXEC=${CHATGPT_ON_WECHAT_EXEC:-""}
|
||||
|
||||
# use environment variables to pass parameters
|
||||
# if you have not defined environment variables, set them below
|
||||
export OPEN_AI_API_KEY=${OPEN_AI_API_KEY:-'YOUR API KEY'}
|
||||
export OPEN_AI_PROXY=${OPEN_AI_PROXY:-""}
|
||||
export SINGLE_CHAT_PREFIX=${SINGLE_CHAT_PREFIX:-'["bot", "@bot"]'}
|
||||
export SINGLE_CHAT_REPLY_PREFIX=${SINGLE_CHAT_REPLY_PREFIX:-'"[bot] "'}
|
||||
export GROUP_CHAT_PREFIX=${GROUP_CHAT_PREFIX:-'["@bot"]'}
|
||||
export GROUP_NAME_WHITE_LIST=${GROUP_NAME_WHITE_LIST:-'["ChatGPT测试群", "ChatGPT测试群2"]'}
|
||||
export IMAGE_CREATE_PREFIX=${IMAGE_CREATE_PREFIX:-'["画", "看", "找"]'}
|
||||
export CONVERSATION_MAX_TOKENS=${CONVERSATION_MAX_TOKENS:-"1000"}
|
||||
export SPEECH_RECOGNITION=${SPEECH_RECOGNITION:-"False"}
|
||||
export CHARACTER_DESC=${CHARACTER_DESC:-"你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"}
|
||||
export EXPIRES_IN_SECONDS=${EXPIRES_IN_SECONDS:-"3600"}
|
||||
# export OPEN_AI_API_KEY=${OPEN_AI_API_KEY:-'YOUR API KEY'}
|
||||
# export OPEN_AI_PROXY=${OPEN_AI_PROXY:-""}
|
||||
# export SINGLE_CHAT_PREFIX=${SINGLE_CHAT_PREFIX:-'["bot", "@bot"]'}
|
||||
# export SINGLE_CHAT_REPLY_PREFIX=${SINGLE_CHAT_REPLY_PREFIX:-'"[bot] "'}
|
||||
# export GROUP_CHAT_PREFIX=${GROUP_CHAT_PREFIX:-'["@bot"]'}
|
||||
# export GROUP_NAME_WHITE_LIST=${GROUP_NAME_WHITE_LIST:-'["ChatGPT测试群", "ChatGPT测试群2"]'}
|
||||
# export IMAGE_CREATE_PREFIX=${IMAGE_CREATE_PREFIX:-'["画", "看", "找"]'}
|
||||
# export CONVERSATION_MAX_TOKENS=${CONVERSATION_MAX_TOKENS:-"1000"}
|
||||
# export SPEECH_RECOGNITION=${SPEECH_RECOGNITION:-"False"}
|
||||
# export CHARACTER_DESC=${CHARACTER_DESC:-"你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"}
|
||||
# export EXPIRES_IN_SECONDS=${EXPIRES_IN_SECONDS:-"3600"}
|
||||
|
||||
# CHATGPT_ON_WECHAT_PREFIX is empty, use /app
|
||||
if [ "$CHATGPT_ON_WECHAT_PREFIX" == "" ] ; then
|
||||
@@ -38,9 +38,9 @@ if [ "$CHATGPT_ON_WECHAT_EXEC" == "" ] ; then
|
||||
fi
|
||||
|
||||
# modify content in config.json
|
||||
if [ "$OPEN_AI_API_KEY" == "YOUR API KEY" ] ; then
|
||||
echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m"
|
||||
fi
|
||||
# if [ "$OPEN_AI_API_KEY" == "YOUR API KEY" ] || [ "$OPEN_AI_API_KEY" == "" ]; then
|
||||
# echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m"
|
||||
# fi
|
||||
|
||||
|
||||
# go to prefix dir
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
OPEN_AI_API_KEY=YOUR API KEY
|
||||
OPEN_AI_PROXY=
|
||||
SINGLE_CHAT_PREFIX=["bot", "@bot"]
|
||||
SINGLE_CHAT_REPLY_PREFIX="[bot] "
|
||||
GROUP_CHAT_PREFIX=["@bot"]
|
||||
GROUP_NAME_WHITE_LIST=["ChatGPT测试群", "ChatGPT测试群2"]
|
||||
IMAGE_CREATE_PREFIX=["画", "看", "找"]
|
||||
CONVERSATION_MAX_TOKENS=1000
|
||||
SPEECH_RECOGNITION=false
|
||||
CHARACTER_DESC=你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。
|
||||
EXPIRES_IN_SECONDS=3600
|
||||
|
||||
# Optional
|
||||
#CHATGPT_ON_WECHAT_PREFIX=/app
|
||||
#CHATGPT_ON_WECHAT_CONFIG_PATH=/app/config.json
|
||||
#CHATGPT_ON_WECHAT_EXEC=python app.py
|
||||
@@ -1,26 +0,0 @@
|
||||
IMG:=`cat Name`
|
||||
MOUNT:=
|
||||
PORT_MAP:=
|
||||
DOTENV:=.env
|
||||
CONTAINER_NAME:=sample-chatgpt-on-wechat
|
||||
|
||||
echo:
|
||||
echo $(IMG)
|
||||
|
||||
run_d:
|
||||
docker rm $(CONTAINER_NAME) || echo
|
||||
docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \
|
||||
--env-file=$(DOTENV) \
|
||||
$(MOUNT) $(IMG)
|
||||
|
||||
run_i:
|
||||
docker rm $(CONTAINER_NAME) || echo
|
||||
docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \
|
||||
--env-file=$(DOTENV) \
|
||||
$(MOUNT) $(IMG)
|
||||
|
||||
stop:
|
||||
docker stop $(CONTAINER_NAME)
|
||||
|
||||
rm: stop
|
||||
docker rm $(CONTAINER_NAME)
|
||||
@@ -1 +0,0 @@
|
||||
zhayujie/chatgpt-on-wechat
|
||||
BIN
docs/images/contact.jpg
Normal file
BIN
docs/images/contact.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 151 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 326 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 382 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user