Compare commits
353 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a77e4bfb7a | ||
|
|
654c177333 | ||
|
|
b92669ba33 | ||
|
|
f2e4f6607d | ||
|
|
5ec909c565 | ||
|
|
a84f31d54a | ||
|
|
e0dd21406d | ||
|
|
72f5f7a0b8 | ||
|
|
e3d20085c5 | ||
|
|
8bf1aef801 | ||
|
|
5f7ade20dc | ||
|
|
70d7e52df0 | ||
|
|
8e6afa5614 | ||
|
|
a1ae3804e3 | ||
|
|
814ce7a43b | ||
|
|
628f75009e | ||
|
|
03fc8c1202 | ||
|
|
8c8e996c87 | ||
|
|
933bb0b1fb | ||
|
|
931fbc3eb5 | ||
|
|
3db5e70a3d | ||
|
|
7b19b70d90 | ||
|
|
99b8103d70 | ||
|
|
7167310ccd | ||
|
|
263667a2d4 | ||
|
|
d5cef291f6 | ||
|
|
c8d166e833 | ||
|
|
6e25782d8b | ||
|
|
c3127f7e84 | ||
|
|
7b90fb018b | ||
|
|
e8bc173cd7 | ||
|
|
4d1cdf5207 | ||
|
|
57a473364e | ||
|
|
40b62e9d38 | ||
|
|
ead5f9926b | ||
|
|
814b6753c2 | ||
|
|
ce505251f8 | ||
|
|
5d2a987aaa | ||
|
|
4d67e08723 | ||
|
|
2e71dd5fe2 | ||
|
|
c3b9643227 | ||
|
|
0aad5dc2b7 | ||
|
|
cec900168f | ||
|
|
f9b1c403d5 | ||
|
|
9024b602f5 | ||
|
|
c139fd9a57 | ||
|
|
e299b68163 | ||
|
|
7777a53a82 | ||
|
|
3e185dbbfe | ||
|
|
e8a32af369 | ||
|
|
7b0ec6687e | ||
|
|
ec1c6c7b92 | ||
|
|
8dfaa86760 | ||
|
|
323aebd1be | ||
|
|
436c038a2f | ||
|
|
ccd50ec6c0 | ||
|
|
a7541c2c0f | ||
|
|
c3a57d756c | ||
|
|
aa300a4c98 | ||
|
|
83ea7352b9 | ||
|
|
9050712cd8 | ||
|
|
8d92fdbb6e | ||
|
|
a2442ec1b9 | ||
|
|
71662c9cd9 | ||
|
|
54ff5dbcc2 | ||
|
|
4ab7bd3b51 | ||
|
|
ef3c61a297 | ||
|
|
abf79bf60c | ||
|
|
5d3cecd926 | ||
|
|
16324e7283 | ||
|
|
9f7e2e1572 | ||
|
|
857ce1d530 | ||
|
|
be0d72775d | ||
|
|
7832a2495b | ||
|
|
0506b7f735 | ||
|
|
4c0b7942f0 | ||
|
|
651c840c4a | ||
|
|
2a351ca415 | ||
|
|
49b7106d71 | ||
|
|
8bf633f539 | ||
|
|
0f8efcb4b0 | ||
|
|
c567641c5c | ||
|
|
bdc3820382 | ||
|
|
33a69a7907 | ||
|
|
a4d0e9bbc3 | ||
|
|
afc753e1d2 | ||
|
|
e641a41224 | ||
|
|
79305c0632 | ||
|
|
ef2ce3f09d | ||
|
|
71c18c04fc | ||
|
|
cf84e57f81 | ||
|
|
9421d44579 | ||
|
|
5cd2ae8cc8 | ||
|
|
22d67b3a59 | ||
|
|
e102cbb8c4 | ||
|
|
d90eeb7ee4 | ||
|
|
1989d53031 | ||
|
|
04ef0907b4 | ||
|
|
517b43561c | ||
|
|
ccb8c7227f | ||
|
|
9fbfeeb04f | ||
|
|
8b753a5a1f | ||
|
|
d25cab0627 | ||
|
|
84da0a8a35 | ||
|
|
6f665cffba | ||
|
|
aea8ac2e97 | ||
|
|
8418fa7b45 | ||
|
|
9cc4d0ee07 | ||
|
|
da60831c44 | ||
|
|
0773174a20 | ||
|
|
70e007d8ca | ||
|
|
fcc4d02c2f | ||
|
|
f4a5f00593 | ||
|
|
1170ed6566 | ||
|
|
883f0d449b | ||
|
|
f4c62e7844 | ||
|
|
f0d212a9d2 | ||
|
|
76a8974034 | ||
|
|
0614e822f4 | ||
|
|
6f682c9a2e | ||
|
|
a9fdbc31c5 | ||
|
|
086fdb5856 | ||
|
|
63c8ef4f17 | ||
|
|
736f6523c7 | ||
|
|
8b0b360d25 | ||
|
|
80b84e2ee6 | ||
|
|
b5b7d86f7b | ||
|
|
f20d704390 | ||
|
|
e4e1e2e944 | ||
|
|
6bc7eeb4cc | ||
|
|
656ed5de7b | ||
|
|
a11d695c78 | ||
|
|
c4f9acd5c5 | ||
|
|
5ef929dc42 | ||
|
|
c8cf27b544 | ||
|
|
bb5ecfc398 | ||
|
|
c91e7c35bb | ||
|
|
532d56df2d | ||
|
|
111ad44029 | ||
|
|
6b02bae957 | ||
|
|
6831743416 | ||
|
|
63e2f42636 | ||
|
|
f6e6805453 | ||
|
|
ad77ad8f2b | ||
|
|
469524e8ae | ||
|
|
f4f55d5dfd | ||
|
|
c248d0f3f4 | ||
|
|
648a04b513 | ||
|
|
bdc86c16ec | ||
|
|
21efd17c17 | ||
|
|
aaa75e7b62 | ||
|
|
6d0cef3152 | ||
|
|
c18472289f | ||
|
|
02b7c70a81 | ||
|
|
4eaa2b93c6 | ||
|
|
d347905373 | ||
|
|
f495213b2c | ||
|
|
9b125913ae | ||
|
|
da81f05804 | ||
|
|
9a371a4d4d | ||
|
|
1e92828f1a | ||
|
|
7e724b3fa3 | ||
|
|
3f5b976a87 | ||
|
|
49f2339cc2 | ||
|
|
29f1699de8 | ||
|
|
c415485801 | ||
|
|
6937673472 | ||
|
|
c4f10fe876 | ||
|
|
55ca652ad8 | ||
|
|
3effd5afd1 | ||
|
|
000c2029de | ||
|
|
ab88e3af06 | ||
|
|
b544a4c954 | ||
|
|
baff5fafec | ||
|
|
1673de73ba | ||
|
|
e68936e36e | ||
|
|
7dbd195e45 | ||
|
|
3dc22f98bf | ||
|
|
805e870c18 | ||
|
|
de2c031797 | ||
|
|
3aa571aa1b | ||
|
|
3e4969efe6 | ||
|
|
446e94df76 | ||
|
|
5b26066a4c | ||
|
|
8a80de5c3f | ||
|
|
52a490c87e | ||
|
|
29490741fd | ||
|
|
f0e416455f | ||
|
|
f7a2c97943 | ||
|
|
993853757b | ||
|
|
a3abfb987d | ||
|
|
2711fa1b1b | ||
|
|
1f7afaba07 | ||
|
|
e02c8bff81 | ||
|
|
22391ba1a5 | ||
|
|
a05781ec19 | ||
|
|
f898ed6a2a | ||
|
|
e6d0a15b54 | ||
|
|
49cff026e2 | ||
|
|
08f0023cfd | ||
|
|
e311466ee6 | ||
|
|
56789e68d7 | ||
|
|
87525bb383 | ||
|
|
bb2880191a | ||
|
|
4f1acf26d6 | ||
|
|
fc2d6b21ac | ||
|
|
b9e84fefbd | ||
|
|
91f5ffb2d9 | ||
|
|
70ff2341cb | ||
|
|
74eed93497 | ||
|
|
d02e26c014 | ||
|
|
523cade7c3 | ||
|
|
e22c183ca9 | ||
|
|
3afd99da30 | ||
|
|
f44979f983 | ||
|
|
095f9cc108 | ||
|
|
1089076fce | ||
|
|
cad3b691a9 | ||
|
|
bac21426d3 | ||
|
|
c4a35314cd | ||
|
|
7090722565 | ||
|
|
6d972c7c18 | ||
|
|
6961a88feb | ||
|
|
c41ec13984 | ||
|
|
ca8e06e562 | ||
|
|
200cd33a8e | ||
|
|
1da7991c65 | ||
|
|
fdfb7e369a | ||
|
|
c2b01cc957 | ||
|
|
5de8e94bb4 | ||
|
|
7a2c15d912 | ||
|
|
70344dd214 | ||
|
|
405372d1a7 | ||
|
|
b8c5174da5 | ||
|
|
1f6f9103d9 | ||
|
|
6431487c7a | ||
|
|
8b2d1189db | ||
|
|
b777f27cb7 | ||
|
|
b31c3b124a | ||
|
|
fa1e965fba | ||
|
|
91dc8b4d58 | ||
|
|
6d16ea8830 | ||
|
|
7db4253264 | ||
|
|
4d2b7d9bf9 | ||
|
|
8f6f4acb88 | ||
|
|
f20d84cb37 | ||
|
|
afbdf1d5d5 | ||
|
|
bc8364d594 | ||
|
|
c8d388f70f | ||
|
|
be13cc3194 | ||
|
|
a46320e744 | ||
|
|
071709d263 | ||
|
|
93a32ae5ff | ||
|
|
eee96f226f | ||
|
|
e19a8b479c | ||
|
|
9ef459112e | ||
|
|
e96474bd5c | ||
|
|
6fed719e09 | ||
|
|
99aac76618 | ||
|
|
599f458201 | ||
|
|
2f8099059c | ||
|
|
e24f177832 | ||
|
|
48cc143e88 | ||
|
|
b09b46c045 | ||
|
|
2c6583cc9c | ||
|
|
e381d1bfb8 | ||
|
|
eac619d54f | ||
|
|
a6ef3bc0ce | ||
|
|
118122c541 | ||
|
|
bfdf33ac09 | ||
|
|
fa3370df5b | ||
|
|
f1e51672c5 | ||
|
|
91f97b2728 | ||
|
|
2c542e03fe | ||
|
|
71a11b4267 | ||
|
|
ea642757db | ||
|
|
fb72b601aa | ||
|
|
27e507e744 | ||
|
|
4db19f816f | ||
|
|
096d5776d1 | ||
|
|
3d799eb4d9 | ||
|
|
e4ac3afa4d | ||
|
|
d38e4eed5b | ||
|
|
97787fac91 | ||
|
|
b494ee2f1c | ||
|
|
31ac80a074 | ||
|
|
c8896450f6 | ||
|
|
c662fa4c63 | ||
|
|
db2ee802ca | ||
|
|
d40e915e2b | ||
|
|
c0616e7efa | ||
|
|
01660597e3 | ||
|
|
c5b549f450 | ||
|
|
802d8457bb | ||
|
|
c3a3df67b0 | ||
|
|
5798aeb3cd | ||
|
|
cc81dd9172 | ||
|
|
44fdadda08 | ||
|
|
66a014150b | ||
|
|
1da596639f | ||
|
|
76614ae9e5 | ||
|
|
6ddddffc0f | ||
|
|
dd95f849d4 | ||
|
|
22c7f8fe9e | ||
|
|
3d47be1f49 | ||
|
|
5e399c46b1 | ||
|
|
38e1db7a37 | ||
|
|
8309f7cdbe | ||
|
|
b8cc62ae95 | ||
|
|
c0eb433fa2 | ||
|
|
7f857d66f6 | ||
|
|
93b14d38f4 | ||
|
|
21825faab0 | ||
|
|
1fafd39298 | ||
|
|
23b750fc4f | ||
|
|
90581c840d | ||
|
|
cac7a6228a | ||
|
|
674fbc3f69 | ||
|
|
9577bf1cc7 | ||
|
|
654ebe93e7 | ||
|
|
ecb1b3c491 | ||
|
|
c3d1711edc | ||
|
|
c12c7f10f0 | ||
|
|
f71820bf4e | ||
|
|
748c53c774 | ||
|
|
b290a71bfb | ||
|
|
3204c51eca | ||
|
|
2c4b8a44dc | ||
|
|
943aa05eaa | ||
|
|
d0fd36e7e1 | ||
|
|
f45ff5fd0a | ||
|
|
c22c7102d5 | ||
|
|
11ecfd1b41 | ||
|
|
798e30e5ac | ||
|
|
15e0702329 | ||
|
|
a2bc22c37d | ||
|
|
8093fcc64c | ||
|
|
800419e7cc | ||
|
|
a241dc6785 | ||
|
|
805bea0d5f | ||
|
|
9d394adf24 | ||
|
|
2074f27aff | ||
|
|
283ad48b86 | ||
|
|
07e10a7943 | ||
|
|
2812a5026c | ||
|
|
3a20461abf | ||
|
|
64ae3d1e21 | ||
|
|
a25d7ea65b | ||
|
|
74ebbdd761 | ||
|
|
a0427b569e | ||
|
|
916762cc8c | ||
|
|
977d3bc02e | ||
|
|
854d613a81 |
4
.gitignore
vendored
@@ -14,6 +14,9 @@ tmp
|
||||
plugins.json
|
||||
itchat.pkl
|
||||
*.log
|
||||
logs/
|
||||
workspace
|
||||
config.yaml
|
||||
user_datas.pkl
|
||||
chatgpt_tool_hub/
|
||||
plugins/**/
|
||||
@@ -30,4 +33,5 @@ plugins/banwords/lib/__pycache__
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
!plugins/agent
|
||||
client_config.json
|
||||
|
||||
170
README.md
@@ -1,42 +1,84 @@
|
||||
# 简介
|
||||
<p align="center"><img src= "https://github.com/user-attachments/assets/31fb4eab-3be4-477d-aa76-82cf62bfd12c" alt="Chatgpt-on-Wechat" width="600" /></p>
|
||||
|
||||
> 本项目是基于大模型的智能对话机器人,支持微信、企业微信、公众号、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/文心一言/讯飞星火/通义千问/Gemini/LinkAI/ZhipuAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
|
||||
<p align="center">
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/releases/latest"><img src="https://img.shields.io/github/v/release/zhayujie/chatgpt-on-wechat" alt="Latest release"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat/blob/master/LICENSE"><img src="https://img.shields.io/github/license/zhayujie/chatgpt-on-wechat" alt="License: MIT"></a>
|
||||
<a href="https://github.com/zhayujie/chatgpt-on-wechat"><img src="https://img.shields.io/github/stars/zhayujie/chatgpt-on-wechat?style=flat-square" alt="Stars"></a> <br/>
|
||||
</p>
|
||||
|
||||
chatgpt-on-wechat(简称CoW)项目是基于大模型的智能对话机器人,支持微信公众号、企业微信应用、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/Gemini/LinkAI/ChatGLM/KIMI/文心一言/讯飞星火/通义千问/LinkAI/ModelScope,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
|
||||
|
||||
# 简介
|
||||
|
||||
最新版本支持的功能如下:
|
||||
|
||||
- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信、微信公众号和、企业微信、飞书、钉钉等部署方式
|
||||
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4, claude, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM
|
||||
- [x] **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
|
||||
- [x] **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型
|
||||
- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件
|
||||
- [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、智能客服、私域助手使用,基于 [LinkAI](https://link-ai.tech) 实现
|
||||
- ✅ **多端部署:** 有多种部署方式可选择且功能完备,目前已支持微信公众号、企业微信应用、飞书、钉钉等部署方式
|
||||
- ✅ **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-4o系列, GPT-4.1系列, Claude, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM-4,Kimi, MiniMax, GiteeAI, ModelScope
|
||||
- ✅ **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
|
||||
- ✅ **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型
|
||||
- ✅ **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件
|
||||
- ✅ **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、智能客服、私域助手使用,基于 [LinkAI](https://link-ai.tech) 实现
|
||||
|
||||
# 演示
|
||||
## 声明
|
||||
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/assets/26161723/d5154020-36e3-41db-8706-40ce9f3f1b1e
|
||||
1. 本项目遵循 [MIT开源协议](/LICENSE),仅用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为
|
||||
2. 境内使用该项目时,请使用国内厂商的大模型服务,并进行必要的内容安全审核及过滤
|
||||
3. 本项目主要接入协同办公平台,推荐使用公众号、企微自建应用、钉钉、飞书等接入通道,其他通道为历史产物已不维护
|
||||
4. 任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任
|
||||
|
||||
Demo made by [Visionn](https://www.wangpc.cc/)
|
||||
## 演示
|
||||
|
||||
# 商业支持
|
||||
DEMO视频:https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
> 我们还提供企业级的 **AI应用平台**,包含知识库、Agent插件、应用管理等能力,支持多平台聚合的应用接入、客户端管理、对话管理,以及提供
|
||||
SaaS服务、私有化部署、稳定托管接入 等多种模式。
|
||||
>
|
||||
> 目前已在私域运营、智能客服、企业效率助手等场景积累了丰富的 AI 解决方案, 在电商、文教、健康、新消费等各行业沉淀了 AI 落地的最佳实践,致力于打造助力中小企业拥抱 AI 的一站式平台。
|
||||
|
||||
企业服务和商用咨询可联系产品顾问:
|
||||
|
||||
<img width="240" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg">
|
||||
|
||||
# 开源社区
|
||||
## 社区
|
||||
|
||||
添加小助手微信加入开源项目交流群:
|
||||
|
||||
<img width="240" src="./docs/images/contact.jpg">
|
||||
<img width="160" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
|
||||
# 更新日志
|
||||
<br>
|
||||
|
||||
>**2023.11.11:** [1.5.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.3) 和 [1.5.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.4),新增Google Gemini、通义千问模型
|
||||
# 企业服务
|
||||
|
||||
<a href="https://link-ai.tech" target="_blank"><img width="800" src="https://cdn.link-ai.tech/image/link-ai-intro.jpg"></a>
|
||||
|
||||
> [LinkAI](https://link-ai.tech/) 是面向企业和开发者的一站式AI应用平台,聚合多模态大模型、知识库、Agent 插件、工作流等能力,支持一键接入主流平台并进行管理,支持SaaS、私有化部署多种模式。
|
||||
>
|
||||
> LinkAI 目前 已在私域运营、智能客服、企业效率助手等场景积累了丰富的 AI 解决方案, 在电商、文教、健康、新消费、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。
|
||||
|
||||
**企业服务和产品咨询** 可联系产品顾问:
|
||||
|
||||
<img width="160" src="https://cdn.link-ai.tech/consultant-s.jpg">
|
||||
|
||||
<br>
|
||||
|
||||
# 🏷 更新日志
|
||||
|
||||
>**2025.05.23:** [1.7.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.6) 优化web网页channel、新增[AgentMesh多智能体插件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/agent/README.md)、百度语音合成优化、企微应用`access_token`获取优化、支持`claude-4-sonnet`和`claude-4-opus`模型
|
||||
|
||||
>**2025.04.11:** [1.7.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.5) 新增支持 [wechatferry](https://github.com/zhayujie/chatgpt-on-wechat/pull/2562) 协议、新增 deepseek 模型、新增支持腾讯云语音能力、新增支持 ModelScope 和 Gitee-AI API接口
|
||||
|
||||
>**2024.12.13:** [1.7.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.4) 新增 Gemini 2.0 模型、新增web channel、解决内存泄漏问题、解决 `#reloadp` 命令重载不生效问题
|
||||
|
||||
>**2024.10.31:** [1.7.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.3) 程序稳定性提升、数据库功能、Claude模型优化、linkai插件优化、离线通知
|
||||
|
||||
>**2024.09.26:** [1.7.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.2) 和 [1.7.1版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.1) 文心,讯飞等模型优化、o1 模型、快速安装和管理脚本
|
||||
|
||||
>**2024.08.02:** [1.7.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.0) 新增 讯飞4.0 模型、知识库引用来源展示、相关插件优化
|
||||
|
||||
>**2024.07.19:** [1.6.9版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.9) 新增 gpt-4o-mini 模型、阿里语音识别、企微应用渠道路由优化
|
||||
|
||||
>**2024.07.05:** [1.6.8版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.8) 和 [1.6.7版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.7),Claude3.5, Gemini 1.5 Pro, MiniMax模型、工作流图片输入、模型列表完善
|
||||
|
||||
>**2024.06.04:** [1.6.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.6) 和 [1.6.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.5),gpt-4o模型、钉钉流式卡片、讯飞语音识别/合成
|
||||
|
||||
>**2024.04.26:** [1.6.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.0),新增 Kimi 接入、gpt-4-turbo版本升级、文件总结和语音识别问题修复
|
||||
|
||||
>**2024.03.26:** [1.5.8版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.8) 和 [1.5.7版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.7),新增 GLM-4、Claude-3 模型,edge-tts 语音支持
|
||||
|
||||
>**2024.01.26:** [1.5.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.6) 和 [1.5.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.5),钉钉接入,tool插件升级,4-turbo模型更新
|
||||
|
||||
>**2023.11.11:** [1.5.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.3) 和 [1.5.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.4),新增通义千问模型、Google Gemini
|
||||
|
||||
>**2023.11.10:** [1.5.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.2),新增飞书通道、图像识别对话、黑名单配置
|
||||
|
||||
@@ -48,25 +90,22 @@ SaaS服务、私有化部署、稳定托管接入 等多种模式。
|
||||
|
||||
>**2023.08.08:** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图
|
||||
|
||||
>**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,并接入微信、公众号及企业微信中,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
>**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
>**2023.04.26:** 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatcom/README.md)。(contributed by [@lanvent](https://github.com/lanvent) in [#944](https://github.com/zhayujie/chatgpt-on-wechat/pull/944))
|
||||
更早更新日志查看: [归档日志](/docs/version/old-version.md)
|
||||
|
||||
>**2023.04.05:** 支持微信公众号部署,兼容插件,并支持语音图片交互,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
|
||||
<br>
|
||||
|
||||
>**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
|
||||
# 🚀 快速开始
|
||||
|
||||
>**2023.03.25:** 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 [#578](https://github.com/zhayujie/chatgpt-on-wechat/issues/578)。(contributed by [@lanvent](https://github.com/lanvent) in [#565](https://github.com/zhayujie/chatgpt-on-wechat/pull/565))
|
||||
- 快速开始详细文档:[项目搭建文档](https://docs.link-ai.tech/cow/quick-start)
|
||||
|
||||
>**2023.03.09:** 基于 `whisper API`(后续已接入更多的语音`API`服务) 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
|
||||
|
||||
>**2023.02.09:** 扫码登录存在账号限制风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
|
||||
|
||||
# 快速开始
|
||||
|
||||
快速开始文档:[项目搭建文档](https://docs.link-ai.tech/cow/quick-start)
|
||||
|
||||
## 准备
|
||||
- 快速安装脚本,详细使用指导:[一键安装启动脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/%E4%B8%80%E9%94%AE%E5%AE%89%E8%A3%85%E5%90%AF%E5%8A%A8%E8%84%9A%E6%9C%AC)
|
||||
```bash
|
||||
bash <(curl -sS https://cdn.link-ai.tech/code/cow/install.sh)
|
||||
```
|
||||
- 项目管理脚本,详细使用指导:[项目管理脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/%E9%A1%B9%E7%9B%AE%E7%AE%A1%E7%90%86%E8%84%9A%E6%9C%AC)
|
||||
## 一、准备
|
||||
|
||||
### 1. 账号注册
|
||||
|
||||
@@ -74,7 +113,7 @@ SaaS服务、私有化部署、稳定托管接入 等多种模式。
|
||||
|
||||
> 默认对话模型是 openai 的 gpt-3.5-turbo,计费方式是约每 1000tokens (约750个英文单词 或 500汉字,包含请求和回复) 消耗 $0.002,图片生成是Dell E模型,每张消耗 $0.016。
|
||||
|
||||
项目同时也支持使用 LinkAI 接口,无需代理,可使用 文心、讯飞、GPT-3、GPT-4 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结和对话等能力。修改配置即可一键切换,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
项目同时也支持使用 LinkAI 接口,无需代理,可使用 Kimi、文心、讯飞、GPT-3.5、GPT-4o 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结、工作流等能力。修改配置即可一键使用,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
### 2.运行环境
|
||||
|
||||
@@ -105,7 +144,7 @@ pip3 install -r requirements-optional.txt
|
||||
```
|
||||
> 如果某项依赖安装失败可注释掉对应的行再继续
|
||||
|
||||
## 配置
|
||||
## 二、配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
|
||||
@@ -113,13 +152,14 @@ pip3 install -r requirements-optional.txt
|
||||
cp config-template.json config.json
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(请去掉注释):
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证JSON格式的完整):
|
||||
|
||||
```bash
|
||||
# config.json文件内容示例
|
||||
{
|
||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
||||
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"model": "gpt-4o-mini", # 模型名称, 支持 gpt-4o-mini, gpt-4.1, gpt-4o, wenxin, xunfei, glm-4, claude-3-7-sonnet-latest, moonshot等
|
||||
"open_ai_api_key": "YOUR API KEY", # 如果使用openAI模型则填入上面创建的 OpenAI API KEY
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # OpenAI接口代理地址
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
@@ -130,15 +170,13 @@ pip3 install -r requirements-optional.txt
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
|
||||
"azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称
|
||||
"azure_api_version": "", # 采用Azure ChatGPT时,API版本
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
"voice_reply_voice": false, # 是否使用语音回复语音
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
# 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
|
||||
"linkai_api_key": "", # LinkAI Api Key
|
||||
"linkai_app_code": "" # LinkAI 应用code
|
||||
"linkai_app_code": "" # LinkAI 应用或工作流code
|
||||
}
|
||||
```
|
||||
**配置说明:**
|
||||
@@ -159,11 +197,11 @@ pip3 install -r requirements-optional.txt
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊),但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊)
|
||||
|
||||
**4.其他配置**
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`, `wenxin` , `claude` , `xunfei`(其中gpt-4 api暂未完全开放,申请通过后可使用)
|
||||
+ `model`: 模型名称,目前支持 `gpt-4o-mini`, `gpt-4.1`, `gpt-4o`, `gpt-3.5-turbo`, `wenxin` , `claude` , `gemini`, `glm-4`, `xunfei`, `moonshot`等,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
@@ -171,7 +209,7 @@ pip3 install -r requirements-optional.txt
|
||||
+ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
|
||||
+ `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
|
||||
+ `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
|
||||
+ `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。
|
||||
+ `hot_reload`: 程序退出后,暂存等于状态,默认关闭。
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
|
||||
@@ -179,11 +217,11 @@ pip3 install -r requirements-optional.txt
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用code,选填
|
||||
+ `linkai_app_code`: LinkAI 应用或工作流的code,选填
|
||||
|
||||
**本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
|
||||
## 运行
|
||||
## 三、运行
|
||||
|
||||
### 1.本地运行
|
||||
|
||||
@@ -193,7 +231,7 @@ pip3 install -r requirements-optional.txt
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
|
||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
终端输出二维码后,进行扫码登录,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的账号需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
@@ -215,7 +253,7 @@ nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通
|
||||
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功的表现是执行 `docker -v` 和 `docker-compose version` (或 docker compose version) 可以查看到版本号,可前往 [docker官网](https://docs.docker.com/engine/install/) 进行下载。
|
||||
|
||||
#### (1) 下载 docker-compose.yml 文件
|
||||
**(1) 下载 docker-compose.yml 文件**
|
||||
|
||||
```bash
|
||||
wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml
|
||||
@@ -223,7 +261,7 @@ wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 修改所需配置,如 `OPEN_AI_API_KEY` 和 `GROUP_NAME_WHITE_LIST` 等。
|
||||
|
||||
#### (2) 启动容器
|
||||
**(2) 启动容器**
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
|
||||
@@ -244,7 +282,7 @@ sudo docker compose up -d
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
#### (3) 插件使用
|
||||
**(3) 插件使用**
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
@@ -253,7 +291,7 @@ sudo docker logs -f chatgpt-on-wechat
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
|
||||
**注**:采用docker方式部署的详细教程可以参考:[docker部署CoW项目](https://www.wangpc.cc/ai/docker-deploy-cow/)
|
||||
### 4. Railway部署
|
||||
|
||||
> Railway 每月提供5刀和最多500小时的免费额度。 (07.11更新: 目前大部分账号已无法免费部署)
|
||||
@@ -266,16 +304,22 @@ volumes:
|
||||
|
||||
[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
|
||||
## 常见问题
|
||||
<br>
|
||||
|
||||
# 🔎 常见问题
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (beta版本,语料完善中,回复仅供参考)
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (语料持续完善中,回复仅供参考)
|
||||
|
||||
## 开发
|
||||
# 🛠️ 开发
|
||||
|
||||
欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py) 实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
|
||||
|
||||
## 联系
|
||||
# ✉ 联系
|
||||
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品顾问](https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg)咨询。
|
||||
|
||||
# 🌟 贡献者
|
||||
|
||||

|
||||
|
||||
2
app.py
@@ -27,7 +27,7 @@ def sigterm_handler_wrap(_signo):
|
||||
|
||||
def start_channel(channel_name: str):
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework",
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp","web", "wechatmp_service", "wechatcom_app", "wework",
|
||||
const.FEISHU, const.DINGTALK]:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests, json
|
||||
import requests
|
||||
import json
|
||||
from common import const
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
@@ -16,9 +18,20 @@ class BaiduWenxinBot(Bot):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant"
|
||||
if conf().get("model") and conf().get("model") == "wenxin-4":
|
||||
wenxin_model = "completions_pro"
|
||||
wenxin_model = conf().get("baidu_wenxin_model")
|
||||
self.prompt_enabled = conf().get("baidu_wenxin_prompt_enabled")
|
||||
if self.prompt_enabled:
|
||||
self.prompt = conf().get("character_desc", "")
|
||||
if self.prompt == "":
|
||||
logger.warn("[BAIDU] Although you enabled model prompt, character_desc is not specified.")
|
||||
if wenxin_model is not None:
|
||||
wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant"
|
||||
else:
|
||||
if conf().get("model") and conf().get("model") == const.WEN_XIN:
|
||||
wenxin_model = "completions"
|
||||
elif conf().get("model") and conf().get("model") == const.WEN_XIN_4:
|
||||
wenxin_model = "completions_pro"
|
||||
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=wenxin_model)
|
||||
|
||||
def reply(self, query, context=None):
|
||||
@@ -76,7 +89,7 @@ class BaiduWenxinBot(Bot):
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
payload = {'messages': session.messages}
|
||||
payload = {'messages': session.messages, 'system': self.prompt} if self.prompt_enabled else {'messages': session.messages}
|
||||
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
|
||||
response_text = json.loads(response.text)
|
||||
logger.info(f"[BAIDU] response text={response_text}")
|
||||
@@ -94,7 +107,7 @@ class BaiduWenxinBot(Bot):
|
||||
logger.warn("[BAIDU] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
result = {"completion_tokens": 0, "content": "出错了: {}".format(e)}
|
||||
result = {"total_tokens": 0, "completion_tokens": 0, "content": "出错了: {}".format(e)}
|
||||
return result
|
||||
|
||||
def get_access_token(self):
|
||||
|
||||
@@ -43,11 +43,15 @@ def create_bot(bot_type):
|
||||
elif bot_type == const.CLAUDEAI:
|
||||
from bot.claude.claude_ai_bot import ClaudeAIBot
|
||||
return ClaudeAIBot()
|
||||
|
||||
elif bot_type == const.CLAUDEAPI:
|
||||
from bot.claudeapi.claude_api_bot import ClaudeAPIBot
|
||||
return ClaudeAPIBot()
|
||||
elif bot_type == const.QWEN:
|
||||
from bot.ali.ali_qwen_bot import AliQwenBot
|
||||
return AliQwenBot()
|
||||
|
||||
elif bot_type == const.QWEN_DASHSCOPE:
|
||||
from bot.dashscope.dashscope_bot import DashscopeBot
|
||||
return DashscopeBot()
|
||||
elif bot_type == const.GEMINI:
|
||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
||||
return GoogleGeminiBot()
|
||||
@@ -56,5 +60,17 @@ def create_bot(bot_type):
|
||||
from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
|
||||
return ZHIPUAIBot()
|
||||
|
||||
elif bot_type == const.MOONSHOT:
|
||||
from bot.moonshot.moonshot_bot import MoonshotBot
|
||||
return MoonshotBot()
|
||||
|
||||
elif bot_type == const.MiniMax:
|
||||
from bot.minimax.minimax_bot import MinimaxBot
|
||||
return MinimaxBot()
|
||||
|
||||
elif bot_type == const.MODELSCOPE:
|
||||
from bot.modelscope.modelscope_bot import ModelScopeBot
|
||||
return ModelScopeBot()
|
||||
|
||||
|
||||
raise RuntimeError
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
import openai
|
||||
import openai.error
|
||||
import requests
|
||||
|
||||
from common import const
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
@@ -15,7 +15,7 @@ from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from config import conf, load_config
|
||||
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot, OpenAIImage):
|
||||
@@ -30,10 +30,12 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
openai.proxy = proxy
|
||||
if conf().get("rate_limit_chatgpt"):
|
||||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
||||
|
||||
conf_model = conf().get("model") or "gpt-3.5-turbo"
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
# o1相关模型不支持system prompt,暂时用文心模型的session
|
||||
|
||||
self.args = {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"model": conf_model, # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p": conf().get("top_p", 1),
|
||||
@@ -42,6 +44,12 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
# o1相关模型固定了部分参数,暂时去掉
|
||||
if conf_model in [const.O1, const.O1_MINI]:
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or const.O1_MINI)
|
||||
remove_keys = ["temperature", "top_p", "frequency_penalty", "presence_penalty"]
|
||||
for key in remove_keys:
|
||||
self.args.pop(key, None) # 如果键不存在,使用 None 来避免抛出错误
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
@@ -171,24 +179,70 @@ class AzureChatGPTBot(ChatGPTBot):
|
||||
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
api_version = "2022-08-03-preview"
|
||||
url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version)
|
||||
api_key = api_key or openai.api_key
|
||||
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
||||
try:
|
||||
body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")}
|
||||
submission = requests.post(url, headers=headers, json=body)
|
||||
operation_location = submission.headers["Operation-Location"]
|
||||
retry_after = submission.headers["Retry-after"]
|
||||
status = ""
|
||||
image_url = ""
|
||||
while status != "Succeeded":
|
||||
logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
|
||||
time.sleep(int(retry_after))
|
||||
response = requests.get(operation_location, headers=headers)
|
||||
status = response.json()["status"]
|
||||
image_url = response.json()["result"]["contentUrl"]
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.error("create image error: {}".format(e))
|
||||
return False, "图片生成失败"
|
||||
text_to_image_model = conf().get("text_to_image")
|
||||
if text_to_image_model == "dall-e-2":
|
||||
api_version = "2023-06-01-preview"
|
||||
endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base")
|
||||
# 检查endpoint是否以/结尾
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint = endpoint + "/"
|
||||
url = "{}openai/images/generations:submit?api-version={}".format(endpoint, api_version)
|
||||
api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key")
|
||||
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
||||
try:
|
||||
body = {"prompt": query, "size": conf().get("image_create_size", "256x256"),"n": 1}
|
||||
submission = requests.post(url, headers=headers, json=body)
|
||||
operation_location = submission.headers['operation-location']
|
||||
status = ""
|
||||
while (status != "succeeded"):
|
||||
if retry_count > 3:
|
||||
return False, "图片生成失败"
|
||||
response = requests.get(operation_location, headers=headers)
|
||||
status = response.json()['status']
|
||||
retry_count += 1
|
||||
image_url = response.json()['result']['data'][0]['url']
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.error("create image error: {}".format(e))
|
||||
return False, "图片生成失败"
|
||||
elif text_to_image_model == "dall-e-3":
|
||||
api_version = conf().get("azure_api_version", "2024-02-15-preview")
|
||||
endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base")
|
||||
# 检查endpoint是否以/结尾
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint = endpoint + "/"
|
||||
url = "{}openai/deployments/{}/images/generations?api-version={}".format(endpoint, conf().get("azure_openai_dalle_deployment_id","text_to_image"),api_version)
|
||||
api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key")
|
||||
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
||||
try:
|
||||
body = {"prompt": query, "size": conf().get("image_create_size", "1024x1024"), "quality": conf().get("dalle3_image_quality", "standard")}
|
||||
response = requests.post(url, headers=headers, json=body)
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
data = response.json()
|
||||
|
||||
# 检查响应中是否包含图像 URL
|
||||
if 'data' in data and len(data['data']) > 0 and 'url' in data['data'][0]:
|
||||
image_url = data['data'][0]['url']
|
||||
return True, image_url
|
||||
else:
|
||||
error_message = "响应中没有图像 URL"
|
||||
logger.error(error_message)
|
||||
return False, "图片生成失败"
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
# 捕获所有请求相关的异常
|
||||
try:
|
||||
error_detail = response.json().get('error', {}).get('message', str(e))
|
||||
except ValueError:
|
||||
error_detail = str(e)
|
||||
error_message = f"{error_detail}"
|
||||
logger.error(error_message)
|
||||
return False, error_message
|
||||
|
||||
except Exception as e:
|
||||
# 捕获所有其他异常
|
||||
error_message = f"生成图像时发生错误: {e}"
|
||||
logger.error(error_message)
|
||||
return False, "图片生成失败"
|
||||
else:
|
||||
return False, "图片生成失败,未配置text_to_image参数"
|
||||
|
||||
@@ -57,18 +57,20 @@ class ChatGPTSession(Session):
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
|
||||
if model in ["wenxin", "xunfei", const.GEMINI]:
|
||||
if model in ["wenxin", "xunfei"] or model.startswith(const.GEMINI):
|
||||
return num_tokens_by_character(messages)
|
||||
|
||||
import tiktoken
|
||||
|
||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]:
|
||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot", const.LINKAI_35]:
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]:
|
||||
"gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW, const.GPT4_TURBO_01_25,
|
||||
const.GPT_4o, const.GPT_4O_0806, const.GPT_4o_MINI, const.LINKAI_4o, const.LINKAI_4_TURBO]:
|
||||
return num_tokens_from_messages(messages, model="gpt-4")
|
||||
|
||||
elif model.startswith("claude-3"):
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
@@ -81,7 +83,7 @@ def num_tokens_from_messages(messages, model):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
|
||||
logger.debug(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
|
||||
132
bot/claudeapi/claude_api_bot.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import anthropic
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common import const
|
||||
from config import conf
|
||||
|
||||
user_session = dict()
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
proxy = conf().get("proxy", None)
|
||||
base_url = conf().get("open_ai_api_base", None) # 复用"open_ai_api_base"参数作为base_url
|
||||
self.claudeClient = anthropic.Anthropic(
|
||||
api_key=conf().get("claude_api_key"),
|
||||
proxies=proxy if proxy else None,
|
||||
base_url=base_url if base_url else None
|
||||
)
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[CLAUDE_API] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
if query == "#清除记忆":
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
result = self.reply_text(session)
|
||||
logger.info(result)
|
||||
total_tokens, completion_tokens, reply_content = (
|
||||
result["total_tokens"],
|
||||
result["completion_tokens"],
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[CLAUDE_API] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
||||
try:
|
||||
actual_model = self._model_mapping(conf().get("model"))
|
||||
response = self.claudeClient.messages.create(
|
||||
model=actual_model,
|
||||
max_tokens=4096,
|
||||
system=conf().get("character_desc", ""),
|
||||
messages=session.messages
|
||||
)
|
||||
# response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response.usage.input_tokens+response.usage.output_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
logger.info("[CLAUDE_API] reply={}".format(res_content))
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"total_tokens": 0, "completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CLAUDE_API] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.warn("[CLAUDE_API] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
def _model_mapping(self, model) -> str:
|
||||
if model == "claude-3-opus":
|
||||
return const.CLAUDE_3_OPUS
|
||||
elif model == "claude-3-sonnet":
|
||||
return const.CLAUDE_3_SONNET
|
||||
elif model == "claude-3-haiku":
|
||||
return const.CLAUDE_3_HAIKU
|
||||
elif model == "claude-3.5-sonnet":
|
||||
return const.CLAUDE_35_SONNET
|
||||
return model
|
||||
117
bot/dashscope/dashscope_bot.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from .dashscope_session import DashscopeSession
|
||||
import os
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
|
||||
|
||||
|
||||
dashscope_models = {
|
||||
"qwen-turbo": dashscope.Generation.Models.qwen_turbo,
|
||||
"qwen-plus": dashscope.Generation.Models.qwen_plus,
|
||||
"qwen-max": dashscope.Generation.Models.qwen_max,
|
||||
"qwen-bailian-v1": dashscope.Generation.Models.bailian_v1
|
||||
}
|
||||
# ZhipuAI对话模型API
|
||||
class DashscopeBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(DashscopeSession, model=conf().get("model") or "qwen-plus")
|
||||
self.model_name = conf().get("model") or "qwen-plus"
|
||||
self.api_key = conf().get("dashscope_api_key")
|
||||
os.environ["DASHSCOPE_API_KEY"] = self.api_key
|
||||
self.client = dashscope.Generation
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[DASHSCOPE] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[DASHSCOPE] session query={}".format(session.messages))
|
||||
|
||||
reply_content = self.reply_text(session)
|
||||
logger.debug(
|
||||
"[DASHSCOPE] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[DASHSCOPE] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: DashscopeSession, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
dashscope.api_key = self.api_key
|
||||
response = self.client.call(
|
||||
dashscope_models[self.model_name],
|
||||
messages=session.messages,
|
||||
result_format="message"
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
content = response.output.choices[0]["message"]["content"]
|
||||
return {
|
||||
"total_tokens": response.usage["total_tokens"],
|
||||
"completion_tokens": response.usage["output_tokens"],
|
||||
"content": content,
|
||||
}
|
||||
else:
|
||||
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||||
response.request_id, response.status_code,
|
||||
response.code, response.message
|
||||
))
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
51
bot/dashscope/dashscope_session.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class DashscopeSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="qwen-turbo"):
|
||||
super().__init__(session_id)
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages):
|
||||
# 只是大概,具体计算规则:https://help.aliyun.com/zh/dashscope/developer-reference/token-api?spm=a2c4g.11186623.0.0.4d8b12b0BkP3K9
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -13,7 +13,9 @@ from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
@@ -22,9 +24,11 @@ class GoogleGeminiBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api_key = conf().get("gemini_api_key")
|
||||
# 复用文心的token计算方式
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
|
||||
# 复用chatGPT的token计算方式
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.model = conf().get("model") or "gemini-pro"
|
||||
if self.model == "gemini":
|
||||
self.model = "gemini-pro"
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
try:
|
||||
if context.type != ContextType.TEXT:
|
||||
@@ -33,18 +37,45 @@ class GoogleGeminiBot(Bot):
|
||||
logger.info(f"[Gemini] query={query}")
|
||||
session_id = context["session_id"]
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
|
||||
gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
|
||||
logger.debug(f"[Gemini] messages={gemini_messages}")
|
||||
genai.configure(api_key=self.api_key)
|
||||
model = genai.GenerativeModel('gemini-pro')
|
||||
response = model.generate_content(gemini_messages)
|
||||
reply_text = response.text
|
||||
self.sessions.session_reply(reply_text, session_id)
|
||||
logger.info(f"[Gemini] reply={reply_text}")
|
||||
return Reply(ReplyType.TEXT, reply_text)
|
||||
model = genai.GenerativeModel(self.model)
|
||||
|
||||
# 添加安全设置
|
||||
safety_settings = {
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
|
||||
# 生成回复,包含安全设置
|
||||
response = model.generate_content(
|
||||
gemini_messages,
|
||||
safety_settings=safety_settings
|
||||
)
|
||||
if response.candidates and response.candidates[0].content:
|
||||
reply_text = response.candidates[0].content.parts[0].text
|
||||
logger.info(f"[Gemini] reply={reply_text}")
|
||||
self.sessions.session_reply(reply_text, session_id)
|
||||
return Reply(ReplyType.TEXT, reply_text)
|
||||
else:
|
||||
# 没有有效响应内容,可能内容被屏蔽,输出安全评分
|
||||
logger.warning("[Gemini] No valid response generated. Checking safety ratings.")
|
||||
if hasattr(response, 'candidates') and response.candidates:
|
||||
for rating in response.candidates[0].safety_ratings:
|
||||
logger.warning(f"Safety rating: {rating.category} - {rating.probability}")
|
||||
error_message = "No valid response generated due to safety constraints."
|
||||
self.sessions.session_reply(error_message, session_id)
|
||||
return Reply(ReplyType.ERROR, error_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Gemini] fetch reply error, may contain unsafe content")
|
||||
logger.error(e)
|
||||
|
||||
logger.error(f"[Gemini] Error generating response: {str(e)}", exc_info=True)
|
||||
error_message = "Failed to invoke [Gemini] api!"
|
||||
self.sessions.session_reply(error_message, session_id)
|
||||
return Reply(ReplyType.ERROR, error_message)
|
||||
|
||||
def _convert_to_gemini_messages(self, messages: list):
|
||||
res = []
|
||||
for msg in messages:
|
||||
@@ -52,6 +83,8 @@ class GoogleGeminiBot(Bot):
|
||||
role = "user"
|
||||
elif msg.get("role") == "assistant":
|
||||
role = "model"
|
||||
elif msg.get("role") == "system":
|
||||
role = "user"
|
||||
else:
|
||||
continue
|
||||
res.append({
|
||||
@@ -60,12 +93,19 @@ class GoogleGeminiBot(Bot):
|
||||
})
|
||||
return res
|
||||
|
||||
def _filter_messages(self, messages: list):
|
||||
@staticmethod
|
||||
def filter_messages(messages: list):
|
||||
res = []
|
||||
turn = "user"
|
||||
if not messages:
|
||||
return res
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
message = messages[i]
|
||||
if message.get("role") != turn:
|
||||
role = message.get("role")
|
||||
if role == "system":
|
||||
res.insert(0, message)
|
||||
continue
|
||||
if role != turn:
|
||||
continue
|
||||
res.insert(0, message)
|
||||
if turn == "user":
|
||||
|
||||
@@ -92,7 +92,8 @@ class LinkAIBot(Bot):
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"session_id": session_id,
|
||||
"channel_type": conf().get("channel_type")
|
||||
"sender_id": session_id,
|
||||
"channel_type": conf().get("channel_type", "wx")
|
||||
}
|
||||
try:
|
||||
from linkai import LinkAIClient
|
||||
@@ -121,7 +122,7 @@ class LinkAIBot(Bot):
|
||||
headers = {"Authorization": "Bearer " + linkai_api_key}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
@@ -129,9 +130,12 @@ class LinkAIBot(Bot):
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
|
||||
|
||||
res_code = response.get('code')
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}, res_code={res_code}")
|
||||
if res_code == 429:
|
||||
logger.warn(f"[LINKAI] 用户访问超出限流配置,sender_id={body.get('sender_id')}")
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
|
||||
agent_suffix = self._fetch_agent_suffix(response)
|
||||
if agent_suffix:
|
||||
reply_content += agent_suffix
|
||||
@@ -143,9 +147,9 @@ class LinkAIBot(Bot):
|
||||
if response["choices"][0].get("img_urls"):
|
||||
thread = threading.Thread(target=self._send_image, args=(context.get("channel"), context, response["choices"][0].get("img_urls")))
|
||||
thread.start()
|
||||
if response["choices"][0].get("text_content"):
|
||||
reply_content = response["choices"][0].get("text_content")
|
||||
reply_content = self._process_url(reply_content)
|
||||
reply_content = response["choices"][0].get("text_content")
|
||||
if reply_content:
|
||||
reply_content = self._process_url(reply_content)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
|
||||
else:
|
||||
@@ -160,7 +164,10 @@ class LinkAIBot(Bot):
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
return Reply(ReplyType.TEXT, "提问太快啦,请休息一下再问我吧")
|
||||
error_reply = "提问太快啦,请休息一下再问我吧"
|
||||
if res.status_code == 409:
|
||||
error_reply = "这个问题我还没有学会,请问我其它问题吧"
|
||||
return Reply(ReplyType.TEXT, error_reply)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
@@ -254,7 +261,7 @@ class LinkAIBot(Bot):
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
@@ -297,7 +304,7 @@ class LinkAIBot(Bot):
|
||||
def _fetch_app_info(self, app_code: str):
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
params = {"app_code": app_code}
|
||||
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
|
||||
if res.status_code == 200:
|
||||
@@ -319,7 +326,7 @@ class LinkAIBot(Bot):
|
||||
"response_format": "url",
|
||||
"img_proxy": conf().get("image_proxy")
|
||||
}
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/images/generations"
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/images/generations"
|
||||
res = requests.post(url, headers=headers, json=data, timeout=(5, 90))
|
||||
t2 = time.time()
|
||||
image_url = res.json()["data"][0]["url"]
|
||||
@@ -392,6 +399,7 @@ class LinkAIBot(Bot):
|
||||
return
|
||||
max_send_num = conf().get("max_media_send_count")
|
||||
send_interval = conf().get("media_send_interval")
|
||||
file_type = (".pdf", ".doc", ".docx", ".csv", ".xls", ".xlsx", ".txt", ".rtf", ".ppt", ".pptx")
|
||||
try:
|
||||
i = 0
|
||||
for url in image_urls:
|
||||
@@ -400,7 +408,7 @@ class LinkAIBot(Bot):
|
||||
i += 1
|
||||
if url.endswith(".mp4"):
|
||||
reply_type = ReplyType.VIDEO_URL
|
||||
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"):
|
||||
elif url.endswith(file_type):
|
||||
reply_type = ReplyType.FILE
|
||||
url = _download_file(url)
|
||||
if not url:
|
||||
|
||||
151
bot/minimax/minimax_bot.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.minimax.minimax_session import MinimaxSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
import requests
|
||||
from common import const
|
||||
|
||||
|
||||
# ZhipuAI对话模型API
|
||||
class MinimaxBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.args = {
|
||||
"model": conf().get("model") or "abab6.5", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
|
||||
"top_p": conf().get("top_p", 0.95), # 使用默认值
|
||||
}
|
||||
self.api_key = conf().get("Minimax_api_key")
|
||||
self.group_id = conf().get("Minimax_group_id")
|
||||
self.base_url = conf().get("Minimax_base_url", f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={self.group_id}")
|
||||
# tokens_to_generate/bot_setting/reply_constraints可自行修改
|
||||
self.request_body = {
|
||||
"model": self.args["model"],
|
||||
"tokens_to_generate": 2048,
|
||||
"reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"},
|
||||
"messages": [],
|
||||
"bot_setting": [
|
||||
{
|
||||
"bot_name": "MM智能助理",
|
||||
"content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。",
|
||||
}
|
||||
],
|
||||
}
|
||||
self.sessions = SessionManager(MinimaxSession, model=const.MiniMax)
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
# acquire reply content
|
||||
logger.info("[Minimax_AI] query={}".format(query))
|
||||
if context.type == ContextType.TEXT:
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[Minimax_AI] session query={}".format(session))
|
||||
|
||||
model = context.get("Minimax_model")
|
||||
new_args = self.args.copy()
|
||||
if model:
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, args=new_args)
|
||||
logger.debug(
|
||||
"[Minimax_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[Minimax_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: MinimaxSession, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + self.api_key}
|
||||
self.request_body["messages"].extend(session.messages)
|
||||
logger.info("[Minimax_AI] request_body={}".format(self.request_body))
|
||||
# logger.info("[Minimax_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
res = requests.post(self.base_url, headers=headers, json=self.request_body)
|
||||
|
||||
# self.request_body["messages"].extend(response.json()["choices"][0]["messages"])
|
||||
if res.status_code == 200:
|
||||
response = res.json()
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["total_tokens"],
|
||||
"content": response["reply"],
|
||||
}
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[Minimax_AI] chat failed, status_code={res.status_code}, " f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
need_retry = False
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
logger.warn(f"[Minimax_AI] do retry, times={retry_count}")
|
||||
need_retry = retry_count < 2
|
||||
elif res.status_code == 401:
|
||||
result["content"] = "授权失败,请检查API Key是否正确"
|
||||
elif res.status_code == 429:
|
||||
result["content"] = "请求过于频繁,请稍后再试"
|
||||
need_retry = retry_count < 2
|
||||
else:
|
||||
need_retry = False
|
||||
|
||||
if need_retry:
|
||||
time.sleep(3)
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
72
bot/minimax/minimax_session.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
e.g.
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class MinimaxSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="minimax"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# self.reset()
|
||||
|
||||
def add_query(self, query):
|
||||
user_item = {"sender_type": "USER", "sender_name": self.session_id, "text": query}
|
||||
self.messages.append(user_item)
|
||||
|
||||
def add_reply(self, reply):
|
||||
assistant_item = {"sender_type": "BOT", "sender_name": "MM智能助理", "text": reply}
|
||||
self.messages.append(assistant_item)
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "BOT":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "USER":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
# 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词"
|
||||
# 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html
|
||||
# 目前根据字符串长度粗略估计token数,不影响正常使用
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["text"])
|
||||
return tokens
|
||||
277
bot/modelscope/modelscope_bot.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
import json
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from .modelscope_session import ModelScopeSession
|
||||
import requests
|
||||
|
||||
|
||||
# ModelScope对话模型API
|
||||
class ModelScopeBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ModelScopeSession, model=conf().get("model") or "Qwen/Qwen2.5-7B-Instruct")
|
||||
model = conf().get("model") or "Qwen/Qwen2.5-7B-Instruct"
|
||||
if model == "modelscope":
|
||||
model = "Qwen/Qwen2.5-7B-Instruct"
|
||||
self.args = {
|
||||
"model": model, # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
|
||||
"top_p": conf().get("top_p", 1.0), # 使用默认值
|
||||
}
|
||||
self.api_key = conf().get("modelscope_api_key")
|
||||
self.base_url = conf().get("modelscope_base_url", "https://api-inference.modelscope.cn/v1/chat/completions")
|
||||
"""
|
||||
需要获取ModelScope支持API-inference的模型名称列表,请到魔搭社区官网模型中心查看 https://modelscope.cn/models?filter=inference_type&page=1。
|
||||
或者使用命令 curl https://api-inference.modelscope.cn/v1/models 对模型列表和ID进行获取。查看commend/const.py文件也可以获取模型列表。
|
||||
获取ModelScope的免费API Key,请到魔搭社区官网用户中心查看获取方式 https://modelscope.cn/docs/model-service/API-Inference/intro。
|
||||
"""
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[MODELSCOPE_AI] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[MODELSCOPE_AI] session query={}".format(session.messages))
|
||||
|
||||
model = context.get("modelscope_model")
|
||||
new_args = self.args.copy()
|
||||
if model:
|
||||
new_args["model"] = model
|
||||
|
||||
if new_args["model"] == "Qwen/QwQ-32B":
|
||||
reply_content = self.reply_text_stream(session, args=new_args)
|
||||
else:
|
||||
reply_content = self.reply_text(session, args=new_args)
|
||||
|
||||
logger.debug(
|
||||
"[MODELSCOPE_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
# 只有当 content 为空且 completion_tokens 为 0 时才标记为错误
|
||||
if len(reply_content["content"]) == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[MODELSCOPE_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: ModelScopeSession, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key
|
||||
}
|
||||
|
||||
body = args
|
||||
body["messages"] = session.messages
|
||||
res = requests.post(
|
||||
self.base_url,
|
||||
headers=headers,
|
||||
data=json.dumps(body)
|
||||
)
|
||||
|
||||
if res.status_code == 200:
|
||||
response = res.json()
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response["choices"][0]["message"]["content"]
|
||||
}
|
||||
else:
|
||||
response = res.json()
|
||||
if "errors" in response:
|
||||
error = response.get("errors")
|
||||
elif "error" in response:
|
||||
error = response.get("error")
|
||||
else:
|
||||
error = "Unknown error"
|
||||
logger.error(f"[MODELSCOPE_AI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
need_retry = False
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
logger.warn(f"[MODELSCOPE_AI] do retry, times={retry_count}")
|
||||
need_retry = retry_count < 2
|
||||
elif res.status_code == 401:
|
||||
result["content"] = "授权失败,请检查API Key是否正确"
|
||||
elif res.status_code == 429:
|
||||
result["content"] = "请求过于频繁,请稍后再试"
|
||||
need_retry = retry_count < 2
|
||||
else:
|
||||
need_retry = False
|
||||
|
||||
if need_retry:
|
||||
time.sleep(3)
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
def reply_text_stream(self, session: ModelScopeSession, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call ModelScope's ChatCompletion to get the answer with stream response
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key
|
||||
}
|
||||
|
||||
body = args
|
||||
body["messages"] = session.messages
|
||||
body["stream"] = True # 启用流式响应
|
||||
|
||||
res = requests.post(
|
||||
self.base_url,
|
||||
headers=headers,
|
||||
data=json.dumps(body),
|
||||
stream=True
|
||||
)
|
||||
if res.status_code == 200:
|
||||
content = ""
|
||||
for line in res.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode('utf-8')
|
||||
if decoded_line.startswith("data: "):
|
||||
try:
|
||||
json_data = json.loads(decoded_line[6:])
|
||||
delta_content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
||||
if delta_content:
|
||||
content += delta_content
|
||||
except json.JSONDecodeError as e:
|
||||
pass
|
||||
return {
|
||||
"total_tokens": 1, # 流式响应通常不返回token使用情况
|
||||
"completion_tokens": 1,
|
||||
"content": content
|
||||
}
|
||||
else:
|
||||
response = res.json()
|
||||
if "errors" in response:
|
||||
error = response.get("errors")
|
||||
elif "error" in response:
|
||||
error = response.get("error")
|
||||
else:
|
||||
error = "Unknown error"
|
||||
logger.error(f"[MODELSCOPE_AI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
need_retry = False
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
logger.warn(f"[MODELSCOPE_AI] do retry, times={retry_count}")
|
||||
need_retry = retry_count < 2
|
||||
elif res.status_code == 401:
|
||||
result["content"] = "授权失败,请检查API Key是否正确"
|
||||
elif res.status_code == 429:
|
||||
result["content"] = "请求过于频繁,请稍后再试"
|
||||
need_retry = retry_count < 2
|
||||
else:
|
||||
need_retry = False
|
||||
|
||||
if need_retry:
|
||||
time.sleep(3)
|
||||
return self.reply_text_stream(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text_stream(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
def create_img(self, query, retry_count=0):
|
||||
try:
|
||||
logger.info("[ModelScopeImage] image_query={}".format(query))
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8", # 明确指定编码
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
payload = {
|
||||
"prompt": query, # required
|
||||
"n": 1,
|
||||
"model": conf().get("text_to_image"),
|
||||
}
|
||||
url = "https://api-inference.modelscope.cn/v1/images/generations"
|
||||
|
||||
# 手动序列化并保留中文(禁用 ASCII 转义)
|
||||
json_payload = json.dumps(payload, ensure_ascii=False).encode('utf-8')
|
||||
|
||||
# 使用 data 参数发送原始字符串(requests 会自动处理编码)
|
||||
res = requests.post(url, headers=headers, data=json_payload)
|
||||
|
||||
response_data = res.json()
|
||||
image_url = response_data['images'][0]['url']
|
||||
logger.info("[ModelScopeImage] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(format(e))
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
51
bot/modelscope/modelscope_session.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ModelScopeSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="Qwen/Qwen2.5-7B-Instruct"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
146
bot/moonshot/moonshot_bot.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from .moonshot_session import MoonshotSession
|
||||
import requests
|
||||
|
||||
|
||||
# ZhipuAI对话模型API
|
||||
class MoonshotBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(MoonshotSession, model=conf().get("model") or "moonshot-v1-128k")
|
||||
model = conf().get("model") or "moonshot-v1-128k"
|
||||
if model == "moonshot":
|
||||
model = "moonshot-v1-32k"
|
||||
self.args = {
|
||||
"model": model, # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
|
||||
"top_p": conf().get("top_p", 1.0), # 使用默认值
|
||||
}
|
||||
self.api_key = conf().get("moonshot_api_key")
|
||||
self.base_url = conf().get("moonshot_base_url", "https://api.moonshot.cn/v1/chat/completions")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[MOONSHOT_AI] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[MOONSHOT_AI] session query={}".format(session.messages))
|
||||
|
||||
model = context.get("moonshot_model")
|
||||
new_args = self.args.copy()
|
||||
if model:
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, args=new_args)
|
||||
logger.debug(
|
||||
"[MOONSHOT_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[MOONSHOT_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: MoonshotSession, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key
|
||||
}
|
||||
body = args
|
||||
body["messages"] = session.messages
|
||||
# logger.debug("[MOONSHOT_AI] response={}".format(response))
|
||||
# logger.info("[MOONSHOT_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
res = requests.post(
|
||||
self.base_url,
|
||||
headers=headers,
|
||||
json=body
|
||||
)
|
||||
if res.status_code == 200:
|
||||
response = res.json()
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response["choices"][0]["message"]["content"]
|
||||
}
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[MOONSHOT_AI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
need_retry = False
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
logger.warn(f"[MOONSHOT_AI] do retry, times={retry_count}")
|
||||
need_retry = retry_count < 2
|
||||
elif res.status_code == 401:
|
||||
result["content"] = "授权失败,请检查API Key是否正确"
|
||||
elif res.status_code == 429:
|
||||
result["content"] = "请求过于频繁,请稍后再试"
|
||||
need_retry = retry_count < 2
|
||||
else:
|
||||
need_retry = False
|
||||
|
||||
if need_retry:
|
||||
time.sleep(3)
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
51
bot/moonshot/moonshot_session.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class MoonshotSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="moonshot-v1-128k"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -3,7 +3,7 @@
|
||||
import requests, json
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
@@ -41,17 +41,19 @@ class XunFeiBot(Bot):
|
||||
self.api_key = conf().get("xunfei_api_key")
|
||||
self.api_secret = conf().get("xunfei_api_secret")
|
||||
# 默认使用v2.0版本: "generalv2"
|
||||
# v1.5版本为 "general"
|
||||
# v3.0版本为: "generalv3"
|
||||
self.domain = "generalv3"
|
||||
# 默认使用v2.0版本: "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
# v1.5版本为: "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
# v3.0版本为: "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
self.spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
# Spark Lite请求地址(spark_url): wss://spark-api.xf-yun.com/v1.1/chat, 对应的domain参数为: "lite"
|
||||
# Spark V2.0请求地址(spark_url): wss://spark-api.xf-yun.com/v2.1/chat, 对应的domain参数为: "generalv2"
|
||||
# Spark Pro 请求地址(spark_url): wss://spark-api.xf-yun.com/v3.1/chat, 对应的domain参数为: "generalv3"
|
||||
# Spark Pro-128K请求地址(spark_url): wss://spark-api.xf-yun.com/chat/pro-128k, 对应的domain参数为: "pro-128k"
|
||||
# Spark Max 请求地址(spark_url): wss://spark-api.xf-yun.com/v3.5/chat, 对应的domain参数为: "generalv3.5"
|
||||
# Spark4.0 Ultra 请求地址(spark_url): wss://spark-api.xf-yun.com/v4.0/chat, 对应的domain参数为: "4.0Ultra"
|
||||
# 后续模型更新,对应的参数可以参考官网文档获取:https://www.xfyun.cn/doc/spark/Web.html
|
||||
self.domain = conf().get("xunfei_domain", "generalv3.5")
|
||||
self.spark_url = conf().get("xunfei_spark_url", "wss://spark-api.xf-yun.com/v3.5/chat")
|
||||
self.host = urlparse(self.spark_url).netloc
|
||||
self.path = urlparse(self.spark_url).path
|
||||
# 和wenxin使用相同的session机制
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
|
||||
self.sessions = SessionManager(ChatGPTSession, model=const.XUNFEI)
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
|
||||
@@ -7,6 +7,8 @@ class ZhipuAISession(Session):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
if not system_prompt:
|
||||
logger.warn("[ZhiPu] `character_desc` can not be empty")
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
|
||||
@@ -18,34 +18,54 @@ class Bridge(object):
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
}
|
||||
model_type = conf().get("model") or const.GPT35
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype["chat"] = const.CHATGPTONAZURE
|
||||
if model_type in ["wenxin", "wenxin-4"]:
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.GEMINI]:
|
||||
self.btype["chat"] = const.GEMINI
|
||||
if model_type in [const.ZHIPU_AI]:
|
||||
self.btype["chat"] = const.ZHIPU_AI
|
||||
# 这边取配置的模型
|
||||
bot_type = conf().get("bot_type")
|
||||
if bot_type:
|
||||
self.btype["chat"] = bot_type
|
||||
else:
|
||||
model_type = conf().get("model") or const.GPT35
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype["chat"] = const.CHATGPTONAZURE
|
||||
if model_type in ["wenxin", "wenxin-4"]:
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and model_type.startswith("gemini"):
|
||||
self.btype["chat"] = const.GEMINI
|
||||
if model_type and model_type.startswith("glm"):
|
||||
self.btype["chat"] = const.ZHIPU_AI
|
||||
if model_type and model_type.startswith("claude"):
|
||||
self.btype["chat"] = const.CLAUDEAPI
|
||||
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
self.btype["chat"] = const.LINKAI
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
self.btype["voice_to_text"] = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
self.btype["text_to_voice"] = const.LINKAI
|
||||
if model_type in ["claude"]:
|
||||
self.btype["chat"] = const.CLAUDEAI
|
||||
|
||||
if model_type in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
|
||||
if model_type in [const.MODELSCOPE]:
|
||||
self.btype["chat"] = const.MODELSCOPE
|
||||
|
||||
if model_type in ["abab6.5-chat"]:
|
||||
self.btype["chat"] = const.MiniMax
|
||||
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
self.btype["chat"] = const.LINKAI
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
self.btype["voice_to_text"] = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
self.btype["text_to_voice"] = const.LINKAI
|
||||
|
||||
if model_type in ["claude"]:
|
||||
self.btype["chat"] = const.CLAUDEAI
|
||||
self.bots = {}
|
||||
self.chat_bots = {}
|
||||
|
||||
# 模型对应的接口
|
||||
def get_bot(self, typename):
|
||||
if self.bots.get(typename) is None:
|
||||
logger.info("create bot {} for {}".format(self.btype[typename], typename))
|
||||
|
||||
@@ -11,7 +11,7 @@ class ReplyType(Enum):
|
||||
VIDEO_URL = 5 # 视频URL
|
||||
FILE = 6 # 文件
|
||||
CARD = 7 # 微信名片,仅支持ntchat
|
||||
InviteRoom = 8 # 邀请好友进群
|
||||
INVITE_ROOM = 8 # 邀请好友进群
|
||||
INFO = 9
|
||||
ERROR = 10
|
||||
TEXT_ = 11 # 强制文本
|
||||
|
||||
@@ -18,9 +18,15 @@ def create_channel(channel_type) -> Channel:
|
||||
elif channel_type == "wxy":
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
ch = WechatyChannel()
|
||||
elif channel_type == "wcf":
|
||||
from channel.wechat.wcf_channel import WechatfChannel
|
||||
ch = WechatfChannel()
|
||||
elif channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == 'web':
|
||||
from channel.web.web_channel import WebChannel
|
||||
ch = WebChannel()
|
||||
elif channel_type == "wechatmp":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
ch = WechatMPChannel(passive_reply=True)
|
||||
|
||||
@@ -4,7 +4,6 @@ import threading
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from concurrent import futures
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
@@ -75,6 +74,7 @@ class ChatChannel(Channel):
|
||||
):
|
||||
session_id = group_id
|
||||
else:
|
||||
logger.debug(f"No need reply, groupName not in whitelist, group_name={group_name}")
|
||||
return None
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = group_id
|
||||
@@ -86,14 +86,14 @@ class ChatChannel(Channel):
|
||||
if e_context.is_pass() or context is None:
|
||||
return context
|
||||
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
|
||||
logger.debug("[WX]self message skipped")
|
||||
logger.debug("[chat_channel]self message skipped")
|
||||
return None
|
||||
|
||||
# 消息内容匹配过程,并处理content
|
||||
if ctype == ContextType.TEXT:
|
||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||
logger.debug(content)
|
||||
logger.debug("[WX]reference query skipped")
|
||||
logger.debug("[chat_channel]reference query skipped")
|
||||
return None
|
||||
|
||||
nick_name_black_list = conf().get("nick_name_black_list", [])
|
||||
@@ -111,12 +111,13 @@ class ChatChannel(Channel):
|
||||
nick_name = context["msg"].actual_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[WX] Nickname {nick_name} in In BlackList, ignore")
|
||||
logger.warning(f"[chat_channel] Nickname {nick_name} in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
logger.info("[WX]receive group at")
|
||||
logger.info("[chat_channel]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
flag = True
|
||||
self.name = self.name if self.name is not None else "" # 部分渠道self.name可能没有赋值
|
||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
if isinstance(context["msg"].at_list, list):
|
||||
@@ -130,13 +131,13 @@ class ChatChannel(Channel):
|
||||
content = subtract_res
|
||||
if not flag:
|
||||
if context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
||||
logger.info("[chat_channel]receive group voice, but checkprefix didn't match")
|
||||
return None
|
||||
else: # 单聊
|
||||
nick_name = context["msg"].from_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[WX] Nickname '{nick_name}' in In BlackList, ignore")
|
||||
logger.warning(f"[chat_channel] Nickname '{nick_name}' in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
|
||||
@@ -145,9 +146,10 @@ class ChatChannel(Channel):
|
||||
elif context["origin_ctype"] == ContextType.VOICE: # 如果源消息是私聊的语音消息,允许不匹配前缀,放宽条件
|
||||
pass
|
||||
else:
|
||||
logger.info("[chat_channel]receive single chat msg, but checkprefix didn't match")
|
||||
return None
|
||||
content = content.strip()
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix",[""]))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
@@ -159,22 +161,23 @@ class ChatChannel(Channel):
|
||||
elif context.type == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
def _handle(self, context: Context):
|
||||
if context is None or not context.content:
|
||||
return
|
||||
logger.debug("[WX] ready to handle context: {}".format(context))
|
||||
logger.debug("[chat_channel] ready to handle context: {}".format(context))
|
||||
# reply的构建步骤
|
||||
reply = self._generate_reply(context)
|
||||
|
||||
logger.debug("[WX] ready to decorate reply: {}".format(reply))
|
||||
# reply的包装步骤
|
||||
reply = self._decorate_reply(context, reply)
|
||||
logger.debug("[chat_channel] ready to decorate reply: {}".format(reply))
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
# reply的包装步骤
|
||||
if reply and reply.content:
|
||||
reply = self._decorate_reply(context, reply)
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
|
||||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
||||
e_context = PluginManager().emit_event(
|
||||
@@ -185,7 +188,7 @@ class ChatChannel(Channel):
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass():
|
||||
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
|
||||
logger.debug("[chat_channel] ready to handle context: type={}, content={}".format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
context["channel"] = e_context["channel"]
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
@@ -197,7 +200,7 @@ class ChatChannel(Channel):
|
||||
try:
|
||||
any_to_wav(file_path, wav_path)
|
||||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
||||
logger.warning("[WX]any to wav error, use raw path. " + str(e))
|
||||
logger.warning("[chat_channel]any to wav error, use raw path. " + str(e))
|
||||
wav_path = file_path
|
||||
# 语音识别
|
||||
reply = super().build_voice_to_text(wav_path)
|
||||
@@ -208,7 +211,7 @@ class ChatChannel(Channel):
|
||||
os.remove(wav_path)
|
||||
except Exception as e:
|
||||
pass
|
||||
# logger.warning("[WX]delete temp file error: " + str(e))
|
||||
# logger.warning("[chat_channel]delete temp file error: " + str(e))
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
|
||||
@@ -226,7 +229,7 @@ class ChatChannel(Channel):
|
||||
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
|
||||
pass
|
||||
else:
|
||||
logger.warning("[WX] unknown context type: {}".format(context.type))
|
||||
logger.warning("[chat_channel] unknown context type: {}".format(context.type))
|
||||
return
|
||||
return reply
|
||||
|
||||
@@ -242,7 +245,7 @@ class ChatChannel(Channel):
|
||||
desire_rtype = context.get("desire_rtype")
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||
logger.error("[WX]reply type not support: " + str(reply.type))
|
||||
logger.error("[chat_channel]reply type not support: " + str(reply.type))
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "不支持发送的消息类型: " + str(reply.type)
|
||||
|
||||
@@ -263,10 +266,10 @@ class ChatChannel(Channel):
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
|
||||
pass
|
||||
else:
|
||||
logger.error("[WX] unknown reply type: {}".format(reply.type))
|
||||
logger.error("[chat_channel] unknown reply type: {}".format(reply.type))
|
||||
return
|
||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
||||
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
|
||||
logger.warning("[chat_channel] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
|
||||
return reply
|
||||
|
||||
def _send_reply(self, context: Context, reply: Reply):
|
||||
@@ -279,14 +282,14 @@ class ChatChannel(Channel):
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
|
||||
logger.debug("[chat_channel] ready to send reply: {}, context: {}".format(reply, context))
|
||||
self._send(reply, context)
|
||||
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt=0):
|
||||
try:
|
||||
self.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error("[WX] sendMsg error: {}".format(str(e)))
|
||||
logger.error("[chat_channel] sendMsg error: {}".format(str(e)))
|
||||
if isinstance(e, NotImplementedError):
|
||||
return
|
||||
logger.exception(e)
|
||||
@@ -335,24 +338,27 @@ class ChatChannel(Channel):
|
||||
while True:
|
||||
with self.lock:
|
||||
session_ids = list(self.sessions.keys())
|
||||
for session_id in session_ids:
|
||||
for session_id in session_ids:
|
||||
with self.lock:
|
||||
context_queue, semaphore = self.sessions[session_id]
|
||||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[WX] consume context: {}".format(context))
|
||||
future: Future = handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[chat_channel] consume context: {}".format(context))
|
||||
future: Future = handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||
with self.lock:
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
self.futures[session_id].append(future)
|
||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
with self.lock:
|
||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
||||
del self.sessions[session_id]
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.2)
|
||||
|
||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
||||
def cancel_session(self, session_id):
|
||||
|
||||
@@ -4,20 +4,81 @@
|
||||
@author huiwen
|
||||
@Date 2023/11/28
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
# -*- coding=utf-8 -*-
|
||||
import logging
|
||||
import time
|
||||
|
||||
import dingtalk_stream
|
||||
from dingtalk_stream import AckMessage
|
||||
from dingtalk_stream.card_replier import AICardReplier
|
||||
from dingtalk_stream.card_replier import AICardStatus
|
||||
from dingtalk_stream.card_replier import CardReplier
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.dingtalk.dingtalk_message import DingTalkMessage
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.time_check import time_checker
|
||||
from config import conf
|
||||
from common.expired_dict import ExpiredDict
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_channel import ChatChannel
|
||||
import logging
|
||||
from dingtalk_stream import AckMessage
|
||||
import dingtalk_stream
|
||||
|
||||
|
||||
class CustomAICardReplier(CardReplier):
|
||||
def __init__(self, dingtalk_client, incoming_message):
|
||||
super(AICardReplier, self).__init__(dingtalk_client, incoming_message)
|
||||
|
||||
def start(
|
||||
self,
|
||||
card_template_id: str,
|
||||
card_data: dict,
|
||||
recipients: list = None,
|
||||
support_forward: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
AI卡片的创建接口
|
||||
:param support_forward:
|
||||
:param recipients:
|
||||
:param card_template_id:
|
||||
:param card_data:
|
||||
:return:
|
||||
"""
|
||||
card_data_with_status = copy.deepcopy(card_data)
|
||||
card_data_with_status["flowStatus"] = AICardStatus.PROCESSING
|
||||
return self.create_and_send_card(
|
||||
card_template_id,
|
||||
card_data_with_status,
|
||||
at_sender=True,
|
||||
at_all=False,
|
||||
recipients=recipients,
|
||||
support_forward=support_forward,
|
||||
)
|
||||
|
||||
|
||||
# 对 AICardReplier 进行猴子补丁
|
||||
AICardReplier.start = CustomAICardReplier.start
|
||||
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: DingTalkMessage):
|
||||
msgId = cmsg.msg_id
|
||||
if msgId in self.receivedMsgs:
|
||||
logger.info("DingTalk message {} already received, ignore".format(msgId))
|
||||
return
|
||||
self.receivedMsgs[msgId] = True
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[DingTalk] History message {} skipped".format(msgId))
|
||||
return
|
||||
if cmsg.my_msg and not cmsg.is_group:
|
||||
logger.debug("[DingTalk] My message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@singleton
|
||||
@@ -39,11 +100,13 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
super(dingtalk_stream.ChatbotHandler, self).__init__()
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
logger.info("[dingtalk] client_id={}, client_secret={} ".format(
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
logger.info("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
# 单聊无需前缀
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
def startup(self):
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
@@ -51,50 +114,112 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
client.start_forever()
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
image_download_handler = self # 传入方法所在的类实例
|
||||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
|
||||
|
||||
if dingtalk_msg.is_group:
|
||||
self.handle_group(dingtalk_msg)
|
||||
else:
|
||||
self.handle_single(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(f"dingtalk process error={e}")
|
||||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: DingTalkMessage):
|
||||
# 处理单聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE_CREATE:
|
||||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
expression = cmsg.my_msg
|
||||
cmsg.content = conf()["single_chat_prefix"][0] + cmsg.content
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: DingTalkMessage):
|
||||
# 处理群聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE_CREATE:
|
||||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
expression = cmsg.my_msg
|
||||
cmsg.content = conf()["group_chat_prefix"][0] + cmsg.content
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
context['no_need_at'] = True
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
dingtalk_msg = DingTalkMessage(incoming_message)
|
||||
if incoming_message.conversation_type == '1':
|
||||
self.handle_single(dingtalk_msg)
|
||||
else:
|
||||
self.handle_group(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.FAILED_MSG
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
isgroup = context.kwargs['msg'].is_group
|
||||
incoming_message = context.kwargs['msg'].incoming_message
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
|
||||
if conf().get("dingtalk_card_enabled"):
|
||||
logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
def reply_with_text():
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
def reply_with_at_text():
|
||||
self.reply_text("📢 您有一条新的消息,请查看。", incoming_message)
|
||||
def reply_with_ai_markdown():
|
||||
button_list, markdown_content = self.generate_button_markdown_content(context, reply)
|
||||
self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI生成", "",[incoming_message.sender_staff_id])
|
||||
|
||||
if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]:
|
||||
if isgroup:
|
||||
reply_with_ai_markdown()
|
||||
reply_with_at_text()
|
||||
else:
|
||||
reply_with_ai_markdown()
|
||||
else:
|
||||
# 暂不支持其它类型消息回复
|
||||
reply_with_text()
|
||||
else:
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
|
||||
|
||||
def generate_button_markdown_content(self, context, reply):
|
||||
image_url = context.kwargs.get("image_url")
|
||||
promptEn = context.kwargs.get("promptEn")
|
||||
reply_text = reply.content
|
||||
button_list = []
|
||||
markdown_content = f"""
|
||||
{reply.content}
|
||||
"""
|
||||
if image_url is not None and promptEn is not None:
|
||||
button_list = [
|
||||
{"text": "查看原图", "url": image_url, "iosUrl": image_url, "color": "blue"}
|
||||
]
|
||||
markdown_content = f"""
|
||||
{promptEn}
|
||||
|
||||

|
||||
|
||||
{reply_text}
|
||||
|
||||
"""
|
||||
logger.debug(f"[Dingtalk] generate_button_markdown_content, button_list={button_list} , markdown_content={markdown_content}")
|
||||
|
||||
return button_list, markdown_content
|
||||
|
||||
@@ -1,44 +1,84 @@
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
from dingtalk_stream import ChatbotMessage
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
# -*- coding=utf-8 -*-
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class DingTalkMessage(ChatMessage):
|
||||
def __init__(self, event: ChatbotMessage):
|
||||
def __init__(self, event: ChatbotMessage, image_download_handler):
|
||||
super().__init__(event)
|
||||
|
||||
self.image_download_handler = image_download_handler
|
||||
self.msg_id = event.message_id
|
||||
msg_type = event.message_type
|
||||
self.incoming_message =event
|
||||
self.message_type = event.message_type
|
||||
self.incoming_message = event
|
||||
self.sender_staff_id = event.sender_staff_id
|
||||
self.other_user_id = event.conversation_id
|
||||
self.create_time = event.create_at
|
||||
if event.conversation_type=="1":
|
||||
self.image_content = event.image_content
|
||||
self.rich_text_content = event.rich_text_content
|
||||
if event.conversation_type == "1":
|
||||
self.is_group = False
|
||||
else:
|
||||
self.is_group = True
|
||||
|
||||
|
||||
if msg_type == "text":
|
||||
if self.message_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
|
||||
self.content = event.text.content.strip()
|
||||
elif msg_type == "audio":
|
||||
|
||||
elif self.message_type == "audio":
|
||||
# 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理
|
||||
self.content = event.extensions['content']['recognition'].strip()
|
||||
self.ctype = ContextType.TEXT
|
||||
self.from_user_id = event.sender_id
|
||||
elif (self.message_type == 'picture') or (self.message_type == 'richText'):
|
||||
self.ctype = ContextType.IMAGE
|
||||
# 钉钉图片类型或富文本类型消息处理
|
||||
image_list = event.get_image_list()
|
||||
if len(image_list) > 0:
|
||||
download_code = image_list[0]
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
self.content = download_image_file(download_url, TmpDir().path())
|
||||
else:
|
||||
logger.debug(f"[Dingtalk] messageType :{self.message_type} , imageList isEmpty")
|
||||
|
||||
if self.is_group:
|
||||
self.from_user_id = event.conversation_id
|
||||
self.actual_user_id = event.sender_id
|
||||
self.is_at = True
|
||||
else:
|
||||
self.from_user_id = event.sender_id
|
||||
self.actual_user_id = event.sender_id
|
||||
self.to_user_id = event.chatbot_user_id
|
||||
self.other_user_nickname = event.conversation_title
|
||||
|
||||
user_id = event.sender_id
|
||||
nickname =event.sender_nick
|
||||
|
||||
|
||||
|
||||
|
||||
def download_image_file(image_url, temp_dir):
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36'
|
||||
}
|
||||
# 设置代理
|
||||
# self.proxies
|
||||
# , proxies=self.proxies
|
||||
response = requests.get(image_url, headers=headers, stream=True, timeout=60 * 5)
|
||||
if response.status_code == 200:
|
||||
|
||||
# 生成文件名
|
||||
file_name = image_url.split("/")[-1].split("?")[0]
|
||||
|
||||
# 检查临时目录是否存在,如果不存在则创建
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
|
||||
# 将文件保存到临时目录
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
return file_path
|
||||
else:
|
||||
logger.info(f"[Dingtalk] Failed to download image file, {response.content}")
|
||||
return None
|
||||
|
||||
@@ -40,7 +40,7 @@ class FeiShuChanel(ChatChannel):
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = []
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
def startup(self):
|
||||
urls = (
|
||||
|
||||
@@ -78,6 +78,7 @@ class TerminalChannel(ChatChannel):
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
|
||||
context["isgroup"] = False
|
||||
if context:
|
||||
self.produce(context)
|
||||
else:
|
||||
|
||||
10
channel/web/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# Web Channel
|
||||
|
||||
提供了一个默认的AI对话页面,可展示文本、图片等消息交互,支持markdown语法渲染,兼容插件执行。
|
||||
|
||||
# 使用说明
|
||||
|
||||
- 在 `config.json` 配置文件中的 `channel_type` 字段填入 `web`
|
||||
- 程序运行后将监听9899端口,浏览器访问 http://localhost:9899/chat 即可使用
|
||||
- 监听端口可以在配置文件 `web_port` 中自定义
|
||||
- 对于Docker运行方式,如果需要外部访问,需要在 `docker-compose.yml` 中通过 ports配置将端口监听映射到宿主机
|
||||
1506
channel/web/chat.html
Normal file
2
channel/web/static/axios.min.js
vendored
Normal file
BIN
channel/web/static/favicon.ico
Normal file
|
After Width: | Height: | Size: 4.2 KiB |
BIN
channel/web/static/github.png
Normal file
|
After Width: | Height: | Size: 3.4 KiB |
BIN
channel/web/static/logo.jpg
Normal file
|
After Width: | Height: | Size: 21 KiB |
291
channel/web/web_channel.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import sys
|
||||
import time
|
||||
import web
|
||||
import json
|
||||
import uuid
|
||||
from queue import Queue, Empty
|
||||
from bridge.context import *
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
import os
|
||||
import mimetypes # 添加这行来处理MIME类型
|
||||
import threading
|
||||
import logging
|
||||
|
||||
class WebMessage(ChatMessage):
|
||||
def __init__(
|
||||
self,
|
||||
msg_id,
|
||||
content,
|
||||
ctype=ContextType.TEXT,
|
||||
from_user_id="User",
|
||||
to_user_id="Chatgpt",
|
||||
other_user_id="Chatgpt",
|
||||
):
|
||||
self.msg_id = msg_id
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = to_user_id
|
||||
self.other_user_id = other_user_id
|
||||
|
||||
|
||||
@singleton
|
||||
class WebChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
||||
_instance = None
|
||||
|
||||
# def __new__(cls):
|
||||
# if cls._instance is None:
|
||||
# cls._instance = super(WebChannel, cls).__new__(cls)
|
||||
# return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.msg_id_counter = 0 # 添加消息ID计数器
|
||||
self.session_queues = {} # 存储session_id到队列的映射
|
||||
self.request_to_session = {} # 存储request_id到session_id的映射
|
||||
|
||||
def _generate_msg_id(self):
|
||||
"""生成唯一的消息ID"""
|
||||
self.msg_id_counter += 1
|
||||
return str(int(time.time())) + str(self.msg_id_counter)
|
||||
|
||||
def _generate_request_id(self):
|
||||
"""生成唯一的请求ID"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
try:
|
||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||
logger.warning(f"Web channel doesn't support {reply.type} yet")
|
||||
return
|
||||
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
time.sleep(0.5)
|
||||
|
||||
# 获取请求ID和会话ID
|
||||
request_id = context.get("request_id", None)
|
||||
|
||||
if not request_id:
|
||||
logger.error("No request_id found in context, cannot send message")
|
||||
return
|
||||
|
||||
# 通过request_id获取session_id
|
||||
session_id = self.request_to_session.get(request_id)
|
||||
if not session_id:
|
||||
logger.error(f"No session_id found for request {request_id}")
|
||||
return
|
||||
|
||||
# 检查是否有会话队列
|
||||
if session_id in self.session_queues:
|
||||
# 创建响应数据,包含请求ID以区分不同请求的响应
|
||||
response_data = {
|
||||
"type": str(reply.type),
|
||||
"content": reply.content,
|
||||
"timestamp": time.time(),
|
||||
"request_id": request_id
|
||||
}
|
||||
self.session_queues[session_id].put(response_data)
|
||||
logger.debug(f"Response sent to queue for session {session_id}, request {request_id}")
|
||||
else:
|
||||
logger.warning(f"No response queue found for session {session_id}, response dropped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in send method: {e}")
|
||||
|
||||
def post_message(self):
|
||||
"""
|
||||
Handle incoming messages from users via POST request.
|
||||
Returns a request_id for tracking this specific request.
|
||||
"""
|
||||
try:
|
||||
data = web.data() # 获取原始POST数据
|
||||
json_data = json.loads(data)
|
||||
session_id = json_data.get('session_id', f'session_{int(time.time())}')
|
||||
prompt = json_data.get('message', '')
|
||||
|
||||
# 生成请求ID
|
||||
request_id = self._generate_request_id()
|
||||
|
||||
# 将请求ID与会话ID关联
|
||||
self.request_to_session[request_id] = session_id
|
||||
|
||||
# 确保会话队列存在
|
||||
if session_id not in self.session_queues:
|
||||
self.session_queues[session_id] = Queue()
|
||||
|
||||
# 创建消息对象
|
||||
msg = WebMessage(self._generate_msg_id(), prompt)
|
||||
msg.from_user_id = session_id # 使用会话ID作为用户ID
|
||||
|
||||
# 创建上下文
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=msg)
|
||||
|
||||
# 添加必要的字段
|
||||
context["session_id"] = session_id
|
||||
context["request_id"] = request_id
|
||||
context["isgroup"] = False # 添加 isgroup 字段
|
||||
context["receiver"] = session_id # 添加 receiver 字段
|
||||
|
||||
# 异步处理消息 - 只传递上下文
|
||||
threading.Thread(target=self.produce, args=(context,)).start()
|
||||
|
||||
# 返回请求ID
|
||||
return json.dumps({"status": "success", "request_id": request_id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
def poll_response(self):
|
||||
"""
|
||||
Poll for responses using the session_id.
|
||||
"""
|
||||
try:
|
||||
# 不记录轮询请求的日志
|
||||
web.ctx.log_request = False
|
||||
|
||||
data = web.data()
|
||||
json_data = json.loads(data)
|
||||
session_id = json_data.get('session_id')
|
||||
|
||||
if not session_id or session_id not in self.session_queues:
|
||||
return json.dumps({"status": "error", "message": "Invalid session ID"})
|
||||
|
||||
# 尝试从队列获取响应,不等待
|
||||
try:
|
||||
# 使用peek而不是get,这样如果前端没有成功处理,下次还能获取到
|
||||
response = self.session_queues[session_id].get(block=False)
|
||||
|
||||
# 返回响应,包含请求ID以区分不同请求
|
||||
return json.dumps({
|
||||
"status": "success",
|
||||
"has_content": True,
|
||||
"content": response["content"],
|
||||
"request_id": response["request_id"],
|
||||
"timestamp": response["timestamp"]
|
||||
})
|
||||
|
||||
except Empty:
|
||||
# 没有新响应
|
||||
return json.dumps({"status": "success", "has_content": False})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error polling response: {e}")
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
def chat_page(self):
|
||||
"""Serve the chat HTML page."""
|
||||
file_path = os.path.join(os.path.dirname(__file__), 'chat.html') # 使用绝对路径
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
def startup(self):
|
||||
logger.info("""[WebChannel] 当前channel为web,可修改 config.json 配置文件中的 channel_type 字段进行切换。全部可用类型为:
|
||||
1. web: 网页
|
||||
2. terminal: 终端
|
||||
3. wechatmp: 个人公众号
|
||||
4. wechatmp_service: 企业公众号
|
||||
5. wechatcom_app: 企微自建应用
|
||||
6. dingtalk: 钉钉
|
||||
7. feishu: 飞书""")
|
||||
logger.info("Web对话网页已运行, 请使用浏览器访问 http://localhost:9899/chat")
|
||||
|
||||
# 确保静态文件目录存在
|
||||
static_dir = os.path.join(os.path.dirname(__file__), 'static')
|
||||
if not os.path.exists(static_dir):
|
||||
os.makedirs(static_dir)
|
||||
logger.info(f"Created static directory: {static_dir}")
|
||||
|
||||
urls = (
|
||||
'/', 'RootHandler', # 添加根路径处理器
|
||||
'/message', 'MessageHandler',
|
||||
'/poll', 'PollHandler', # 添加轮询处理器
|
||||
'/chat', 'ChatHandler',
|
||||
'/assets/(.*)', 'AssetsHandler', # 匹配 /assets/任何路径
|
||||
)
|
||||
port = conf().get("web_port", 9899)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
|
||||
# 禁用web.py的默认日志输出
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
# 配置web.py的日志级别为ERROR,只显示错误
|
||||
logging.getLogger("web").setLevel(logging.ERROR)
|
||||
|
||||
# 禁用web.httpserver的日志
|
||||
logging.getLogger("web.httpserver").setLevel(logging.ERROR)
|
||||
|
||||
# 临时重定向标准输出,捕获web.py的启动消息
|
||||
with redirect_stdout(io.StringIO()):
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
|
||||
class RootHandler:
|
||||
def GET(self):
|
||||
# 重定向到/chat
|
||||
raise web.seeother('/chat')
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
def POST(self):
|
||||
return WebChannel().post_message()
|
||||
|
||||
|
||||
class PollHandler:
|
||||
def POST(self):
|
||||
return WebChannel().poll_response()
|
||||
|
||||
|
||||
class ChatHandler:
|
||||
def GET(self):
|
||||
# 正常返回聊天页面
|
||||
file_path = os.path.join(os.path.dirname(__file__), 'chat.html')
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
class AssetsHandler:
|
||||
def GET(self, file_path): # 修改默认参数
|
||||
try:
|
||||
# 如果请求是/static/,需要处理
|
||||
if file_path == '':
|
||||
# 返回目录列表...
|
||||
pass
|
||||
|
||||
# 获取当前文件的绝对路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
static_dir = os.path.join(current_dir, 'static')
|
||||
|
||||
full_path = os.path.normpath(os.path.join(static_dir, file_path))
|
||||
|
||||
# 安全检查:确保请求的文件在static目录内
|
||||
if not os.path.abspath(full_path).startswith(os.path.abspath(static_dir)):
|
||||
logger.error(f"Security check failed for path: {full_path}")
|
||||
raise web.notfound()
|
||||
|
||||
if not os.path.exists(full_path) or not os.path.isfile(full_path):
|
||||
logger.error(f"File not found: {full_path}")
|
||||
raise web.notfound()
|
||||
|
||||
# 设置正确的Content-Type
|
||||
content_type = mimetypes.guess_type(full_path)[0]
|
||||
if content_type:
|
||||
web.header('Content-Type', content_type)
|
||||
else:
|
||||
# 默认为二进制流
|
||||
web.header('Content-Type', 'application/octet-stream')
|
||||
|
||||
# 读取并返回文件内容
|
||||
with open(full_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving static file: {e}", exc_info=True) # 添加更详细的错误信息
|
||||
raise web.notfound()
|
||||
179
channel/wechat/wcf_channel.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
wechat channel
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wechat.wcf_message import WechatfMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import *
|
||||
from config import conf, get_appdata_dir
|
||||
from wcferry import Wcf, WxMsg
|
||||
|
||||
|
||||
@singleton
|
||||
class WechatfChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.NOT_SUPPORT_REPLYTYPE = []
|
||||
# 使用字典存储最近消息,用于去重
|
||||
self.received_msgs = {}
|
||||
# 初始化wcferry客户端
|
||||
self.wcf = Wcf()
|
||||
self.wxid = None # 登录后会被设置为当前登录用户的wxid
|
||||
|
||||
def startup(self):
|
||||
"""
|
||||
启动通道
|
||||
"""
|
||||
try:
|
||||
# wcferry会自动唤起微信并登录
|
||||
self.wxid = self.wcf.get_self_wxid()
|
||||
self.name = self.wcf.get_user_info().get("name")
|
||||
logger.info(f"微信登录成功,当前用户ID: {self.wxid}, 用户名:{self.name}")
|
||||
self.contact_cache = ContactCache(self.wcf)
|
||||
self.contact_cache.update()
|
||||
# 启动消息接收
|
||||
self.wcf.enable_receiving_msg()
|
||||
# 创建消息处理线程
|
||||
t = threading.Thread(target=self._process_messages, name="WeChatThread", daemon=True)
|
||||
t.start()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"微信通道启动失败: {e}")
|
||||
raise e
|
||||
|
||||
def _process_messages(self):
|
||||
"""
|
||||
处理消息队列
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
msg = self.wcf.get_msg()
|
||||
if msg:
|
||||
self._handle_message(msg)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
continue
|
||||
|
||||
def _handle_message(self, msg: WxMsg):
|
||||
"""
|
||||
处理单条消息
|
||||
"""
|
||||
try:
|
||||
# 构造消息对象
|
||||
cmsg = WechatfMessage(self, msg)
|
||||
# 消息去重
|
||||
if cmsg.msg_id in self.received_msgs:
|
||||
return
|
||||
self.received_msgs[cmsg.msg_id] = time.time()
|
||||
# 清理过期消息ID
|
||||
self._clean_expired_msgs()
|
||||
|
||||
logger.debug(f"收到消息: {msg}")
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content,
|
||||
isgroup=cmsg.is_group,
|
||||
msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
def _clean_expired_msgs(self, expire_time: float = 60):
|
||||
"""
|
||||
清理过期的消息ID
|
||||
"""
|
||||
now = time.time()
|
||||
for msg_id in list(self.received_msgs.keys()):
|
||||
if now - self.received_msgs[msg_id] > expire_time:
|
||||
del self.received_msgs[msg_id]
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
"""
|
||||
发送消息
|
||||
"""
|
||||
receiver = context["receiver"]
|
||||
if not receiver:
|
||||
logger.error("receiver is empty")
|
||||
return
|
||||
|
||||
try:
|
||||
if reply.type == ReplyType.TEXT:
|
||||
# 处理@信息
|
||||
at_list = []
|
||||
if context.get("isgroup"):
|
||||
if context["msg"].actual_user_id:
|
||||
at_list = [context["msg"].actual_user_id]
|
||||
at_str = ",".join(at_list) if at_list else ""
|
||||
self.wcf.send_text(reply.content, receiver, at_str)
|
||||
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
self.wcf.send_text(reply.content, receiver)
|
||||
else:
|
||||
logger.error(f"暂不支持的消息类型: {reply.type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭通道
|
||||
"""
|
||||
try:
|
||||
self.wcf.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭通道失败: {e}")
|
||||
|
||||
|
||||
class ContactCache:
|
||||
def __init__(self, wcf):
|
||||
"""
|
||||
wcf: 一个 wcfferry.client.Wcf 实例
|
||||
"""
|
||||
self.wcf = wcf
|
||||
self._contact_map = {} # 形如 {wxid: {完整联系人信息}}
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
更新缓存:调用 get_contacts(),
|
||||
再把 wcf.contacts 构建成 {wxid: {完整信息}} 的字典
|
||||
"""
|
||||
self.wcf.get_contacts()
|
||||
self._contact_map.clear()
|
||||
for item in self.wcf.contacts:
|
||||
wxid = item.get('wxid')
|
||||
if wxid: # 确保有 wxid 字段
|
||||
self._contact_map[wxid] = item
|
||||
|
||||
def get_contact(self, wxid: str) -> dict:
|
||||
"""
|
||||
返回该 wxid 对应的完整联系人 dict,
|
||||
如果没找到就返回 None
|
||||
"""
|
||||
return self._contact_map.get(wxid)
|
||||
|
||||
def get_name_by_wxid(self, wxid: str) -> str:
|
||||
"""
|
||||
通过wxid,获取成员/群名称
|
||||
"""
|
||||
contact = self.get_contact(wxid)
|
||||
if contact:
|
||||
return contact.get('name', '')
|
||||
return ''
|
||||
58
channel/wechat/wcf_message.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# encoding:utf-8
|
||||
|
||||
"""
|
||||
wechat channel message
|
||||
"""
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from wcferry import WxMsg
|
||||
|
||||
|
||||
class WechatfMessage(ChatMessage):
|
||||
"""
|
||||
微信消息封装类
|
||||
"""
|
||||
|
||||
def __init__(self, channel, wcf_msg: WxMsg, is_group=False):
|
||||
"""
|
||||
初始化消息对象
|
||||
:param wcf_msg: wcferry消息对象
|
||||
:param is_group: 是否是群消息
|
||||
"""
|
||||
super().__init__(wcf_msg)
|
||||
self.msg_id = wcf_msg.id
|
||||
self.create_time = wcf_msg.ts # 使用消息时间戳
|
||||
self.is_group = is_group or wcf_msg._is_group
|
||||
self.wxid = channel.wxid
|
||||
self.name = channel.name
|
||||
|
||||
# 解析消息类型
|
||||
if wcf_msg.is_text():
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wcf_msg.content
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported message type: {wcf_msg.type}")
|
||||
|
||||
# 设置发送者和接收者信息
|
||||
self.from_user_id = self.wxid if wcf_msg.sender == self.wxid else wcf_msg.sender
|
||||
self.from_user_nickname = self.name if wcf_msg.sender == self.wxid else channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
|
||||
self.to_user_id = self.wxid
|
||||
self.to_user_nickname = self.name
|
||||
self.other_user_id = wcf_msg.sender
|
||||
self.other_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
|
||||
|
||||
# 群消息特殊处理
|
||||
if self.is_group:
|
||||
self.other_user_id = wcf_msg.roomid
|
||||
self.other_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.roomid)
|
||||
self.actual_user_id = wcf_msg.sender
|
||||
self.actual_user_nickname = channel.wcf.get_alias_in_chatroom(wcf_msg.sender, wcf_msg.roomid)
|
||||
if not self.actual_user_nickname: # 群聊获取不到企微号成员昵称,这里尝试从联系人缓存去获取
|
||||
self.actual_user_nickname = channel.contact_cache.get_name_by_wxid(wcf_msg.sender)
|
||||
self.room_id = wcf_msg.roomid
|
||||
self.is_at = wcf_msg.is_at(self.wxid) # 是否被@当前登录用户
|
||||
|
||||
# 判断是否是自己发送的消息
|
||||
self.my_msg = wcf_msg.from_self()
|
||||
@@ -9,7 +9,6 @@ import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from bridge.context import *
|
||||
@@ -21,6 +20,7 @@ from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.time_check import time_checker
|
||||
from common.utils import convert_webp_to_png, remove_markdown_symbol
|
||||
from config import conf, get_appdata_dir
|
||||
from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
@@ -96,11 +96,14 @@ def qrCallback(uuid, status, qrcode):
|
||||
print(qr_api4)
|
||||
print(qr_api2)
|
||||
print(qr_api1)
|
||||
_send_qr_code([qr_api1, qr_api2, qr_api3, qr_api4])
|
||||
_send_qr_code([qr_api3, qr_api4, qr_api2, qr_api1])
|
||||
qr = qrcode.QRCode(border=1)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
try:
|
||||
qr.print_ascii(invert=True)
|
||||
except UnicodeEncodeError:
|
||||
print("ASCII QR code printing failed due to encoding issues.")
|
||||
|
||||
|
||||
@singleton
|
||||
@@ -109,30 +112,42 @@ class WechatChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.receivedMsgs = ExpiredDict(60 * 60)
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
self.auto_login_times = 0
|
||||
|
||||
def startup(self):
|
||||
try:
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get("hot_reload", False)
|
||||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
itchat.auto_login(
|
||||
enableCmdQR=2,
|
||||
hotReload=hotReload,
|
||||
statusStorageDir=status_path,
|
||||
qrCallback=qrCallback,
|
||||
exitCallback=self.exitCallback,
|
||||
loginCallback=self.loginCallback
|
||||
)
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
self.name = itchat.instance.storageClass.nickName
|
||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# start message listener
|
||||
itchat.run()
|
||||
time.sleep(3)
|
||||
logger.error("""[WechatChannel] 当前channel暂不可用,目前支持的channel有:
|
||||
1. terminal: 终端
|
||||
2. wechatmp: 个人公众号
|
||||
3. wechatmp_service: 企业公众号
|
||||
4. wechatcom_app: 企微自建应用
|
||||
5. dingtalk: 钉钉
|
||||
6. feishu: 飞书
|
||||
7. web: 网页
|
||||
8. wcf: wechat (需Windows环境,参考 https://github.com/zhayujie/chatgpt-on-wechat/pull/2562 )
|
||||
可修改 config.json 配置文件的 channel_type 字段进行切换""")
|
||||
|
||||
# itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# # login by scan QRCode
|
||||
# hotReload = conf().get("hot_reload", False)
|
||||
# status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
# itchat.auto_login(
|
||||
# enableCmdQR=2,
|
||||
# hotReload=hotReload,
|
||||
# statusStorageDir=status_path,
|
||||
# qrCallback=qrCallback,
|
||||
# exitCallback=self.exitCallback,
|
||||
# loginCallback=self.loginCallback
|
||||
# )
|
||||
# self.user_id = itchat.instance.storageClass.userName
|
||||
# self.name = itchat.instance.storageClass.nickName
|
||||
# logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# # start message listener
|
||||
# itchat.run()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.exception(e)
|
||||
|
||||
def exitCallback(self):
|
||||
try:
|
||||
@@ -202,7 +217,7 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug(f"[WX]receive attachment msg, file_name={cmsg.content}")
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg, no_need_at=conf().get("no_need_at", False))
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@@ -210,9 +225,11 @@ class WechatChannel(ChatChannel):
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply.content = remove_markdown_symbol(reply.content)
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = remove_markdown_symbol(reply.content)
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
@@ -229,6 +246,12 @@ class WechatChannel(ChatChannel):
|
||||
image_storage.write(block)
|
||||
logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
|
||||
image_storage.seek(0)
|
||||
if ".webp" in img_url:
|
||||
try:
|
||||
image_storage = convert_webp_to_png(image_storage)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert image: {e}")
|
||||
return
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
@@ -266,6 +289,7 @@ def _send_login_success():
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def _send_logout():
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
@@ -274,6 +298,7 @@ def _send_logout():
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def _send_qr_code(qrcode_list: list):
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
@@ -281,3 +306,4 @@ def _send_qr_code(qrcode_list: list):
|
||||
chat_client.send_qrcode(qrcode_list)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
@@ -14,6 +14,11 @@ class WechatMessage(ChatMessage):
|
||||
self.create_time = itchat_msg["CreateTime"]
|
||||
self.is_group = is_group
|
||||
|
||||
notes_join_group = ["加入群聊", "加入了群聊", "invited", "joined"] # 可通过添加对应语言的加入群聊通知中的关键词适配更多
|
||||
notes_bot_join_group = ["邀请你", "invited you", "You've joined", "你通过扫描"]
|
||||
notes_exit_group = ["移出了群聊", "removed"] # 可通过添加对应语言的踢出群聊通知中的关键词适配更多
|
||||
notes_patpat = ["拍了拍我", "tickled my", "tickled me"] # 可通过添加对应语言的拍一拍通知中的关键词适配更多
|
||||
|
||||
if itchat_msg["Type"] == TEXT:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = itchat_msg["Text"]
|
||||
@@ -26,30 +31,47 @@ class WechatMessage(ChatMessage):
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
|
||||
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
|
||||
if is_group:
|
||||
if any(note_bot_join_group in itchat_msg["Content"] for note_bot_join_group in notes_bot_join_group): # 邀请机器人加入群聊
|
||||
logger.warn("机器人加入群聊消息,不处理~")
|
||||
pass
|
||||
elif any(note_join_group in itchat_msg["Content"] for note_join_group in notes_join_group): # 若有任何在notes_join_group列表中的字符串出现在NOTE中
|
||||
# 这里只能得到nickname, actual_user_id还是机器人的id
|
||||
if "加入了群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
|
||||
elif "加入群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
if "加入群聊" not in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
if "invited" in itchat_msg["Content"]: # 匹配英文信息
|
||||
self.actual_user_nickname = re.findall(r'invited\s+(.+?)\s+to\s+the\s+group\s+chat', itchat_msg["Content"])[0]
|
||||
elif "joined" in itchat_msg["Content"]: # 匹配通过二维码加入的英文信息
|
||||
self.actual_user_nickname = re.findall(r'"(.*?)" joined the group chat via the QR Code shared by', itchat_msg["Content"])[0]
|
||||
elif "加入了群聊" in itchat_msg["Content"]:
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
|
||||
elif "加入群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif any(note_exit_group in itchat_msg["Content"] for note_exit_group in notes_exit_group): # 若有任何在notes_exit_group列表中的字符串出现在NOTE中
|
||||
self.ctype = ContextType.EXIT_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif is_group and ("移出了群聊" in itchat_msg["Content"]):
|
||||
self.ctype = ContextType.EXIT_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
|
||||
self.ctype = ContextType.PATPAT
|
||||
self.content = itchat_msg["Content"]
|
||||
if "拍了拍我" in itchat_msg["Content"]: # 识别中文
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
elif "tickled my" in itchat_msg["Content"] or "tickled me" in itchat_msg["Content"]:
|
||||
self.actual_user_nickname = re.findall(r'^(.*?)(?:tickled my|tickled me)', itchat_msg["Content"])[0]
|
||||
else:
|
||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||
|
||||
elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
|
||||
self.ctype = ContextType.ACCEPT_FRIEND
|
||||
self.content = itchat_msg["Content"]
|
||||
elif "拍了拍我" in itchat_msg["Content"]:
|
||||
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
|
||||
self.ctype = ContextType.PATPAT
|
||||
self.content = itchat_msg["Content"]
|
||||
if is_group:
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
else:
|
||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||
elif itchat_msg["Type"] == ATTACHMENT:
|
||||
|
||||
@@ -17,7 +17,7 @@ from channel.wechatcom.wechatcomapp_client import WechatComAppClient
|
||||
from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import compress_imgfile, fsize, split_string_by_utf8_length
|
||||
from common.utils import compress_imgfile, fsize, split_string_by_utf8_length, convert_webp_to_png, remove_markdown_symbol
|
||||
from config import conf, subscribe_msg
|
||||
from voice.audio_convert import any_to_amr, split_audio
|
||||
|
||||
@@ -44,7 +44,7 @@ class WechatComAppChannel(ChatChannel):
|
||||
|
||||
def startup(self):
|
||||
# start message listener
|
||||
urls = ("/wxcomapp", "channel.wechatcom.wechatcomapp_channel.Query")
|
||||
urls = ("/wxcomapp/?", "channel.wechatcom.wechatcomapp_channel.Query")
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("wechatcomapp_port", 9898)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
@@ -52,7 +52,7 @@ class WechatComAppChannel(ChatChannel):
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if reply.type in [ReplyType.TEXT, ReplyType.ERROR, ReplyType.INFO]:
|
||||
reply_text = reply.content
|
||||
reply_text = remove_markdown_symbol(reply.content)
|
||||
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
||||
if len(texts) > 1:
|
||||
logger.info("[wechatcom] text too long, split into {} parts".format(len(texts)))
|
||||
@@ -99,6 +99,12 @@ class WechatComAppChannel(ChatChannel):
|
||||
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
||||
logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
|
||||
image_storage.seek(0)
|
||||
if ".webp" in img_url:
|
||||
try:
|
||||
image_storage = convert_webp_to_png(image_storage)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert image: {e}")
|
||||
return
|
||||
try:
|
||||
response = self.client.media.upload("image", image_storage)
|
||||
logger.debug("[wechatcom] upload image response: {}".format(response))
|
||||
@@ -156,11 +162,12 @@ class Query:
|
||||
logger.debug("[wechatcom] receive message: {}, msg= {}".format(message, msg))
|
||||
if msg.type == "event":
|
||||
if msg.event == "subscribe":
|
||||
reply_content = subscribe_msg()
|
||||
if reply_content:
|
||||
reply = create_reply(reply_content, msg).render()
|
||||
res = channel.crypto.encrypt_message(reply, nonce, timestamp)
|
||||
return res
|
||||
pass
|
||||
# reply_content = subscribe_msg()
|
||||
# if reply_content:
|
||||
# reply = create_reply(reply_content, msg).render()
|
||||
# res = channel.crypto.encrypt_message(reply, nonce, timestamp)
|
||||
# return res
|
||||
else:
|
||||
try:
|
||||
wechatcom_msg = WechatComAppMessage(msg, client=channel.client)
|
||||
|
||||
@@ -1,21 +1,43 @@
|
||||
# wechatcomapp_client.py
|
||||
import threading
|
||||
import time
|
||||
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
|
||||
|
||||
class WechatComAppClient(WeChatClient):
|
||||
def __init__(self, corp_id, secret, access_token=None, session=None, timeout=None, auto_retry=True):
|
||||
super(WechatComAppClient, self).__init__(corp_id, secret, access_token, session, timeout, auto_retry)
|
||||
self.fetch_access_token_lock = threading.Lock()
|
||||
self._active_refresh()
|
||||
|
||||
def _active_refresh(self):
|
||||
"""启动主动刷新的后台线程"""
|
||||
def refresh_loop():
|
||||
while True:
|
||||
now = time.time()
|
||||
expires_at = self.session.get(f"{self.corp_id}_expires_at", 0)
|
||||
|
||||
# 提前10分钟刷新(600秒)
|
||||
if expires_at - now < 600:
|
||||
with self.fetch_access_token_lock:
|
||||
# 双重检查避免重复刷新
|
||||
if self.session.get(f"{self.corp_id}_expires_at", 0) - time.time() < 600:
|
||||
super(WechatComAppClient, self).fetch_access_token()
|
||||
# 每次检查间隔60秒
|
||||
time.sleep(60)
|
||||
|
||||
# 启动守护线程
|
||||
refresh_thread = threading.Thread(
|
||||
target=refresh_loop,
|
||||
daemon=True,
|
||||
name="wechatcom_token_refresh_thread"
|
||||
)
|
||||
refresh_thread.start()
|
||||
|
||||
def fetch_access_token(self): # 重载父类方法,加锁避免多线程重复获取access_token
|
||||
def fetch_access_token(self):
|
||||
with self.fetch_access_token_lock:
|
||||
access_token = self.session.get(self.access_token_key)
|
||||
if access_token:
|
||||
if not self.expires_at:
|
||||
return access_token
|
||||
timestamp = time.time()
|
||||
if self.expires_at - timestamp > 60:
|
||||
return access_token
|
||||
return super().fetch_access_token()
|
||||
expires_at = self.session.get(f"{self.corp_id}_expires_at", 0)
|
||||
|
||||
if access_token and expires_at > time.time() + 60:
|
||||
return access_token
|
||||
return super().fetch_access_token()
|
||||
@@ -19,7 +19,7 @@ from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_client import WechatMPClient
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import split_string_by_utf8_length
|
||||
from common.utils import split_string_by_utf8_length, remove_markdown_symbol
|
||||
from config import conf
|
||||
from voice.audio_convert import any_to_mp3, split_audio
|
||||
|
||||
@@ -81,7 +81,7 @@ class WechatMPChannel(ChatChannel):
|
||||
receiver = context["receiver"]
|
||||
if self.passive_reply:
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
reply_text = remove_markdown_symbol(reply.content)
|
||||
logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
|
||||
self.cache_dict[receiver].append(("text", reply_text))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
@@ -140,6 +140,42 @@ class WechatMPChannel(ChatChannel):
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("image", media_id))
|
||||
elif reply.type == ReplyType.VIDEO_URL: # 从网络下载视频
|
||||
video_url = reply.content
|
||||
video_res = requests.get(video_url, stream=True)
|
||||
video_storage = io.BytesIO()
|
||||
for block in video_res.iter_content(1024):
|
||||
video_storage.write(block)
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.material.add("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] video uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("video", media_id))
|
||||
|
||||
elif reply.type == ReplyType.VIDEO: # 从文件读取视频
|
||||
video_storage = reply.content
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.material.add("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] video uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("video", media_id))
|
||||
|
||||
else:
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
@@ -222,6 +258,38 @@ class WechatMPChannel(ChatChannel):
|
||||
return
|
||||
self.client.message.send_image(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send image to {}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL: # 从网络下载视频
|
||||
video_url = reply.content
|
||||
video_res = requests.get(video_url, stream=True)
|
||||
video_storage = io.BytesIO()
|
||||
for block in video_res.iter_content(1024):
|
||||
video_storage.write(block)
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.media.upload("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_video(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send video to {}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO: # 从文件读取视频
|
||||
video_storage = reply.content
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.media.upload("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_video(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send video to {}".format(receiver))
|
||||
return
|
||||
|
||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
|
||||
|
||||
110
common/const.py
@@ -1,28 +1,118 @@
|
||||
# bot_type
|
||||
OPEN_AI = "openAI"
|
||||
CHATGPT = "chatGPT"
|
||||
BAIDU = "baidu"
|
||||
BAIDU = "baidu" # 百度文心一言模型
|
||||
XUNFEI = "xunfei"
|
||||
CHATGPTONAZURE = "chatGPTOnAzure"
|
||||
LINKAI = "linkai"
|
||||
CLAUDEAI = "claude"
|
||||
QWEN = "qwen"
|
||||
GEMINI = "gemini"
|
||||
ZHIPU_AI = "glm-4"
|
||||
CLAUDEAI = "claude" # 使用cookie的历史模型
|
||||
CLAUDEAPI= "claudeAPI" # 通过Claude api调用模型
|
||||
QWEN = "qwen" # 旧版通义模型
|
||||
QWEN_DASHSCOPE = "dashscope" # 通义新版sdk和api key
|
||||
|
||||
|
||||
GEMINI = "gemini" # gemini-1.0-pro
|
||||
ZHIPU_AI = "glm-4"
|
||||
MOONSHOT = "moonshot"
|
||||
MiniMax = "minimax"
|
||||
MODELSCOPE = "modelscope"
|
||||
|
||||
# model
|
||||
CLAUDE3 = "claude-3-opus-20240229"
|
||||
GPT35 = "gpt-3.5-turbo"
|
||||
GPT4 = "gpt-4"
|
||||
GPT4_TURBO_PREVIEW = "gpt-4-0125-preview"
|
||||
GPT35_0125 = "gpt-3.5-turbo-0125"
|
||||
GPT35_1106 = "gpt-3.5-turbo-1106"
|
||||
|
||||
GPT_4o = "gpt-4o"
|
||||
GPT_4O_0806 = "gpt-4o-2024-08-06"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT4_TURBO_PREVIEW = "gpt-4-turbo-preview"
|
||||
GPT4_TURBO_04_09 = "gpt-4-turbo-2024-04-09"
|
||||
GPT4_TURBO_01_25 = "gpt-4-0125-preview"
|
||||
GPT4_TURBO_11_06 = "gpt-4-1106-preview"
|
||||
GPT4_VISION_PREVIEW = "gpt-4-vision-preview"
|
||||
|
||||
GPT4 = "gpt-4"
|
||||
GPT_4o_MINI = "gpt-4o-mini"
|
||||
GPT4_32k = "gpt-4-32k"
|
||||
GPT4_06_13 = "gpt-4-0613"
|
||||
GPT4_32k_06_13 = "gpt-4-32k-0613"
|
||||
GPT_41 = "gpt-4.1"
|
||||
GPT_41_MINI = "gpt-4.1-mini"
|
||||
GPT_41_NANO = "gpt-4.1-nano"
|
||||
|
||||
O1 = "o1-preview"
|
||||
O1_MINI = "o1-mini"
|
||||
|
||||
WHISPER_1 = "whisper-1"
|
||||
TTS_1 = "tts-1"
|
||||
TTS_1_HD = "tts-1-hd"
|
||||
|
||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo",
|
||||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI]
|
||||
WEN_XIN = "wenxin"
|
||||
WEN_XIN_4 = "wenxin-4"
|
||||
|
||||
QWEN_TURBO = "qwen-turbo"
|
||||
QWEN_PLUS = "qwen-plus"
|
||||
QWEN_MAX = "qwen-max"
|
||||
|
||||
LINKAI_35 = "linkai-3.5"
|
||||
LINKAI_4_TURBO = "linkai-4-turbo"
|
||||
LINKAI_4o = "linkai-4o"
|
||||
|
||||
GEMINI_PRO = "gemini-1.0-pro"
|
||||
GEMINI_15_flash = "gemini-1.5-flash"
|
||||
GEMINI_15_PRO = "gemini-1.5-pro"
|
||||
GEMINI_20_flash_exp = "gemini-2.0-flash-exp"
|
||||
|
||||
|
||||
GLM_4 = "glm-4"
|
||||
GLM_4_PLUS = "glm-4-plus"
|
||||
GLM_4_flash = "glm-4-flash"
|
||||
GLM_4_LONG = "glm-4-long"
|
||||
GLM_4_ALLTOOLS = "glm-4-alltools"
|
||||
GLM_4_0520 = "glm-4-0520"
|
||||
GLM_4_AIR = "glm-4-air"
|
||||
GLM_4_AIRX = "glm-4-airx"
|
||||
|
||||
|
||||
CLAUDE_3_OPUS = "claude-3-opus-latest"
|
||||
CLAUDE_3_OPUS_0229 = "claude-3-opus-20240229"
|
||||
CLAUDE_35_SONNET = "claude-3-5-sonnet-latest" # 带 latest 标签的模型名称,会不断更新指向最新发布的模型
|
||||
CLAUDE_35_SONNET_1022 = "claude-3-5-sonnet-20241022" # 带具体日期的模型名称,会固定为该日期发布的模型
|
||||
CLAUDE_35_SONNET_0620 = "claude-3-5-sonnet-20240620"
|
||||
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-0"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-0"
|
||||
|
||||
DEEPSEEK_CHAT = "deepseek-chat" # DeepSeek-V3对话模型
|
||||
DEEPSEEK_REASONER = "deepseek-reasoner" # DeepSeek-R1模型
|
||||
|
||||
GITEE_AI_MODEL_LIST = ["Yi-34B-Chat", "InternVL2-8B", "deepseek-coder-33B-instruct", "InternVL2.5-26B", "Qwen2-VL-72B", "Qwen2.5-32B-Instruct", "glm-4-9b-chat", "codegeex4-all-9b", "Qwen2.5-Coder-32B-Instruct", "Qwen2.5-72B-Instruct", "Qwen2.5-7B-Instruct", "Qwen2-72B-Instruct", "Qwen2-7B-Instruct", "code-raccoon-v1", "Qwen2.5-14B-Instruct"]
|
||||
|
||||
MODELSCOPE_MODEL_LIST = ["LLM-Research/c4ai-command-r-plus-08-2024","mistralai/Mistral-Small-Instruct-2409","mistralai/Ministral-8B-Instruct-2410","mistralai/Mistral-Large-Instruct-2407",
|
||||
"Qwen/Qwen2.5-Coder-32B-Instruct","Qwen/Qwen2.5-Coder-14B-Instruct","Qwen/Qwen2.5-Coder-7B-Instruct","Qwen/Qwen2.5-72B-Instruct","Qwen/Qwen2.5-32B-Instruct","Qwen/Qwen2.5-14B-Instruct","Qwen/Qwen2.5-7B-Instruct","Qwen/QwQ-32B-Preview",
|
||||
"LLM-Research/Llama-3.3-70B-Instruct","opencompass/CompassJudger-1-32B-Instruct","Qwen/QVQ-72B-Preview","LLM-Research/Meta-Llama-3.1-405B-Instruct","LLM-Research/Meta-Llama-3.1-8B-Instruct","Qwen/Qwen2-VL-7B-Instruct","LLM-Research/Meta-Llama-3.1-70B-Instruct",
|
||||
"Qwen/Qwen2.5-14B-Instruct-1M","Qwen/Qwen2.5-7B-Instruct-1M","Qwen/Qwen2.5-VL-3B-Instruct","Qwen/Qwen2.5-VL-7B-Instruct","Qwen/Qwen2.5-VL-72B-Instruct","deepseek-ai/DeepSeek-R1-Distill-Llama-70B","deepseek-ai/DeepSeek-R1-Distill-Llama-8B","deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
||||
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B","deepseek-ai/DeepSeek-R1-Distill-Qwen-7B","deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B","deepseek-ai/DeepSeek-R1","deepseek-ai/DeepSeek-V3","Qwen/QwQ-32B"]
|
||||
|
||||
MODEL_LIST = [
|
||||
GPT35, GPT35_0125, GPT35_1106, "gpt-3.5-turbo-16k",
|
||||
GPT_41, GPT_41_MINI, GPT_41_NANO, O1, O1_MINI, GPT_4o, GPT_4O_0806, GPT_4o_MINI, GPT4_TURBO, GPT4_TURBO_PREVIEW, GPT4_TURBO_01_25, GPT4_TURBO_11_06, GPT4, GPT4_32k, GPT4_06_13, GPT4_32k_06_13,
|
||||
WEN_XIN, WEN_XIN_4,
|
||||
XUNFEI,
|
||||
ZHIPU_AI, GLM_4, GLM_4_PLUS, GLM_4_flash, GLM_4_LONG, GLM_4_ALLTOOLS, GLM_4_0520, GLM_4_AIR, GLM_4_AIRX,
|
||||
MOONSHOT, MiniMax,
|
||||
GEMINI, GEMINI_PRO, GEMINI_15_flash, GEMINI_15_PRO,GEMINI_20_flash_exp,
|
||||
CLAUDE_4_OPUS, CLAUDE_4_SONNET, CLAUDE_3_OPUS, CLAUDE_3_OPUS_0229, CLAUDE_35_SONNET, CLAUDE_35_SONNET_1022, CLAUDE_35_SONNET_0620, CLAUDE_3_SONNET, CLAUDE_3_HAIKU, "claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3.5-sonnet",
|
||||
"moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k",
|
||||
QWEN, QWEN_TURBO, QWEN_PLUS, QWEN_MAX,
|
||||
LINKAI_35, LINKAI_4_TURBO, LINKAI_4o,
|
||||
DEEPSEEK_CHAT, DEEPSEEK_REASONER,
|
||||
MODELSCOPE
|
||||
]
|
||||
|
||||
MODEL_LIST = MODEL_LIST + GITEE_AI_MODEL_LIST + MODELSCOPE_MODEL_LIST
|
||||
# channel
|
||||
FEISHU = "feishu"
|
||||
DINGTALK = "dingtalk"
|
||||
DINGTALK = "dingtalk"
|
||||
|
||||
@@ -2,12 +2,14 @@ from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from linkai import LinkAIClient, PushMsg
|
||||
from config import conf, pconf, plugin_config
|
||||
from config import conf, pconf, plugin_config, available_setting, write_plugin_config
|
||||
from plugins import PluginManager
|
||||
import time
|
||||
|
||||
|
||||
chat_client: LinkAIClient
|
||||
|
||||
|
||||
class ChatClient(LinkAIClient):
|
||||
def __init__(self, api_key, host, channel):
|
||||
super().__init__(api_key, host)
|
||||
@@ -27,29 +29,82 @@ class ChatClient(LinkAIClient):
|
||||
def on_config(self, config: dict):
|
||||
if not self.client_id:
|
||||
return
|
||||
logger.info(f"从控制台加载配置: {config}")
|
||||
logger.info(f"[LinkAI] 从客户端管理加载远程配置: {config}")
|
||||
if config.get("enabled") != "Y":
|
||||
return
|
||||
|
||||
local_config = conf()
|
||||
for key in local_config.keys():
|
||||
if config.get(key) is not None:
|
||||
for key in config.keys():
|
||||
if key in available_setting and config.get(key) is not None:
|
||||
local_config[key] = config.get(key)
|
||||
if config.get("reply_voice_mode"):
|
||||
if config.get("reply_voice_mode") == "voice_reply_voice":
|
||||
# 语音配置
|
||||
reply_voice_mode = config.get("reply_voice_mode")
|
||||
if reply_voice_mode:
|
||||
if reply_voice_mode == "voice_reply_voice":
|
||||
local_config["voice_reply_voice"] = True
|
||||
elif config.get("reply_voice_mode") == "always_reply_voice":
|
||||
local_config["always_reply_voice"] = False
|
||||
elif reply_voice_mode == "always_reply_voice":
|
||||
local_config["always_reply_voice"] = True
|
||||
# if config.get("admin_password") and plugin_config["Godcmd"]:
|
||||
# plugin_config["Godcmd"]["password"] = config.get("admin_password")
|
||||
# PluginManager().instances["Godcmd"].reload()
|
||||
# if config.get("group_app_map") and pconf("linkai"):
|
||||
# local_group_map = {}
|
||||
# for mapping in config.get("group_app_map"):
|
||||
# local_group_map[mapping.get("group_name")] = mapping.get("app_code")
|
||||
# pconf("linkai")["group_app_map"] = local_group_map
|
||||
# PluginManager().instances["linkai"].reload()
|
||||
local_config["voice_reply_voice"] = True
|
||||
elif reply_voice_mode == "no_reply_voice":
|
||||
local_config["always_reply_voice"] = False
|
||||
local_config["voice_reply_voice"] = False
|
||||
|
||||
if config.get("admin_password"):
|
||||
if not pconf("Godcmd"):
|
||||
write_plugin_config({"Godcmd": {"password": config.get("admin_password"), "admin_users": []} })
|
||||
else:
|
||||
pconf("Godcmd")["password"] = config.get("admin_password")
|
||||
PluginManager().instances["GODCMD"].reload()
|
||||
|
||||
if config.get("group_app_map") and pconf("linkai"):
|
||||
local_group_map = {}
|
||||
for mapping in config.get("group_app_map"):
|
||||
local_group_map[mapping.get("group_name")] = mapping.get("app_code")
|
||||
pconf("linkai")["group_app_map"] = local_group_map
|
||||
PluginManager().instances["LINKAI"].reload()
|
||||
|
||||
if config.get("text_to_image") and config.get("text_to_image") == "midjourney" and pconf("linkai"):
|
||||
if pconf("linkai")["midjourney"]:
|
||||
pconf("linkai")["midjourney"]["enabled"] = True
|
||||
pconf("linkai")["midjourney"]["use_image_create_prefix"] = True
|
||||
elif config.get("text_to_image") and config.get("text_to_image") in ["dall-e-2", "dall-e-3"]:
|
||||
if pconf("linkai")["midjourney"]:
|
||||
pconf("linkai")["midjourney"]["use_image_create_prefix"] = False
|
||||
|
||||
|
||||
def start(channel):
|
||||
global chat_client
|
||||
chat_client = ChatClient(api_key=conf().get("linkai_api_key"),
|
||||
host="link-ai.chat", channel=channel)
|
||||
chat_client = ChatClient(api_key=conf().get("linkai_api_key"), host="", channel=channel)
|
||||
chat_client.config = _build_config()
|
||||
chat_client.start()
|
||||
time.sleep(1.5)
|
||||
if chat_client.client_id:
|
||||
logger.info("[LinkAI] 可前往控制台进行线上登录和配置:https://link-ai.tech/console/clients")
|
||||
|
||||
|
||||
def _build_config():
|
||||
local_conf = conf()
|
||||
config = {
|
||||
"linkai_app_code": local_conf.get("linkai_app_code"),
|
||||
"single_chat_prefix": local_conf.get("single_chat_prefix"),
|
||||
"single_chat_reply_prefix": local_conf.get("single_chat_reply_prefix"),
|
||||
"single_chat_reply_suffix": local_conf.get("single_chat_reply_suffix"),
|
||||
"group_chat_prefix": local_conf.get("group_chat_prefix"),
|
||||
"group_chat_reply_prefix": local_conf.get("group_chat_reply_prefix"),
|
||||
"group_chat_reply_suffix": local_conf.get("group_chat_reply_suffix"),
|
||||
"group_name_white_list": local_conf.get("group_name_white_list"),
|
||||
"nick_name_black_list": local_conf.get("nick_name_black_list"),
|
||||
"speech_recognition": "Y" if local_conf.get("speech_recognition") else "N",
|
||||
"text_to_image": local_conf.get("text_to_image"),
|
||||
"image_create_prefix": local_conf.get("image_create_prefix")
|
||||
}
|
||||
if local_conf.get("always_reply_voice"):
|
||||
config["reply_voice_mode"] = "always_reply_voice"
|
||||
elif local_conf.get("voice_reply_voice"):
|
||||
config["reply_voice_mode"] = "voice_reply_voice"
|
||||
if pconf("linkai"):
|
||||
config["group_app_map"] = pconf("linkai").get("group_app_map")
|
||||
if plugin_config.get("Godcmd"):
|
||||
config["admin_password"] = plugin_config.get("Godcmd").get("password")
|
||||
return config
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
|
||||
import config
|
||||
from common.log import logger
|
||||
|
||||
@@ -10,31 +8,33 @@ def time_checker(f):
|
||||
def _time_checker(self, *args, **kwargs):
|
||||
_config = config.conf()
|
||||
chat_time_module = _config.get("chat_time_module", False)
|
||||
|
||||
if chat_time_module:
|
||||
chat_start_time = _config.get("chat_start_time", "00:00")
|
||||
chat_stopt_time = _config.get("chat_stop_time", "24:00")
|
||||
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00
|
||||
chat_stop_time = _config.get("chat_stop_time", "24:00")
|
||||
|
||||
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
|
||||
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
|
||||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
|
||||
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$")
|
||||
|
||||
# 时间格式检查
|
||||
if not (starttime_format_check and stoptime_format_check and chat_time_check):
|
||||
logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
|
||||
if chat_start_time > "23:59":
|
||||
logger.error("启动时间可能存在问题,请修改!")
|
||||
|
||||
# 服务时间检查
|
||||
now_time = time.strftime("%H:%M", time.localtime())
|
||||
if chat_start_time <= now_time <= chat_stopt_time: # 服务时间内,正常返回回答
|
||||
f(self, *args, **kwargs)
|
||||
if not (time_regex.match(chat_start_time) and time_regex.match(chat_stop_time)):
|
||||
logger.warning("时间格式不正确,请在config.json中修改CHAT_START_TIME/CHAT_STOP_TIME。")
|
||||
return None
|
||||
|
||||
now_time = time.strptime(time.strftime("%H:%M"), "%H:%M")
|
||||
chat_start_time = time.strptime(chat_start_time, "%H:%M")
|
||||
chat_stop_time = time.strptime(chat_stop_time, "%H:%M")
|
||||
# 结束时间小于开始时间,跨天了
|
||||
if chat_stop_time < chat_start_time and (chat_start_time <= now_time or now_time <= chat_stop_time):
|
||||
f(self, *args, **kwargs)
|
||||
# 结束大于开始时间代表,没有跨天
|
||||
elif chat_start_time < chat_stop_time and chat_start_time <= now_time <= chat_stop_time:
|
||||
f(self, *args, **kwargs)
|
||||
else:
|
||||
if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置
|
||||
# 定义匹配规则,如果以 #reconf 或者 #更新配置 结尾, 非服务时间可以修改开始/结束时间并重载配置
|
||||
pattern = re.compile(r"^.*#(?:reconf|更新配置)$")
|
||||
if args and pattern.match(args[0].content):
|
||||
f(self, *args, **kwargs)
|
||||
else:
|
||||
logger.info("非服务时间内,不接受访问")
|
||||
logger.info("非服务时间内,不接受访问")
|
||||
return None
|
||||
else:
|
||||
f(self, *args, **kwargs) # 未开启时间模块则直接回答
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
|
||||
from common.log import logger
|
||||
|
||||
def fsize(file):
|
||||
if isinstance(file, io.BytesIO):
|
||||
@@ -54,3 +55,24 @@ def split_string_by_utf8_length(string, max_length, max_split=0):
|
||||
def get_path_suffix(path):
|
||||
path = urlparse(path).path
|
||||
return os.path.splitext(path)[-1].lstrip('.')
|
||||
|
||||
|
||||
def convert_webp_to_png(webp_image):
|
||||
from PIL import Image
|
||||
try:
|
||||
webp_image.seek(0)
|
||||
img = Image.open(webp_image).convert("RGBA")
|
||||
png_image = io.BytesIO()
|
||||
img.save(png_image, format="PNG")
|
||||
png_image.seek(0)
|
||||
return png_image
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert WEBP to PNG: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def remove_markdown_symbol(text: str):
|
||||
# 移除markdown格式,目前先移除**
|
||||
if not text:
|
||||
return text
|
||||
return re.sub(r'\*\*(.*?)\*\*', r'\1', text)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"channel_type": "wx",
|
||||
"channel_type": "web",
|
||||
"model": "",
|
||||
"open_ai_api_key": "YOUR API KEY",
|
||||
"claude_api_key": "YOUR API KEY",
|
||||
"text_to_image": "dall-e-2",
|
||||
"voice_to_text": "openai",
|
||||
"text_to_voice": "openai",
|
||||
@@ -27,7 +28,7 @@
|
||||
"voice_reply_voice": false,
|
||||
"conversation_max_tokens": 2500,
|
||||
"expires_in_seconds": 3600,
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"temperature": 0.7,
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false,
|
||||
|
||||
93
config.py
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import copy
|
||||
|
||||
from common.log import logger
|
||||
|
||||
@@ -16,7 +17,8 @@ available_setting = {
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"proxy": "", # openai使用的代理
|
||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"model": "gpt-3.5-turbo", # 还支持 gpt-4, gpt-4-turbo, wenxin, xunfei, qwen
|
||||
"model": "gpt-3.5-turbo", # 可选择: gpt-4o, pt-4o-mini, gpt-4-turbo, claude-3-sonnet, wenxin, moonshot, qwen-turbo, xunfei, glm-4, minimax, gemini等模型,全部可选模型详见common/const.py文件
|
||||
"bot_type": "", # 可选配置,使用兼容openai格式的三方服务时候,需填"chatGPT"。bot具体名称详见common/const.py文件列出的bot_type,如不填根据model名称判断,
|
||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||
"azure_deployment_id": "", # azure 模型部署名称
|
||||
"azure_api_version": "", # azure api版本
|
||||
@@ -25,6 +27,7 @@ available_setting = {
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"single_chat_reply_suffix": "", # 私聊时自动回复的后缀,\n 可以换行
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"no_need_at": False, # 群聊回复时是否不需要艾特
|
||||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
|
||||
"group_chat_reply_suffix": "", # 群聊时自动回复的后缀,\n 可以换行
|
||||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
|
||||
@@ -33,14 +36,21 @@ available_setting = {
|
||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"nick_name_black_list": [], # 用户昵称黑名单
|
||||
"group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎
|
||||
"group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎
|
||||
"trigger_by_self": False, # 是否允许机器人触发
|
||||
"text_to_image": "dall-e-2", # 图片生成模型,可选 dall-e-2, dall-e-3
|
||||
# Azure OpenAI dall-e-3 配置
|
||||
"dalle3_image_style": "vivid", # 图片生成dalle3的风格,可选有 vivid, natural
|
||||
"dalle3_image_quality": "hd", # 图片生成dalle3的质量,可选有 standard, hd
|
||||
# Azure OpenAI DALL-E API 配置, 当use_azure_chatgpt为true时,用于将文字回复的资源和Dall-E的资源分开.
|
||||
"azure_openai_dalle_api_base": "", # [可选] azure openai 用于回复图片的资源 endpoint,默认使用 open_ai_api_base
|
||||
"azure_openai_dalle_api_key": "", # [可选] azure openai 用于回复图片的资源 key,默认使用 open_ai_api_key
|
||||
"azure_openai_dalle_deployment_id":"", # [可选] azure openai 用于回复图片的资源 deployment id,默认使用 text_to_image
|
||||
"image_proxy": True, # 是否需要图片代理,国内访问LinkAI时需要
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
|
||||
"image_create_size": "256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024 (dall-e-3默认为1024x1024)
|
||||
"group_chat_exit_group": False,
|
||||
"group_chat_exit_group": False,
|
||||
# chatgpt会话参数
|
||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||
# 人格描述
|
||||
@@ -60,19 +70,26 @@ available_setting = {
|
||||
"baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
|
||||
"baidu_wenxin_api_key": "", # Baidu api key
|
||||
"baidu_wenxin_secret_key": "", # Baidu secret key
|
||||
"baidu_wenxin_prompt_enabled": False, # Enable prompt if you are using ernie character model
|
||||
# 讯飞星火API
|
||||
"xunfei_app_id": "", # 讯飞应用ID
|
||||
"xunfei_api_key": "", # 讯飞 API key
|
||||
"xunfei_api_secret": "", # 讯飞 API secret
|
||||
"xunfei_domain": "", # 讯飞模型对应的domain参数,Spark4.0 Ultra为 4.0Ultra,其他模型详见: https://www.xfyun.cn/doc/spark/Web.html
|
||||
"xunfei_spark_url": "", # 讯飞模型对应的请求地址,Spark4.0 Ultra为 wss://spark-api.xf-yun.com/v4.0/chat,其他模型参考详见: https://www.xfyun.cn/doc/spark/Web.html
|
||||
# claude 配置
|
||||
"claude_api_cookie": "",
|
||||
"claude_uuid": "",
|
||||
# claude api key
|
||||
"claude_api_key": "",
|
||||
# 通义千问API, 获取方式查看文档 https://help.aliyun.com/document_detail/2587494.html
|
||||
"qwen_access_key_id": "",
|
||||
"qwen_access_key_secret": "",
|
||||
"qwen_agent_key": "",
|
||||
"qwen_app_id": "",
|
||||
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
|
||||
# 阿里灵积(通义新版sdk)模型api key
|
||||
"dashscope_api_key": "",
|
||||
# Google Gemini Api Key
|
||||
"gemini_api_key": "",
|
||||
# wework的通用配置
|
||||
@@ -82,8 +99,8 @@ available_setting = {
|
||||
"group_speech_recognition": False, # 是否开启群组语音识别
|
||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
||||
"always_reply_voice": False, # 是否一直使用语音回复
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure
|
||||
"text_to_voice": "openai", # 语音合成引擎,支持openai,baidu,google,pytts(offline),azure,elevenlabs
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure,xunfei,ali
|
||||
"text_to_voice": "openai", # 语音合成引擎,支持openai,baidu,google,azure,xunfei,ali,pytts(offline),elevenlabs,edge(online)
|
||||
"text_to_voice_model": "tts-1",
|
||||
"tts_voice_id": "alloy",
|
||||
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
|
||||
@@ -91,13 +108,13 @@ available_setting = {
|
||||
"baidu_api_key": "",
|
||||
"baidu_secret_key": "",
|
||||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
|
||||
"baidu_dev_pid": "1536",
|
||||
"baidu_dev_pid": 1536,
|
||||
# azure 语音api配置, 使用azure语音识别和语音合成时需要
|
||||
"azure_voice_api_key": "",
|
||||
"azure_voice_region": "japaneast",
|
||||
# elevenlabs 语音api配置
|
||||
"xi_api_key": "", #获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication
|
||||
"xi_voice_id": "", #ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam”
|
||||
"xi_api_key": "", # 获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication
|
||||
"xi_voice_id": "", # ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam”
|
||||
# 服务时间限制,目前支持itchat
|
||||
"chat_time_module": False, # 是否开启服务时间限制
|
||||
"chat_start_time": "00:00", # 服务开始时间
|
||||
@@ -125,22 +142,21 @@ available_setting = {
|
||||
"wechatcomapp_secret": "", # 企业微信app的secret
|
||||
"wechatcomapp_agent_id": "", # 企业微信app的agent_id
|
||||
"wechatcomapp_aes_key": "", # 企业微信app的aes_key
|
||||
|
||||
# 飞书配置
|
||||
"feishu_port": 80, # 飞书bot监听端口
|
||||
"feishu_app_id": "", # 飞书机器人应用APP Id
|
||||
"feishu_app_secret": "", # 飞书机器人APP secret
|
||||
"feishu_token": "", # 飞书 verification token
|
||||
"feishu_bot_name": "", # 飞书机器人的名字
|
||||
|
||||
# 钉钉配置
|
||||
"dingtalk_client_id": "", # 钉钉机器人Client ID
|
||||
"dingtalk_client_secret": "", # 钉钉机器人Client Secret
|
||||
"dingtalk_client_secret": "", # 钉钉机器人Client Secret
|
||||
"dingtalk_card_enabled": False,
|
||||
|
||||
# chatgpt指令自定义触发词
|
||||
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
|
||||
# channel配置
|
||||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app}
|
||||
"channel_type": "", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app,dingtalk}
|
||||
"subscribe_msg": "", # 订阅消息, 支持: wechatmp, wechatmp_service, wechatcom_app
|
||||
"debug": False, # 是否开启debug模式,开启后会打印更多日志
|
||||
"appdata_dir": "", # 数据目录
|
||||
@@ -148,16 +164,25 @@ available_setting = {
|
||||
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
|
||||
# 是否使用全局插件配置
|
||||
"use_global_plugin_config": False,
|
||||
"max_media_send_count": 3, # 单次最大发送媒体资源的个数
|
||||
"max_media_send_count": 3, # 单次最大发送媒体资源的个数
|
||||
"media_send_interval": 1, # 发送图片的事件间隔,单位秒
|
||||
# 智谱AI 平台配置
|
||||
"zhipu_ai_api_key": "",
|
||||
"zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"moonshot_api_key": "",
|
||||
"moonshot_base_url": "https://api.moonshot.cn/v1/chat/completions",
|
||||
#魔搭社区 平台配置
|
||||
"modelscope_api_key": "",
|
||||
"modelscope_base_url": "https://api-inference.modelscope.cn/v1/chat/completions",
|
||||
# LinkAI平台配置
|
||||
"use_linkai": False,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": "",
|
||||
"linkai_api_base": "https://api.link-ai.chat", # linkAI服务地址,若国内无法访问或延迟较高可改为 https://api.link-ai.tech
|
||||
"linkai_api_base": "https://api.link-ai.tech", # linkAI服务地址
|
||||
"Minimax_api_key": "",
|
||||
"Minimax_group_id": "",
|
||||
"Minimax_base_url": "",
|
||||
"web_port": 9899,
|
||||
}
|
||||
|
||||
|
||||
@@ -218,6 +243,30 @@ class Config(dict):
|
||||
config = Config()
|
||||
|
||||
|
||||
def drag_sensitive(config):
|
||||
try:
|
||||
if isinstance(config, str):
|
||||
conf_dict: dict = json.loads(config)
|
||||
conf_dict_copy = copy.deepcopy(conf_dict)
|
||||
for key in conf_dict_copy:
|
||||
if "key" in key or "secret" in key:
|
||||
if isinstance(conf_dict_copy[key], str):
|
||||
conf_dict_copy[key] = conf_dict_copy[key][0:3] + "*" * 5 + conf_dict_copy[key][-3:]
|
||||
return json.dumps(conf_dict_copy, indent=4)
|
||||
|
||||
elif isinstance(config, dict):
|
||||
config_copy = copy.deepcopy(config)
|
||||
for key in config:
|
||||
if "key" in key or "secret" in key:
|
||||
if isinstance(config_copy[key], str):
|
||||
config_copy[key] = config_copy[key][0:3] + "*" * 5 + config_copy[key][-3:]
|
||||
return config_copy
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return config
|
||||
return config
|
||||
|
||||
|
||||
def load_config():
|
||||
global config
|
||||
config_path = "./config.json"
|
||||
@@ -226,7 +275,7 @@ def load_config():
|
||||
config_path = "./config-template.json"
|
||||
|
||||
config_str = read_file(config_path)
|
||||
logger.debug("[INIT] config str: {}".format(config_str))
|
||||
logger.debug("[INIT] config str: {}".format(drag_sensitive(config_str)))
|
||||
|
||||
# 将json字符串反序列化为dict类型
|
||||
config = Config(json.loads(config_str))
|
||||
@@ -251,7 +300,7 @@ def load_config():
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.debug("[INIT] set log level to DEBUG")
|
||||
|
||||
logger.info("[INIT] load config: {}".format(config))
|
||||
logger.info("[INIT] load config: {}".format(drag_sensitive(config)))
|
||||
|
||||
config.load_user_datas()
|
||||
|
||||
@@ -296,6 +345,14 @@ def write_plugin_config(pconf: dict):
|
||||
for k in pconf:
|
||||
plugin_config[k.lower()] = pconf[k]
|
||||
|
||||
def remove_plugin_config(name: str):
|
||||
"""
|
||||
移除待重新加载的插件全局配置
|
||||
:param name: 待重载的插件名
|
||||
"""
|
||||
global plugin_config
|
||||
plugin_config.pop(name.lower(), None)
|
||||
|
||||
|
||||
def pconf(plugin_name: str) -> dict:
|
||||
"""
|
||||
@@ -307,6 +364,4 @@ def pconf(plugin_name: str) -> dict:
|
||||
|
||||
|
||||
# 全局配置,用于存放全局生效的状态
|
||||
global_config = {
|
||||
"admin_users": []
|
||||
}
|
||||
global_config = {"admin_users": []}
|
||||
|
||||
@@ -6,6 +6,7 @@ services:
|
||||
security_opt:
|
||||
- seccomp:unconfined
|
||||
environment:
|
||||
TZ: 'Asia/Shanghai'
|
||||
OPEN_AI_API_KEY: 'YOUR API KEY'
|
||||
MODEL: 'gpt-3.5-turbo'
|
||||
PROXY: ''
|
||||
|
||||
|
Before Width: | Height: | Size: 51 KiB |
|
Before Width: | Height: | Size: 326 KiB |
|
Before Width: | Height: | Size: 382 KiB |
|
Before Width: | Height: | Size: 33 KiB |
|
Before Width: | Height: | Size: 180 KiB |
13
docs/version/old-version.md
Normal file
@@ -0,0 +1,13 @@
|
||||
## 归档更新日志
|
||||
|
||||
2023.04.26: 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,使用文档。(contributed by @lanvent in #944)
|
||||
|
||||
2023.04.05: 支持微信公众号部署,兼容插件,并支持语音图片交互,使用文档。(contributed by @JS00000 in #686)
|
||||
|
||||
2023.04.05: 增加能让ChatGPT使用工具的tool插件,使用文档。工具相关issue可反馈至chatgpt-tool-hub。(contributed by @goldfishh in #663)
|
||||
|
||||
2023.03.25: 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 #578。(contributed by @lanvent in #565)
|
||||
|
||||
2023.03.09: 基于 whisper API(后续已接入更多的语音API服务) 实现对语音消息的解析和回复,添加配置项 "speech_recognition":true 即可启用,使用参考 #415。(contributed by wanggang1987 in #385)
|
||||
|
||||
2023.02.09: 扫码登录存在账号限制风险,请谨慎使用,参考#58
|
||||
66
plugins/agent/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Agent插件
|
||||
|
||||
## 插件说明
|
||||
|
||||
基于 [AgentMesh](https://github.com/MinimalFuture/AgentMesh) 多智能体框架实现的Agent插件,可以让机器人快速获得Agent能力,通过自然语言对话来访问 **终端、浏览器、文件系统、搜索引擎** 等各类工具。
|
||||
同时还支持通过 **多智能体协作** 来完成复杂任务,例如多智能体任务分发、多智能体问题讨论、协同处理等。
|
||||
|
||||
AgentMesh项目地址:https://github.com/MinimalFuture/AgentMesh
|
||||
|
||||
## 安装
|
||||
|
||||
1. 确保已安装依赖:
|
||||
|
||||
```bash
|
||||
pip install agentmesh-sdk>=0.1.2
|
||||
```
|
||||
|
||||
2. 如需使用浏览器工具,还需安装:
|
||||
|
||||
```bash
|
||||
pip install browser-use>=0.1.40
|
||||
playwright install
|
||||
```
|
||||
|
||||
## 配置
|
||||
|
||||
插件配置文件是 `plugins/agent`目录下的 `config.yaml`,包含智能体团队的配置以及工具的配置,可以从模板文件 `config-template.yaml`中复制:
|
||||
|
||||
```bash
|
||||
cp config-template.yaml config.yaml
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `team`配置是默认选中的 agent team
|
||||
- `teams` 下是Agent团队配置,团队的model默认为`gpt-4.1-mini`,可根据需要进行修改,模型对应的 `api_key` 需要在项目根目录的 `config.json` 全局配置中进行配置。例如openai模型需要配置 `open_ai_api_key`
|
||||
- 支持为 `agents` 下面的每个agent添加model字段来设置不同的模型
|
||||
|
||||
|
||||
## 使用方法
|
||||
|
||||
在对机器人发送的消息中使用 `$agent` 前缀来触发插件,支持以下命令:
|
||||
|
||||
- `$agent [task]`: 使用默认团队执行任务 (默认团队可通 config.yaml 中的team配置修改)
|
||||
- `$agent teams`: 列出可用的团队
|
||||
- `$agent use [team_name] [task]`: 使用指定的团队执行任务
|
||||
|
||||
|
||||
### 示例
|
||||
|
||||
```bash
|
||||
$agent 帮我查看当前目录下有哪些文件夹
|
||||
$agent teams
|
||||
$agent use software_team 帮我写一个产品预约体验的表单页面
|
||||
```
|
||||
|
||||
## 工具支持
|
||||
|
||||
目前支持多种内置工具,包括但不限于:
|
||||
|
||||
- `calculator`: 数学计算工具
|
||||
- `current_time`: 获取当前时间
|
||||
- `browser`: 浏览器操作工具,注意需安装`browser-use`依赖
|
||||
- `google_search`: 搜索引擎,注意需在`config.yaml`中配置 `api_key`
|
||||
- `file_save`: 文件保存工具,开启后智能体输出的内容将保存在 `workspace` 目录下
|
||||
- `terminal`: 终端命令执行工具
|
||||
3
plugins/agent/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .agent import AgentPlugin
|
||||
|
||||
__all__ = ["AgentPlugin"]
|
||||
282
plugins/agent/agent.py
Normal file
@@ -0,0 +1,282 @@
|
||||
import os
|
||||
import yaml
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from agentmesh import AgentTeam, Agent, LLMModel
|
||||
from agentmesh.models import ClaudeModel
|
||||
from agentmesh.tools import ToolManager
|
||||
from config import conf
|
||||
|
||||
import plugins
|
||||
from plugins import Plugin, Event, EventContext, EventAction
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@plugins.register(
|
||||
name="agent",
|
||||
desc="Use AgentMesh framework to process tasks with multi-agent teams",
|
||||
version="0.1.0",
|
||||
author="Saboteur7",
|
||||
desire_priority=1,
|
||||
)
|
||||
class AgentPlugin(Plugin):
|
||||
"""Plugin for integrating AgentMesh framework."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
self.name = "agent"
|
||||
self.description = "Use AgentMesh framework to process tasks with multi-agent teams"
|
||||
self.config = self._load_config()
|
||||
self.tool_manager = ToolManager()
|
||||
self.tool_manager.load_tools(config_dict=self.config.get("tools"))
|
||||
logger.info("[agent] inited")
|
||||
|
||||
def _load_config(self) -> Dict:
|
||||
"""Load configuration from config.yaml file."""
|
||||
config_path = os.path.join(self.path, "config.yaml")
|
||||
if not os.path.exists(config_path):
|
||||
logger.warning(f"Config file not found at {config_path}")
|
||||
return {}
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
"""Return help message for the agent plugin."""
|
||||
help_text = "通过AgentMesh实现对终端、浏览器、文件系统、搜索引擎等工具的执行,并支持多智能体协作。"
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
|
||||
if not verbose:
|
||||
return help_text
|
||||
|
||||
teams = self.get_available_teams()
|
||||
teams_str = ", ".join(teams) if teams else "未配置任何团队"
|
||||
|
||||
help_text += "\n\n使用说明:\n"
|
||||
help_text += f"{trigger_prefix}agent [task] - 使用默认团队执行任务\n"
|
||||
help_text += f"{trigger_prefix}agent teams - 列出可用的团队\n"
|
||||
help_text += f"{trigger_prefix}agent use [team_name] [task] - 使用特定团队执行任务\n\n"
|
||||
help_text += f"可用团队: \n{teams_str}\n\n"
|
||||
help_text += f"示例:\n"
|
||||
help_text += f"{trigger_prefix}agent 帮我查看当前文件夹路径\n"
|
||||
help_text += f"{trigger_prefix}agent use software_team 帮我写一个产品预约体验的表单页面"
|
||||
return help_text
|
||||
|
||||
def get_available_teams(self) -> List[str]:
|
||||
"""Get list of available teams from configuration."""
|
||||
teams_config = self.config.get("teams", {})
|
||||
return list(teams_config.keys())
|
||||
|
||||
|
||||
def create_team_from_config(self, team_name: str) -> Optional[AgentTeam]:
|
||||
"""Create a team from configuration."""
|
||||
# Get teams configuration
|
||||
teams_config = self.config.get("teams", {})
|
||||
|
||||
# Check if the specified team exists
|
||||
if team_name not in teams_config:
|
||||
logger.error(f"Team '{team_name}' not found in configuration.")
|
||||
available_teams = list(teams_config.keys())
|
||||
logger.info(f"Available teams: {', '.join(available_teams)}")
|
||||
return None
|
||||
|
||||
# Get team configuration
|
||||
team_config = teams_config[team_name]
|
||||
|
||||
# Get team's model
|
||||
team_model_name = team_config.get("model", "gpt-4.1-mini")
|
||||
team_model = self.create_llm_model(team_model_name)
|
||||
|
||||
# Get team's max_steps (default to 20 if not specified)
|
||||
team_max_steps = team_config.get("max_steps", 20)
|
||||
|
||||
# Create team with the model
|
||||
team = AgentTeam(
|
||||
name=team_name,
|
||||
description=team_config.get("description", ""),
|
||||
rule=team_config.get("rule", ""),
|
||||
model=team_model,
|
||||
max_steps=team_max_steps
|
||||
)
|
||||
|
||||
# Create and add agents to the team
|
||||
agents_config = team_config.get("agents", [])
|
||||
for agent_config in agents_config:
|
||||
# Check if agent has a specific model
|
||||
if agent_config.get("model"):
|
||||
agent_model = self.create_llm_model(agent_config.get("model"))
|
||||
else:
|
||||
agent_model = team_model
|
||||
|
||||
# Get agent's max_steps
|
||||
agent_max_steps = agent_config.get("max_steps")
|
||||
|
||||
agent = Agent(
|
||||
name=agent_config.get("name", ""),
|
||||
system_prompt=agent_config.get("system_prompt", ""),
|
||||
model=agent_model, # Use agent's model if specified, otherwise will use team's model
|
||||
description=agent_config.get("description", ""),
|
||||
max_steps=agent_max_steps
|
||||
)
|
||||
|
||||
# Add tools to the agent if specified
|
||||
tool_names = agent_config.get("tools", [])
|
||||
for tool_name in tool_names:
|
||||
tool = self.tool_manager.create_tool(tool_name)
|
||||
if tool:
|
||||
agent.add_tool(tool)
|
||||
else:
|
||||
if tool_name == "browser":
|
||||
logger.warning(
|
||||
"Tool 'Browser' loaded failed, "
|
||||
"please install the required dependency with: \n"
|
||||
"'pip install browser-use>=0.1.40' or 'pip install agentmesh-sdk[full]'\n"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Tool '{tool_name}' not found for agent '{agent.name}'\n")
|
||||
|
||||
# Add agent to team
|
||||
team.add(agent)
|
||||
|
||||
return team
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
"""Handle the message context."""
|
||||
if e_context['context'].type != ContextType.TEXT:
|
||||
return
|
||||
content = e_context['context'].content
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
|
||||
if not content.startswith(f"{trigger_prefix}agent "):
|
||||
e_context.action = EventAction.CONTINUE
|
||||
return
|
||||
|
||||
if not self.config:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "未找到插件配置,请在 plugins/agent 目录下创建 config.yaml 配置文件,可根据 config-template.yml 模板文件复制"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
# Extract the actual task
|
||||
task = content[len(f"{trigger_prefix}agent "):].strip()
|
||||
|
||||
# If task is empty, return help message
|
||||
if not task:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = self.get_help_text(verbose=True)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
# Check if task is asking for available teams
|
||||
if task.lower() in ["teams", "list teams", "show teams"]:
|
||||
teams = self.get_available_teams()
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
|
||||
if not teams:
|
||||
reply.content = "未配置任何团队。请检查 config.yaml 文件。"
|
||||
else:
|
||||
reply.content = f"可用团队: {', '.join(teams)}"
|
||||
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
# Check if task specifies a team
|
||||
team_name = None
|
||||
if task.startswith("use "):
|
||||
parts = task[4:].split(" ", 1)
|
||||
if len(parts) > 0:
|
||||
team_name = parts[0]
|
||||
if len(parts) > 1:
|
||||
task = parts[1].strip()
|
||||
else:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = f"已选择团队 '{team_name}'。请输入您想执行的任务。"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
if not team_name:
|
||||
team_name = self.config.get("team")
|
||||
|
||||
# If no team specified, use default or first available
|
||||
if not team_name:
|
||||
teams = self.configself.get_available_teams()
|
||||
if not teams:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = "未配置任何团队。请检查 config.yaml 文件。"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
team_name = teams[0]
|
||||
|
||||
# Create team
|
||||
team = self.create_team_from_config(team_name)
|
||||
if not team:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = f"创建团队 '{team_name}' 失败。请检查配置。"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
# Run the task
|
||||
try:
|
||||
logger.info(f"[agent] Running task '{task}' with team '{team_name}', team_model={team.model.model}")
|
||||
result = team.run_async(task=task)
|
||||
for agent_result in result:
|
||||
res_text = f"🤖 {agent_result.get('agent_name')}\n\n{agent_result.get('final_answer')}"
|
||||
_send_text(e_context, content=res_text)
|
||||
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = ""
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error running task with team '{team_name}'")
|
||||
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = f"执行任务时出错: {str(e)}"
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
def create_llm_model(self, model_name) -> LLMModel:
|
||||
if conf().get("use_linkai"):
|
||||
api_base = "https://api.link-ai.tech/v1"
|
||||
api_key = conf().get("linkai_api_key")
|
||||
elif model_name.startswith(("gpt", "text-davinci", "o1", "o3")):
|
||||
api_base = conf().get("open_ai_api_base") or "https://api.openai.com/v1"
|
||||
api_key = conf().get("open_ai_api_key")
|
||||
elif model_name.startswith("claude"):
|
||||
return ClaudeModel(model=model_name, api_key=conf().get("claude_api_key"))
|
||||
elif model_name.startswith("moonshot"):
|
||||
api_base = "https://api.moonshot.cn/v1"
|
||||
api_key = conf().get("moonshot_api_key")
|
||||
elif model_name.startswith("qwen"):
|
||||
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
api_key = conf().get("dashscope_api_key")
|
||||
else:
|
||||
api_base = conf().get("open_ai_api_base") or "https://api.openai.com/v1"
|
||||
api_key = conf().get("open_ai_api_key")
|
||||
|
||||
llm_model = LLMModel(model=model_name, api_key=api_key, api_base=api_base)
|
||||
return llm_model
|
||||
|
||||
|
||||
def _send_text(e_context: EventContext, content: str):
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
channel = e_context["channel"]
|
||||
channel.send(reply, e_context["context"])
|
||||
52
plugins/agent/config-template.yaml
Normal file
@@ -0,0 +1,52 @@
|
||||
# 默认选中的Agent Team名称
|
||||
team: general_team
|
||||
|
||||
tools:
|
||||
google_search:
|
||||
# get your apikey from https://serper.dev/
|
||||
api_key: "YOUR API KEY"
|
||||
|
||||
# Agent Team 配置
|
||||
teams:
|
||||
# 通用智能体团队
|
||||
general_team:
|
||||
model: "gpt-4.1-mini" # 团队使用的模型
|
||||
description: "A versatile research and information agent team"
|
||||
max_steps: 5
|
||||
agents:
|
||||
- name: "通用智能助手"
|
||||
description: "Universal assistant specializing in research, information synthesis, and task execution"
|
||||
system_prompt: "You are a versatile assistant who answers questions and completes tasks using available tools. Reply in a clearly structured, attractive and easy to read format."
|
||||
# Agent 支持使用的工具
|
||||
tools:
|
||||
- time
|
||||
- calculator
|
||||
- google_search
|
||||
- browser
|
||||
- terminal
|
||||
|
||||
# 软件开发智能体团队
|
||||
software_team:
|
||||
model: "gpt-4.1-mini"
|
||||
description: "A software development team with product manager, developer and tester."
|
||||
rule: "A normal R&D process should be that Product Manager writes PRD, Developer writes code based on PRD, and Finally, Tester performs testing."
|
||||
max_steps: 10
|
||||
agents:
|
||||
- name: "Product-Manager"
|
||||
description: "Responsible for product requirements and documentation"
|
||||
system_prompt: "You are an experienced product manager who creates concise PRDs, focusing on user needs and feature specifications. You always format your responses in Markdown."
|
||||
tools:
|
||||
- time
|
||||
- file_save
|
||||
- name: "Developer"
|
||||
description: "Implements code based on PRD"
|
||||
system_prompt: "You are a skilled developer. When developing web application, you creates single-page website based on user needs, you deliver HTML files with embedded JavaScript and CSS that are visually appealing, responsive, and user-friendly, featuring a grand layout and beautiful background. The HTML, CSS, and JavaScript code should be well-structured and effectively organized."
|
||||
tools:
|
||||
- file_save
|
||||
- name: "Tester"
|
||||
description: "Tests code and verifies functionality"
|
||||
system_prompt: "You are a tester who validates code against requirements. For HTML applications, use browser tools to test functionality. For Python or other client-side applications, use the terminal tool to run and test. You only need to test a few core cases."
|
||||
tools:
|
||||
- file_save
|
||||
- browser
|
||||
- terminal
|
||||
@@ -10,9 +10,7 @@
|
||||
},
|
||||
"tool": {
|
||||
"tools": [
|
||||
"python",
|
||||
"url-get",
|
||||
"terminal",
|
||||
"meteo-weather"
|
||||
],
|
||||
"kwargs": {
|
||||
@@ -40,5 +38,22 @@
|
||||
"max_file_size": 5000,
|
||||
"type": ["FILE", "SHARING"]
|
||||
}
|
||||
},
|
||||
"hello": {
|
||||
"group_welc_fixed_msg": {
|
||||
"群聊1": "群聊1的固定欢迎语",
|
||||
"群聊2": "群聊2的固定欢迎语"
|
||||
},
|
||||
"group_welc_prompt": "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。",
|
||||
|
||||
"group_exit_prompt": "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。",
|
||||
|
||||
"patpat_prompt": "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。",
|
||||
|
||||
"use_character_desc": false
|
||||
},
|
||||
"Apilot": {
|
||||
"alapi_token": "xxx",
|
||||
"morning_news_text_enabled": false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ def get_help_text(isadmin, isgroup):
|
||||
for plugin in plugins:
|
||||
if plugins[plugin].enabled and not plugins[plugin].hidden:
|
||||
namecn = plugins[plugin].namecn
|
||||
help_text += "\n%s:" % namecn
|
||||
help_text += "\n%s: " % namecn
|
||||
help_text += PluginManager().instances[plugin].get_help_text(verbose=False).strip()
|
||||
|
||||
if ADMIN_COMMANDS and isadmin:
|
||||
@@ -313,7 +313,7 @@ class Godcmd(Plugin):
|
||||
except Exception as e:
|
||||
ok, result = False, "你没有设置私有GPT模型"
|
||||
elif cmd == "reset":
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]:
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.CLAUDEAPI]:
|
||||
bot.sessions.clear_session(session_id)
|
||||
if Bridge().chat_bots.get(bottype):
|
||||
Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)
|
||||
@@ -339,7 +339,8 @@ class Godcmd(Plugin):
|
||||
ok, result = True, "配置已重载"
|
||||
elif cmd == "resetall":
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
|
||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]:
|
||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.MOONSHOT,
|
||||
const.MODELSCOPE]:
|
||||
channel.cancel_all_session()
|
||||
bot.sessions.clear_all_session()
|
||||
ok, result = True, "重置所有会话成功"
|
||||
@@ -477,7 +478,7 @@ class Godcmd(Plugin):
|
||||
return model
|
||||
|
||||
def reload(self):
|
||||
gconf = plugin_config[self.name]
|
||||
gconf = pconf(self.name)
|
||||
if gconf:
|
||||
if gconf.get("password"):
|
||||
self.password = gconf["password"]
|
||||
|
||||
41
plugins/hello/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
## 插件说明
|
||||
|
||||
可以根据需求设置入群欢迎、群聊拍一拍、退群等消息的自定义提示词,也支持为每个群设置对应的固定欢迎语。
|
||||
|
||||
该插件也是用户根据需求开发自定义插件的示例插件,参考[插件开发说明](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)
|
||||
|
||||
## 插件配置
|
||||
|
||||
将 `plugins/hello` 目录下的 `config.json.template` 配置模板复制为最终生效的 `config.json`。 (如果未配置则会默认使用`config.json.template`模板中配置)。
|
||||
|
||||
以下是插件配置项说明:
|
||||
|
||||
```bash
|
||||
{
|
||||
"group_welc_fixed_msg": { ## 这里可以为特定群里配置特定的固定欢迎语
|
||||
"群聊1": "群聊1的固定欢迎语",
|
||||
"群聊2": "群聊2的固定欢迎语"
|
||||
},
|
||||
|
||||
"group_welc_prompt": "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。", ## 群聊随机欢迎语的提示词
|
||||
|
||||
"group_exit_prompt": "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。", ## 移出群聊的提示词
|
||||
|
||||
"patpat_prompt": "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。", ## 群内拍一拍的提示词
|
||||
|
||||
"use_character_desc": false ## 是否在Hello插件中使用LinkAI应用的系统设定
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
注意:
|
||||
|
||||
- 设置全局的用户进群固定欢迎语,可以在***项目根目录下***的`config.json`文件里,可以添加参数`"group_welcome_msg": "" `,参考 [#1482](https://github.com/zhayujie/chatgpt-on-wechat/pull/1482)
|
||||
- 为每个群设置固定的欢迎语,可以在`"group_welc_fixed_msg": {}`配置群聊名和对应的固定欢迎语,优先级高于全局固定欢迎语
|
||||
- 如果没有配置以上两个参数,则使用随机欢迎语,如需设定风格,语言等,修改`"group_welc_prompt": `即可
|
||||
- 如果使用LinkAI的服务,想在随机欢迎中结合LinkAI应用的设定,配置`"use_character_desc": true `
|
||||
- 实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释
|
||||
- 如果是`docker`部署,可通过映射 `plugins/config.json` 到容器中来完成插件配置,参考[文档](https://github.com/zhayujie/chatgpt-on-wechat#3-%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8)
|
||||
|
||||
|
||||
|
||||
14
plugins/hello/config.json.template
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"group_welc_fixed_msg": {
|
||||
"群聊1": "群聊1的固定欢迎语",
|
||||
"群聊2": "群聊2的固定欢迎语"
|
||||
},
|
||||
|
||||
"group_welc_prompt": "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。",
|
||||
|
||||
"group_exit_prompt": "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。",
|
||||
|
||||
"patpat_prompt": "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。",
|
||||
|
||||
"use_character_desc": false
|
||||
}
|
||||
@@ -17,12 +17,29 @@ from config import conf
|
||||
version="0.1",
|
||||
author="lanvent",
|
||||
)
|
||||
|
||||
|
||||
class Hello(Plugin):
|
||||
|
||||
group_welc_prompt = "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。"
|
||||
group_exit_prompt = "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。"
|
||||
patpat_prompt = "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[Hello] inited")
|
||||
self.config = super().load_config()
|
||||
try:
|
||||
self.config = super().load_config()
|
||||
if not self.config:
|
||||
self.config = self._load_config_template()
|
||||
self.group_welc_fixed_msg = self.config.get("group_welc_fixed_msg", {})
|
||||
self.group_welc_prompt = self.config.get("group_welc_prompt", self.group_welc_prompt)
|
||||
self.group_exit_prompt = self.config.get("group_exit_prompt", self.group_exit_prompt)
|
||||
self.patpat_prompt = self.config.get("patpat_prompt", self.patpat_prompt)
|
||||
logger.info("[Hello] inited")
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
except Exception as e:
|
||||
logger.error(f"[Hello]初始化异常:{e}")
|
||||
raise "[Hello] init failed, ignore "
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
if e_context["context"].type not in [
|
||||
@@ -32,17 +49,21 @@ class Hello(Plugin):
|
||||
ContextType.EXIT_GROUP
|
||||
]:
|
||||
return
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
group_name = msg.from_user_nickname
|
||||
if e_context["context"].type == ContextType.JOIN_GROUP:
|
||||
if "group_welcome_msg" in conf():
|
||||
if "group_welcome_msg" in conf() or group_name in self.group_welc_fixed_msg:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = conf().get("group_welcome_msg", "")
|
||||
if group_name in self.group_welc_fixed_msg:
|
||||
reply.content = self.group_welc_fixed_msg.get(group_name, "")
|
||||
else:
|
||||
reply.content = conf().get("group_welcome_msg", "")
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
return
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
|
||||
e_context["context"].content = self.group_welc_prompt.format(nickname=msg.actual_user_nickname)
|
||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
|
||||
if not self.config or not self.config.get("use_character_desc"):
|
||||
e_context["context"]["generate_breaked_by"] = EventAction.BREAK
|
||||
@@ -51,8 +72,7 @@ class Hello(Plugin):
|
||||
if e_context["context"].type == ContextType.EXIT_GROUP:
|
||||
if conf().get("group_chat_exit_group"):
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f'请你随机使用一种风格跟其他群用户说他违反规则"{msg.actual_user_nickname}"退出群聊。'
|
||||
e_context["context"].content = self.group_exit_prompt.format(nickname=msg.actual_user_nickname)
|
||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
|
||||
return
|
||||
e_context.action = EventAction.BREAK
|
||||
@@ -60,8 +80,7 @@ class Hello(Plugin):
|
||||
|
||||
if e_context["context"].type == ContextType.PATPAT:
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f"请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。"
|
||||
e_context["context"].content = self.patpat_prompt
|
||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
|
||||
if not self.config or not self.config.get("use_character_desc"):
|
||||
e_context["context"]["generate_breaked_by"] = EventAction.BREAK
|
||||
@@ -72,7 +91,6 @@ class Hello(Plugin):
|
||||
if content == "Hello":
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
if e_context["context"]["isgroup"]:
|
||||
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
|
||||
else:
|
||||
@@ -96,3 +114,14 @@ class Hello(Plugin):
|
||||
def get_help_text(self, **kwargs):
|
||||
help_text = "输入Hello,我会回复你的名字\n输入End,我会回复你世界的图片\n"
|
||||
return help_text
|
||||
|
||||
def _load_config_template(self):
|
||||
logger.debug("No Hello plugin config.json, use plugins/hello/config.json.template")
|
||||
try:
|
||||
plugin_config_path = os.path.join(self.path, "config.json.template")
|
||||
if os.path.exists(plugin_config_path):
|
||||
with open(plugin_config_path, "r", encoding="utf-8") as f:
|
||||
plugin_conf = json.load(f)
|
||||
return plugin_conf
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
@@ -55,7 +55,7 @@ class Keyword(Plugin):
|
||||
reply_text = self.keyword[content]
|
||||
|
||||
# 判断匹配内容的类型
|
||||
if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".img"]):
|
||||
if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".webp", ".jpeg", ".png", ".gif", ".img"]):
|
||||
# 如果是以 http:// 或 https:// 开头,且".jpg", ".jpeg", ".png", ".gif", ".img"结尾,则认为是图片 URL。
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.IMAGE_URL
|
||||
|
||||
@@ -98,6 +98,8 @@
|
||||
|
||||
如果不想创建 `plugins/linkai/config.json` 配置,可以直接通过 `$linkai sum open` 指令开启该功能。
|
||||
|
||||
也可以通过私聊(全局 `config.json` 中的 `linkai_app_code`)或者群聊绑定(通过`group_app_map`参数配置)的应用来开启该功能:在LinkAI平台 [应用配置](https://link-ai.tech/console/factory) 里添加并开启**内容总结**插件。
|
||||
|
||||
#### 使用
|
||||
|
||||
功能开启后,向机器人发送 **文件**、 **分享链接卡片**、**图片** 即可生成摘要,进一步可以与文件或链接的内容进行多轮对话。如果需要关闭某种类型的内容总结,设置 `summary`配置中的type字段即可。
|
||||
|
||||
@@ -9,6 +9,8 @@ from common.expired_dict import ExpiredDict
|
||||
from common import const
|
||||
import os
|
||||
from .utils import Util
|
||||
from config import plugin_config, conf
|
||||
|
||||
|
||||
@plugins.register(
|
||||
name="linkai",
|
||||
@@ -26,13 +28,12 @@ class LinkAI(Plugin):
|
||||
# 未加载到配置,使用模板中的配置
|
||||
self.config = self._load_config_template()
|
||||
if self.config:
|
||||
self.mj_bot = MJBot(self.config.get("midjourney"))
|
||||
self.mj_bot = MJBot(self.config.get("midjourney"), self._fetch_group_app_code)
|
||||
self.sum_config = {}
|
||||
if self.config:
|
||||
self.sum_config = self.config.get("summary")
|
||||
logger.info(f"[LinkAI] inited, config={self.config}")
|
||||
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
"""
|
||||
消息处理逻辑
|
||||
@@ -42,7 +43,8 @@ class LinkAI(Plugin):
|
||||
return
|
||||
|
||||
context = e_context['context']
|
||||
if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE, ContextType.FILE, ContextType.SHARING]:
|
||||
if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE, ContextType.FILE,
|
||||
ContextType.SHARING]:
|
||||
# filter content no need solve
|
||||
return
|
||||
|
||||
@@ -54,7 +56,8 @@ class LinkAI(Plugin):
|
||||
return
|
||||
if context.type != ContextType.IMAGE:
|
||||
_send_info(e_context, "正在为你加速生成摘要,请稍后")
|
||||
res = LinkSummary().summary_file(file_path)
|
||||
app_code = self._fetch_app_code(context)
|
||||
res = LinkSummary().summary_file(file_path, app_code)
|
||||
if not res:
|
||||
if context.type != ContextType.IMAGE:
|
||||
_set_reply_text("因为神秘力量无法获取内容,请稍后再试吧", e_context, level=ReplyType.TEXT)
|
||||
@@ -68,15 +71,17 @@ class LinkAI(Plugin):
|
||||
return
|
||||
|
||||
if (context.type == ContextType.SHARING and self._is_summary_open(context)) or \
|
||||
(context.type == ContextType.TEXT and LinkSummary().check_url(context.content)):
|
||||
(context.type == ContextType.TEXT and self._is_summary_open(context) and LinkSummary().check_url(context.content)):
|
||||
if not LinkSummary().check_url(context.content):
|
||||
return
|
||||
_send_info(e_context, "正在为你加速生成摘要,请稍后")
|
||||
res = LinkSummary().summary_url(context.content)
|
||||
app_code = self._fetch_app_code(context)
|
||||
res = LinkSummary().summary_url(context.content, app_code)
|
||||
if not res:
|
||||
_set_reply_text("因为神秘力量无法获取文章内容,请稍后再试吧~", e_context, level=ReplyType.TEXT)
|
||||
return
|
||||
_set_reply_text(res.get("summary") + "\n\n💬 发送 \"开启对话\" 可以开启与文章内容的对话", e_context, level=ReplyType.TEXT)
|
||||
_set_reply_text(res.get("summary") + "\n\n💬 发送 \"开启对话\" 可以开启与文章内容的对话", e_context,
|
||||
level=ReplyType.TEXT)
|
||||
USER_FILE_MAP[_find_user_id(context) + "-sum_id"] = res.get("summary_id")
|
||||
return
|
||||
|
||||
@@ -99,7 +104,8 @@ class LinkAI(Plugin):
|
||||
_set_reply_text("开启对话失败,请稍后再试吧", e_context)
|
||||
return
|
||||
USER_FILE_MAP[_find_user_id(context) + "-file_id"] = res.get("file_id")
|
||||
_set_reply_text("💡你可以问我关于这篇文章的任何问题,例如:\n\n" + res.get("questions") + "\n\n发送 \"退出对话\" 可以关闭与文章的对话", e_context, level=ReplyType.TEXT)
|
||||
_set_reply_text("💡你可以问我关于这篇文章的任何问题,例如:\n\n" + res.get(
|
||||
"questions") + "\n\n发送 \"退出对话\" 可以关闭与文章的对话", e_context, level=ReplyType.TEXT)
|
||||
return
|
||||
|
||||
if context.type == ContextType.TEXT and context.content == "退出对话" and _find_file_id(context):
|
||||
@@ -117,12 +123,10 @@ class LinkAI(Plugin):
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
|
||||
if self._is_chat_task(e_context):
|
||||
# 文本对话任务处理
|
||||
self._process_chat_task(e_context)
|
||||
|
||||
|
||||
# 插件管理功能
|
||||
def _process_admin_cmd(self, e_context: EventContext):
|
||||
context = e_context['context']
|
||||
@@ -167,7 +171,7 @@ class LinkAI(Plugin):
|
||||
return
|
||||
|
||||
if len(cmd) == 3 and cmd[1] == "sum" and (cmd[2] == "open" or cmd[2] == "close"):
|
||||
# 知识库开关指令
|
||||
# 总结对话开关指令
|
||||
if not Util.is_admin(e_context):
|
||||
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
@@ -177,7 +181,9 @@ class LinkAI(Plugin):
|
||||
tips_text = "关闭"
|
||||
is_open = False
|
||||
if not self.sum_config:
|
||||
_set_reply_text(f"插件未启用summary功能,请参考以下链添加插件配置\n\nhttps://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/linkai/README.md", e_context, level=ReplyType.INFO)
|
||||
_set_reply_text(
|
||||
f"插件未启用summary功能,请参考以下链添加插件配置\n\nhttps://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/linkai/README.md",
|
||||
e_context, level=ReplyType.INFO)
|
||||
else:
|
||||
self.sum_config["enabled"] = is_open
|
||||
_set_reply_text(f"文章总结功能{tips_text}", e_context, level=ReplyType.INFO)
|
||||
@@ -188,14 +194,36 @@ class LinkAI(Plugin):
|
||||
return
|
||||
|
||||
def _is_summary_open(self, context) -> bool:
|
||||
if not self.sum_config or not self.sum_config.get("enabled"):
|
||||
return False
|
||||
if context.kwargs.get("isgroup") and not self.sum_config.get("group_enabled"):
|
||||
return False
|
||||
support_type = self.sum_config.get("type") or ["FILE", "SHARING"]
|
||||
if context.type.name not in support_type:
|
||||
return False
|
||||
return True
|
||||
# 获取远程应用插件状态
|
||||
remote_enabled = False
|
||||
if context.kwargs.get("isgroup"):
|
||||
# 群聊场景只查询群对应的app_code
|
||||
group_name = context.get("msg").from_user_nickname
|
||||
app_code = self._fetch_group_app_code(group_name)
|
||||
if app_code:
|
||||
if context.type.name in ["FILE", "SHARING"]:
|
||||
remote_enabled = Util.fetch_app_plugin(app_code, "内容总结")
|
||||
else:
|
||||
# 非群聊场景使用全局app_code
|
||||
app_code = conf().get("linkai_app_code")
|
||||
if app_code:
|
||||
if context.type.name in ["FILE", "SHARING"]:
|
||||
remote_enabled = Util.fetch_app_plugin(app_code, "内容总结")
|
||||
|
||||
# 基础条件:总开关开启且消息类型符合要求
|
||||
base_enabled = (
|
||||
self.sum_config
|
||||
and self.sum_config.get("enabled")
|
||||
and (context.type.name in (
|
||||
self.sum_config.get("type") or ["FILE", "SHARING"]) or context.type.name == "TEXT")
|
||||
)
|
||||
|
||||
# 群聊:需要满足(总开关和群开关)或远程插件开启
|
||||
if context.kwargs.get("isgroup"):
|
||||
return (base_enabled and self.sum_config.get("group_enabled")) or remote_enabled
|
||||
|
||||
# 非群聊:只需要满足总开关或远程插件开启
|
||||
return base_enabled or remote_enabled
|
||||
|
||||
# LinkAI 对话任务处理
|
||||
def _is_chat_task(self, e_context: EventContext):
|
||||
@@ -226,6 +254,19 @@ class LinkAI(Plugin):
|
||||
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
|
||||
return app_code
|
||||
|
||||
def _fetch_app_code(self, context) -> str:
|
||||
"""
|
||||
根据主配置或者群聊名称获取对应的应用code,优先获取群聊配置的应用code
|
||||
:param context: 上下文
|
||||
:return: 应用code
|
||||
"""
|
||||
app_code = conf().get("linkai_app_code")
|
||||
if context.kwargs.get("isgroup"):
|
||||
# 群聊场景只查询群对应的app_code
|
||||
group_name = context.get("msg").from_user_nickname
|
||||
app_code = self._fetch_group_app_code(group_name)
|
||||
return app_code
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
trigger_prefix = _get_trigger_prefix()
|
||||
help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画、文档总结、联网搜索等能力。\n\n"
|
||||
@@ -250,10 +291,14 @@ class LinkAI(Plugin):
|
||||
plugin_conf = json.load(f)
|
||||
plugin_conf["midjourney"]["enabled"] = False
|
||||
plugin_conf["summary"]["enabled"] = False
|
||||
write_plugin_config({"linkai": plugin_conf})
|
||||
return plugin_conf
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def reload(self):
|
||||
self.config = super().load_config()
|
||||
|
||||
|
||||
def _send_info(e_context: EventContext, content: str):
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
@@ -273,15 +318,19 @@ def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = Re
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
|
||||
def _get_trigger_prefix():
|
||||
return conf().get("plugin_trigger_prefix", "$")
|
||||
|
||||
|
||||
def _find_sum_id(context):
|
||||
return USER_FILE_MAP.get(_find_user_id(context) + "-sum_id")
|
||||
|
||||
|
||||
def _find_file_id(context):
|
||||
user_id = _find_user_id(context)
|
||||
if user_id:
|
||||
return USER_FILE_MAP.get(user_id + "-file_id")
|
||||
|
||||
|
||||
USER_FILE_MAP = ExpiredDict(conf().get("expires_in_seconds") or 60 * 30)
|
||||
|
||||
@@ -10,6 +10,7 @@ from bridge.context import ContextType
|
||||
from plugins import EventContext, EventAction
|
||||
from .utils import Util
|
||||
|
||||
|
||||
INVALID_REQUEST = 410
|
||||
NOT_FOUND_ORIGIN_IMAGE = 461
|
||||
NOT_FOUND_TASK = 462
|
||||
@@ -67,10 +68,11 @@ class MJTask:
|
||||
|
||||
# midjourney bot
|
||||
class MJBot:
|
||||
def __init__(self, config):
|
||||
self.base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/img/midjourney"
|
||||
def __init__(self, config, fetch_group_app_code):
|
||||
self.base_url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/img/midjourney"
|
||||
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
self.config = config
|
||||
self.fetch_group_app_code = fetch_group_app_code
|
||||
self.tasks = {}
|
||||
self.temp_dict = {}
|
||||
self.tasks_lock = threading.Lock()
|
||||
@@ -98,7 +100,7 @@ class MJBot:
|
||||
return TaskType.VARIATION
|
||||
elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
|
||||
return TaskType.RESET
|
||||
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix") and self.config.get("enabled"):
|
||||
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix") and self._is_mj_open(context):
|
||||
return TaskType.GENERATE
|
||||
|
||||
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
||||
@@ -129,8 +131,8 @@ class MJBot:
|
||||
self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if not self.config.get("enabled"):
|
||||
logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置")
|
||||
if not self._is_mj_open(context):
|
||||
logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置,或者在LinkAI平台 应用中添加/打开”MJ“插件")
|
||||
self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
@@ -409,6 +411,25 @@ class MJBot:
|
||||
result.append(task)
|
||||
return result
|
||||
|
||||
def _is_mj_open(self, context) -> bool:
|
||||
# 获取远程应用插件状态
|
||||
remote_enabled = False
|
||||
if context.kwargs.get("isgroup"):
|
||||
# 群聊场景只查询群对应的app_code
|
||||
group_name = context.get("msg").from_user_nickname
|
||||
app_code = self.fetch_group_app_code(group_name)
|
||||
if app_code:
|
||||
remote_enabled = Util.fetch_app_plugin(app_code, "Midjourney")
|
||||
else:
|
||||
# 非群聊场景使用全局app_code
|
||||
app_code = conf().get("linkai_app_code")
|
||||
if app_code:
|
||||
remote_enabled = Util.fetch_app_plugin(app_code, "Midjourney")
|
||||
|
||||
# 本地配置
|
||||
base_enabled = self.config.get("enabled")
|
||||
|
||||
return base_enabled or remote_enabled
|
||||
|
||||
def _send(channel, reply: Reply, context, retry_cnt=0):
|
||||
try:
|
||||
|
||||
@@ -2,25 +2,33 @@ import requests
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import os
|
||||
import html
|
||||
|
||||
|
||||
class LinkSummary:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def summary_file(self, file_path: str):
|
||||
def summary_file(self, file_path: str, app_code: str):
|
||||
file_body = {
|
||||
"file": open(file_path, "rb"),
|
||||
"name": file_path.split("/")[-1],
|
||||
"name": file_path.split("/")[-1]
|
||||
}
|
||||
body = {
|
||||
"app_code": app_code
|
||||
}
|
||||
url = self.base_url() + "/v1/summary/file"
|
||||
res = requests.post(url, headers=self.headers(), files=file_body, timeout=(5, 300))
|
||||
logger.info(f"[LinkSum] file summary, app_code={app_code}")
|
||||
res = requests.post(url, headers=self.headers(), files=file_body, data=body, timeout=(5, 300))
|
||||
return self._parse_summary_res(res)
|
||||
|
||||
def summary_url(self, url: str):
|
||||
def summary_url(self, url: str, app_code: str):
|
||||
url = html.unescape(url)
|
||||
body = {
|
||||
"url": url
|
||||
"url": url,
|
||||
"app_code": app_code
|
||||
}
|
||||
logger.info(f"[LinkSum] url summary, app_code={app_code}")
|
||||
res = requests.post(url=self.base_url() + "/v1/summary/url", headers=self.headers(), json=body, timeout=(5, 180))
|
||||
return self._parse_summary_res(res)
|
||||
|
||||
@@ -46,7 +54,7 @@ class LinkSummary:
|
||||
def _parse_summary_res(self, res):
|
||||
if res.status_code == 200:
|
||||
res = res.json()
|
||||
logger.debug(f"[LinkSum] url summary, res={res}")
|
||||
logger.debug(f"[LinkSum] summary result, res={res}")
|
||||
if res.get("code") == 200:
|
||||
data = res.get("data")
|
||||
return {
|
||||
@@ -59,7 +67,7 @@ class LinkSummary:
|
||||
return None
|
||||
|
||||
def base_url(self):
|
||||
return conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
return conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
|
||||
def headers(self):
|
||||
return {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import requests
|
||||
from common.log import logger
|
||||
from config import global_config
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from plugins.event import EventContext, EventAction
|
||||
|
||||
from config import conf
|
||||
|
||||
class Util:
|
||||
@staticmethod
|
||||
@@ -26,3 +28,23 @@ class Util:
|
||||
reply = Reply(level, content)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
@staticmethod
|
||||
def fetch_app_plugin(app_code: str, plugin_name: str) -> bool:
|
||||
try:
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
params = {"app_code": app_code}
|
||||
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
|
||||
if res.status_code == 200:
|
||||
plugins = res.json().get("data").get("plugins")
|
||||
for plugin in plugins:
|
||||
if plugin.get("name") and plugin.get("name") == plugin_name:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"[LinkAI] find app info exception, res={res}")
|
||||
return False
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
from config import pconf, plugin_config, conf
|
||||
from config import pconf, plugin_config, conf, write_plugin_config
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@@ -18,18 +18,19 @@ class Plugin:
|
||||
if not plugin_conf:
|
||||
# 全局配置不存在,则获取插件目录下的配置
|
||||
plugin_config_path = os.path.join(self.path, "config.json")
|
||||
logger.debug(f"loading plugin config, plugin_config_path={plugin_config_path}, exist={os.path.exists(plugin_config_path)}")
|
||||
if os.path.exists(plugin_config_path):
|
||||
with open(plugin_config_path, "r", encoding="utf-8") as f:
|
||||
plugin_conf = json.load(f)
|
||||
|
||||
# 写入全局配置内存
|
||||
plugin_config[self.name] = plugin_conf
|
||||
write_plugin_config({self.name: plugin_conf})
|
||||
logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
|
||||
return plugin_conf
|
||||
|
||||
def save_config(self, config: dict):
|
||||
try:
|
||||
plugin_config[self.name] = config
|
||||
write_plugin_config({self.name: config})
|
||||
# 写入全局配置
|
||||
global_config_path = "./plugins/config.json"
|
||||
if os.path.exists(global_config_path):
|
||||
|
||||
@@ -9,7 +9,7 @@ import sys
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.sorted_dict import SortedDict
|
||||
from config import conf, write_plugin_config
|
||||
from config import conf, remove_plugin_config, write_plugin_config
|
||||
|
||||
from .event import *
|
||||
|
||||
@@ -99,7 +99,7 @@ class PluginManager:
|
||||
try:
|
||||
self.current_plugin_path = plugin_path
|
||||
if plugin_path in self.loaded:
|
||||
if self.loaded[plugin_path] == None:
|
||||
if plugin_name.upper() != 'GODCMD':
|
||||
logger.info("reload module %s" % plugin_name)
|
||||
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
|
||||
dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
|
||||
@@ -141,28 +141,35 @@ class PluginManager:
|
||||
failed_plugins = []
|
||||
for name, plugincls in self.plugins.items():
|
||||
if plugincls.enabled:
|
||||
if name not in self.instances:
|
||||
try:
|
||||
instance = plugincls()
|
||||
except Exception as e:
|
||||
logger.warn("Failed to init %s, diabled. %s" % (name, e))
|
||||
self.disable_plugin(name)
|
||||
failed_plugins.append(name)
|
||||
continue
|
||||
self.instances[name] = instance
|
||||
for event in instance.handlers:
|
||||
if event not in self.listening_plugins:
|
||||
self.listening_plugins[event] = []
|
||||
self.listening_plugins[event].append(name)
|
||||
if 'GODCMD' in self.instances and name == 'GODCMD':
|
||||
continue
|
||||
# if name not in self.instances:
|
||||
try:
|
||||
instance = plugincls()
|
||||
except Exception as e:
|
||||
logger.warn("Failed to init %s, diabled. %s" % (name, e))
|
||||
self.disable_plugin(name)
|
||||
failed_plugins.append(name)
|
||||
continue
|
||||
if name in self.instances:
|
||||
self.instances[name].handlers.clear()
|
||||
self.instances[name] = instance
|
||||
for event in instance.handlers:
|
||||
if event not in self.listening_plugins:
|
||||
self.listening_plugins[event] = []
|
||||
self.listening_plugins[event].append(name)
|
||||
self.refresh_order()
|
||||
return failed_plugins
|
||||
|
||||
def reload_plugin(self, name: str):
|
||||
name = name.upper()
|
||||
remove_plugin_config(name)
|
||||
if name in self.instances:
|
||||
for event in self.listening_plugins:
|
||||
if name in self.listening_plugins[event]:
|
||||
self.listening_plugins[event].remove(name)
|
||||
if name in self.instances:
|
||||
self.instances[name].handlers.clear()
|
||||
del self.instances[name]
|
||||
self.activate_plugins()
|
||||
return True
|
||||
|
||||
@@ -99,7 +99,8 @@ class Role(Plugin):
|
||||
if e_context["context"].type != ContextType.TEXT:
|
||||
return
|
||||
btype = Bridge().get_bot_type("chat")
|
||||
if btype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
|
||||
if btype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.QWEN_DASHSCOPE, const.XUNFEI, const.BAIDU, const.ZHIPU_AI, const.MOONSHOT, const.MiniMax, const.LINKAI,const.MODELSCOPE]:
|
||||
logger.debug(f'不支持的bot: {btype}')
|
||||
return
|
||||
bot = Bridge().get_bot("chat")
|
||||
content = e_context["context"].content[:]
|
||||
@@ -179,6 +180,7 @@ class Role(Plugin):
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
else:
|
||||
e_context["context"]["generate_breaked_by"] = EventAction.BREAK
|
||||
prompt = self.roleplays[sessionid].action(content)
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
e_context["context"].content = prompt
|
||||
|
||||
@@ -12,13 +12,29 @@
|
||||
"url": "https://github.com/lanvent/plugin_summary.git",
|
||||
"desc": "总结聊天记录的插件"
|
||||
},
|
||||
"timetask": {
|
||||
"url": "https://github.com/haikerapples/timetask.git",
|
||||
"desc": "一款定时任务系统的插件"
|
||||
},
|
||||
"Apilot": {
|
||||
"url": "https://github.com/6vision/Apilot.git",
|
||||
"desc": "通过api直接查询早报、热榜、快递、天气等实用信息的插件"
|
||||
},
|
||||
"pictureChange": {
|
||||
"url": "https://github.com/Yanyutin753/pictureChange.git",
|
||||
"desc": "1. 支持百度AI和Stable Diffusion WebUI进行图像处理,提供多种模型选择,支持图生图、文生图自定义模板。2. 支持Suno音乐AI可将图像和文字转为音乐。3. 支持自定义模型进行文件、图片总结功能。4. 支持管理员控制群聊内容与参数和功能改变。"
|
||||
},
|
||||
"Blackroom": {
|
||||
"url": "https://github.com/dividduang/blackroom.git",
|
||||
"desc": "小黑屋插件,被拉进小黑屋的人将不能使用@bot的功能的插件"
|
||||
},
|
||||
"midjourney": {
|
||||
"url": "https://github.com/baojingyu/midjourney.git",
|
||||
"desc": "利用midjourney实现ai绘图的的插件"
|
||||
},
|
||||
"solitaire": {
|
||||
"url": "https://github.com/Wang-zhechao/solitaire.git",
|
||||
"desc": "机器人微信接龙插件"
|
||||
},
|
||||
"HighSpeedTicket": {
|
||||
"url": "https://github.com/He0607/HighSpeedTicket.git",
|
||||
"desc": "高铁(火车)票查询插件"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
{
|
||||
"tools": [
|
||||
"python",
|
||||
"url-get",
|
||||
"terminal",
|
||||
"meteo"
|
||||
],
|
||||
"kwargs": {
|
||||
|
||||
@@ -22,11 +22,13 @@ class Tool(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
|
||||
self.app = self._reset_app()
|
||||
|
||||
if not self.tool_config.get("tools"):
|
||||
logger.warn("[tool] init failed, ignore ")
|
||||
raise Exception("config.json not found")
|
||||
logger.info("[tool] inited")
|
||||
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。"
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
@@ -137,7 +139,7 @@ class Tool(Plugin):
|
||||
|
||||
return {
|
||||
# 全局配置相关
|
||||
"log": True, # tool 日志开关
|
||||
"log": False, # tool 日志开关
|
||||
"debug": kwargs.get("debug", False), # 输出更多日志
|
||||
"no_default": kwargs.get("no_default", False), # 不要默认的工具,只加载自己导入的工具
|
||||
"think_depth": kwargs.get("think_depth", 2), # 一个问题最多使用多少次工具
|
||||
|
||||
@@ -7,8 +7,10 @@ gTTS>=2.3.1 # google text to speech
|
||||
pyttsx3>=2.90 # pytsx text to speech
|
||||
baidu_aip>=4.16.10 # baidu voice
|
||||
azure-cognitiveservices-speech # azure voice
|
||||
edge-tts # edge-tts
|
||||
numpy<=1.24.2
|
||||
langid # language detect
|
||||
elevenlabs==1.0.3 # elevenlabs TTS
|
||||
|
||||
#install plugin
|
||||
dulwich
|
||||
@@ -25,6 +27,8 @@ websocket-client==1.2.0
|
||||
|
||||
# claude bot
|
||||
curl_cffi
|
||||
# claude API
|
||||
anthropic
|
||||
|
||||
# tongyi qwen
|
||||
broadscope_bailian
|
||||
@@ -32,11 +36,14 @@ broadscope_bailian
|
||||
# google
|
||||
google-generativeai
|
||||
|
||||
# linkai
|
||||
linkai
|
||||
|
||||
# dingtalk
|
||||
dingtalk_stream
|
||||
|
||||
# zhipuai
|
||||
zhipuai>=2.0.1
|
||||
|
||||
# tongyi qwen new sdk
|
||||
dashscope
|
||||
|
||||
# tencentcloud sdk
|
||||
tencentcloud-sdk-python>=3.0.0
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
openai==0.27.8
|
||||
HTMLParser>=0.0.2
|
||||
PyQRCode>=1.2.1
|
||||
qrcode>=7.4.2
|
||||
PyQRCode==1.2.1
|
||||
qrcode==7.4.2
|
||||
requests>=2.28.2
|
||||
chardet>=5.1.0
|
||||
Pillow
|
||||
pre-commit
|
||||
web.py
|
||||
linkai>=0.0.6.0
|
||||
agentmesh-sdk>=0.1.3
|
||||
|
||||
192
run.sh
Normal file
@@ -0,0 +1,192 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
# 颜色定义
|
||||
RED='\033[0;31m' # 红色
|
||||
GREEN='\033[0;32m' # 绿色
|
||||
YELLOW='\033[0;33m' # 黄色
|
||||
BLUE='\033[0;34m' # 蓝色
|
||||
NC='\033[0m' # 无颜色
|
||||
|
||||
# 获取当前脚本的目录
|
||||
export BASE_DIR=$(cd "$(dirname "$0")"; pwd)
|
||||
echo -e "${GREEN}📁 BASE_DIR: ${BASE_DIR}${NC}"
|
||||
|
||||
# 检查 config.json 文件是否存在
|
||||
check_config_file() {
|
||||
if [ ! -f "${BASE_DIR}/config.json" ]; then
|
||||
echo -e "${RED}❌ 错误:未找到 config.json 文件。请确保 config.json 存在于当前目录。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查 Python 版本是否大于等于 3.7,并检查 pip 是否可用
|
||||
check_python_version() {
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
echo -e "${RED}❌ 错误:未找到 Python3。请安装 Python 3.7 或以上版本。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
||||
PYTHON_MAJOR=$(echo "$PYTHON_VERSION" | cut -d'.' -f1)
|
||||
PYTHON_MINOR=$(echo "$PYTHON_VERSION" | cut -d'.' -f2)
|
||||
|
||||
if (( PYTHON_MAJOR < 3 || (PYTHON_MAJOR == 3 && PYTHON_MINOR < 7) )); then
|
||||
echo -e "${RED}❌ 错误:Python 版本为 ${PYTHON_VERSION}。请安装 Python 3.7 或以上版本。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! python3 -m pip --version &> /dev/null; then
|
||||
echo -e "${RED}❌ 错误:未找到 pip。请安装 pip。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查并安装缺失的依赖
|
||||
install_dependencies() {
|
||||
echo -e "${YELLOW}⏳ 正在安装依赖...${NC}"
|
||||
|
||||
if [ ! -f "${BASE_DIR}/requirements.txt" ]; then
|
||||
echo -e "${RED}❌ 错误:未找到 requirements.txt 文件。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 安装 requirements.txt 中的依赖,使用清华大学的 PyPI 镜像
|
||||
pip3 install -r "${BASE_DIR}/requirements.txt" -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# 处理 requirements-optional.txt(如果存在)
|
||||
if [ -f "${BASE_DIR}/requirements-optional.txt" ]; then
|
||||
echo -e "${YELLOW}⏳ 正在安装可选的依赖...${NC}"
|
||||
pip3 install -r "${BASE_DIR}/requirements-optional.txt" -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
fi
|
||||
}
|
||||
|
||||
# 启动项目
|
||||
run_project() {
|
||||
echo -e "${GREEN}🚀 准备启动项目...${NC}"
|
||||
cd "${BASE_DIR}"
|
||||
sleep 2
|
||||
|
||||
|
||||
# 判断操作系统类型
|
||||
OS_TYPE=$(uname)
|
||||
|
||||
if [[ "$OS_TYPE" == "Linux" ]]; then
|
||||
# 在 Linux 上使用 setsid
|
||||
setsid python3 "${BASE_DIR}/app.py" > "${BASE_DIR}/nohup.out" 2>&1 &
|
||||
echo -e "${GREEN}🚀 正在启动 ChatGPT-on-WeChat (Linux)...${NC}"
|
||||
elif [[ "$OS_TYPE" == "Darwin" ]]; then
|
||||
# 在 macOS 上直接运行
|
||||
python3 "${BASE_DIR}/app.py" > "${BASE_DIR}/nohup.out" 2>&1 &
|
||||
echo -e "${GREEN}🚀 正在启动 ChatGPT-on-WeChat (macOS)...${NC}"
|
||||
else
|
||||
echo -e "${RED}❌ 错误:不支持的操作系统 ${OS_TYPE}。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
# 显示日志输出,供用户扫码
|
||||
tail -n 30 -f "${BASE_DIR}/nohup.out"
|
||||
|
||||
}
|
||||
# 更新项目
|
||||
update_project() {
|
||||
echo -e "${GREEN}🔄 准备更新项目,现在停止项目...${NC}"
|
||||
cd "${BASE_DIR}"
|
||||
|
||||
# 停止项目
|
||||
stop_project
|
||||
echo -e "${GREEN}🔄 开始更新项目...${NC}"
|
||||
# 更新代码,从 git 仓库拉取最新代码
|
||||
if [ -d .git ]; then
|
||||
GIT_PULL_OUTPUT=$(git pull)
|
||||
if [ $? -eq 0 ]; then
|
||||
if [[ "$GIT_PULL_OUTPUT" == *"Already up to date."* ]]; then
|
||||
echo -e "${GREEN}✅ 代码已经是最新的。${NC}"
|
||||
else
|
||||
echo -e "${GREEN}✅ 代码更新完成。${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW}⚠️ 从 GitHub 更新失败,尝试切换到 Gitee 仓库...${NC}"
|
||||
# 更改远程仓库为 Gitee
|
||||
git remote set-url origin https://gitee.com/zhayujie/chatgpt-on-wechat.git
|
||||
GIT_PULL_OUTPUT=$(git pull)
|
||||
if [ $? -eq 0 ]; then
|
||||
if [[ "$GIT_PULL_OUTPUT" == *"Already up to date."* ]]; then
|
||||
echo -e "${GREEN}✅ 代码已经是最新的。${NC}"
|
||||
else
|
||||
echo -e "${GREEN}✅ 从 Gitee 更新成功。${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}❌ 错误:从 Gitee 更新仍然失败,请检查网络连接。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}❌ 错误:当前目录不是 git 仓库,无法更新代码。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 安装依赖
|
||||
install_dependencies
|
||||
|
||||
# 启动项目
|
||||
run_project
|
||||
}
|
||||
|
||||
# 停止项目
|
||||
stop_project() {
|
||||
echo -e "${GREEN}🛑 正在停止项目...${NC}"
|
||||
cd "${BASE_DIR}"
|
||||
pid=$(ps ax | grep -i app.py | grep "${BASE_DIR}" | grep python3 | grep -v grep | awk '{print $1}')
|
||||
if [ -z "$pid" ] ; then
|
||||
echo -e "${YELLOW}⚠️ 未找到正在运行的 ChatGPT-on-WeChat。${NC}"
|
||||
return
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}🛑 正在运行的 ChatGPT-on-WeChat (PID: ${pid})${NC}"
|
||||
|
||||
kill ${pid}
|
||||
sleep 3
|
||||
|
||||
if ps -p $pid > /dev/null; then
|
||||
echo -e "${YELLOW}⚠️ 进程未停止,尝试强制终止...${NC}"
|
||||
kill -9 ${pid}
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}✅ 已停止 ChatGPT-on-WeChat (PID: ${pid})${NC}"
|
||||
}
|
||||
|
||||
# 主函数,根据用户参数执行操作
|
||||
case "$1" in
|
||||
start)
|
||||
check_config_file
|
||||
check_python_version
|
||||
run_project
|
||||
;;
|
||||
stop)
|
||||
stop_project
|
||||
;;
|
||||
restart)
|
||||
stop_project
|
||||
check_config_file
|
||||
check_python_version
|
||||
run_project
|
||||
;;
|
||||
update)
|
||||
check_config_file
|
||||
check_python_version
|
||||
update_project
|
||||
;;
|
||||
*)
|
||||
echo -e "${YELLOW}=========================================${NC}"
|
||||
echo -e "${YELLOW}用法:${GREEN}$0 ${BLUE}{start|stop|restart|update}${NC}"
|
||||
echo -e "${YELLOW}示例:${NC}"
|
||||
echo -e " ${GREEN}$0 ${BLUE}start${NC}"
|
||||
echo -e " ${GREEN}$0 ${BLUE}stop${NC}"
|
||||
echo -e " ${GREEN}$0 ${BLUE}restart${NC}"
|
||||
echo -e " ${GREEN}$0 ${BLUE}update${NC}"
|
||||
echo -e "${YELLOW}=========================================${NC}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
278
simple_login_form_test.html
Normal file
@@ -0,0 +1,278 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>登录</title>
|
||||
<style>
|
||||
/* Reset and base */
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
body, html {
|
||||
margin: 0; padding: 0; height: 100%;
|
||||
font-family: "Segoe UI", Tahoma, Geneva, Verdana, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea, #764ba2);
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.login-container {
|
||||
background: rgba(255, 255, 255, 0.95);
|
||||
padding: 2.5rem 3rem;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 8px 24px rgba(0,0,0,0.15);
|
||||
width: 100%;
|
||||
max-width: 400px;
|
||||
}
|
||||
h2 {
|
||||
margin-bottom: 1.5rem;
|
||||
color: #333;
|
||||
text-align: center;
|
||||
}
|
||||
form {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
label {
|
||||
font-weight: 600;
|
||||
margin-bottom: 0.4rem;
|
||||
color: #444;
|
||||
}
|
||||
input[type="text"],
|
||||
input[type="email"],
|
||||
input[type="password"] {
|
||||
padding: 0.6rem 0.8rem;
|
||||
font-size: 1rem;
|
||||
border: 1.8px solid #ccc;
|
||||
border-radius: 6px;
|
||||
transition: border-color 0.3s ease;
|
||||
outline-offset: 2px;
|
||||
}
|
||||
input[type="text"]:focus,
|
||||
input[type="email"]:focus,
|
||||
input[type="password"]:focus {
|
||||
border-color: #667eea;
|
||||
}
|
||||
.password-wrapper {
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
.toggle-password {
|
||||
position: absolute;
|
||||
right: 0.8rem;
|
||||
background: none;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
font-size: 1rem;
|
||||
color: #667eea;
|
||||
user-select: none;
|
||||
}
|
||||
.login-button {
|
||||
margin-top: 1.5rem;
|
||||
padding: 0.75rem;
|
||||
font-size: 1.1rem;
|
||||
font-weight: 700;
|
||||
background-color: #667eea;
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.3s ease;
|
||||
}
|
||||
.login-button:disabled {
|
||||
background-color: #a3a9f7;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
.forgot-password {
|
||||
margin-top: 1rem;
|
||||
text-align: right;
|
||||
}
|
||||
.forgot-password a {
|
||||
color: #667eea;
|
||||
text-decoration: none;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
.forgot-password a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.error-message {
|
||||
margin-top: 1rem;
|
||||
color: #d93025;
|
||||
font-weight: 600;
|
||||
text-align: center;
|
||||
}
|
||||
.loading-spinner {
|
||||
border: 3px solid #f3f3f3;
|
||||
border-top: 3px solid #667eea;
|
||||
border-radius: 50%;
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
animation: spin 1s linear infinite;
|
||||
display: inline-block;
|
||||
vertical-align: middle;
|
||||
margin-left: 8px;
|
||||
}
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg);}
|
||||
100% { transform: rotate(360deg);}
|
||||
}
|
||||
/* Responsive */
|
||||
@media (max-width: 480px) {
|
||||
.login-container {
|
||||
margin: 1rem;
|
||||
padding: 2rem 1.5rem;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="login-container" role="main" aria-label="登录表单">
|
||||
<h2>用户登录</h2>
|
||||
<form id="loginForm" novalidate>
|
||||
<label for="usernameEmail">用户名或邮箱</label>
|
||||
<input type="text" id="usernameEmail" name="usernameEmail" autocomplete="username" placeholder="请输入用户名或邮箱" required aria-describedby="usernameEmailError" />
|
||||
<div id="usernameEmailError" class="error-message" aria-live="polite"></div>
|
||||
|
||||
<label for="password" style="margin-top:1rem;">密码</label>
|
||||
<div class="password-wrapper">
|
||||
<input type="password" id="password" name="password" autocomplete="current-password" placeholder="请输入密码" required minlength="6" aria-describedby="passwordError" />
|
||||
<button type="button" class="toggle-password" aria-label="切换密码可见性" title="切换密码可见性">👁️</button>
|
||||
</div>
|
||||
<div id="passwordError" class="error-message" aria-live="polite"></div>
|
||||
|
||||
<button type="submit" id="loginButton" class="login-button" disabled>登录</button>
|
||||
<div class="forgot-password">
|
||||
<a href="/forgot-password.html" target="_blank" rel="noopener noreferrer">忘记密码?</a>
|
||||
</div>
|
||||
<div id="submitError" class="error-message" aria-live="polite"></div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
(function(){
|
||||
const usernameEmailInput = document.getElementById('usernameEmail');
|
||||
const passwordInput = document.getElementById('password');
|
||||
const loginButton = document.getElementById('loginButton');
|
||||
const usernameEmailError = document.getElementById('usernameEmailError');
|
||||
const passwordError = document.getElementById('passwordError');
|
||||
const submitError = document.getElementById('submitError');
|
||||
const togglePasswordBtn = document.querySelector('.toggle-password');
|
||||
const form = document.getElementById('loginForm');
|
||||
|
||||
// 校验用户名或邮箱格式
|
||||
function validateUsernameEmail(value) {
|
||||
if (!value.trim()) {
|
||||
return "用户名或邮箱不能为空";
|
||||
}
|
||||
// 简单邮箱正则
|
||||
const emailRegex = /^[^\s@]+@[^\s@]+\.[^\s@]+$/;
|
||||
// 用户名规则:允许字母数字下划线,长度3-20
|
||||
const usernameRegex = /^[a-zA-Z0-9_]{3,20}$/;
|
||||
if (emailRegex.test(value)) {
|
||||
return "";
|
||||
} else if (usernameRegex.test(value)) {
|
||||
return "";
|
||||
} else {
|
||||
return "请输入有效的用户名或邮箱格式";
|
||||
}
|
||||
}
|
||||
|
||||
// 校验密码格式
|
||||
function validatePassword(value) {
|
||||
if (!value) {
|
||||
return "密码不能为空";
|
||||
}
|
||||
if (value.length < 6) {
|
||||
return "密码长度不能少于6位";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
// 实时校验并更新错误提示和按钮状态
|
||||
function validateForm() {
|
||||
const usernameEmailVal = usernameEmailInput.value;
|
||||
const passwordVal = passwordInput.value;
|
||||
|
||||
const usernameEmailErrMsg = validateUsernameEmail(usernameEmailVal);
|
||||
const passwordErrMsg = validatePassword(passwordVal);
|
||||
|
||||
usernameEmailError.textContent = usernameEmailErrMsg;
|
||||
passwordError.textContent = passwordErrMsg;
|
||||
submitError.textContent = "";
|
||||
|
||||
const isValid = !usernameEmailErrMsg && !passwordErrMsg;
|
||||
loginButton.disabled = !isValid;
|
||||
return isValid;
|
||||
}
|
||||
|
||||
// 密码可见切换
|
||||
togglePasswordBtn.addEventListener('click', () => {
|
||||
if (passwordInput.type === 'password') {
|
||||
passwordInput.type = 'text';
|
||||
togglePasswordBtn.textContent = '🙈';
|
||||
togglePasswordBtn.setAttribute('aria-label', '隐藏密码');
|
||||
togglePasswordBtn.setAttribute('title', '隐藏密码');
|
||||
} else {
|
||||
passwordInput.type = 'password';
|
||||
togglePasswordBtn.textContent = '👁️';
|
||||
togglePasswordBtn.setAttribute('aria-label', '显示密码');
|
||||
togglePasswordBtn.setAttribute('title', '显示密码');
|
||||
}
|
||||
});
|
||||
|
||||
// 监听输入事件实时校验
|
||||
usernameEmailInput.addEventListener('input', validateForm);
|
||||
passwordInput.addEventListener('input', validateForm);
|
||||
|
||||
// 模拟登录请求
|
||||
function fakeLoginRequest(data) {
|
||||
return new Promise((resolve, reject) => {
|
||||
setTimeout(() => {
|
||||
// 模拟用户名/邮箱为 "user" 或 "user@example.com" 且密码为 "password123" 才成功
|
||||
const validUsers = ["user", "user@example.com"];
|
||||
if (validUsers.includes(data.usernameEmail.toLowerCase()) && data.password === "password123") {
|
||||
resolve();
|
||||
} else {
|
||||
reject(new Error("用户名或密码错误"));
|
||||
}
|
||||
}, 1500);
|
||||
});
|
||||
}
|
||||
|
||||
// 表单提交处理
|
||||
form.addEventListener('submit', async (e) => {
|
||||
e.preventDefault();
|
||||
if (!validateForm()) return;
|
||||
|
||||
loginButton.disabled = true;
|
||||
const originalText = loginButton.textContent;
|
||||
loginButton.textContent = "登录中";
|
||||
const spinner = document.createElement('span');
|
||||
spinner.className = 'loading-spinner';
|
||||
loginButton.appendChild(spinner);
|
||||
submitError.textContent = "";
|
||||
|
||||
try {
|
||||
await fakeLoginRequest({
|
||||
usernameEmail: usernameEmailInput.value.trim(),
|
||||
password: passwordInput.value
|
||||
});
|
||||
// 登录成功跳转(此处用alert模拟)
|
||||
alert("登录成功,跳转到用户主页");
|
||||
// window.location.href = "/user-home.html"; // 实际跳转
|
||||
} catch (err) {
|
||||
submitError.textContent = err.message;
|
||||
} finally {
|
||||
loginButton.disabled = false;
|
||||
loginButton.textContent = originalText;
|
||||
}
|
||||
});
|
||||
|
||||
// 页面加载时校验一次,防止缓存值导致按钮状态异常
|
||||
validateForm();
|
||||
})();
|
||||
<\/script>
|
||||
<\/body>
|
||||
<\/html>
|
||||
@@ -8,6 +8,7 @@ Description:
|
||||
|
||||
"""
|
||||
|
||||
import http.client
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
@@ -61,6 +62,69 @@ def text_to_speech_aliyun(url, text, appkey, token):
|
||||
|
||||
return output_file
|
||||
|
||||
def speech_to_text_aliyun(url, audioContent, appkey, token):
|
||||
"""
|
||||
使用阿里云的语音识别服务识别音频文件中的语音。
|
||||
|
||||
参数:
|
||||
- url (str): 阿里云语音识别服务的端点URL。
|
||||
- audioContent (byte): pcm音频数据。
|
||||
- appkey (str): 您的阿里云appkey。
|
||||
- token (str): 阿里云API的认证令牌。
|
||||
|
||||
返回值:
|
||||
- str: 成功时输出识别到的文本,否则为None。
|
||||
"""
|
||||
format = 'pcm'
|
||||
sample_rate = 16000
|
||||
enablePunctuationPrediction = True
|
||||
enableInverseTextNormalization = True
|
||||
enableVoiceDetection = False
|
||||
|
||||
# 设置RESTful请求参数
|
||||
request = url + '?appkey=' + appkey
|
||||
request = request + '&format=' + format
|
||||
request = request + '&sample_rate=' + str(sample_rate)
|
||||
|
||||
if enablePunctuationPrediction :
|
||||
request = request + '&enable_punctuation_prediction=' + 'true'
|
||||
|
||||
if enableInverseTextNormalization :
|
||||
request = request + '&enable_inverse_text_normalization=' + 'true'
|
||||
|
||||
if enableVoiceDetection :
|
||||
request = request + '&enable_voice_detection=' + 'true'
|
||||
|
||||
host = 'nls-gateway-cn-shanghai.aliyuncs.com'
|
||||
|
||||
# 设置HTTPS请求头部
|
||||
httpHeaders = {
|
||||
'X-NLS-Token': token,
|
||||
'Content-type': 'application/octet-stream',
|
||||
'Content-Length': len(audioContent)
|
||||
}
|
||||
|
||||
conn = http.client.HTTPSConnection(host)
|
||||
conn.request(method='POST', url=request, body=audioContent, headers=httpHeaders)
|
||||
|
||||
response = conn.getresponse()
|
||||
body = response.read()
|
||||
try:
|
||||
body = json.loads(body)
|
||||
status = body['status']
|
||||
if status == 20000000 :
|
||||
result = body['result']
|
||||
if result :
|
||||
logger.info(f"阿里云语音识别到了:{result}")
|
||||
conn.close()
|
||||
return result
|
||||
else :
|
||||
logger.error(f"语音识别失败,状态码: {status}")
|
||||
except ValueError:
|
||||
logger.error(f"语音识别失败,收到非JSON格式的数据: {body}")
|
||||
conn.close()
|
||||
return None
|
||||
|
||||
|
||||
class AliyunTokenGenerator:
|
||||
"""
|
||||
|
||||
@@ -15,9 +15,9 @@ import time
|
||||
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from voice.audio_convert import get_pcm_from_wav
|
||||
from voice.voice import Voice
|
||||
from voice.ali.ali_api import AliyunTokenGenerator
|
||||
from voice.ali.ali_api import text_to_speech_aliyun
|
||||
from voice.ali.ali_api import AliyunTokenGenerator, speech_to_text_aliyun, text_to_speech_aliyun
|
||||
from config import conf
|
||||
|
||||
|
||||
@@ -34,7 +34,8 @@ class AliVoice(Voice):
|
||||
self.token = None
|
||||
self.token_expire_time = 0
|
||||
# 默认复用阿里云千问的 access_key 和 access_secret
|
||||
self.api_url = config.get("api_url")
|
||||
self.api_url_voice_to_text = config.get("api_url_voice_to_text")
|
||||
self.api_url_text_to_voice = config.get("api_url_text_to_voice")
|
||||
self.app_key = config.get("app_key")
|
||||
self.access_key_id = conf().get("qwen_access_key_id") or config.get("access_key_id")
|
||||
self.access_key_secret = conf().get("qwen_access_key_secret") or config.get("access_key_secret")
|
||||
@@ -53,7 +54,7 @@ class AliVoice(Voice):
|
||||
r'äöüÄÖÜáéíóúÁÉÍÓÚàèìòùÀÈÌÒÙâêîôûÂÊÎÔÛçÇñÑ,。!?,.]', '', text)
|
||||
# 提取有效的token
|
||||
token_id = self.get_valid_token()
|
||||
fileName = text_to_speech_aliyun(self.api_url, text, self.app_key, token_id)
|
||||
fileName = text_to_speech_aliyun(self.api_url_text_to_voice, text, self.app_key, token_id)
|
||||
if fileName:
|
||||
logger.info("[Ali] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
@@ -61,6 +62,25 @@ class AliVoice(Voice):
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
||||
return reply
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
"""
|
||||
将语音文件转换为文本。
|
||||
|
||||
:param voice_file: 要转换的语音文件。
|
||||
:return: 返回一个Reply对象,其中包含转换得到的文本或错误信息。
|
||||
"""
|
||||
# 提取有效的token
|
||||
token_id = self.get_valid_token()
|
||||
logger.debug("[Ali] voice file name={}".format(voice_file))
|
||||
pcm = get_pcm_from_wav(voice_file)
|
||||
text = speech_to_text_aliyun(self.api_url_voice_to_text, pcm, self.app_key, token_id)
|
||||
if text:
|
||||
logger.info("[Ali] VoicetoText = {}".format(text))
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
|
||||
return reply
|
||||
|
||||
def get_valid_token(self):
|
||||
"""
|
||||
获取有效的阿里云token。
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"api_url": "https://nls-gateway-cn-shanghai.aliyuncs.com/stream/v1/tts",
|
||||
"api_url_text_to_voice": "https://nls-gateway-cn-shanghai.aliyuncs.com/stream/v1/tts",
|
||||
"api_url_voice_to_text": "https://nls-gateway.cn-shanghai.aliyuncs.com/stream/v1/asr",
|
||||
"app_key": "",
|
||||
"access_key_id": "",
|
||||
"access_key_secret": ""
|
||||
|
||||
@@ -6,7 +6,7 @@ from common.log import logger
|
||||
try:
|
||||
import pysilk
|
||||
except ImportError:
|
||||
logger.warn("import pysilk failed, wechaty voice message will not be supported.")
|
||||
logger.debug("import pysilk failed, wechaty voice message will not be supported.")
|
||||
|
||||
from pydub import AudioSegment
|
||||
|
||||
@@ -64,7 +64,9 @@ def any_to_wav(any_path, wav_path):
|
||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
|
||||
return sil_to_wav(any_path, wav_path)
|
||||
audio = AudioSegment.from_file(any_path)
|
||||
audio.export(wav_path, format="wav")
|
||||
audio.set_frame_rate(8000) # 百度语音转写支持8000采样率, pcm_s16le, 单通道语音识别
|
||||
audio.set_channels(1)
|
||||
audio.export(wav_path, format="wav", codec='pcm_s16le')
|
||||
|
||||
|
||||
def any_to_sil(any_path, sil_path):
|
||||
|
||||
@@ -65,7 +65,7 @@ class AzureVoice(Voice):
|
||||
reply = Reply(ReplyType.TEXT, result.text)
|
||||
else:
|
||||
cancel_details = result.cancellation_details
|
||||
logger.error("[Azure] voiceToText error, result={}, errordetails={}".format(result, cancel_details.error_details))
|
||||
logger.error("[Azure] voiceToText error, result={}, errordetails={}".format(result, cancel_details))
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
|
||||
return reply
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""
|
||||
baidu voice service
|
||||
baidu voice service with thread-safe token caching
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import requests
|
||||
|
||||
from aip import AipSpeech
|
||||
|
||||
@@ -14,28 +16,13 @@ from config import conf
|
||||
from voice.audio_convert import get_pcm_from_wav
|
||||
from voice.voice import Voice
|
||||
|
||||
"""
|
||||
百度的语音识别API.
|
||||
dev_pid:
|
||||
- 1936: 普通话远场
|
||||
- 1536:普通话(支持简单的英文识别)
|
||||
- 1537:普通话(纯中文识别)
|
||||
- 1737:英语
|
||||
- 1637:粤语
|
||||
- 1837:四川话
|
||||
要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号,
|
||||
之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key
|
||||
然后在 config.json 中填入这两个值, 以及 app_id, dev_pid
|
||||
"""
|
||||
|
||||
|
||||
class BaiduVoice(Voice):
|
||||
def __init__(self):
|
||||
try:
|
||||
# 读取本地 TTS 参数配置
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
bconf = None
|
||||
if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件
|
||||
if not os.path.exists(config_path):
|
||||
bconf = {"lang": "zh", "ctp": 1, "spd": 5, "pit": 5, "vol": 5, "per": 0}
|
||||
with open(config_path, "w") as fw:
|
||||
json.dump(bconf, fw, indent=4)
|
||||
@@ -47,48 +34,139 @@ class BaiduVoice(Voice):
|
||||
self.api_key = str(conf().get("baidu_api_key"))
|
||||
self.secret_key = str(conf().get("baidu_secret_key"))
|
||||
self.dev_id = conf().get("baidu_dev_pid")
|
||||
self.lang = bconf["lang"]
|
||||
self.ctp = bconf["ctp"]
|
||||
self.spd = bconf["spd"]
|
||||
self.pit = bconf["pit"]
|
||||
self.vol = bconf["vol"]
|
||||
self.per = bconf["per"]
|
||||
|
||||
self.lang = bconf["lang"]
|
||||
self.ctp = bconf["ctp"]
|
||||
self.spd = bconf["spd"]
|
||||
self.pit = bconf["pit"]
|
||||
self.vol = bconf["vol"]
|
||||
self.per = bconf["per"]
|
||||
|
||||
# 百度 SDK 客户端(短文本合成 & 语音识别)
|
||||
self.client = AipSpeech(self.app_id, self.api_key, self.secret_key)
|
||||
|
||||
# access_token 缓存与锁
|
||||
self._access_token = None
|
||||
self._token_expire_ts = 0
|
||||
self._token_lock = threading.Lock()
|
||||
except Exception as e:
|
||||
logger.warn("BaiduVoice init failed: %s, ignore " % e)
|
||||
logger.warn("BaiduVoice init failed: %s, ignore" % e)
|
||||
|
||||
def _get_access_token(self):
|
||||
# 多线程安全获取 token
|
||||
with self._token_lock:
|
||||
now = time.time()
|
||||
if self._access_token and now < self._token_expire_ts:
|
||||
return self._access_token
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.api_key,
|
||||
"client_secret": self.secret_key,
|
||||
}
|
||||
resp = requests.post(url, params=params).json()
|
||||
token = resp.get("access_token")
|
||||
expires_in = resp.get("expires_in", 2592000)
|
||||
if token:
|
||||
self._access_token = token
|
||||
self._token_expire_ts = now + expires_in - 60 # 提前 1 分钟过期
|
||||
return token
|
||||
else:
|
||||
logger.error("BaiduVoice _get_access_token failed: %s", resp)
|
||||
return None
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
# 识别本地文件
|
||||
logger.debug("[Baidu] voice file name={}".format(voice_file))
|
||||
logger.debug("[Baidu] recognize voice file=%s", voice_file)
|
||||
pcm = get_pcm_from_wav(voice_file)
|
||||
res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id})
|
||||
if res["err_no"] == 0:
|
||||
logger.info("百度语音识别到了:{}".format(res["result"]))
|
||||
if res.get("err_no") == 0:
|
||||
text = "".join(res["result"])
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
logger.info("[Baidu] ASR result: %s", text)
|
||||
return Reply(ReplyType.TEXT, text)
|
||||
else:
|
||||
logger.info("百度语音识别出错了: {}".format(res["err_msg"]))
|
||||
if res["err_msg"] == "request pv too much":
|
||||
logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费")
|
||||
reply = Reply(ReplyType.ERROR, "百度语音识别出错了;{0}".format(res["err_msg"]))
|
||||
return reply
|
||||
err = res.get("err_msg", "")
|
||||
logger.error("[Baidu] ASR error: %s", err)
|
||||
return Reply(ReplyType.ERROR, f"语音识别失败:{err}")
|
||||
|
||||
def _long_text_synthesis(self, text):
|
||||
token = self._get_access_token()
|
||||
if not token:
|
||||
return Reply(ReplyType.ERROR, "获取百度 access_token 失败")
|
||||
|
||||
# 创建合成任务
|
||||
create_url = f"https://aip.baidubce.com/rpc/2.0/tts/v1/create?access_token={token}"
|
||||
payload = {
|
||||
"text": text,
|
||||
"format": "mp3-16k",
|
||||
"voice": 0,
|
||||
"lang": self.lang,
|
||||
"speed": self.spd,
|
||||
"pitch": self.pit,
|
||||
"volume": self.vol,
|
||||
"enable_subtitle": 0,
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
create_resp = requests.post(create_url, headers=headers, json=payload).json()
|
||||
task_id = create_resp.get("task_id")
|
||||
if not task_id:
|
||||
logger.error("[Baidu] 长文本合成创建任务失败: %s", create_resp)
|
||||
return Reply(ReplyType.ERROR, "长文本合成任务提交失败")
|
||||
logger.info("[Baidu] 长文本合成任务已提交 task_id=%s", task_id)
|
||||
|
||||
# 轮询查询任务状态
|
||||
query_url = f"https://aip.baidubce.com/rpc/2.0/tts/v1/query?access_token={token}"
|
||||
for _ in range(100):
|
||||
time.sleep(3)
|
||||
resp = requests.post(query_url, headers=headers, json={"task_ids":[task_id]})
|
||||
result = resp.json()
|
||||
infos = result.get("tasks_info") or result.get("tasks") or []
|
||||
if not infos:
|
||||
continue
|
||||
info = infos[0]
|
||||
status = info.get("task_status")
|
||||
if status == "Success":
|
||||
task_res = info.get("task_result", {})
|
||||
audio_url = task_res.get("audio_address") or task_res.get("speech_url")
|
||||
break
|
||||
elif status == "Running":
|
||||
continue
|
||||
else:
|
||||
logger.error("[Baidu] 长文本合成失败: %s", info)
|
||||
return Reply(ReplyType.ERROR, "长文本合成执行失败")
|
||||
else:
|
||||
return Reply(ReplyType.ERROR, "长文本合成超时,请稍后重试")
|
||||
|
||||
# 下载并保存音频
|
||||
audio_data = requests.get(audio_url).content
|
||||
fn = TmpDir().path() + f"reply-long-{int(time.time())}-{hash(text)&0x7FFFFFFF}.mp3"
|
||||
with open(fn, "wb") as f:
|
||||
f.write(audio_data)
|
||||
logger.info("[Baidu] 长文本合成 success: %s", fn)
|
||||
return Reply(ReplyType.VOICE, fn)
|
||||
|
||||
def textToVoice(self, text):
|
||||
result = self.client.synthesis(
|
||||
text,
|
||||
self.lang,
|
||||
self.ctp,
|
||||
{"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per},
|
||||
)
|
||||
if not isinstance(result, dict):
|
||||
# Avoid the same filename under multithreading
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
with open(fileName, "wb") as f:
|
||||
f.write(result)
|
||||
logger.info("[Baidu] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
else:
|
||||
logger.error("[Baidu] textToVoice error={}".format(result))
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
||||
return reply
|
||||
try:
|
||||
# GBK 编码字节长度
|
||||
gbk_len = len(text.encode("gbk", errors="ignore"))
|
||||
if gbk_len <= 1024:
|
||||
# 短文本走 SDK 合成
|
||||
result = self.client.synthesis(
|
||||
text, self.lang, self.ctp,
|
||||
{"spd":self.spd, "pit":self.pit, "vol":self.vol, "per":self.per}
|
||||
)
|
||||
if not isinstance(result, dict):
|
||||
fn = TmpDir().path() + f"reply-{int(time.time())}-{hash(text)&0x7FFFFFFF}.mp3"
|
||||
with open(fn, "wb") as f:
|
||||
f.write(result)
|
||||
logger.info("[Baidu] 短文本合成 success: %s", fn)
|
||||
return Reply(ReplyType.VOICE, fn)
|
||||
else:
|
||||
logger.error("[Baidu] 短文本合成 error: %s", result)
|
||||
return Reply(ReplyType.ERROR, "短文本语音合成失败")
|
||||
else:
|
||||
# 长文本
|
||||
return self._long_text_synthesis(text)
|
||||
except Exception as e:
|
||||
logger.error("BaiduVoice textToVoice exception: %s", e)
|
||||
return Reply(ReplyType.ERROR, f"合成异常:{e}")
|
||||
|
||||
|
||||
50
voice/edge/edge_voice.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import time
|
||||
|
||||
import edge_tts
|
||||
import asyncio
|
||||
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from voice.voice import Voice
|
||||
|
||||
|
||||
class EdgeVoice(Voice):
|
||||
|
||||
def __init__(self):
|
||||
'''
|
||||
# 普通话
|
||||
zh-CN-XiaoxiaoNeural
|
||||
zh-CN-XiaoyiNeural
|
||||
zh-CN-YunjianNeural
|
||||
zh-CN-YunxiNeural
|
||||
zh-CN-YunxiaNeural
|
||||
zh-CN-YunyangNeural
|
||||
# 地方口音
|
||||
zh-CN-liaoning-XiaobeiNeural
|
||||
zh-CN-shaanxi-XiaoniNeural
|
||||
# 粤语
|
||||
zh-HK-HiuGaaiNeural
|
||||
zh-HK-HiuMaanNeural
|
||||
zh-HK-WanLungNeural
|
||||
# 湾湾腔
|
||||
zh-TW-HsiaoChenNeural
|
||||
zh-TW-HsiaoYuNeural
|
||||
zh-TW-YunJheNeural
|
||||
'''
|
||||
self.voice = "zh-CN-YunjianNeural"
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
pass
|
||||
|
||||
async def gen_voice(self, text, fileName):
|
||||
communicate = edge_tts.Communicate(text, self.voice)
|
||||
await communicate.save(fileName)
|
||||
|
||||
def textToVoice(self, text):
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
|
||||
asyncio.run(self.gen_voice(text, fileName))
|
||||
|
||||
logger.info("[EdgeTTS] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
return Reply(ReplyType.VOICE, fileName)
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
|
||||
from elevenlabs import set_api_key,generate
|
||||
|
||||
from elevenlabs.client import ElevenLabs
|
||||
from elevenlabs import save
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
@@ -9,7 +9,7 @@ from voice.voice import Voice
|
||||
from config import conf
|
||||
|
||||
XI_API_KEY = conf().get("xi_api_key")
|
||||
set_api_key(XI_API_KEY)
|
||||
client = ElevenLabs(api_key=XI_API_KEY)
|
||||
name = conf().get("xi_voice_id")
|
||||
|
||||
class ElevenLabsVoice(Voice):
|
||||
@@ -21,13 +21,12 @@ class ElevenLabsVoice(Voice):
|
||||
pass
|
||||
|
||||
def textToVoice(self, text):
|
||||
audio = generate(
|
||||
audio = client.generate(
|
||||
text=text,
|
||||
voice=name,
|
||||
model='eleven_multilingual_v1'
|
||||
model='eleven_multilingual_v2'
|
||||
)
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
with open(fileName, "wb") as f:
|
||||
f.write(audio)
|
||||
save(audio, fileName)
|
||||
logger.info("[ElevenLabs] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
return Reply(ReplyType.VOICE, fileName)
|
||||
@@ -42,4 +42,16 @@ def create_voice(voice_type):
|
||||
from voice.ali.ali_voice import AliVoice
|
||||
|
||||
return AliVoice()
|
||||
elif voice_type == "edge":
|
||||
from voice.edge.edge_voice import EdgeVoice
|
||||
|
||||
return EdgeVoice()
|
||||
elif voice_type == "xunfei":
|
||||
from voice.xunfei.xunfei_voice import XunfeiVoice
|
||||
|
||||
return XunfeiVoice()
|
||||
elif voice_type == "tencent":
|
||||
from voice.tencent.tencent_voice import TencentVoice
|
||||
|
||||
return TencentVoice()
|
||||
raise RuntimeError
|
||||
|
||||
@@ -19,7 +19,7 @@ class LinkAIVoice(Voice):
|
||||
def voiceToText(self, voice_file):
|
||||
logger.debug("[LinkVoice] voice file name={}".format(voice_file))
|
||||
try:
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/audio/transcriptions"
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/audio/transcriptions"
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
model = None
|
||||
if not conf().get("text_to_voice") or conf().get("voice_to_text") == "openai":
|
||||
@@ -54,7 +54,7 @@ class LinkAIVoice(Voice):
|
||||
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/audio/speech"
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/audio/speech"
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
model = const.TTS_1
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
|
||||
@@ -21,8 +21,21 @@ class OpenaiVoice(Voice):
|
||||
logger.debug("[Openai] voice file name={}".format(voice_file))
|
||||
try:
|
||||
file = open(voice_file, "rb")
|
||||
result = openai.Audio.transcribe("whisper-1", file)
|
||||
text = result["text"]
|
||||
api_base = conf().get("open_ai_api_base") or "https://api.openai.com/v1"
|
||||
url = f'{api_base}/audio/transcriptions'
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + conf().get("open_ai_api_key"),
|
||||
# 'Content-Type': 'multipart/form-data' # 加了会报错,不知道什么原因
|
||||
}
|
||||
files = {
|
||||
"file": file,
|
||||
}
|
||||
data = {
|
||||
"model": "whisper-1",
|
||||
}
|
||||
response = requests.post(url, headers=headers, files=files, data=data)
|
||||
response_data = response.json()
|
||||
text = response_data['text']
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file))
|
||||
except Exception as e:
|
||||
|
||||
5
voice/tencent/config.json.template
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"voice_type": 1003,
|
||||
"secret_id": "YOUR_SECRET_ID",
|
||||
"secret_key": "YOUR_SECRET_KEY"
|
||||
}
|
||||
119
voice/tencent/tencent_voice.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import json
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
from voice.voice import Voice
|
||||
from common.log import logger
|
||||
from tencentcloud.common import credential
|
||||
from tencentcloud.asr.v20190614 import asr_client, models as asr_models
|
||||
from tencentcloud.tts.v20190823 import tts_client, models as tts_models
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
class TencentVoice(Voice):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.secret_id = None
|
||||
self.secret_key = None
|
||||
self.voice_type = 1003
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""
|
||||
从本地配置文件加载配置
|
||||
"""
|
||||
try:
|
||||
config_path = os.path.join(os.path.dirname(__file__), 'config.json')
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
self.secret_id = config.get('secret_id')
|
||||
self.secret_key = config.get('secret_key')
|
||||
self.voice_type = config.get('voice_type', self.voice_type)
|
||||
if not self.secret_id or not self.secret_key:
|
||||
logger.error("[Tencent] Missing credentials in config.json")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tencent] Failed to load config: {e}")
|
||||
|
||||
def setup(self, config):
|
||||
"""
|
||||
设置配置信息(保留此方法用于向后兼容)
|
||||
"""
|
||||
pass
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
"""
|
||||
将语音文件转换为文本
|
||||
"""
|
||||
try:
|
||||
# 实例化认证对象
|
||||
cred = credential.Credential(self.secret_id, self.secret_key)
|
||||
|
||||
# 实例化客户端
|
||||
client = asr_client.AsrClient(cred, "ap-guangzhou")
|
||||
|
||||
# 读取音频文件
|
||||
with open(voice_file, 'rb') as f:
|
||||
audio_data = f.read()
|
||||
|
||||
# 进行base64编码
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
|
||||
# 构造请求对象
|
||||
req = asr_models.SentenceRecognitionRequest()
|
||||
req.ProjectId = 0
|
||||
req.SubServiceType = 2
|
||||
req.EngSerViceType = "16k_zh"
|
||||
req.SourceType = 1
|
||||
req.VoiceFormat = "wav"
|
||||
req.UsrAudioKey = "voice_recognition"
|
||||
req.Data = base64_audio
|
||||
|
||||
# 发起请求
|
||||
resp = client.SentenceRecognition(req)
|
||||
|
||||
# 解析结果
|
||||
if resp.Result:
|
||||
logger.info("[Tencent] Voice to text success: {}".format(resp.Result))
|
||||
return Reply(ReplyType.TEXT, resp.Result)
|
||||
else:
|
||||
logger.warning("[Tencent] Voice to text failed")
|
||||
return Reply(ReplyType.ERROR, "腾讯语音识别失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Tencent] Voice to text error: {}".format(e))
|
||||
return Reply(ReplyType.ERROR, "腾讯语音识别出错:{}".format(str(e)))
|
||||
|
||||
def textToVoice(self, text):
|
||||
"""
|
||||
将文本转换为语音
|
||||
"""
|
||||
try:
|
||||
cred = credential.Credential(self.secret_id, self.secret_key)
|
||||
client = tts_client.TtsClient(cred, "ap-guangzhou")
|
||||
|
||||
req = tts_models.TextToVoiceRequest()
|
||||
req.Text = text
|
||||
req.SessionId = str(int(time.time()))
|
||||
req.Volume = 5
|
||||
req.Speed = 0
|
||||
req.ProjectId = 0
|
||||
req.ModelType = 1
|
||||
req.PrimaryLanguage = 1
|
||||
req.SampleRate = 16000
|
||||
req.VoiceType = self.voice_type # 客服女声
|
||||
|
||||
response = client.TextToVoice(req)
|
||||
|
||||
if response.Audio:
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
with open(fileName, "wb") as f:
|
||||
f.write(base64.b64decode(response.Audio))
|
||||
logger.info("[Tencent] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
return Reply(ReplyType.VOICE, fileName)
|
||||
else:
|
||||
logger.error("[Tencent] textToVoice failed")
|
||||
return Reply(ReplyType.ERROR, "腾讯语音合成失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Tencent] Text to voice error: {}".format(e))
|
||||
return Reply(ReplyType.ERROR, "腾讯语音合成出错:{}".format(str(e)))
|
||||
7
voice/xunfei/config.json.template
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"APPID":"xxx71xxx",
|
||||
"APIKey":"xxxx69058exxxxxx",
|
||||
"APISecret":"xxxx697f0xxxxxx",
|
||||
"BusinessArgsTTS":{"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"},
|
||||
"BusinessArgsASR":{"domain": "iat", "language": "zh_cn", "accent": "mandarin", "vad_eos":10000, "dwa": "wpgs"}
|
||||
}
|
||||
209
voice/xunfei/xunfei_asr.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# Author: njnuko
|
||||
# Email: njnuko@163.com
|
||||
#
|
||||
# 这个文档是基于官方的demo来改的,固体官方demo文档请参考官网
|
||||
#
|
||||
# 语音听写流式 WebAPI 接口调用示例 接口文档(必看):https://doc.xfyun.cn/rest_api/语音听写(流式版).html
|
||||
# webapi 听写服务参考帖子(必看):http://bbs.xfyun.cn/forum.php?mod=viewthread&tid=38947&extra=
|
||||
# 语音听写流式WebAPI 服务,热词使用方式:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--个性化热词,
|
||||
# 设置热词
|
||||
# 注意:热词只能在识别的时候会增加热词的识别权重,需要注意的是增加相应词条的识别率,但并不是绝对的,具体效果以您测试为准。
|
||||
# 语音听写流式WebAPI 服务,方言试用方法:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--识别语种列表
|
||||
# 可添加语种或方言,添加后会显示该方言的参数值
|
||||
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
import websocket
|
||||
import datetime
|
||||
import hashlib
|
||||
import base64
|
||||
import hmac
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
import time
|
||||
import ssl
|
||||
from wsgiref.handlers import format_date_time
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
import _thread as thread
|
||||
import os
|
||||
import wave
|
||||
|
||||
|
||||
STATUS_FIRST_FRAME = 0 # 第一帧的标识
|
||||
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
|
||||
STATUS_LAST_FRAME = 2 # 最后一帧的标识
|
||||
|
||||
#############
|
||||
#whole_dict 是用来存储返回值的,由于带语音修正,所以用dict来存储,有更新的化pop之前的值,最后再合并
|
||||
global whole_dict
|
||||
#这个文档是官方文档改的,这个参数是用来做函数调用时用的
|
||||
global wsParam
|
||||
##############
|
||||
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
def __init__(self, APPID, APIKey, APISecret,BusinessArgs, AudioFile):
|
||||
self.APPID = APPID
|
||||
self.APIKey = APIKey
|
||||
self.APISecret = APISecret
|
||||
self.AudioFile = AudioFile
|
||||
self.BusinessArgs = BusinessArgs
|
||||
# 公共参数(common)
|
||||
self.CommonArgs = {"app_id": self.APPID}
|
||||
# 业务参数(business),更多个性化参数可在官网查看
|
||||
#self.BusinessArgs = {"domain": "iat", "language": "zh_cn", "accent": "mandarin", "vinfo":1,"vad_eos":10000}
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = 'wss://ws-api.xfyun.cn/v2/iat'
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": "ws-api.xfyun.cn"
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = url + '?' + urlencode(v)
|
||||
#print("date: ",date)
|
||||
#print("v: ",v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
#print('websocket url :', url)
|
||||
return url
|
||||
|
||||
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
global whole_dict
|
||||
try:
|
||||
code = json.loads(message)["code"]
|
||||
sid = json.loads(message)["sid"]
|
||||
if code != 0:
|
||||
errMsg = json.loads(message)["message"]
|
||||
print("sid:%s call error:%s code is:%s" % (sid, errMsg, code))
|
||||
else:
|
||||
temp1 = json.loads(message)["data"]["result"]
|
||||
data = json.loads(message)["data"]["result"]["ws"]
|
||||
sn = temp1["sn"]
|
||||
if "rg" in temp1.keys():
|
||||
rep = temp1["rg"]
|
||||
rep_start = rep[0]
|
||||
rep_end = rep[1]
|
||||
for sn in range(rep_start,rep_end+1):
|
||||
#print("before pop",whole_dict)
|
||||
#print("sn",sn)
|
||||
whole_dict.pop(sn,None)
|
||||
#print("after pop",whole_dict)
|
||||
results = ""
|
||||
for i in data:
|
||||
for w in i["cw"]:
|
||||
results += w["w"]
|
||||
whole_dict[sn]=results
|
||||
#print("after add",whole_dict)
|
||||
else:
|
||||
results = ""
|
||||
for i in data:
|
||||
for w in i["cw"]:
|
||||
results += w["w"]
|
||||
whole_dict[sn]=results
|
||||
#print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False)))
|
||||
except Exception as e:
|
||||
print("receive msg,but parse exception:", e)
|
||||
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("### error:", error)
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws,a,b):
|
||||
print("### closed ###")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
global wsParam
|
||||
def run(*args):
|
||||
frameSize = 8000 # 每一帧的音频大小
|
||||
intervel = 0.04 # 发送音频间隔(单位:s)
|
||||
status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
|
||||
|
||||
with wave.open(wsParam.AudioFile, "rb") as fp:
|
||||
while True:
|
||||
buf = fp.readframes(frameSize)
|
||||
# 文件结束
|
||||
if not buf:
|
||||
status = STATUS_LAST_FRAME
|
||||
# 第一帧处理
|
||||
# 发送第一帧音频,带business 参数
|
||||
# appid 必须带上,只需第一帧发送
|
||||
if status == STATUS_FIRST_FRAME:
|
||||
d = {"common": wsParam.CommonArgs,
|
||||
"business": wsParam.BusinessArgs,
|
||||
"data": {"status": 0, "format": "audio/L16;rate=16000","audio": str(base64.b64encode(buf), 'utf-8'), "encoding": "raw"}}
|
||||
d = json.dumps(d)
|
||||
ws.send(d)
|
||||
status = STATUS_CONTINUE_FRAME
|
||||
# 中间帧处理
|
||||
elif status == STATUS_CONTINUE_FRAME:
|
||||
d = {"data": {"status": 1, "format": "audio/L16;rate=16000",
|
||||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||
"encoding": "raw"}}
|
||||
ws.send(json.dumps(d))
|
||||
# 最后一帧处理
|
||||
elif status == STATUS_LAST_FRAME:
|
||||
d = {"data": {"status": 2, "format": "audio/L16;rate=16000",
|
||||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||
"encoding": "raw"}}
|
||||
ws.send(json.dumps(d))
|
||||
time.sleep(1)
|
||||
break
|
||||
# 模拟音频采样间隔
|
||||
time.sleep(intervel)
|
||||
ws.close()
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
#提供给xunfei_voice调用的函数
|
||||
def xunfei_asr(APPID,APISecret,APIKey,BusinessArgsASR,AudioFile):
|
||||
global whole_dict
|
||||
global wsParam
|
||||
whole_dict = {}
|
||||
wsParam1 = Ws_Param(APPID=APPID, APISecret=APISecret,
|
||||
APIKey=APIKey,BusinessArgs=BusinessArgsASR,
|
||||
AudioFile=AudioFile)
|
||||
#wsParam是global变量,给上面on_open函数调用使用的
|
||||
wsParam = wsParam1
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
|
||||
ws.on_open = on_open
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
#把字典的值合并起来做最后识别的输出
|
||||
whole_words = ""
|
||||
for i in sorted(whole_dict.keys()):
|
||||
whole_words += whole_dict[i]
|
||||
return whole_words
|
||||
|
||||
|
||||
163
voice/xunfei/xunfei_tts.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# Author: njnuko
|
||||
# Email: njnuko@163.com
|
||||
#
|
||||
# 这个文档是基于官方的demo来改的,固体官方demo文档请参考官网
|
||||
#
|
||||
# 语音听写流式 WebAPI 接口调用示例 接口文档(必看):https://doc.xfyun.cn/rest_api/语音听写(流式版).html
|
||||
# webapi 听写服务参考帖子(必看):http://bbs.xfyun.cn/forum.php?mod=viewthread&tid=38947&extra=
|
||||
# 语音听写流式WebAPI 服务,热词使用方式:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--个性化热词,
|
||||
# 设置热词
|
||||
# 注意:热词只能在识别的时候会增加热词的识别权重,需要注意的是增加相应词条的识别率,但并不是绝对的,具体效果以您测试为准。
|
||||
# 语音听写流式WebAPI 服务,方言试用方法:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--识别语种列表
|
||||
# 可添加语种或方言,添加后会显示该方言的参数值
|
||||
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
import websocket
|
||||
import datetime
|
||||
import hashlib
|
||||
import base64
|
||||
import hmac
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
import time
|
||||
import ssl
|
||||
from wsgiref.handlers import format_date_time
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
import _thread as thread
|
||||
import os
|
||||
|
||||
|
||||
|
||||
STATUS_FIRST_FRAME = 0 # 第一帧的标识
|
||||
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
|
||||
STATUS_LAST_FRAME = 2 # 最后一帧的标识
|
||||
|
||||
#############
|
||||
#这个参数是用来做输出文件路径的
|
||||
global outfile
|
||||
#这个文档是官方文档改的,这个参数是用来做函数调用时用的
|
||||
global wsParam
|
||||
##############
|
||||
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
def __init__(self, APPID, APIKey, APISecret,BusinessArgs,Text):
|
||||
self.APPID = APPID
|
||||
self.APIKey = APIKey
|
||||
self.APISecret = APISecret
|
||||
self.BusinessArgs = BusinessArgs
|
||||
self.Text = Text
|
||||
|
||||
# 公共参数(common)
|
||||
self.CommonArgs = {"app_id": self.APPID}
|
||||
# 业务参数(business),更多个性化参数可在官网查看
|
||||
#self.BusinessArgs = {"aue": "raw", "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"}
|
||||
self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-8')), "UTF8")}
|
||||
#使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"”
|
||||
#self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = 'wss://tts-api.xfyun.cn/v2/tts'
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": "ws-api.xfyun.cn"
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = url + '?' + urlencode(v)
|
||||
# print("date: ",date)
|
||||
# print("v: ",v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
# print('websocket url :', url)
|
||||
return url
|
||||
|
||||
def on_message(ws, message):
|
||||
#输出文件
|
||||
global outfile
|
||||
try:
|
||||
message =json.loads(message)
|
||||
code = message["code"]
|
||||
sid = message["sid"]
|
||||
audio = message["data"]["audio"]
|
||||
audio = base64.b64decode(audio)
|
||||
status = message["data"]["status"]
|
||||
if status == 2:
|
||||
print("ws is closed")
|
||||
ws.close()
|
||||
if code != 0:
|
||||
errMsg = message["message"]
|
||||
print("sid:%s call error:%s code is:%s" % (sid, errMsg, code))
|
||||
else:
|
||||
|
||||
with open(outfile, 'ab') as f:
|
||||
f.write(audio)
|
||||
|
||||
except Exception as e:
|
||||
print("receive msg,but parse exception:", e)
|
||||
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
global outfile
|
||||
global wsParam
|
||||
def run(*args):
|
||||
d = {"common": wsParam.CommonArgs,
|
||||
"business": wsParam.BusinessArgs,
|
||||
"data": wsParam.Data,
|
||||
}
|
||||
d = json.dumps(d)
|
||||
# print("------>开始发送文本数据")
|
||||
ws.send(d)
|
||||
if os.path.exists(outfile):
|
||||
os.remove(outfile)
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("### error:", error)
|
||||
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws):
|
||||
print("### closed ###")
|
||||
|
||||
|
||||
|
||||
def xunfei_tts(APPID, APIKey, APISecret,BusinessArgsTTS, Text, OutFile):
|
||||
global outfile
|
||||
global wsParam
|
||||
outfile = OutFile
|
||||
wsParam1 = Ws_Param(APPID,APIKey,APISecret,BusinessArgsTTS,Text)
|
||||
wsParam = wsParam1
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
|
||||
ws.on_open = on_open
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
return outfile
|
||||
|
||||
86
voice/xunfei/xunfei_voice.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#####################################################################
|
||||
# xunfei voice service
|
||||
# Auth: njnuko
|
||||
# Email: njnuko@163.com
|
||||
#
|
||||
# 要使用本模块, 首先到 xfyun.cn 注册一个开发者账号,
|
||||
# 之后创建一个新应用, 然后在应用管理的语音识别或者语音合同右边可以查看APPID API Key 和 Secret Key
|
||||
# 然后在 config.json 中填入这三个值
|
||||
#
|
||||
# 配置说明:
|
||||
# {
|
||||
# "APPID":"xxx71xxx",
|
||||
# "APIKey":"xxxx69058exxxxxx", #讯飞xfyun.cn控制台语音合成或者听写界面的APIKey
|
||||
# "APISecret":"xxxx697f0xxxxxx", #讯飞xfyun.cn控制台语音合成或者听写界面的APIKey
|
||||
# "BusinessArgsTTS":{"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"}, #语音合成的参数,具体可以参考xfyun.cn的文档
|
||||
# "BusinessArgsASR":{"domain": "iat", "language": "zh_cn", "accent": "mandarin", "vad_eos":10000, "dwa": "wpgs"} #语音听写的参数,具体可以参考xfyun.cn的文档
|
||||
# }
|
||||
#####################################################################
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from config import conf
|
||||
from voice.voice import Voice
|
||||
from .xunfei_asr import xunfei_asr
|
||||
from .xunfei_tts import xunfei_tts
|
||||
from voice.audio_convert import any_to_mp3
|
||||
import shutil
|
||||
from pydub import AudioSegment
|
||||
|
||||
|
||||
class XunfeiVoice(Voice):
|
||||
def __init__(self):
|
||||
try:
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
conf = None
|
||||
with open(config_path, "r") as fr:
|
||||
conf = json.load(fr)
|
||||
print(conf)
|
||||
self.APPID = str(conf.get("APPID"))
|
||||
self.APIKey = str(conf.get("APIKey"))
|
||||
self.APISecret = str(conf.get("APISecret"))
|
||||
self.BusinessArgsTTS = conf.get("BusinessArgsTTS")
|
||||
self.BusinessArgsASR= conf.get("BusinessArgsASR")
|
||||
|
||||
except Exception as e:
|
||||
logger.warn("XunfeiVoice init failed: %s, ignore " % e)
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
# 识别本地文件
|
||||
try:
|
||||
logger.debug("[Xunfei] voice file name={}".format(voice_file))
|
||||
#print("voice_file===========",voice_file)
|
||||
#print("voice_file_type===========",type(voice_file))
|
||||
#mp3_name, file_extension = os.path.splitext(voice_file)
|
||||
#mp3_file = mp3_name + ".mp3"
|
||||
#pcm_data=get_pcm_from_wav(voice_file)
|
||||
#mp3_name, file_extension = os.path.splitext(voice_file)
|
||||
#AudioSegment.from_wav(voice_file).export(mp3_file, format="mp3")
|
||||
#shutil.copy2(voice_file, 'tmp/test1.wav')
|
||||
#shutil.copy2(mp3_file, 'tmp/test1.mp3')
|
||||
#print("voice and mp3 file",voice_file,mp3_file)
|
||||
text = xunfei_asr(self.APPID,self.APISecret,self.APIKey,self.BusinessArgsASR,voice_file)
|
||||
logger.info("讯飞语音识别到了: {}".format(text))
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
except Exception as e:
|
||||
logger.warn("XunfeiVoice init failed: %s, ignore " % e)
|
||||
reply = Reply(ReplyType.ERROR, "讯飞语音识别出错了;{0}")
|
||||
return reply
|
||||
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
# Avoid the same filename under multithreading
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
return_file = xunfei_tts(self.APPID,self.APIKey,self.APISecret,self.BusinessArgsTTS,text,fileName)
|
||||
logger.info("[Xunfei] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
except Exception as e:
|
||||
logger.error("[Xunfei] textToVoice error={}".format(fileName))
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,讯飞语音合成失败")
|
||||
return reply
|
||||