mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-03 02:27:09 +08:00
Compare commits
620 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eae95dfef5 | ||
|
|
11f6e98874 | ||
|
|
2609e595f4 | ||
|
|
ac6e41abc8 | ||
|
|
9c17e16d0a | ||
|
|
91cabd7d49 | ||
|
|
7456950530 | ||
|
|
8fcdda625d | ||
|
|
40a10ee926 | ||
|
|
c3f7e2645c | ||
|
|
b264af1892 | ||
|
|
43e93e8e22 | ||
|
|
d6c4789688 | ||
|
|
cb31ee6f01 | ||
|
|
f7b694ac56 | ||
|
|
eb809055d4 | ||
|
|
78d9be82b2 | ||
|
|
76a95c0226 | ||
|
|
d3ab8fb04a | ||
|
|
f7a0b63a00 | ||
|
|
a21dd97786 | ||
|
|
04943c0bfa | ||
|
|
203d4d8bfb | ||
|
|
c049a619dc | ||
|
|
cc1b14b607 | ||
|
|
e04a12a8f4 | ||
|
|
a2c82bc583 | ||
|
|
b4dc382f7c | ||
|
|
eca1892e2a | ||
|
|
23a237074e | ||
|
|
219e9eca4f | ||
|
|
413e09fb9e | ||
|
|
3514c37e4c | ||
|
|
95260e303c | ||
|
|
0cef34bdfa | ||
|
|
9838979bbd | ||
|
|
c8910b8e14 | ||
|
|
207fa1d019 | ||
|
|
be0bb591e7 | ||
|
|
bfacdb9c3b | ||
|
|
ae4077ed6c | ||
|
|
6eb3c90e18 | ||
|
|
8c2a53a504 | ||
|
|
74db1e0308 | ||
|
|
b9dfdcef3d | ||
|
|
9d4afeac31 | ||
|
|
14ae2f169a | ||
|
|
55df19142f | ||
|
|
40fd545b2c | ||
|
|
95fb07343e | ||
|
|
4d87906559 | ||
|
|
6b30dced43 | ||
|
|
293a03b7c8 | ||
|
|
c010549f17 | ||
|
|
cc0be22026 | ||
|
|
e5ba26febe | ||
|
|
36f9680eec | ||
|
|
f4f5be5b08 | ||
|
|
d89b056886 | ||
|
|
65424c7db9 | ||
|
|
32a8a847fc | ||
|
|
88fb3dbf60 | ||
|
|
f6bee3aa58 | ||
|
|
5f19f37dcb | ||
|
|
dd36d8ce9e | ||
|
|
865e4b5349 | ||
|
|
e70564752b | ||
|
|
6e0d2f9437 | ||
|
|
291f936097 | ||
|
|
0b2ce48586 | ||
|
|
da87fd9e20 | ||
|
|
d4da4d2575 | ||
|
|
bad20ff483 | ||
|
|
21ad51ffbf | ||
|
|
697c6d5fbe | ||
|
|
293c659053 | ||
|
|
a12507abbd | ||
|
|
4e675b84fb | ||
|
|
c1022feab8 | ||
|
|
ddcfcf21fe | ||
|
|
86a58c3d80 | ||
|
|
abf9a9048d | ||
|
|
b1030a527a | ||
|
|
8d07ba6332 | ||
|
|
4ce37f84e4 | ||
|
|
061d8a3a5f | ||
|
|
374cd5dbb8 | ||
|
|
5ad53c2b9c | ||
|
|
a2ec1a063d | ||
|
|
e431dbe2df | ||
|
|
7218463f9e | ||
|
|
aeb09a95b0 | ||
|
|
0c8f292e12 | ||
|
|
f001ac6903 | ||
|
|
db8e506de0 | ||
|
|
099f859dd4 | ||
|
|
b7684c1c2b | ||
|
|
058c167f79 | ||
|
|
49446d4872 | ||
|
|
ced560e1e1 | ||
|
|
339102c3cd | ||
|
|
6331350239 | ||
|
|
34e06fcbf8 | ||
|
|
70aac312ff | ||
|
|
5e00704152 | ||
|
|
1a9edb6907 | ||
|
|
0c18c3a6dd | ||
|
|
847bb51ce4 | ||
|
|
fa60a5dc63 | ||
|
|
aaed3f9839 | ||
|
|
21b956b983 | ||
|
|
792e940279 | ||
|
|
c2477b26c0 | ||
|
|
4b27de809b | ||
|
|
572932d8e8 | ||
|
|
270dd778d9 | ||
|
|
dd04287b0a | ||
|
|
36ac6d005a | ||
|
|
701daedf49 | ||
|
|
238f05f453 | ||
|
|
dd082bd212 | ||
|
|
cfd2f27b0b | ||
|
|
a2160d135e | ||
|
|
16d7836369 | ||
|
|
f3de4dcc5f | ||
|
|
e34523028f | ||
|
|
efe2fbacd6 | ||
|
|
2fa1df29be | ||
|
|
f72cd13fba | ||
|
|
5b552dffbf | ||
|
|
a0ae2d13dc | ||
|
|
f7262a0a3a | ||
|
|
9736f121eb | ||
|
|
7c8fb7eacc | ||
|
|
b45eea5908 | ||
|
|
6babf4ee6c | ||
|
|
576526d4ee | ||
|
|
c03e31b7be | ||
|
|
a1aa925019 | ||
|
|
a5a234ed97 | ||
|
|
5b5dbcd78b | ||
|
|
bd1c6361d3 | ||
|
|
1fc1febf03 | ||
|
|
55cc35efa9 | ||
|
|
5ba8fdc5e7 | ||
|
|
6ea295e227 | ||
|
|
5010c76ef7 | ||
|
|
79c7f0c29f | ||
|
|
2b3e643786 | ||
|
|
90cdff327c | ||
|
|
55c116e727 | ||
|
|
3dd83aa6b7 | ||
|
|
a74aa12641 | ||
|
|
151e8c69f9 | ||
|
|
d8bfa77705 | ||
|
|
6bd286e8d5 | ||
|
|
905532b681 | ||
|
|
04d5c1ab01 | ||
|
|
28be141dc7 | ||
|
|
652b786baf | ||
|
|
ba6c671051 | ||
|
|
ca25d0433f | ||
|
|
5338106dfa | ||
|
|
b6b76be4f6 | ||
|
|
03d94fcfa0 | ||
|
|
b2c5f0d455 | ||
|
|
54f60dd38c | ||
|
|
42f181aca2 | ||
|
|
9c3a27894f | ||
|
|
f7cd348912 | ||
|
|
aeaeb75d3b | ||
|
|
96542b532e | ||
|
|
139295fe0d | ||
|
|
13217b2ce2 | ||
|
|
5cc8b56a7c | ||
|
|
e23e01c95e | ||
|
|
bca8ba12c7 | ||
|
|
3c44bdbe1c | ||
|
|
db93ed025b | ||
|
|
4209e108d0 | ||
|
|
14cbf011af | ||
|
|
03a41ec199 | ||
|
|
125fe2a026 | ||
|
|
ac4adac29e | ||
|
|
ac449d078e | ||
|
|
79be4530d4 | ||
|
|
85ce52d70c | ||
|
|
7ab56b9076 | ||
|
|
dedf976375 | ||
|
|
89f438208a | ||
|
|
ffbc5080ae | ||
|
|
4167f13bac | ||
|
|
6ba0baabb0 | ||
|
|
081003df47 | ||
|
|
559194ffb2 | ||
|
|
97a26d4a46 | ||
|
|
503c6c9b7e | ||
|
|
9a1e10deff | ||
|
|
054f927c05 | ||
|
|
22210747d0 | ||
|
|
53b2deb72c | ||
|
|
6fc158e7d6 | ||
|
|
a23a65c731 | ||
|
|
7dc7105ee2 | ||
|
|
bac70108b2 | ||
|
|
297404b21e | ||
|
|
33a7f8b558 | ||
|
|
4a670b7df7 | ||
|
|
79e4af315e | ||
|
|
c6e31b2fdc | ||
|
|
91dc44df53 | ||
|
|
7e57f8f157 | ||
|
|
15f6b7c6d3 | ||
|
|
b213ba541d | ||
|
|
7c6ed9944e | ||
|
|
a5a825e439 | ||
|
|
a4ab547f77 | ||
|
|
76ed763abe | ||
|
|
b9e3125610 | ||
|
|
8d9d5b7b6f | ||
|
|
187601da1e | ||
|
|
cc3a0fc367 | ||
|
|
44cc4165d1 | ||
|
|
f98b43514e | ||
|
|
3c9b1a14e9 | ||
|
|
827e8eddf8 | ||
|
|
7bc27d6167 | ||
|
|
ba06edd63a | ||
|
|
cacf553a5b | ||
|
|
d89091a8ea | ||
|
|
01a56e1155 | ||
|
|
a64d7c42b1 | ||
|
|
36b6cc58bf | ||
|
|
5ac8a257e7 | ||
|
|
74119d0372 | ||
|
|
4e162c73e5 | ||
|
|
5ff753a492 | ||
|
|
89400630c0 | ||
|
|
3899c0cfe3 | ||
|
|
a086f1989f | ||
|
|
1171b04e93 | ||
|
|
c55d81825a | ||
|
|
2dcd026e9f | ||
|
|
cdf8609d24 | ||
|
|
36580c5f7f | ||
|
|
1cff2521f4 | ||
|
|
db4998a56b | ||
|
|
acbd506568 | ||
|
|
0cf8e3be73 | ||
|
|
2473334dfc | ||
|
|
1ff72d1d37 | ||
|
|
241fad5524 | ||
|
|
1b48cea50a | ||
|
|
88bf345b91 | ||
|
|
ab4ff3d1a3 | ||
|
|
3502e0d643 | ||
|
|
995894d3aa | ||
|
|
4da8714124 | ||
|
|
6b247ae880 | ||
|
|
176941ea3b | ||
|
|
5176b56d3b | ||
|
|
8abf18ab25 | ||
|
|
395edbd9f4 | ||
|
|
2386eb8fc2 | ||
|
|
68208f82a0 | ||
|
|
ca916b7ce5 | ||
|
|
01e02934da | ||
|
|
c81a79f7b9 | ||
|
|
1133648bf6 | ||
|
|
e05bc541d7 | ||
|
|
d689d20482 | ||
|
|
39dd99b272 | ||
|
|
cda21acb43 | ||
|
|
9bd7d09f20 | ||
|
|
b22994c2d2 | ||
|
|
e027286b6d | ||
|
|
d6e16995e0 | ||
|
|
782bff3a51 | ||
|
|
de26dc0597 | ||
|
|
233b24ab0f | ||
|
|
2f9e5b1219 | ||
|
|
dd36b8b150 | ||
|
|
f81ac31fe1 | ||
|
|
24b63bc5bd | ||
|
|
1817a972c6 | ||
|
|
74a253f521 | ||
|
|
41762a1c57 | ||
|
|
a786fa4b75 | ||
|
|
e4c7602c0c | ||
|
|
e0d2e34980 | ||
|
|
9ef8e1be3f | ||
|
|
aae9b64833 | ||
|
|
4bab4299f2 | ||
|
|
954e55f4b4 | ||
|
|
2361e3c28c | ||
|
|
8224c2fc16 | ||
|
|
8aac86f0a9 | ||
|
|
6384e9310b | ||
|
|
7a9205dfba | ||
|
|
94b47a56f4 | ||
|
|
709b5be634 | ||
|
|
f970b2c168 | ||
|
|
973acb37ed | ||
|
|
1c9020a565 | ||
|
|
c5f1d0042c | ||
|
|
fa706e8b1d | ||
|
|
12c170f227 | ||
|
|
db27dfe227 | ||
|
|
2db4673392 | ||
|
|
38619db629 | ||
|
|
930fd436ea | ||
|
|
98b8ff2fc8 | ||
|
|
d0662683f9 | ||
|
|
957f2574a9 | ||
|
|
109b362ebd | ||
|
|
ff3fdfa738 | ||
|
|
e2636ed54a | ||
|
|
dbe2f17e1a | ||
|
|
4dc535673f | ||
|
|
f414b6408e | ||
|
|
3aa2e6a04d | ||
|
|
1963ff273f | ||
|
|
bb737a71d5 | ||
|
|
a582a46ce9 | ||
|
|
abf80a3266 | ||
|
|
d768f5c66d | ||
|
|
b25e843351 | ||
|
|
419a3e518e | ||
|
|
d1b867a7c0 | ||
|
|
c34d70b3cb | ||
|
|
a33df9312f | ||
|
|
ebf8db0b37 | ||
|
|
e539ae3b69 | ||
|
|
4c5e8850aa | ||
|
|
94c0af3037 | ||
|
|
165182c68f | ||
|
|
65b9542599 | ||
|
|
d01d1f8830 | ||
|
|
ad3e9f3d42 | ||
|
|
4589974095 | ||
|
|
ed4553ddf8 | ||
|
|
ff97ae73f1 | ||
|
|
f96b4d2781 | ||
|
|
ce32cfffdb | ||
|
|
f66df8531e | ||
|
|
dfe1c23e76 | ||
|
|
07fd81919f | ||
|
|
210042bb81 | ||
|
|
12dc7427e9 | ||
|
|
b476085110 | ||
|
|
776cdaf63c | ||
|
|
69b6855745 | ||
|
|
3590babd8b | ||
|
|
c29d391c1d | ||
|
|
50e44dbb2a | ||
|
|
34277a3940 | ||
|
|
f1a00d58ca | ||
|
|
d1a5f17ae8 | ||
|
|
4dbc54fa15 | ||
|
|
1d4ff796d7 | ||
|
|
44cb54a9ea | ||
|
|
6409f49609 | ||
|
|
9ee0ea88b5 | ||
|
|
a3819d8673 | ||
|
|
2d7dd71a3d | ||
|
|
0e8195ae61 | ||
|
|
3e92d07618 | ||
|
|
e59597280d | ||
|
|
f2e3d69d8a | ||
|
|
9d2cb75c84 | ||
|
|
f971505c4a | ||
|
|
2133c1d6af | ||
|
|
0bf06ddfd3 | ||
|
|
024a50d642 | ||
|
|
e4eebd64d1 | ||
|
|
c9055989e9 | ||
|
|
4f1ed197ce | ||
|
|
3e710aa2a1 | ||
|
|
b6226a45bb | ||
|
|
3001ba9266 | ||
|
|
b0a401a1ed | ||
|
|
6b4dc37428 | ||
|
|
8528c9b262 | ||
|
|
7222a5c2f4 | ||
|
|
59050001ef | ||
|
|
2ba8f18724 | ||
|
|
fb22e01b89 | ||
|
|
76a81d5360 | ||
|
|
3314b05648 | ||
|
|
45b89218de | ||
|
|
beb7bda243 | ||
|
|
bef2896f50 | ||
|
|
9fea949b25 | ||
|
|
be258e5b05 | ||
|
|
008178d737 | ||
|
|
527d5e1dbc | ||
|
|
9b47e2d6f9 | ||
|
|
8781b1e976 | ||
|
|
38c653d8d8 | ||
|
|
74e48bb137 | ||
|
|
c3aaa1f735 | ||
|
|
bead2aa228 | ||
|
|
dc52ab8aa9 | ||
|
|
20b71f206b | ||
|
|
73c87d5959 | ||
|
|
c6601aaeed | ||
|
|
6e14fce1fe | ||
|
|
be5a62f1b8 | ||
|
|
1fa8cefaea | ||
|
|
d7c251ac83 | ||
|
|
d03229a183 | ||
|
|
243482e829 | ||
|
|
79d10be8a0 | ||
|
|
dca5c058e0 | ||
|
|
9163ce71fd | ||
|
|
2ec5374765 | ||
|
|
d6a4b35cd3 | ||
|
|
8205d2552c | ||
|
|
9a99caeb9d | ||
|
|
1e09bd0e76 | ||
|
|
cae12eb187 | ||
|
|
8bb36e0eb6 | ||
|
|
d183204caa | ||
|
|
4a22ae6b61 | ||
|
|
a52f54d988 | ||
|
|
618c94edb8 | ||
|
|
eaf4e9174f | ||
|
|
4af2c7f3d7 | ||
|
|
361f599df0 | ||
|
|
ffe4ea5e4c | ||
|
|
9461e3e01a | ||
|
|
7c85c6f742 | ||
|
|
b5df6faadf | ||
|
|
7cefe2d825 | ||
|
|
350633b69b | ||
|
|
1cd6a71ce0 | ||
|
|
3a08b002a0 | ||
|
|
665001732b | ||
|
|
cca49da730 | ||
|
|
f6d370ad29 | ||
|
|
c9131b333b | ||
|
|
e44161bf42 | ||
|
|
a26189fb25 | ||
|
|
89dd8a1db6 | ||
|
|
650e0b4ad4 | ||
|
|
c60f0517fb | ||
|
|
0f8dc91a8b | ||
|
|
b58feb5d8e | ||
|
|
71c8043699 | ||
|
|
40264bc9cb | ||
|
|
a7772316f9 | ||
|
|
34209021c8 | ||
|
|
3e9e8d442a | ||
|
|
d2bf90c6c7 | ||
|
|
1e58c1ad2b | ||
|
|
8cea022ec5 | ||
|
|
f32f8aa08e | ||
|
|
3ea8781381 | ||
|
|
ab83dacb76 | ||
|
|
4cbf46fd4d | ||
|
|
0a7d6e4577 | ||
|
|
df4c1f0401 | ||
|
|
9a86a67984 | ||
|
|
a0cbe9c3e2 | ||
|
|
a83e5a9b65 | ||
|
|
de33911460 | ||
|
|
0be56e5b25 | ||
|
|
abcbb34b1c | ||
|
|
6a13dd04a3 | ||
|
|
f2e29f3f2e | ||
|
|
68361cddd2 | ||
|
|
6404332adc | ||
|
|
e060b6fea2 | ||
|
|
e8aae27ee9 | ||
|
|
2f732e5493 | ||
|
|
65f20ff2c1 | ||
|
|
8f72e8c3e6 | ||
|
|
3b8972ce1f | ||
|
|
fc5d3e4e9c | ||
|
|
29fbf69945 | ||
|
|
583440b82b | ||
|
|
720de9d73f | ||
|
|
78332d882b | ||
|
|
2dfbc840b3 | ||
|
|
0b4bf15163 | ||
|
|
2989249e4b | ||
|
|
9cef559a05 | ||
|
|
47fe16c92a | ||
|
|
36b5c821ff | ||
|
|
82ec440b45 | ||
|
|
88f4a45cae | ||
|
|
7fb4f72b84 | ||
|
|
d4fc322101 | ||
|
|
8fa3da9ca5 | ||
|
|
68ef5aa3ae | ||
|
|
28bd917c9f | ||
|
|
0eb1b94300 | ||
|
|
15e6cf850b | ||
|
|
ee91c86a29 | ||
|
|
48c08f4aad | ||
|
|
fceabb8e67 | ||
|
|
fcfafb05f1 | ||
|
|
f1e8344beb | ||
|
|
f687b2b6f4 | ||
|
|
8ee7a48151 | ||
|
|
89e8f385b4 | ||
|
|
bf4ae9a051 | ||
|
|
6bd1242d43 | ||
|
|
8779eab36b | ||
|
|
3174b1158c | ||
|
|
18740093d1 | ||
|
|
8c7d1d4010 | ||
|
|
8c48a27e1a | ||
|
|
4278d2b8ef | ||
|
|
3a3affd3ec | ||
|
|
45d72b8b9b | ||
|
|
03b908c079 | ||
|
|
d35d01f980 | ||
|
|
9c208ffa2c | ||
|
|
bea4416f12 | ||
|
|
2ea8b4ef73 | ||
|
|
e6946ef989 | ||
|
|
9aeb60f66d | ||
|
|
d687f9329e | ||
|
|
3207258fd9 | ||
|
|
d8b75206fe | ||
|
|
88e8dd5162 | ||
|
|
c9306633b2 | ||
|
|
c50d1cc99d | ||
|
|
9a20c1cb02 | ||
|
|
176f77ba5b | ||
|
|
484de6237b | ||
|
|
898aa30b1d | ||
|
|
8b73a74609 | ||
|
|
3c6d42b22e | ||
|
|
40563c1e96 | ||
|
|
cb0c86ec1c | ||
|
|
614f3b1ea4 | ||
|
|
938e3b5cf2 | ||
|
|
5fe8d9a855 | ||
|
|
8193ecf5f6 | ||
|
|
1dff630257 | ||
|
|
eaac3e3579 | ||
|
|
d3758968d0 | ||
|
|
020f9a8d98 | ||
|
|
9d8ae80548 | ||
|
|
7e7484a27d | ||
|
|
0adf8d6e5d | ||
|
|
1a981ea970 | ||
|
|
5bd9f50818 | ||
|
|
44f6892cb7 | ||
|
|
fdf6b0dc6b | ||
|
|
a7914279a9 | ||
|
|
2cf71dd6f2 | ||
|
|
62e3baba20 | ||
|
|
e00c99c1d7 | ||
|
|
31d5b95611 | ||
|
|
cc881adda6 | ||
|
|
78d4c58b70 | ||
|
|
eca369532d | ||
|
|
9520d94b13 | ||
|
|
f973bc3fe2 | ||
|
|
94004b095b | ||
|
|
f652d592bd | ||
|
|
186e18fe94 | ||
|
|
28eb67bc24 | ||
|
|
6c7e4aaf37 | ||
|
|
709a1317ef | ||
|
|
371e38cfa6 | ||
|
|
5a221848e9 | ||
|
|
6901c5ba56 | ||
|
|
21a3b0d9a1 | ||
|
|
29422edcc9 | ||
|
|
2da1c18b71 | ||
|
|
be592cc290 | ||
|
|
ce8635dd99 | ||
|
|
76783f0ad3 | ||
|
|
441228e200 | ||
|
|
45a131aa0d | ||
|
|
a7900d4b2c | ||
|
|
a4b1d7446a | ||
|
|
7458a6298f | ||
|
|
b0f54bb8b7 | ||
|
|
acddadc406 | ||
|
|
761fb20dd9 | ||
|
|
b74274b96b | ||
|
|
7835379f8f | ||
|
|
49ba278316 | ||
|
|
388058467c | ||
|
|
cf25bd7869 | ||
|
|
02a95345aa | ||
|
|
6076e2ed0a | ||
|
|
cec674cb47 | ||
|
|
c5a90823fa | ||
|
|
18d82bc1f0 | ||
|
|
a68af990ea | ||
|
|
e71c600d10 | ||
|
|
d7f1f7182c | ||
|
|
dfb2e460b4 | ||
|
|
5badef8ba9 | ||
|
|
18aa5ce75c | ||
|
|
1545a9f262 | ||
|
|
47cc65a787 | ||
|
|
cda9d5873d | ||
|
|
02cd553990 | ||
|
|
71d288f550 | ||
|
|
87df588c80 | ||
|
|
4ad2997717 | ||
|
|
50a03e7c15 | ||
|
|
4f3d12129c | ||
|
|
37a95980d4 | ||
|
|
f49806558e | ||
|
|
8da362d6fe | ||
|
|
bf02a59aec | ||
|
|
461777cad3 | ||
|
|
0597ba20d2 | ||
|
|
0b5fd27cd8 | ||
|
|
f5f8033d4d | ||
|
|
a5f7dec011 | ||
|
|
d9ef5a6612 |
13
.flake8
Normal file
13
.flake8
Normal file
@@ -0,0 +1,13 @@
|
||||
[flake8]
|
||||
max-line-length = 176
|
||||
select = E303,W293,W291,W292,E305,E231,E302
|
||||
exclude =
|
||||
.tox,
|
||||
__pycache__,
|
||||
*.pyc,
|
||||
.env
|
||||
venv/*
|
||||
.venv/*
|
||||
reports/*
|
||||
dist/*
|
||||
lib/*
|
||||
30
.github/ISSUE_TEMPLATE.md
vendored
30
.github/ISSUE_TEMPLATE.md
vendored
@@ -1,30 +0,0 @@
|
||||
### 前置确认
|
||||
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
||||
3. `git pull` 拉取最新代码
|
||||
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
||||
5. 在已有 issue 中未搜索到类似问题
|
||||
6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
|
||||
|
||||
### 问题描述
|
||||
|
||||
> 简要说明、截图、复现步骤等,也可以是需求或想法
|
||||
|
||||
|
||||
|
||||
|
||||
### 终端日志 (如有报错)
|
||||
|
||||
```
|
||||
[在此处粘贴终端日志]
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 环境
|
||||
|
||||
- 操作系统类型 (Mac/Windows/Linux):
|
||||
- Python版本 ( 执行 `python3 -V` ):
|
||||
- pip版本 ( 依赖问题此项必填,执行 `pip3 -V`):
|
||||
133
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
Normal file
133
.github/ISSUE_TEMPLATE/1.bug.yml
vendored
Normal file
@@ -0,0 +1,133 @@
|
||||
name: Bug report 🐛
|
||||
description: 项目运行中遇到的Bug或问题。
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
### ⚠️ 前置确认
|
||||
1. 网络能够访问openai接口
|
||||
2. python 已安装:版本在 3.7 ~ 3.10 之间
|
||||
3. `git pull` 拉取最新代码
|
||||
4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足
|
||||
5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足
|
||||
6. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: 前置确认
|
||||
options:
|
||||
- label: 我确认我运行的是最新版本的代码,并且安装了所需的依赖,在[FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs)中也未找到类似问题。
|
||||
required: true
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: ⚠️ 搜索issues中是否已存在类似问题
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索你的问题
|
||||
或相关日志的关键词来查找是否存在类似问题。
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有跟我遇到的问题相关的issue
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写你对你所遇到问题的简略总结,这将帮助其他人更好的找到相似问题,谢谢❤️。
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 操作系统类型?
|
||||
description: >
|
||||
请选择你运行程序的操作系统类型。
|
||||
options:
|
||||
- Windows
|
||||
- Linux
|
||||
- MacOS
|
||||
- Docker
|
||||
- Railway
|
||||
- Windows Subsystem for Linux (WSL)
|
||||
- Other (请在问题中说明)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的python版本是?
|
||||
description: |
|
||||
请选择你运行程序的`python`版本。
|
||||
注意:在`python 3.7`中,有部分可选依赖无法安装。
|
||||
经过长时间的观察,我们认为`python 3.8`是兼容性最好的版本。
|
||||
`python 3.7`~`python 3.10`以外版本的issue,将视情况直接关闭。
|
||||
options:
|
||||
- python 3.7
|
||||
- python 3.8
|
||||
- python 3.9
|
||||
- python 3.10
|
||||
- other
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 使用的chatgpt-on-wechat版本是?
|
||||
description: |
|
||||
请确保你使用的是 [releases](https://github.com/zhayujie/chatgpt-on-wechat/releases) 中的最新版本。
|
||||
如果你使用git, 请使用`git branch`命令来查看分支。
|
||||
options:
|
||||
- Latest Release
|
||||
- Master (branch)
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 运行的`channel`类型是?
|
||||
description: |
|
||||
请确保你正确配置了该`channel`所需的配置项,所有可选的配置项都写在了[该文件中](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py),请将所需配置项填写在根目录下的`config.json`文件中。
|
||||
options:
|
||||
- wx(个人微信, itchat)
|
||||
- wxy(个人微信, wechaty)
|
||||
- wechatmp(公众号, 订阅号)
|
||||
- wechatmp_service(公众号, 服务号)
|
||||
- terminal
|
||||
- other
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 复现步骤 🕹
|
||||
description: |
|
||||
**⚠️ 不能复现将会关闭issue.**
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 问题描述 😯
|
||||
description: 详细描述出现的问题,或提供有关截图。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 终端日志 📒
|
||||
description: |
|
||||
在此处粘贴终端日志,可在主目录下`run.log`文件中找到,这会帮助我们更好的分析问题,注意隐去你的API key。
|
||||
如果在配置文件中加入`"debug": true`,打印出的日志会更有帮助。
|
||||
|
||||
<details>
|
||||
<summary><i>示例</i></summary>
|
||||
```log
|
||||
[DEBUG][2023-04-16 00:23:22][plugin_manager.py:157] - Plugin SUMMARY triggered by event Event.ON_HANDLE_CONTEXT
|
||||
[DEBUG][2023-04-16 00:23:22][main.py:221] - [Summary] on_handle_context. content: $总结前100条消息
|
||||
[DEBUG][2023-04-16 00:23:24][main.py:240] - [Summary] limit: 100, duration: -1 seconds
|
||||
[ERROR][2023-04-16 00:23:24][chat_channel.py:244] - Worker return exception: name 'start_date' is not defined
|
||||
Traceback (most recent call last):
|
||||
File "C:\ProgramData\Anaconda3\lib\concurrent\futures\thread.py", line 57, in run
|
||||
result = self.fn(*self.args, **self.kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 132, in _handle
|
||||
reply = self._generate_reply(context)
|
||||
File "D:\project\chatgpt-on-wechat\channel\chat_channel.py", line 142, in _generate_reply
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
||||
File "D:\project\chatgpt-on-wechat\plugins\plugin_manager.py", line 159, in emit_event
|
||||
instance.handlers[e_context.event](e_context, *args, **kwargs)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 255, in on_handle_context
|
||||
records = self._get_records(session_id, start_time, limit)
|
||||
File "D:\project\chatgpt-on-wechat\plugins\summary\main.py", line 96, in _get_records
|
||||
c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", (session_id, start_date, limit))
|
||||
NameError: name 'start_date' is not defined
|
||||
[INFO][2023-04-16 00:23:36][app.py:14] - signal 2 received, exiting...
|
||||
```
|
||||
</details>
|
||||
value: |
|
||||
```log
|
||||
<此处粘贴终端日志>
|
||||
```
|
||||
28
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
Normal file
28
.github/ISSUE_TEMPLATE/2.feature.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Feature request 🚀
|
||||
description: 提出你对项目的新想法或建议。
|
||||
labels: ['status: needs check']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
请在上方的`title`中填写简略总结,谢谢❤️。
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: ⚠️ 搜索是否存在类似issue
|
||||
description: >
|
||||
请在 [历史issue](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中清空输入框,搜索关键词查找是否存在相似issue。
|
||||
options:
|
||||
- label: 我已经搜索过issues和disscussions,没有发现相似issue
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 总结
|
||||
description: 描述feature的功能。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 举例
|
||||
description: 提供聊天示例,草图或相关网址。
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 动机
|
||||
description: 描述你提出该feature的动机,比如没有这项feature对你的使用造成了怎样的影响。 请提供更详细的场景描述,这可能会帮助我们发现并提出更好的解决方案。
|
||||
72
.github/workflows/deploy-image-arm.yml
vendored
Normal file
72
.github/workflows/deploy-image-arm.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
# GitHub recommends pinning actions to a commit SHA.
|
||||
# To get a newer version, you will need to update the SHA.
|
||||
# You can also reference a tag or branch, but the action may change without warning.
|
||||
|
||||
name: Create and publish a Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ['master']
|
||||
create:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Available platforms
|
||||
run: echo ${{ steps.buildx.outputs.platforms }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
platforms: linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}-arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
with:
|
||||
package-name: 'chatgpt-on-wechat'
|
||||
package-type: 'container'
|
||||
min-versions-to-keep: 10
|
||||
delete-only-untagged-versions: 'true'
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
18
.github/workflows/deploy-image.yml
vendored
18
.github/workflows/deploy-image.yml
vendored
@@ -12,14 +12,14 @@ name: Create and publish a Docker image
|
||||
on:
|
||||
push:
|
||||
branches: ['master']
|
||||
release:
|
||||
types: [published]
|
||||
create:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -29,6 +29,12 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
@@ -40,7 +46,9 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
images: |
|
||||
${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
@@ -50,9 +58,9 @@ jobs:
|
||||
file: ./docker/Dockerfile.latest
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
with:
|
||||
with:
|
||||
package-name: 'chatgpt-on-wechat'
|
||||
package-type: 'container'
|
||||
min-versions-to-keep: 10
|
||||
|
||||
21
.gitignore
vendored
21
.gitignore
vendored
@@ -1,5 +1,8 @@
|
||||
.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
.venv
|
||||
.vs
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
@@ -10,3 +13,21 @@ nohup.out
|
||||
tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
user_datas.pkl
|
||||
chatgpt_tool_hub/
|
||||
plugins/**/
|
||||
!plugins/bdunit
|
||||
!plugins/dungeon
|
||||
!plugins/finish
|
||||
!plugins/godcmd
|
||||
!plugins/tool
|
||||
!plugins/banwords
|
||||
!plugins/banwords/**/
|
||||
plugins/banwords/__pycache__
|
||||
plugins/banwords/lib/__pycache__
|
||||
!plugins/hello
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
client_config.json
|
||||
|
||||
30
.pre-commit-config.yaml
Normal file
30
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: fix-byte-order-marker
|
||||
- id: check-case-conflict
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
||||
- id: pretty-format-json
|
||||
types: [text]
|
||||
files: \.json(.template)?$
|
||||
args: [ --autofix , --no-ensure-ascii, --indent=2, --no-sort-keys]
|
||||
- id: trailing-whitespace
|
||||
exclude: '(\/|^)lib\/'
|
||||
args: [ --markdown-linebreak-ext=md ]
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
exclude: '(\/|^)lib\/'
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
exclude: '(\/|^)lib\/'
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: '(\/|^)lib\/'
|
||||
219
README.md
219
README.md
@@ -2,70 +2,78 @@
|
||||
|
||||
> ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。
|
||||
|
||||
最新版本支持的功能如下:
|
||||
|
||||
基于ChatGPT的微信聊天机器人,通过 [ChatGPT](https://github.com/openai/openai-python) 接口生成对话内容,使用 [itchat](https://github.com/littlecodersh/ItChat) 实现微信消息的接收和自动回复。已实现的特性如下:
|
||||
- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信、微信公众号和、企业微信、飞书等部署方式
|
||||
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4, claude, Gemini, 文心一言, 讯飞星火, 通义千问
|
||||
- [x] **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
|
||||
- [x] **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, vision模型
|
||||
- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话等插件
|
||||
- [X] **Tool工具:** 与操作系统和互联网交互,支持最新信息搜索、数学计算、天气和资讯查询、网页总结,基于 [chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) 实现
|
||||
- [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、领域知识库、智能客服使用,基于 [LinkAI](https://link-ai.tech/console) 实现
|
||||
|
||||
- [x] **文本对话:** 接收私聊及群组中的微信消息,使用ChatGPT生成回复内容,完成自动回复
|
||||
- [x] **规则定制化:** 支持私聊中按指定规则触发自动回复,支持对群组设置自动回复白名单
|
||||
- [x] **多账号:** 支持多微信账号同时运行
|
||||
- [x] **图片生成:** 支持根据描述生成图片,并自动发送至个人聊天或群聊
|
||||
- [x] **上下文记忆**:支持多轮对话记忆,且为每个好友维护独立的上下会话
|
||||
- [x] **语音识别:** 支持接收和处理语音消息,通过文字或语音回复
|
||||
- [x] **插件化:** 支持个性化功能插件,提供角色扮演、文字冒险游戏等预设插件
|
||||
> 欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
|
||||
|
||||
> 快速部署:
|
||||
>
|
||||
>[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
# 演示
|
||||
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/assets/26161723/d5154020-36e3-41db-8706-40ce9f3f1b1e
|
||||
|
||||
Demo made by [Visionn](https://www.wangpc.cc/)
|
||||
|
||||
# 交流群
|
||||
|
||||
添加小助手微信进群,请备注 "wechat":
|
||||
|
||||
<img width="240" src="./docs/images/contact.jpg">
|
||||
|
||||
# 更新日志
|
||||
|
||||
>**2023.11.11:** [1.5.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.3) 和 [1.5.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.4),新增Google Gemini、通义千问模型
|
||||
|
||||
>**2023.11.10:** [1.5.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.2),新增飞书通道、图像识别对话、黑名单配置
|
||||
|
||||
>**2023.11.10:** [1.5.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.0),新增 `gpt-4-turbo`, `dall-e-3`, `tts` 模型接入,完善图像理解&生成、语音识别&生成的多模态能力
|
||||
|
||||
>**2023.10.16:** 支持通过意图识别使用LinkAI联网搜索、数学计算、网页访问等插件,参考[插件文档](https://docs.link-ai.tech/platform/plugins)
|
||||
|
||||
>**2023.09.26:** 插件增加 文件/文章链接 一键总结和对话的功能,使用参考:[插件说明](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai#3%E6%96%87%E6%A1%A3%E6%80%BB%E7%BB%93%E5%AF%B9%E8%AF%9D%E5%8A%9F%E8%83%BD)
|
||||
|
||||
>**2023.08.08:** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图
|
||||
|
||||
>**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,并接入微信、公众号及企业微信中,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
>**2023.04.26:** 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatcom/README.md)。(contributed by [@lanvent](https://github.com/lanvent) in [#944](https://github.com/zhayujie/chatgpt-on-wechat/pull/944))
|
||||
|
||||
>**2023.04.05:** 支持微信公众号部署,兼容插件,并支持语音图片交互,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
|
||||
|
||||
>**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
|
||||
|
||||
>**2023.03.25:** 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 [#578](https://github.com/zhayujie/chatgpt-on-wechat/issues/578)。(contributed by [@lanvent](https://github.com/lanvent) in [#565](https://github.com/zhayujie/chatgpt-on-wechat/pull/565))
|
||||
|
||||
>**2023.03.09:** 基于 `whisper API` 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
|
||||
|
||||
>**2023.03.02:** 接入[ChatGPT API](https://platform.openai.com/docs/guides/chat) (gpt-3.5-turbo),默认使用该模型进行对话,需升级openai依赖 (`pip3 install --upgrade openai`)。网络问题参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
|
||||
>**2023.02.09:** 扫码登录存在封号风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
|
||||
|
||||
>**2023.02.05:** 在openai官方接口方案中 (GPT-3模型) 实现上下文对话
|
||||
|
||||
>**2022.12.18:** 支持根据描述生成图片并发送,openai版本需大于0.25.0
|
||||
|
||||
>**2022.12.17:** 原来的方案是从 [ChatGPT页面](https://chat.openai.com/chat) 获取session_token,使用 [revChatGPT](https://github.com/acheong08/ChatGPT) 直接访问web接口,但随着ChatGPT接入Cloudflare人机验证,这一方案难以在服务器顺利运行。 所以目前使用的方案是调用 OpenAI 官方提供的 [API](https://beta.openai.com/docs/api-reference/introduction),回复质量上基本接近于ChatGPT的内容,劣势是暂不支持有上下文记忆的对话,优势是稳定性和响应速度较好。
|
||||
|
||||
# 使用效果
|
||||
|
||||
### 个人聊天
|
||||
|
||||

|
||||
|
||||
### 群组聊天
|
||||
|
||||

|
||||
|
||||
### 图片生成
|
||||
|
||||

|
||||
>**2023.03.09:** 基于 `whisper API`(后续已接入更多的语音`API`服务) 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
|
||||
|
||||
>**2023.02.09:** 扫码登录存在账号限制风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
|
||||
|
||||
# 快速开始
|
||||
|
||||
快速开始文档:[项目搭建文档](https://docs.link-ai.tech/cow/quick-start)
|
||||
|
||||
## 准备
|
||||
|
||||
### 1. OpenAI账号注册
|
||||
### 1. 账号注册
|
||||
|
||||
前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,参考这篇 [教程](https://www.pythonthree.com/register-openai-chatgpt/) 可以通过虚拟手机号来接收验证码。创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。
|
||||
项目默认使用OpenAI接口,需前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。接口需要海外网络访问及绑定信用卡支付。
|
||||
|
||||
> 项目中使用的对话模型是 davinci,计费方式是约每 750 字 (包含请求和回复) 消耗 $0.02,图片生成是每张消耗 $0.016,账号创建有免费的 $18 额度 (更新3.25: 最新注册的已经无免费额度了),使用完可以更换邮箱重新注册。
|
||||
|
||||
#### 1.1 ChapGPT service On Azure
|
||||
一种替换以上的方法是使用Azure推出的[ChatGPT service](https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/)。它host在公有云Azure上,因此不需要VPN就可以直接访问。不过目前仍然处于preview阶段。新用户可以通过Try Azure for free来薅一段时间的羊毛
|
||||
> 默认对话模型是 openai 的 gpt-3.5-turbo,计费方式是约每 1000tokens (约750个英文单词 或 500汉字,包含请求和回复) 消耗 $0.002,图片生成是Dell E模型,每张消耗 $0.016。
|
||||
|
||||
项目同时也支持使用 LinkAI 接口,无需代理,可使用 文心、讯飞、GPT-3、GPT-4 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结和对话等能力。修改配置即可一键切换,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
### 2.运行环境
|
||||
|
||||
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
|
||||
> 建议Python版本在 3.7.1~3.9.X 之间,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
|
||||
> 建议Python版本在 3.7.1~3.9.X 之间,推荐3.8版本,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
|
||||
|
||||
> 注意:Docker 或 Railway 部署无需安装python环境和下载源码,可直接快进到下一节。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
|
||||
@@ -75,14 +83,20 @@ cd chatgpt-on-wechat/
|
||||
```
|
||||
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
|
||||
> 能够使用`itchat`创建机器人,并具有文字交流功能所需的最小依赖集合。
|
||||
```bash
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
其中`tiktoken`要求`python`版本在3.8以上,它用于精确计算会话使用的tokens数量,可以不装但建议安装。
|
||||
**(3) 拓展依赖 (可选,建议安装):**
|
||||
|
||||
```bash
|
||||
pip3 install -r requirements-optional.txt
|
||||
```
|
||||
> 如果某项依赖安装失败请注释掉对应的行再继续。
|
||||
|
||||
其中`tiktoken`要求`python`版本在3.8以上,它用于精确计算会话使用的tokens数量,强烈建议安装。
|
||||
|
||||
**(3) 拓展依赖 (可选):**
|
||||
|
||||
使用`google`或`baidu`语音识别需安装`ffmpeg`,
|
||||
|
||||
@@ -90,6 +104,13 @@ pip3 install -r requirements.txt
|
||||
|
||||
参考[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)
|
||||
|
||||
使用`azure`语音功能需安装依赖,并参考[文档](https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/quickstarts/setup-platform?pivots=programming-language-python&tabs=linux%2Cubuntu%2Cdotnet%2Cjre%2Cmaven%2Cnodejs%2Cmac%2Cpypi)的环境要求。
|
||||
:
|
||||
|
||||
```bash
|
||||
pip3 install azure-cognitiveservices-speech
|
||||
```
|
||||
|
||||
## 配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
@@ -98,25 +119,32 @@ pip3 install -r requirements.txt
|
||||
cp config-template.json config.json
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改:
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(请去掉注释):
|
||||
|
||||
```bash
|
||||
# config.json文件内容示例
|
||||
{
|
||||
{
|
||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
||||
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
|
||||
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述,
|
||||
"azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称
|
||||
"azure_api_version": "", # 采用Azure ChatGPT时,API版本
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
# 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
|
||||
"linkai_api_key": "", # LinkAI Api Key
|
||||
"linkai_app_code": "" # LinkAI 应用code
|
||||
}
|
||||
```
|
||||
**配置说明:**
|
||||
@@ -141,18 +169,25 @@ pip3 install -r requirements.txt
|
||||
|
||||
**4.其他配置**
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未开放)
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`, `wenxin` , `claude` , `xunfei`(其中gpt-4 api暂未完全开放,申请通过后可使用)
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档直接在 [代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/bot/openai/open_ai_bot.py) `bot/openai/open_ai_bot.py` 中进行调整。
|
||||
+ 关于OpenAI对话及图片接口的参数配置(内容自由度、回复字数限制、图片大小等),可以参考 [对话接口](https://beta.openai.com/docs/api-reference/completions) 和 [图像接口](https://beta.openai.com/docs/api-reference/completions) 文档,在[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中检查哪些参数在本项目中是可配置的。
|
||||
+ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
|
||||
+ `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
|
||||
+ `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
|
||||
+ `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
|
||||
**所有可选的配置项均在该[文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
**5.LinkAI配置 (可选)**
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用code,选填
|
||||
|
||||
**本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
|
||||
## 运行
|
||||
|
||||
@@ -161,45 +196,89 @@ pip3 install -r requirements.txt
|
||||
如果是开发机 **本地运行**,直接在项目根目录下执行:
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
|
||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
使用nohup命令在后台运行程序:
|
||||
|
||||
```bash
|
||||
touch nohup.out # 首次运行需要新建日志文件
|
||||
touch nohup.out # 首次运行需要新建日志文件
|
||||
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
|
||||
```
|
||||
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。
|
||||
|
||||
> **注意:** 如果 扫码后手机提示登录验证需要等待5s,而终端的二维码再次刷新并提示 `Log in time out, reloading QR code`,此时需参考此 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/8) 修改一行代码即可解决。
|
||||
> **多账号支持:** 将项目复制多份,分别启动程序,用不同账号扫码登录即可实现同时运行。
|
||||
|
||||
> **多账号支持:** 将 项目复制多份,分别启动程序,用不同账号扫码登录即可实现同时运行。
|
||||
|
||||
> **特殊指令:** 用户向机器人发送 **#清除记忆** 即可清空该用户的上下文记忆。
|
||||
> **特殊指令:** 用户向机器人发送 **#reset** 即可清空该用户的上下文记忆。
|
||||
|
||||
|
||||
### 3.Docker部署
|
||||
|
||||
参考文档 [Docker部署](https://github.com/limccn/chatgpt-on-wechat/wiki/Docker%E9%83%A8%E7%BD%B2) (Contributed by [limccn](https://github.com/limccn))。
|
||||
> 使用docker部署无需下载源码和安装依赖,只需要获取 docker-compose.yml 配置文件并启动容器即可。
|
||||
|
||||
### 4. Railway部署(✅推荐)
|
||||
> Railway每月提供5刀和最多500小时的免费额度。
|
||||
1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)。
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功的表现是执行 `docker -v` 和 `docker-compose version` (或 docker compose version) 可以查看到版本号,可前往 [docker官网](https://docs.docker.com/engine/install/) 进行下载。
|
||||
|
||||
#### (1) 下载 docker-compose.yml 文件
|
||||
|
||||
```bash
|
||||
wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml
|
||||
```
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 修改所需配置,如 `OPEN_AI_API_KEY` 和 `GROUP_NAME_WHITE_LIST` 等。
|
||||
|
||||
#### (2) 启动容器
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
|
||||
```bash
|
||||
sudo docker compose up -d
|
||||
```
|
||||
|
||||
运行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。
|
||||
|
||||
注意:
|
||||
|
||||
- 如果 `docker-compose` 是 1.X 版本 则需要执行 `sudo docker-compose up -d` 来启动容器
|
||||
- 该命令会自动去 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取 latest 版本的镜像,latest 镜像会在每次项目 release 新的版本时生成
|
||||
|
||||
最后运行以下命令可查看容器运行日志,扫描日志中的二维码即可完成登录:
|
||||
|
||||
```bash
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
#### (3) 插件使用
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
|
||||
```
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
|
||||
### 4. Railway部署
|
||||
|
||||
> Railway 每月提供5刀和最多500小时的免费额度。 (07.11更新: 目前大部分账号已无法免费部署)
|
||||
|
||||
1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
2. 点击 `Deploy Now` 按钮。
|
||||
3. 设置环境变量来重载程序运行的参数,例如`open_ai_api_key`, `character_desc`。
|
||||
|
||||
**一键部署:**
|
||||
|
||||
[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
|
||||
## 常见问题
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (beta版本,语料完善中,回复仅供参考)
|
||||
|
||||
## 联系
|
||||
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题优先查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索,若无相似问题可创建Issue,或加微信 eijuyahz 交流。
|
||||
|
||||
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。参与更多讨论可加入技术交流群。
|
||||
|
||||
53
app.py
53
app.py
@@ -1,27 +1,66 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from config import conf, load_config
|
||||
from channel import channel_factory
|
||||
from common.log import logger
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from channel import channel_factory
|
||||
from common import const
|
||||
from config import load_config
|
||||
from plugins import *
|
||||
import threading
|
||||
|
||||
|
||||
def sigterm_handler_wrap(_signo):
|
||||
old_handler = signal.getsignal(_signo)
|
||||
|
||||
def func(_signo, _stack_frame):
|
||||
logger.info("signal {} received, exiting...".format(_signo))
|
||||
conf().save_user_datas()
|
||||
if callable(old_handler): # check old_handler
|
||||
return old_handler(_signo, _stack_frame)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def run():
|
||||
try:
|
||||
# load config
|
||||
load_config()
|
||||
# ctrl + c
|
||||
sigterm_handler_wrap(signal.SIGINT)
|
||||
# kill signal
|
||||
sigterm_handler_wrap(signal.SIGTERM)
|
||||
|
||||
# create channel
|
||||
channel_name=conf().get('channel_type', 'wx')
|
||||
channel_name = conf().get("channel_type", "wx")
|
||||
|
||||
if "--cmd" in sys.argv:
|
||||
channel_name = "terminal"
|
||||
|
||||
if channel_name == "wxy":
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name=='wx':
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU,const.DINGTALK]:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
try:
|
||||
from common import linkai_client
|
||||
threading.Thread(target=linkai_client.start, args=(channel, )).start()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# startup channel
|
||||
channel.startup()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
|
||||
214
bot/ali/ali_qwen_bot.py
Normal file
214
bot/ali/ali_qwen_bot.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import broadscope_bailian
|
||||
from broadscope_bailian import ChatQaMessage
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.ali.ali_qwen_session import AliQwenSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common import const
|
||||
from config import conf, load_config
|
||||
|
||||
class AliQwenBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api_key_expired_time = self.set_api_key()
|
||||
self.sessions = SessionManager(AliQwenSession, model=conf().get("model", const.QWEN))
|
||||
|
||||
def api_key_client(self):
|
||||
return broadscope_bailian.AccessTokenClient(access_key_id=self.access_key_id(), access_key_secret=self.access_key_secret())
|
||||
|
||||
def access_key_id(self):
|
||||
return conf().get("qwen_access_key_id")
|
||||
|
||||
def access_key_secret(self):
|
||||
return conf().get("qwen_access_key_secret")
|
||||
|
||||
def agent_key(self):
|
||||
return conf().get("qwen_agent_key")
|
||||
|
||||
def app_id(self):
|
||||
return conf().get("qwen_app_id")
|
||||
|
||||
def node_id(self):
|
||||
return conf().get("qwen_node_id", "")
|
||||
|
||||
def temperature(self):
|
||||
return conf().get("temperature", 0.2 )
|
||||
|
||||
def top_p(self):
|
||||
return conf().get("top_p", 1)
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[QWEN] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[QWEN] session query={}".format(session.messages))
|
||||
|
||||
reply_content = self.reply_text(session)
|
||||
logger.debug(
|
||||
"[QWEN] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[QWEN] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: AliQwenSession, retry_count=0) -> dict:
|
||||
"""
|
||||
call bailian's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
prompt, history = self.convert_messages_format(session.messages)
|
||||
self.update_api_key_if_expired()
|
||||
# NOTE 阿里百炼的call()函数未提供temperature参数,考虑到temperature和top_p参数作用相同,取两者较小的值作为top_p参数传入,详情见文档 https://help.aliyun.com/document_detail/2587502.htm
|
||||
response = broadscope_bailian.Completions().call(app_id=self.app_id(), prompt=prompt, history=history, top_p=min(self.temperature(), self.top_p()))
|
||||
completion_content = self.get_completion_content(response, self.node_id())
|
||||
completion_tokens, total_tokens = self.calc_tokens(session.messages, completion_content)
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": completion_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[QWEN] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[QWEN] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIError):
|
||||
logger.warn("[QWEN] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[QWEN] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.exception("[QWEN] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[QWEN] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
def set_api_key(self):
|
||||
api_key, expired_time = self.api_key_client().create_token(agent_key=self.agent_key())
|
||||
broadscope_bailian.api_key = api_key
|
||||
return expired_time
|
||||
|
||||
def update_api_key_if_expired(self):
|
||||
if time.time() > self.api_key_expired_time:
|
||||
self.api_key_expired_time = self.set_api_key()
|
||||
|
||||
def convert_messages_format(self, messages) -> Tuple[str, List[ChatQaMessage]]:
|
||||
history = []
|
||||
user_content = ''
|
||||
assistant_content = ''
|
||||
system_content = ''
|
||||
for message in messages:
|
||||
role = message.get('role')
|
||||
if role == 'user':
|
||||
user_content += message.get('content')
|
||||
elif role == 'assistant':
|
||||
assistant_content = message.get('content')
|
||||
history.append(ChatQaMessage(user_content, assistant_content))
|
||||
user_content = ''
|
||||
assistant_content = ''
|
||||
elif role =='system':
|
||||
system_content += message.get('content')
|
||||
if user_content == '':
|
||||
raise Exception('no user message')
|
||||
if system_content != '':
|
||||
# NOTE 模拟系统消息,测试发现人格描述以"你需要扮演ChatGPT"开头能够起作用,而以"你是ChatGPT"开头模型会直接否认
|
||||
system_qa = ChatQaMessage(system_content, '好的,我会严格按照你的设定回答问题')
|
||||
history.insert(0, system_qa)
|
||||
logger.debug("[QWEN] converted qa messages: {}".format([item.to_dict() for item in history]))
|
||||
logger.debug("[QWEN] user content as prompt: {}".format(user_content))
|
||||
return user_content, history
|
||||
|
||||
def get_completion_content(self, response, node_id):
|
||||
if not response['Success']:
|
||||
return f"[ERROR]\n{response['Code']}:{response['Message']}"
|
||||
text = response['Data']['Text']
|
||||
if node_id == '':
|
||||
return text
|
||||
# TODO: 当使用流程编排创建大模型应用时,响应结构如下,最终结果在['finalResult'][node_id]['response']['text']中,暂时先这么写
|
||||
# {
|
||||
# 'Success': True,
|
||||
# 'Code': None,
|
||||
# 'Message': None,
|
||||
# 'Data': {
|
||||
# 'ResponseId': '9822f38dbacf4c9b8daf5ca03a2daf15',
|
||||
# 'SessionId': 'session_id',
|
||||
# 'Text': '{"finalResult":{"LLM_T7islK":{"params":{"modelId":"qwen-plus-v1","prompt":"${systemVars.query}${bizVars.Text}"},"response":{"text":"作为一个AI语言模型,我没有年龄,因为我没有生日。\n我只是一个程序,没有生命和身体。"}}}}',
|
||||
# 'Thoughts': [],
|
||||
# 'Debug': {},
|
||||
# 'DocReferences': []
|
||||
# },
|
||||
# 'RequestId': '8e11d31551ce4c3f83f49e6e0dd998b0',
|
||||
# 'Failed': None
|
||||
# }
|
||||
text_dict = json.loads(text)
|
||||
completion_content = text_dict['finalResult'][node_id]['response']['text']
|
||||
return completion_content
|
||||
|
||||
def calc_tokens(self, messages, completion_content):
|
||||
completion_tokens = len(completion_content)
|
||||
prompt_tokens = 0
|
||||
for message in messages:
|
||||
prompt_tokens += len(message["content"])
|
||||
return completion_tokens, prompt_tokens + completion_tokens
|
||||
62
bot/ali/ali_qwen_session.py
Normal file
62
bot/ali/ali_qwen_session.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
e.g.
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
class AliQwenSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="qianwen"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
# 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词"
|
||||
# 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html
|
||||
# 目前根据字符串长度粗略估计token数,不影响正常使用
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -1,6 +1,7 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests
|
||||
|
||||
from bot.bot import Bot
|
||||
from bridge.reply import Reply, ReplyType
|
||||
|
||||
@@ -9,20 +10,27 @@ from bridge.reply import Reply, ReplyType
|
||||
class BaiduUnitBot(Bot):
|
||||
def reply(self, query, context=None):
|
||||
token = self.get_token()
|
||||
url = 'https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=' + token
|
||||
post_data = "{\"version\":\"3.0\",\"service_id\":\"S73177\",\"session_id\":\"\",\"log_id\":\"7758521\",\"skill_ids\":[\"1221886\"],\"request\":{\"terminal_id\":\"88888\",\"query\":\"" + query + "\", \"hyper_params\": {\"chat_custom_bot_profile\": 1}}}"
|
||||
url = "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" + token
|
||||
post_data = (
|
||||
'{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"'
|
||||
+ query
|
||||
+ '", "hyper_params": {"chat_custom_bot_profile": 1}}}'
|
||||
)
|
||||
print(post_data)
|
||||
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
||||
headers = {"content-type": "application/x-www-form-urlencoded"}
|
||||
response = requests.post(url, data=post_data.encode(), headers=headers)
|
||||
if response:
|
||||
reply = Reply(ReplyType.TEXT, response.json()['result']['context']['SYS_PRESUMED_HIST'][1])
|
||||
reply = Reply(
|
||||
ReplyType.TEXT,
|
||||
response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1],
|
||||
)
|
||||
return reply
|
||||
|
||||
def get_token(self):
|
||||
access_key = 'YOUR_ACCESS_KEY'
|
||||
secret_key = 'YOUR_SECRET_KEY'
|
||||
host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=' + access_key + '&client_secret=' + secret_key
|
||||
access_key = "YOUR_ACCESS_KEY"
|
||||
secret_key = "YOUR_SECRET_KEY"
|
||||
host = "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" + access_key + "&client_secret=" + secret_key
|
||||
response = requests.get(host)
|
||||
if response:
|
||||
print(response.json())
|
||||
return response.json()['access_token']
|
||||
return response.json()["access_token"]
|
||||
|
||||
107
bot/baidu/baidu_wenxin.py
Normal file
107
bot/baidu/baidu_wenxin.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests, json
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
BAIDU_API_KEY = conf().get("baidu_wenxin_api_key")
|
||||
BAIDU_SECRET_KEY = conf().get("baidu_wenxin_secret_key")
|
||||
|
||||
class BaiduWenxinBot(Bot):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant"
|
||||
if conf().get("model") and conf().get("model") == "wenxin-4":
|
||||
wenxin_model = "completions_pro"
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=wenxin_model)
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[BAIDU] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
if query == "#清除记忆":
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
result = self.reply_text(session)
|
||||
total_tokens, completion_tokens, reply_content = (
|
||||
result["total_tokens"],
|
||||
result["completion_tokens"],
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[BAIDU] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
||||
try:
|
||||
logger.info("[BAIDU] model={}".format(session.model))
|
||||
access_token = self.get_access_token()
|
||||
if access_token == 'None':
|
||||
logger.warn("[BAIDU] access token 获取失败")
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": 0,
|
||||
}
|
||||
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + session.model + "?access_token=" + access_token
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
payload = {'messages': session.messages}
|
||||
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
|
||||
response_text = json.loads(response.text)
|
||||
logger.info(f"[BAIDU] response text={response_text}")
|
||||
res_content = response_text["result"]
|
||||
total_tokens = response_text["usage"]["total_tokens"]
|
||||
completion_tokens = response_text["usage"]["completion_tokens"]
|
||||
logger.info("[BAIDU] reply={}".format(res_content))
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
logger.warn("[BAIDU] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
result = {"completion_tokens": 0, "content": "出错了: {}".format(e)}
|
||||
return result
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
使用 AK,SK 生成鉴权签名(Access Token)
|
||||
:return: access_token,或是None(如果错误)
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {"grant_type": "client_credentials", "client_id": BAIDU_API_KEY, "client_secret": BAIDU_SECRET_KEY}
|
||||
return str(requests.post(url, params=params).json().get("access_token"))
|
||||
53
bot/baidu/baidu_wenxin_session.py
Normal file
53
bot/baidu/baidu_wenxin_session.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
e.g. [
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class BaiduWenxinSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# 百度文心不支持system prompt
|
||||
# self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) >= 2:
|
||||
self.messages.pop(0)
|
||||
self.messages.pop(0)
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
# 官方token计算规则暂不明确: "大约为 token数为 "中文字 + 其他语种单词数 x 1.3"
|
||||
# 这里先直接根据字数粗略估算吧,暂不影响正常使用,仅在判断是否丢弃历史会话的时候会有偏差
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -8,7 +8,7 @@ from bridge.reply import Reply
|
||||
|
||||
|
||||
class Bot(object):
|
||||
def reply(self, query, context : Context =None) -> Reply:
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
"""
|
||||
bot auto-reply content
|
||||
:param req: received message
|
||||
|
||||
@@ -11,9 +11,11 @@ def create_bot(bot_type):
|
||||
:return: bot instance
|
||||
"""
|
||||
if bot_type == const.BAIDU:
|
||||
# Baidu Unit对话接口
|
||||
from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
return BaiduUnitBot()
|
||||
# 替换Baidu Unit为Baidu文心千帆对话接口
|
||||
# from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
# return BaiduUnitBot()
|
||||
from bot.baidu.baidu_wenxin import BaiduWenxinBot
|
||||
return BaiduWenxinBot()
|
||||
|
||||
elif bot_type == const.CHATGPT:
|
||||
# ChatGPT 网页端web接口
|
||||
@@ -29,4 +31,25 @@ def create_bot(bot_type):
|
||||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
||||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
||||
return AzureChatGPTBot()
|
||||
|
||||
elif bot_type == const.XUNFEI:
|
||||
from bot.xunfei.xunfei_spark_bot import XunFeiBot
|
||||
return XunFeiBot()
|
||||
|
||||
elif bot_type == const.LINKAI:
|
||||
from bot.linkai.link_ai_bot import LinkAIBot
|
||||
return LinkAIBot()
|
||||
|
||||
elif bot_type == const.CLAUDEAI:
|
||||
from bot.claude.claude_ai_bot import ClaudeAIBot
|
||||
return ClaudeAIBot()
|
||||
|
||||
elif bot_type == const.QWEN:
|
||||
from bot.ali.ali_qwen_bot import AliQwenBot
|
||||
return AliQwenBot()
|
||||
|
||||
elif bot_type == const.GEMINI:
|
||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
||||
return GoogleGeminiBot()
|
||||
|
||||
raise RuntimeError
|
||||
|
||||
@@ -1,68 +1,96 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import requests
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import Session, SessionManager
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf, load_config
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from common.expired_dict import ExpiredDict
|
||||
import openai
|
||||
import openai.error
|
||||
import time
|
||||
from config import conf, load_config
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot,OpenAIImage):
|
||||
class ChatGPTBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
proxy = conf().get('proxy')
|
||||
# set the default api_key
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("open_ai_api_base"):
|
||||
openai.api_base = conf().get("open_ai_api_base")
|
||||
proxy = conf().get("proxy")
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
if conf().get('rate_limit_chatgpt'):
|
||||
self.tb4chatgpt = TokenBucket(conf().get('rate_limit_chatgpt', 20))
|
||||
|
||||
self.sessions = SessionManager(ChatGPTSession, model= conf().get("model") or "gpt-3.5-turbo")
|
||||
if conf().get("rate_limit_chatgpt"):
|
||||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
||||
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[CHATGPT] query={}".format(query))
|
||||
|
||||
session_id = context['session_id']
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get('clear_memory_commands', ['#清除记忆'])
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
elif query == '#更新配置':
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, '配置已更新')
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[CHATGPT] session query={}".format(session.messages))
|
||||
|
||||
api_key = context.get("openai_api_key")
|
||||
model = context.get("gpt_model")
|
||||
new_args = None
|
||||
if model:
|
||||
new_args = self.args.copy()
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, session_id, 0)
|
||||
logger.debug("[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content["content"], reply_content["completion_tokens"]))
|
||||
if reply_content['completion_tokens'] == 0 and len(reply_content['content']) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
reply_content = self.reply_text(session, api_key, args=new_args)
|
||||
logger.debug(
|
||||
"[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content['content'])
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
@@ -75,62 +103,62 @@ class ChatGPTBot(Bot,OpenAIImage):
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, 'Bot不支持处理{}类型的消息'.format(context.type))
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def compose_args(self):
|
||||
return {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"temperature":conf().get('temperature', 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p":1,
|
||||
"frequency_penalty":conf().get('frequency_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty":conf().get('presence_penalty', 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
}
|
||||
|
||||
def reply_text(self, session:ChatGPTSession, session_id, retry_count=0) -> dict:
|
||||
'''
|
||||
def reply_text(self, session: ChatGPTSession, api_key=None, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
'''
|
||||
"""
|
||||
try:
|
||||
if conf().get('rate_limit_chatgpt') and not self.tb4chatgpt.get_token():
|
||||
if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
||||
raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=session.messages, **self.compose_args()
|
||||
)
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
if args is None:
|
||||
args = self.args
|
||||
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
||||
# logger.debug("[CHATGPT] response={}".format(response))
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]['message']['content']}
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response.choices[0]["message"]["content"],
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[CHATGPT] RateLimitError: {}".format(e))
|
||||
result['content'] = "提问太快啦,请休息一下再问我吧"
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CHATGPT] Timeout: {}".format(e))
|
||||
result['content'] = "我没有收到你的消息"
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIError):
|
||||
logger.warn("[CHATGPT] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result['content'] = "我连接不到你的网络"
|
||||
result["content"] = "我连接不到你的网络"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
else:
|
||||
logger.warn("[CHATGPT] Exception: {}".format(e))
|
||||
logger.exception("[CHATGPT] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session_id)
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(session, session_id, retry_count+1)
|
||||
logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, api_key, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
@@ -139,10 +167,28 @@ class AzureChatGPTBot(ChatGPTBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_type = "azure"
|
||||
openai.api_version = "2023-03-15-preview"
|
||||
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
|
||||
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
||||
|
||||
def compose_args(self):
|
||||
args = super().compose_args()
|
||||
args["engine"] = args["model"]
|
||||
del(args["model"])
|
||||
return args
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
api_version = "2022-08-03-preview"
|
||||
url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version)
|
||||
api_key = api_key or openai.api_key
|
||||
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
||||
try:
|
||||
body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")}
|
||||
submission = requests.post(url, headers=headers, json=body)
|
||||
operation_location = submission.headers["Operation-Location"]
|
||||
retry_after = submission.headers["Retry-after"]
|
||||
status = ""
|
||||
image_url = ""
|
||||
while status != "Succeeded":
|
||||
logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
|
||||
time.sleep(int(retry_after))
|
||||
response = requests.get(operation_location, headers=headers)
|
||||
status = response.json()["status"]
|
||||
image_url = response.json()["result"]["contentUrl"]
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.error("create image error: {}".format(e))
|
||||
return False, "图片生成失败"
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
'''
|
||||
from common import const
|
||||
|
||||
"""
|
||||
e.g. [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class ChatGPTSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model= "gpt-3.5-turbo"):
|
||||
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens= None):
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = num_tokens_from_messages(self.messages, self.model)
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
@@ -29,7 +33,7 @@ class ChatGPTSession(Session):
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = num_tokens_from_messages(self.messages, self.model)
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
@@ -40,34 +44,44 @@ class ChatGPTSession(Session):
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = num_tokens_from_messages(self.messages, self.model)
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
|
||||
if model in ["wenxin", "xunfei", const.GEMINI]:
|
||||
return num_tokens_by_character(messages)
|
||||
|
||||
import tiktoken
|
||||
|
||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]:
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]:
|
||||
return num_tokens_from_messages(messages, model="gpt-4")
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.debug("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model == "gpt-3.5-turbo":
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
||||
elif model == "gpt-4":
|
||||
return num_tokens_from_messages(messages, model="gpt-4-0314")
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif model == "gpt-4-0314":
|
||||
elif model == "gpt-4":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
@@ -76,4 +90,12 @@ def num_tokens_from_messages(messages, model):
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
return num_tokens
|
||||
|
||||
|
||||
def num_tokens_by_character(messages):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
|
||||
222
bot/claude/claude_ai_bot.py
Normal file
222
bot/claude/claude_ai_bot.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from curl_cffi import requests
|
||||
from bot.bot import Bot
|
||||
from bot.claude.claude_ai_session import ClaudeAiSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class ClaudeAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.claude_api_cookie = conf().get("claude_api_cookie")
|
||||
self.proxy = conf().get("proxy")
|
||||
self.con_uuid_dic = {}
|
||||
if self.proxy:
|
||||
self.proxies = {
|
||||
"http": self.proxy,
|
||||
"https": self.proxy
|
||||
}
|
||||
else:
|
||||
self.proxies = None
|
||||
self.error = ""
|
||||
self.org_uuid = self.get_organization_id()
|
||||
|
||||
def generate_uuid(self):
|
||||
random_uuid = uuid.uuid4()
|
||||
random_uuid_str = str(random_uuid)
|
||||
formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
|
||||
return formatted_uuid
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def get_organization_id(self):
|
||||
url = "https://claude.ai/api/organizations"
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}'
|
||||
}
|
||||
try:
|
||||
response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
|
||||
res = json.loads(response.text)
|
||||
uuid = res[0]['uuid']
|
||||
except:
|
||||
if "App unavailable" in response.text:
|
||||
logger.error("IP error: The IP is not allowed to be used on Claude")
|
||||
self.error = "ip所在地区不被claude支持"
|
||||
elif "Invalid authorization" in response.text:
|
||||
logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
|
||||
self.error = "无法通过claude身份验证,请检查cookie"
|
||||
return None
|
||||
return uuid
|
||||
|
||||
def conversation_share_check(self,session_id):
|
||||
if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
|
||||
con_uuid = conf().get("claude_uuid")
|
||||
return con_uuid
|
||||
if session_id not in self.con_uuid_dic:
|
||||
self.con_uuid_dic[session_id] = self.generate_uuid()
|
||||
self.create_new_chat(self.con_uuid_dic[session_id])
|
||||
return self.con_uuid_dic[session_id]
|
||||
|
||||
def check_cookie(self):
|
||||
flag = self.get_organization_id()
|
||||
return flag
|
||||
|
||||
def create_new_chat(self, con_uuid):
|
||||
"""
|
||||
新建claude对话实体
|
||||
:param con_uuid: 对话id
|
||||
:return:
|
||||
"""
|
||||
url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
|
||||
payload = json.dumps({"uuid": con_uuid, "name": ""})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': self.claude_api_cookie,
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
|
||||
# Returns JSON of the newly created conversation information
|
||||
return response.json()
|
||||
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[CLAUDEAI] failed after maximum number of retry times")
|
||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
||||
|
||||
try:
|
||||
session_id = context["session_id"]
|
||||
if self.org_uuid is None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
con_uuid = self.conversation_share_check(session_id)
|
||||
|
||||
model = conf().get("model") or "gpt-3.5-turbo"
|
||||
# remove system message
|
||||
if session.messages[0].get("role") == "system":
|
||||
if model == "wenxin" or model == "claude":
|
||||
session.messages.pop(0)
|
||||
logger.info(f"[CLAUDEAI] query={query}")
|
||||
|
||||
# do http request
|
||||
base_url = "https://claude.ai"
|
||||
payload = json.dumps({
|
||||
"completion": {
|
||||
"prompt": f"{query}",
|
||||
"timezone": "Asia/Kolkata",
|
||||
"model": "claude-2"
|
||||
},
|
||||
"organization_uuid": f"{self.org_uuid}",
|
||||
"conversation_uuid": f"{con_uuid}",
|
||||
"text": f"{query}",
|
||||
"attachments": []
|
||||
})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept': 'text/event-stream, text/event-stream',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
|
||||
res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
|
||||
if res.status_code == 200 or "pemission" in res.text:
|
||||
# execute success
|
||||
decoded_data = res.content.decode("utf-8")
|
||||
decoded_data = re.sub('\n+', '\n', decoded_data).strip()
|
||||
data_strings = decoded_data.split('\n')
|
||||
completions = []
|
||||
for data_string in data_strings:
|
||||
json_str = data_string[6:].strip()
|
||||
data = json.loads(json_str)
|
||||
if 'completion' in data:
|
||||
completions.append(data['completion'])
|
||||
|
||||
reply_content = ''.join(completions)
|
||||
|
||||
if "rate limi" in reply_content:
|
||||
logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
|
||||
return Reply(ReplyType.ERROR, "对话达到系统速率限制,与cladue同步,请进入官网查看解除限制时间")
|
||||
logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
|
||||
self.sessions.session_reply(reply_content, session_id, 100)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
else:
|
||||
flag = self.check_cookie()
|
||||
if flag == None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
9
bot/claude/claude_ai_session.py
Normal file
9
bot/claude/claude_ai_session.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from bot.session_manager import Session
|
||||
|
||||
|
||||
class ClaudeAiSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="claude"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# claude逆向不支持role prompt
|
||||
# self.reset()
|
||||
75
bot/gemini/google_gemini_bot.py
Normal file
75
bot/gemini/google_gemini_bot.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Google gemini bot
|
||||
|
||||
@author zhayujie
|
||||
@Date 2023/12/15
|
||||
"""
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
import google.generativeai as genai
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class GoogleGeminiBot(Bot):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api_key = conf().get("gemini_api_key")
|
||||
# 复用文心的token计算方式
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
try:
|
||||
if context.type != ContextType.TEXT:
|
||||
logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
|
||||
return Reply(ReplyType.TEXT, None)
|
||||
logger.info(f"[Gemini] query={query}")
|
||||
session_id = context["session_id"]
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
|
||||
genai.configure(api_key=self.api_key)
|
||||
model = genai.GenerativeModel('gemini-pro')
|
||||
response = model.generate_content(gemini_messages)
|
||||
reply_text = response.text
|
||||
self.sessions.session_reply(reply_text, session_id)
|
||||
logger.info(f"[Gemini] reply={reply_text}")
|
||||
return Reply(ReplyType.TEXT, reply_text)
|
||||
except Exception as e:
|
||||
logger.error("[Gemini] fetch reply error, may contain unsafe content")
|
||||
logger.error(e)
|
||||
|
||||
def _convert_to_gemini_messages(self, messages: list):
|
||||
res = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
role = "user"
|
||||
elif msg.get("role") == "assistant":
|
||||
role = "model"
|
||||
else:
|
||||
continue
|
||||
res.append({
|
||||
"role": role,
|
||||
"parts": [{"text": msg.get("content")}]
|
||||
})
|
||||
return res
|
||||
|
||||
def _filter_messages(self, messages: list):
|
||||
res = []
|
||||
turn = "user"
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
message = messages[i]
|
||||
if message.get("role") != turn:
|
||||
continue
|
||||
res.insert(0, message)
|
||||
if turn == "user":
|
||||
turn = "assistant"
|
||||
elif turn == "assistant":
|
||||
turn = "user"
|
||||
return res
|
||||
428
bot/linkai/link_ai_bot.py
Normal file
428
bot/linkai/link_ai_bot.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# access LinkAI knowledge base platform
|
||||
# docs: https://link-ai.tech/platform/link-app/wechat
|
||||
|
||||
import re
|
||||
import time
|
||||
import requests
|
||||
import config
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, pconf
|
||||
import threading
|
||||
from common import memory, utils
|
||||
import base64
|
||||
|
||||
class LinkAIBot(Bot):
|
||||
# authentication failed
|
||||
AUTH_FAILED_CODE = 401
|
||||
NO_QUOTA_CODE = 406
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {}
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
if not conf().get("text_to_image"):
|
||||
logger.warn("[LinkAI] text_to_image is not enabled, ignore the IMAGE_CREATE request")
|
||||
return Reply(ReplyType.TEXT, "")
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count > 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[LINKAI] failed after maximum number of retry times")
|
||||
return Reply(ReplyType.TEXT, "请再问我一次吧")
|
||||
|
||||
try:
|
||||
# load config
|
||||
if context.get("generate_breaked_by"):
|
||||
logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
|
||||
app_code = None
|
||||
else:
|
||||
plugin_app_code = self._find_group_mapping_code(context)
|
||||
app_code = context.kwargs.get("app_code") or plugin_app_code or conf().get("linkai_app_code")
|
||||
linkai_api_key = conf().get("linkai_api_key")
|
||||
|
||||
session_id = context["session_id"]
|
||||
session_message = self.sessions.session_msg_query(query, session_id)
|
||||
logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")
|
||||
|
||||
# image process
|
||||
img_cache = memory.USER_IMAGE_CACHE.get(session_id)
|
||||
if img_cache:
|
||||
messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
|
||||
if messages:
|
||||
session_message = messages
|
||||
|
||||
model = conf().get("model")
|
||||
# remove system message
|
||||
if session_message[0].get("role") == "system":
|
||||
if app_code or model == "wenxin":
|
||||
session_message.pop(0)
|
||||
body = {
|
||||
"app_code": app_code,
|
||||
"messages": session_message,
|
||||
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"temperature": conf().get("temperature"),
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"session_id": session_id,
|
||||
"channel_type": conf().get("channel_type")
|
||||
}
|
||||
try:
|
||||
from linkai import LinkAIClient
|
||||
client_id = LinkAIClient.fetch_client_id()
|
||||
if client_id:
|
||||
body["client_id"] = client_id
|
||||
# start: client info deliver
|
||||
if context.kwargs.get("msg"):
|
||||
body["session_id"] = context.kwargs.get("msg").from_user_id
|
||||
if context.kwargs.get("msg").is_group:
|
||||
body["is_group"] = True
|
||||
body["group_name"] = context.kwargs.get("msg").from_user_nickname
|
||||
body["sender_name"] = context.kwargs.get("msg").actual_user_nickname
|
||||
else:
|
||||
body["sender_name"] = context.kwargs.get("msg").from_user_nickname
|
||||
except Exception as e:
|
||||
pass
|
||||
file_id = context.kwargs.get("file_id")
|
||||
if file_id:
|
||||
body["file_id"] = file_id
|
||||
logger.info(f"[LINKAI] query={query}, app_code={app_code}, model={body.get('model')}, file_id={file_id}")
|
||||
headers = {"Authorization": "Bearer " + linkai_api_key}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
# execute success
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
|
||||
|
||||
agent_suffix = self._fetch_agent_suffix(response)
|
||||
if agent_suffix:
|
||||
reply_content += agent_suffix
|
||||
if not agent_suffix:
|
||||
knowledge_suffix = self._fetch_knowledge_search_suffix(response)
|
||||
if knowledge_suffix:
|
||||
reply_content += knowledge_suffix
|
||||
# image process
|
||||
if response["choices"][0].get("img_urls"):
|
||||
thread = threading.Thread(target=self._send_image, args=(context.get("channel"), context, response["choices"][0].get("img_urls")))
|
||||
thread.start()
|
||||
if response["choices"][0].get("text_content"):
|
||||
reply_content = response["choices"][0].get("text_content")
|
||||
reply_content = self._process_url(reply_content)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
return Reply(ReplyType.TEXT, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
|
||||
try:
|
||||
enable_image_input = False
|
||||
app_info = self._fetch_app_info(app_code)
|
||||
if not app_info:
|
||||
logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
|
||||
return None
|
||||
plugins = app_info.get("data").get("plugins")
|
||||
for plugin in plugins:
|
||||
if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
|
||||
enable_image_input = True
|
||||
if not enable_image_input:
|
||||
return
|
||||
msg = img_cache.get("msg")
|
||||
path = img_cache.get("path")
|
||||
msg.prepare()
|
||||
logger.info(f"[LinkAI] query with images, path={path}")
|
||||
messages = self._build_vision_msg(query, path)
|
||||
memory.USER_IMAGE_CACHE[session_id] = None
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _find_group_mapping_code(self, context):
|
||||
try:
|
||||
if context.kwargs.get("isgroup"):
|
||||
group_name = context.kwargs.get("msg").from_user_nickname
|
||||
if config.plugin_config and config.plugin_config.get("linkai"):
|
||||
linkai_config = config.plugin_config.get("linkai")
|
||||
group_mapping = linkai_config.get("group_app_map")
|
||||
if group_mapping and group_name:
|
||||
return group_mapping.get(group_name)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return None
|
||||
|
||||
def _build_vision_msg(self, query: str, path: str):
|
||||
try:
|
||||
suffix = utils.get_path_suffix(path)
|
||||
with open(path, "rb") as file:
|
||||
base64_str = base64.b64encode(file.read()).decode('utf-8')
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": query
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/{suffix};base64,{base64_str}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}]
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[LINKAI] failed after maximum number of retry times")
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": "请再问我一次吧"
|
||||
}
|
||||
|
||||
try:
|
||||
body = {
|
||||
"app_code": app_code,
|
||||
"messages": session.messages,
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"temperature": conf().get("temperature"),
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
}
|
||||
if self.args.get("max_tokens"):
|
||||
body["max_tokens"] = self.args.get("max_tokens")
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
# execute success
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": reply_content,
|
||||
}
|
||||
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self.reply_text(session, app_code, retry_count + 1)
|
||||
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": "提问太快啦,请休息一下再问我吧"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self.reply_text(session, app_code, retry_count + 1)
|
||||
|
||||
def _fetch_app_info(self, app_code: str):
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
params = {"app_code": app_code}
|
||||
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
|
||||
if res.status_code == 200:
|
||||
return res.json()
|
||||
else:
|
||||
logger.warning(f"[LinkAI] find app info exception, res={res}")
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
try:
|
||||
logger.info("[LinkImage] image_query={}".format(query))
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {conf().get('linkai_api_key')}"
|
||||
}
|
||||
data = {
|
||||
"prompt": query,
|
||||
"n": 1,
|
||||
"model": conf().get("text_to_image") or "dall-e-2",
|
||||
"response_format": "url",
|
||||
"img_proxy": conf().get("image_proxy")
|
||||
}
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/images/generations"
|
||||
res = requests.post(url, headers=headers, json=data, timeout=(5, 90))
|
||||
t2 = time.time()
|
||||
image_url = res.json()["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(format(e))
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
|
||||
|
||||
def _fetch_knowledge_search_suffix(self, response) -> str:
|
||||
try:
|
||||
if response.get("knowledge_base"):
|
||||
search_hit = response.get("knowledge_base").get("search_hit")
|
||||
first_similarity = response.get("knowledge_base").get("first_similarity")
|
||||
logger.info(f"[LINKAI] knowledge base, search_hit={search_hit}, first_similarity={first_similarity}")
|
||||
plugin_config = pconf("linkai")
|
||||
if plugin_config and plugin_config.get("knowledge_base") and plugin_config.get("knowledge_base").get("search_miss_text_enabled"):
|
||||
search_miss_similarity = plugin_config.get("knowledge_base").get("search_miss_similarity")
|
||||
search_miss_text = plugin_config.get("knowledge_base").get("search_miss_suffix")
|
||||
if not search_hit:
|
||||
return search_miss_text
|
||||
if search_miss_similarity and float(search_miss_similarity) > first_similarity:
|
||||
return search_miss_text
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
def _fetch_agent_suffix(self, response):
|
||||
try:
|
||||
plugin_list = []
|
||||
logger.debug(f"[LinkAgent] res={response}")
|
||||
if response.get("agent") and response.get("agent").get("chain") and response.get("agent").get("need_show_plugin"):
|
||||
chain = response.get("agent").get("chain")
|
||||
suffix = "\n\n- - - - - - - - - - - -"
|
||||
i = 0
|
||||
for turn in chain:
|
||||
plugin_name = turn.get('plugin_name')
|
||||
suffix += "\n"
|
||||
need_show_thought = response.get("agent").get("need_show_thought")
|
||||
if turn.get("thought") and plugin_name and need_show_thought:
|
||||
suffix += f"{turn.get('thought')}\n"
|
||||
if plugin_name:
|
||||
plugin_list.append(turn.get('plugin_name'))
|
||||
suffix += f"{turn.get('plugin_icon')} {turn.get('plugin_name')}"
|
||||
if turn.get('plugin_input'):
|
||||
suffix += f":{turn.get('plugin_input')}"
|
||||
if i < len(chain) - 1:
|
||||
suffix += "\n"
|
||||
i += 1
|
||||
logger.info(f"[LinkAgent] use plugins: {plugin_list}")
|
||||
return suffix
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _process_url(self, text):
|
||||
try:
|
||||
url_pattern = re.compile(r'\[(.*?)\]\((http[s]?://.*?)\)')
|
||||
def replace_markdown_url(match):
|
||||
return f"{match.group(2)}"
|
||||
return url_pattern.sub(replace_markdown_url, text)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def _send_image(self, channel, context, image_urls):
|
||||
if not image_urls:
|
||||
return
|
||||
try:
|
||||
for url in image_urls:
|
||||
reply = Reply(ReplyType.IMAGE_URL, url)
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
|
||||
class LinkAISessionManager(SessionManager):
|
||||
def session_msg_query(self, query, session_id):
|
||||
session = self.build_session(session_id)
|
||||
messages = session.messages + [{"role": "user", "content": query}]
|
||||
return messages
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens=None, query=None):
|
||||
session = self.build_session(session_id)
|
||||
if query:
|
||||
session.add_query(query)
|
||||
session.add_reply(reply)
|
||||
try:
|
||||
max_tokens = conf().get("conversation_max_tokens", 2500)
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.debug(f"[LinkAI] chat history, before tokens={total_tokens}, now tokens={tokens_cnt}")
|
||||
except Exception as e:
|
||||
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
|
||||
class LinkAISession(ChatGPTSession):
|
||||
def calc_tokens(self):
|
||||
if not self.messages:
|
||||
return 0
|
||||
return len(str(self.messages))
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
cur_tokens = self.calc_tokens()
|
||||
if cur_tokens > max_tokens:
|
||||
for i in range(0, len(self.messages)):
|
||||
if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
|
||||
self.messages.pop(i)
|
||||
self.messages.pop(i - 1)
|
||||
return self.calc_tokens()
|
||||
return cur_tokens
|
||||
@@ -1,54 +1,72 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.openai.open_ai_session import OpenAISession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import openai
|
||||
import openai.error
|
||||
import time
|
||||
from config import conf
|
||||
|
||||
user_session = dict()
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class OpenAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('open_ai_api_base'):
|
||||
openai.api_base = conf().get('open_ai_api_base')
|
||||
proxy = conf().get('proxy')
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("open_ai_api_base"):
|
||||
openai.api_base = conf().get("open_ai_api_base")
|
||||
proxy = conf().get("proxy")
|
||||
if proxy:
|
||||
openai.proxy = proxy
|
||||
|
||||
self.sessions = SessionManager(OpenAISession, model= conf().get("model") or "text-davinci-003")
|
||||
self.sessions = SessionManager(OpenAISession, model=conf().get("model") or "text-davinci-003")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
"max_tokens": 1200, # 回复最大的字符数
|
||||
"top_p": 1,
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
"stop": ["\n\n\n"],
|
||||
}
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[OPEN_AI] query={}".format(query))
|
||||
session_id = context['session_id']
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
if query == '#清除记忆':
|
||||
if query == "#清除记忆":
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, '记忆已清除')
|
||||
elif query == '#清除所有':
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, '所有人记忆已清除')
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
new_query = str(session)
|
||||
logger.debug("[OPEN_AI] session query={}".format(new_query))
|
||||
result = self.reply_text(session)
|
||||
total_tokens, completion_tokens, reply_content = (
|
||||
result["total_tokens"],
|
||||
result["completion_tokens"],
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
total_tokens, completion_tokens, reply_content = self.reply_text(new_query, session_id, 0)
|
||||
logger.debug("[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(new_query, session_id, reply_content, completion_tokens))
|
||||
|
||||
if total_tokens == 0 :
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
@@ -63,47 +81,42 @@ class OpenAIBot(Bot, OpenAIImage):
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, query, session_id, retry_count=0):
|
||||
def reply_text(self, session: OpenAISession, retry_count=0):
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
model= conf().get("model") or "text-davinci-003", # 对话模型的名称
|
||||
prompt=query,
|
||||
temperature=0.9, # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
max_tokens=1200, # 回复最大的字符数
|
||||
top_p=1,
|
||||
frequency_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
presence_penalty=0.0, # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
stop=["\n\n\n"]
|
||||
)
|
||||
res_content = response.choices[0]['text'].strip().replace('<|endoftext|>', '')
|
||||
response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = response.choices[0]["text"].strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
completion_tokens = response["usage"]["completion_tokens"]
|
||||
logger.info("[OPEN_AI] reply={}".format(res_content))
|
||||
return total_tokens, completion_tokens, res_content
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = [0,0,"我现在有点累了,等会再来吧"]
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[OPEN_AI] RateLimitError: {}".format(e))
|
||||
result[2] = "提问太快啦,请休息一下再问我吧"
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[OPEN_AI] Timeout: {}".format(e))
|
||||
result[2] = "我没有收到你的消息"
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[OPEN_AI] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result[2] = "我连接不到你的网络"
|
||||
result["content"] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.warn("[OPEN_AI] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session_id)
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count+1))
|
||||
return self.reply_text(query, session_id, retry_count+1)
|
||||
logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -1,38 +1,43 @@
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from common.token_bucket import TokenBucket
|
||||
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from config import conf
|
||||
|
||||
|
||||
# OPENAI提供的画图接口
|
||||
class OpenAIImage(object):
|
||||
def __init__(self):
|
||||
openai.api_key = conf().get('open_ai_api_key')
|
||||
if conf().get('rate_limit_dalle'):
|
||||
self.tb4dalle = TokenBucket(conf().get('rate_limit_dalle', 50))
|
||||
|
||||
def create_img(self, query, retry_count=0):
|
||||
openai.api_key = conf().get("open_ai_api_key")
|
||||
if conf().get("rate_limit_dalle"):
|
||||
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None, api_base=None):
|
||||
try:
|
||||
if conf().get('rate_limit_dalle') and not self.tb4dalle.get_token():
|
||||
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[OPEN_AI] image_query={}".format(query))
|
||||
response = openai.Image.create(
|
||||
prompt=query, #图片描述
|
||||
n=1, #每次生成图片的数量
|
||||
size="256x256" #图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
api_key=api_key,
|
||||
prompt=query, # 图片描述
|
||||
n=1, # 每次生成图片的数量
|
||||
model=conf().get("text_to_image") or "dall-e-2",
|
||||
# size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response['data'][0]['url']
|
||||
image_url = response["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except openai.error.RateLimitError as e:
|
||||
logger.warn(e)
|
||||
if retry_count < 1:
|
||||
time.sleep(5)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count+1))
|
||||
return self.create_img(query, retry_count+1)
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
|
||||
return self.create_img(query, retry_count + 1)
|
||||
else:
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, str(e)
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
|
||||
@@ -1,35 +1,37 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class OpenAISession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model= "text-davinci-003"):
|
||||
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def __str__(self):
|
||||
# 构造对话模型的输入
|
||||
'''
|
||||
"""
|
||||
e.g. Q: xxx
|
||||
A: xxx
|
||||
Q: xxx
|
||||
'''
|
||||
"""
|
||||
prompt = ""
|
||||
for item in self.messages:
|
||||
if item['role'] == 'system':
|
||||
prompt += item['content'] + "<|endoftext|>\n\n\n"
|
||||
elif item['role'] == 'user':
|
||||
prompt += "Q: " + item['content'] + "\n"
|
||||
elif item['role'] == 'assistant':
|
||||
prompt += "\n\nA: " + item['content'] + "<|endoftext|>\n"
|
||||
if item["role"] == "system":
|
||||
prompt += item["content"] + "<|endoftext|>\n\n\n"
|
||||
elif item["role"] == "user":
|
||||
prompt += "Q: " + item["content"] + "\n"
|
||||
elif item["role"] == "assistant":
|
||||
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"
|
||||
|
||||
if len(self.messages) > 0 and self.messages[-1]['role'] == 'user':
|
||||
if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
|
||||
prompt += "A: "
|
||||
return prompt
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens= None):
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = num_tokens_from_string(str(self), self.model)
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
@@ -41,7 +43,7 @@ class OpenAISession(Session):
|
||||
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
|
||||
self.messages.pop(0)
|
||||
if precise:
|
||||
cur_tokens = num_tokens_from_string(str(self), self.model)
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = len(str(self))
|
||||
break
|
||||
@@ -52,16 +54,20 @@ class OpenAISession(Session):
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = num_tokens_from_string(str(self), self.model)
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = len(str(self))
|
||||
return cur_tokens
|
||||
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_string(str(self), self.model)
|
||||
|
||||
|
||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
def num_tokens_from_string(string: str, model: str) -> int:
|
||||
"""Returns the number of tokens in a text string."""
|
||||
import tiktoken
|
||||
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
num_tokens = len(encoding.encode(string,disallowed_special=()))
|
||||
return num_tokens
|
||||
num_tokens = len(encoding.encode(string, disallowed_special=()))
|
||||
return num_tokens
|
||||
|
||||
@@ -2,6 +2,7 @@ from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class Session(object):
|
||||
def __init__(self, session_id, system_prompt=None):
|
||||
self.session_id = session_id
|
||||
@@ -13,7 +14,7 @@ class Session(object):
|
||||
|
||||
# 重置会话
|
||||
def reset(self):
|
||||
system_item = {'role': 'system', 'content': self.system_prompt}
|
||||
system_item = {"role": "system", "content": self.system_prompt}
|
||||
self.messages = [system_item]
|
||||
|
||||
def set_system_prompt(self, system_prompt):
|
||||
@@ -21,22 +22,24 @@ class Session(object):
|
||||
self.reset()
|
||||
|
||||
def add_query(self, query):
|
||||
user_item = {'role': 'user', 'content': query}
|
||||
user_item = {"role": "user", "content": query}
|
||||
self.messages.append(user_item)
|
||||
|
||||
def add_reply(self, reply):
|
||||
assistant_item = {'role': 'assistant', 'content': reply}
|
||||
assistant_item = {"role": "assistant", "content": reply}
|
||||
self.messages.append(assistant_item)
|
||||
|
||||
|
||||
def discard_exceeding(self, max_tokens=None, cur_tokens=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def calc_tokens(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SessionManager(object):
|
||||
def __init__(self, sessioncls, **session_args):
|
||||
if conf().get('expires_in_seconds'):
|
||||
sessions = ExpiredDict(conf().get('expires_in_seconds'))
|
||||
if conf().get("expires_in_seconds"):
|
||||
sessions = ExpiredDict(conf().get("expires_in_seconds"))
|
||||
else:
|
||||
sessions = dict()
|
||||
self.sessions = sessions
|
||||
@@ -44,17 +47,20 @@ class SessionManager(object):
|
||||
self.session_args = session_args
|
||||
|
||||
def build_session(self, session_id, system_prompt=None):
|
||||
'''
|
||||
如果session_id不在sessions中,创建一个新的session并添加到sessions中
|
||||
如果system_prompt不会空,会更新session的system_prompt并重置session
|
||||
'''
|
||||
"""
|
||||
如果session_id不在sessions中,创建一个新的session并添加到sessions中
|
||||
如果system_prompt不会空,会更新session的system_prompt并重置session
|
||||
"""
|
||||
if session_id is None:
|
||||
return self.sessioncls(session_id, system_prompt, **self.session_args)
|
||||
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args)
|
||||
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
|
||||
elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session
|
||||
self.sessions[session_id].set_system_prompt(system_prompt)
|
||||
session = self.sessions[session_id]
|
||||
return session
|
||||
|
||||
|
||||
def session_query(self, query, session_id):
|
||||
session = self.build_session(session_id)
|
||||
session.add_query(query)
|
||||
@@ -63,10 +69,10 @@ class SessionManager(object):
|
||||
total_tokens = session.discard_exceeding(max_tokens, None)
|
||||
logger.debug("prompt tokens used={}".format(total_tokens))
|
||||
except Exception as e:
|
||||
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
||||
logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens = None):
|
||||
def session_reply(self, reply, session_id, total_tokens=None):
|
||||
session = self.build_session(session_id)
|
||||
session.add_reply(reply)
|
||||
try:
|
||||
@@ -74,12 +80,12 @@ class SessionManager(object):
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
|
||||
except Exception as e:
|
||||
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def clear_session(self, session_id):
|
||||
if session_id in self.sessions:
|
||||
del(self.sessions[session_id])
|
||||
del self.sessions[session_id]
|
||||
|
||||
def clear_all_session(self):
|
||||
self.sessions.clear()
|
||||
|
||||
267
bot/xunfei/xunfei_spark_bot.py
Normal file
267
bot/xunfei/xunfei_spark_bot.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests, json
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from common import const
|
||||
import time
|
||||
import _thread as thread
|
||||
import datetime
|
||||
from datetime import datetime
|
||||
from wsgiref.handlers import format_date_time
|
||||
from urllib.parse import urlencode
|
||||
import base64
|
||||
import ssl
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from time import mktime
|
||||
from urllib.parse import urlparse
|
||||
import websocket
|
||||
import queue
|
||||
import threading
|
||||
import random
|
||||
|
||||
# 消息队列 map
|
||||
queue_map = dict()
|
||||
|
||||
# 响应队列 map
|
||||
reply_map = dict()
|
||||
|
||||
|
||||
class XunFeiBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app_id = conf().get("xunfei_app_id")
|
||||
self.api_key = conf().get("xunfei_api_key")
|
||||
self.api_secret = conf().get("xunfei_api_secret")
|
||||
# 默认使用v2.0版本: "generalv2"
|
||||
# v1.5版本为 "general"
|
||||
# v3.0版本为: "generalv3"
|
||||
self.domain = "generalv3"
|
||||
# 默认使用v2.0版本: "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
# v1.5版本为: "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
# v3.0版本为: "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
self.spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
self.host = urlparse(self.spark_url).netloc
|
||||
self.path = urlparse(self.spark_url).path
|
||||
# 和wenxin使用相同的session机制
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[XunFei] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
request_id = self.gen_request_id(session_id)
|
||||
reply_map[request_id] = ""
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
threading.Thread(target=self.create_web_socket,
|
||||
args=(session.messages, request_id)).start()
|
||||
depth = 0
|
||||
time.sleep(0.1)
|
||||
t1 = time.time()
|
||||
usage = {}
|
||||
while depth <= 300:
|
||||
try:
|
||||
data_queue = queue_map.get(request_id)
|
||||
if not data_queue:
|
||||
depth += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
data_item = data_queue.get(block=True, timeout=0.1)
|
||||
if data_item.is_end:
|
||||
# 请求结束
|
||||
del queue_map[request_id]
|
||||
if data_item.reply:
|
||||
reply_map[request_id] += data_item.reply
|
||||
usage = data_item.usage
|
||||
break
|
||||
|
||||
reply_map[request_id] += data_item.reply
|
||||
depth += 1
|
||||
except Exception as e:
|
||||
depth += 1
|
||||
continue
|
||||
t2 = time.time()
|
||||
logger.info(
|
||||
f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}"
|
||||
)
|
||||
self.sessions.session_reply(reply_map[request_id], session_id,
|
||||
usage.get("total_tokens"))
|
||||
reply = Reply(ReplyType.TEXT, reply_map[request_id])
|
||||
del reply_map[request_id]
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR,
|
||||
"Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def create_web_socket(self, prompt, session_id, temperature=0.5):
|
||||
logger.info(f"[XunFei] start connect, prompt={prompt}")
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = self.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl,
|
||||
on_message=on_message,
|
||||
on_error=on_error,
|
||||
on_close=on_close,
|
||||
on_open=on_open)
|
||||
data_queue = queue.Queue(1000)
|
||||
queue_map[session_id] = data_queue
|
||||
ws.appid = self.app_id
|
||||
ws.question = prompt
|
||||
ws.domain = self.domain
|
||||
ws.session_id = session_id
|
||||
ws.temperature = temperature
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def gen_request_id(self, session_id: str):
|
||||
return session_id + "_" + str(int(time.time())) + "" + str(
|
||||
random.randint(0, 100))
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.api_secret.encode('utf-8'),
|
||||
signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(
|
||||
encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
|
||||
f'signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(
|
||||
authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.spark_url + '?' + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
def gen_params(self, appid, domain, question):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class ReplyItem:
|
||||
def __init__(self, reply, usage=None, is_end=False):
|
||||
self.is_end = is_end
|
||||
self.reply = reply
|
||||
self.usage = usage
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
logger.error(f"[XunFei] error: {str(error)}")
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws, one, two):
|
||||
data_queue = queue_map.get(ws.session_id)
|
||||
data_queue.put("END")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
|
||||
thread.start_new_thread(run, (ws, ))
|
||||
|
||||
|
||||
def run(ws, *args):
|
||||
data = json.dumps(
|
||||
gen_params(appid=ws.appid,
|
||||
domain=ws.domain,
|
||||
question=ws.question,
|
||||
temperature=ws.temperature))
|
||||
ws.send(data)
|
||||
|
||||
|
||||
# Websocket 操作
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
logger.error(f'请求错误: {code}, {data}')
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
data_queue = queue_map.get(ws.session_id)
|
||||
if not data_queue:
|
||||
logger.error(
|
||||
f"[XunFei] can't find data queue, session_id={ws.session_id}")
|
||||
return
|
||||
reply_item = ReplyItem(content)
|
||||
if status == 2:
|
||||
usage = data["payload"].get("usage")
|
||||
reply_item = ReplyItem(content, usage)
|
||||
reply_item.is_end = True
|
||||
ws.close()
|
||||
data_queue.put(reply_item)
|
||||
|
||||
|
||||
def gen_params(appid, domain, question, temperature=0.5):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"temperature": temperature,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
@@ -1,50 +1,84 @@
|
||||
from bot.bot_factory import create_bot
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from common.log import logger
|
||||
from bot import bot_factory
|
||||
from common.singleton import singleton
|
||||
from voice import voice_factory
|
||||
from config import conf
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from translate.factory import create_translator
|
||||
from voice.factory import create_voice
|
||||
|
||||
|
||||
@singleton
|
||||
class Bridge(object):
|
||||
def __init__(self):
|
||||
self.btype={
|
||||
self.btype = {
|
||||
"chat": const.CHATGPT,
|
||||
"voice_to_text": conf().get("voice_to_text", "openai"),
|
||||
"text_to_voice": conf().get("text_to_voice", "google")
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
}
|
||||
model_type = conf().get("model")
|
||||
model_type = conf().get("model") or const.GPT35
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype['chat'] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt"):
|
||||
self.btype['chat'] = const.CHATGPTONAZURE
|
||||
self.bots={}
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype["chat"] = const.CHATGPTONAZURE
|
||||
if model_type in ["wenxin", "wenxin-4"]:
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.GEMINI]:
|
||||
self.btype["chat"] = const.GEMINI
|
||||
|
||||
def get_bot(self,typename):
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
self.btype["chat"] = const.LINKAI
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
self.btype["voice_to_text"] = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
self.btype["text_to_voice"] = const.LINKAI
|
||||
|
||||
if model_type in ["claude"]:
|
||||
self.btype["chat"] = const.CLAUDEAI
|
||||
self.bots = {}
|
||||
self.chat_bots = {}
|
||||
|
||||
def get_bot(self, typename):
|
||||
if self.bots.get(typename) is None:
|
||||
logger.info("create bot {} for {}".format(self.btype[typename],typename))
|
||||
logger.info("create bot {} for {}".format(self.btype[typename], typename))
|
||||
if typename == "text_to_voice":
|
||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||
self.bots[typename] = create_voice(self.btype[typename])
|
||||
elif typename == "voice_to_text":
|
||||
self.bots[typename] = voice_factory.create_voice(self.btype[typename])
|
||||
self.bots[typename] = create_voice(self.btype[typename])
|
||||
elif typename == "chat":
|
||||
self.bots[typename] = bot_factory.create_bot(self.btype[typename])
|
||||
self.bots[typename] = create_bot(self.btype[typename])
|
||||
elif typename == "translate":
|
||||
self.bots[typename] = create_translator(self.btype[typename])
|
||||
return self.bots[typename]
|
||||
|
||||
def get_bot_type(self,typename):
|
||||
|
||||
def get_bot_type(self, typename):
|
||||
return self.btype[typename]
|
||||
|
||||
|
||||
def fetch_reply_content(self, query, context : Context) -> Reply:
|
||||
def fetch_reply_content(self, query, context: Context) -> Reply:
|
||||
return self.get_bot("chat").reply(query, context)
|
||||
|
||||
|
||||
def fetch_voice_to_text(self, voiceFile) -> Reply:
|
||||
return self.get_bot("voice_to_text").voiceToText(voiceFile)
|
||||
|
||||
def fetch_text_to_voice(self, text) -> Reply:
|
||||
return self.get_bot("text_to_voice").textToVoice(text)
|
||||
|
||||
def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
|
||||
return self.get_bot("translate").translate(text, from_lang, to_lang)
|
||||
|
||||
def find_chat_bot(self, bot_type: str):
|
||||
if self.chat_bots.get(bot_type) is None:
|
||||
self.chat_bots[bot_type] = create_bot(bot_type)
|
||||
return self.chat_bots.get(bot_type)
|
||||
|
||||
def reset_bot(self):
|
||||
"""
|
||||
重置bot路由
|
||||
"""
|
||||
self.__init__()
|
||||
|
||||
@@ -2,35 +2,49 @@
|
||||
|
||||
from enum import Enum
|
||||
|
||||
class ContextType (Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE_CREATE = 3 # 创建图片命令
|
||||
|
||||
|
||||
class ContextType(Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE = 3 # 图片消息
|
||||
FILE = 4 # 文件信息
|
||||
VIDEO = 5 # 视频信息
|
||||
SHARING = 6 # 分享信息
|
||||
|
||||
IMAGE_CREATE = 10 # 创建图片命令
|
||||
ACCEPT_FRIEND = 19 # 同意好友请求
|
||||
JOIN_GROUP = 20 # 加入群聊
|
||||
PATPAT = 21 # 拍了拍
|
||||
FUNCTION = 22 # 函数调用
|
||||
EXIT_GROUP = 23 #退出
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, type : ContextType = None , content = None, kwargs = dict()):
|
||||
def __init__(self, type: ContextType = None, content=None, kwargs=dict()):
|
||||
self.type = type
|
||||
self.content = content
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __contains__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
return self.type is not None
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
return self.content is not None
|
||||
else:
|
||||
return key in self.kwargs
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
return self.type
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
return self.content
|
||||
else:
|
||||
return self.kwargs[key]
|
||||
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
@@ -38,20 +52,20 @@ class Context:
|
||||
return default
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
self.type = value
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
self.content = value
|
||||
else:
|
||||
self.kwargs[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
if key == 'type':
|
||||
if key == "type":
|
||||
self.type = None
|
||||
elif key == 'content':
|
||||
elif key == "content":
|
||||
self.content = None
|
||||
else:
|
||||
del self.kwargs[key]
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
||||
return "Context(type={}, content={}, kwargs={})".format(self.type, self.content, self.kwargs)
|
||||
|
||||
@@ -1,22 +1,31 @@
|
||||
|
||||
# encoding:utf-8
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ReplyType(Enum):
|
||||
TEXT = 1 # 文本
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
|
||||
TEXT = 1 # 文本
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
VIDEO_URL = 5 # 视频URL
|
||||
FILE = 6 # 文件
|
||||
CARD = 7 # 微信名片,仅支持ntchat
|
||||
InviteRoom = 8 # 邀请好友进群
|
||||
INFO = 9
|
||||
ERROR = 10
|
||||
TEXT_ = 11 # 强制文本
|
||||
VIDEO = 12
|
||||
MINIAPP = 13 # 小程序
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Reply:
|
||||
def __init__(self, type : ReplyType = None , content = None):
|
||||
def __init__(self, type: ReplyType = None, content=None):
|
||||
self.type = type
|
||||
self.content = content
|
||||
|
||||
def __str__(self):
|
||||
return "Reply(type={}, content={})".format(self.type, self.content)
|
||||
return "Reply(type={}, content={})".format(self.type, self.content)
|
||||
|
||||
@@ -4,9 +4,13 @@ Message sending channel abstract class
|
||||
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from bridge.reply import *
|
||||
|
||||
|
||||
class Channel(object):
|
||||
channel_type = ""
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
init channel
|
||||
@@ -20,20 +24,21 @@ class Channel(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def send(self, msg, receiver):
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""
|
||||
send message to user
|
||||
:param msg: message content
|
||||
:param receiver: receiver channel account
|
||||
:return:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_reply_content(self, query, context : Context=None) -> Reply:
|
||||
def build_reply_content(self, query, context: Context = None) -> Reply:
|
||||
return Bridge().fetch_reply_content(query, context)
|
||||
|
||||
def build_voice_to_text(self, voice_file) -> Reply:
|
||||
return Bridge().fetch_voice_to_text(voice_file)
|
||||
|
||||
|
||||
def build_text_to_voice(self, text) -> Reply:
|
||||
return Bridge().fetch_text_to_voice(text)
|
||||
|
||||
@@ -1,20 +1,45 @@
|
||||
"""
|
||||
channel factory
|
||||
"""
|
||||
from common import const
|
||||
from .channel import Channel
|
||||
|
||||
def create_channel(channel_type):
|
||||
|
||||
def create_channel(channel_type) -> Channel:
|
||||
"""
|
||||
create a channel instance
|
||||
:param channel_type: channel type code
|
||||
:return: channel instance
|
||||
"""
|
||||
if channel_type == 'wx':
|
||||
ch = Channel()
|
||||
if channel_type == "wx":
|
||||
from channel.wechat.wechat_channel import WechatChannel
|
||||
return WechatChannel()
|
||||
elif channel_type == 'wxy':
|
||||
ch = WechatChannel()
|
||||
elif channel_type == "wxy":
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
return WechatyChannel()
|
||||
elif channel_type == 'terminal':
|
||||
ch = WechatyChannel()
|
||||
elif channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
return TerminalChannel()
|
||||
raise RuntimeError
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == "wechatmp":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
ch = WechatMPChannel(passive_reply=True)
|
||||
elif channel_type == "wechatmp_service":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
ch = WechatMPChannel(passive_reply=False)
|
||||
elif channel_type == "wechatcom_app":
|
||||
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
|
||||
ch = WechatComAppChannel()
|
||||
elif channel_type == "wework":
|
||||
from channel.wework.wework_channel import WeworkChannel
|
||||
ch = WeworkChannel()
|
||||
elif channel_type == const.FEISHU:
|
||||
from channel.feishu.feishu_channel import FeiShuChanel
|
||||
ch = FeiShuChanel()
|
||||
elif channel_type == const.DINGTALK:
|
||||
from channel.dingtalk.dingtalk_channel import DingTalkChanel
|
||||
ch = DingTalkChanel()
|
||||
else:
|
||||
raise RuntimeError
|
||||
ch.channel_type = channel_type
|
||||
return ch
|
||||
|
||||
392
channel/chat_channel.py
Normal file
392
channel/chat_channel.py
Normal file
@@ -0,0 +1,392 @@
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.channel import Channel
|
||||
from common.dequeue import Dequeue
|
||||
from common import memory
|
||||
from plugins import *
|
||||
|
||||
try:
|
||||
from voice.audio_convert import any_to_wav
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
||||
class ChatChannel(Channel):
|
||||
name = None # 登录的用户名
|
||||
user_id = None # 登录的用户id
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
|
||||
def __init__(self):
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
|
||||
# 根据消息构造context,消息内容相关的触发项写在这里
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
# context首次传入时,origin_ctype是None,
|
||||
# 引入的起因是:当输入语音时,会嵌套生成两个context,第一步语音转文本,第二步通过文本生成文字回复。
|
||||
# origin_ctype用于第二步文本回复时,判断是否需要匹配前缀,如果是私聊的语音,就不需要匹配前缀
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
# context首次传入时,receiver是None,根据类型设置receiver
|
||||
first_in = "receiver" not in context
|
||||
# 群名匹配过程,设置session_id和receiver
|
||||
if first_in: # context首次传入时,receiver是None,根据类型设置receiver
|
||||
config = conf()
|
||||
cmsg = context["msg"]
|
||||
user_data = conf().get_user_data(cmsg.from_user_id)
|
||||
context["openai_api_key"] = user_data.get("openai_api_key")
|
||||
context["gpt_model"] = user_data.get("gpt_model")
|
||||
if context.get("isgroup", False):
|
||||
group_name = cmsg.other_user_nickname
|
||||
group_id = cmsg.other_user_id
|
||||
|
||||
group_name_white_list = config.get("group_name_white_list", [])
|
||||
group_name_keyword_white_list = config.get("group_name_keyword_white_list", [])
|
||||
if any(
|
||||
[
|
||||
group_name in group_name_white_list,
|
||||
"ALL_GROUP" in group_name_white_list,
|
||||
check_contain(group_name, group_name_keyword_white_list),
|
||||
]
|
||||
):
|
||||
group_chat_in_one_session = conf().get("group_chat_in_one_session", [])
|
||||
session_id = cmsg.actual_user_id
|
||||
if any(
|
||||
[
|
||||
group_name in group_chat_in_one_session,
|
||||
"ALL_GROUP" in group_chat_in_one_session,
|
||||
]
|
||||
):
|
||||
session_id = group_id
|
||||
else:
|
||||
return None
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = group_id
|
||||
else:
|
||||
context["session_id"] = cmsg.other_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_RECEIVE_MESSAGE, {"channel": self, "context": context}))
|
||||
context = e_context["context"]
|
||||
if e_context.is_pass() or context is None:
|
||||
return context
|
||||
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
|
||||
logger.debug("[WX]self message skipped")
|
||||
return None
|
||||
|
||||
# 消息内容匹配过程,并处理content
|
||||
if ctype == ContextType.TEXT:
|
||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||
logger.debug(content)
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return None
|
||||
|
||||
nick_name_black_list = conf().get("nick_name_black_list", [])
|
||||
if context.get("isgroup", False): # 群聊
|
||||
# 校验关键字
|
||||
match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
|
||||
match_contain = check_contain(content, conf().get("group_chat_keyword"))
|
||||
flag = False
|
||||
if context["msg"].to_user_id != context["msg"].actual_user_id:
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
flag = True
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
if context["msg"].is_at:
|
||||
nick_name = context["msg"].actual_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[WX] Nickname {nick_name} in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
logger.info("[WX]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
flag = True
|
||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
if isinstance(context["msg"].at_list, list):
|
||||
for at in context["msg"].at_list:
|
||||
pattern = f"@{re.escape(at)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", subtract_res)
|
||||
if subtract_res == content and context["msg"].self_display_name:
|
||||
# 前缀移除后没有变化,使用群昵称再次移除
|
||||
pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
content = subtract_res
|
||||
if not flag:
|
||||
if context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
||||
return None
|
||||
else: # 单聊
|
||||
nick_name = context["msg"].from_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[WX] Nickname '{nick_name}' in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
|
||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
content = content.strip()
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
if "desire_rtype" not in context and conf().get("always_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
elif context.type == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
def _handle(self, context: Context):
|
||||
if context is None or not context.content:
|
||||
return
|
||||
logger.debug("[WX] ready to handle context: {}".format(context))
|
||||
# reply的构建步骤
|
||||
reply = self._generate_reply(context)
|
||||
|
||||
logger.debug("[WX] ready to decorate reply: {}".format(reply))
|
||||
# reply的包装步骤
|
||||
reply = self._decorate_reply(context, reply)
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
|
||||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_HANDLE_CONTEXT,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass():
|
||||
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
context["channel"] = e_context["channel"]
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE: # 语音消息
|
||||
cmsg = context["msg"]
|
||||
cmsg.prepare()
|
||||
file_path = context.content
|
||||
wav_path = os.path.splitext(file_path)[0] + ".wav"
|
||||
try:
|
||||
any_to_wav(file_path, wav_path)
|
||||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
||||
logger.warning("[WX]any to wav error, use raw path. " + str(e))
|
||||
wav_path = file_path
|
||||
# 语音识别
|
||||
reply = super().build_voice_to_text(wav_path)
|
||||
# 删除临时文件
|
||||
try:
|
||||
os.remove(file_path)
|
||||
if wav_path != file_path:
|
||||
os.remove(wav_path)
|
||||
except Exception as e:
|
||||
pass
|
||||
# logger.warning("[WX]delete temp file error: " + str(e))
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
|
||||
if new_context:
|
||||
reply = self._generate_reply(new_context)
|
||||
else:
|
||||
return
|
||||
elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
|
||||
memory.USER_IMAGE_CACHE[context["session_id"]] = {
|
||||
"path": context.content,
|
||||
"msg": context.get("msg")
|
||||
}
|
||||
elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
|
||||
pass
|
||||
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
|
||||
pass
|
||||
else:
|
||||
logger.warning("[WX] unknown context type: {}".format(context.type))
|
||||
return
|
||||
return reply
|
||||
|
||||
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_DECORATE_REPLY,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
desire_rtype = context.get("desire_rtype")
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||
logger.error("[WX]reply type not support: " + str(reply.type))
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "不支持发送的消息类型: " + str(reply.type)
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if desire_rtype == ReplyType.VOICE and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context.get("isgroup", False):
|
||||
if not context.get("no_need_at", False):
|
||||
reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "")
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "")
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = "[" + str(reply.type) + "]\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
|
||||
pass
|
||||
else:
|
||||
logger.error("[WX] unknown reply type: {}".format(reply.type))
|
||||
return
|
||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
||||
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
|
||||
return reply
|
||||
|
||||
def _send_reply(self, context: Context, reply: Reply):
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(
|
||||
EventContext(
|
||||
Event.ON_SEND_REPLY,
|
||||
{"channel": self, "context": context, "reply": reply},
|
||||
)
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
|
||||
self._send(reply, context)
|
||||
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt=0):
|
||||
try:
|
||||
self.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error("[WX] sendMsg error: {}".format(str(e)))
|
||||
if isinstance(e, NotImplementedError):
|
||||
return
|
||||
logger.exception(e)
|
||||
if retry_cnt < 2:
|
||||
time.sleep(3 + 3 * retry_cnt)
|
||||
self._send(reply, context, retry_cnt + 1)
|
||||
|
||||
def _success_callback(self, session_id, **kwargs): # 线程正常结束时的回调函数
|
||||
logger.debug("Worker return success, session_id = {}".format(session_id))
|
||||
|
||||
def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.exception("Worker return exception: {}".format(exception))
|
||||
|
||||
def _thread_pool_callback(self, session_id, **kwargs):
|
||||
def func(worker: Future):
|
||||
try:
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
self._fail_callback(session_id, exception=worker_exception, **kwargs)
|
||||
else:
|
||||
self._success_callback(session_id, **kwargs)
|
||||
except CancelledError as e:
|
||||
logger.info("Worker cancelled, session_id = {}".format(session_id))
|
||||
except Exception as e:
|
||||
logger.exception("Worker raise exception: {}".format(e))
|
||||
with self.lock:
|
||||
self.sessions[session_id][1].release()
|
||||
|
||||
return func
|
||||
|
||||
def produce(self, context: Context):
|
||||
session_id = context["session_id"]
|
||||
with self.lock:
|
||||
if session_id not in self.sessions:
|
||||
self.sessions[session_id] = [
|
||||
Dequeue(),
|
||||
threading.BoundedSemaphore(conf().get("concurrency_in_session", 4)),
|
||||
]
|
||||
if context.type == ContextType.TEXT and context.content.startswith("#"):
|
||||
self.sessions[session_id][0].putleft(context) # 优先处理管理命令
|
||||
else:
|
||||
self.sessions[session_id][0].put(context)
|
||||
|
||||
# 消费者函数,单独线程,用于从消息队列中取出消息并处理
|
||||
def consume(self):
|
||||
while True:
|
||||
with self.lock:
|
||||
session_ids = list(self.sessions.keys())
|
||||
for session_id in session_ids:
|
||||
context_queue, semaphore = self.sessions[session_id]
|
||||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[WX] consume context: {}".format(context))
|
||||
future: Future = self.handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
self.futures[session_id].append(future)
|
||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
||||
del self.sessions[session_id]
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.1)
|
||||
|
||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
||||
def cancel_session(self, session_id):
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt > 0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
def cancel_all_session(self):
|
||||
with self.lock:
|
||||
for session_id in self.sessions:
|
||||
for future in self.futures[session_id]:
|
||||
future.cancel()
|
||||
cnt = self.sessions[session_id][0].qsize()
|
||||
if cnt > 0:
|
||||
logger.info("Cancel {} messages in session {}".format(cnt, session_id))
|
||||
self.sessions[session_id][0] = Dequeue()
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
if not prefix_list:
|
||||
return None
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
|
||||
def check_contain(content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
87
channel/chat_message.py
Normal file
87
channel/chat_message.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。
|
||||
|
||||
填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel
|
||||
|
||||
ChatMessage
|
||||
msg_id: 消息id (必填)
|
||||
create_time: 消息创建时间
|
||||
|
||||
ctype: 消息类型 : ContextType (必填)
|
||||
content: 消息内容, 如果是声音/图片,这里是文件路径 (必填)
|
||||
|
||||
from_user_id: 发送者id (必填)
|
||||
from_user_nickname: 发送者昵称
|
||||
to_user_id: 接收者id (必填)
|
||||
to_user_nickname: 接收者昵称
|
||||
|
||||
other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id (必填)
|
||||
other_user_nickname: 同上
|
||||
|
||||
is_group: 是否是群消息 (群聊必填)
|
||||
is_at: 是否被at
|
||||
|
||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
||||
actual_user_id: 实际发送者id (群聊必填)
|
||||
actual_user_nickname:实际发送者昵称
|
||||
self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
|
||||
|
||||
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
|
||||
_prepared: 是否已经调用过准备函数
|
||||
_rawmsg: 原始消息对象
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ChatMessage(object):
|
||||
msg_id = None
|
||||
create_time = None
|
||||
|
||||
ctype = None
|
||||
content = None
|
||||
|
||||
from_user_id = None
|
||||
from_user_nickname = None
|
||||
to_user_id = None
|
||||
to_user_nickname = None
|
||||
other_user_id = None
|
||||
other_user_nickname = None
|
||||
my_msg = False
|
||||
self_display_name = None
|
||||
|
||||
is_group = False
|
||||
is_at = False
|
||||
actual_user_id = None
|
||||
actual_user_nickname = None
|
||||
at_list = None
|
||||
|
||||
_prepare_fn = None
|
||||
_prepared = False
|
||||
_rawmsg = None
|
||||
|
||||
def __init__(self, _rawmsg):
|
||||
self._rawmsg = _rawmsg
|
||||
|
||||
def prepare(self):
|
||||
if self._prepare_fn and not self._prepared:
|
||||
self._prepared = True
|
||||
self._prepare_fn()
|
||||
|
||||
def __str__(self):
|
||||
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}, at_list={}".format(
|
||||
self.msg_id,
|
||||
self.create_time,
|
||||
self.ctype,
|
||||
self.content,
|
||||
self.from_user_id,
|
||||
self.from_user_nickname,
|
||||
self.to_user_id,
|
||||
self.to_user_nickname,
|
||||
self.other_user_id,
|
||||
self.other_user_nickname,
|
||||
self.is_group,
|
||||
self.is_at,
|
||||
self.actual_user_id,
|
||||
self.actual_user_nickname,
|
||||
self.at_list
|
||||
)
|
||||
100
channel/dingtalk/dingtalk_channel.py
Normal file
100
channel/dingtalk/dingtalk_channel.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
钉钉通道接入
|
||||
|
||||
@author huiwen
|
||||
@Date 2023/11/28
|
||||
"""
|
||||
|
||||
# -*- coding=utf-8 -*-
|
||||
from channel.dingtalk.dingtalk_message import DingTalkMessage
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from common.expired_dict import ExpiredDict
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_channel import ChatChannel
|
||||
import logging
|
||||
from dingtalk_stream import AckMessage
|
||||
import dingtalk_stream
|
||||
|
||||
|
||||
@singleton
|
||||
class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
|
||||
def setup_logger(self):
|
||||
logger = logging.getLogger()
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter('%(asctime)s %(name)-8s %(levelname)-8s %(message)s [%(filename)s:%(lineno)d]'))
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
return logger
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
super(dingtalk_stream.ChatbotHandler, self).__init__()
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
logger.info("[dingtalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
|
||||
def startup(self):
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
client.start_forever()
|
||||
|
||||
def handle_single(self, cmsg: DingTalkMessage):
|
||||
# 处理单聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
expression = cmsg.my_msg
|
||||
cmsg.content = conf()["single_chat_prefix"][0] + cmsg.content
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
def handle_group(self, cmsg: DingTalkMessage):
|
||||
# 处理群聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
expression = cmsg.my_msg
|
||||
cmsg.content = conf()["group_chat_prefix"][0] + cmsg.content
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
context['no_need_at'] = True
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
dingtalk_msg = DingTalkMessage(incoming_message)
|
||||
if incoming_message.conversation_type == '1':
|
||||
self.handle_single(dingtalk_msg)
|
||||
else:
|
||||
self.handle_group(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.FAILED_MSG
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
incoming_message = context.kwargs['msg'].incoming_message
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
44
channel/dingtalk/dingtalk_message.py
Normal file
44
channel/dingtalk/dingtalk_message.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
import json
|
||||
import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
from dingtalk_stream import ChatbotMessage
|
||||
|
||||
class DingTalkMessage(ChatMessage):
|
||||
def __init__(self, event: ChatbotMessage):
|
||||
super().__init__(event)
|
||||
|
||||
self.msg_id = event.message_id
|
||||
msg_type = event.message_type
|
||||
self.incoming_message =event
|
||||
self.sender_staff_id = event.sender_staff_id
|
||||
self.other_user_id = event.conversation_id
|
||||
self.create_time = event.create_at
|
||||
if event.conversation_type=="1":
|
||||
self.is_group = False
|
||||
else:
|
||||
self.is_group = True
|
||||
|
||||
|
||||
if msg_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
self.content = event.text.content.strip()
|
||||
elif msg_type == "audio":
|
||||
|
||||
# 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理
|
||||
self.content = event.extensions['content']['recognition'].strip()
|
||||
self.ctype = ContextType.TEXT
|
||||
self.from_user_id = event.sender_id
|
||||
self.to_user_id = event.chatbot_user_id
|
||||
self.other_user_nickname = event.conversation_title
|
||||
|
||||
user_id = event.sender_id
|
||||
nickname =event.sender_nick
|
||||
|
||||
|
||||
|
||||
|
||||
254
channel/feishu/feishu_channel.py
Normal file
254
channel/feishu/feishu_channel.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
飞书通道接入
|
||||
|
||||
@author Saboteur7
|
||||
@Date 2023/11/19
|
||||
"""
|
||||
|
||||
# -*- coding=utf-8 -*-
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
import web
|
||||
from channel.feishu.feishu_message import FeishuMessage
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from common.expired_dict import ExpiredDict
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from common import utils
|
||||
import json
|
||||
import os
|
||||
|
||||
URL_VERIFICATION = "url_verification"
|
||||
|
||||
|
||||
@singleton
|
||||
class FeiShuChanel(ChatChannel):
|
||||
feishu_app_id = conf().get('feishu_app_id')
|
||||
feishu_app_secret = conf().get('feishu_app_secret')
|
||||
feishu_token = conf().get('feishu_token')
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
logger.info("[FeiShu] app_id={}, app_secret={} verification_token={}".format(
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = []
|
||||
|
||||
def startup(self):
|
||||
urls = (
|
||||
'/', 'channel.feishu.feishu_channel.FeishuController'
|
||||
)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("feishu_port", 9891)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context["isgroup"]
|
||||
if msg:
|
||||
access_token = msg.access_token
|
||||
else:
|
||||
access_token = self.fetch_access_token()
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
msg_type = "text"
|
||||
logger.info(f"[FeiShu] start send reply message, type={context.type}, content={reply.content}")
|
||||
reply_content = reply.content
|
||||
content_key = "text"
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
# 图片上传
|
||||
reply_content = self._upload_image_url(reply.content, access_token)
|
||||
if not reply_content:
|
||||
logger.warning("[FeiShu] upload file failed")
|
||||
return
|
||||
msg_type = "image"
|
||||
content_key = "image_key"
|
||||
if is_group:
|
||||
# 群聊中直接回复
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply"
|
||||
data = {
|
||||
"msg_type": msg_type,
|
||||
"content": json.dumps({content_key: reply_content})
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
|
||||
else:
|
||||
url = "https://open.feishu.cn/open-apis/im/v1/messages"
|
||||
params = {"receive_id_type": context.get("receive_id_type") or "open_id"}
|
||||
data = {
|
||||
"receive_id": context.get("receiver"),
|
||||
"msg_type": msg_type,
|
||||
"content": json.dumps({content_key: reply_content})
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10))
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
logger.info(f"[FeiShu] send message success")
|
||||
else:
|
||||
logger.error(f"[FeiShu] send message failed, code={res.get('code')}, msg={res.get('msg')}")
|
||||
|
||||
|
||||
def fetch_access_token(self) -> str:
|
||||
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/"
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
req_body = {
|
||||
"app_id": self.feishu_app_id,
|
||||
"app_secret": self.feishu_app_secret
|
||||
}
|
||||
data = bytes(json.dumps(req_body), encoding='utf8')
|
||||
response = requests.post(url=url, data=data, headers=headers)
|
||||
if response.status_code == 200:
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
logger.error(f"[FeiShu] get tenant_access_token error, code={res.get('code')}, msg={res.get('msg')}")
|
||||
return ""
|
||||
else:
|
||||
return res.get("tenant_access_token")
|
||||
else:
|
||||
logger.error(f"[FeiShu] fetch token error, res={response}")
|
||||
|
||||
|
||||
def _upload_image_url(self, img_url, access_token):
|
||||
logger.debug(f"[WX] start download image, img_url={img_url}")
|
||||
response = requests.get(img_url)
|
||||
suffix = utils.get_path_suffix(img_url)
|
||||
temp_name = str(uuid.uuid4()) + "." + suffix
|
||||
if response.status_code == 200:
|
||||
# 将图片内容保存为临时文件
|
||||
with open(temp_name, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
# upload
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
|
||||
data = {
|
||||
'image_type': 'message'
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
}
|
||||
with open(temp_name, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
os.remove(temp_name)
|
||||
return upload_response.json().get("data").get("image_key")
|
||||
|
||||
|
||||
|
||||
class FeishuController:
|
||||
# 类常量
|
||||
FAILED_MSG = '{"success": false}'
|
||||
SUCCESS_MSG = '{"success": true}'
|
||||
MESSAGE_RECEIVE_TYPE = "im.message.receive_v1"
|
||||
|
||||
def GET(self):
|
||||
return "Feishu service start success!"
|
||||
|
||||
def POST(self):
|
||||
try:
|
||||
channel = FeiShuChanel()
|
||||
|
||||
request = json.loads(web.data().decode("utf-8"))
|
||||
logger.debug(f"[FeiShu] receive request: {request}")
|
||||
|
||||
# 1.事件订阅回调验证
|
||||
if request.get("type") == URL_VERIFICATION:
|
||||
varify_res = {"challenge": request.get("challenge")}
|
||||
return json.dumps(varify_res)
|
||||
|
||||
# 2.消息接收处理
|
||||
# token 校验
|
||||
header = request.get("header")
|
||||
if not header or header.get("token") != channel.feishu_token:
|
||||
return self.FAILED_MSG
|
||||
|
||||
# 处理消息事件
|
||||
event = request.get("event")
|
||||
if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event:
|
||||
if not event.get("message") or not event.get("sender"):
|
||||
logger.warning(f"[FeiShu] invalid message, msg={request}")
|
||||
return self.FAILED_MSG
|
||||
msg = event.get("message")
|
||||
|
||||
# 幂等判断
|
||||
if channel.receivedMsgs.get(msg.get("message_id")):
|
||||
logger.warning(f"[FeiShu] repeat msg filtered, event_id={header.get('event_id')}")
|
||||
return self.SUCCESS_MSG
|
||||
channel.receivedMsgs[msg.get("message_id")] = True
|
||||
|
||||
is_group = False
|
||||
chat_type = msg.get("chat_type")
|
||||
if chat_type == "group":
|
||||
if not msg.get("mentions") and msg.get("message_type") == "text":
|
||||
# 群聊中未@不响应
|
||||
return self.SUCCESS_MSG
|
||||
if msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get("message_type") == "text":
|
||||
# 不是@机器人,不响应
|
||||
return self.SUCCESS_MSG
|
||||
# 群聊
|
||||
is_group = True
|
||||
receive_id_type = "chat_id"
|
||||
elif chat_type == "p2p":
|
||||
receive_id_type = "open_id"
|
||||
else:
|
||||
logger.warning("[FeiShu] message ignore")
|
||||
return self.SUCCESS_MSG
|
||||
# 构造飞书消息对象
|
||||
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=channel.fetch_access_token())
|
||||
if not feishu_msg:
|
||||
return self.SUCCESS_MSG
|
||||
|
||||
context = self._compose_context(
|
||||
feishu_msg.ctype,
|
||||
feishu_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=feishu_msg,
|
||||
receive_id_type=receive_id_type,
|
||||
no_need_at=True
|
||||
)
|
||||
if context:
|
||||
channel.produce(context)
|
||||
logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
|
||||
return self.SUCCESS_MSG
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.FAILED_MSG
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# 1.文本请求
|
||||
# 图片生成处理
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
elif context.type == ContextType.VOICE:
|
||||
# 2.语音请求
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
63
channel/feishu/feishu_message.py
Normal file
63
channel/feishu/feishu_message.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
import json
|
||||
import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
|
||||
|
||||
class FeishuMessage(ChatMessage):
|
||||
def __init__(self, event: dict, is_group=False, access_token=None):
|
||||
super().__init__(event)
|
||||
msg = event.get("message")
|
||||
sender = event.get("sender")
|
||||
self.access_token = access_token
|
||||
self.msg_id = msg.get("message_id")
|
||||
self.create_time = msg.get("create_time")
|
||||
self.is_group = is_group
|
||||
msg_type = msg.get("message_type")
|
||||
|
||||
if msg_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
content = json.loads(msg.get('content'))
|
||||
self.content = content.get("text").strip()
|
||||
elif msg_type == "file":
|
||||
self.ctype = ContextType.FILE
|
||||
content = json.loads(msg.get("content"))
|
||||
file_key = content.get("file_key")
|
||||
file_name = content.get("file_name")
|
||||
|
||||
self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name)
|
||||
|
||||
def _download_file():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}"
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
}
|
||||
params = {
|
||||
"type": "file"
|
||||
}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
|
||||
self._prepare_fn = _download_file
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))
|
||||
|
||||
self.from_user_id = sender.get("sender_id").get("open_id")
|
||||
self.to_user_id = event.get("app_id")
|
||||
if is_group:
|
||||
# 群聊
|
||||
self.other_user_id = msg.get("chat_id")
|
||||
self.actual_user_id = self.from_user_id
|
||||
self.content = self.content.replace("@_user_1", "").strip()
|
||||
self.actual_user_nickname = ""
|
||||
else:
|
||||
# 私聊
|
||||
self.other_user_id = self.from_user_id
|
||||
self.actual_user_id = self.from_user_id
|
||||
@@ -1,31 +1,92 @@
|
||||
from bridge.context import *
|
||||
from channel.channel import Channel
|
||||
import sys
|
||||
|
||||
class TerminalChannel(Channel):
|
||||
from bridge.context import *
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class TerminalMessage(ChatMessage):
|
||||
def __init__(
|
||||
self,
|
||||
msg_id,
|
||||
content,
|
||||
ctype=ContextType.TEXT,
|
||||
from_user_id="User",
|
||||
to_user_id="Chatgpt",
|
||||
other_user_id="Chatgpt",
|
||||
):
|
||||
self.msg_id = msg_id
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = to_user_id
|
||||
self.other_user_id = other_user_id
|
||||
|
||||
|
||||
class TerminalChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
print("\nBot:")
|
||||
if reply.type == ReplyType.IMAGE:
|
||||
from PIL import Image
|
||||
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print("<IMAGE>")
|
||||
img.show()
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print(img_url)
|
||||
img.show()
|
||||
else:
|
||||
print(reply.content)
|
||||
print("\nUser:", end="")
|
||||
sys.stdout.flush()
|
||||
return
|
||||
|
||||
def startup(self):
|
||||
context = Context()
|
||||
print("\nPlease input your question")
|
||||
logger.setLevel("WARN")
|
||||
print("\nPlease input your question:\nUser:", end="")
|
||||
sys.stdout.flush()
|
||||
msg_id = 0
|
||||
while True:
|
||||
try:
|
||||
prompt = self.get_input("User:\n")
|
||||
prompt = self.get_input()
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
sys.exit()
|
||||
msg_id += 1
|
||||
trigger_prefixs = conf().get("single_chat_prefix", [""])
|
||||
if check_prefix(prompt, trigger_prefixs) is None:
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
|
||||
context.type = ContextType.TEXT
|
||||
context['session_id'] = "User"
|
||||
context.content = prompt
|
||||
print("Bot:")
|
||||
sys.stdout.flush()
|
||||
res = super().build_reply_content(prompt, context).content
|
||||
print(res)
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
|
||||
if context:
|
||||
self.produce(context)
|
||||
else:
|
||||
raise Exception("context is None")
|
||||
|
||||
|
||||
def get_input(self, prompt):
|
||||
def get_input(self):
|
||||
"""
|
||||
Multi-line input function
|
||||
"""
|
||||
print(prompt, end="")
|
||||
sys.stdout.flush()
|
||||
line = input()
|
||||
return line
|
||||
|
||||
@@ -4,99 +4,141 @@
|
||||
wechat channel
|
||||
"""
|
||||
|
||||
import os
|
||||
import requests
|
||||
import io
|
||||
import time
|
||||
from lib import itchat
|
||||
import json
|
||||
from lib.itchat.content import *
|
||||
from bridge.reply import *
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from bridge.context import *
|
||||
from channel.channel import Channel
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from config import conf
|
||||
from common.time_check import time_checker
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wechat_message import *
|
||||
from common.expired_dict import ExpiredDict
|
||||
from plugins import *
|
||||
try:
|
||||
from voice.audio_convert import mp3_to_wav
|
||||
except Exception as e:
|
||||
pass
|
||||
thread_pool = ThreadPoolExecutor(max_workers=8)
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.time_check import time_checker
|
||||
from config import conf, get_appdata_dir
|
||||
from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
|
||||
|
||||
def thread_pool_callback(worker):
|
||||
worker_exception = worker.exception()
|
||||
if worker_exception:
|
||||
logger.exception("Worker return exception: {}".format(worker_exception))
|
||||
|
||||
|
||||
@itchat.msg_register(TEXT)
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING])
|
||||
def handler_single_msg(msg):
|
||||
WechatChannel().handle_text(msg)
|
||||
try:
|
||||
cmsg = WechatMessage(msg, False)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX]single message {} skipped: {}".format(msg["MsgId"], e))
|
||||
return None
|
||||
WechatChannel().handle_single(cmsg)
|
||||
return None
|
||||
|
||||
@itchat.msg_register(TEXT, isGroupChat=True)
|
||||
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING], isGroupChat=True)
|
||||
def handler_group_msg(msg):
|
||||
WechatChannel().handle_group(msg)
|
||||
try:
|
||||
cmsg = WechatMessage(msg, True)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX]group message {} skipped: {}".format(msg["MsgId"], e))
|
||||
return None
|
||||
WechatChannel().handle_group(cmsg)
|
||||
return None
|
||||
|
||||
@itchat.msg_register(VOICE)
|
||||
def handler_single_voice(msg):
|
||||
WechatChannel().handle_voice(msg)
|
||||
return None
|
||||
|
||||
@itchat.msg_register(VOICE, isGroupChat=True)
|
||||
def handler_group_voice(msg):
|
||||
WechatChannel().handle_group_voice(msg)
|
||||
return None
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, msg):
|
||||
msgId = msg['MsgId']
|
||||
def wrapper(self, cmsg: ChatMessage):
|
||||
msgId = cmsg.msg_id
|
||||
if msgId in self.receivedMsgs:
|
||||
logger.info("Wechat message {} already received, ignore".format(msgId))
|
||||
return
|
||||
self.receivedMsgs[msgId] = msg
|
||||
create_time = msg['CreateTime'] # 消息时间
|
||||
if conf().get('hot_reload') == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
self.receivedMsgs[msgId] = True
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, msg)
|
||||
if cmsg.my_msg and not cmsg.is_group:
|
||||
logger.debug("[WX]my message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class WechatChannel(Channel):
|
||||
# 可用的二维码生成接口
|
||||
# https://api.qrserver.com/v1/create-qr-code/?size=400×400&data=https://www.abc.com
|
||||
# https://api.isoyu.com/qr/?m=1&e=L&p=20&url=https://www.abc.com
|
||||
def qrCallback(uuid, status, qrcode):
|
||||
# logger.debug("qrCallback: {} {}".format(uuid,status))
|
||||
if status == "0":
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(io.BytesIO(qrcode))
|
||||
_thread = threading.Thread(target=img.show, args=("QRCode",))
|
||||
_thread.setDaemon(True)
|
||||
_thread.start()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
import qrcode
|
||||
|
||||
url = f"https://login.weixin.qq.com/l/{uuid}"
|
||||
|
||||
qr_api1 = "https://api.isoyu.com/qr/?m=1&e=L&p=20&url={}".format(url)
|
||||
qr_api2 = "https://api.qrserver.com/v1/create-qr-code/?size=400×400&data={}".format(url)
|
||||
qr_api3 = "https://api.pwmqr.com/qrcode/create/?url={}".format(url)
|
||||
qr_api4 = "https://my.tv.sohu.com/user/a/wvideo/getQRCode.do?text={}".format(url)
|
||||
print("You can also scan QRCode in any website below:")
|
||||
print(qr_api3)
|
||||
print(qr_api4)
|
||||
print(qr_api2)
|
||||
print(qr_api1)
|
||||
|
||||
qr = qrcode.QRCode(border=1)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
self.userName = None
|
||||
self.nickName = None
|
||||
self.receivedMsgs = ExpiredDict(60*60*24)
|
||||
super().__init__()
|
||||
self.receivedMsgs = ExpiredDict(60 * 60)
|
||||
self.auto_login_times = 0
|
||||
|
||||
def startup(self):
|
||||
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get('hot_reload', False)
|
||||
try:
|
||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
|
||||
except Exception as e:
|
||||
if hotReload:
|
||||
logger.error("Hot reload failed, try to login without hot reload")
|
||||
itchat.logout()
|
||||
os.remove("itchat.pkl")
|
||||
itchat.auto_login(enableCmdQR=2, hotReload=hotReload)
|
||||
else:
|
||||
raise e
|
||||
self.userName = itchat.instance.storageClass.userName
|
||||
self.nickName = itchat.instance.storageClass.nickName
|
||||
logger.info("Wechat login success, username: {}, nickname: {}".format(self.userName, self.nickName))
|
||||
hotReload = conf().get("hot_reload", False)
|
||||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
itchat.auto_login(
|
||||
enableCmdQR=2,
|
||||
hotReload=hotReload,
|
||||
statusStorageDir=status_path,
|
||||
qrCallback=qrCallback,
|
||||
exitCallback=self.exitCallback,
|
||||
loginCallback=self.loginCallback
|
||||
)
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
self.name = itchat.instance.storageClass.nickName
|
||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# start message listener
|
||||
itchat.run()
|
||||
|
||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入handle函数中处理Context和发送回复
|
||||
def exitCallback(self):
|
||||
self.auto_login_times += 1
|
||||
if self.auto_login_times < 100:
|
||||
self.startup()
|
||||
|
||||
def loginCallback(self):
|
||||
pass
|
||||
|
||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复
|
||||
# Context包含了消息的所有信息,包括以下属性
|
||||
# type 消息类型, 包括TEXT、VOICE、IMAGE_CREATE
|
||||
# content 消息内容,如果是TEXT类型,content就是文本内容,如果是VOICE类型,content就是语音文件名,如果是IMAGE_CREATE类型,content就是图片生成命令
|
||||
@@ -104,289 +146,102 @@ class WechatChannel(Channel):
|
||||
# session_id: 会话id
|
||||
# isgroup: 是否是群聊
|
||||
# receiver: 需要回复的对象
|
||||
# msg: itchat的原始消息对象
|
||||
# origin_ctype: 原始消息类型,用于私聊语音消息时,避免匹配前缀
|
||||
# desire_rtype: 希望回复类型,TEXT类型是文本回复,VOICE类型是语音回复
|
||||
# msg: ChatMessage消息对象
|
||||
# origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
|
||||
# desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_voice(self, msg):
|
||||
if conf().get('speech_recognition') != True:
|
||||
def handle_single(self, cmsg: ChatMessage):
|
||||
# filter system message
|
||||
if cmsg.other_user_id in ["weixin"]:
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: " + msg['FileName'])
|
||||
to_user_id = msg['ToUserName']
|
||||
from_user_id = msg['FromUserName']
|
||||
try:
|
||||
other_user_id = msg['User']['UserName'] # 对手方id
|
||||
except Exception as e:
|
||||
logger.warn("[WX]get other_user_id failed: " + str(e))
|
||||
if from_user_id == self.userName:
|
||||
other_user_id = to_user_id
|
||||
else:
|
||||
other_user_id = from_user_id
|
||||
if from_user_id == other_user_id:
|
||||
context = self._compose_context(ContextType.VOICE, msg['FileName'], isgroup=False, msg=msg, receiver=other_user_id, session_id=other_user_id)
|
||||
if context:
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_text(self, msg):
|
||||
logger.debug("[WX]receive text msg: " + json.dumps(msg, ensure_ascii=False))
|
||||
content = msg['Text']
|
||||
from_user_id = msg['FromUserName']
|
||||
to_user_id = msg['ToUserName'] # 接收人id
|
||||
try:
|
||||
other_user_id = msg['User']['UserName'] # 对手方id
|
||||
except Exception as e:
|
||||
logger.warn("[WX]get other_user_id failed: " + str(e))
|
||||
if from_user_id == self.userName:
|
||||
other_user_id = to_user_id
|
||||
else:
|
||||
other_user_id = from_user_id
|
||||
if "」\n- - - - - - - - - - - - - - -" in content:
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return
|
||||
|
||||
context = self._compose_context(ContextType.TEXT, content, isgroup=False, msg=msg, receiver=other_user_id, session_id=other_user_id)
|
||||
if context:
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, msg):
|
||||
logger.debug("[WX]receive group msg: " + json.dumps(msg, ensure_ascii=False))
|
||||
group_name = msg['User'].get('NickName', None)
|
||||
group_id = msg['User'].get('UserName', None)
|
||||
if not group_name:
|
||||
return ""
|
||||
origin_content = msg['Content']
|
||||
content = msg['Content']
|
||||
content_list = content.split(' ', 1)
|
||||
context_special_list = content.split('\u2005', 1)
|
||||
if len(context_special_list) == 2:
|
||||
content = context_special_list[1]
|
||||
elif len(content_list) == 2:
|
||||
content = content_list[1]
|
||||
if "」\n- - - - - - - - - - - - - - -" in content:
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return ""
|
||||
|
||||
config = conf()
|
||||
group_name_white_list = config.get('group_name_white_list', [])
|
||||
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
|
||||
|
||||
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
session_id = msg['ActualUserName']
|
||||
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
|
||||
session_id = group_id
|
||||
context = self._compose_context(ContextType.TEXT, content, isgroup=True, msg=msg, receiver=group_id, session_id=session_id)
|
||||
if context:
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group_voice(self, msg):
|
||||
if conf().get('group_speech_recognition', False) != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: " + msg['FileName'])
|
||||
group_name = msg['User'].get('NickName', None)
|
||||
group_id = msg['User'].get('UserName', None)
|
||||
# 验证群名
|
||||
if not group_name:
|
||||
return ""
|
||||
|
||||
config = conf()
|
||||
group_name_white_list = config.get('group_name_white_list', [])
|
||||
group_name_keyword_white_list = config.get('group_name_keyword_white_list', [])
|
||||
if any([group_name in group_name_white_list, 'ALL_GROUP' in group_name_white_list, check_contain(group_name, group_name_keyword_white_list)]):
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
session_id =msg['ActualUserName']
|
||||
if any([group_name in group_chat_in_one_session, 'ALL_GROUP' in group_chat_in_one_session]):
|
||||
session_id = group_id
|
||||
context = self._compose_context(ContextType.VOICE, msg['FileName'], isgroup=True, msg=msg, receiver=group_id, session_id=session_id)
|
||||
if context:
|
||||
thread_pool.submit(self.handle, context).add_done_callback(thread_pool_callback)
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if 'origin_ctype' not in context:
|
||||
context['origin_ctype'] = ctype
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
if context["isgroup"]: # 群聊
|
||||
# 校验关键字
|
||||
match_prefix = check_prefix(content, conf().get('group_chat_prefix'))
|
||||
match_contain = check_contain(content, conf().get('group_chat_keyword'))
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
# 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
elif context['msg']['IsAt'] and not conf().get("group_at_off", False):
|
||||
logger.info("[WX]receive group at, continue")
|
||||
elif context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, checkprefix didn't match")
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
else: # 单聊
|
||||
match_prefix = check_prefix(content, conf().get('single_chat_prefix'))
|
||||
if match_prefix: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, '', 1).strip()
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,不匹配前缀,直接返回
|
||||
pass
|
||||
else:
|
||||
return None
|
||||
img_match_prefix = check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, '', 1).strip()
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content
|
||||
elif context.type == ContextType.VOICE:
|
||||
if 'desire_rtype' not in context and conf().get('voice_reply_voice'):
|
||||
context['desire_rtype'] = ReplyType.VOICE
|
||||
return context
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, receiver, retry_cnt = 0):
|
||||
try:
|
||||
if reply.type == ReplyType.TEXT:
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
itchat.send_file(reply.content, toUserName=receiver)
|
||||
logger.info('[WX] sendFile={}, receiver={}'.format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info('[WX] sendImage url={}, receiver={}'.format(img_url,receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info('[WX] sendImage, receiver={}'.format(receiver))
|
||||
except Exception as e:
|
||||
logger.error('[WX] sendMsg error: {}, receiver={}'.format(e, receiver))
|
||||
if retry_cnt < 2:
|
||||
time.sleep(3+3*retry_cnt)
|
||||
self.send(reply, receiver, retry_cnt + 1)
|
||||
|
||||
# 处理消息 TODO: 如果wechaty解耦,此处逻辑可以放置到父类
|
||||
def handle(self, context: Context):
|
||||
if context is None or not context.content:
|
||||
return
|
||||
logger.debug('[WX] ready to handle context: {}'.format(context))
|
||||
# reply的构建步骤
|
||||
reply = self._generate_reply(context)
|
||||
|
||||
logger.debug('[WX] ready to decorate reply: {}'.format(reply))
|
||||
# reply的包装步骤
|
||||
reply = self._decorate_reply(context, reply)
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
|
||||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_HANDLE_CONTEXT, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
if not e_context.is_pass():
|
||||
logger.debug('[WX] ready to handle context: type={}, content={}'.format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE: # 语音消息
|
||||
msg = context['msg']
|
||||
mp3_path = TmpDir().path() + context.content
|
||||
msg.download(mp3_path)
|
||||
# mp3转wav
|
||||
wav_path = os.path.splitext(mp3_path)[0] + '.wav'
|
||||
try:
|
||||
mp3_to_wav(mp3_path=mp3_path, wav_path=wav_path)
|
||||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
||||
logger.warning("[WX]mp3 to wav error, use mp3 path. " + str(e))
|
||||
wav_path = mp3_path
|
||||
# 语音识别
|
||||
reply = super().build_voice_to_text(wav_path)
|
||||
# 删除临时文件
|
||||
try:
|
||||
os.remove(wav_path)
|
||||
os.remove(mp3_path)
|
||||
except Exception as e:
|
||||
logger.warning("[WX]delete temp file error: " + str(e))
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
new_context = self._compose_context(
|
||||
ContextType.TEXT, reply.content, **context.kwargs)
|
||||
if new_context:
|
||||
reply = self._generate_reply(new_context)
|
||||
else:
|
||||
return
|
||||
else:
|
||||
logger.error('[WX] unknown context type: {}'.format(context.type))
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if conf().get("speech_recognition") != True:
|
||||
return
|
||||
return reply
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
else:
|
||||
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_DECORATE_REPLY, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
desire_rtype = context.get('desire_rtype')
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if desire_rtype == ReplyType.VOICE:
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context['isgroup']:
|
||||
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = str(reply.type)+":\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
||||
pass
|
||||
else:
|
||||
logger.error('[WX] unknown reply type: {}'.format(reply.type))
|
||||
return
|
||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
||||
logger.warning('[WX] desire_rtype: {}, but reply type: {}'.format(context.get('desire_rtype'), reply.type))
|
||||
return reply
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: ChatMessage):
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if conf().get("group_speech_recognition") != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.ACCEPT_FRIEND, ContextType.EXIT_GROUP]:
|
||||
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
pass
|
||||
elif cmsg.ctype == ContextType.FILE:
|
||||
logger.debug(f"[WX]receive attachment msg, file_name={cmsg.content}")
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
def _send_reply(self, context: Context, reply: Reply):
|
||||
if reply and reply.type:
|
||||
e_context = PluginManager().emit_event(EventContext(Event.ON_SEND_REPLY, {
|
||||
'channel': self, 'context': context, 'reply': reply}))
|
||||
reply = e_context['reply']
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug('[WX] ready to send reply: {} to {}'.format(reply, context['receiver']))
|
||||
self.send(reply, context['receiver'])
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
def check_contain(content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if reply.type == ReplyType.TEXT:
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
itchat.send_file(reply.content, toUserName=receiver)
|
||||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
logger.debug(f"[WX] start download image, img_url={img_url}")
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
size = 0
|
||||
for block in pic_res.iter_content(1024):
|
||||
size += len(block)
|
||||
image_storage.write(block)
|
||||
logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.FILE: # 新增文件回复类型
|
||||
file_storage = reply.content
|
||||
itchat.send_file(file_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendFile, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO: # 新增视频回复类型
|
||||
video_storage = reply.content
|
||||
itchat.send_video(video_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendFile, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL: # 新增视频URL回复类型
|
||||
video_url = reply.content
|
||||
logger.debug(f"[WX] start download video, video_url={video_url}")
|
||||
video_res = requests.get(video_url, stream=True)
|
||||
video_storage = io.BytesIO()
|
||||
size = 0
|
||||
for block in video_res.iter_content(1024):
|
||||
size += len(block)
|
||||
video_storage.write(block)
|
||||
logger.info(f"[WX] download video success, size={size}, video_url={video_url}")
|
||||
video_storage.seek(0)
|
||||
itchat.send_video(video_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendVideo url={}, receiver={}".format(video_url, receiver))
|
||||
|
||||
102
channel/wechat/wechat_message.py
Normal file
102
channel/wechat/wechat_message.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import re
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
|
||||
class WechatMessage(ChatMessage):
|
||||
def __init__(self, itchat_msg, is_group=False):
|
||||
super().__init__(itchat_msg)
|
||||
self.msg_id = itchat_msg["MsgId"]
|
||||
self.create_time = itchat_msg["CreateTime"]
|
||||
self.is_group = is_group
|
||||
|
||||
if itchat_msg["Type"] == TEXT:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = itchat_msg["Text"]
|
||||
elif itchat_msg["Type"] == VOICE:
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3:
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
|
||||
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
|
||||
# 这里只能得到nickname, actual_user_id还是机器人的id
|
||||
if "加入了群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
|
||||
elif "加入群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif is_group and ("移出了群聊" in itchat_msg["Content"]):
|
||||
self.ctype = ContextType.EXIT_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
|
||||
self.ctype = ContextType.ACCEPT_FRIEND
|
||||
self.content = itchat_msg["Content"]
|
||||
elif "拍了拍我" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.PATPAT
|
||||
self.content = itchat_msg["Content"]
|
||||
if is_group:
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
else:
|
||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||
elif itchat_msg["Type"] == ATTACHMENT:
|
||||
self.ctype = ContextType.FILE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == SHARING:
|
||||
self.ctype = ContextType.SHARING
|
||||
self.content = itchat_msg.get("Url")
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
|
||||
|
||||
self.from_user_id = itchat_msg["FromUserName"]
|
||||
self.to_user_id = itchat_msg["ToUserName"]
|
||||
|
||||
user_id = itchat.instance.storageClass.userName
|
||||
nickname = itchat.instance.storageClass.nickName
|
||||
|
||||
# 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下
|
||||
# 以下很繁琐,一句话总结:能填的都填了。
|
||||
if self.from_user_id == user_id:
|
||||
self.from_user_nickname = nickname
|
||||
if self.to_user_id == user_id:
|
||||
self.to_user_nickname = nickname
|
||||
try: # 陌生人时候, User字段可能不存在
|
||||
# my_msg 为True是表示是自己发送的消息
|
||||
self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
|
||||
itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
|
||||
self.other_user_id = itchat_msg["User"]["UserName"]
|
||||
self.other_user_nickname = itchat_msg["User"]["NickName"]
|
||||
if self.other_user_id == self.from_user_id:
|
||||
self.from_user_nickname = self.other_user_nickname
|
||||
if self.other_user_id == self.to_user_id:
|
||||
self.to_user_nickname = self.other_user_nickname
|
||||
if itchat_msg["User"].get("Self"):
|
||||
# 自身的展示名,当设置了群昵称时,该字段表示群昵称
|
||||
self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
|
||||
except KeyError as e: # 处理偶尔没有对方信息的情况
|
||||
logger.warn("[WX]get other_user_id failed: " + str(e))
|
||||
if self.from_user_id == user_id:
|
||||
self.other_user_id = self.to_user_id
|
||||
else:
|
||||
self.other_user_id = self.from_user_id
|
||||
|
||||
if self.is_group:
|
||||
self.is_at = itchat_msg["IsAt"]
|
||||
self.actual_user_id = itchat_msg["ActualUserName"]
|
||||
if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.EXIT_GROUP]:
|
||||
self.actual_user_nickname = itchat_msg["ActualNickName"]
|
||||
@@ -4,336 +4,126 @@
|
||||
wechaty channel
|
||||
Python Wechaty - https://github.com/wechaty/python-wechaty
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional, Union
|
||||
from bridge.context import Context, ContextType
|
||||
from wechaty_puppet import MessageType, FileBox, ScanStatus # type: ignore
|
||||
from wechaty import Wechaty, Contact
|
||||
from wechaty.user import Message, MiniProgram, UrlLink
|
||||
from channel.channel import Channel
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from config import conf
|
||||
from voice.audio_convert import sil_to_wav, mp3_to_sil
|
||||
|
||||
class WechatyChannel(Channel):
|
||||
from wechaty import Contact, Wechaty
|
||||
from wechaty.user import Message
|
||||
from wechaty_puppet import FileBox
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.context import Context
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wechaty_message import WechatyMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
|
||||
try:
|
||||
from voice.audio_convert import any_to_sil
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatyChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
super().__init__()
|
||||
|
||||
def startup(self):
|
||||
config = conf()
|
||||
token = config.get("wechaty_puppet_service_token")
|
||||
os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token
|
||||
asyncio.run(self.main())
|
||||
|
||||
async def main(self):
|
||||
config = conf()
|
||||
# 使用PadLocal协议 比较稳定(免费web协议 os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:8080')
|
||||
token = config.get('wechaty_puppet_service_token')
|
||||
os.environ['WECHATY_PUPPET_SERVICE_TOKEN'] = token
|
||||
global bot
|
||||
bot = Wechaty()
|
||||
|
||||
bot.on('scan', self.on_scan)
|
||||
bot.on('login', self.on_login)
|
||||
bot.on('message', self.on_message)
|
||||
await bot.start()
|
||||
loop = asyncio.get_event_loop()
|
||||
# 将asyncio的loop传入处理线程
|
||||
self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop)
|
||||
self.bot = Wechaty()
|
||||
self.bot.on("login", self.on_login)
|
||||
self.bot.on("message", self.on_message)
|
||||
await self.bot.start()
|
||||
|
||||
async def on_login(self, contact: Contact):
|
||||
logger.info('[WX] login user={}'.format(contact))
|
||||
self.user_id = contact.contact_id
|
||||
self.name = contact.name
|
||||
logger.info("[WX] login user={}".format(contact))
|
||||
|
||||
async def on_scan(self, status: ScanStatus, qr_code: Optional[str] = None,
|
||||
data: Optional[str] = None):
|
||||
pass
|
||||
# contact = self.Contact.load(self.contact_id)
|
||||
# logger.info('[WX] scan user={}, scan status={}, scan qr_code={}'.format(contact, status.name, qr_code))
|
||||
# print(f'user <{contact}> scan status: {status.name} , 'f'qr_code: {qr_code}')
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver_id = context["receiver"]
|
||||
loop = asyncio.get_event_loop()
|
||||
if context["isgroup"]:
|
||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Room.find(receiver_id), loop).result()
|
||||
else:
|
||||
receiver = asyncio.run_coroutine_threadsafe(self.bot.Contact.find(receiver_id), loop).result()
|
||||
msg = None
|
||||
if reply.type == ReplyType.TEXT:
|
||||
msg = reply.content
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
msg = reply.content
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
voiceLength = None
|
||||
file_path = reply.content
|
||||
sil_file = os.path.splitext(file_path)[0] + ".sil"
|
||||
voiceLength = int(any_to_sil(file_path, sil_file))
|
||||
if voiceLength >= 60000:
|
||||
voiceLength = 60000
|
||||
logger.info("[WX] voice too long, length={}, set to 60s".format(voiceLength))
|
||||
# 发送语音
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_file(sil_file, name=str(t) + ".sil")
|
||||
if voiceLength is not None:
|
||||
msg.metadata["voiceLength"] = voiceLength
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
try:
|
||||
os.remove(file_path)
|
||||
if sil_file != file_path:
|
||||
os.remove(sil_file)
|
||||
except Exception as e:
|
||||
pass
|
||||
logger.info("[WX] sendVoice={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_url(url=img_url, name=str(t) + ".png")
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
t = int(time.time())
|
||||
msg = FileBox.from_base64(base64.b64encode(image_storage.read()), str(t) + ".png")
|
||||
asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result()
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
|
||||
async def on_message(self, msg: Message):
|
||||
"""
|
||||
listen for message event
|
||||
"""
|
||||
from_contact = msg.talker() # 获取消息的发送者
|
||||
to_contact = msg.to() # 接收人
|
||||
try:
|
||||
cmsg = await WechatyMessage(msg)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[WX] {}".format(e))
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("[WX] {}".format(e))
|
||||
return
|
||||
logger.debug("[WX] message:{}".format(cmsg))
|
||||
room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None
|
||||
from_user_id = from_contact.contact_id
|
||||
to_user_id = to_contact.contact_id # 接收人id
|
||||
# other_user_id = msg['User']['UserName'] # 对手方id
|
||||
content = msg.text()
|
||||
mention_content = await msg.mention_text() # 返回过滤掉@name后的消息
|
||||
match_prefix = self.check_prefix(content, conf().get('single_chat_prefix'))
|
||||
# conversation: Union[Room, Contact] = from_contact if room is None else room
|
||||
|
||||
if room is None and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||
if not msg.is_self() and match_prefix is not None:
|
||||
# 好友向自己发送消息
|
||||
if match_prefix != '':
|
||||
str_list = content.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
content = str_list[1].strip()
|
||||
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
await self._do_send_img(content, from_user_id)
|
||||
else:
|
||||
await self._do_send(content, from_user_id)
|
||||
elif msg.is_self() and match_prefix:
|
||||
# 自己给好友发送消息
|
||||
str_list = content.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
content = str_list[1].strip()
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
await self._do_send_img(content, to_user_id)
|
||||
else:
|
||||
await self._do_send(content, to_user_id)
|
||||
elif room is None and msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
|
||||
if not msg.is_self(): # 接收语音消息
|
||||
# 下载语音文件
|
||||
voice_file = await msg.to_file_box()
|
||||
silk_file = TmpDir().path() + voice_file.name
|
||||
await voice_file.to_file(silk_file)
|
||||
logger.info("[WX]receive voice file: " + silk_file)
|
||||
# 将文件转成wav格式音频
|
||||
wav_file = os.path.splitext(silk_file)[0] + '.wav'
|
||||
sil_to_wav(silk_file, wav_file)
|
||||
# 语音识别为文本
|
||||
query = super().build_voice_to_text(wav_file).content
|
||||
# 交验关键字
|
||||
match_prefix = self.check_prefix(query, conf().get('single_chat_prefix'))
|
||||
if match_prefix is not None:
|
||||
if match_prefix != '':
|
||||
str_list = query.split(match_prefix, 1)
|
||||
if len(str_list) == 2:
|
||||
query = str_list[1].strip()
|
||||
# 返回消息
|
||||
if conf().get('voice_reply_voice'):
|
||||
await self._do_send_voice(query, from_user_id)
|
||||
else:
|
||||
await self._do_send(query, from_user_id)
|
||||
else:
|
||||
logger.info("[WX]receive voice check prefix: " + 'False')
|
||||
# 清除缓存文件
|
||||
os.remove(wav_file)
|
||||
os.remove(silk_file)
|
||||
elif room and msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||
# 群组&文本消息
|
||||
room_id = room.room_id
|
||||
room_name = await room.topic()
|
||||
from_user_id = from_contact.contact_id
|
||||
from_user_name = from_contact.name
|
||||
is_at = await msg.mention_self()
|
||||
content = mention_content
|
||||
config = conf()
|
||||
match_prefix = (is_at and not config.get("group_at_off", False)) \
|
||||
or self.check_prefix(content, config.get('group_chat_prefix')) \
|
||||
or self.check_contain(content, config.get('group_chat_keyword'))
|
||||
# Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
|
||||
# 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
|
||||
prefixes = config.get('group_chat_prefix')
|
||||
for prefix in prefixes:
|
||||
if content.startswith(prefix):
|
||||
content = content.replace(prefix, '', 1).strip()
|
||||
break
|
||||
if ('ALL_GROUP' in config.get('group_name_white_list') or room_name in config.get(
|
||||
'group_name_white_list') or self.check_contain(room_name, config.get(
|
||||
'group_name_keyword_white_list'))) and match_prefix:
|
||||
img_match_prefix = self.check_prefix(content, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
content = content.split(img_match_prefix, 1)[1].strip()
|
||||
await self._do_send_group_img(content, room_id)
|
||||
else:
|
||||
await self._do_send_group(content, room_id, room_name, from_user_id, from_user_name)
|
||||
elif room and msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
|
||||
# 群组&语音消息
|
||||
room_id = room.room_id
|
||||
room_name = await room.topic()
|
||||
from_user_id = from_contact.contact_id
|
||||
from_user_name = from_contact.name
|
||||
is_at = await msg.mention_self()
|
||||
config = conf()
|
||||
# 是否开启语音识别、群消息响应功能、群名白名单符合等条件
|
||||
if config.get('group_speech_recognition') and (
|
||||
'ALL_GROUP' in config.get('group_name_white_list') or room_name in config.get(
|
||||
'group_name_white_list') or self.check_contain(room_name, config.get(
|
||||
'group_name_keyword_white_list'))):
|
||||
# 下载语音文件
|
||||
voice_file = await msg.to_file_box()
|
||||
silk_file = TmpDir().path() + voice_file.name
|
||||
await voice_file.to_file(silk_file)
|
||||
logger.info("[WX]receive voice file: " + silk_file)
|
||||
# 将文件转成wav格式音频
|
||||
wav_file = os.path.splitext(silk_file)[0] + '.wav'
|
||||
sil_to_wav(silk_file, wav_file)
|
||||
# 语音识别为文本
|
||||
query = super().build_voice_to_text(wav_file).content
|
||||
# 校验关键字
|
||||
match_prefix = self.check_prefix(query, config.get('group_chat_prefix')) \
|
||||
or self.check_contain(query, config.get('group_chat_keyword'))
|
||||
# Wechaty判断is_at为True,返回的内容是过滤掉@之后的内容;而is_at为False,则会返回完整的内容
|
||||
if match_prefix is not None:
|
||||
# 故判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容,用于实现类似自定义+前缀触发生成AI图片的功能
|
||||
prefixes = config.get('group_chat_prefix')
|
||||
for prefix in prefixes:
|
||||
if query.startswith(prefix):
|
||||
query = query.replace(prefix, '', 1).strip()
|
||||
break
|
||||
# 返回消息
|
||||
img_match_prefix = self.check_prefix(query, conf().get('image_create_prefix'))
|
||||
if img_match_prefix:
|
||||
query = query.split(img_match_prefix, 1)[1].strip()
|
||||
await self._do_send_group_img(query, room_id)
|
||||
elif config.get('voice_reply_voice'):
|
||||
await self._do_send_group_voice(query, room_id, room_name, from_user_id, from_user_name)
|
||||
else:
|
||||
await self._do_send_group(query, room_id, room_name, from_user_id, from_user_name)
|
||||
else:
|
||||
logger.info("[WX]receive voice check prefix: " + 'False')
|
||||
# 清除缓存文件
|
||||
os.remove(wav_file)
|
||||
os.remove(silk_file)
|
||||
|
||||
async def send(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
|
||||
if receiver:
|
||||
contact = await bot.Contact.find(receiver)
|
||||
await contact.say(message)
|
||||
|
||||
async def send_group(self, message: Union[str, Message, FileBox, Contact, UrlLink, MiniProgram], receiver):
|
||||
logger.info('[WX] sendMsg={}, receiver={}'.format(message, receiver))
|
||||
if receiver:
|
||||
room = await bot.Room.find(receiver)
|
||||
await room.say(message)
|
||||
|
||||
async def _do_send(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context['session_id'] = reply_user_id
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
await self.send(conf().get("single_chat_reply_prefix") + reply_text, reply_user_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
async def _do_send_voice(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.TEXT, query)
|
||||
context['session_id'] = reply_user_id
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
# 转换 mp3 文件为 silk 格式
|
||||
mp3_file = super().build_text_to_voice(reply_text).content
|
||||
silk_file = os.path.splitext(mp3_file)[0] + '.sil'
|
||||
voiceLength = mp3_to_sil(mp3_file, silk_file)
|
||||
# 发送语音
|
||||
t = int(time.time())
|
||||
file_box = FileBox.from_file(silk_file, name=str(t) + '.sil')
|
||||
file_box.metadata = {'voiceLength': voiceLength}
|
||||
await self.send(file_box, reply_user_id)
|
||||
# 清除缓存文件
|
||||
os.remove(mp3_file)
|
||||
os.remove(silk_file)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
async def _do_send_img(self, query, reply_user_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.IMAGE_CREATE, query)
|
||||
img_url = super().build_reply_content(query, context).content
|
||||
if not img_url:
|
||||
return
|
||||
# 图片下载
|
||||
# pic_res = requests.get(img_url, stream=True)
|
||||
# image_storage = io.BytesIO()
|
||||
# for block in pic_res.iter_content(1024):
|
||||
# image_storage.write(block)
|
||||
# image_storage.seek(0)
|
||||
|
||||
# 图片发送
|
||||
logger.info('[WX] sendImage, receiver={}'.format(reply_user_id))
|
||||
t = int(time.time())
|
||||
file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
|
||||
await self.send(file_box, reply_user_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
async def _do_send_group(self, query, group_id, group_name, group_user_id, group_user_name):
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.TEXT, query)
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
if ('ALL_GROUP' in group_chat_in_one_session or \
|
||||
group_name in group_chat_in_one_session or \
|
||||
self.check_contain(group_name, group_chat_in_one_session)):
|
||||
context['session_id'] = str(group_id)
|
||||
else:
|
||||
context['session_id'] = str(group_id) + '-' + str(group_user_id)
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
reply_text = '@' + group_user_name + ' ' + reply_text.strip()
|
||||
await self.send_group(conf().get("group_chat_reply_prefix", "") + reply_text, group_id)
|
||||
|
||||
async def _do_send_group_voice(self, query, group_id, group_name, group_user_id, group_user_name):
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.TEXT, query)
|
||||
group_chat_in_one_session = conf().get('group_chat_in_one_session', [])
|
||||
if ('ALL_GROUP' in group_chat_in_one_session or \
|
||||
group_name in group_chat_in_one_session or \
|
||||
self.check_contain(group_name, group_chat_in_one_session)):
|
||||
context['session_id'] = str(group_id)
|
||||
else:
|
||||
context['session_id'] = str(group_id) + '-' + str(group_user_id)
|
||||
reply_text = super().build_reply_content(query, context).content
|
||||
if reply_text:
|
||||
reply_text = '@' + group_user_name + ' ' + reply_text.strip()
|
||||
# 转换 mp3 文件为 silk 格式
|
||||
mp3_file = super().build_text_to_voice(reply_text).content
|
||||
silk_file = os.path.splitext(mp3_file)[0] + '.sil'
|
||||
voiceLength = mp3_to_sil(mp3_file, silk_file)
|
||||
# 发送语音
|
||||
t = int(time.time())
|
||||
file_box = FileBox.from_file(silk_file, name=str(t) + '.silk')
|
||||
file_box.metadata = {'voiceLength': voiceLength}
|
||||
await self.send_group(file_box, group_id)
|
||||
# 清除缓存文件
|
||||
os.remove(mp3_file)
|
||||
os.remove(silk_file)
|
||||
|
||||
async def _do_send_group_img(self, query, reply_room_id):
|
||||
try:
|
||||
if not query:
|
||||
return
|
||||
context = Context(ContextType.IMAGE_CREATE, query)
|
||||
img_url = super().build_reply_content(query, context).content
|
||||
if not img_url:
|
||||
return
|
||||
# 图片发送
|
||||
logger.info('[WX] sendImage, receiver={}'.format(reply_room_id))
|
||||
t = int(time.time())
|
||||
file_box = FileBox.from_url(url=img_url, name=str(t) + '.png')
|
||||
await self.send_group(file_box, reply_room_id)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def check_prefix(self, content, prefix_list):
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
|
||||
def check_contain(self, content, keyword_list):
|
||||
if not keyword_list:
|
||||
return None
|
||||
for ky in keyword_list:
|
||||
if content.find(ky) != -1:
|
||||
return True
|
||||
return None
|
||||
isgroup = room is not None
|
||||
ctype = cmsg.ctype
|
||||
context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg)
|
||||
if context:
|
||||
logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context))
|
||||
self.produce(context)
|
||||
|
||||
89
channel/wechat/wechaty_message.py
Normal file
89
channel/wechat/wechaty_message.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from wechaty import MessageType
|
||||
from wechaty.user import Message
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class aobject(object):
|
||||
"""Inheriting this class allows you to define an async __init__.
|
||||
|
||||
So you can create objects by doing something like `await MyClass(params)`
|
||||
"""
|
||||
|
||||
async def __new__(cls, *a, **kw):
|
||||
instance = super().__new__(cls)
|
||||
await instance.__init__(*a, **kw)
|
||||
return instance
|
||||
|
||||
async def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class WechatyMessage(ChatMessage, aobject):
|
||||
async def __init__(self, wechaty_msg: Message):
|
||||
super().__init__(wechaty_msg)
|
||||
|
||||
room = wechaty_msg.room()
|
||||
|
||||
self.msg_id = wechaty_msg.message_id
|
||||
self.create_time = wechaty_msg.payload.timestamp
|
||||
self.is_group = room is not None
|
||||
|
||||
if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wechaty_msg.text()
|
||||
elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO:
|
||||
self.ctype = ContextType.VOICE
|
||||
voice_file = await wechaty_msg.to_file_box()
|
||||
self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径
|
||||
|
||||
def func():
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.run_coroutine_threadsafe(voice_file.to_file(self.content), loop).result()
|
||||
|
||||
self._prepare_fn = func
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: {}".format(wechaty_msg.type()))
|
||||
|
||||
from_contact = wechaty_msg.talker() # 获取消息的发送者
|
||||
self.from_user_id = from_contact.contact_id
|
||||
self.from_user_nickname = from_contact.name
|
||||
|
||||
# group中的from和to,wechaty跟itchat含义不一样
|
||||
# wecahty: from是消息实际发送者, to:所在群
|
||||
# itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己
|
||||
# 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户
|
||||
|
||||
if self.is_group:
|
||||
self.to_user_id = room.room_id
|
||||
self.to_user_nickname = await room.topic()
|
||||
else:
|
||||
to_contact = wechaty_msg.to()
|
||||
self.to_user_id = to_contact.contact_id
|
||||
self.to_user_nickname = to_contact.name
|
||||
|
||||
if self.is_group or wechaty_msg.is_self(): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。
|
||||
self.other_user_id = self.to_user_id
|
||||
self.other_user_nickname = self.to_user_nickname
|
||||
else:
|
||||
self.other_user_id = self.from_user_id
|
||||
self.other_user_nickname = self.from_user_nickname
|
||||
|
||||
if self.is_group: # wechaty群聊中,实际发送用户就是from_user
|
||||
self.is_at = await wechaty_msg.mention_self()
|
||||
if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容
|
||||
name = wechaty_msg.wechaty.user_self().name
|
||||
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
|
||||
if re.search(pattern, self.content):
|
||||
logger.debug(f"wechaty message {self.msg_id} include at")
|
||||
self.is_at = True
|
||||
|
||||
self.actual_user_id = self.from_user_id
|
||||
self.actual_user_nickname = self.from_user_nickname
|
||||
85
channel/wechatcom/README.md
Normal file
85
channel/wechatcom/README.md
Normal file
@@ -0,0 +1,85 @@
|
||||
# 企业微信应用号channel
|
||||
|
||||
企业微信官方提供了客服、应用等API,本channel使用的是企业微信的自建应用API的能力。
|
||||
|
||||
因为未来可能还会开发客服能力,所以本channel的类型名叫作`wechatcom_app`。
|
||||
|
||||
`wechatcom_app` channel支持插件系统和图片声音交互等能力,除了无法加入群聊,作为个人使用的私人助理已绰绰有余。
|
||||
|
||||
## 开始之前
|
||||
|
||||
- 在企业中确认自己拥有在企业内自建应用的权限。
|
||||
- 如果没有权限或者是个人用户,也可创建未认证的企业。操作方式:登录手机企业微信,选择`创建/加入企业`来创建企业,类型请选择企业,企业名称可随意填写。
|
||||
未认证的企业有100人的服务人数上限,其他功能与认证企业没有差异。
|
||||
|
||||
本channel需安装的依赖与公众号一致,需要安装`wechatpy`和`web.py`,它们包含在`requirements-optional.txt`中。
|
||||
|
||||
此外,如果你是`Linux`系统,除了`ffmpeg`还需要安装`amr`编码器,否则会出现找不到编码器的错误,无法正常使用语音功能。
|
||||
|
||||
- Ubuntu/Debian
|
||||
|
||||
```bash
|
||||
apt-get install libavcodec-extra
|
||||
```
|
||||
|
||||
- Alpine
|
||||
|
||||
需自行编译`ffmpeg`,在编译参数里加入`amr`编码器的支持
|
||||
|
||||
## 使用方法
|
||||
|
||||
1.查看企业ID
|
||||
|
||||
- 扫码登陆[企业微信后台](https://work.weixin.qq.com)
|
||||
- 选择`我的企业`,点击`企业信息`,记住该`企业ID`
|
||||
|
||||
2.创建自建应用
|
||||
|
||||
- 选择应用管理, 在自建区选创建应用来创建企业自建应用
|
||||
- 上传应用logo,填写应用名称等项
|
||||
- 创建应用后进入应用详情页面,记住`AgentId`和`Secert`
|
||||
|
||||
3.配置应用
|
||||
|
||||
- 在详情页点击`企业可信IP`的配置(没看到可以不管),填入你服务器的公网IP,如果不知道可以先不填
|
||||
- 点击`接收消息`下的启用API接收消息
|
||||
- `URL`填写格式为`http://url:port/wxcomapp`,`port`是程序监听的端口,默认是9898
|
||||
如果是未认证的企业,url可直接使用服务器的IP。如果是认证企业,需要使用备案的域名,可使用二级域名。
|
||||
- `Token`可随意填写,停留在这个页面
|
||||
- 在程序根目录`config.json`中增加配置(**去掉注释**),`wechatcomapp_aes_key`是当前页面的`wechatcomapp_aes_key`
|
||||
|
||||
```python
|
||||
"channel_type": "wechatcom_app",
|
||||
"wechatcom_corp_id": "", # 企业微信公司的corpID
|
||||
"wechatcomapp_token": "", # 企业微信app的token
|
||||
"wechatcomapp_port": 9898, # 企业微信app的服务端口, 不需要端口转发
|
||||
"wechatcomapp_secret": "", # 企业微信app的secret
|
||||
"wechatcomapp_agent_id": "", # 企业微信app的agent_id
|
||||
"wechatcomapp_aes_key": "", # 企业微信app的aes_key
|
||||
```
|
||||
|
||||
- 运行程序,在页面中点击保存,保存成功说明验证成功
|
||||
|
||||
4.连接个人微信
|
||||
|
||||
选择`我的企业`,点击`微信插件`,下面有个邀请关注的二维码。微信扫码后,即可在微信中看到对应企业,在这里你便可以和机器人沟通。
|
||||
|
||||
向机器人发送消息,如果日志里出现报错:
|
||||
|
||||
```bash
|
||||
Error code: 60020, message: "not allow to access from your ip, ...from ip: xx.xx.xx.xx"
|
||||
```
|
||||
|
||||
意思是IP不可信,需要参考上一步的`企业可信IP`配置,把这里的IP加进去。
|
||||
|
||||
~~### Railway部署方式~~(2023-06-08已失效)
|
||||
|
||||
~~公众号不能在`Railway`上部署,但企业微信应用[可以](https://railway.app/template/-FHS--?referralCode=RC3znh)!~~
|
||||
|
||||
~~填写配置后,将部署完成后的网址```**.railway.app/wxcomapp```,填写在上一步的URL中。发送信息后观察日志,把报错的IP加入到可信IP。(每次重启后都需要加入可信IP)~~
|
||||
|
||||
## 测试体验
|
||||
|
||||
AIGC开放社区中已经部署了多个可免费使用的Bot,扫描下方的二维码会自动邀请你来体验。
|
||||
|
||||
<img width="200" src="../../docs/images/aigcopen.png">
|
||||
178
channel/wechatcom/wechatcomapp_channel.py
Normal file
178
channel/wechatcom/wechatcomapp_channel.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# -*- coding=utf-8 -*-
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
import web
|
||||
from wechatpy.enterprise import create_reply, parse_message
|
||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||
from wechatpy.enterprise.exceptions import InvalidCorpIdException
|
||||
from wechatpy.exceptions import InvalidSignatureException, WeChatClientException
|
||||
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechatcom.wechatcomapp_client import WechatComAppClient
|
||||
from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import compress_imgfile, fsize, split_string_by_utf8_length
|
||||
from config import conf, subscribe_msg
|
||||
from voice.audio_convert import any_to_amr, split_audio
|
||||
|
||||
MAX_UTF8_LEN = 2048
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatComAppChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.corp_id = conf().get("wechatcom_corp_id")
|
||||
self.secret = conf().get("wechatcomapp_secret")
|
||||
self.agent_id = conf().get("wechatcomapp_agent_id")
|
||||
self.token = conf().get("wechatcomapp_token")
|
||||
self.aes_key = conf().get("wechatcomapp_aes_key")
|
||||
print(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
|
||||
logger.info(
|
||||
"[wechatcom] init: corp_id: {}, secret: {}, agent_id: {}, token: {}, aes_key: {}".format(self.corp_id, self.secret, self.agent_id, self.token, self.aes_key)
|
||||
)
|
||||
self.crypto = WeChatCrypto(self.token, self.aes_key, self.corp_id)
|
||||
self.client = WechatComAppClient(self.corp_id, self.secret)
|
||||
|
||||
def startup(self):
|
||||
# start message listener
|
||||
urls = ("/wxcomapp", "channel.wechatcom.wechatcomapp_channel.Query")
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("wechatcomapp_port", 9898)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if reply.type in [ReplyType.TEXT, ReplyType.ERROR, ReplyType.INFO]:
|
||||
reply_text = reply.content
|
||||
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
||||
if len(texts) > 1:
|
||||
logger.info("[wechatcom] text too long, split into {} parts".format(len(texts)))
|
||||
for i, text in enumerate(texts):
|
||||
self.client.message.send_text(self.agent_id, receiver, text)
|
||||
if i != len(texts) - 1:
|
||||
time.sleep(0.5) # 休眠0.5秒,防止发送过快乱序
|
||||
logger.info("[wechatcom] Do send text to {}: {}".format(receiver, reply_text))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
try:
|
||||
media_ids = []
|
||||
file_path = reply.content
|
||||
amr_file = os.path.splitext(file_path)[0] + ".amr"
|
||||
any_to_amr(file_path, amr_file)
|
||||
duration, files = split_audio(amr_file, 60 * 1000)
|
||||
if len(files) > 1:
|
||||
logger.info("[wechatcom] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
|
||||
for path in files:
|
||||
response = self.client.media.upload("voice", open(path, "rb"))
|
||||
logger.debug("[wechatcom] upload voice response: {}".format(response))
|
||||
media_ids.append(response["media_id"])
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatcom] upload voice failed: {}".format(e))
|
||||
return
|
||||
try:
|
||||
os.remove(file_path)
|
||||
if amr_file != file_path:
|
||||
os.remove(amr_file)
|
||||
except Exception:
|
||||
pass
|
||||
for media_id in media_ids:
|
||||
self.client.message.send_voice(self.agent_id, receiver, media_id)
|
||||
time.sleep(1)
|
||||
logger.info("[wechatcom] sendVoice={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
sz = fsize(image_storage)
|
||||
if sz >= 10 * 1024 * 1024:
|
||||
logger.info("[wechatcom] image too large, ready to compress, sz={}".format(sz))
|
||||
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
||||
logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
|
||||
image_storage.seek(0)
|
||||
try:
|
||||
response = self.client.media.upload("image", image_storage)
|
||||
logger.debug("[wechatcom] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatcom] upload image failed: {}".format(e))
|
||||
return
|
||||
|
||||
self.client.message.send_image(self.agent_id, receiver, response["media_id"])
|
||||
logger.info("[wechatcom] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
sz = fsize(image_storage)
|
||||
if sz >= 10 * 1024 * 1024:
|
||||
logger.info("[wechatcom] image too large, ready to compress, sz={}".format(sz))
|
||||
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
||||
logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
|
||||
image_storage.seek(0)
|
||||
try:
|
||||
response = self.client.media.upload("image", image_storage)
|
||||
logger.debug("[wechatcom] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatcom] upload image failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_image(self.agent_id, receiver, response["media_id"])
|
||||
logger.info("[wechatcom] sendImage, receiver={}".format(receiver))
|
||||
|
||||
|
||||
class Query:
|
||||
def GET(self):
|
||||
channel = WechatComAppChannel()
|
||||
params = web.input()
|
||||
logger.info("[wechatcom] receive params: {}".format(params))
|
||||
try:
|
||||
signature = params.msg_signature
|
||||
timestamp = params.timestamp
|
||||
nonce = params.nonce
|
||||
echostr = params.echostr
|
||||
echostr = channel.crypto.check_signature(signature, timestamp, nonce, echostr)
|
||||
except InvalidSignatureException:
|
||||
raise web.Forbidden()
|
||||
return echostr
|
||||
|
||||
def POST(self):
|
||||
channel = WechatComAppChannel()
|
||||
params = web.input()
|
||||
logger.info("[wechatcom] receive params: {}".format(params))
|
||||
try:
|
||||
signature = params.msg_signature
|
||||
timestamp = params.timestamp
|
||||
nonce = params.nonce
|
||||
message = channel.crypto.decrypt_message(web.data(), signature, timestamp, nonce)
|
||||
except (InvalidSignatureException, InvalidCorpIdException):
|
||||
raise web.Forbidden()
|
||||
msg = parse_message(message)
|
||||
logger.debug("[wechatcom] receive message: {}, msg= {}".format(message, msg))
|
||||
if msg.type == "event":
|
||||
if msg.event == "subscribe":
|
||||
reply_content = subscribe_msg()
|
||||
if reply_content:
|
||||
reply = create_reply(reply_content, msg).render()
|
||||
res = channel.crypto.encrypt_message(reply, nonce, timestamp)
|
||||
return res
|
||||
else:
|
||||
try:
|
||||
wechatcom_msg = WechatComAppMessage(msg, client=channel.client)
|
||||
except NotImplementedError as e:
|
||||
logger.debug("[wechatcom] " + str(e))
|
||||
return "success"
|
||||
context = channel._compose_context(
|
||||
wechatcom_msg.ctype,
|
||||
wechatcom_msg.content,
|
||||
isgroup=False,
|
||||
msg=wechatcom_msg,
|
||||
)
|
||||
if context:
|
||||
channel.produce(context)
|
||||
return "success"
|
||||
21
channel/wechatcom/wechatcomapp_client.py
Normal file
21
channel/wechatcom/wechatcomapp_client.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
|
||||
|
||||
class WechatComAppClient(WeChatClient):
|
||||
def __init__(self, corp_id, secret, access_token=None, session=None, timeout=None, auto_retry=True):
|
||||
super(WechatComAppClient, self).__init__(corp_id, secret, access_token, session, timeout, auto_retry)
|
||||
self.fetch_access_token_lock = threading.Lock()
|
||||
|
||||
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
|
||||
with self.fetch_access_token_lock:
|
||||
access_token = self.session.get(self.access_token_key)
|
||||
if access_token:
|
||||
if not self.expires_at:
|
||||
return access_token
|
||||
timestamp = time.time()
|
||||
if self.expires_at - timestamp > 60:
|
||||
return access_token
|
||||
return super().fetch_access_token()
|
||||
52
channel/wechatcom/wechatcomapp_message.py
Normal file
52
channel/wechatcom/wechatcomapp_message.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class WechatComAppMessage(ChatMessage):
|
||||
def __init__(self, msg, client: WeChatClient, is_group=False):
|
||||
super().__init__(msg)
|
||||
self.msg_id = msg.id
|
||||
self.create_time = msg.time
|
||||
self.is_group = is_group
|
||||
|
||||
if msg.type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = msg.content
|
||||
elif msg.type == "voice":
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
|
||||
|
||||
def download_voice():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
response = client.media.download(msg.media_id)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[wechatcom] Failed to download voice file, {response.content}")
|
||||
|
||||
self._prepare_fn = download_voice
|
||||
elif msg.type == "image":
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
|
||||
|
||||
def download_image():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
response = client.media.download(msg.media_id)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[wechatcom] Failed to download image file, {response.content}")
|
||||
|
||||
self._prepare_fn = download_image
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
|
||||
|
||||
self.from_user_id = msg.source
|
||||
self.to_user_id = msg.target
|
||||
self.other_user_id = msg.source
|
||||
100
channel/wechatmp/README.md
Normal file
100
channel/wechatmp/README.md
Normal file
@@ -0,0 +1,100 @@
|
||||
# 微信公众号channel
|
||||
|
||||
鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。
|
||||
目前支持订阅号和服务号两种类型的公众号,它们都支持文本交互,语音和图片输入。其中个人主体的微信订阅号由于无法通过微信认证,存在回复时间限制,每天的图片和声音回复次数也有限制。
|
||||
|
||||
## 使用方法(订阅号,服务号类似)
|
||||
|
||||
在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。
|
||||
|
||||
此外,需要在我们的服务器上安装python的web框架web.py和wechatpy。
|
||||
以ubuntu为例(在ubuntu 22.04上测试):
|
||||
```
|
||||
pip3 install web.py
|
||||
pip3 install wechatpy
|
||||
```
|
||||
|
||||
然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。
|
||||
|
||||
然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。`URL`填写格式为`http://url/wx`,可使用IP(成功几率看脸),`Token`是你自己编的一个特定的令牌。消息加解密方式如果选择了需要加密的模式,需要在配置中填写`wechatmp_aes_key`。
|
||||
|
||||
相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加
|
||||
```
|
||||
"channel_type": "wechatmp", # 如果通过了微信认证,将"wechatmp"替换为"wechatmp_service",可极大的优化使用体验
|
||||
"wechatmp_token": "xxxx", # 微信公众平台的Token
|
||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||
"wechatmp_app_id": "xxxx", # 微信公众平台的appID
|
||||
"wechatmp_app_secret": "xxxx", # 微信公众平台的appsecret
|
||||
"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
|
||||
"single_chat_prefix": [""], # 推荐设置,任意对话都可以触发回复,不添加前缀
|
||||
"single_chat_reply_prefix": "", # 推荐设置,回复不设置前缀
|
||||
"plugin_trigger_prefix": "&", # 推荐设置,在手机微信客户端中,$%^等符号与中文连在一起时会自动显示一段较大的间隔,用户体验不好。请不要使用管理员指令前缀"#",这会造成未知问题。
|
||||
```
|
||||
然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口:
|
||||
```
|
||||
sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080
|
||||
sudo iptables-save > /etc/iptables/rules.v4
|
||||
```
|
||||
第二个方法是让python程序直接监听80端口,在配置文件中设置`"wechatmp_port": 80` ,在linux上需要使用`sudo python3 app.py`启动程序。然而这会导致一系列环境和权限问题,因此不是推荐的方法。
|
||||
|
||||
443端口同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`中需要修改相应的证书路径。
|
||||
|
||||
程序启动并监听端口后,在刚才的“服务器配置”中点击`提交`即可验证你的服务器。
|
||||
随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。
|
||||
|
||||
之后需要在公众号开发信息下将本机IP加入到IP白名单。
|
||||
|
||||
不然在启用后,发送语音、图片等消息可能会遇到如下报错:
|
||||
```
|
||||
'errcode': 40164, 'errmsg': 'invalid ip xx.xx.xx.xx not in whitelist rid
|
||||
```
|
||||
|
||||
|
||||
## 个人微信公众号的限制
|
||||
由于人微信公众号不能通过微信认证,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。
|
||||
|
||||
另外,由于微信官方的限制,自动回复有长度限制。因此这里将ChatGPT的回答进行了拆分,以满足限制。
|
||||
|
||||
## 私有api_key
|
||||
公共api有访问频率限制(免费账号每分钟最多3次ChatGPT的API调用),这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。
|
||||
|
||||
## 语音输入
|
||||
利用微信自带的语音识别功能,提供语音输入能力。需要在公众号管理页面的“设置与开发”->“接口权限”页面开启“接收语音识别结果”。
|
||||
|
||||
## 语音回复
|
||||
请在配置文件中添加以下词条:
|
||||
```
|
||||
"voice_reply_voice": true,
|
||||
```
|
||||
这样公众号将会用语音回复语音消息,实现语音对话。
|
||||
|
||||
默认的语音合成引擎是`google`,它是免费使用的。
|
||||
|
||||
如果要选择其他的语音合成引擎,请添加以下配置项:
|
||||
```
|
||||
"text_to_voice": "pytts"
|
||||
```
|
||||
|
||||
pytts是本地的语音合成引擎。还支持baidu,azure,这些你需要自行配置相关的依赖和key。
|
||||
|
||||
如果使用pytts,在ubuntu上需要安装如下依赖:
|
||||
```
|
||||
sudo apt update
|
||||
sudo apt install espeak
|
||||
sudo apt install ffmpeg
|
||||
python3 -m pip install pyttsx3
|
||||
```
|
||||
不是很建议开启pytts语音回复,因为它是离线本地计算,算的慢会拖垮服务器,且声音不好听。
|
||||
|
||||
## 图片回复
|
||||
现在认证公众号和非认证公众号都可以实现的图片和语音回复。但是非认证公众号使用了永久素材接口,每天有1000次的调用上限(每个月有10次重置机会,程序中已设定遇到上限会自动重置),且永久素材库存也有上限。因此对于非认证公众号,我们会在回复图片或者语音消息后的10秒内从永久素材库存内删除该素材。
|
||||
|
||||
## 测试
|
||||
目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp)),感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有详尽测试。百度的接口暂未测试。[wechatmp-stable分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp-stable)是较稳定的上个版本,但也缺少最新的功能支持。
|
||||
|
||||
## TODO
|
||||
- [x] 语音输入
|
||||
- [x] 图片输入
|
||||
- [x] 使用临时素材接口提供认证公众号的图片和语音回复
|
||||
- [x] 使用永久素材接口提供未认证公众号的图片和语音回复
|
||||
- [ ] 高并发支持
|
||||
75
channel/wechatmp/active_reply.py
Normal file
75
channel/wechatmp/active_reply.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import time
|
||||
|
||||
import web
|
||||
from wechatpy import parse_message
|
||||
from wechatpy.replies import create_reply
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
from channel.wechatmp.wechatmp_message import WeChatMPMessage
|
||||
from common.log import logger
|
||||
from config import conf, subscribe_msg
|
||||
|
||||
|
||||
# This class is instantiated once per query
|
||||
class Query:
|
||||
def GET(self):
|
||||
return verify_server(web.input())
|
||||
|
||||
def POST(self):
|
||||
# Make sure to return the instance that first created, @singleton will do that.
|
||||
try:
|
||||
args = web.input()
|
||||
verify_server(args)
|
||||
channel = WechatMPChannel()
|
||||
message = web.data()
|
||||
encrypt_func = lambda x: x
|
||||
if args.get("encrypt_type") == "aes":
|
||||
logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
|
||||
if not channel.crypto:
|
||||
raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
|
||||
message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
|
||||
encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
|
||||
else:
|
||||
logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
|
||||
msg = parse_message(message)
|
||||
if msg.type in ["text", "voice", "image"]:
|
||||
wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
|
||||
from_user = wechatmp_msg.from_user_id
|
||||
content = wechatmp_msg.content
|
||||
message_id = wechatmp_msg.msg_id
|
||||
|
||||
logger.info(
|
||||
"[wechatmp] {}:{} Receive post query {} {}: {}".format(
|
||||
web.ctx.env.get("REMOTE_ADDR"),
|
||||
web.ctx.env.get("REMOTE_PORT"),
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
)
|
||||
)
|
||||
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
|
||||
else:
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
|
||||
if context:
|
||||
channel.produce(context)
|
||||
# The reply will be sent by channel.send() in another thread
|
||||
return "success"
|
||||
elif msg.type == "event":
|
||||
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
|
||||
if msg.event in ["subscribe", "subscribe_scan"]:
|
||||
reply_text = subscribe_msg()
|
||||
if reply_text:
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
else:
|
||||
return "success"
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
27
channel/wechatmp/common.py
Normal file
27
channel/wechatmp/common.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import web
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.utils import check_signature
|
||||
|
||||
from config import conf
|
||||
|
||||
MAX_UTF8_LEN = 2048
|
||||
|
||||
|
||||
class WeChatAPIException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def verify_server(data):
|
||||
try:
|
||||
signature = data.signature
|
||||
timestamp = data.timestamp
|
||||
nonce = data.nonce
|
||||
echostr = data.get("echostr", None)
|
||||
token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写
|
||||
check_signature(token, signature, timestamp, nonce)
|
||||
return echostr
|
||||
except InvalidSignatureException:
|
||||
raise web.Forbidden("Invalid signature")
|
||||
except Exception as e:
|
||||
raise web.Forbidden(str(e))
|
||||
211
channel/wechatmp/passive_reply.py
Normal file
211
channel/wechatmp/passive_reply.py
Normal file
@@ -0,0 +1,211 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import web
|
||||
from wechatpy import parse_message
|
||||
from wechatpy.replies import ImageReply, VoiceReply, create_reply
|
||||
import textwrap
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
from channel.wechatmp.wechatmp_message import WeChatMPMessage
|
||||
from common.log import logger
|
||||
from common.utils import split_string_by_utf8_length
|
||||
from config import conf, subscribe_msg
|
||||
|
||||
|
||||
# This class is instantiated once per query
|
||||
class Query:
|
||||
def GET(self):
|
||||
return verify_server(web.input())
|
||||
|
||||
def POST(self):
|
||||
try:
|
||||
args = web.input()
|
||||
verify_server(args)
|
||||
request_time = time.time()
|
||||
channel = WechatMPChannel()
|
||||
message = web.data()
|
||||
encrypt_func = lambda x: x
|
||||
if args.get("encrypt_type") == "aes":
|
||||
logger.debug("[wechatmp] Receive encrypted post data:\n" + message.decode("utf-8"))
|
||||
if not channel.crypto:
|
||||
raise Exception("Crypto not initialized, Please set wechatmp_aes_key in config.json")
|
||||
message = channel.crypto.decrypt_message(message, args.msg_signature, args.timestamp, args.nonce)
|
||||
encrypt_func = lambda x: channel.crypto.encrypt_message(x, args.nonce, args.timestamp)
|
||||
else:
|
||||
logger.debug("[wechatmp] Receive post data:\n" + message.decode("utf-8"))
|
||||
msg = parse_message(message)
|
||||
if msg.type in ["text", "voice", "image"]:
|
||||
wechatmp_msg = WeChatMPMessage(msg, client=channel.client)
|
||||
from_user = wechatmp_msg.from_user_id
|
||||
content = wechatmp_msg.content
|
||||
message_id = wechatmp_msg.msg_id
|
||||
|
||||
supported = True
|
||||
if "【收到不支持的消息类型,暂无法显示】" in content:
|
||||
supported = False # not supported, used to refresh
|
||||
|
||||
# New request
|
||||
if (
|
||||
channel.cache_dict.get(from_user) is None
|
||||
and from_user not in channel.running
|
||||
or content.startswith("#")
|
||||
and message_id not in channel.request_cnt # insert the godcmd
|
||||
):
|
||||
# The first query begin
|
||||
if msg.type == "voice" and wechatmp_msg.ctype == ContextType.TEXT and conf().get("voice_reply_voice", False):
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, desire_rtype=ReplyType.VOICE, msg=wechatmp_msg)
|
||||
else:
|
||||
context = channel._compose_context(wechatmp_msg.ctype, content, isgroup=False, msg=wechatmp_msg)
|
||||
logger.debug("[wechatmp] context: {} {} {}".format(context, wechatmp_msg, supported))
|
||||
|
||||
if supported and context:
|
||||
channel.running.add(from_user)
|
||||
channel.produce(context)
|
||||
else:
|
||||
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
|
||||
if trigger_prefix or not supported:
|
||||
if trigger_prefix:
|
||||
reply_text = textwrap.dedent(
|
||||
f"""\
|
||||
请输入'{trigger_prefix}'接你想说的话跟我说话。
|
||||
例如:
|
||||
{trigger_prefix}你好,很高兴见到你。"""
|
||||
)
|
||||
else:
|
||||
reply_text = textwrap.dedent(
|
||||
"""\
|
||||
你好,很高兴见到你。
|
||||
请跟我说话吧。"""
|
||||
)
|
||||
else:
|
||||
logger.error(f"[wechatmp] unknown error")
|
||||
reply_text = textwrap.dedent(
|
||||
"""\
|
||||
未知错误,请稍后再试"""
|
||||
)
|
||||
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
# Wechat official server will request 3 times (5 seconds each), with the same message_id.
|
||||
# Because the interval is 5 seconds, here assumed that do not have multithreading problems.
|
||||
request_cnt = channel.request_cnt.get(message_id, 0) + 1
|
||||
channel.request_cnt[message_id] = request_cnt
|
||||
logger.info(
|
||||
"[wechatmp] Request {} from {} {} {}:{}\n{}".format(
|
||||
request_cnt, from_user, message_id, web.ctx.env.get("REMOTE_ADDR"), web.ctx.env.get("REMOTE_PORT"), content
|
||||
)
|
||||
)
|
||||
|
||||
task_running = True
|
||||
waiting_until = request_time + 4
|
||||
while time.time() < waiting_until:
|
||||
if from_user in channel.running:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
task_running = False
|
||||
break
|
||||
|
||||
reply_text = ""
|
||||
if task_running:
|
||||
if request_cnt < 3:
|
||||
# waiting for timeout (the POST request will be closed by Wechat official server)
|
||||
time.sleep(2)
|
||||
# and do nothing, waiting for the next request
|
||||
return "success"
|
||||
else: # request_cnt == 3:
|
||||
# return timeout message
|
||||
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
# reply is ready
|
||||
channel.request_cnt.pop(message_id)
|
||||
|
||||
# no return because of bandwords or other reasons
|
||||
if from_user not in channel.cache_dict and from_user not in channel.running:
|
||||
return "success"
|
||||
|
||||
# Only one request can access to the cached data
|
||||
try:
|
||||
(reply_type, reply_content) = channel.cache_dict[from_user].pop(0)
|
||||
if not channel.cache_dict[from_user]: # If popping the message makes the list empty, delete the user entry from cache
|
||||
del channel.cache_dict[from_user]
|
||||
except IndexError:
|
||||
return "success"
|
||||
|
||||
if reply_type == "text":
|
||||
if len(reply_content.encode("utf8")) <= MAX_UTF8_LEN:
|
||||
reply_text = reply_content
|
||||
else:
|
||||
continue_text = "\n【未完待续,回复任意文字以继续】"
|
||||
splits = split_string_by_utf8_length(
|
||||
reply_content,
|
||||
MAX_UTF8_LEN - len(continue_text.encode("utf-8")),
|
||||
max_split=1,
|
||||
)
|
||||
reply_text = splits[0] + continue_text
|
||||
channel.cache_dict[from_user].append(("text", splits[1]))
|
||||
|
||||
logger.info(
|
||||
"[wechatmp] Request {} do send to {} {}: {}\n{}".format(
|
||||
request_cnt,
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
reply_text,
|
||||
)
|
||||
)
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
elif reply_type == "voice":
|
||||
media_id = reply_content
|
||||
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
|
||||
logger.info(
|
||||
"[wechatmp] Request {} do send to {} {}: {} voice media_id {}".format(
|
||||
request_cnt,
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
media_id,
|
||||
)
|
||||
)
|
||||
replyPost = VoiceReply(message=msg)
|
||||
replyPost.media_id = media_id
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
elif reply_type == "image":
|
||||
media_id = reply_content
|
||||
asyncio.run_coroutine_threadsafe(channel.delete_media(media_id), channel.delete_media_loop)
|
||||
logger.info(
|
||||
"[wechatmp] Request {} do send to {} {}: {} image media_id {}".format(
|
||||
request_cnt,
|
||||
from_user,
|
||||
message_id,
|
||||
content,
|
||||
media_id,
|
||||
)
|
||||
)
|
||||
replyPost = ImageReply(message=msg)
|
||||
replyPost.media_id = media_id
|
||||
return encrypt_func(replyPost.render())
|
||||
|
||||
elif msg.type == "event":
|
||||
logger.info("[wechatmp] Event {} from {}".format(msg.event, msg.source))
|
||||
if msg.event in ["subscribe", "subscribe_scan"]:
|
||||
reply_text = subscribe_msg()
|
||||
if reply_text:
|
||||
replyPost = create_reply(reply_text, msg)
|
||||
return encrypt_func(replyPost.render())
|
||||
else:
|
||||
return "success"
|
||||
else:
|
||||
logger.info("暂且不处理")
|
||||
return "success"
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
return exc
|
||||
236
channel/wechatmp/wechatmp_channel.py
Normal file
236
channel/wechatmp/wechatmp_channel.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import imghdr
|
||||
import io
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
import web
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy.exceptions import WeChatClientException
|
||||
from collections import defaultdict
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_client import WechatMPClient
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import split_string_by_utf8_length
|
||||
from config import conf
|
||||
from voice.audio_convert import any_to_mp3, split_audio
|
||||
|
||||
# If using SSL, uncomment the following lines, and modify the certificate path.
|
||||
# from cheroot.server import HTTPServer
|
||||
# from cheroot.ssl.builtin import BuiltinSSLAdapter
|
||||
# HTTPServer.ssl_adapter = BuiltinSSLAdapter(
|
||||
# certificate='/ssl/cert.pem',
|
||||
# private_key='/ssl/cert.key')
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatMPChannel(ChatChannel):
|
||||
def __init__(self, passive_reply=True):
|
||||
super().__init__()
|
||||
self.passive_reply = passive_reply
|
||||
self.NOT_SUPPORT_REPLYTYPE = []
|
||||
appid = conf().get("wechatmp_app_id")
|
||||
secret = conf().get("wechatmp_app_secret")
|
||||
token = conf().get("wechatmp_token")
|
||||
aes_key = conf().get("wechatmp_aes_key")
|
||||
self.client = WechatMPClient(appid, secret)
|
||||
self.crypto = None
|
||||
if aes_key:
|
||||
self.crypto = WeChatCrypto(token, aes_key, appid)
|
||||
if self.passive_reply:
|
||||
# Cache the reply to the user's first message
|
||||
self.cache_dict = defaultdict(list)
|
||||
# Record whether the current message is being processed
|
||||
self.running = set()
|
||||
# Count the request from wechat official server by message_id
|
||||
self.request_cnt = dict()
|
||||
# The permanent media need to be deleted to avoid media number limit
|
||||
self.delete_media_loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=self.start_loop, args=(self.delete_media_loop,))
|
||||
t.setDaemon(True)
|
||||
t.start()
|
||||
|
||||
def startup(self):
|
||||
if self.passive_reply:
|
||||
urls = ("/wx", "channel.wechatmp.passive_reply.Query")
|
||||
else:
|
||||
urls = ("/wx", "channel.wechatmp.active_reply.Query")
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("wechatmp_port", 8080)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
def start_loop(self, loop):
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
async def delete_media(self, media_id):
|
||||
logger.debug("[wechatmp] permanent media {} will be deleted in 10s".format(media_id))
|
||||
await asyncio.sleep(10)
|
||||
self.client.material.delete(media_id)
|
||||
logger.info("[wechatmp] permanent media {} has been deleted".format(media_id))
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if self.passive_reply:
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
|
||||
self.cache_dict[receiver].append(("text", reply_text))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
voice_file_path = reply.content
|
||||
duration, files = split_audio(voice_file_path, 60 * 1000)
|
||||
if len(files) > 1:
|
||||
logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
|
||||
|
||||
for path in files:
|
||||
# support: <2M, <60s, mp3/wma/wav/amr
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
response = self.client.material.add("voice", f)
|
||||
logger.debug("[wechatmp] upload voice response: {}".format(response))
|
||||
f_size = os.fstat(f.fileno()).st_size
|
||||
time.sleep(1.0 + 2 * f_size / 1024 / 1024)
|
||||
# todo check media_id
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("voice", media_id))
|
||||
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.material.add("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("image", media_id))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.material.add("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("image", media_id))
|
||||
else:
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
||||
if len(texts) > 1:
|
||||
logger.info("[wechatmp] text too long, split into {} parts".format(len(texts)))
|
||||
for i, text in enumerate(texts):
|
||||
self.client.message.send_text(receiver, text)
|
||||
if i != len(texts) - 1:
|
||||
time.sleep(0.5) # 休眠0.5秒,防止发送过快乱序
|
||||
logger.info("[wechatmp] Do send text to {}: {}".format(receiver, reply_text))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
try:
|
||||
file_path = reply.content
|
||||
file_name = os.path.basename(file_path)
|
||||
file_type = os.path.splitext(file_name)[1]
|
||||
if file_type == ".mp3":
|
||||
file_type = "audio/mpeg"
|
||||
elif file_type == ".amr":
|
||||
file_type = "audio/amr"
|
||||
else:
|
||||
mp3_file = os.path.splitext(file_path)[0] + ".mp3"
|
||||
any_to_mp3(file_path, mp3_file)
|
||||
file_path = mp3_file
|
||||
file_name = os.path.basename(file_path)
|
||||
file_type = "audio/mpeg"
|
||||
logger.info("[wechatmp] file_name: {}, file_type: {} ".format(file_name, file_type))
|
||||
media_ids = []
|
||||
duration, files = split_audio(file_path, 60 * 1000)
|
||||
if len(files) > 1:
|
||||
logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
|
||||
for path in files:
|
||||
# support: <2M, <60s, AMR\MP3
|
||||
response = self.client.media.upload("voice", (os.path.basename(path), open(path, "rb"), file_type))
|
||||
logger.debug("[wechatcom] upload voice response: {}".format(response))
|
||||
media_ids.append(response["media_id"])
|
||||
os.remove(path)
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
||||
return
|
||||
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for media_id in media_ids:
|
||||
self.client.message.send_voice(receiver, media_id)
|
||||
time.sleep(1)
|
||||
logger.info("[wechatmp] Do send voice to {}".format(receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.media.upload("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_image(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send image to {}".format(receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
image_type = imghdr.what(image_storage)
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + image_type
|
||||
content_type = "image/" + image_type
|
||||
try:
|
||||
response = self.client.media.upload("image", (filename, image_storage, content_type))
|
||||
logger.debug("[wechatmp] upload image response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload image failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_image(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send image to {}".format(receiver))
|
||||
return
|
||||
|
||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context["msg"].msg_id))
|
||||
if self.passive_reply:
|
||||
self.running.remove(session_id)
|
||||
|
||||
def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
|
||||
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context["msg"].msg_id, exception))
|
||||
if self.passive_reply:
|
||||
assert session_id not in self.cache_dict
|
||||
self.running.remove(session_id)
|
||||
49
channel/wechatmp/wechatmp_client.py
Normal file
49
channel/wechatmp/wechatmp_client.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
from wechatpy.client import WeChatClient
|
||||
from wechatpy.exceptions import APILimitedException
|
||||
|
||||
from channel.wechatmp.common import *
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class WechatMPClient(WeChatClient):
|
||||
def __init__(self, appid, secret, access_token=None, session=None, timeout=None, auto_retry=True):
|
||||
super(WechatMPClient, self).__init__(appid, secret, access_token, session, timeout, auto_retry)
|
||||
self.fetch_access_token_lock = threading.Lock()
|
||||
self.clear_quota_lock = threading.Lock()
|
||||
self.last_clear_quota_time = -1
|
||||
|
||||
def clear_quota(self):
|
||||
return self.post("clear_quota", data={"appid": self.appid})
|
||||
|
||||
def clear_quota_v2(self):
|
||||
return self.post("clear_quota/v2", params={"appid": self.appid, "appsecret": self.secret})
|
||||
|
||||
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
|
||||
with self.fetch_access_token_lock:
|
||||
access_token = self.session.get(self.access_token_key)
|
||||
if access_token:
|
||||
if not self.expires_at:
|
||||
return access_token
|
||||
timestamp = time.time()
|
||||
if self.expires_at - timestamp > 60:
|
||||
return access_token
|
||||
return super().fetch_access_token()
|
||||
|
||||
def _request(self, method, url_or_endpoint, **kwargs): # 重载父类方法,遇到API限流时,清除quota后重试
|
||||
try:
|
||||
return super()._request(method, url_or_endpoint, **kwargs)
|
||||
except APILimitedException as e:
|
||||
logger.error("[wechatmp] API quata has been used up. {}".format(e))
|
||||
if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
|
||||
with self.clear_quota_lock:
|
||||
if self.last_clear_quota_time == -1 or time.time() - self.last_clear_quota_time > 60:
|
||||
self.last_clear_quota_time = time.time()
|
||||
response = self.clear_quota_v2()
|
||||
logger.debug("[wechatmp] API quata has been cleard, {}".format(response))
|
||||
return super()._request(method, url_or_endpoint, **kwargs)
|
||||
else:
|
||||
logger.error("[wechatmp] last clear quota time is {}, less than 60s, skip clear quota")
|
||||
raise e
|
||||
56
channel/wechatmp/wechatmp_message.py
Normal file
56
channel/wechatmp/wechatmp_message.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-#
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class WeChatMPMessage(ChatMessage):
|
||||
def __init__(self, msg, client=None):
|
||||
super().__init__(msg)
|
||||
self.msg_id = msg.id
|
||||
self.create_time = msg.time
|
||||
self.is_group = False
|
||||
|
||||
if msg.type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = msg.content
|
||||
elif msg.type == "voice":
|
||||
if msg.recognition == None:
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = TmpDir().path() + msg.media_id + "." + msg.format # content直接存临时目录路径
|
||||
|
||||
def download_voice():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
response = client.media.download(msg.media_id)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[wechatmp] Failed to download voice file, {response.content}")
|
||||
|
||||
self._prepare_fn = download_voice
|
||||
else:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = msg.recognition
|
||||
elif msg.type == "image":
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = TmpDir().path() + msg.media_id + ".png" # content直接存临时目录路径
|
||||
|
||||
def download_image():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
response = client.media.download(msg.media_id)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[wechatmp] Failed to download image file, {response.content}")
|
||||
|
||||
self._prepare_fn = download_image
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg.type))
|
||||
|
||||
self.from_user_id = msg.source
|
||||
self.to_user_id = msg.target
|
||||
self.other_user_id = msg.source
|
||||
17
channel/wework/run.py
Normal file
17
channel/wework/run.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
import time
|
||||
os.environ['ntwork_LOG'] = "ERROR"
|
||||
import ntwork
|
||||
|
||||
wework = ntwork.WeWork()
|
||||
|
||||
|
||||
def forever():
|
||||
try:
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
ntwork.exit_()
|
||||
os._exit(0)
|
||||
|
||||
|
||||
326
channel/wework/wework_channel.py
Normal file
326
channel/wework/wework_channel.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import threading
|
||||
os.environ['ntwork_LOG'] = "ERROR"
|
||||
import ntwork
|
||||
import requests
|
||||
import uuid
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wework.wework_message import *
|
||||
from channel.wework.wework_message import WeworkMessage
|
||||
from common.singleton import singleton
|
||||
from common.log import logger
|
||||
from common.time_check import time_checker
|
||||
from common.utils import compress_imgfile, fsize
|
||||
from config import conf
|
||||
from channel.wework.run import wework
|
||||
from channel.wework import run
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_wxid_by_name(room_members, group_wxid, name):
|
||||
if group_wxid in room_members:
|
||||
for member in room_members[group_wxid]['member_list']:
|
||||
if member['room_nickname'] == name or member['username'] == name:
|
||||
return member['user_id']
|
||||
return None # 如果没有找到对应的group_wxid或name,则返回None
|
||||
|
||||
|
||||
def download_and_compress_image(url, filename, quality=30):
|
||||
# 确定保存图片的目录
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
# 如果目录不存在,则创建目录
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
# 下载图片
|
||||
pic_res = requests.get(url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
|
||||
# 检查图片大小并可能进行压缩
|
||||
sz = fsize(image_storage)
|
||||
if sz >= 10 * 1024 * 1024: # 如果图片大于 10 MB
|
||||
logger.info("[wework] image too large, ready to compress, sz={}".format(sz))
|
||||
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
||||
logger.info("[wework] image compressed, sz={}".format(fsize(image_storage)))
|
||||
|
||||
# 将内存缓冲区的指针重置到起始位置
|
||||
image_storage.seek(0)
|
||||
|
||||
# 读取并保存图片
|
||||
image = Image.open(image_storage)
|
||||
image_path = os.path.join(directory, f"{filename}.png")
|
||||
image.save(image_path, "png")
|
||||
|
||||
return image_path
|
||||
|
||||
|
||||
def download_video(url, filename):
|
||||
# 确定保存视频的目录
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
# 如果目录不存在,则创建目录
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
# 下载视频
|
||||
response = requests.get(url, stream=True)
|
||||
total_size = 0
|
||||
|
||||
video_path = os.path.join(directory, f"{filename}.mp4")
|
||||
|
||||
with open(video_path, 'wb') as f:
|
||||
for block in response.iter_content(1024):
|
||||
total_size += len(block)
|
||||
|
||||
# 如果视频的总大小超过30MB (30 * 1024 * 1024 bytes),则停止下载并返回
|
||||
if total_size > 30 * 1024 * 1024:
|
||||
logger.info("[WX] Video is larger than 30MB, skipping...")
|
||||
return None
|
||||
|
||||
f.write(block)
|
||||
|
||||
return video_path
|
||||
|
||||
|
||||
def create_message(wework_instance, message, is_group):
|
||||
logger.debug(f"正在为{'群聊' if is_group else '单聊'}创建 WeworkMessage")
|
||||
cmsg = WeworkMessage(message, wework=wework_instance, is_group=is_group)
|
||||
logger.debug(f"cmsg:{cmsg}")
|
||||
return cmsg
|
||||
|
||||
|
||||
def handle_message(cmsg, is_group):
|
||||
logger.debug(f"准备用 WeworkChannel 处理{'群聊' if is_group else '单聊'}消息")
|
||||
if is_group:
|
||||
WeworkChannel().handle_group(cmsg)
|
||||
else:
|
||||
WeworkChannel().handle_single(cmsg)
|
||||
logger.debug(f"已用 WeworkChannel 处理完{'群聊' if is_group else '单聊'}消息")
|
||||
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: ChatMessage):
|
||||
msgId = cmsg.msg_id
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if create_time is None:
|
||||
return func(self, cmsg)
|
||||
if int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@wework.msg_register(
|
||||
[ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_LINK_CARD_MSG,ntwork.MT_RECV_FILE_MSG, ntwork.MT_RECV_VOICE_MSG])
|
||||
def all_msg_handler(wework_instance: ntwork.WeWork, message):
|
||||
logger.debug(f"收到消息: {message}")
|
||||
if 'data' in message:
|
||||
# 首先查找conversation_id,如果没有找到,则查找room_conversation_id
|
||||
conversation_id = message['data'].get('conversation_id', message['data'].get('room_conversation_id'))
|
||||
if conversation_id is not None:
|
||||
is_group = "R:" in conversation_id
|
||||
try:
|
||||
cmsg = create_message(wework_instance=wework_instance, message=message, is_group=is_group)
|
||||
except NotImplementedError as e:
|
||||
logger.error(f"[WX]{message.get('MsgId', 'unknown')} 跳过: {e}")
|
||||
return None
|
||||
delay = random.randint(1, 2)
|
||||
timer = threading.Timer(delay, handle_message, args=(cmsg, is_group))
|
||||
timer.start()
|
||||
else:
|
||||
logger.debug("消息数据中无 conversation_id")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def accept_friend_with_retries(wework_instance, user_id, corp_id):
|
||||
result = wework_instance.accept_friend(user_id, corp_id)
|
||||
logger.debug(f'result:{result}')
|
||||
|
||||
|
||||
# @wework.msg_register(ntwork.MT_RECV_FRIEND_MSG)
|
||||
# def friend(wework_instance: ntwork.WeWork, message):
|
||||
# data = message["data"]
|
||||
# user_id = data["user_id"]
|
||||
# corp_id = data["corp_id"]
|
||||
# logger.info(f"接收到好友请求,消息内容:{data}")
|
||||
# delay = random.randint(1, 180)
|
||||
# threading.Timer(delay, accept_friend_with_retries, args=(wework_instance, user_id, corp_id)).start()
|
||||
#
|
||||
# return None
|
||||
|
||||
|
||||
def get_with_retry(get_func, max_retries=5, delay=5):
|
||||
retries = 0
|
||||
result = None
|
||||
while retries < max_retries:
|
||||
result = get_func()
|
||||
if result:
|
||||
break
|
||||
logger.warning(f"获取数据失败,重试第{retries + 1}次······")
|
||||
retries += 1
|
||||
time.sleep(delay) # 等待一段时间后重试
|
||||
return result
|
||||
|
||||
|
||||
@singleton
|
||||
class WeworkChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def startup(self):
|
||||
smart = conf().get("wework_smart", True)
|
||||
wework.open(smart)
|
||||
logger.info("等待登录······")
|
||||
wework.wait_login()
|
||||
login_info = wework.get_login_info()
|
||||
self.user_id = login_info['user_id']
|
||||
self.name = login_info['nickname']
|
||||
logger.info(f"登录信息:>>>user_id:{self.user_id}>>>>>>>>name:{self.name}")
|
||||
logger.info("静默延迟60s,等待客户端刷新数据,请勿进行任何操作······")
|
||||
time.sleep(60)
|
||||
contacts = get_with_retry(wework.get_external_contacts)
|
||||
rooms = get_with_retry(wework.get_rooms)
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
if not contacts or not rooms:
|
||||
logger.error("获取contacts或rooms失败,程序退出")
|
||||
ntwork.exit_()
|
||||
os.exit(0)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
# 将contacts保存到json文件中
|
||||
with open(os.path.join(directory, 'wework_contacts.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(contacts, f, ensure_ascii=False, indent=4)
|
||||
with open(os.path.join(directory, 'wework_rooms.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(rooms, f, ensure_ascii=False, indent=4)
|
||||
# 创建一个空字典来保存结果
|
||||
result = {}
|
||||
|
||||
# 遍历列表中的每个字典
|
||||
for room in rooms['room_list']:
|
||||
# 获取聊天室ID
|
||||
room_wxid = room['conversation_id']
|
||||
|
||||
# 获取聊天室成员
|
||||
room_members = wework.get_room_members(room_wxid)
|
||||
|
||||
# 将聊天室成员保存到结果字典中
|
||||
result[room_wxid] = room_members
|
||||
|
||||
# 将结果保存到json文件中
|
||||
with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=4)
|
||||
logger.info("wework程序初始化完成········")
|
||||
run.forever()
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: ChatMessage):
|
||||
if cmsg.from_user_id == cmsg.to_user_id:
|
||||
# ignore self reply
|
||||
return
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if not conf().get("speech_recognition"):
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
else:
|
||||
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: ChatMessage):
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if not conf().get("speech_recognition"):
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
|
||||
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
pass
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
logger.debug(f"context: {context}")
|
||||
receiver = context["receiver"]
|
||||
actual_user_id = context["msg"].actual_user_id
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.TEXT_:
|
||||
match = re.search(r"^@(.*?)\n", reply.content)
|
||||
logger.debug(f"match: {match}")
|
||||
if match:
|
||||
new_content = re.sub(r"^@(.*?)\n", "\n", reply.content)
|
||||
at_list = [actual_user_id]
|
||||
logger.debug(f"new_content: {new_content}")
|
||||
wework.send_room_at_msg(receiver, new_content, at_list)
|
||||
else:
|
||||
wework.send_text(receiver, reply.content)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
wework.send_text(receiver, reply.content)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
# Read data from image_storage
|
||||
data = image_storage.read()
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp:
|
||||
temp_path = temp.name
|
||||
temp.write(data)
|
||||
# Send the image
|
||||
wework.send_image(receiver, temp_path)
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
# Remove the temporary file
|
||||
os.remove(temp_path)
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
filename = str(uuid.uuid4())
|
||||
|
||||
# 调用你的函数,下载图片并保存为本地文件
|
||||
image_path = download_and_compress_image(img_url, filename)
|
||||
|
||||
wework.send_image(receiver, file_path=image_path)
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL:
|
||||
video_url = reply.content
|
||||
filename = str(uuid.uuid4())
|
||||
video_path = download_video(video_url, filename)
|
||||
|
||||
if video_path is None:
|
||||
# 如果视频太大,下载可能会被跳过,此时 video_path 将为 None
|
||||
wework.send_text(receiver, "抱歉,视频太大了!!!")
|
||||
else:
|
||||
wework.send_video(receiver, video_path)
|
||||
logger.info("[WX] sendVideo, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
current_dir = os.getcwd()
|
||||
voice_file = reply.content.split("/")[-1]
|
||||
reply.content = os.path.join(current_dir, "tmp", voice_file)
|
||||
wework.send_file(receiver, reply.content)
|
||||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
||||
224
channel/wework/wework_message.py
Normal file
224
channel/wework/wework_message.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import pilk
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from ntwork.const import send_type
|
||||
|
||||
|
||||
def get_with_retry(get_func, max_retries=5, delay=5):
|
||||
retries = 0
|
||||
result = None
|
||||
while retries < max_retries:
|
||||
result = get_func()
|
||||
if result:
|
||||
break
|
||||
logger.warning(f"获取数据失败,重试第{retries + 1}次······")
|
||||
retries += 1
|
||||
time.sleep(delay) # 等待一段时间后重试
|
||||
return result
|
||||
|
||||
|
||||
def get_room_info(wework, conversation_id):
|
||||
logger.debug(f"传入的 conversation_id: {conversation_id}")
|
||||
rooms = wework.get_rooms()
|
||||
if not rooms or 'room_list' not in rooms:
|
||||
logger.error(f"获取群聊信息失败: {rooms}")
|
||||
return None
|
||||
time.sleep(1)
|
||||
logger.debug(f"获取到的群聊信息: {rooms}")
|
||||
for room in rooms['room_list']:
|
||||
if room['conversation_id'] == conversation_id:
|
||||
return room
|
||||
return None
|
||||
|
||||
|
||||
def cdn_download(wework, message, file_name):
|
||||
data = message["data"]
|
||||
aes_key = data["cdn"]["aes_key"]
|
||||
file_size = data["cdn"]["size"]
|
||||
|
||||
# 获取当前工作目录,然后与文件名拼接得到保存路径
|
||||
current_dir = os.getcwd()
|
||||
save_path = os.path.join(current_dir, "tmp", file_name)
|
||||
|
||||
# 下载保存图片到本地
|
||||
if "url" in data["cdn"].keys() and "auth_key" in data["cdn"].keys():
|
||||
url = data["cdn"]["url"]
|
||||
auth_key = data["cdn"]["auth_key"]
|
||||
# result = wework.wx_cdn_download(url, auth_key, aes_key, file_size, save_path) # ntwork库本身接口有问题,缺失了aes_key这个参数
|
||||
"""
|
||||
下载wx类型的cdn文件,以https开头
|
||||
"""
|
||||
data = {
|
||||
'url': url,
|
||||
'auth_key': auth_key,
|
||||
'aes_key': aes_key,
|
||||
'size': file_size,
|
||||
'save_path': save_path
|
||||
}
|
||||
result = wework._WeWork__send_sync(send_type.MT_WXCDN_DOWNLOAD_MSG, data) # 直接用wx_cdn_download的接口内部实现来调用
|
||||
elif "file_id" in data["cdn"].keys():
|
||||
file_type = 2
|
||||
file_id = data["cdn"]["file_id"]
|
||||
result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
|
||||
else:
|
||||
logger.error(f"something is wrong, data: {data}")
|
||||
return
|
||||
|
||||
# 输出下载结果
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
|
||||
def c2c_download_and_convert(wework, message, file_name):
|
||||
data = message["data"]
|
||||
aes_key = data["cdn"]["aes_key"]
|
||||
file_size = data["cdn"]["size"]
|
||||
file_type = 5
|
||||
file_id = data["cdn"]["file_id"]
|
||||
|
||||
current_dir = os.getcwd()
|
||||
save_path = os.path.join(current_dir, "tmp", file_name)
|
||||
result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
|
||||
logger.debug(result)
|
||||
|
||||
# 在下载完SILK文件之后,立即将其转换为WAV文件
|
||||
base_name, _ = os.path.splitext(save_path)
|
||||
wav_file = base_name + ".wav"
|
||||
pilk.silk_to_wav(save_path, wav_file, rate=24000)
|
||||
|
||||
# 删除SILK文件
|
||||
try:
|
||||
os.remove(save_path)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
class WeworkMessage(ChatMessage):
|
||||
def __init__(self, wework_msg, wework, is_group=False):
|
||||
try:
|
||||
super().__init__(wework_msg)
|
||||
self.msg_id = wework_msg['data'].get('conversation_id', wework_msg['data'].get('room_conversation_id'))
|
||||
# 使用.get()防止 'send_time' 键不存在时抛出错误
|
||||
self.create_time = wework_msg['data'].get("send_time")
|
||||
self.is_group = is_group
|
||||
self.wework = wework
|
||||
|
||||
if wework_msg["type"] == 11041: # 文本消息类型
|
||||
if any(substring in wework_msg['data']['content'] for substring in ("该消息类型暂不能展示", "不支持的消息类型")):
|
||||
return
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wework_msg['data']['content']
|
||||
elif wework_msg["type"] == 11044: # 语音消息类型,需要缓存文件
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".silk"
|
||||
base_name, _ = os.path.splitext(file_name)
|
||||
file_name_2 = base_name + ".wav"
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name_2)
|
||||
self._prepare_fn = lambda: c2c_download_and_convert(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11042: # 图片消息类型,需要下载文件
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".jpg"
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name)
|
||||
self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11045: # 文件消息
|
||||
print("文件消息")
|
||||
print(wework_msg)
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
||||
file_name = file_name + wework_msg['data']['cdn']['file_name']
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.FILE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name)
|
||||
self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11047: # 链接消息
|
||||
self.ctype = ContextType.SHARING
|
||||
self.content = wework_msg['data']['url']
|
||||
elif wework_msg["type"] == 11072: # 新成员入群通知
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
member_list = wework_msg['data']['member_list']
|
||||
self.actual_user_nickname = member_list[0]['name']
|
||||
self.actual_user_id = member_list[0]['user_id']
|
||||
self.content = f"{self.actual_user_nickname}加入了群聊!"
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
rooms = get_with_retry(wework.get_rooms)
|
||||
if not rooms:
|
||||
logger.error("更新群信息失败···")
|
||||
else:
|
||||
result = {}
|
||||
for room in rooms['room_list']:
|
||||
# 获取聊天室ID
|
||||
room_wxid = room['conversation_id']
|
||||
|
||||
# 获取聊天室成员
|
||||
room_members = wework.get_room_members(room_wxid)
|
||||
|
||||
# 将聊天室成员保存到结果字典中
|
||||
result[room_wxid] = room_members
|
||||
with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=4)
|
||||
logger.info("有新成员加入,已自动更新群成员列表缓存!")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported message type: Type:{} MsgType:{}".format(wework_msg["type"], wework_msg["MsgType"]))
|
||||
|
||||
data = wework_msg['data']
|
||||
login_info = self.wework.get_login_info()
|
||||
logger.debug(f"login_info: {login_info}")
|
||||
nickname = f"{login_info['username']}({login_info['nickname']})" if login_info['nickname'] else login_info['username']
|
||||
user_id = login_info['user_id']
|
||||
|
||||
sender_id = data.get('sender')
|
||||
conversation_id = data.get('conversation_id')
|
||||
sender_name = data.get("sender_name")
|
||||
|
||||
self.from_user_id = user_id if sender_id == user_id else conversation_id
|
||||
self.from_user_nickname = nickname if sender_id == user_id else sender_name
|
||||
self.to_user_id = user_id
|
||||
self.to_user_nickname = nickname
|
||||
self.other_user_nickname = sender_name
|
||||
self.other_user_id = conversation_id
|
||||
|
||||
if self.is_group:
|
||||
conversation_id = data.get('conversation_id') or data.get('room_conversation_id')
|
||||
self.other_user_id = conversation_id
|
||||
if conversation_id:
|
||||
room_info = get_room_info(wework=wework, conversation_id=conversation_id)
|
||||
self.other_user_nickname = room_info.get('nickname', None) if room_info else None
|
||||
self.from_user_nickname = room_info.get('nickname', None) if room_info else None
|
||||
at_list = data.get('at_list', [])
|
||||
tmp_list = []
|
||||
for at in at_list:
|
||||
tmp_list.append(at['nickname'])
|
||||
at_list = tmp_list
|
||||
logger.debug(f"at_list: {at_list}")
|
||||
logger.debug(f"nickname: {nickname}")
|
||||
self.is_at = False
|
||||
if nickname in at_list or login_info['nickname'] in at_list or login_info['username'] in at_list:
|
||||
self.is_at = True
|
||||
self.at_list = at_list
|
||||
|
||||
# 检查消息内容是否包含@用户名。处理复制粘贴的消息,这类消息可能不会触发@通知,但内容中可能包含 "@用户名"。
|
||||
content = data.get('content', '')
|
||||
name = nickname
|
||||
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
|
||||
if re.search(pattern, content):
|
||||
logger.debug(f"Wechaty message {self.msg_id} includes at")
|
||||
self.is_at = True
|
||||
|
||||
if not self.actual_user_id:
|
||||
self.actual_user_id = data.get("sender")
|
||||
self.actual_user_nickname = sender_name if self.ctype != ContextType.JOIN_GROUP else self.actual_user_nickname
|
||||
else:
|
||||
logger.error("群聊消息中没有找到 conversation_id 或 room_conversation_id")
|
||||
|
||||
logger.debug(f"WeworkMessage has been successfully instantiated with message id: {self.msg_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"在 WeworkMessage 的初始化过程中出现错误:{e}")
|
||||
raise e
|
||||
@@ -2,4 +2,24 @@
|
||||
OPEN_AI = "openAI"
|
||||
CHATGPT = "chatGPT"
|
||||
BAIDU = "baidu"
|
||||
CHATGPTONAZURE = "chatGPTOnAzure"
|
||||
XUNFEI = "xunfei"
|
||||
CHATGPTONAZURE = "chatGPTOnAzure"
|
||||
LINKAI = "linkai"
|
||||
CLAUDEAI = "claude"
|
||||
QWEN = "qwen"
|
||||
GEMINI = "gemini"
|
||||
|
||||
# model
|
||||
GPT35 = "gpt-3.5-turbo"
|
||||
GPT4 = "gpt-4"
|
||||
GPT4_TURBO_PREVIEW = "gpt-4-1106-preview"
|
||||
GPT4_VISION_PREVIEW = "gpt-4-vision-preview"
|
||||
WHISPER_1 = "whisper-1"
|
||||
TTS_1 = "tts-1"
|
||||
TTS_1_HD = "tts-1-hd"
|
||||
|
||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo", GPT4_TURBO_PREVIEW, QWEN, GEMINI]
|
||||
|
||||
# channel
|
||||
FEISHU = "feishu"
|
||||
DINGTALK = "dingtalk"
|
||||
|
||||
33
common/dequeue.py
Normal file
33
common/dequeue.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from queue import Full, Queue
|
||||
from time import monotonic as time
|
||||
|
||||
|
||||
# add implementation of putleft to Queue
|
||||
class Dequeue(Queue):
|
||||
def putleft(self, item, block=True, timeout=None):
|
||||
with self.not_full:
|
||||
if self.maxsize > 0:
|
||||
if not block:
|
||||
if self._qsize() >= self.maxsize:
|
||||
raise Full
|
||||
elif timeout is None:
|
||||
while self._qsize() >= self.maxsize:
|
||||
self.not_full.wait()
|
||||
elif timeout < 0:
|
||||
raise ValueError("'timeout' must be a non-negative number")
|
||||
else:
|
||||
endtime = time() + timeout
|
||||
while self._qsize() >= self.maxsize:
|
||||
remaining = endtime - time()
|
||||
if remaining <= 0.0:
|
||||
raise Full
|
||||
self.not_full.wait(remaining)
|
||||
self._putleft(item)
|
||||
self.unfinished_tasks += 1
|
||||
self.not_empty.notify()
|
||||
|
||||
def putleft_nowait(self, item):
|
||||
return self.putleft(item, block=False)
|
||||
|
||||
def _putleft(self, item):
|
||||
self.queue.appendleft(item)
|
||||
@@ -39,4 +39,4 @@ class ExpiredDict(dict):
|
||||
return [(key, self[key]) for key in self.keys()]
|
||||
|
||||
def __iter__(self):
|
||||
return self.keys().__iter__()
|
||||
return self.keys().__iter__()
|
||||
|
||||
28
common/linkai_client.py
Normal file
28
common/linkai_client.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from linkai import LinkAIClient, PushMsg
|
||||
from config import conf
|
||||
|
||||
|
||||
class ChatClient(LinkAIClient):
|
||||
def __init__(self, api_key, host, channel):
|
||||
super().__init__(api_key, host)
|
||||
self.channel = channel
|
||||
self.client_type = channel.channel_type
|
||||
|
||||
def on_message(self, push_msg: PushMsg):
|
||||
session_id = push_msg.session_id
|
||||
msg_content = push_msg.msg_content
|
||||
logger.info(f"receive msg push, session_id={session_id}, msg_content={msg_content}")
|
||||
context = Context()
|
||||
context.type = ContextType.TEXT
|
||||
context["receiver"] = session_id
|
||||
context["isgroup"] = push_msg.is_group
|
||||
self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
|
||||
|
||||
|
||||
def start(channel):
|
||||
client = ChatClient(api_key=conf().get("linkai_api_key"),
|
||||
host="link-ai.chat", channel=channel)
|
||||
client.start()
|
||||
@@ -2,15 +2,37 @@ import logging
|
||||
import sys
|
||||
|
||||
|
||||
def _get_logger():
|
||||
log = logging.getLogger('log')
|
||||
log.setLevel(logging.INFO)
|
||||
def _reset_logger(log):
|
||||
for handler in log.handlers:
|
||||
handler.close()
|
||||
log.removeHandler(handler)
|
||||
del handler
|
||||
log.handlers.clear()
|
||||
log.propagate = False
|
||||
console_handle = logging.StreamHandler(sys.stdout)
|
||||
console_handle.setFormatter(logging.Formatter('[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'))
|
||||
console_handle.setFormatter(
|
||||
logging.Formatter(
|
||||
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
file_handle = logging.FileHandler("run.log", encoding="utf-8")
|
||||
file_handle.setFormatter(
|
||||
logging.Formatter(
|
||||
"[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
)
|
||||
log.addHandler(file_handle)
|
||||
log.addHandler(console_handle)
|
||||
|
||||
|
||||
def _get_logger():
|
||||
log = logging.getLogger("log")
|
||||
_reset_logger(log)
|
||||
log.setLevel(logging.INFO)
|
||||
return log
|
||||
|
||||
|
||||
# 日志句柄
|
||||
logger = _get_logger()
|
||||
logger = _get_logger()
|
||||
|
||||
3
common/memory.py
Normal file
3
common/memory.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from common.expired_dict import ExpiredDict
|
||||
|
||||
USER_IMAGE_CACHE = ExpiredDict(60 * 3)
|
||||
36
common/package_manager.py
Normal file
36
common/package_manager.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import time
|
||||
|
||||
import pip
|
||||
from pip._internal import main as pipmain
|
||||
|
||||
from common.log import _reset_logger, logger
|
||||
|
||||
|
||||
def install(package):
|
||||
pipmain(["install", package])
|
||||
|
||||
|
||||
def install_requirements(file):
|
||||
pipmain(["install", "-r", file, "--upgrade"])
|
||||
_reset_logger(logger)
|
||||
|
||||
|
||||
def check_dulwich():
|
||||
needwait = False
|
||||
for i in range(2):
|
||||
if needwait:
|
||||
time.sleep(3)
|
||||
needwait = False
|
||||
try:
|
||||
import dulwich
|
||||
|
||||
return
|
||||
except ImportError:
|
||||
try:
|
||||
install("dulwich")
|
||||
except:
|
||||
needwait = True
|
||||
try:
|
||||
import dulwich
|
||||
except ImportError:
|
||||
raise ImportError("Unable to import dulwich")
|
||||
@@ -62,4 +62,4 @@ class SortedDict(dict):
|
||||
return iter(self.keys())
|
||||
|
||||
def __repr__(self):
|
||||
return f'{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})'
|
||||
return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})"
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import time,re,hashlib
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
|
||||
import config
|
||||
from common.log import logger
|
||||
|
||||
|
||||
def time_checker(f):
|
||||
def _time_checker(self, *args, **kwargs):
|
||||
_config = config.conf()
|
||||
@@ -9,17 +13,17 @@ def time_checker(f):
|
||||
if chat_time_module:
|
||||
chat_start_time = _config.get("chat_start_time", "00:00")
|
||||
chat_stopt_time = _config.get("chat_stop_time", "24:00")
|
||||
time_regex = re.compile(r'^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$') #时间匹配,包含24:00
|
||||
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00
|
||||
|
||||
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
|
||||
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
|
||||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
|
||||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
|
||||
|
||||
# 时间格式检查
|
||||
if not (starttime_format_check and stoptime_format_check and chat_time_check):
|
||||
logger.warn('时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})'.format(starttime_format_check,stoptime_format_check))
|
||||
if chat_start_time>"23:59":
|
||||
logger.error('启动时间可能存在问题,请修改!')
|
||||
logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
|
||||
if chat_start_time > "23:59":
|
||||
logger.error("启动时间可能存在问题,请修改!")
|
||||
|
||||
# 服务时间检查
|
||||
now_time = time.strftime("%H:%M", time.localtime())
|
||||
@@ -27,12 +31,12 @@ def time_checker(f):
|
||||
f(self, *args, **kwargs)
|
||||
return None
|
||||
else:
|
||||
if args[0]['Content'] == "#更新配置": # 不在服务时间内也可以更新配置
|
||||
if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置
|
||||
f(self, *args, **kwargs)
|
||||
else:
|
||||
logger.info('非服务时间内,不接受访问')
|
||||
logger.info("非服务时间内,不接受访问")
|
||||
return None
|
||||
else:
|
||||
f(self, *args, **kwargs) # 未开启时间模块则直接回答
|
||||
return _time_checker
|
||||
|
||||
return _time_checker
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from config import conf
|
||||
|
||||
|
||||
class TmpDir(object):
|
||||
"""A temporary directory that is deleted when the object is destroyed.
|
||||
"""
|
||||
"""A temporary directory that is deleted when the object is destroyed."""
|
||||
|
||||
tmpFilePath = pathlib.Path("./tmp/")
|
||||
|
||||
tmpFilePath = pathlib.Path('./tmp/')
|
||||
|
||||
def __init__(self):
|
||||
pathExists = os.path.exists(self.tmpFilePath)
|
||||
if not pathExists and conf().get('speech_recognition') == True:
|
||||
if not pathExists:
|
||||
os.makedirs(self.tmpFilePath)
|
||||
|
||||
def path(self):
|
||||
return str(self.tmpFilePath) + '/'
|
||||
|
||||
return str(self.tmpFilePath) + "/"
|
||||
|
||||
56
common/utils.py
Normal file
56
common/utils.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import io
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def fsize(file):
|
||||
if isinstance(file, io.BytesIO):
|
||||
return file.getbuffer().nbytes
|
||||
elif isinstance(file, str):
|
||||
return os.path.getsize(file)
|
||||
elif hasattr(file, "seek") and hasattr(file, "tell"):
|
||||
pos = file.tell()
|
||||
file.seek(0, os.SEEK_END)
|
||||
size = file.tell()
|
||||
file.seek(pos)
|
||||
return size
|
||||
else:
|
||||
raise TypeError("Unsupported type")
|
||||
|
||||
|
||||
def compress_imgfile(file, max_size):
|
||||
if fsize(file) <= max_size:
|
||||
return file
|
||||
file.seek(0)
|
||||
img = Image.open(file)
|
||||
rgb_image = img.convert("RGB")
|
||||
quality = 95
|
||||
while True:
|
||||
out_buf = io.BytesIO()
|
||||
rgb_image.save(out_buf, "JPEG", quality=quality)
|
||||
if fsize(out_buf) <= max_size:
|
||||
return out_buf
|
||||
quality -= 5
|
||||
|
||||
|
||||
def split_string_by_utf8_length(string, max_length, max_split=0):
|
||||
encoded = string.encode("utf-8")
|
||||
start, end = 0, 0
|
||||
result = []
|
||||
while end < len(encoded):
|
||||
if max_split > 0 and len(result) >= max_split:
|
||||
result.append(encoded[start:].decode("utf-8"))
|
||||
break
|
||||
end = min(start + max_length, len(encoded))
|
||||
# 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止
|
||||
while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000:
|
||||
end -= 1
|
||||
result.append(encoded[start:end].decode("utf-8"))
|
||||
start = end
|
||||
return result
|
||||
|
||||
|
||||
def get_path_suffix(path):
|
||||
path = urlparse(path).path
|
||||
return os.path.splitext(path)[-1].lstrip('.')
|
||||
@@ -1,18 +1,36 @@
|
||||
{
|
||||
"channel_type": "wx",
|
||||
"model": "",
|
||||
"open_ai_api_key": "YOUR API KEY",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"text_to_image": "dall-e-2",
|
||||
"voice_to_text": "openai",
|
||||
"text_to_voice": "openai",
|
||||
"proxy": "",
|
||||
"use_azure_chatgpt": false,
|
||||
"single_chat_prefix": ["bot", "@bot"],
|
||||
"hot_reload": false,
|
||||
"single_chat_prefix": [
|
||||
"bot",
|
||||
"@bot"
|
||||
],
|
||||
"single_chat_reply_prefix": "[bot] ",
|
||||
"group_chat_prefix": ["@bot"],
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"],
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"],
|
||||
"image_create_prefix": ["画", "看", "找"],
|
||||
"speech_recognition": false,
|
||||
"group_chat_prefix": [
|
||||
"@bot"
|
||||
],
|
||||
"group_name_white_list": [
|
||||
"ChatGPT测试群",
|
||||
"ChatGPT测试群2"
|
||||
],
|
||||
"image_create_prefix": [
|
||||
"画"
|
||||
],
|
||||
"speech_recognition": true,
|
||||
"group_speech_recognition": false,
|
||||
"voice_reply_voice": false,
|
||||
"conversation_max_tokens": 1000,
|
||||
"conversation_max_tokens": 2500,
|
||||
"expires_in_seconds": 3600,
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"
|
||||
}
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"temperature": 0.7,
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": ""
|
||||
}
|
||||
|
||||
217
config.py
217
config.py
@@ -1,10 +1,14 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from common.log import logger
|
||||
|
||||
# 将所有可用的配置项写在字典里, 请使用小写字母
|
||||
# 此处的配置值无实际意义,程序不会读取此处的配置,仅用于提示格式,请将配置加入到config.json中
|
||||
available_setting = {
|
||||
# openai api配置
|
||||
"open_ai_api_key": "", # openai api key
|
||||
@@ -12,73 +16,156 @@ available_setting = {
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"proxy": "", # openai使用的代理
|
||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": "gpt-3.5-turbo", # 还支持 gpt-4, gpt-4-turbo, wenxin, xunfei, qwen
|
||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||
|
||||
"azure_deployment_id": "", # azure 模型部署名称
|
||||
"azure_api_version": "", # azure api版本
|
||||
# Bot触发配置
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"single_chat_reply_suffix": "", # 私聊时自动回复的后缀,\n 可以换行
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
|
||||
"group_chat_reply_suffix": "", # 群聊时自动回复的后缀,\n 可以换行
|
||||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
|
||||
"group_at_off": False, # 是否关闭群聊时@bot的触发
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"nick_name_black_list": [], # 用户昵称黑名单
|
||||
"group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎
|
||||
"trigger_by_self": False, # 是否允许机器人触发
|
||||
"text_to_image": "dall-e-2", # 图片生成模型,可选 dall-e-2, dall-e-3
|
||||
"image_proxy": True, # 是否需要图片代理,国内访问LinkAI时需要
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
|
||||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
|
||||
"image_create_size": "256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024 (dall-e-3默认为1024x1024)
|
||||
"group_chat_exit_group": False,
|
||||
# chatgpt会话参数
|
||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
# 人格描述
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
|
||||
# chatgpt限流配置
|
||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
||||
"rate_limit_dalle": 50, # openai dalle的调用频率限制
|
||||
|
||||
|
||||
# chatgpt api参数 参考https://platform.openai.com/docs/api-reference/chat/create
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
|
||||
"request_timeout": 180, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
|
||||
# Baidu 文心一言参数
|
||||
"baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
|
||||
"baidu_wenxin_api_key": "", # Baidu api key
|
||||
"baidu_wenxin_secret_key": "", # Baidu secret key
|
||||
# 讯飞星火API
|
||||
"xunfei_app_id": "", # 讯飞应用ID
|
||||
"xunfei_api_key": "", # 讯飞 API key
|
||||
"xunfei_api_secret": "", # 讯飞 API secret
|
||||
# claude 配置
|
||||
"claude_api_cookie": "",
|
||||
"claude_uuid": "",
|
||||
# 通义千问API, 获取方式查看文档 https://help.aliyun.com/document_detail/2587494.html
|
||||
"qwen_access_key_id": "",
|
||||
"qwen_access_key_secret": "",
|
||||
"qwen_agent_key": "",
|
||||
"qwen_app_id": "",
|
||||
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
|
||||
# Google Gemini Api Key
|
||||
"gemini_api_key": "",
|
||||
# wework的通用配置
|
||||
"wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开
|
||||
# 语音设置
|
||||
"speech_recognition": False, # 是否开启语音识别
|
||||
"speech_recognition": True, # 是否开启语音识别
|
||||
"group_speech_recognition": False, # 是否开启群组语音识别
|
||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,google
|
||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline)
|
||||
|
||||
# baidu api的配置, 使用百度语音识别和语音合成时需要
|
||||
"always_reply_voice": False, # 是否一直使用语音回复
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure
|
||||
"text_to_voice": "openai", # 语音合成引擎,支持openai,baidu,google,pytts(offline),azure,elevenlabs
|
||||
"text_to_voice_model": "tts-1",
|
||||
"tts_voice_id": "alloy",
|
||||
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
|
||||
"baidu_app_id": "",
|
||||
"baidu_api_key": "",
|
||||
"baidu_secret_key": "",
|
||||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
|
||||
"baidu_dev_pid": "1536",
|
||||
|
||||
# azure 语音api配置, 使用azure语音识别和语音合成时需要
|
||||
"azure_voice_api_key": "",
|
||||
"azure_voice_region": "japaneast",
|
||||
# elevenlabs 语音api配置
|
||||
"xi_api_key": "", #获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication
|
||||
"xi_voice_id": "", #ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam”
|
||||
# 服务时间限制,目前支持itchat
|
||||
"chat_time_module": False, # 是否开启服务时间限制
|
||||
"chat_start_time": "00:00", # 服务开始时间
|
||||
"chat_stop_time": "24:00", # 服务结束时间
|
||||
|
||||
# 翻译api
|
||||
"translate": "baidu", # 翻译api,支持baidu
|
||||
# baidu翻译api的配置
|
||||
"baidu_translate_app_id": "", # 百度翻译api的appid
|
||||
"baidu_translate_app_key": "", # 百度翻译api的秘钥
|
||||
# itchat的配置
|
||||
"hot_reload": False, # 是否开启热重载
|
||||
|
||||
# wechaty的配置
|
||||
"wechaty_puppet_service_token": "", # wechaty的token
|
||||
# wechatmp的配置
|
||||
"wechatmp_token": "", # 微信公众平台的Token
|
||||
"wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443
|
||||
"wechatmp_app_id": "", # 微信公众平台的appID
|
||||
"wechatmp_app_secret": "", # 微信公众平台的appsecret
|
||||
"wechatmp_aes_key": "", # 微信公众平台的EncodingAESKey,加密模式需要
|
||||
# wechatcom的通用配置
|
||||
"wechatcom_corp_id": "", # 企业微信公司的corpID
|
||||
# wechatcomapp的配置
|
||||
"wechatcomapp_token": "", # 企业微信app的token
|
||||
"wechatcomapp_port": 9898, # 企业微信app的服务端口,不需要端口转发
|
||||
"wechatcomapp_secret": "", # 企业微信app的secret
|
||||
"wechatcomapp_agent_id": "", # 企业微信app的agent_id
|
||||
"wechatcomapp_aes_key": "", # 企业微信app的aes_key
|
||||
|
||||
# 飞书配置
|
||||
"feishu_port": 80, # 飞书bot监听端口
|
||||
"feishu_app_id": "", # 飞书机器人应用APP Id
|
||||
"feishu_app_secret": "", # 飞书机器人APP secret
|
||||
"feishu_token": "", # 飞书 verification token
|
||||
"feishu_bot_name": "", # 飞书机器人的名字
|
||||
|
||||
# 钉钉配置
|
||||
"dingtalk_client_id": "", # 钉钉机器人Client ID
|
||||
"dingtalk_client_secret": "", # 钉钉机器人Client Secret
|
||||
|
||||
# chatgpt指令自定义触发词
|
||||
"clear_memory_commands": ['#清除记忆'], # 重置会话指令
|
||||
|
||||
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
|
||||
# channel配置
|
||||
"channel_type": "wx", # 通道类型,支持wx,wxy和terminal
|
||||
|
||||
|
||||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app}
|
||||
"subscribe_msg": "", # 订阅消息, 支持: wechatmp, wechatmp_service, wechatcom_app
|
||||
"debug": False, # 是否开启debug模式,开启后会打印更多日志
|
||||
"appdata_dir": "", # 数据目录
|
||||
# 插件配置
|
||||
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
|
||||
# 是否使用全局插件配置
|
||||
"use_global_plugin_config": False,
|
||||
# 知识库平台配置
|
||||
"use_linkai": False,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": "",
|
||||
"linkai_api_base": "https://api.link-ai.chat", # linkAI服务地址,若国内无法访问或延迟较高可改为 https://api.link-ai.tech
|
||||
}
|
||||
|
||||
|
||||
class Config(dict):
|
||||
def __init__(self, d=None):
|
||||
super().__init__()
|
||||
if d is None:
|
||||
d = {}
|
||||
for k, v in d.items():
|
||||
self[k] = v
|
||||
# user_datas: 用户数据,key为用户名,value为用户数据,也是dict
|
||||
self.user_datas = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in available_setting:
|
||||
raise Exception("key {} not in available_setting".format(key))
|
||||
@@ -97,6 +184,31 @@ class Config(dict):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
# Make sure to return a dictionary to ensure atomic
|
||||
def get_user_data(self, user) -> dict:
|
||||
if self.user_datas.get(user) is None:
|
||||
self.user_datas[user] = {}
|
||||
return self.user_datas[user]
|
||||
|
||||
def load_user_datas(self):
|
||||
try:
|
||||
with open(os.path.join(get_appdata_dir(), "user_datas.pkl"), "rb") as f:
|
||||
self.user_datas = pickle.load(f)
|
||||
logger.info("[Config] User datas loaded.")
|
||||
except FileNotFoundError as e:
|
||||
logger.info("[Config] User datas file not found, ignore.")
|
||||
except Exception as e:
|
||||
logger.info("[Config] User datas error: {}".format(e))
|
||||
self.user_datas = {}
|
||||
|
||||
def save_user_datas(self):
|
||||
try:
|
||||
with open(os.path.join(get_appdata_dir(), "user_datas.pkl"), "wb") as f:
|
||||
pickle.dump(self.user_datas, f)
|
||||
logger.info("[Config] User datas saved.")
|
||||
except Exception as e:
|
||||
logger.info("[Config] User datas error: {}".format(e))
|
||||
|
||||
|
||||
config = Config()
|
||||
|
||||
@@ -105,7 +217,7 @@ def load_config():
|
||||
global config
|
||||
config_path = "./config.json"
|
||||
if not os.path.exists(config_path):
|
||||
logger.info('配置文件不存在,将使用config-template.json模板')
|
||||
logger.info("配置文件不存在,将使用config-template.json模板")
|
||||
config_path = "./config-template.json"
|
||||
|
||||
config_str = read_file(config_path)
|
||||
@@ -119,24 +231,77 @@ def load_config():
|
||||
for name, value in os.environ.items():
|
||||
name = name.lower()
|
||||
if name in available_setting:
|
||||
logger.info(
|
||||
"[INIT] override config by environ args: {}={}".format(name, value))
|
||||
logger.info("[INIT] override config by environ args: {}={}".format(name, value))
|
||||
try:
|
||||
config[name] = eval(value)
|
||||
except:
|
||||
config[name] = value
|
||||
if value == "false":
|
||||
config[name] = False
|
||||
elif value == "true":
|
||||
config[name] = True
|
||||
else:
|
||||
config[name] = value
|
||||
|
||||
if config.get("debug", False):
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.debug("[INIT] set log level to DEBUG")
|
||||
|
||||
logger.info("[INIT] load config: {}".format(config))
|
||||
|
||||
config.load_user_datas()
|
||||
|
||||
|
||||
def get_root():
|
||||
return os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def read_file(path):
|
||||
with open(path, mode='r', encoding='utf-8') as f:
|
||||
with open(path, mode="r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def conf():
|
||||
return config
|
||||
|
||||
|
||||
def get_appdata_dir():
|
||||
data_path = os.path.join(get_root(), conf().get("appdata_dir", ""))
|
||||
if not os.path.exists(data_path):
|
||||
logger.info("[INIT] data path not exists, create it: {}".format(data_path))
|
||||
os.makedirs(data_path)
|
||||
return data_path
|
||||
|
||||
|
||||
def subscribe_msg():
|
||||
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
|
||||
msg = conf().get("subscribe_msg", "")
|
||||
return msg.format(trigger_prefix=trigger_prefix)
|
||||
|
||||
|
||||
# global plugin config
|
||||
plugin_config = {}
|
||||
|
||||
|
||||
def write_plugin_config(pconf: dict):
|
||||
"""
|
||||
写入插件全局配置
|
||||
:param pconf: 全量插件配置
|
||||
"""
|
||||
global plugin_config
|
||||
for k in pconf:
|
||||
plugin_config[k.lower()] = pconf[k]
|
||||
|
||||
|
||||
def pconf(plugin_name: str) -> dict:
|
||||
"""
|
||||
根据插件名称获取配置
|
||||
:param plugin_name: 插件名称
|
||||
:return: 该插件的配置项
|
||||
"""
|
||||
return plugin_config.get(plugin_name.lower())
|
||||
|
||||
|
||||
# 全局配置,用于存放全局生效的状态
|
||||
global_config = {
|
||||
"admin_users": []
|
||||
}
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
FROM python:3.10-alpine
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
curl \
|
||||
wget \
|
||||
&& export BUILD_GITHUB_TAG=${CHATGPT_ON_WECHAT_VER:-`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`} \
|
||||
&& wget -t 3 -T 30 -nv -O chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/archive/refs/tags/${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& tar -xzf chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& mv chatgpt-on-wechat-${BUILD_GITHUB_TAG} ${BUILD_PREFIX} \
|
||||
&& rm chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache -r requirements.txt \
|
||||
&& apk del curl wget
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \
|
||||
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,39 +0,0 @@
|
||||
FROM python:3.10
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
wget \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& export BUILD_GITHUB_TAG=${CHATGPT_ON_WECHAT_VER:-`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`} \
|
||||
&& wget -t 3 -T 30 -nv -O chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/archive/refs/tags/${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& tar -xzf chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& mv chatgpt-on-wechat-${BUILD_GITHUB_TAG} ${BUILD_PREFIX} \
|
||||
&& rm chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache -r requirements.txt
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& groupadd -r noroot \
|
||||
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
|
||||
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,28 +1,35 @@
|
||||
FROM python:3.10-alpine
|
||||
FROM python:3.10-slim-bullseye
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
RUN echo /etc/apt/sources.list
|
||||
# RUN sed -i 's/deb.debian.org/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
ADD . ${BUILD_PREFIX}
|
||||
|
||||
RUN apk add --no-cache bash ffmpeg espeak \
|
||||
RUN apt-get update \
|
||||
&&apt-get install -y --no-install-recommends bash ffmpeg espeak libavcodec-extra\
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache -r requirements.txt
|
||||
&& pip install --no-cache -r requirements.txt \
|
||||
&& pip install --no-cache -r requirements-optional.txt \
|
||||
&& pip install azure-cognitiveservices-speech
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \
|
||||
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
||||
&& mkdir -p /home/noroot \
|
||||
&& groupadd -r noroot \
|
||||
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
|
||||
&& chown -R noroot:noroot /home/noroot ${BUILD_PREFIX} /usr/local/lib
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["docker/entrypoint.sh"]
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# fetch latest release tag
|
||||
CHATGPT_ON_WECHAT_TAG=`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`
|
||||
|
||||
# build image
|
||||
docker build -f Dockerfile.alpine \
|
||||
--build-arg CHATGPT_ON_WECHAT_VER=$CHATGPT_ON_WECHAT_TAG \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
|
||||
# tag image
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# fetch latest release tag
|
||||
CHATGPT_ON_WECHAT_TAG=`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`
|
||||
|
||||
# build image
|
||||
docker build -f Dockerfile.debian \
|
||||
--build-arg CHATGPT_ON_WECHAT_VER=$CHATGPT_ON_WECHAT_TAG \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
|
||||
# tag image
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian
|
||||
@@ -1,4 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd .. && docker build -f Dockerfile \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
unset KUBECONFIG
|
||||
|
||||
cd .. && docker build -f docker/Dockerfile.latest \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$(date +%y%m%d)
|
||||
@@ -1,23 +0,0 @@
|
||||
FROM zhayujie/chatgpt-on-wechat:alpine
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
USER root
|
||||
|
||||
RUN apk add --no-cache \
|
||||
ffmpeg \
|
||||
espeak \
|
||||
&& pip install --no-cache \
|
||||
baidu-aip \
|
||||
chardet \
|
||||
SpeechRecognition
|
||||
|
||||
# replace entrypoint
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,24 +0,0 @@
|
||||
FROM zhayujie/chatgpt-on-wechat:debian
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
USER root
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
espeak \
|
||||
&& pip install --no-cache \
|
||||
baidu-aip \
|
||||
chardet \
|
||||
SpeechRecognition
|
||||
|
||||
# replace entrypoint
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,24 +0,0 @@
|
||||
version: '2.0'
|
||||
services:
|
||||
chatgpt-on-wechat:
|
||||
build:
|
||||
context: ./
|
||||
dockerfile: Dockerfile.alpine
|
||||
image: zhayujie/chatgpt-on-wechat-voice-reply
|
||||
container_name: chatgpt-on-wechat-voice-reply
|
||||
environment:
|
||||
OPEN_AI_API_KEY: 'YOUR API KEY'
|
||||
OPEN_AI_PROXY: ''
|
||||
SINGLE_CHAT_PREFIX: '["bot", "@bot"]'
|
||||
SINGLE_CHAT_REPLY_PREFIX: '"[bot] "'
|
||||
GROUP_CHAT_PREFIX: '["@bot"]'
|
||||
GROUP_NAME_WHITE_LIST: '["ChatGPT测试群", "ChatGPT测试群2"]'
|
||||
IMAGE_CREATE_PREFIX: '["画", "看", "找"]'
|
||||
CONVERSATION_MAX_TOKENS: 1000
|
||||
SPEECH_RECOGNITION: 'true'
|
||||
CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
|
||||
EXPIRES_IN_SECONDS: 3600
|
||||
VOICE_REPLY_VOICE: 'true'
|
||||
BAIDU_APP_ID: 'YOUR BAIDU APP ID'
|
||||
BAIDU_API_KEY: 'YOUR BAIDU API KEY'
|
||||
BAIDU_SECRET_KEY: 'YOUR BAIDU SERVICE KEY'
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# build prefix
|
||||
CHATGPT_ON_WECHAT_PREFIX=${CHATGPT_ON_WECHAT_PREFIX:-""}
|
||||
# path to config.json
|
||||
CHATGPT_ON_WECHAT_CONFIG_PATH=${CHATGPT_ON_WECHAT_CONFIG_PATH:-""}
|
||||
# execution command line
|
||||
CHATGPT_ON_WECHAT_EXEC=${CHATGPT_ON_WECHAT_EXEC:-""}
|
||||
|
||||
OPEN_AI_API_KEY=${OPEN_AI_API_KEY:-""}
|
||||
OPEN_AI_PROXY=${OPEN_AI_PROXY:-""}
|
||||
SINGLE_CHAT_PREFIX=${SINGLE_CHAT_PREFIX:-""}
|
||||
SINGLE_CHAT_REPLY_PREFIX=${SINGLE_CHAT_REPLY_PREFIX:-""}
|
||||
GROUP_CHAT_PREFIX=${GROUP_CHAT_PREFIX:-""}
|
||||
GROUP_NAME_WHITE_LIST=${GROUP_NAME_WHITE_LIST:-""}
|
||||
IMAGE_CREATE_PREFIX=${IMAGE_CREATE_PREFIX:-""}
|
||||
CONVERSATION_MAX_TOKENS=${CONVERSATION_MAX_TOKENS:-""}
|
||||
SPEECH_RECOGNITION=${SPEECH_RECOGNITION:-""}
|
||||
CHARACTER_DESC=${CHARACTER_DESC:-""}
|
||||
EXPIRES_IN_SECONDS=${EXPIRES_IN_SECONDS:-""}
|
||||
|
||||
VOICE_REPLY_VOICE=${VOICE_REPLY_VOICE:-""}
|
||||
BAIDU_APP_ID=${BAIDU_APP_ID:-""}
|
||||
BAIDU_API_KEY=${BAIDU_API_KEY:-""}
|
||||
BAIDU_SECRET_KEY=${BAIDU_SECRET_KEY:-""}
|
||||
|
||||
# CHATGPT_ON_WECHAT_PREFIX is empty, use /app
|
||||
if [ "$CHATGPT_ON_WECHAT_PREFIX" == "" ] ; then
|
||||
CHATGPT_ON_WECHAT_PREFIX=/app
|
||||
fi
|
||||
|
||||
# CHATGPT_ON_WECHAT_CONFIG_PATH is empty, use '/app/config.json'
|
||||
if [ "$CHATGPT_ON_WECHAT_CONFIG_PATH" == "" ] ; then
|
||||
CHATGPT_ON_WECHAT_CONFIG_PATH=$CHATGPT_ON_WECHAT_PREFIX/config.json
|
||||
fi
|
||||
|
||||
# CHATGPT_ON_WECHAT_EXEC is empty, use ‘python app.py’
|
||||
if [ "$CHATGPT_ON_WECHAT_EXEC" == "" ] ; then
|
||||
CHATGPT_ON_WECHAT_EXEC="python app.py"
|
||||
fi
|
||||
|
||||
# modify content in config.json
|
||||
if [ "$OPEN_AI_API_KEY" != "" ] ; then
|
||||
sed -i "s/\"open_ai_api_key\".*,$/\"open_ai_api_key\": \"$OPEN_AI_API_KEY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
else
|
||||
echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m"
|
||||
fi
|
||||
|
||||
# use http_proxy as default
|
||||
if [ "$HTTP_PROXY" != "" ] ; then
|
||||
sed -i "s/\"proxy\".*,$/\"proxy\": \"$HTTP_PROXY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$OPEN_AI_PROXY" != "" ] ; then
|
||||
sed -i "s/\"proxy\".*,$/\"proxy\": \"$OPEN_AI_PROXY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$SINGLE_CHAT_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"single_chat_prefix\".*,$/\"single_chat_prefix\": $SINGLE_CHAT_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$SINGLE_CHAT_REPLY_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"single_chat_reply_prefix\".*,$/\"single_chat_reply_prefix\": $SINGLE_CHAT_REPLY_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$GROUP_CHAT_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"group_chat_prefix\".*,$/\"group_chat_prefix\": $GROUP_CHAT_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$GROUP_NAME_WHITE_LIST" != "" ] ; then
|
||||
sed -i "s/\"group_name_white_list\".*,$/\"group_name_white_list\": $GROUP_NAME_WHITE_LIST,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$IMAGE_CREATE_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"image_create_prefix\".*,$/\"image_create_prefix\": $IMAGE_CREATE_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$CONVERSATION_MAX_TOKENS" != "" ] ; then
|
||||
sed -i "s/\"conversation_max_tokens\".*,$/\"conversation_max_tokens\": $CONVERSATION_MAX_TOKENS,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$SPEECH_RECOGNITION" != "" ] ; then
|
||||
sed -i "s/\"speech_recognition\".*,$/\"speech_recognition\": $SPEECH_RECOGNITION,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$CHARACTER_DESC" != "" ] ; then
|
||||
sed -i "s/\"character_desc\".*,$/\"character_desc\": \"$CHARACTER_DESC\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$EXPIRES_IN_SECONDS" != "" ] ; then
|
||||
sed -i "s/\"expires_in_seconds\".*$/\"expires_in_seconds\": $EXPIRES_IN_SECONDS/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
# append
|
||||
if [ "$BAIDU_SECRET_KEY" != "" ] ; then
|
||||
sed -i "1a \ \ \"baidu_secret_key\": \"$BAIDU_SECRET_KEY\"," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$BAIDU_API_KEY" != "" ] ; then
|
||||
sed -i "1a \ \ \"baidu_api_key\": \"$BAIDU_API_KEY\"," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$BAIDU_APP_ID" != "" ] ; then
|
||||
sed -i "1a \ \ \"baidu_app_id\": \"$BAIDU_APP_ID\"," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$VOICE_REPLY_VOICE" != "" ] ; then
|
||||
sed -i "1a \ \ \"voice_reply_voice\": $VOICE_REPLY_VOICE," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
# go to prefix dir
|
||||
cd $CHATGPT_ON_WECHAT_PREFIX
|
||||
# excute
|
||||
$CHATGPT_ON_WECHAT_EXEC
|
||||
|
||||
|
||||
@@ -1,20 +1,24 @@
|
||||
version: '2.0'
|
||||
services:
|
||||
chatgpt-on-wechat:
|
||||
build:
|
||||
context: ./
|
||||
dockerfile: Dockerfile.alpine
|
||||
image: zhayujie/chatgpt-on-wechat
|
||||
container_name: sample-chatgpt-on-wechat
|
||||
container_name: chatgpt-on-wechat
|
||||
security_opt:
|
||||
- seccomp:unconfined
|
||||
environment:
|
||||
OPEN_AI_API_KEY: 'YOUR API KEY'
|
||||
OPEN_AI_PROXY: ''
|
||||
MODEL: 'gpt-3.5-turbo'
|
||||
PROXY: ''
|
||||
SINGLE_CHAT_PREFIX: '["bot", "@bot"]'
|
||||
SINGLE_CHAT_REPLY_PREFIX: '"[bot] "'
|
||||
GROUP_CHAT_PREFIX: '["@bot"]'
|
||||
GROUP_NAME_WHITE_LIST: '["ChatGPT测试群", "ChatGPT测试群2"]'
|
||||
IMAGE_CREATE_PREFIX: '["画", "看", "找"]'
|
||||
CONVERSATION_MAX_TOKENS: 1000
|
||||
SPEECH_RECOGNITION: "False"
|
||||
SPEECH_RECOGNITION: 'False'
|
||||
CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
|
||||
EXPIRES_IN_SECONDS: 3600
|
||||
EXPIRES_IN_SECONDS: 3600
|
||||
USE_GLOBAL_PLUGIN_CONFIG: 'True'
|
||||
USE_LINKAI: 'False'
|
||||
LINKAI_API_KEY: ''
|
||||
LINKAI_APP_CODE: ''
|
||||
@@ -38,9 +38,9 @@ if [ "$CHATGPT_ON_WECHAT_EXEC" == "" ] ; then
|
||||
fi
|
||||
|
||||
# modify content in config.json
|
||||
if [ "$OPEN_AI_API_KEY" == "YOUR API KEY" ] || [ "$OPEN_AI_API_KEY" == "" ]; then
|
||||
echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m"
|
||||
fi
|
||||
# if [ "$OPEN_AI_API_KEY" == "YOUR API KEY" ] || [ "$OPEN_AI_API_KEY" == "" ]; then
|
||||
# echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m"
|
||||
# fi
|
||||
|
||||
|
||||
# go to prefix dir
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
OPEN_AI_API_KEY=YOUR API KEY
|
||||
OPEN_AI_PROXY=
|
||||
SINGLE_CHAT_PREFIX=["bot", "@bot"]
|
||||
SINGLE_CHAT_REPLY_PREFIX="[bot] "
|
||||
GROUP_CHAT_PREFIX=["@bot"]
|
||||
GROUP_NAME_WHITE_LIST=["ChatGPT测试群", "ChatGPT测试群2"]
|
||||
IMAGE_CREATE_PREFIX=["画", "看", "找"]
|
||||
CONVERSATION_MAX_TOKENS=1000
|
||||
SPEECH_RECOGNITION=false
|
||||
CHARACTER_DESC=你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。
|
||||
EXPIRES_IN_SECONDS=3600
|
||||
|
||||
# Optional
|
||||
#CHATGPT_ON_WECHAT_PREFIX=/app
|
||||
#CHATGPT_ON_WECHAT_CONFIG_PATH=/app/config.json
|
||||
#CHATGPT_ON_WECHAT_EXEC=python app.py
|
||||
@@ -1,26 +0,0 @@
|
||||
IMG:=`cat Name`
|
||||
MOUNT:=
|
||||
PORT_MAP:=
|
||||
DOTENV:=.env
|
||||
CONTAINER_NAME:=sample-chatgpt-on-wechat
|
||||
|
||||
echo:
|
||||
echo $(IMG)
|
||||
|
||||
run_d:
|
||||
docker rm $(CONTAINER_NAME) || echo
|
||||
docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \
|
||||
--env-file=$(DOTENV) \
|
||||
$(MOUNT) $(IMG)
|
||||
|
||||
run_i:
|
||||
docker rm $(CONTAINER_NAME) || echo
|
||||
docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \
|
||||
--env-file=$(DOTENV) \
|
||||
$(MOUNT) $(IMG)
|
||||
|
||||
stop:
|
||||
docker stop $(CONTAINER_NAME)
|
||||
|
||||
rm: stop
|
||||
docker rm $(CONTAINER_NAME)
|
||||
@@ -1 +0,0 @@
|
||||
zhayujie/chatgpt-on-wechat
|
||||
BIN
docs/images/aigcopen.png
Normal file
BIN
docs/images/aigcopen.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 51 KiB |
BIN
docs/images/contact.jpg
Normal file
BIN
docs/images/contact.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 151 KiB |
BIN
docs/images/planet.jpg
Normal file
BIN
docs/images/planet.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 33 KiB |
9
lib/itchat/LICENSE
Normal file
9
lib/itchat/LICENSE
Normal file
@@ -0,0 +1,9 @@
|
||||
**The MIT License (MIT)**
|
||||
|
||||
Copyright (c) 2017 LittleCoder ([littlecodersh@Github](https://github.com/littlecodersh))
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
@@ -43,6 +43,7 @@ def login(self, enableCmdQR=False, picDir=None, qrCallback=None,
|
||||
logger.warning('itchat has already logged in.')
|
||||
return
|
||||
self.isLogging = True
|
||||
logger.info('Ready to login.')
|
||||
while self.isLogging:
|
||||
uuid = push_login(self)
|
||||
if uuid:
|
||||
@@ -84,7 +85,7 @@ def login(self, enableCmdQR=False, picDir=None, qrCallback=None,
|
||||
if hasattr(loginCallback, '__call__'):
|
||||
r = loginCallback()
|
||||
else:
|
||||
utils.clear_screen()
|
||||
# utils.clear_screen()
|
||||
if os.path.exists(picDir or config.DEFAULT_QR):
|
||||
os.remove(picDir or config.DEFAULT_QR)
|
||||
logger.info('Login successfully as %s' % self.storageClass.nickName)
|
||||
@@ -195,13 +196,17 @@ def process_login_info(core, loginContent):
|
||||
core.loginInfo['logintime'] = int(time.time() * 1e3)
|
||||
core.loginInfo['BaseRequest'] = {}
|
||||
cookies = core.s.cookies.get_dict()
|
||||
skey = re.findall('<skey>(.*?)</skey>', r.text, re.S)[0]
|
||||
pass_ticket = re.findall(
|
||||
'<pass_ticket>(.*?)</pass_ticket>', r.text, re.S)[0]
|
||||
core.loginInfo['skey'] = core.loginInfo['BaseRequest']['Skey'] = skey
|
||||
res = re.findall('<skey>(.*?)</skey>', r.text, re.S)
|
||||
skey = res[0] if res else None
|
||||
res = re.findall(
|
||||
'<pass_ticket>(.*?)</pass_ticket>', r.text, re.S)
|
||||
pass_ticket = res[0] if res else None
|
||||
if skey is not None:
|
||||
core.loginInfo['skey'] = core.loginInfo['BaseRequest']['Skey'] = skey
|
||||
core.loginInfo['wxsid'] = core.loginInfo['BaseRequest']['Sid'] = cookies["wxsid"]
|
||||
core.loginInfo['wxuin'] = core.loginInfo['BaseRequest']['Uin'] = cookies["wxuin"]
|
||||
core.loginInfo['pass_ticket'] = pass_ticket
|
||||
if pass_ticket is not None:
|
||||
core.loginInfo['pass_ticket'] = pass_ticket
|
||||
# A question : why pass_ticket == DeviceID ?
|
||||
# deviceID is only a randomly generated number
|
||||
|
||||
@@ -317,6 +322,8 @@ def start_receiving(self, exitCallback=None, getReceivingFnOnly=False):
|
||||
retryCount += 1
|
||||
logger.error(traceback.format_exc())
|
||||
if self.receivingRetryCount < retryCount:
|
||||
logger.error("Having tried %s times, but still failed. " % (
|
||||
retryCount) + "Stop trying...")
|
||||
self.alive = False
|
||||
else:
|
||||
time.sleep(1)
|
||||
@@ -363,7 +370,7 @@ def sync_check(self):
|
||||
regx = r'window.synccheck={retcode:"(\d+)",selector:"(\d+)"}'
|
||||
pm = re.search(regx, r.text)
|
||||
if pm is None or pm.group(1) != '0':
|
||||
logger.debug('Unexpected sync check result: %s' % r.text)
|
||||
logger.error('Unexpected sync check result: %s' % r.text)
|
||||
return None
|
||||
return pm.group(2)
|
||||
|
||||
|
||||
@@ -25,9 +25,12 @@ def auto_login(self, hotReload=False, statusStorageDir='itchat.pkl',
|
||||
self.useHotReload = hotReload
|
||||
self.hotReloadDir = statusStorageDir
|
||||
if hotReload:
|
||||
if self.load_login_status(statusStorageDir,
|
||||
loginCallback=loginCallback, exitCallback=exitCallback):
|
||||
rval=self.load_login_status(statusStorageDir,
|
||||
loginCallback=loginCallback, exitCallback=exitCallback)
|
||||
if rval:
|
||||
return
|
||||
logger.error('Hot reload failed, logging in normally, error={}'.format(rval))
|
||||
self.logout()
|
||||
self.login(enableCmdQR=enableCmdQR, picDir=picDir, qrCallback=qrCallback,
|
||||
loginCallback=loginCallback, exitCallback=exitCallback)
|
||||
self.dump_login_status(statusStorageDir)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
providers = ['python']
|
||||
|
||||
[phases.setup]
|
||||
nixPkgs = ['python310']
|
||||
cmds = ['apt-get update','apt-get install -y --no-install-recommends ffmpeg espeak']
|
||||
cmds = ['apt-get update','apt-get install -y --no-install-recommends ffmpeg espeak libavcodec-extra']
|
||||
[phases.install]
|
||||
cmds = ['python -m venv /opt/venv && . /opt/venv/bin/activate && pip install -r requirements.txt && pip install -r requirements-optional.txt']
|
||||
[start]
|
||||
cmd = "python ./app.py"
|
||||
@@ -1,6 +1,14 @@
|
||||
**Table of Content**
|
||||
|
||||
- [插件化初衷](#插件化初衷)
|
||||
- [插件安装方法](#插件安装方法)
|
||||
- [插件化实现](#插件化实现)
|
||||
- [插件编写示例](#插件编写示例)
|
||||
- [插件设计建议](#插件设计建议)
|
||||
|
||||
## 插件化初衷
|
||||
|
||||
之前未插件化的代码耦合程度高,如果要定制一些个性化功能(如流量控制、接入`NovelAI`画图平台等),需要了解代码主体,避免影响到其他的功能。在实现多个功能后,不但无法调整功能的优先级顺序,功能的配置项也会变得非常混乱。
|
||||
之前未插件化的代码耦合程度高,如果要定制一些个性化功能(如流量控制、接入`NovelAI`画图平台等),需要了解代码主体,避免影响到其他的功能。多个功能同时存在时,无法调整功能的优先级顺序,功能配置项也非常混乱。
|
||||
|
||||
此时插件化应声而出。
|
||||
|
||||
@@ -11,7 +19,23 @@
|
||||
- [x] 插件化能够自由开关和调整优先级。
|
||||
- [x] 每个插件可在插件文件夹内维护独立的配置文件,方便代码的测试和调试,可以在独立的仓库开发插件。
|
||||
|
||||
PS: 插件目前仅支持`itchat`
|
||||
## 插件安装方法
|
||||
|
||||
在本仓库中预置了一些插件,如果要安装其他仓库的插件,有两种方法。
|
||||
|
||||
- 第一种方法是在将下载的插件文件都解压到"plugins"文件夹的一个单独的文件夹,最终插件的代码都位于"plugins/PLUGIN_NAME/*"中。启动程序后,如果插件的目录结构正确,插件会自动被扫描加载。除此以外,注意你还需要安装文件夹中`requirements.txt`中的依赖。
|
||||
|
||||
- 第二种方法是`Godcmd`插件,它是预置的管理员插件,能够让程序在运行时就能安装插件,它能够自动安装依赖。
|
||||
|
||||
安装插件的命令是"#installp [仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件名/仓库地址"。这是管理员命令,认证方法在[这里](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/godcmd)。
|
||||
|
||||
- 安装[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)记录的插件:#installp sdwebui
|
||||
|
||||
- 安装指定仓库的插件:#installp https://github.com/lanvent/plugin_sdwebui.git
|
||||
|
||||
在安装之后,需要执行"#scanp"命令来扫描加载新安装的插件(或者重新启动程序)。
|
||||
|
||||
安装插件后需要注意有些插件有自己的配置模板,一般要去掉".template"新建一个配置文件。
|
||||
|
||||
## 插件化实现
|
||||
|
||||
@@ -26,7 +50,9 @@ PS: 插件目前仅支持`itchat`
|
||||
1.收到消息 ---> 2.产生回复 ---> 3.包装回复 ---> 4.发送回复
|
||||
```
|
||||
|
||||
以下是它们的默认处理逻辑(太长不看,可跳过):
|
||||
以下是它们的默认处理逻辑(太长不看,可跳到[插件编写示例](#插件编写示例)):
|
||||
|
||||
**注意以下包含的代码是`v1.1.0`中的片段,已过时,只可用于理解事件,最新的默认代码逻辑请参考[chat_channel](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/chat_channel.py)**
|
||||
|
||||
#### 1. 收到消息
|
||||
|
||||
@@ -67,9 +93,9 @@ PS: 插件目前仅支持`itchat`
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE:
|
||||
reply = super().build_reply_content(context.content, context) #文字跟画图交付给chatgpt
|
||||
elif context.type == ContextType.VOICE: # 声音先进行语音转文字后,修改Context类型为文字后,再交付给chatgpt
|
||||
msg = context['msg']
|
||||
file_name = TmpDir().path() + context.content
|
||||
msg.download(file_name)
|
||||
cmsg = context['msg']
|
||||
cmsg.prepare()
|
||||
file_name = context.content
|
||||
reply = super().build_voice_to_text(file_name)
|
||||
if reply.type != ReplyType.ERROR and reply.type != ReplyType.INFO:
|
||||
context.content = reply.content # 语音转文字后,将文字内容作为新的context
|
||||
@@ -81,14 +107,14 @@ PS: 插件目前仅支持`itchat`
|
||||
```
|
||||
|
||||
回复`Reply`的定义如下所示,它允许Bot可以回复多类不同的消息。同时也加入了`INFO`和`ERROR`消息类型区分系统提示和系统错误。
|
||||
|
||||
|
||||
```python
|
||||
class ReplyType(Enum):
|
||||
TEXT = 1 # 文本
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
|
||||
|
||||
INFO = 9
|
||||
ERROR = 10
|
||||
class Reply:
|
||||
@@ -101,7 +127,7 @@ PS: 插件目前仅支持`itchat`
|
||||
|
||||
根据`Context`和回复`Reply`的类型,对回复的内容进行装饰。目前的装饰有以下两种:
|
||||
|
||||
- `TEXT`文本回复,根据是否在群聊中来决定是艾特接收方还是添加回复的前缀。
|
||||
- `TEXT`文本回复:如果这次消息需要的回复是`VOICE`,进行文字转语音回复之后再次装饰。 否则根据是否在群聊中来决定是艾特接收方还是添加回复的前缀。
|
||||
|
||||
- `INFO`或`ERROR`类型,会在消息前添加对应的系统提示字样。
|
||||
|
||||
@@ -110,8 +136,11 @@ PS: 插件目前仅支持`itchat`
|
||||
```python
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply_text = reply.content
|
||||
if context.get('desire_rtype') == ReplyType.VOICE:
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context['isgroup']:
|
||||
reply_text = '@' + context['msg']['ActualNickName'] + ' ' + reply_text.strip()
|
||||
reply_text = '@' + context['msg'].actual_user_nickname + ' ' + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "")+reply_text
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "")+reply_text
|
||||
@@ -130,12 +159,12 @@ PS: 插件目前仅支持`itchat`
|
||||
|
||||
目前支持三类触发事件:
|
||||
```
|
||||
1.收到消息
|
||||
---> `ON_HANDLE_CONTEXT`
|
||||
2.产生回复
|
||||
---> `ON_DECORATE_REPLY`
|
||||
3.装饰回复
|
||||
---> `ON_SEND_REPLY`
|
||||
1.收到消息
|
||||
---> `ON_HANDLE_CONTEXT`
|
||||
2.产生回复
|
||||
---> `ON_DECORATE_REPLY`
|
||||
3.装饰回复
|
||||
---> `ON_SEND_REPLY`
|
||||
4.发送回复
|
||||
```
|
||||
|
||||
@@ -151,7 +180,8 @@ PS: 插件目前仅支持`itchat`
|
||||
|
||||
### 1. 创建插件
|
||||
|
||||
在`plugins`目录下创建一个插件文件夹`hello`。然后,在该文件夹中创建一个与文件夹同名的`.py`文件`hello.py`。
|
||||
在`plugins`目录下创建一个插件文件夹`hello`。然后,在该文件夹中创建``__init__.py``文件,在``__init__.py``中将其他编写的模块文件导入。在程序启动时,插件管理器会读取``__init__.py``的所有内容。
|
||||
|
||||
```
|
||||
plugins/
|
||||
└── hello
|
||||
@@ -159,6 +189,11 @@ plugins/
|
||||
└── hello.py
|
||||
```
|
||||
|
||||
``__init__.py``的内容:
|
||||
```
|
||||
from .hello import *
|
||||
```
|
||||
|
||||
### 2. 编写插件类
|
||||
|
||||
在`hello.py`文件中,创建插件类,它继承自`Plugin`。
|
||||
@@ -213,11 +248,11 @@ class Hello(Plugin):
|
||||
if content == "Hello":
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
msg = e_context['context']['msg']
|
||||
msg:ChatMessage = e_context['context']['msg']
|
||||
if e_context['context']['isgroup']:
|
||||
reply.content = "Hello, " + msg['ActualNickName'] + " from " + msg['User'].get('NickName', "Group")
|
||||
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
|
||||
else:
|
||||
reply.content = "Hello, " + msg['User'].get('NickName', "My friend")
|
||||
reply.content = f"Hello, {msg.from_user_nickname}"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
if content == "End":
|
||||
@@ -231,5 +266,8 @@ class Hello(Plugin):
|
||||
|
||||
- 尽情将你想要的个性化功能设计为插件。
|
||||
- 一个插件目录建议只注册一个插件类。建议使用单独的仓库维护插件,便于更新。
|
||||
|
||||
在测试调试好后提交`PR`,把自己的仓库加入到[仓库源](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/source.json)中。
|
||||
|
||||
- 插件的config文件、使用说明`README.md`、`requirement.txt`等放置在插件目录中。
|
||||
- 默认优先级不要超过管理员插件`Godcmd`的优先级(999),`Godcmd`插件提供了配置管理、插件管理等功能。
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from .plugin_manager import PluginManager
|
||||
from .event import *
|
||||
from .plugin import *
|
||||
from .plugin_manager import PluginManager
|
||||
|
||||
instance = PluginManager()
|
||||
|
||||
register = instance.register
|
||||
register = instance.register
|
||||
# load_plugins = instance.load_plugins
|
||||
# emit_event = instance.emit_event
|
||||
|
||||
@@ -1,9 +1,27 @@
|
||||
|
||||
## 插件描述
|
||||
|
||||
简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。
|
||||
|
||||
`config.json`中能够填写默认的处理行为,目前行为有:
|
||||
使用前将`config.json.template`复制为`config.json`,并自行配置。
|
||||
|
||||
目前插件对消息的默认处理行为有如下两种:
|
||||
|
||||
- `ignore` : 无视这条消息。
|
||||
- `replace` : 将消息中的敏感词替换成"*",并回复违规。
|
||||
|
||||
```json
|
||||
"action": "replace",
|
||||
"reply_filter": true,
|
||||
"reply_action": "ignore"
|
||||
```
|
||||
|
||||
在以上配置项中:
|
||||
|
||||
- `action`: 对用户消息的默认处理行为
|
||||
- `reply_filter`: 是否对ChatGPT的回复也进行敏感词过滤
|
||||
- `reply_action`: 如果开启了回复过滤,对回复的默认处理行为
|
||||
|
||||
## 致谢
|
||||
|
||||
搜索功能实现来自https://github.com/toolgood/ToolGood.Words
|
||||
@@ -0,0 +1 @@
|
||||
from .banwords import *
|
||||
|
||||
@@ -2,65 +2,99 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import plugins
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
import plugins
|
||||
from plugins import *
|
||||
from common.log import logger
|
||||
from .WordsSearch import WordsSearch
|
||||
from plugins import *
|
||||
|
||||
from .lib.WordsSearch import WordsSearch
|
||||
|
||||
|
||||
@plugins.register(name="Banwords", desc="判断消息中是否有敏感词、决定是否回复。", version="1.0", author="lanvent", desire_priority= 100)
|
||||
@plugins.register(
|
||||
name="Banwords",
|
||||
desire_priority=100,
|
||||
hidden=True,
|
||||
desc="判断消息中是否有敏感词、决定是否回复。",
|
||||
version="1.0",
|
||||
author="lanvent",
|
||||
)
|
||||
class Banwords(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
curdir=os.path.dirname(__file__)
|
||||
config_path=os.path.join(curdir,"config.json")
|
||||
conf=None
|
||||
if not os.path.exists(config_path):
|
||||
conf={"action":"ignore"}
|
||||
with open(config_path,"w") as f:
|
||||
json.dump(conf,f,indent=4)
|
||||
else:
|
||||
with open(config_path,"r") as f:
|
||||
conf=json.load(f)
|
||||
# load config
|
||||
conf = super().load_config()
|
||||
curdir = os.path.dirname(__file__)
|
||||
if not conf:
|
||||
# 配置不存在则写入默认配置
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
conf = {"action": "ignore"}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(conf, f, indent=4)
|
||||
|
||||
self.searchr = WordsSearch()
|
||||
self.action = conf["action"]
|
||||
banwords_path = os.path.join(curdir,"banwords.txt")
|
||||
with open(banwords_path, 'r', encoding='utf-8') as f:
|
||||
words=[]
|
||||
banwords_path = os.path.join(curdir, "banwords.txt")
|
||||
with open(banwords_path, "r", encoding="utf-8") as f:
|
||||
words = []
|
||||
for line in f:
|
||||
word = line.strip()
|
||||
if word:
|
||||
words.append(word)
|
||||
self.searchr.SetKeywords(words)
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
if conf.get("reply_filter", True):
|
||||
self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply
|
||||
self.reply_action = conf.get("reply_action", "ignore")
|
||||
logger.info("[Banwords] inited")
|
||||
except Exception as e:
|
||||
logger.warn("Banwords init failed: %s, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ." % e)
|
||||
|
||||
|
||||
logger.warn("[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords .")
|
||||
raise e
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
|
||||
if e_context['context'].type not in [ContextType.TEXT,ContextType.IMAGE_CREATE]:
|
||||
if e_context["context"].type not in [
|
||||
ContextType.TEXT,
|
||||
ContextType.IMAGE_CREATE,
|
||||
]:
|
||||
return
|
||||
|
||||
content = e_context['context'].content
|
||||
|
||||
content = e_context["context"].content
|
||||
logger.debug("[Banwords] on_handle_context. content: %s" % content)
|
||||
if self.action == "ignore":
|
||||
f = self.searchr.FindFirst(content)
|
||||
if f:
|
||||
logger.info("Banwords: %s" % f["Keyword"])
|
||||
logger.info("[Banwords] %s in message" % f["Keyword"])
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
elif self.action == "replace":
|
||||
if self.searchr.ContainsAny(content):
|
||||
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n"+self.searchr.Replace(content))
|
||||
e_context['reply'] = reply
|
||||
reply = Reply(ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content))
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
|
||||
def on_decorate_reply(self, e_context: EventContext):
|
||||
if e_context["reply"].type not in [ReplyType.TEXT]:
|
||||
return
|
||||
|
||||
reply = e_context["reply"]
|
||||
content = reply.content
|
||||
if self.reply_action == "ignore":
|
||||
f = self.searchr.FindFirst(content)
|
||||
if f:
|
||||
logger.info("[Banwords] %s in reply" % f["Keyword"])
|
||||
e_context["reply"] = None
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
elif self.reply_action == "replace":
|
||||
if self.searchr.ContainsAny(content):
|
||||
reply = Reply(ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content))
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.CONTINUE
|
||||
return
|
||||
|
||||
def get_help_text(self, **kwargs):
|
||||
return Banwords.desc
|
||||
return "过滤消息中的敏感词。"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
{
|
||||
"action": "ignore"
|
||||
}
|
||||
"action": "replace",
|
||||
"reply_filter": true,
|
||||
"reply_action": "ignore"
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user