mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-02 09:48:22 +08:00
Compare commits
296 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16324e7283 | ||
|
|
9f7e2e1572 | ||
|
|
857ce1d530 | ||
|
|
be0d72775d | ||
|
|
7832a2495b | ||
|
|
0506b7f735 | ||
|
|
4c0b7942f0 | ||
|
|
651c840c4a | ||
|
|
2a351ca415 | ||
|
|
49b7106d71 | ||
|
|
8bf633f539 | ||
|
|
0f8efcb4b0 | ||
|
|
c567641c5c | ||
|
|
bdc3820382 | ||
|
|
33a69a7907 | ||
|
|
a4d0e9bbc3 | ||
|
|
afc753e1d2 | ||
|
|
e641a41224 | ||
|
|
79305c0632 | ||
|
|
ef2ce3f09d | ||
|
|
71c18c04fc | ||
|
|
cf84e57f81 | ||
|
|
9421d44579 | ||
|
|
5cd2ae8cc8 | ||
|
|
22d67b3a59 | ||
|
|
e102cbb8c4 | ||
|
|
d90eeb7ee4 | ||
|
|
1989d53031 | ||
|
|
04ef0907b4 | ||
|
|
517b43561c | ||
|
|
ccb8c7227f | ||
|
|
9fbfeeb04f | ||
|
|
8b753a5a1f | ||
|
|
d25cab0627 | ||
|
|
84da0a8a35 | ||
|
|
6f665cffba | ||
|
|
aea8ac2e97 | ||
|
|
8418fa7b45 | ||
|
|
9cc4d0ee07 | ||
|
|
da60831c44 | ||
|
|
0773174a20 | ||
|
|
70e007d8ca | ||
|
|
fcc4d02c2f | ||
|
|
f4a5f00593 | ||
|
|
1170ed6566 | ||
|
|
883f0d449b | ||
|
|
f4c62e7844 | ||
|
|
f0d212a9d2 | ||
|
|
76a8974034 | ||
|
|
0614e822f4 | ||
|
|
6f682c9a2e | ||
|
|
a9fdbc31c5 | ||
|
|
086fdb5856 | ||
|
|
63c8ef4f17 | ||
|
|
736f6523c7 | ||
|
|
8b0b360d25 | ||
|
|
80b84e2ee6 | ||
|
|
b5b7d86f7b | ||
|
|
f20d704390 | ||
|
|
e4e1e2e944 | ||
|
|
6bc7eeb4cc | ||
|
|
656ed5de7b | ||
|
|
a11d695c78 | ||
|
|
c4f9acd5c5 | ||
|
|
5ef929dc42 | ||
|
|
c8cf27b544 | ||
|
|
bb5ecfc398 | ||
|
|
c91e7c35bb | ||
|
|
532d56df2d | ||
|
|
111ad44029 | ||
|
|
6b02bae957 | ||
|
|
6831743416 | ||
|
|
63e2f42636 | ||
|
|
f6e6805453 | ||
|
|
ad77ad8f2b | ||
|
|
469524e8ae | ||
|
|
f4f55d5dfd | ||
|
|
c248d0f3f4 | ||
|
|
648a04b513 | ||
|
|
bdc86c16ec | ||
|
|
21efd17c17 | ||
|
|
aaa75e7b62 | ||
|
|
6d0cef3152 | ||
|
|
c18472289f | ||
|
|
02b7c70a81 | ||
|
|
4eaa2b93c6 | ||
|
|
d347905373 | ||
|
|
f495213b2c | ||
|
|
9b125913ae | ||
|
|
da81f05804 | ||
|
|
9a371a4d4d | ||
|
|
1e92828f1a | ||
|
|
7e724b3fa3 | ||
|
|
3f5b976a87 | ||
|
|
49f2339cc2 | ||
|
|
29f1699de8 | ||
|
|
c415485801 | ||
|
|
6937673472 | ||
|
|
c4f10fe876 | ||
|
|
55ca652ad8 | ||
|
|
3effd5afd1 | ||
|
|
000c2029de | ||
|
|
ab88e3af06 | ||
|
|
b544a4c954 | ||
|
|
baff5fafec | ||
|
|
1673de73ba | ||
|
|
e68936e36e | ||
|
|
7dbd195e45 | ||
|
|
3dc22f98bf | ||
|
|
805e870c18 | ||
|
|
de2c031797 | ||
|
|
3aa571aa1b | ||
|
|
3e4969efe6 | ||
|
|
446e94df76 | ||
|
|
5b26066a4c | ||
|
|
8a80de5c3f | ||
|
|
52a490c87e | ||
|
|
29490741fd | ||
|
|
f0e416455f | ||
|
|
f7a2c97943 | ||
|
|
993853757b | ||
|
|
a3abfb987d | ||
|
|
2711fa1b1b | ||
|
|
1f7afaba07 | ||
|
|
e02c8bff81 | ||
|
|
22391ba1a5 | ||
|
|
a05781ec19 | ||
|
|
f898ed6a2a | ||
|
|
e6d0a15b54 | ||
|
|
49cff026e2 | ||
|
|
08f0023cfd | ||
|
|
e311466ee6 | ||
|
|
56789e68d7 | ||
|
|
87525bb383 | ||
|
|
bb2880191a | ||
|
|
4f1acf26d6 | ||
|
|
fc2d6b21ac | ||
|
|
b9e84fefbd | ||
|
|
91f5ffb2d9 | ||
|
|
70ff2341cb | ||
|
|
74eed93497 | ||
|
|
d02e26c014 | ||
|
|
523cade7c3 | ||
|
|
e22c183ca9 | ||
|
|
3afd99da30 | ||
|
|
f44979f983 | ||
|
|
095f9cc108 | ||
|
|
1089076fce | ||
|
|
cad3b691a9 | ||
|
|
bac21426d3 | ||
|
|
c4a35314cd | ||
|
|
7090722565 | ||
|
|
6d972c7c18 | ||
|
|
6961a88feb | ||
|
|
c41ec13984 | ||
|
|
ca8e06e562 | ||
|
|
200cd33a8e | ||
|
|
1da7991c65 | ||
|
|
fdfb7e369a | ||
|
|
c2b01cc957 | ||
|
|
5de8e94bb4 | ||
|
|
7a2c15d912 | ||
|
|
70344dd214 | ||
|
|
405372d1a7 | ||
|
|
b8c5174da5 | ||
|
|
1f6f9103d9 | ||
|
|
6431487c7a | ||
|
|
8b2d1189db | ||
|
|
b777f27cb7 | ||
|
|
b31c3b124a | ||
|
|
fa1e965fba | ||
|
|
91dc8b4d58 | ||
|
|
6d16ea8830 | ||
|
|
7db4253264 | ||
|
|
4d2b7d9bf9 | ||
|
|
8f6f4acb88 | ||
|
|
f20d84cb37 | ||
|
|
afbdf1d5d5 | ||
|
|
bc8364d594 | ||
|
|
c8d388f70f | ||
|
|
be13cc3194 | ||
|
|
a46320e744 | ||
|
|
071709d263 | ||
|
|
93a32ae5ff | ||
|
|
eee96f226f | ||
|
|
e19a8b479c | ||
|
|
9ef459112e | ||
|
|
e96474bd5c | ||
|
|
6fed719e09 | ||
|
|
99aac76618 | ||
|
|
599f458201 | ||
|
|
2f8099059c | ||
|
|
e24f177832 | ||
|
|
48cc143e88 | ||
|
|
b09b46c045 | ||
|
|
2c6583cc9c | ||
|
|
e381d1bfb8 | ||
|
|
eac619d54f | ||
|
|
a6ef3bc0ce | ||
|
|
118122c541 | ||
|
|
bfdf33ac09 | ||
|
|
fa3370df5b | ||
|
|
f1e51672c5 | ||
|
|
91f97b2728 | ||
|
|
2c542e03fe | ||
|
|
71a11b4267 | ||
|
|
ea642757db | ||
|
|
fb72b601aa | ||
|
|
27e507e744 | ||
|
|
4db19f816f | ||
|
|
096d5776d1 | ||
|
|
3d799eb4d9 | ||
|
|
e4ac3afa4d | ||
|
|
d38e4eed5b | ||
|
|
97787fac91 | ||
|
|
b494ee2f1c | ||
|
|
31ac80a074 | ||
|
|
c8896450f6 | ||
|
|
c662fa4c63 | ||
|
|
db2ee802ca | ||
|
|
d40e915e2b | ||
|
|
c0616e7efa | ||
|
|
01660597e3 | ||
|
|
c5b549f450 | ||
|
|
802d8457bb | ||
|
|
c3a3df67b0 | ||
|
|
5798aeb3cd | ||
|
|
cc81dd9172 | ||
|
|
44fdadda08 | ||
|
|
66a014150b | ||
|
|
1da596639f | ||
|
|
76614ae9e5 | ||
|
|
6ddddffc0f | ||
|
|
dd95f849d4 | ||
|
|
22c7f8fe9e | ||
|
|
3d47be1f49 | ||
|
|
5e399c46b1 | ||
|
|
38e1db7a37 | ||
|
|
8309f7cdbe | ||
|
|
b8cc62ae95 | ||
|
|
c0eb433fa2 | ||
|
|
7f857d66f6 | ||
|
|
93b14d38f4 | ||
|
|
21825faab0 | ||
|
|
1fafd39298 | ||
|
|
23b750fc4f | ||
|
|
90581c840d | ||
|
|
cac7a6228a | ||
|
|
674fbc3f69 | ||
|
|
9577bf1cc7 | ||
|
|
654ebe93e7 | ||
|
|
ecb1b3c491 | ||
|
|
c3d1711edc | ||
|
|
c12c7f10f0 | ||
|
|
f71820bf4e | ||
|
|
748c53c774 | ||
|
|
b290a71bfb | ||
|
|
3204c51eca | ||
|
|
2c4b8a44dc | ||
|
|
943aa05eaa | ||
|
|
d0fd36e7e1 | ||
|
|
f45ff5fd0a | ||
|
|
c22c7102d5 | ||
|
|
11ecfd1b41 | ||
|
|
798e30e5ac | ||
|
|
15e0702329 | ||
|
|
a2bc22c37d | ||
|
|
8093fcc64c | ||
|
|
800419e7cc | ||
|
|
a241dc6785 | ||
|
|
805bea0d5f | ||
|
|
9d394adf24 | ||
|
|
2074f27aff | ||
|
|
283ad48b86 | ||
|
|
07e10a7943 | ||
|
|
2812a5026c | ||
|
|
3a20461abf | ||
|
|
64ae3d1e21 | ||
|
|
a25d7ea65b | ||
|
|
74ebbdd761 | ||
|
|
a0427b569e | ||
|
|
5346dfdd8b | ||
|
|
3ee4147285 | ||
|
|
c41e486bfc | ||
|
|
eda3ba92fd | ||
|
|
40255290b0 | ||
|
|
af5bc73dc0 | ||
|
|
0247cd4c45 | ||
|
|
916762cc8c | ||
|
|
d6fdf8ca2a | ||
|
|
95708489c9 | ||
|
|
ced0fa4608 | ||
|
|
7e0fbd600f | ||
|
|
f33e4e0323 | ||
|
|
977d3bc02e | ||
|
|
854d613a81 |
153
README.md
153
README.md
@@ -1,42 +1,69 @@
|
||||
# 简介
|
||||
|
||||
> 本项目是基于大模型的智能对话机器人,支持微信、企业微信、公众号、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/文心一言/讯飞星火/通义千问/Gemini/LinkAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
|
||||
> chatgpt-on-wechat(简称CoW)项目是基于大模型的智能对话机器人,支持微信公众号、企业微信应用、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/Gemini/LinkAI/ChatGLM/KIMI/文心一言/讯飞星火/通义千问/LinkAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
|
||||
|
||||
最新版本支持的功能如下:
|
||||
|
||||
- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信、微信公众号和、企业微信、飞书、钉钉等部署方式
|
||||
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4, claude, Gemini, 文心一言, 讯飞星火, 通义千问
|
||||
- [x] **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
|
||||
- [x] **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, vision模型
|
||||
- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件
|
||||
- [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、智能客服、私域助手使用,基于 [LinkAI](https://link-ai.tech) 实现
|
||||
- ✅ **多端部署:** 有多种部署方式可选择且功能完备,目前已支持微信公众号、企业微信应用、飞书、钉钉等部署方式
|
||||
- ✅ **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4o-mini, GPT-4o, GPT-4, Claude-3.5, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM-4,Kimi(月之暗面), MiniMax
|
||||
- ✅ **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
|
||||
- ✅ **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型
|
||||
- ✅ **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件
|
||||
- ✅ **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、智能客服、私域助手使用,基于 [LinkAI](https://link-ai.tech) 实现
|
||||
|
||||
# 演示
|
||||
## 声明
|
||||
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/assets/26161723/d5154020-36e3-41db-8706-40ce9f3f1b1e
|
||||
1. 本项目遵循 [MIT开源协议](/LICENSE),仅用于技术研究和学习,使用本项目时需遵守所在地法律法规、相关政策以及企业章程,禁止用于任何违法或侵犯他人权益的行为
|
||||
2. 境内使用该项目时,请使用国内厂商的大模型服务,并进行必要的内容安全审核及过滤
|
||||
3. 本项目主要接入协同办公平台,推荐使用公众号、企微自建应用、钉钉、飞书等接入通道,其他通道为历史产物已不维护
|
||||
4. 任何个人、团队和企业,无论以何种方式使用该项目、对何对象提供服务,所产生的一切后果,本项目均不承担任何责任
|
||||
|
||||
Demo made by [Visionn](https://www.wangpc.cc/)
|
||||
## 演示
|
||||
|
||||
# 商业支持
|
||||
DEMO视频:https://cdn.link-ai.tech/doc/cow_demo.mp4
|
||||
|
||||
> 我们还提供企业级的 **AI应用平台**,包含知识库、Agent插件、应用管理等能力,支持多平台聚合的应用接入、客户端管理、对话管理,以及提供
|
||||
SaaS服务、私有化部署、稳定托管接入 等多种模式。
|
||||
>
|
||||
> 目前已在私域运营、智能客服、企业效率助手等场景积累了丰富的 AI 解决方案, 在电商、文教、健康、新消费等各行业沉淀了 AI 落地的最佳实践,致力于打造助力中小企业拥抱 AI 的一站式平台。
|
||||
|
||||
企业服务和商用咨询可联系产品顾问:
|
||||
|
||||
<img width="240" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg">
|
||||
|
||||
# 开源社区
|
||||
## 社区
|
||||
|
||||
添加小助手微信加入开源项目交流群:
|
||||
|
||||
<img width="240" src="./docs/images/contact.jpg">
|
||||
<img width="160" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/open-community.png">
|
||||
|
||||
# 更新日志
|
||||
<br>
|
||||
|
||||
>**2023.11.11:** [1.5.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.3) 和 [1.5.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.4),新增Google Gemini、通义千问模型
|
||||
# 企业服务
|
||||
|
||||
<a href="https://link-ai.tech" target="_blank"><img width="800" src="https://cdn.link-ai.tech/image/link-ai-intro.jpg"></a>
|
||||
|
||||
> [LinkAI](https://link-ai.tech/) 是面向企业和开发者的一站式AI应用平台,聚合多模态大模型、知识库、Agent 插件、工作流等能力,支持一键接入主流平台并进行管理,支持SaaS、私有化部署多种模式。
|
||||
>
|
||||
> LinkAI 目前 已在私域运营、智能客服、企业效率助手等场景积累了丰富的 AI 解决方案, 在电商、文教、健康、新消费、科技制造等各行业沉淀了大模型落地应用的最佳实践,致力于帮助更多企业和开发者拥抱 AI 生产力。
|
||||
|
||||
**企业服务和产品咨询** 可联系产品顾问:
|
||||
|
||||
<img width="160" src="https://cdn.link-ai.tech/consultant-s.jpg">
|
||||
|
||||
<br>
|
||||
|
||||
# 🏷 更新日志
|
||||
>**2024.10.31:** [1.7.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.3) 程序稳定性提升、数据库功能、Claude模型优化、linkai插件优化、离线通知
|
||||
|
||||
>**2024.09.26:** [1.7.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.2) 和 [1.7.1版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.1) 文心,讯飞等模型优化、o1 模型、快速安装和管理脚本
|
||||
|
||||
>**2024.08.02:** [1.7.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.7.0) 新增 讯飞4.0 模型、知识库引用来源展示、相关插件优化
|
||||
|
||||
>**2024.07.19:** [1.6.9版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.9) 新增 gpt-4o-mini 模型、阿里语音识别、企微应用渠道路由优化
|
||||
|
||||
>**2024.07.05:** [1.6.8版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.8) 和 [1.6.7版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.7),Claude3.5, Gemini 1.5 Pro, MiniMax模型、工作流图片输入、模型列表完善
|
||||
|
||||
>**2024.06.04:** [1.6.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.6) 和 [1.6.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.5),gpt-4o模型、钉钉流式卡片、讯飞语音识别/合成
|
||||
|
||||
>**2024.04.26:** [1.6.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.6.0),新增 Kimi 接入、gpt-4-turbo版本升级、文件总结和语音识别问题修复
|
||||
|
||||
>**2024.03.26:** [1.5.8版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.8) 和 [1.5.7版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.7),新增 GLM-4、Claude-3 模型,edge-tts 语音支持
|
||||
|
||||
>**2024.01.26:** [1.5.6版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.6) 和 [1.5.5版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.5),钉钉接入,tool插件升级,4-turbo模型更新
|
||||
|
||||
>**2023.11.11:** [1.5.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.3) 和 [1.5.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.4),新增通义千问模型、Google Gemini
|
||||
|
||||
>**2023.11.10:** [1.5.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.2),新增飞书通道、图像识别对话、黑名单配置
|
||||
|
||||
@@ -48,25 +75,22 @@ SaaS服务、私有化部署、稳定托管接入 等多种模式。
|
||||
|
||||
>**2023.08.08:** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图
|
||||
|
||||
>**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,并接入微信、公众号及企业微信中,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
>**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
>**2023.04.26:** 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatcom/README.md)。(contributed by [@lanvent](https://github.com/lanvent) in [#944](https://github.com/zhayujie/chatgpt-on-wechat/pull/944))
|
||||
更早更新日志查看: [归档日志](/docs/version/old-version.md)
|
||||
|
||||
>**2023.04.05:** 支持微信公众号部署,兼容插件,并支持语音图片交互,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatmp/README.md)。(contributed by [@JS00000](https://github.com/JS00000) in [#686](https://github.com/zhayujie/chatgpt-on-wechat/pull/686))
|
||||
<br>
|
||||
|
||||
>**2023.04.05:** 增加能让ChatGPT使用工具的`tool`插件,[使用文档](https://github.com/goldfishh/chatgpt-on-wechat/blob/master/plugins/tool/README.md)。工具相关issue可反馈至[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)。(contributed by [@goldfishh](https://github.com/goldfishh) in [#663](https://github.com/zhayujie/chatgpt-on-wechat/pull/663))
|
||||
# 🚀 快速开始
|
||||
|
||||
>**2023.03.25:** 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 [#578](https://github.com/zhayujie/chatgpt-on-wechat/issues/578)。(contributed by [@lanvent](https://github.com/lanvent) in [#565](https://github.com/zhayujie/chatgpt-on-wechat/pull/565))
|
||||
- 快速开始详细文档:[项目搭建文档](https://docs.link-ai.tech/cow/quick-start)
|
||||
|
||||
>**2023.03.09:** 基于 `whisper API`(后续已接入更多的语音`API`服务) 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
|
||||
|
||||
>**2023.02.09:** 扫码登录存在账号限制风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
|
||||
|
||||
# 快速开始
|
||||
|
||||
快速开始文档:[项目搭建文档](https://docs.link-ai.tech/cow/quick-start)
|
||||
|
||||
## 准备
|
||||
- 快速安装脚本,详细使用指导:[一键安装启动脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/%E4%B8%80%E9%94%AE%E5%AE%89%E8%A3%85%E5%90%AF%E5%8A%A8%E8%84%9A%E6%9C%AC)
|
||||
```bash
|
||||
bash <(curl -sS https://cdn.link-ai.tech/code/cow/install.sh)
|
||||
```
|
||||
- 项目管理脚本,详细使用指导:[项目管理脚本](https://github.com/zhayujie/chatgpt-on-wechat/wiki/%E9%A1%B9%E7%9B%AE%E7%AE%A1%E7%90%86%E8%84%9A%E6%9C%AC)
|
||||
## 一、准备
|
||||
|
||||
### 1. 账号注册
|
||||
|
||||
@@ -74,7 +98,7 @@ SaaS服务、私有化部署、稳定托管接入 等多种模式。
|
||||
|
||||
> 默认对话模型是 openai 的 gpt-3.5-turbo,计费方式是约每 1000tokens (约750个英文单词 或 500汉字,包含请求和回复) 消耗 $0.002,图片生成是Dell E模型,每张消耗 $0.016。
|
||||
|
||||
项目同时也支持使用 LinkAI 接口,无需代理,可使用 文心、讯飞、GPT-3、GPT-4 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结和对话等能力。修改配置即可一键切换,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
项目同时也支持使用 LinkAI 接口,无需代理,可使用 Kimi、文心、讯飞、GPT-3.5、GPT-4o 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结、工作流等能力。修改配置即可一键使用,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
### 2.运行环境
|
||||
|
||||
@@ -105,7 +129,7 @@ pip3 install -r requirements-optional.txt
|
||||
```
|
||||
> 如果某项依赖安装失败可注释掉对应的行再继续
|
||||
|
||||
## 配置
|
||||
## 二、配置
|
||||
|
||||
配置文件的模板在根目录的`config-template.json`中,需复制该模板创建最终生效的 `config.json` 文件:
|
||||
|
||||
@@ -113,13 +137,14 @@ pip3 install -r requirements-optional.txt
|
||||
cp config-template.json config.json
|
||||
```
|
||||
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(请去掉注释):
|
||||
然后在`config.json`中填入配置,以下是对默认配置的说明,可根据需要进行自定义修改(注意实际使用时请去掉注释,保证JSON格式的完整):
|
||||
|
||||
```bash
|
||||
# config.json文件内容示例
|
||||
{
|
||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
||||
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-4, gpt-4-turbo, wenxin, xunfei, glm-4, claude-3-haiku, moonshot
|
||||
"open_ai_api_key": "YOUR API KEY", # 如果使用openAI模型则填入上面创建的 OpenAI API KEY
|
||||
"open_ai_api_base": "https://api.openai.com/v1", # OpenAI接口代理地址
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
@@ -130,15 +155,13 @@ pip3 install -r requirements-optional.txt
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
"speech_recognition": false, # 是否开启语音识别
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
|
||||
"azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称
|
||||
"azure_api_version": "", # 采用Azure ChatGPT时,API版本
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
"voice_reply_voice": false, # 是否使用语音回复语音
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
# 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
|
||||
"linkai_api_key": "", # LinkAI Api Key
|
||||
"linkai_app_code": "" # LinkAI 应用code
|
||||
"linkai_app_code": "" # LinkAI 应用或工作流code
|
||||
}
|
||||
```
|
||||
**配置说明:**
|
||||
@@ -159,11 +182,11 @@ pip3 install -r requirements-optional.txt
|
||||
|
||||
+ 添加 `"speech_recognition": true` 将开启语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,该参数仅支持私聊 (注意由于语音消息无法匹配前缀,一旦开启将对所有语音自动回复,支持语音触发画图);
|
||||
+ 添加 `"group_speech_recognition": true` 将开启群组语音识别,默认使用openai的whisper模型识别为文字,同时以文字回复,参数仅支持群聊 (会匹配group_chat_prefix和group_chat_keyword, 支持语音触发画图);
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊),但是需要配置对应语音合成平台的key,由于itchat协议的限制,只能发送语音mp3文件,若使用wechaty则回复的是微信语音。
|
||||
+ 添加 `"voice_reply_voice": true` 将开启语音回复语音(同时作用于私聊和群聊)
|
||||
|
||||
**4.其他配置**
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`, `wenxin` , `claude` , `xunfei`(其中gpt-4 api暂未完全开放,申请通过后可使用)
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `gpt-4o-mini`, `gpt-4o`, `gpt-4`, `wenxin` , `claude` , `gemini`, `glm-4`, `xunfei`, `moonshot`等,全部模型名称参考[common/const.py](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/common/const.py)文件
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
@@ -171,7 +194,7 @@ pip3 install -r requirements-optional.txt
|
||||
+ `conversation_max_tokens`:表示能够记忆的上下文最大字数(一问一答为一组对话,如果累积的对话字数超出限制,就会优先移除最早的一组对话)
|
||||
+ `rate_limit_chatgpt`,`rate_limit_dalle`:每分钟最高问答速率、画图速率,超速后排队按序处理。
|
||||
+ `clear_memory_commands`: 对话内指令,主动清空前文记忆,字符串数组可自定义指令别名。
|
||||
+ `hot_reload`: 程序退出后,暂存微信扫码状态,默认关闭。
|
||||
+ `hot_reload`: 程序退出后,暂存等于状态,默认关闭。
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
|
||||
@@ -179,11 +202,11 @@ pip3 install -r requirements-optional.txt
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用code,选填
|
||||
+ `linkai_app_code`: LinkAI 应用或工作流的code,选填
|
||||
|
||||
**本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
|
||||
## 运行
|
||||
## 三、运行
|
||||
|
||||
### 1.本地运行
|
||||
|
||||
@@ -193,7 +216,7 @@ pip3 install -r requirements-optional.txt
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
|
||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
终端输出二维码后,进行扫码登录,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的账号需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
@@ -215,7 +238,7 @@ nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通
|
||||
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功的表现是执行 `docker -v` 和 `docker-compose version` (或 docker compose version) 可以查看到版本号,可前往 [docker官网](https://docs.docker.com/engine/install/) 进行下载。
|
||||
|
||||
#### (1) 下载 docker-compose.yml 文件
|
||||
**(1) 下载 docker-compose.yml 文件**
|
||||
|
||||
```bash
|
||||
wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml
|
||||
@@ -223,7 +246,7 @@ wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 修改所需配置,如 `OPEN_AI_API_KEY` 和 `GROUP_NAME_WHITE_LIST` 等。
|
||||
|
||||
#### (2) 启动容器
|
||||
**(2) 启动容器**
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
|
||||
@@ -244,7 +267,7 @@ sudo docker compose up -d
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
#### (3) 插件使用
|
||||
**(3) 插件使用**
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
@@ -253,7 +276,7 @@ sudo docker logs -f chatgpt-on-wechat
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
|
||||
**注**:采用docker方式部署的详细教程可以参考:[docker部署CoW项目](https://www.wangpc.cc/ai/docker-deploy-cow/)
|
||||
### 4. Railway部署
|
||||
|
||||
> Railway 每月提供5刀和最多500小时的免费额度。 (07.11更新: 目前大部分账号已无法免费部署)
|
||||
@@ -266,16 +289,22 @@ volumes:
|
||||
|
||||
[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
|
||||
## 常见问题
|
||||
<br>
|
||||
|
||||
# 🔎 常见问题
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (beta版本,语料完善中,回复仅供参考)
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (语料持续完善中,回复仅供参考)
|
||||
|
||||
## 开发
|
||||
# 🛠️ 开发
|
||||
|
||||
欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py) 实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
|
||||
|
||||
## 联系
|
||||
# ✉ 联系
|
||||
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品顾问](https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg)咨询。
|
||||
|
||||
# 🌟 贡献者
|
||||
|
||||

|
||||
|
||||
33
app.py
33
app.py
@@ -3,6 +3,7 @@
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
from channel import channel_factory
|
||||
from common import const
|
||||
@@ -24,6 +25,21 @@ def sigterm_handler_wrap(_signo):
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def start_channel(channel_name: str):
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp","web", "wechatmp_service", "wechatcom_app", "wework",
|
||||
const.FEISHU, const.DINGTALK]:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
try:
|
||||
from common import linkai_client
|
||||
threading.Thread(target=linkai_client.start, args=(channel,)).start()
|
||||
except Exception as e:
|
||||
pass
|
||||
channel.startup()
|
||||
|
||||
|
||||
def run():
|
||||
try:
|
||||
# load config
|
||||
@@ -41,22 +57,11 @@ def run():
|
||||
|
||||
if channel_name == "wxy":
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework", const.FEISHU,const.DINGTALK]:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
try:
|
||||
from common import linkai_client
|
||||
threading.Thread(target=linkai_client.start, args=(channel, )).start()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# startup channel
|
||||
channel.startup()
|
||||
start_channel(channel_name)
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests, json
|
||||
import requests
|
||||
import json
|
||||
from common import const
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
@@ -16,9 +18,20 @@ class BaiduWenxinBot(Bot):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant"
|
||||
if conf().get("model") and conf().get("model") == "wenxin-4":
|
||||
wenxin_model = "completions_pro"
|
||||
wenxin_model = conf().get("baidu_wenxin_model")
|
||||
self.prompt_enabled = conf().get("baidu_wenxin_prompt_enabled")
|
||||
if self.prompt_enabled:
|
||||
self.prompt = conf().get("character_desc", "")
|
||||
if self.prompt == "":
|
||||
logger.warn("[BAIDU] Although you enabled model prompt, character_desc is not specified.")
|
||||
if wenxin_model is not None:
|
||||
wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant"
|
||||
else:
|
||||
if conf().get("model") and conf().get("model") == const.WEN_XIN:
|
||||
wenxin_model = "completions"
|
||||
elif conf().get("model") and conf().get("model") == const.WEN_XIN_4:
|
||||
wenxin_model = "completions_pro"
|
||||
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=wenxin_model)
|
||||
|
||||
def reply(self, query, context=None):
|
||||
@@ -76,7 +89,7 @@ class BaiduWenxinBot(Bot):
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
payload = {'messages': session.messages}
|
||||
payload = {'messages': session.messages, 'system': self.prompt} if self.prompt_enabled else {'messages': session.messages}
|
||||
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
|
||||
response_text = json.loads(response.text)
|
||||
logger.info(f"[BAIDU] response text={response_text}")
|
||||
@@ -94,7 +107,7 @@ class BaiduWenxinBot(Bot):
|
||||
logger.warn("[BAIDU] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
result = {"completion_tokens": 0, "content": "出错了: {}".format(e)}
|
||||
result = {"total_tokens": 0, "completion_tokens": 0, "content": "出错了: {}".format(e)}
|
||||
return result
|
||||
|
||||
def get_access_token(self):
|
||||
|
||||
@@ -43,13 +43,30 @@ def create_bot(bot_type):
|
||||
elif bot_type == const.CLAUDEAI:
|
||||
from bot.claude.claude_ai_bot import ClaudeAIBot
|
||||
return ClaudeAIBot()
|
||||
|
||||
elif bot_type == const.CLAUDEAPI:
|
||||
from bot.claudeapi.claude_api_bot import ClaudeAPIBot
|
||||
return ClaudeAPIBot()
|
||||
elif bot_type == const.QWEN:
|
||||
from bot.ali.ali_qwen_bot import AliQwenBot
|
||||
return AliQwenBot()
|
||||
|
||||
elif bot_type == const.QWEN_DASHSCOPE:
|
||||
from bot.dashscope.dashscope_bot import DashscopeBot
|
||||
return DashscopeBot()
|
||||
elif bot_type == const.GEMINI:
|
||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
||||
return GoogleGeminiBot()
|
||||
|
||||
elif bot_type == const.ZHIPU_AI:
|
||||
from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
|
||||
return ZHIPUAIBot()
|
||||
|
||||
elif bot_type == const.MOONSHOT:
|
||||
from bot.moonshot.moonshot_bot import MoonshotBot
|
||||
return MoonshotBot()
|
||||
|
||||
elif bot_type == const.MiniMax:
|
||||
from bot.minimax.minimax_bot import MinimaxBot
|
||||
return MinimaxBot()
|
||||
|
||||
|
||||
raise RuntimeError
|
||||
|
||||
@@ -5,7 +5,7 @@ import time
|
||||
import openai
|
||||
import openai.error
|
||||
import requests
|
||||
|
||||
from common import const
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
@@ -15,7 +15,7 @@ from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.token_bucket import TokenBucket
|
||||
from config import conf, load_config
|
||||
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ChatGPTBot(Bot, OpenAIImage):
|
||||
@@ -30,10 +30,12 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
openai.proxy = proxy
|
||||
if conf().get("rate_limit_chatgpt"):
|
||||
self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20))
|
||||
|
||||
conf_model = conf().get("model") or "gpt-3.5-turbo"
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
# o1相关模型不支持system prompt,暂时用文心模型的session
|
||||
|
||||
self.args = {
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"model": conf_model, # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性
|
||||
# "max_tokens":4096, # 回复最大的字符数
|
||||
"top_p": conf().get("top_p", 1),
|
||||
@@ -42,6 +44,12 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
"request_timeout": conf().get("request_timeout", None), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试
|
||||
}
|
||||
# o1相关模型固定了部分参数,暂时去掉
|
||||
if conf_model in [const.O1, const.O1_MINI]:
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or const.O1_MINI)
|
||||
remove_keys = ["temperature", "top_p", "frequency_penalty", "presence_penalty"]
|
||||
for key in remove_keys:
|
||||
self.args.pop(key, None) # 如果键不存在,使用 None 来避免抛出错误
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
@@ -171,24 +179,70 @@ class AzureChatGPTBot(ChatGPTBot):
|
||||
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
api_version = "2022-08-03-preview"
|
||||
url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, api_version)
|
||||
api_key = api_key or openai.api_key
|
||||
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
||||
try:
|
||||
body = {"caption": query, "resolution": conf().get("image_create_size", "256x256")}
|
||||
submission = requests.post(url, headers=headers, json=body)
|
||||
operation_location = submission.headers["Operation-Location"]
|
||||
retry_after = submission.headers["Retry-after"]
|
||||
status = ""
|
||||
image_url = ""
|
||||
while status != "Succeeded":
|
||||
logger.info("waiting for image create..., " + status + ",retry after " + retry_after + " seconds")
|
||||
time.sleep(int(retry_after))
|
||||
response = requests.get(operation_location, headers=headers)
|
||||
status = response.json()["status"]
|
||||
image_url = response.json()["result"]["contentUrl"]
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.error("create image error: {}".format(e))
|
||||
return False, "图片生成失败"
|
||||
text_to_image_model = conf().get("text_to_image")
|
||||
if text_to_image_model == "dall-e-2":
|
||||
api_version = "2023-06-01-preview"
|
||||
endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base")
|
||||
# 检查endpoint是否以/结尾
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint = endpoint + "/"
|
||||
url = "{}openai/images/generations:submit?api-version={}".format(endpoint, api_version)
|
||||
api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key")
|
||||
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
||||
try:
|
||||
body = {"prompt": query, "size": conf().get("image_create_size", "256x256"),"n": 1}
|
||||
submission = requests.post(url, headers=headers, json=body)
|
||||
operation_location = submission.headers['operation-location']
|
||||
status = ""
|
||||
while (status != "succeeded"):
|
||||
if retry_count > 3:
|
||||
return False, "图片生成失败"
|
||||
response = requests.get(operation_location, headers=headers)
|
||||
status = response.json()['status']
|
||||
retry_count += 1
|
||||
image_url = response.json()['result']['data'][0]['url']
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.error("create image error: {}".format(e))
|
||||
return False, "图片生成失败"
|
||||
elif text_to_image_model == "dall-e-3":
|
||||
api_version = conf().get("azure_api_version", "2024-02-15-preview")
|
||||
endpoint = conf().get("azure_openai_dalle_api_base","open_ai_api_base")
|
||||
# 检查endpoint是否以/结尾
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint = endpoint + "/"
|
||||
url = "{}openai/deployments/{}/images/generations?api-version={}".format(endpoint, conf().get("azure_openai_dalle_deployment_id","text_to_image"),api_version)
|
||||
api_key = conf().get("azure_openai_dalle_api_key","open_ai_api_key")
|
||||
headers = {"api-key": api_key, "Content-Type": "application/json"}
|
||||
try:
|
||||
body = {"prompt": query, "size": conf().get("image_create_size", "1024x1024"), "quality": conf().get("dalle3_image_quality", "standard")}
|
||||
response = requests.post(url, headers=headers, json=body)
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
data = response.json()
|
||||
|
||||
# 检查响应中是否包含图像 URL
|
||||
if 'data' in data and len(data['data']) > 0 and 'url' in data['data'][0]:
|
||||
image_url = data['data'][0]['url']
|
||||
return True, image_url
|
||||
else:
|
||||
error_message = "响应中没有图像 URL"
|
||||
logger.error(error_message)
|
||||
return False, "图片生成失败"
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
# 捕获所有请求相关的异常
|
||||
try:
|
||||
error_detail = response.json().get('error', {}).get('message', str(e))
|
||||
except ValueError:
|
||||
error_detail = str(e)
|
||||
error_message = f"{error_detail}"
|
||||
logger.error(error_message)
|
||||
return False, error_message
|
||||
|
||||
except Exception as e:
|
||||
# 捕获所有其他异常
|
||||
error_message = f"生成图像时发生错误: {e}"
|
||||
logger.error(error_message)
|
||||
return False, "图片生成失败"
|
||||
else:
|
||||
return False, "图片生成失败,未配置text_to_image参数"
|
||||
|
||||
@@ -57,18 +57,20 @@ class ChatGPTSession(Session):
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
|
||||
if model in ["wenxin", "xunfei", const.GEMINI]:
|
||||
if model in ["wenxin", "xunfei"] or model.startswith(const.GEMINI):
|
||||
return num_tokens_by_character(messages)
|
||||
|
||||
import tiktoken
|
||||
|
||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]:
|
||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106", "moonshot", const.LINKAI_35]:
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]:
|
||||
"gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW, const.GPT4_TURBO_01_25,
|
||||
const.GPT_4o, const.GPT_4O_0806, const.GPT_4o_MINI, const.LINKAI_4o, const.LINKAI_4_TURBO]:
|
||||
return num_tokens_from_messages(messages, model="gpt-4")
|
||||
|
||||
elif model.startswith("claude-3"):
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
|
||||
132
bot/claudeapi/claude_api_bot.py
Normal file
132
bot/claudeapi/claude_api_bot.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import anthropic
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common import const
|
||||
from config import conf
|
||||
|
||||
user_session = dict()
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class ClaudeAPIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
proxy = conf().get("proxy", None)
|
||||
base_url = conf().get("open_ai_api_base", None) # 复用"open_ai_api_base"参数作为base_url
|
||||
self.claudeClient = anthropic.Anthropic(
|
||||
api_key=conf().get("claude_api_key"),
|
||||
proxies=proxy if proxy else None,
|
||||
base_url=base_url if base_url else None
|
||||
)
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "text-davinci-003")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[CLAUDE_API] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
if query == "#清除记忆":
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
result = self.reply_text(session)
|
||||
logger.info(result)
|
||||
total_tokens, completion_tokens, reply_content = (
|
||||
result["total_tokens"],
|
||||
result["completion_tokens"],
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[CLAUDE_API] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
||||
try:
|
||||
actual_model = self._model_mapping(conf().get("model"))
|
||||
response = self.claudeClient.messages.create(
|
||||
model=actual_model,
|
||||
max_tokens=4096,
|
||||
system=conf().get("character_desc", ""),
|
||||
messages=session.messages
|
||||
)
|
||||
# response = openai.Completion.create(prompt=str(session), **self.args)
|
||||
res_content = response.content[0].text.strip().replace("<|endoftext|>", "")
|
||||
total_tokens = response.usage.input_tokens+response.usage.output_tokens
|
||||
completion_tokens = response.usage.output_tokens
|
||||
logger.info("[CLAUDE_API] reply={}".format(res_content))
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"total_tokens": 0, "completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[CLAUDE_API] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[CLAUDE_API] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CLAUDE_API] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.warn("[CLAUDE_API] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[CLAUDE_API] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
def _model_mapping(self, model) -> str:
|
||||
if model == "claude-3-opus":
|
||||
return const.CLAUDE_3_OPUS
|
||||
elif model == "claude-3-sonnet":
|
||||
return const.CLAUDE_3_SONNET
|
||||
elif model == "claude-3-haiku":
|
||||
return const.CLAUDE_3_HAIKU
|
||||
elif model == "claude-3.5-sonnet":
|
||||
return const.CLAUDE_35_SONNET
|
||||
return model
|
||||
117
bot/dashscope/dashscope_bot.py
Normal file
117
bot/dashscope/dashscope_bot.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from .dashscope_session import DashscopeSession
|
||||
import os
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
|
||||
|
||||
|
||||
dashscope_models = {
|
||||
"qwen-turbo": dashscope.Generation.Models.qwen_turbo,
|
||||
"qwen-plus": dashscope.Generation.Models.qwen_plus,
|
||||
"qwen-max": dashscope.Generation.Models.qwen_max,
|
||||
"qwen-bailian-v1": dashscope.Generation.Models.bailian_v1
|
||||
}
|
||||
# ZhipuAI对话模型API
|
||||
class DashscopeBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(DashscopeSession, model=conf().get("model") or "qwen-plus")
|
||||
self.model_name = conf().get("model") or "qwen-plus"
|
||||
self.api_key = conf().get("dashscope_api_key")
|
||||
os.environ["DASHSCOPE_API_KEY"] = self.api_key
|
||||
self.client = dashscope.Generation
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[DASHSCOPE] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[DASHSCOPE] session query={}".format(session.messages))
|
||||
|
||||
reply_content = self.reply_text(session)
|
||||
logger.debug(
|
||||
"[DASHSCOPE] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[DASHSCOPE] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: DashscopeSession, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
dashscope.api_key = self.api_key
|
||||
response = self.client.call(
|
||||
dashscope_models[self.model_name],
|
||||
messages=session.messages,
|
||||
result_format="message"
|
||||
)
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
content = response.output.choices[0]["message"]["content"]
|
||||
return {
|
||||
"total_tokens": response.usage["total_tokens"],
|
||||
"completion_tokens": response.usage["output_tokens"],
|
||||
"content": content,
|
||||
}
|
||||
else:
|
||||
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||||
response.request_id, response.status_code,
|
||||
response.code, response.message
|
||||
))
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
51
bot/dashscope/dashscope_session.py
Normal file
51
bot/dashscope/dashscope_session.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class DashscopeSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="qwen-turbo"):
|
||||
super().__init__(session_id)
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages):
|
||||
# 只是大概,具体计算规则:https://help.aliyun.com/zh/dashscope/developer-reference/token-api?spm=a2c4g.11186623.0.0.4d8b12b0BkP3K9
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -13,7 +13,9 @@ from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
@@ -22,9 +24,11 @@ class GoogleGeminiBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api_key = conf().get("gemini_api_key")
|
||||
# 复用文心的token计算方式
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
|
||||
# 复用chatGPT的token计算方式
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.model = conf().get("model") or "gemini-pro"
|
||||
if self.model == "gemini":
|
||||
self.model = "gemini-pro"
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
try:
|
||||
if context.type != ContextType.TEXT:
|
||||
@@ -33,18 +37,45 @@ class GoogleGeminiBot(Bot):
|
||||
logger.info(f"[Gemini] query={query}")
|
||||
session_id = context["session_id"]
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
|
||||
gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
|
||||
logger.debug(f"[Gemini] messages={gemini_messages}")
|
||||
genai.configure(api_key=self.api_key)
|
||||
model = genai.GenerativeModel('gemini-pro')
|
||||
response = model.generate_content(gemini_messages)
|
||||
reply_text = response.text
|
||||
self.sessions.session_reply(reply_text, session_id)
|
||||
logger.info(f"[Gemini] reply={reply_text}")
|
||||
return Reply(ReplyType.TEXT, reply_text)
|
||||
model = genai.GenerativeModel(self.model)
|
||||
|
||||
# 添加安全设置
|
||||
safety_settings = {
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
|
||||
# 生成回复,包含安全设置
|
||||
response = model.generate_content(
|
||||
gemini_messages,
|
||||
safety_settings=safety_settings
|
||||
)
|
||||
if response.candidates and response.candidates[0].content:
|
||||
reply_text = response.candidates[0].content.parts[0].text
|
||||
logger.info(f"[Gemini] reply={reply_text}")
|
||||
self.sessions.session_reply(reply_text, session_id)
|
||||
return Reply(ReplyType.TEXT, reply_text)
|
||||
else:
|
||||
# 没有有效响应内容,可能内容被屏蔽,输出安全评分
|
||||
logger.warning("[Gemini] No valid response generated. Checking safety ratings.")
|
||||
if hasattr(response, 'candidates') and response.candidates:
|
||||
for rating in response.candidates[0].safety_ratings:
|
||||
logger.warning(f"Safety rating: {rating.category} - {rating.probability}")
|
||||
error_message = "No valid response generated due to safety constraints."
|
||||
self.sessions.session_reply(error_message, session_id)
|
||||
return Reply(ReplyType.ERROR, error_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[Gemini] fetch reply error, may contain unsafe content")
|
||||
logger.error(e)
|
||||
|
||||
logger.error(f"[Gemini] Error generating response: {str(e)}", exc_info=True)
|
||||
error_message = "Failed to invoke [Gemini] api!"
|
||||
self.sessions.session_reply(error_message, session_id)
|
||||
return Reply(ReplyType.ERROR, error_message)
|
||||
|
||||
def _convert_to_gemini_messages(self, messages: list):
|
||||
res = []
|
||||
for msg in messages:
|
||||
@@ -52,6 +83,8 @@ class GoogleGeminiBot(Bot):
|
||||
role = "user"
|
||||
elif msg.get("role") == "assistant":
|
||||
role = "model"
|
||||
elif msg.get("role") == "system":
|
||||
role = "user"
|
||||
else:
|
||||
continue
|
||||
res.append({
|
||||
@@ -60,12 +93,19 @@ class GoogleGeminiBot(Bot):
|
||||
})
|
||||
return res
|
||||
|
||||
def _filter_messages(self, messages: list):
|
||||
@staticmethod
|
||||
def filter_messages(messages: list):
|
||||
res = []
|
||||
turn = "user"
|
||||
if not messages:
|
||||
return res
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
message = messages[i]
|
||||
if message.get("role") != turn:
|
||||
role = message.get("role")
|
||||
if role == "system":
|
||||
res.insert(0, message)
|
||||
continue
|
||||
if role != turn:
|
||||
continue
|
||||
res.insert(0, message)
|
||||
if turn == "user":
|
||||
|
||||
@@ -92,7 +92,8 @@ class LinkAIBot(Bot):
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"session_id": session_id,
|
||||
"channel_type": conf().get("channel_type")
|
||||
"sender_id": session_id,
|
||||
"channel_type": conf().get("channel_type", "wx")
|
||||
}
|
||||
try:
|
||||
from linkai import LinkAIClient
|
||||
@@ -107,7 +108,11 @@ class LinkAIBot(Bot):
|
||||
body["group_name"] = context.kwargs.get("msg").from_user_nickname
|
||||
body["sender_name"] = context.kwargs.get("msg").actual_user_nickname
|
||||
else:
|
||||
body["sender_name"] = context.kwargs.get("msg").from_user_nickname
|
||||
if body.get("channel_type") in ["wechatcom_app"]:
|
||||
body["sender_name"] = context.kwargs.get("msg").from_user_id
|
||||
else:
|
||||
body["sender_name"] = context.kwargs.get("msg").from_user_nickname
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
file_id = context.kwargs.get("file_id")
|
||||
@@ -117,7 +122,7 @@ class LinkAIBot(Bot):
|
||||
headers = {"Authorization": "Bearer " + linkai_api_key}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
@@ -125,9 +130,12 @@ class LinkAIBot(Bot):
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
|
||||
|
||||
res_code = response.get('code')
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}, res_code={res_code}")
|
||||
if res_code == 429:
|
||||
logger.warn(f"[LINKAI] 用户访问超出限流配置,sender_id={body.get('sender_id')}")
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
|
||||
agent_suffix = self._fetch_agent_suffix(response)
|
||||
if agent_suffix:
|
||||
reply_content += agent_suffix
|
||||
@@ -156,7 +164,10 @@ class LinkAIBot(Bot):
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
return Reply(ReplyType.TEXT, "提问太快啦,请休息一下再问我吧")
|
||||
error_reply = "提问太快啦,请休息一下再问我吧"
|
||||
if res.status_code == 409:
|
||||
error_reply = "这个问题我还没有学会,请问我其它问题吧"
|
||||
return Reply(ReplyType.TEXT, error_reply)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
@@ -250,7 +261,7 @@ class LinkAIBot(Bot):
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
@@ -293,7 +304,7 @@ class LinkAIBot(Bot):
|
||||
def _fetch_app_info(self, app_code: str):
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
params = {"app_code": app_code}
|
||||
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
|
||||
if res.status_code == 200:
|
||||
@@ -315,7 +326,7 @@ class LinkAIBot(Bot):
|
||||
"response_format": "url",
|
||||
"img_proxy": conf().get("image_proxy")
|
||||
}
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/images/generations"
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/images/generations"
|
||||
res = requests.post(url, headers=headers, json=data, timeout=(5, 90))
|
||||
t2 = time.time()
|
||||
image_url = res.json()["data"][0]["url"]
|
||||
@@ -386,11 +397,18 @@ class LinkAIBot(Bot):
|
||||
def _send_image(self, channel, context, image_urls):
|
||||
if not image_urls:
|
||||
return
|
||||
max_send_num = conf().get("max_media_send_count")
|
||||
send_interval = conf().get("media_send_interval")
|
||||
file_type = (".pdf", ".doc", ".docx", ".csv", ".xls", ".xlsx", ".txt", ".rtf", ".ppt", ".pptx")
|
||||
try:
|
||||
i = 0
|
||||
for url in image_urls:
|
||||
if max_send_num and i >= max_send_num:
|
||||
continue
|
||||
i += 1
|
||||
if url.endswith(".mp4"):
|
||||
reply_type = ReplyType.VIDEO_URL
|
||||
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx"):
|
||||
elif url.endswith(file_type):
|
||||
reply_type = ReplyType.FILE
|
||||
url = _download_file(url)
|
||||
if not url:
|
||||
@@ -399,6 +417,8 @@ class LinkAIBot(Bot):
|
||||
reply_type = ReplyType.IMAGE_URL
|
||||
reply = Reply(reply_type, url)
|
||||
channel.send(reply, context)
|
||||
if send_interval:
|
||||
time.sleep(send_interval)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
|
||||
151
bot/minimax/minimax_bot.py
Normal file
151
bot/minimax/minimax_bot.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.minimax.minimax_session import MinimaxSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
import requests
|
||||
from common import const
|
||||
|
||||
|
||||
# ZhipuAI对话模型API
|
||||
class MinimaxBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.args = {
|
||||
"model": conf().get("model") or "abab6.5", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
|
||||
"top_p": conf().get("top_p", 0.95), # 使用默认值
|
||||
}
|
||||
self.api_key = conf().get("Minimax_api_key")
|
||||
self.group_id = conf().get("Minimax_group_id")
|
||||
self.base_url = conf().get("Minimax_base_url", f"https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={self.group_id}")
|
||||
# tokens_to_generate/bot_setting/reply_constraints可自行修改
|
||||
self.request_body = {
|
||||
"model": self.args["model"],
|
||||
"tokens_to_generate": 2048,
|
||||
"reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"},
|
||||
"messages": [],
|
||||
"bot_setting": [
|
||||
{
|
||||
"bot_name": "MM智能助理",
|
||||
"content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。",
|
||||
}
|
||||
],
|
||||
}
|
||||
self.sessions = SessionManager(MinimaxSession, model=const.MiniMax)
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
# acquire reply content
|
||||
logger.info("[Minimax_AI] query={}".format(query))
|
||||
if context.type == ContextType.TEXT:
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[Minimax_AI] session query={}".format(session))
|
||||
|
||||
model = context.get("Minimax_model")
|
||||
new_args = self.args.copy()
|
||||
if model:
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, args=new_args)
|
||||
logger.debug(
|
||||
"[Minimax_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[Minimax_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: MinimaxSession, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + self.api_key}
|
||||
self.request_body["messages"].extend(session.messages)
|
||||
logger.info("[Minimax_AI] request_body={}".format(self.request_body))
|
||||
# logger.info("[Minimax_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
res = requests.post(self.base_url, headers=headers, json=self.request_body)
|
||||
|
||||
# self.request_body["messages"].extend(response.json()["choices"][0]["messages"])
|
||||
if res.status_code == 200:
|
||||
response = res.json()
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["total_tokens"],
|
||||
"content": response["reply"],
|
||||
}
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[Minimax_AI] chat failed, status_code={res.status_code}, " f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
need_retry = False
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
logger.warn(f"[Minimax_AI] do retry, times={retry_count}")
|
||||
need_retry = retry_count < 2
|
||||
elif res.status_code == 401:
|
||||
result["content"] = "授权失败,请检查API Key是否正确"
|
||||
elif res.status_code == 429:
|
||||
result["content"] = "请求过于频繁,请稍后再试"
|
||||
need_retry = retry_count < 2
|
||||
else:
|
||||
need_retry = False
|
||||
|
||||
if need_retry:
|
||||
time.sleep(3)
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
72
bot/minimax/minimax_session.py
Normal file
72
bot/minimax/minimax_session.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
e.g.
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class MinimaxSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="minimax"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# self.reset()
|
||||
|
||||
def add_query(self, query):
|
||||
user_item = {"sender_type": "USER", "sender_name": self.session_id, "text": query}
|
||||
self.messages.append(user_item)
|
||||
|
||||
def add_reply(self, reply):
|
||||
assistant_item = {"sender_type": "BOT", "sender_name": "MM智能助理", "text": reply}
|
||||
self.messages.append(assistant_item)
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "BOT":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["sender_type"] == "USER":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
# 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词"
|
||||
# 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html
|
||||
# 目前根据字符串长度粗略估计token数,不影响正常使用
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["text"])
|
||||
return tokens
|
||||
146
bot/moonshot/moonshot_bot.py
Normal file
146
bot/moonshot/moonshot_bot.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from .moonshot_session import MoonshotSession
|
||||
import requests
|
||||
|
||||
|
||||
# ZhipuAI对话模型API
|
||||
class MoonshotBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(MoonshotSession, model=conf().get("model") or "moonshot-v1-128k")
|
||||
model = conf().get("model") or "moonshot-v1-128k"
|
||||
if model == "moonshot":
|
||||
model = "moonshot-v1-32k"
|
||||
self.args = {
|
||||
"model": model, # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
|
||||
"top_p": conf().get("top_p", 1.0), # 使用默认值
|
||||
}
|
||||
self.api_key = conf().get("moonshot_api_key")
|
||||
self.base_url = conf().get("moonshot_base_url", "https://api.moonshot.cn/v1/chat/completions")
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[MOONSHOT_AI] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[MOONSHOT_AI] session query={}".format(session.messages))
|
||||
|
||||
model = context.get("moonshot_model")
|
||||
new_args = self.args.copy()
|
||||
if model:
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, args=new_args)
|
||||
logger.debug(
|
||||
"[MOONSHOT_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[MOONSHOT_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: MoonshotSession, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key
|
||||
}
|
||||
body = args
|
||||
body["messages"] = session.messages
|
||||
# logger.debug("[MOONSHOT_AI] response={}".format(response))
|
||||
# logger.info("[MOONSHOT_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
res = requests.post(
|
||||
self.base_url,
|
||||
headers=headers,
|
||||
json=body
|
||||
)
|
||||
if res.status_code == 200:
|
||||
response = res.json()
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": response["choices"][0]["message"]["content"]
|
||||
}
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[MOONSHOT_AI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
|
||||
need_retry = False
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
logger.warn(f"[MOONSHOT_AI] do retry, times={retry_count}")
|
||||
need_retry = retry_count < 2
|
||||
elif res.status_code == 401:
|
||||
result["content"] = "授权失败,请检查API Key是否正确"
|
||||
elif res.status_code == 429:
|
||||
result["content"] = "请求过于频繁,请稍后再试"
|
||||
need_retry = retry_count < 2
|
||||
else:
|
||||
need_retry = False
|
||||
|
||||
if need_retry:
|
||||
time.sleep(3)
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if need_retry:
|
||||
return self.reply_text(session, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
51
bot/moonshot/moonshot_session.py
Normal file
51
bot/moonshot/moonshot_session.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class MoonshotSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="moonshot-v1-128k"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -3,7 +3,7 @@
|
||||
import requests, json
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
@@ -41,17 +41,19 @@ class XunFeiBot(Bot):
|
||||
self.api_key = conf().get("xunfei_api_key")
|
||||
self.api_secret = conf().get("xunfei_api_secret")
|
||||
# 默认使用v2.0版本: "generalv2"
|
||||
# v1.5版本为 "general"
|
||||
# v3.0版本为: "generalv3"
|
||||
self.domain = "generalv3"
|
||||
# 默认使用v2.0版本: "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
# v1.5版本为: "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
# v3.0版本为: "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
self.spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
# Spark Lite请求地址(spark_url): wss://spark-api.xf-yun.com/v1.1/chat, 对应的domain参数为: "general"
|
||||
# Spark V2.0请求地址(spark_url): wss://spark-api.xf-yun.com/v2.1/chat, 对应的domain参数为: "generalv2"
|
||||
# Spark Pro 请求地址(spark_url): wss://spark-api.xf-yun.com/v3.1/chat, 对应的domain参数为: "generalv3"
|
||||
# Spark Pro-128K请求地址(spark_url): wss://spark-api.xf-yun.com/chat/pro-128k, 对应的domain参数为: "pro-128k"
|
||||
# Spark Max 请求地址(spark_url): wss://spark-api.xf-yun.com/v3.5/chat, 对应的domain参数为: "generalv3.5"
|
||||
# Spark4.0 Ultra 请求地址(spark_url): wss://spark-api.xf-yun.com/v4.0/chat, 对应的domain参数为: "4.0Ultra"
|
||||
# 后续模型更新,对应的参数可以参考官网文档获取:https://www.xfyun.cn/doc/spark/Web.html
|
||||
self.domain = conf().get("xunfei_domain", "generalv3.5")
|
||||
self.spark_url = conf().get("xunfei_spark_url", "wss://spark-api.xf-yun.com/v3.5/chat")
|
||||
self.host = urlparse(self.spark_url).netloc
|
||||
self.path = urlparse(self.spark_url).path
|
||||
# 和wenxin使用相同的session机制
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
|
||||
self.sessions = SessionManager(ChatGPTSession, model=const.XUNFEI)
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
|
||||
29
bot/zhipuai/zhipu_ai_image.py
Normal file
29
bot/zhipuai/zhipu_ai_image.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
# ZhipuAI提供的画图接口
|
||||
|
||||
class ZhipuAIImage(object):
|
||||
def __init__(self):
|
||||
from zhipuai import ZhipuAI
|
||||
self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None, api_base=None):
|
||||
try:
|
||||
if conf().get("rate_limit_dalle"):
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[ZHIPU_AI] image_query={}".format(query))
|
||||
response = self.client.images.generations(
|
||||
prompt=query,
|
||||
n=1, # 每次生成图片的数量
|
||||
model=conf().get("text_to_image") or "cogview-3",
|
||||
size=conf().get("image_create_size", "1024x1024"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
quality="standard",
|
||||
)
|
||||
image_url = response.data[0].url
|
||||
logger.info("[ZHIPU_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
53
bot/zhipuai/zhipu_ai_session.py
Normal file
53
bot/zhipuai/zhipu_ai_session.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ZhipuAISession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="glm-4"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
if not system_prompt:
|
||||
logger.warn("[ZhiPu] `character_desc` can not be empty")
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
149
bot/zhipuai/zhipuai_bot.py
Normal file
149
bot/zhipuai/zhipuai_bot.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.zhipuai.zhipu_ai_session import ZhipuAISession
|
||||
from bot.zhipuai.zhipu_ai_image import ZhipuAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
|
||||
# ZhipuAI对话模型API
|
||||
class ZHIPUAIBot(Bot, ZhipuAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ZhipuAISession, model=conf().get("model") or "ZHIPU_AI")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "glm-4", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在(0,1)之间(智谱AI 的温度不能取 0 或者 1)
|
||||
"top_p": conf().get("top_p", 0.7), # 值在(0,1)之间(智谱AI 的 top_p 不能取 0 或者 1)
|
||||
}
|
||||
self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[ZHIPU_AI] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[ZHIPU_AI] session query={}".format(session.messages))
|
||||
|
||||
api_key = context.get("openai_api_key") or openai.api_key
|
||||
model = context.get("gpt_model")
|
||||
new_args = None
|
||||
if model:
|
||||
new_args = self.args.copy()
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, api_key, args=new_args)
|
||||
logger.debug(
|
||||
"[ZHIPU_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[ZHIPU_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: ZhipuAISession, api_key=None, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
# if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
||||
# raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
if args is None:
|
||||
args = self.args
|
||||
# response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
||||
response = self.client.chat.completions.create(messages=session.messages, **args)
|
||||
# logger.debug("[ZHIPU_AI] response={}".format(response))
|
||||
# logger.info("[ZHIPU_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
|
||||
return {
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"content": response.choices[0].message.content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[ZHIPU_AI] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[ZHIPU_AI] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIError):
|
||||
logger.warn("[ZHIPU_AI] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[ZHIPU_AI] APIConnectionError: {}".format(e))
|
||||
result["content"] = "我连接不到你的网络"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
else:
|
||||
logger.exception("[ZHIPU_AI] Exception: {}".format(e), e)
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[ZHIPU_AI] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, api_key, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
@@ -18,32 +18,51 @@ class Bridge(object):
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
}
|
||||
model_type = conf().get("model") or const.GPT35
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype["chat"] = const.CHATGPTONAZURE
|
||||
if model_type in ["wenxin", "wenxin-4"]:
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.GEMINI]:
|
||||
self.btype["chat"] = const.GEMINI
|
||||
# 这边取配置的模型
|
||||
bot_type = conf().get("bot_type")
|
||||
if bot_type:
|
||||
self.btype["chat"] = bot_type
|
||||
else:
|
||||
model_type = conf().get("model") or const.GPT35
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype["chat"] = const.CHATGPTONAZURE
|
||||
if model_type in ["wenxin", "wenxin-4"]:
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.QWEN_TURBO, const.QWEN_PLUS, const.QWEN_MAX]:
|
||||
self.btype["chat"] = const.QWEN_DASHSCOPE
|
||||
if model_type and model_type.startswith("gemini"):
|
||||
self.btype["chat"] = const.GEMINI
|
||||
if model_type and model_type.startswith("glm"):
|
||||
self.btype["chat"] = const.ZHIPU_AI
|
||||
if model_type and model_type.startswith("claude-3"):
|
||||
self.btype["chat"] = const.CLAUDEAPI
|
||||
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
self.btype["chat"] = const.LINKAI
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
self.btype["voice_to_text"] = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
self.btype["text_to_voice"] = const.LINKAI
|
||||
if model_type in ["claude"]:
|
||||
self.btype["chat"] = const.CLAUDEAI
|
||||
|
||||
if model_type in [const.MOONSHOT, "moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k"]:
|
||||
self.btype["chat"] = const.MOONSHOT
|
||||
|
||||
if model_type in ["abab6.5-chat"]:
|
||||
self.btype["chat"] = const.MiniMax
|
||||
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
self.btype["chat"] = const.LINKAI
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
self.btype["voice_to_text"] = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
self.btype["text_to_voice"] = const.LINKAI
|
||||
|
||||
if model_type in ["claude"]:
|
||||
self.btype["chat"] = const.CLAUDEAI
|
||||
self.bots = {}
|
||||
self.chat_bots = {}
|
||||
|
||||
# 模型对应的接口
|
||||
def get_bot(self, typename):
|
||||
if self.bots.get(typename) is None:
|
||||
logger.info("create bot {} for {}".format(self.btype[typename], typename))
|
||||
|
||||
@@ -11,7 +11,7 @@ class ReplyType(Enum):
|
||||
VIDEO_URL = 5 # 视频URL
|
||||
FILE = 6 # 文件
|
||||
CARD = 7 # 微信名片,仅支持ntchat
|
||||
InviteRoom = 8 # 邀请好友进群
|
||||
INVITE_ROOM = 8 # 邀请好友进群
|
||||
INFO = 9
|
||||
ERROR = 10
|
||||
TEXT_ = 11 # 强制文本
|
||||
|
||||
@@ -21,6 +21,9 @@ def create_channel(channel_type) -> Channel:
|
||||
elif channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == 'web':
|
||||
from channel.web.web_channel import WebChannel
|
||||
ch = WebChannel()
|
||||
elif channel_type == "wechatmp":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
ch = WechatMPChannel(passive_reply=True)
|
||||
|
||||
@@ -17,6 +17,8 @@ try:
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
|
||||
|
||||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
||||
class ChatChannel(Channel):
|
||||
@@ -25,7 +27,6 @@ class ChatChannel(Channel):
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
|
||||
def __init__(self):
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
@@ -73,6 +74,7 @@ class ChatChannel(Channel):
|
||||
):
|
||||
session_id = group_id
|
||||
else:
|
||||
logger.debug(f"No need reply, groupName not in whitelist, group_name={group_name}")
|
||||
return None
|
||||
context["session_id"] = session_id
|
||||
context["receiver"] = group_id
|
||||
@@ -84,14 +86,14 @@ class ChatChannel(Channel):
|
||||
if e_context.is_pass() or context is None:
|
||||
return context
|
||||
if cmsg.from_user_id == self.user_id and not config.get("trigger_by_self", True):
|
||||
logger.debug("[WX]self message skipped")
|
||||
logger.debug("[chat_channel]self message skipped")
|
||||
return None
|
||||
|
||||
# 消息内容匹配过程,并处理content
|
||||
if ctype == ContextType.TEXT:
|
||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||
logger.debug(content)
|
||||
logger.debug("[WX]reference query skipped")
|
||||
logger.debug("[chat_channel]reference query skipped")
|
||||
return None
|
||||
|
||||
nick_name_black_list = conf().get("nick_name_black_list", [])
|
||||
@@ -109,12 +111,13 @@ class ChatChannel(Channel):
|
||||
nick_name = context["msg"].actual_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[WX] Nickname {nick_name} in In BlackList, ignore")
|
||||
logger.warning(f"[chat_channel] Nickname {nick_name} in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
logger.info("[WX]receive group at")
|
||||
logger.info("[chat_channel]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
flag = True
|
||||
self.name = self.name if self.name is not None else "" # 部分渠道self.name可能没有赋值
|
||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
if isinstance(context["msg"].at_list, list):
|
||||
@@ -128,13 +131,13 @@ class ChatChannel(Channel):
|
||||
content = subtract_res
|
||||
if not flag:
|
||||
if context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
||||
logger.info("[chat_channel]receive group voice, but checkprefix didn't match")
|
||||
return None
|
||||
else: # 单聊
|
||||
nick_name = context["msg"].from_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[WX] Nickname '{nick_name}' in In BlackList, ignore")
|
||||
logger.warning(f"[chat_channel] Nickname '{nick_name}' in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
|
||||
@@ -145,7 +148,7 @@ class ChatChannel(Channel):
|
||||
else:
|
||||
return None
|
||||
content = content.strip()
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix",[""]))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
@@ -157,22 +160,23 @@ class ChatChannel(Channel):
|
||||
elif context.type == ContextType.VOICE:
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice") and ReplyType.VOICE not in self.NOT_SUPPORT_REPLYTYPE:
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
|
||||
def _handle(self, context: Context):
|
||||
if context is None or not context.content:
|
||||
return
|
||||
logger.debug("[WX] ready to handle context: {}".format(context))
|
||||
logger.debug("[chat_channel] ready to handle context: {}".format(context))
|
||||
# reply的构建步骤
|
||||
reply = self._generate_reply(context)
|
||||
|
||||
logger.debug("[WX] ready to decorate reply: {}".format(reply))
|
||||
# reply的包装步骤
|
||||
reply = self._decorate_reply(context, reply)
|
||||
logger.debug("[chat_channel] ready to decorate reply: {}".format(reply))
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
# reply的包装步骤
|
||||
if reply and reply.content:
|
||||
reply = self._decorate_reply(context, reply)
|
||||
|
||||
# reply的发送步骤
|
||||
self._send_reply(context, reply)
|
||||
|
||||
def _generate_reply(self, context: Context, reply: Reply = Reply()) -> Reply:
|
||||
e_context = PluginManager().emit_event(
|
||||
@@ -183,7 +187,7 @@ class ChatChannel(Channel):
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass():
|
||||
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
|
||||
logger.debug("[chat_channel] ready to handle context: type={}, content={}".format(context.type, context.content))
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
context["channel"] = e_context["channel"]
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
@@ -195,7 +199,7 @@ class ChatChannel(Channel):
|
||||
try:
|
||||
any_to_wav(file_path, wav_path)
|
||||
except Exception as e: # 转换失败,直接使用mp3,对于某些api,mp3也可以识别
|
||||
logger.warning("[WX]any to wav error, use raw path. " + str(e))
|
||||
logger.warning("[chat_channel]any to wav error, use raw path. " + str(e))
|
||||
wav_path = file_path
|
||||
# 语音识别
|
||||
reply = super().build_voice_to_text(wav_path)
|
||||
@@ -206,7 +210,7 @@ class ChatChannel(Channel):
|
||||
os.remove(wav_path)
|
||||
except Exception as e:
|
||||
pass
|
||||
# logger.warning("[WX]delete temp file error: " + str(e))
|
||||
# logger.warning("[chat_channel]delete temp file error: " + str(e))
|
||||
|
||||
if reply.type == ReplyType.TEXT:
|
||||
new_context = self._compose_context(ContextType.TEXT, reply.content, **context.kwargs)
|
||||
@@ -224,7 +228,7 @@ class ChatChannel(Channel):
|
||||
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
|
||||
pass
|
||||
else:
|
||||
logger.warning("[WX] unknown context type: {}".format(context.type))
|
||||
logger.warning("[chat_channel] unknown context type: {}".format(context.type))
|
||||
return
|
||||
return reply
|
||||
|
||||
@@ -240,7 +244,7 @@ class ChatChannel(Channel):
|
||||
desire_rtype = context.get("desire_rtype")
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
if reply.type in self.NOT_SUPPORT_REPLYTYPE:
|
||||
logger.error("[WX]reply type not support: " + str(reply.type))
|
||||
logger.error("[chat_channel]reply type not support: " + str(reply.type))
|
||||
reply.type = ReplyType.ERROR
|
||||
reply.content = "不支持发送的消息类型: " + str(reply.type)
|
||||
|
||||
@@ -261,10 +265,10 @@ class ChatChannel(Channel):
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
|
||||
pass
|
||||
else:
|
||||
logger.error("[WX] unknown reply type: {}".format(reply.type))
|
||||
logger.error("[chat_channel] unknown reply type: {}".format(reply.type))
|
||||
return
|
||||
if desire_rtype and desire_rtype != reply.type and reply.type not in [ReplyType.ERROR, ReplyType.INFO]:
|
||||
logger.warning("[WX] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
|
||||
logger.warning("[chat_channel] desire_rtype: {}, but reply type: {}".format(context.get("desire_rtype"), reply.type))
|
||||
return reply
|
||||
|
||||
def _send_reply(self, context: Context, reply: Reply):
|
||||
@@ -277,14 +281,14 @@ class ChatChannel(Channel):
|
||||
)
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass() and reply and reply.type:
|
||||
logger.debug("[WX] ready to send reply: {}, context: {}".format(reply, context))
|
||||
logger.debug("[chat_channel] ready to send reply: {}, context: {}".format(reply, context))
|
||||
self._send(reply, context)
|
||||
|
||||
def _send(self, reply: Reply, context: Context, retry_cnt=0):
|
||||
try:
|
||||
self.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error("[WX] sendMsg error: {}".format(str(e)))
|
||||
logger.error("[chat_channel] sendMsg error: {}".format(str(e)))
|
||||
if isinstance(e, NotImplementedError):
|
||||
return
|
||||
logger.exception(e)
|
||||
@@ -333,24 +337,27 @@ class ChatChannel(Channel):
|
||||
while True:
|
||||
with self.lock:
|
||||
session_ids = list(self.sessions.keys())
|
||||
for session_id in session_ids:
|
||||
for session_id in session_ids:
|
||||
with self.lock:
|
||||
context_queue, semaphore = self.sessions[session_id]
|
||||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[WX] consume context: {}".format(context))
|
||||
future: Future = self.handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||
if semaphore.acquire(blocking=False): # 等线程处理完毕才能删除
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[chat_channel] consume context: {}".format(context))
|
||||
future: Future = handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||
with self.lock:
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
self.futures[session_id].append(future)
|
||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
elif semaphore._initial_value == semaphore._value + 1: # 除了当前,没有任务再申请到信号量,说明所有任务都处理完毕
|
||||
with self.lock:
|
||||
self.futures[session_id] = [t for t in self.futures[session_id] if not t.done()]
|
||||
assert len(self.futures[session_id]) == 0, "thread pool error"
|
||||
del self.sessions[session_id]
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
semaphore.release()
|
||||
time.sleep(0.2)
|
||||
|
||||
# 取消session_id对应的所有任务,只能取消排队的消息和已提交线程池但未执行的任务
|
||||
def cancel_session(self, session_id):
|
||||
|
||||
@@ -4,20 +4,81 @@
|
||||
@author huiwen
|
||||
@Date 2023/11/28
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
# -*- coding=utf-8 -*-
|
||||
import logging
|
||||
import time
|
||||
|
||||
import dingtalk_stream
|
||||
from dingtalk_stream import AckMessage
|
||||
from dingtalk_stream.card_replier import AICardReplier
|
||||
from dingtalk_stream.card_replier import AICardStatus
|
||||
from dingtalk_stream.card_replier import CardReplier
|
||||
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.dingtalk.dingtalk_message import DingTalkMessage
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.time_check import time_checker
|
||||
from config import conf
|
||||
from common.expired_dict import ExpiredDict
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_channel import ChatChannel
|
||||
import logging
|
||||
from dingtalk_stream import AckMessage
|
||||
import dingtalk_stream
|
||||
|
||||
|
||||
class CustomAICardReplier(CardReplier):
|
||||
def __init__(self, dingtalk_client, incoming_message):
|
||||
super(AICardReplier, self).__init__(dingtalk_client, incoming_message)
|
||||
|
||||
def start(
|
||||
self,
|
||||
card_template_id: str,
|
||||
card_data: dict,
|
||||
recipients: list = None,
|
||||
support_forward: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
AI卡片的创建接口
|
||||
:param support_forward:
|
||||
:param recipients:
|
||||
:param card_template_id:
|
||||
:param card_data:
|
||||
:return:
|
||||
"""
|
||||
card_data_with_status = copy.deepcopy(card_data)
|
||||
card_data_with_status["flowStatus"] = AICardStatus.PROCESSING
|
||||
return self.create_and_send_card(
|
||||
card_template_id,
|
||||
card_data_with_status,
|
||||
at_sender=True,
|
||||
at_all=False,
|
||||
recipients=recipients,
|
||||
support_forward=support_forward,
|
||||
)
|
||||
|
||||
|
||||
# 对 AICardReplier 进行猴子补丁
|
||||
AICardReplier.start = CustomAICardReplier.start
|
||||
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: DingTalkMessage):
|
||||
msgId = cmsg.msg_id
|
||||
if msgId in self.receivedMsgs:
|
||||
logger.info("DingTalk message {} already received, ignore".format(msgId))
|
||||
return
|
||||
self.receivedMsgs[msgId] = True
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[DingTalk] History message {} skipped".format(msgId))
|
||||
return
|
||||
if cmsg.my_msg and not cmsg.is_group:
|
||||
logger.debug("[DingTalk] My message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@singleton
|
||||
@@ -39,11 +100,13 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
super(dingtalk_stream.ChatbotHandler, self).__init__()
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
logger.info("[dingtalk] client_id={}, client_secret={} ".format(
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
logger.info("[DingTalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
# 单聊无需前缀
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
def startup(self):
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
@@ -51,50 +114,112 @@ class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
client.start_forever()
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
image_download_handler = self # 传入方法所在的类实例
|
||||
dingtalk_msg = DingTalkMessage(incoming_message, image_download_handler)
|
||||
|
||||
if dingtalk_msg.is_group:
|
||||
self.handle_group(dingtalk_msg)
|
||||
else:
|
||||
self.handle_single(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(f"dingtalk process error={e}")
|
||||
return AckMessage.STATUS_SYSTEM_EXCEPTION, 'ERROR'
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: DingTalkMessage):
|
||||
# 处理单聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE_CREATE:
|
||||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
expression = cmsg.my_msg
|
||||
cmsg.content = conf()["single_chat_prefix"][0] + cmsg.content
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: DingTalkMessage):
|
||||
# 处理群聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE_CREATE:
|
||||
logger.debug("[DingTalk]receive image create msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
|
||||
logger.debug("[DingTalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
expression = cmsg.my_msg
|
||||
cmsg.content = conf()["group_chat_prefix"][0] + cmsg.content
|
||||
logger.debug("[DingTalk]receive text msg: {}".format(cmsg.content))
|
||||
else:
|
||||
logger.debug("[DingTalk]receive other msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
context['no_need_at'] = True
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
dingtalk_msg = DingTalkMessage(incoming_message)
|
||||
if incoming_message.conversation_type == '1':
|
||||
self.handle_single(dingtalk_msg)
|
||||
else:
|
||||
self.handle_group(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.FAILED_MSG
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
isgroup = context.kwargs['msg'].is_group
|
||||
incoming_message = context.kwargs['msg'].incoming_message
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
|
||||
if conf().get("dingtalk_card_enabled"):
|
||||
logger.info("[Dingtalk] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
def reply_with_text():
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
def reply_with_at_text():
|
||||
self.reply_text("📢 您有一条新的消息,请查看。", incoming_message)
|
||||
def reply_with_ai_markdown():
|
||||
button_list, markdown_content = self.generate_button_markdown_content(context, reply)
|
||||
self.reply_ai_markdown_button(incoming_message, markdown_content, button_list, "", "📌 内容由AI生成", "",[incoming_message.sender_staff_id])
|
||||
|
||||
if reply.type in [ReplyType.IMAGE_URL, ReplyType.IMAGE, ReplyType.TEXT]:
|
||||
if isgroup:
|
||||
reply_with_ai_markdown()
|
||||
reply_with_at_text()
|
||||
else:
|
||||
reply_with_ai_markdown()
|
||||
else:
|
||||
# 暂不支持其它类型消息回复
|
||||
reply_with_text()
|
||||
else:
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
|
||||
|
||||
def generate_button_markdown_content(self, context, reply):
|
||||
image_url = context.kwargs.get("image_url")
|
||||
promptEn = context.kwargs.get("promptEn")
|
||||
reply_text = reply.content
|
||||
button_list = []
|
||||
markdown_content = f"""
|
||||
{reply.content}
|
||||
"""
|
||||
if image_url is not None and promptEn is not None:
|
||||
button_list = [
|
||||
{"text": "查看原图", "url": image_url, "iosUrl": image_url, "color": "blue"}
|
||||
]
|
||||
markdown_content = f"""
|
||||
{promptEn}
|
||||
|
||||

|
||||
|
||||
{reply_text}
|
||||
|
||||
"""
|
||||
logger.debug(f"[Dingtalk] generate_button_markdown_content, button_list={button_list} , markdown_content={markdown_content}")
|
||||
|
||||
return button_list, markdown_content
|
||||
|
||||
@@ -1,44 +1,84 @@
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
from dingtalk_stream import ChatbotMessage
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
# -*- coding=utf-8 -*-
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
class DingTalkMessage(ChatMessage):
|
||||
def __init__(self, event: ChatbotMessage):
|
||||
def __init__(self, event: ChatbotMessage, image_download_handler):
|
||||
super().__init__(event)
|
||||
|
||||
self.image_download_handler = image_download_handler
|
||||
self.msg_id = event.message_id
|
||||
msg_type = event.message_type
|
||||
self.incoming_message =event
|
||||
self.message_type = event.message_type
|
||||
self.incoming_message = event
|
||||
self.sender_staff_id = event.sender_staff_id
|
||||
self.other_user_id = event.conversation_id
|
||||
self.create_time = event.create_at
|
||||
if event.conversation_type=="1":
|
||||
self.image_content = event.image_content
|
||||
self.rich_text_content = event.rich_text_content
|
||||
if event.conversation_type == "1":
|
||||
self.is_group = False
|
||||
else:
|
||||
self.is_group = True
|
||||
|
||||
|
||||
if msg_type == "text":
|
||||
if self.message_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
|
||||
self.content = event.text.content.strip()
|
||||
elif msg_type == "audio":
|
||||
|
||||
elif self.message_type == "audio":
|
||||
# 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理
|
||||
self.content = event.extensions['content']['recognition'].strip()
|
||||
self.ctype = ContextType.TEXT
|
||||
self.from_user_id = event.sender_id
|
||||
elif (self.message_type == 'picture') or (self.message_type == 'richText'):
|
||||
self.ctype = ContextType.IMAGE
|
||||
# 钉钉图片类型或富文本类型消息处理
|
||||
image_list = event.get_image_list()
|
||||
if len(image_list) > 0:
|
||||
download_code = image_list[0]
|
||||
download_url = image_download_handler.get_image_download_url(download_code)
|
||||
self.content = download_image_file(download_url, TmpDir().path())
|
||||
else:
|
||||
logger.debug(f"[Dingtalk] messageType :{self.message_type} , imageList isEmpty")
|
||||
|
||||
if self.is_group:
|
||||
self.from_user_id = event.conversation_id
|
||||
self.actual_user_id = event.sender_id
|
||||
self.is_at = True
|
||||
else:
|
||||
self.from_user_id = event.sender_id
|
||||
self.actual_user_id = event.sender_id
|
||||
self.to_user_id = event.chatbot_user_id
|
||||
self.other_user_nickname = event.conversation_title
|
||||
|
||||
user_id = event.sender_id
|
||||
nickname =event.sender_nick
|
||||
|
||||
|
||||
|
||||
|
||||
def download_image_file(image_url, temp_dir):
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36'
|
||||
}
|
||||
# 设置代理
|
||||
# self.proxies
|
||||
# , proxies=self.proxies
|
||||
response = requests.get(image_url, headers=headers, stream=True, timeout=60 * 5)
|
||||
if response.status_code == 200:
|
||||
|
||||
# 生成文件名
|
||||
file_name = image_url.split("/")[-1].split("?")[0]
|
||||
|
||||
# 检查临时目录是否存在,如果不存在则创建
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
|
||||
# 将文件保存到临时目录
|
||||
file_path = os.path.join(temp_dir, file_name)
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
return file_path
|
||||
else:
|
||||
logger.info(f"[Dingtalk] Failed to download image file, {response.content}")
|
||||
return None
|
||||
|
||||
@@ -40,7 +40,7 @@ class FeiShuChanel(ChatChannel):
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = []
|
||||
conf()["single_chat_prefix"] = [""]
|
||||
|
||||
def startup(self):
|
||||
urls = (
|
||||
|
||||
@@ -78,6 +78,7 @@ class TerminalChannel(ChatChannel):
|
||||
prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀
|
||||
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt))
|
||||
context["isgroup"] = False
|
||||
if context:
|
||||
self.produce(context)
|
||||
else:
|
||||
|
||||
7
channel/web/README.md
Normal file
7
channel/web/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Web channel
|
||||
使用SSE(Server-Sent Events,服务器推送事件)实现,提供了一个默认的网页。也可以自己实现加入api
|
||||
|
||||
#使用方法
|
||||
- 在配置文件中channel_type填入web即可
|
||||
- 访问地址 http://localhost:9899
|
||||
- port可以在配置项 web_port中设置
|
||||
165
channel/web/chat.html
Normal file
165
channel/web/chat.html
Normal file
@@ -0,0 +1,165 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Chat</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100vh; /* 占据所有高度 */
|
||||
margin: 0;
|
||||
/* background-color: #f8f9fa; */
|
||||
}
|
||||
#chat-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
width: 100%;
|
||||
max-width: 500px;
|
||||
margin: auto;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 5px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||
flex: 1; /* 使聊天容器占据剩余空间 */
|
||||
}
|
||||
#messages {
|
||||
flex-direction: column;
|
||||
display: flex;
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 10px;
|
||||
overflow-y: auto;
|
||||
border-bottom: 1px solid #ccc;
|
||||
background-color: #ffffff;
|
||||
}
|
||||
|
||||
.message {
|
||||
margin: 5px 0; /* 间隔 */
|
||||
padding: 10px 15px; /* 内边距 */
|
||||
border-radius: 15px; /* 圆角 */
|
||||
max-width: 80%; /* 限制最大宽度 */
|
||||
min-width: 80px; /* 设置最小宽度 */
|
||||
min-height: 40px; /* 设置最小高度 */
|
||||
word-wrap: break-word; /* 自动换行 */
|
||||
position: relative; /* 时间戳定位 */
|
||||
display: inline-block; /* 内容自适应宽度 */
|
||||
box-sizing: border-box; /* 包括内边距和边框 */
|
||||
flex-shrink: 0; /* 禁止高度被压缩 */
|
||||
word-wrap: break-word; /* 自动换行,防止单行过长 */
|
||||
white-space: normal; /* 允许正常换行 */
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.bot {
|
||||
background-color: #f1f1f1; /* 灰色背景 */
|
||||
color: black; /* 黑色字体 */
|
||||
align-self: flex-start; /* 左对齐 */
|
||||
margin-right: auto; /* 确保消息靠左 */
|
||||
text-align: left; /* 内容左对齐 */
|
||||
}
|
||||
|
||||
.user {
|
||||
background-color: #2bc840; /* 蓝色背景 */
|
||||
align-self: flex-end; /* 右对齐 */
|
||||
margin-left: auto; /* 确保消息靠右 */
|
||||
text-align: left; /* 内容左对齐 */
|
||||
}
|
||||
.timestamp {
|
||||
font-size: 0.8em; /* 时间戳字体大小 */
|
||||
color: rgba(0, 0, 0, 0.5); /* 半透明黑色 */
|
||||
margin-bottom: 5px; /* 时间戳下方间距 */
|
||||
display: block; /* 时间戳独占一行 */
|
||||
}
|
||||
#input-container {
|
||||
display: flex;
|
||||
padding: 10px;
|
||||
background-color: #ffffff;
|
||||
border-top: 1px solid #ccc;
|
||||
}
|
||||
#input {
|
||||
flex: 1;
|
||||
padding: 10px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 5px;
|
||||
margin-right: 10px;
|
||||
}
|
||||
#send {
|
||||
padding: 10px;
|
||||
border: none;
|
||||
background-color: #007bff;
|
||||
color: white;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
}
|
||||
#send:hover {
|
||||
background-color: #0056b3;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="chat-container">
|
||||
<div id="messages"></div>
|
||||
<div id="input-container">
|
||||
<input type="text" id="input" placeholder="输入消息..." />
|
||||
<button id="send">发送</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const messagesDiv = document.getElementById('messages');
|
||||
const input = document.getElementById('input');
|
||||
const sendButton = document.getElementById('send');
|
||||
|
||||
// 生成唯一的 user_id
|
||||
const userId = 'user_' + Math.random().toString(36).substr(2, 9);
|
||||
|
||||
// 连接 SSE
|
||||
const eventSource = new EventSource(`/sse/${userId}`);
|
||||
|
||||
eventSource.onmessage = function(event) {
|
||||
const message = JSON.parse(event.data);
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = 'message bot';
|
||||
const timestamp = new Date(message.timestamp).toLocaleTimeString(); // 假设消息中有时间戳
|
||||
messageDiv.innerHTML = `<div class="timestamp">${timestamp}</div>${message.content}`; // 显示时间
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
messagesDiv.scrollTop = messagesDiv.scrollHeight; // 滚动到底部
|
||||
};
|
||||
|
||||
sendButton.onclick = function() {
|
||||
sendMessage();
|
||||
};
|
||||
|
||||
input.addEventListener('keypress', function(event) {
|
||||
if (event.key === 'Enter') {
|
||||
sendMessage();
|
||||
event.preventDefault(); // 防止换行
|
||||
}
|
||||
});
|
||||
|
||||
function sendMessage() {
|
||||
const userMessage = input.value;
|
||||
if (userMessage) {
|
||||
const timestamp = new Date().toISOString(); // 获取当前时间戳
|
||||
fetch('/message', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({ user_id: userId, message: userMessage, timestamp: timestamp }) // 发送时间戳
|
||||
});
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = 'message user';
|
||||
const userTimestamp = new Date().toLocaleTimeString(); // 获取当前时间
|
||||
messageDiv.innerHTML = `<div class="timestamp">${userTimestamp}</div>${userMessage}`; // 显示时间
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
messagesDiv.scrollTop = messagesDiv.scrollHeight; // 滚动到底部
|
||||
input.value = ''; // 清空输入框
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
204
channel/web/web_channel.py
Normal file
204
channel/web/web_channel.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import sys
|
||||
import time
|
||||
import web
|
||||
import json
|
||||
from queue import Queue
|
||||
from bridge.context import *
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
import os
|
||||
|
||||
|
||||
class WebMessage(ChatMessage):
|
||||
def __init__(
|
||||
self,
|
||||
msg_id,
|
||||
content,
|
||||
ctype=ContextType.TEXT,
|
||||
from_user_id="User",
|
||||
to_user_id="Chatgpt",
|
||||
other_user_id="Chatgpt",
|
||||
):
|
||||
self.msg_id = msg_id
|
||||
self.ctype = ctype
|
||||
self.content = content
|
||||
self.from_user_id = from_user_id
|
||||
self.to_user_id = to_user_id
|
||||
self.other_user_id = other_user_id
|
||||
|
||||
|
||||
@singleton
|
||||
class WebChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE]
|
||||
_instance = None
|
||||
|
||||
# def __new__(cls):
|
||||
# if cls._instance is None:
|
||||
# cls._instance = super(WebChannel, cls).__new__(cls)
|
||||
# return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.message_queues = {} # 为每个用户存储一个消息队列
|
||||
self.msg_id_counter = 0 # 添加消息ID计数器
|
||||
|
||||
def _generate_msg_id(self):
|
||||
"""生成唯一的消息ID"""
|
||||
self.msg_id_counter += 1
|
||||
return str(int(time.time())) + str(self.msg_id_counter)
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
try:
|
||||
if reply.type == ReplyType.IMAGE:
|
||||
from PIL import Image
|
||||
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print("<IMAGE>")
|
||||
img.show()
|
||||
elif reply.type == ReplyType.IMAGE_URL:
|
||||
import io
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
img_url = reply.content
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
image_storage.seek(0)
|
||||
img = Image.open(image_storage)
|
||||
print(img_url)
|
||||
img.show()
|
||||
else:
|
||||
print(reply.content)
|
||||
|
||||
# 获取用户ID,如果没有则使用默认值
|
||||
# user_id = getattr(context.get("session", None), "session_id", "default_user")
|
||||
user_id = context["receiver"]
|
||||
# 确保用户有对应的消息队列
|
||||
if user_id not in self.message_queues:
|
||||
self.message_queues[user_id] = Queue()
|
||||
|
||||
# 将消息放入对应用户的队列
|
||||
message_data = {
|
||||
"type": str(reply.type),
|
||||
"content": reply.content,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
self.message_queues[user_id].put(message_data)
|
||||
logger.debug(f"Message queued for user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in send method: {e}")
|
||||
raise
|
||||
|
||||
def sse_handler(self, user_id):
|
||||
"""
|
||||
Handle Server-Sent Events (SSE) for real-time communication.
|
||||
"""
|
||||
web.header('Content-Type', 'text/event-stream')
|
||||
web.header('Cache-Control', 'no-cache')
|
||||
web.header('Connection', 'keep-alive')
|
||||
|
||||
# 确保用户有消息队列
|
||||
if user_id not in self.message_queues:
|
||||
self.message_queues[user_id] = Queue()
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# 发送心跳
|
||||
yield f": heartbeat\n\n"
|
||||
|
||||
# 非阻塞方式获取消息
|
||||
if not self.message_queues[user_id].empty():
|
||||
message = self.message_queues[user_id].get_nowait()
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
time.sleep(0.5)
|
||||
except Exception as e:
|
||||
logger.error(f"SSE Error: {e}")
|
||||
break
|
||||
finally:
|
||||
# 清理资源
|
||||
if user_id in self.message_queues:
|
||||
# 只有当队列为空时才删除
|
||||
if self.message_queues[user_id].empty():
|
||||
del self.message_queues[user_id]
|
||||
|
||||
def post_message(self):
|
||||
"""
|
||||
Handle incoming messages from users via POST request.
|
||||
"""
|
||||
try:
|
||||
data = web.data() # 获取原始POST数据
|
||||
json_data = json.loads(data)
|
||||
user_id = json_data.get('user_id', 'default_user')
|
||||
prompt = json_data.get('message', '')
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps({"status": "error", "message": "Invalid JSON"})
|
||||
except Exception as e:
|
||||
return json.dumps({"status": "error", "message": str(e)})
|
||||
|
||||
if not prompt:
|
||||
return json.dumps({"status": "error", "message": "No message provided"})
|
||||
|
||||
try:
|
||||
msg_id = self._generate_msg_id()
|
||||
context = self._compose_context(ContextType.TEXT, prompt, msg=WebMessage(msg_id,
|
||||
prompt,
|
||||
from_user_id=user_id,
|
||||
other_user_id = user_id
|
||||
))
|
||||
context["isgroup"] = False
|
||||
# context["session"] = web.storage(session_id=user_id)
|
||||
|
||||
if not context:
|
||||
return json.dumps({"status": "error", "message": "Failed to process message"})
|
||||
|
||||
self.produce(context)
|
||||
return json.dumps({"status": "success", "message": "Message received"})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
return json.dumps({"status": "error", "message": "Internal server error"})
|
||||
|
||||
def chat_page(self):
|
||||
"""Serve the chat HTML page."""
|
||||
file_path = os.path.join(os.path.dirname(__file__), 'chat.html') # 使用绝对路径
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
def startup(self):
|
||||
logger.setLevel("WARN")
|
||||
print("\nWeb Channel is running. Send POST requests to /message to send messages.")
|
||||
|
||||
urls = (
|
||||
'/sse/(.+)', 'SSEHandler', # 修改路由以接收用户ID
|
||||
'/message', 'MessageHandler',
|
||||
'/chat', 'ChatHandler',
|
||||
)
|
||||
port = conf().get("web_port", 9899)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
|
||||
class SSEHandler:
|
||||
def GET(self, user_id):
|
||||
return WebChannel().sse_handler(user_id)
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
def POST(self):
|
||||
return WebChannel().post_message()
|
||||
|
||||
|
||||
class ChatHandler:
|
||||
def GET(self):
|
||||
return WebChannel().chat_page()
|
||||
@@ -9,17 +9,18 @@ import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel import chat_channel
|
||||
from channel.wechat.wechat_message import *
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.time_check import time_checker
|
||||
from common.utils import convert_webp_to_png, remove_markdown_symbol
|
||||
from config import conf, get_appdata_dir
|
||||
from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
@@ -95,11 +96,14 @@ def qrCallback(uuid, status, qrcode):
|
||||
print(qr_api4)
|
||||
print(qr_api2)
|
||||
print(qr_api1)
|
||||
_send_qr_code([qr_api1, qr_api2, qr_api3, qr_api4])
|
||||
_send_qr_code([qr_api3, qr_api4, qr_api2, qr_api1])
|
||||
qr = qrcode.QRCode(border=1)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
try:
|
||||
qr.print_ascii(invert=True)
|
||||
except UnicodeEncodeError:
|
||||
print("ASCII QR code printing failed due to encoding issues.")
|
||||
|
||||
|
||||
@singleton
|
||||
@@ -108,34 +112,43 @@ class WechatChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.receivedMsgs = ExpiredDict(60 * 60)
|
||||
self.receivedMsgs = ExpiredDict(conf().get("expires_in_seconds", 3600))
|
||||
self.auto_login_times = 0
|
||||
|
||||
def startup(self):
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get("hot_reload", False)
|
||||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
itchat.auto_login(
|
||||
enableCmdQR=2,
|
||||
hotReload=hotReload,
|
||||
statusStorageDir=status_path,
|
||||
qrCallback=qrCallback,
|
||||
exitCallback=self.exitCallback,
|
||||
loginCallback=self.loginCallback
|
||||
)
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
self.name = itchat.instance.storageClass.nickName
|
||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# start message listener
|
||||
itchat.run()
|
||||
try:
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get("hot_reload", False)
|
||||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
itchat.auto_login(
|
||||
enableCmdQR=2,
|
||||
hotReload=hotReload,
|
||||
statusStorageDir=status_path,
|
||||
qrCallback=qrCallback,
|
||||
exitCallback=self.exitCallback,
|
||||
loginCallback=self.loginCallback
|
||||
)
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
self.name = itchat.instance.storageClass.nickName
|
||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# start message listener
|
||||
itchat.run()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def exitCallback(self):
|
||||
_send_logout()
|
||||
time.sleep(3)
|
||||
self.auto_login_times += 1
|
||||
if self.auto_login_times < 100:
|
||||
self.startup()
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id and conf().get("use_linkai"):
|
||||
_send_logout()
|
||||
time.sleep(2)
|
||||
self.auto_login_times += 1
|
||||
if self.auto_login_times < 100:
|
||||
chat_channel.handler_pool._shutdown = False
|
||||
self.startup()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def loginCallback(self):
|
||||
logger.debug("Login success")
|
||||
@@ -192,7 +205,7 @@ class WechatChannel(ChatChannel):
|
||||
logger.debug(f"[WX]receive attachment msg, file_name={cmsg.content}")
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg, no_need_at=conf().get("no_need_at", False))
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@@ -200,9 +213,11 @@ class WechatChannel(ChatChannel):
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if reply.type == ReplyType.TEXT:
|
||||
reply.content = remove_markdown_symbol(reply.content)
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = remove_markdown_symbol(reply.content)
|
||||
itchat.send(reply.content, toUserName=receiver)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
@@ -219,6 +234,12 @@ class WechatChannel(ChatChannel):
|
||||
image_storage.write(block)
|
||||
logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
|
||||
image_storage.seek(0)
|
||||
if ".webp" in img_url:
|
||||
try:
|
||||
image_storage = convert_webp_to_png(image_storage)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert image: {e}")
|
||||
return
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
@@ -251,20 +272,26 @@ class WechatChannel(ChatChannel):
|
||||
def _send_login_success():
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
chat_client.send_login_success()
|
||||
if chat_client.client_id:
|
||||
chat_client.send_login_success()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def _send_logout():
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
chat_client.send_logout()
|
||||
if chat_client.client_id:
|
||||
chat_client.send_logout()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def _send_qr_code(qrcode_list: list):
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
chat_client.send_qrcode(qrcode_list)
|
||||
if chat_client.client_id:
|
||||
chat_client.send_qrcode(qrcode_list)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
@@ -14,6 +14,11 @@ class WechatMessage(ChatMessage):
|
||||
self.create_time = itchat_msg["CreateTime"]
|
||||
self.is_group = is_group
|
||||
|
||||
notes_join_group = ["加入群聊", "加入了群聊", "invited", "joined"] # 可通过添加对应语言的加入群聊通知中的关键词适配更多
|
||||
notes_bot_join_group = ["邀请你", "invited you", "You've joined", "你通过扫描"]
|
||||
notes_exit_group = ["移出了群聊", "removed"] # 可通过添加对应语言的踢出群聊通知中的关键词适配更多
|
||||
notes_patpat = ["拍了拍我", "tickled my", "tickled me"] # 可通过添加对应语言的拍一拍通知中的关键词适配更多
|
||||
|
||||
if itchat_msg["Type"] == TEXT:
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = itchat_msg["Text"]
|
||||
@@ -26,30 +31,47 @@ class WechatMessage(ChatMessage):
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
|
||||
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
|
||||
if is_group:
|
||||
if any(note_bot_join_group in itchat_msg["Content"] for note_bot_join_group in notes_bot_join_group): # 邀请机器人加入群聊
|
||||
logger.warn("机器人加入群聊消息,不处理~")
|
||||
pass
|
||||
elif any(note_join_group in itchat_msg["Content"] for note_join_group in notes_join_group): # 若有任何在notes_join_group列表中的字符串出现在NOTE中
|
||||
# 这里只能得到nickname, actual_user_id还是机器人的id
|
||||
if "加入了群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
|
||||
elif "加入群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
if "加入群聊" not in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
if "invited" in itchat_msg["Content"]: # 匹配英文信息
|
||||
self.actual_user_nickname = re.findall(r'invited\s+(.+?)\s+to\s+the\s+group\s+chat', itchat_msg["Content"])[0]
|
||||
elif "joined" in itchat_msg["Content"]: # 匹配通过二维码加入的英文信息
|
||||
self.actual_user_nickname = re.findall(r'"(.*?)" joined the group chat via the QR Code shared by', itchat_msg["Content"])[0]
|
||||
elif "加入了群聊" in itchat_msg["Content"]:
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
|
||||
elif "加入群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif any(note_exit_group in itchat_msg["Content"] for note_exit_group in notes_exit_group): # 若有任何在notes_exit_group列表中的字符串出现在NOTE中
|
||||
self.ctype = ContextType.EXIT_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif is_group and ("移出了群聊" in itchat_msg["Content"]):
|
||||
self.ctype = ContextType.EXIT_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
|
||||
self.ctype = ContextType.PATPAT
|
||||
self.content = itchat_msg["Content"]
|
||||
if "拍了拍我" in itchat_msg["Content"]: # 识别中文
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
elif "tickled my" in itchat_msg["Content"] or "tickled me" in itchat_msg["Content"]:
|
||||
self.actual_user_nickname = re.findall(r'^(.*?)(?:tickled my|tickled me)', itchat_msg["Content"])[0]
|
||||
else:
|
||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||
|
||||
elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
|
||||
self.ctype = ContextType.ACCEPT_FRIEND
|
||||
self.content = itchat_msg["Content"]
|
||||
elif "拍了拍我" in itchat_msg["Content"]:
|
||||
elif any(note_patpat in itchat_msg["Content"] for note_patpat in notes_patpat): # 若有任何在notes_patpat列表中的字符串出现在NOTE中:
|
||||
self.ctype = ContextType.PATPAT
|
||||
self.content = itchat_msg["Content"]
|
||||
if is_group:
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
else:
|
||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||
elif itchat_msg["Type"] == ATTACHMENT:
|
||||
|
||||
@@ -17,7 +17,7 @@ from channel.wechatcom.wechatcomapp_client import WechatComAppClient
|
||||
from channel.wechatcom.wechatcomapp_message import WechatComAppMessage
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import compress_imgfile, fsize, split_string_by_utf8_length
|
||||
from common.utils import compress_imgfile, fsize, split_string_by_utf8_length, convert_webp_to_png, remove_markdown_symbol
|
||||
from config import conf, subscribe_msg
|
||||
from voice.audio_convert import any_to_amr, split_audio
|
||||
|
||||
@@ -44,7 +44,7 @@ class WechatComAppChannel(ChatChannel):
|
||||
|
||||
def startup(self):
|
||||
# start message listener
|
||||
urls = ("/wxcomapp", "channel.wechatcom.wechatcomapp_channel.Query")
|
||||
urls = ("/wxcomapp/?", "channel.wechatcom.wechatcomapp_channel.Query")
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("wechatcomapp_port", 9898)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
@@ -52,7 +52,7 @@ class WechatComAppChannel(ChatChannel):
|
||||
def send(self, reply: Reply, context: Context):
|
||||
receiver = context["receiver"]
|
||||
if reply.type in [ReplyType.TEXT, ReplyType.ERROR, ReplyType.INFO]:
|
||||
reply_text = reply.content
|
||||
reply_text = remove_markdown_symbol(reply.content)
|
||||
texts = split_string_by_utf8_length(reply_text, MAX_UTF8_LEN)
|
||||
if len(texts) > 1:
|
||||
logger.info("[wechatcom] text too long, split into {} parts".format(len(texts)))
|
||||
@@ -99,6 +99,12 @@ class WechatComAppChannel(ChatChannel):
|
||||
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
||||
logger.info("[wechatcom] image compressed, sz={}".format(fsize(image_storage)))
|
||||
image_storage.seek(0)
|
||||
if ".webp" in img_url:
|
||||
try:
|
||||
image_storage = convert_webp_to_png(image_storage)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert image: {e}")
|
||||
return
|
||||
try:
|
||||
response = self.client.media.upload("image", image_storage)
|
||||
logger.debug("[wechatcom] upload image response: {}".format(response))
|
||||
@@ -156,11 +162,12 @@ class Query:
|
||||
logger.debug("[wechatcom] receive message: {}, msg= {}".format(message, msg))
|
||||
if msg.type == "event":
|
||||
if msg.event == "subscribe":
|
||||
reply_content = subscribe_msg()
|
||||
if reply_content:
|
||||
reply = create_reply(reply_content, msg).render()
|
||||
res = channel.crypto.encrypt_message(reply, nonce, timestamp)
|
||||
return res
|
||||
pass
|
||||
# reply_content = subscribe_msg()
|
||||
# if reply_content:
|
||||
# reply = create_reply(reply_content, msg).render()
|
||||
# res = channel.crypto.encrypt_message(reply, nonce, timestamp)
|
||||
# return res
|
||||
else:
|
||||
try:
|
||||
wechatcom_msg = WechatComAppMessage(msg, client=channel.client)
|
||||
|
||||
@@ -19,7 +19,7 @@ from channel.wechatmp.common import *
|
||||
from channel.wechatmp.wechatmp_client import WechatMPClient
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import split_string_by_utf8_length
|
||||
from common.utils import split_string_by_utf8_length, remove_markdown_symbol
|
||||
from config import conf
|
||||
from voice.audio_convert import any_to_mp3, split_audio
|
||||
|
||||
@@ -81,7 +81,7 @@ class WechatMPChannel(ChatChannel):
|
||||
receiver = context["receiver"]
|
||||
if self.passive_reply:
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
reply_text = remove_markdown_symbol(reply.content)
|
||||
logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
|
||||
self.cache_dict[receiver].append(("text", reply_text))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
@@ -140,6 +140,42 @@ class WechatMPChannel(ChatChannel):
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("image", media_id))
|
||||
elif reply.type == ReplyType.VIDEO_URL: # 从网络下载视频
|
||||
video_url = reply.content
|
||||
video_res = requests.get(video_url, stream=True)
|
||||
video_storage = io.BytesIO()
|
||||
for block in video_res.iter_content(1024):
|
||||
video_storage.write(block)
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.material.add("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] video uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("video", media_id))
|
||||
|
||||
elif reply.type == ReplyType.VIDEO: # 从文件读取视频
|
||||
video_storage = reply.content
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.material.add("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] video uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("video", media_id))
|
||||
|
||||
else:
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
@@ -222,6 +258,38 @@ class WechatMPChannel(ChatChannel):
|
||||
return
|
||||
self.client.message.send_image(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send image to {}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL: # 从网络下载视频
|
||||
video_url = reply.content
|
||||
video_res = requests.get(video_url, stream=True)
|
||||
video_storage = io.BytesIO()
|
||||
for block in video_res.iter_content(1024):
|
||||
video_storage.write(block)
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.media.upload("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_video(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send video to {}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO: # 从文件读取视频
|
||||
video_storage = reply.content
|
||||
video_storage.seek(0)
|
||||
video_type = 'mp4'
|
||||
filename = receiver + "-" + str(context["msg"].msg_id) + "." + video_type
|
||||
content_type = "video/" + video_type
|
||||
try:
|
||||
response = self.client.media.upload("video", (filename, video_storage, content_type))
|
||||
logger.debug("[wechatmp] upload video response: {}".format(response))
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload video failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_video(receiver, response["media_id"])
|
||||
logger.info("[wechatmp] Do send video to {}".format(receiver))
|
||||
return
|
||||
|
||||
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
|
||||
|
||||
@@ -1,26 +1,101 @@
|
||||
# bot_type
|
||||
OPEN_AI = "openAI"
|
||||
CHATGPT = "chatGPT"
|
||||
BAIDU = "baidu"
|
||||
BAIDU = "baidu" # 百度文心一言模型
|
||||
XUNFEI = "xunfei"
|
||||
CHATGPTONAZURE = "chatGPTOnAzure"
|
||||
LINKAI = "linkai"
|
||||
CLAUDEAI = "claude"
|
||||
QWEN = "qwen"
|
||||
GEMINI = "gemini"
|
||||
CLAUDEAI = "claude" # 使用cookie的历史模型
|
||||
CLAUDEAPI= "claudeAPI" # 通过Claude api调用模型
|
||||
QWEN = "qwen" # 旧版通义模型
|
||||
QWEN_DASHSCOPE = "dashscope" # 通义新版sdk和api key
|
||||
|
||||
|
||||
GEMINI = "gemini" # gemini-1.0-pro
|
||||
ZHIPU_AI = "glm-4"
|
||||
MOONSHOT = "moonshot"
|
||||
MiniMax = "minimax"
|
||||
|
||||
|
||||
# model
|
||||
CLAUDE3 = "claude-3-opus-20240229"
|
||||
GPT35 = "gpt-3.5-turbo"
|
||||
GPT4 = "gpt-4"
|
||||
GPT4_TURBO_PREVIEW = "gpt-4-0125-preview"
|
||||
GPT35_0125 = "gpt-3.5-turbo-0125"
|
||||
GPT35_1106 = "gpt-3.5-turbo-1106"
|
||||
|
||||
GPT_4o = "gpt-4o"
|
||||
GPT_4O_0806 = "gpt-4o-2024-08-06"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT4_TURBO_PREVIEW = "gpt-4-turbo-preview"
|
||||
GPT4_TURBO_04_09 = "gpt-4-turbo-2024-04-09"
|
||||
GPT4_TURBO_01_25 = "gpt-4-0125-preview"
|
||||
GPT4_TURBO_11_06 = "gpt-4-1106-preview"
|
||||
GPT4_VISION_PREVIEW = "gpt-4-vision-preview"
|
||||
|
||||
GPT4 = "gpt-4"
|
||||
GPT_4o_MINI = "gpt-4o-mini"
|
||||
GPT4_32k = "gpt-4-32k"
|
||||
GPT4_06_13 = "gpt-4-0613"
|
||||
GPT4_32k_06_13 = "gpt-4-32k-0613"
|
||||
|
||||
O1 = "o1-preview"
|
||||
O1_MINI = "o1-mini"
|
||||
|
||||
WHISPER_1 = "whisper-1"
|
||||
TTS_1 = "tts-1"
|
||||
TTS_1_HD = "tts-1-hd"
|
||||
|
||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo",
|
||||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI]
|
||||
WEN_XIN = "wenxin"
|
||||
WEN_XIN_4 = "wenxin-4"
|
||||
|
||||
QWEN_TURBO = "qwen-turbo"
|
||||
QWEN_PLUS = "qwen-plus"
|
||||
QWEN_MAX = "qwen-max"
|
||||
|
||||
LINKAI_35 = "linkai-3.5"
|
||||
LINKAI_4_TURBO = "linkai-4-turbo"
|
||||
LINKAI_4o = "linkai-4o"
|
||||
|
||||
GEMINI_PRO = "gemini-1.0-pro"
|
||||
GEMINI_15_flash = "gemini-1.5-flash"
|
||||
GEMINI_15_PRO = "gemini-1.5-pro"
|
||||
GEMINI_20_flash_exp = "gemini-2.0-flash-exp"
|
||||
|
||||
|
||||
GLM_4 = "glm-4"
|
||||
GLM_4_PLUS = "glm-4-plus"
|
||||
GLM_4_flash = "glm-4-flash"
|
||||
GLM_4_LONG = "glm-4-long"
|
||||
GLM_4_ALLTOOLS = "glm-4-alltools"
|
||||
GLM_4_0520 = "glm-4-0520"
|
||||
GLM_4_AIR = "glm-4-air"
|
||||
GLM_4_AIRX = "glm-4-airx"
|
||||
|
||||
|
||||
CLAUDE_3_OPUS = "claude-3-opus-latest"
|
||||
CLAUDE_3_OPUS_0229 = "claude-3-opus-20240229"
|
||||
|
||||
CLAUDE_35_SONNET = "claude-3-5-sonnet-latest" # 带 latest 标签的模型名称,会不断更新指向最新发布的模型
|
||||
CLAUDE_35_SONNET_1022 = "claude-3-5-sonnet-20241022" # 带具体日期的模型名称,会固定为该日期发布的模型
|
||||
CLAUDE_35_SONNET_0620 = "claude-3-5-sonnet-20240620"
|
||||
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
|
||||
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
|
||||
MODEL_LIST = [
|
||||
GPT35, GPT35_0125, GPT35_1106, "gpt-3.5-turbo-16k",
|
||||
O1, O1_MINI, GPT_4o, GPT_4O_0806, GPT_4o_MINI, GPT4_TURBO, GPT4_TURBO_PREVIEW, GPT4_TURBO_01_25, GPT4_TURBO_11_06, GPT4, GPT4_32k, GPT4_06_13, GPT4_32k_06_13,
|
||||
WEN_XIN, WEN_XIN_4,
|
||||
XUNFEI,
|
||||
ZHIPU_AI, GLM_4, GLM_4_PLUS, GLM_4_flash, GLM_4_LONG, GLM_4_ALLTOOLS, GLM_4_0520, GLM_4_AIR, GLM_4_AIRX,
|
||||
MOONSHOT, MiniMax,
|
||||
GEMINI, GEMINI_PRO, GEMINI_15_flash, GEMINI_15_PRO,GEMINI_20_flash_exp,
|
||||
CLAUDE_3_OPUS, CLAUDE_3_OPUS_0229, CLAUDE_35_SONNET, CLAUDE_35_SONNET_1022, CLAUDE_35_SONNET_0620, CLAUDE_3_SONNET, CLAUDE_3_HAIKU, "claude", "claude-3-haiku", "claude-3-sonnet", "claude-3-opus", "claude-3.5-sonnet",
|
||||
"moonshot-v1-8k", "moonshot-v1-32k", "moonshot-v1-128k",
|
||||
QWEN, QWEN_TURBO, QWEN_PLUS, QWEN_MAX,
|
||||
LINKAI_35, LINKAI_4_TURBO, LINKAI_4o
|
||||
]
|
||||
|
||||
# channel
|
||||
FEISHU = "feishu"
|
||||
DINGTALK = "dingtalk"
|
||||
DINGTALK = "dingtalk"
|
||||
|
||||
@@ -2,10 +2,14 @@ from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from linkai import LinkAIClient, PushMsg
|
||||
from config import conf
|
||||
from config import conf, pconf, plugin_config, available_setting, write_plugin_config
|
||||
from plugins import PluginManager
|
||||
import time
|
||||
|
||||
|
||||
chat_client: LinkAIClient
|
||||
|
||||
|
||||
class ChatClient(LinkAIClient):
|
||||
def __init__(self, api_key, host, channel):
|
||||
super().__init__(api_key, host)
|
||||
@@ -22,9 +26,85 @@ class ChatClient(LinkAIClient):
|
||||
context["isgroup"] = push_msg.is_group
|
||||
self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
|
||||
|
||||
def on_config(self, config: dict):
|
||||
if not self.client_id:
|
||||
return
|
||||
logger.info(f"[LinkAI] 从客户端管理加载远程配置: {config}")
|
||||
if config.get("enabled") != "Y":
|
||||
return
|
||||
|
||||
local_config = conf()
|
||||
for key in config.keys():
|
||||
if key in available_setting and config.get(key) is not None:
|
||||
local_config[key] = config.get(key)
|
||||
# 语音配置
|
||||
reply_voice_mode = config.get("reply_voice_mode")
|
||||
if reply_voice_mode:
|
||||
if reply_voice_mode == "voice_reply_voice":
|
||||
local_config["voice_reply_voice"] = True
|
||||
local_config["always_reply_voice"] = False
|
||||
elif reply_voice_mode == "always_reply_voice":
|
||||
local_config["always_reply_voice"] = True
|
||||
local_config["voice_reply_voice"] = True
|
||||
elif reply_voice_mode == "no_reply_voice":
|
||||
local_config["always_reply_voice"] = False
|
||||
local_config["voice_reply_voice"] = False
|
||||
|
||||
if config.get("admin_password"):
|
||||
if not pconf("Godcmd"):
|
||||
write_plugin_config({"Godcmd": {"password": config.get("admin_password"), "admin_users": []} })
|
||||
else:
|
||||
pconf("Godcmd")["password"] = config.get("admin_password")
|
||||
PluginManager().instances["GODCMD"].reload()
|
||||
|
||||
if config.get("group_app_map") and pconf("linkai"):
|
||||
local_group_map = {}
|
||||
for mapping in config.get("group_app_map"):
|
||||
local_group_map[mapping.get("group_name")] = mapping.get("app_code")
|
||||
pconf("linkai")["group_app_map"] = local_group_map
|
||||
PluginManager().instances["LINKAI"].reload()
|
||||
|
||||
if config.get("text_to_image") and config.get("text_to_image") == "midjourney" and pconf("linkai"):
|
||||
if pconf("linkai")["midjourney"]:
|
||||
pconf("linkai")["midjourney"]["enabled"] = True
|
||||
pconf("linkai")["midjourney"]["use_image_create_prefix"] = True
|
||||
elif config.get("text_to_image") and config.get("text_to_image") in ["dall-e-2", "dall-e-3"]:
|
||||
if pconf("linkai")["midjourney"]:
|
||||
pconf("linkai")["midjourney"]["use_image_create_prefix"] = False
|
||||
|
||||
|
||||
def start(channel):
|
||||
global chat_client
|
||||
chat_client = ChatClient(api_key=conf().get("linkai_api_key"),
|
||||
host="link-ai.chat", channel=channel)
|
||||
chat_client = ChatClient(api_key=conf().get("linkai_api_key"), host="", channel=channel)
|
||||
chat_client.config = _build_config()
|
||||
chat_client.start()
|
||||
time.sleep(1.5)
|
||||
if chat_client.client_id:
|
||||
logger.info("[LinkAI] 可前往控制台进行线上登录和配置:https://link-ai.tech/console/clients")
|
||||
|
||||
|
||||
def _build_config():
|
||||
local_conf = conf()
|
||||
config = {
|
||||
"linkai_app_code": local_conf.get("linkai_app_code"),
|
||||
"single_chat_prefix": local_conf.get("single_chat_prefix"),
|
||||
"single_chat_reply_prefix": local_conf.get("single_chat_reply_prefix"),
|
||||
"single_chat_reply_suffix": local_conf.get("single_chat_reply_suffix"),
|
||||
"group_chat_prefix": local_conf.get("group_chat_prefix"),
|
||||
"group_chat_reply_prefix": local_conf.get("group_chat_reply_prefix"),
|
||||
"group_chat_reply_suffix": local_conf.get("group_chat_reply_suffix"),
|
||||
"group_name_white_list": local_conf.get("group_name_white_list"),
|
||||
"nick_name_black_list": local_conf.get("nick_name_black_list"),
|
||||
"speech_recognition": "Y" if local_conf.get("speech_recognition") else "N",
|
||||
"text_to_image": local_conf.get("text_to_image"),
|
||||
"image_create_prefix": local_conf.get("image_create_prefix")
|
||||
}
|
||||
if local_conf.get("always_reply_voice"):
|
||||
config["reply_voice_mode"] = "always_reply_voice"
|
||||
elif local_conf.get("voice_reply_voice"):
|
||||
config["reply_voice_mode"] = "voice_reply_voice"
|
||||
if pconf("linkai"):
|
||||
config["group_app_map"] = pconf("linkai").get("group_app_map")
|
||||
if plugin_config.get("Godcmd"):
|
||||
config["admin_password"] = plugin_config.get("Godcmd").get("password")
|
||||
return config
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
|
||||
import config
|
||||
from common.log import logger
|
||||
|
||||
@@ -10,31 +8,33 @@ def time_checker(f):
|
||||
def _time_checker(self, *args, **kwargs):
|
||||
_config = config.conf()
|
||||
chat_time_module = _config.get("chat_time_module", False)
|
||||
|
||||
if chat_time_module:
|
||||
chat_start_time = _config.get("chat_start_time", "00:00")
|
||||
chat_stopt_time = _config.get("chat_stop_time", "24:00")
|
||||
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$") # 时间匹配,包含24:00
|
||||
chat_stop_time = _config.get("chat_stop_time", "24:00")
|
||||
|
||||
starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式
|
||||
stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式
|
||||
chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间
|
||||
time_regex = re.compile(r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$")
|
||||
|
||||
# 时间格式检查
|
||||
if not (starttime_format_check and stoptime_format_check and chat_time_check):
|
||||
logger.warn("时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format(starttime_format_check, stoptime_format_check))
|
||||
if chat_start_time > "23:59":
|
||||
logger.error("启动时间可能存在问题,请修改!")
|
||||
|
||||
# 服务时间检查
|
||||
now_time = time.strftime("%H:%M", time.localtime())
|
||||
if chat_start_time <= now_time <= chat_stopt_time: # 服务时间内,正常返回回答
|
||||
f(self, *args, **kwargs)
|
||||
if not (time_regex.match(chat_start_time) and time_regex.match(chat_stop_time)):
|
||||
logger.warning("时间格式不正确,请在config.json中修改CHAT_START_TIME/CHAT_STOP_TIME。")
|
||||
return None
|
||||
|
||||
now_time = time.strptime(time.strftime("%H:%M"), "%H:%M")
|
||||
chat_start_time = time.strptime(chat_start_time, "%H:%M")
|
||||
chat_stop_time = time.strptime(chat_stop_time, "%H:%M")
|
||||
# 结束时间小于开始时间,跨天了
|
||||
if chat_stop_time < chat_start_time and (chat_start_time <= now_time or now_time <= chat_stop_time):
|
||||
f(self, *args, **kwargs)
|
||||
# 结束大于开始时间代表,没有跨天
|
||||
elif chat_start_time < chat_stop_time and chat_start_time <= now_time <= chat_stop_time:
|
||||
f(self, *args, **kwargs)
|
||||
else:
|
||||
if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置
|
||||
# 定义匹配规则,如果以 #reconf 或者 #更新配置 结尾, 非服务时间可以修改开始/结束时间并重载配置
|
||||
pattern = re.compile(r"^.*#(?:reconf|更新配置)$")
|
||||
if args and pattern.match(args[0].content):
|
||||
f(self, *args, **kwargs)
|
||||
else:
|
||||
logger.info("非服务时间内,不接受访问")
|
||||
logger.info("非服务时间内,不接受访问")
|
||||
return None
|
||||
else:
|
||||
f(self, *args, **kwargs) # 未开启时间模块则直接回答
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
|
||||
from common.log import logger
|
||||
|
||||
def fsize(file):
|
||||
if isinstance(file, io.BytesIO):
|
||||
@@ -54,3 +55,24 @@ def split_string_by_utf8_length(string, max_length, max_split=0):
|
||||
def get_path_suffix(path):
|
||||
path = urlparse(path).path
|
||||
return os.path.splitext(path)[-1].lstrip('.')
|
||||
|
||||
|
||||
def convert_webp_to_png(webp_image):
|
||||
from PIL import Image
|
||||
try:
|
||||
webp_image.seek(0)
|
||||
img = Image.open(webp_image).convert("RGBA")
|
||||
png_image = io.BytesIO()
|
||||
img.save(png_image, format="PNG")
|
||||
png_image.seek(0)
|
||||
return png_image
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert WEBP to PNG: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def remove_markdown_symbol(text: str):
|
||||
# 移除markdown格式,目前先移除**
|
||||
if not text:
|
||||
return text
|
||||
return re.sub(r'\*\*(.*?)\*\*', r'\1', text)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"channel_type": "wx",
|
||||
"model": "",
|
||||
"open_ai_api_key": "YOUR API KEY",
|
||||
"claude_api_key": "YOUR API KEY",
|
||||
"text_to_image": "dall-e-2",
|
||||
"voice_to_text": "openai",
|
||||
"text_to_voice": "openai",
|
||||
@@ -27,7 +28,7 @@
|
||||
"voice_reply_voice": false,
|
||||
"conversation_max_tokens": 2500,
|
||||
"expires_in_seconds": 3600,
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"character_desc": "你是基于大语言模型的AI智能助手,旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"temperature": 0.7,
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false,
|
||||
|
||||
95
config.py
95
config.py
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import copy
|
||||
|
||||
from common.log import logger
|
||||
|
||||
@@ -16,7 +17,8 @@ available_setting = {
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"proxy": "", # openai使用的代理
|
||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"model": "gpt-3.5-turbo", # 还支持 gpt-4, gpt-4-turbo, wenxin, xunfei, qwen
|
||||
"model": "gpt-3.5-turbo", # 可选择: gpt-4o, pt-4o-mini, gpt-4-turbo, claude-3-sonnet, wenxin, moonshot, qwen-turbo, xunfei, glm-4, minimax, gemini等模型,全部可选模型详见common/const.py文件
|
||||
"bot_type": "", # 可选配置,使用兼容openai格式的三方服务时候,需填"chatGPT"。bot具体名称详见common/const.py文件列出的bot_type,如不填根据model名称判断,
|
||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||
"azure_deployment_id": "", # azure 模型部署名称
|
||||
"azure_api_version": "", # azure api版本
|
||||
@@ -25,6 +27,7 @@ available_setting = {
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"single_chat_reply_suffix": "", # 私聊时自动回复的后缀,\n 可以换行
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"no_need_at": False, # 群聊回复时是否不需要艾特
|
||||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
|
||||
"group_chat_reply_suffix": "", # 群聊时自动回复的后缀,\n 可以换行
|
||||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
|
||||
@@ -33,14 +36,21 @@ available_setting = {
|
||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"nick_name_black_list": [], # 用户昵称黑名单
|
||||
"group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎
|
||||
"group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎
|
||||
"trigger_by_self": False, # 是否允许机器人触发
|
||||
"text_to_image": "dall-e-2", # 图片生成模型,可选 dall-e-2, dall-e-3
|
||||
# Azure OpenAI dall-e-3 配置
|
||||
"dalle3_image_style": "vivid", # 图片生成dalle3的风格,可选有 vivid, natural
|
||||
"dalle3_image_quality": "hd", # 图片生成dalle3的质量,可选有 standard, hd
|
||||
# Azure OpenAI DALL-E API 配置, 当use_azure_chatgpt为true时,用于将文字回复的资源和Dall-E的资源分开.
|
||||
"azure_openai_dalle_api_base": "", # [可选] azure openai 用于回复图片的资源 endpoint,默认使用 open_ai_api_base
|
||||
"azure_openai_dalle_api_key": "", # [可选] azure openai 用于回复图片的资源 key,默认使用 open_ai_api_key
|
||||
"azure_openai_dalle_deployment_id":"", # [可选] azure openai 用于回复图片的资源 deployment id,默认使用 text_to_image
|
||||
"image_proxy": True, # 是否需要图片代理,国内访问LinkAI时需要
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
|
||||
"image_create_size": "256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024 (dall-e-3默认为1024x1024)
|
||||
"group_chat_exit_group": False,
|
||||
"group_chat_exit_group": False,
|
||||
# chatgpt会话参数
|
||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||
# 人格描述
|
||||
@@ -60,19 +70,26 @@ available_setting = {
|
||||
"baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
|
||||
"baidu_wenxin_api_key": "", # Baidu api key
|
||||
"baidu_wenxin_secret_key": "", # Baidu secret key
|
||||
"baidu_wenxin_prompt_enabled": False, # Enable prompt if you are using ernie character model
|
||||
# 讯飞星火API
|
||||
"xunfei_app_id": "", # 讯飞应用ID
|
||||
"xunfei_api_key": "", # 讯飞 API key
|
||||
"xunfei_api_secret": "", # 讯飞 API secret
|
||||
"xunfei_domain": "", # 讯飞模型对应的domain参数,Spark4.0 Ultra为 4.0Ultra,其他模型详见: https://www.xfyun.cn/doc/spark/Web.html
|
||||
"xunfei_spark_url": "", # 讯飞模型对应的请求地址,Spark4.0 Ultra为 wss://spark-api.xf-yun.com/v4.0/chat,其他模型参考详见: https://www.xfyun.cn/doc/spark/Web.html
|
||||
# claude 配置
|
||||
"claude_api_cookie": "",
|
||||
"claude_uuid": "",
|
||||
# claude api key
|
||||
"claude_api_key": "",
|
||||
# 通义千问API, 获取方式查看文档 https://help.aliyun.com/document_detail/2587494.html
|
||||
"qwen_access_key_id": "",
|
||||
"qwen_access_key_secret": "",
|
||||
"qwen_agent_key": "",
|
||||
"qwen_app_id": "",
|
||||
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
|
||||
# 阿里灵积(通义新版sdk)模型api key
|
||||
"dashscope_api_key": "",
|
||||
# Google Gemini Api Key
|
||||
"gemini_api_key": "",
|
||||
# wework的通用配置
|
||||
@@ -82,8 +99,8 @@ available_setting = {
|
||||
"group_speech_recognition": False, # 是否开启群组语音识别
|
||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
||||
"always_reply_voice": False, # 是否一直使用语音回复
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure
|
||||
"text_to_voice": "openai", # 语音合成引擎,支持openai,baidu,google,pytts(offline),azure,elevenlabs
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure,xunfei,ali
|
||||
"text_to_voice": "openai", # 语音合成引擎,支持openai,baidu,google,azure,xunfei,ali,pytts(offline),elevenlabs,edge(online)
|
||||
"text_to_voice_model": "tts-1",
|
||||
"tts_voice_id": "alloy",
|
||||
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
|
||||
@@ -91,13 +108,13 @@ available_setting = {
|
||||
"baidu_api_key": "",
|
||||
"baidu_secret_key": "",
|
||||
# 1536普通话(支持简单的英文识别) 1737英语 1637粤语 1837四川话 1936普通话远场
|
||||
"baidu_dev_pid": "1536",
|
||||
"baidu_dev_pid": 1536,
|
||||
# azure 语音api配置, 使用azure语音识别和语音合成时需要
|
||||
"azure_voice_api_key": "",
|
||||
"azure_voice_region": "japaneast",
|
||||
# elevenlabs 语音api配置
|
||||
"xi_api_key": "", #获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication
|
||||
"xi_voice_id": "", #ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam”
|
||||
"xi_api_key": "", # 获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication
|
||||
"xi_voice_id": "", # ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam”
|
||||
# 服务时间限制,目前支持itchat
|
||||
"chat_time_module": False, # 是否开启服务时间限制
|
||||
"chat_start_time": "00:00", # 服务开始时间
|
||||
@@ -125,22 +142,21 @@ available_setting = {
|
||||
"wechatcomapp_secret": "", # 企业微信app的secret
|
||||
"wechatcomapp_agent_id": "", # 企业微信app的agent_id
|
||||
"wechatcomapp_aes_key": "", # 企业微信app的aes_key
|
||||
|
||||
# 飞书配置
|
||||
"feishu_port": 80, # 飞书bot监听端口
|
||||
"feishu_app_id": "", # 飞书机器人应用APP Id
|
||||
"feishu_app_secret": "", # 飞书机器人APP secret
|
||||
"feishu_token": "", # 飞书 verification token
|
||||
"feishu_bot_name": "", # 飞书机器人的名字
|
||||
|
||||
# 钉钉配置
|
||||
"dingtalk_client_id": "", # 钉钉机器人Client ID
|
||||
"dingtalk_client_secret": "", # 钉钉机器人Client Secret
|
||||
"dingtalk_client_secret": "", # 钉钉机器人Client Secret
|
||||
"dingtalk_card_enabled": False,
|
||||
|
||||
# chatgpt指令自定义触发词
|
||||
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
|
||||
# channel配置
|
||||
"channel_type": "wx", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app}
|
||||
"channel_type": "", # 通道类型,支持:{wx,wxy,terminal,wechatmp,wechatmp_service,wechatcom_app,dingtalk}
|
||||
"subscribe_msg": "", # 订阅消息, 支持: wechatmp, wechatmp_service, wechatcom_app
|
||||
"debug": False, # 是否开启debug模式,开启后会打印更多日志
|
||||
"appdata_dir": "", # 数据目录
|
||||
@@ -148,11 +164,22 @@ available_setting = {
|
||||
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
|
||||
# 是否使用全局插件配置
|
||||
"use_global_plugin_config": False,
|
||||
# 知识库平台配置
|
||||
"max_media_send_count": 3, # 单次最大发送媒体资源的个数
|
||||
"media_send_interval": 1, # 发送图片的事件间隔,单位秒
|
||||
# 智谱AI 平台配置
|
||||
"zhipu_ai_api_key": "",
|
||||
"zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
"moonshot_api_key": "",
|
||||
"moonshot_base_url": "https://api.moonshot.cn/v1/chat/completions",
|
||||
# LinkAI平台配置
|
||||
"use_linkai": False,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": "",
|
||||
"linkai_api_base": "https://api.link-ai.chat", # linkAI服务地址,若国内无法访问或延迟较高可改为 https://api.link-ai.tech
|
||||
"linkai_api_base": "https://api.link-ai.tech", # linkAI服务地址
|
||||
"Minimax_api_key": "",
|
||||
"Minimax_group_id": "",
|
||||
"Minimax_base_url": "",
|
||||
"web_port": 9899,
|
||||
}
|
||||
|
||||
|
||||
@@ -213,6 +240,30 @@ class Config(dict):
|
||||
config = Config()
|
||||
|
||||
|
||||
def drag_sensitive(config):
|
||||
try:
|
||||
if isinstance(config, str):
|
||||
conf_dict: dict = json.loads(config)
|
||||
conf_dict_copy = copy.deepcopy(conf_dict)
|
||||
for key in conf_dict_copy:
|
||||
if "key" in key or "secret" in key:
|
||||
if isinstance(conf_dict_copy[key], str):
|
||||
conf_dict_copy[key] = conf_dict_copy[key][0:3] + "*" * 5 + conf_dict_copy[key][-3:]
|
||||
return json.dumps(conf_dict_copy, indent=4)
|
||||
|
||||
elif isinstance(config, dict):
|
||||
config_copy = copy.deepcopy(config)
|
||||
for key in config:
|
||||
if "key" in key or "secret" in key:
|
||||
if isinstance(config_copy[key], str):
|
||||
config_copy[key] = config_copy[key][0:3] + "*" * 5 + config_copy[key][-3:]
|
||||
return config_copy
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return config
|
||||
return config
|
||||
|
||||
|
||||
def load_config():
|
||||
global config
|
||||
config_path = "./config.json"
|
||||
@@ -221,7 +272,7 @@ def load_config():
|
||||
config_path = "./config-template.json"
|
||||
|
||||
config_str = read_file(config_path)
|
||||
logger.debug("[INIT] config str: {}".format(config_str))
|
||||
logger.debug("[INIT] config str: {}".format(drag_sensitive(config_str)))
|
||||
|
||||
# 将json字符串反序列化为dict类型
|
||||
config = Config(json.loads(config_str))
|
||||
@@ -246,7 +297,7 @@ def load_config():
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.debug("[INIT] set log level to DEBUG")
|
||||
|
||||
logger.info("[INIT] load config: {}".format(config))
|
||||
logger.info("[INIT] load config: {}".format(drag_sensitive(config)))
|
||||
|
||||
config.load_user_datas()
|
||||
|
||||
@@ -291,6 +342,14 @@ def write_plugin_config(pconf: dict):
|
||||
for k in pconf:
|
||||
plugin_config[k.lower()] = pconf[k]
|
||||
|
||||
def remove_plugin_config(name: str):
|
||||
"""
|
||||
移除待重新加载的插件全局配置
|
||||
:param name: 待重载的插件名
|
||||
"""
|
||||
global plugin_config
|
||||
plugin_config.pop(name.lower(), None)
|
||||
|
||||
|
||||
def pconf(plugin_name: str) -> dict:
|
||||
"""
|
||||
@@ -302,6 +361,4 @@ def pconf(plugin_name: str) -> dict:
|
||||
|
||||
|
||||
# 全局配置,用于存放全局生效的状态
|
||||
global_config = {
|
||||
"admin_users": []
|
||||
}
|
||||
global_config = {"admin_users": []}
|
||||
|
||||
@@ -6,6 +6,7 @@ services:
|
||||
security_opt:
|
||||
- seccomp:unconfined
|
||||
environment:
|
||||
TZ: 'Asia/Shanghai'
|
||||
OPEN_AI_API_KEY: 'YOUR API KEY'
|
||||
MODEL: 'gpt-3.5-turbo'
|
||||
PROXY: ''
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 51 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 326 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 382 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 33 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 180 KiB |
13
docs/version/old-version.md
Normal file
13
docs/version/old-version.md
Normal file
@@ -0,0 +1,13 @@
|
||||
## 归档更新日志
|
||||
|
||||
2023.04.26: 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,使用文档。(contributed by @lanvent in #944)
|
||||
|
||||
2023.04.05: 支持微信公众号部署,兼容插件,并支持语音图片交互,使用文档。(contributed by @JS00000 in #686)
|
||||
|
||||
2023.04.05: 增加能让ChatGPT使用工具的tool插件,使用文档。工具相关issue可反馈至chatgpt-tool-hub。(contributed by @goldfishh in #663)
|
||||
|
||||
2023.03.25: 支持插件化开发,目前已实现 多角色切换、文字冒险游戏、管理员指令、Stable Diffusion等插件,使用参考 #578。(contributed by @lanvent in #565)
|
||||
|
||||
2023.03.09: 基于 whisper API(后续已接入更多的语音API服务) 实现对语音消息的解析和回复,添加配置项 "speech_recognition":true 即可启用,使用参考 #415。(contributed by wanggang1987 in #385)
|
||||
|
||||
2023.02.09: 扫码登录存在账号限制风险,请谨慎使用,参考#58
|
||||
@@ -10,9 +10,7 @@
|
||||
},
|
||||
"tool": {
|
||||
"tools": [
|
||||
"python",
|
||||
"url-get",
|
||||
"terminal",
|
||||
"meteo-weather"
|
||||
],
|
||||
"kwargs": {
|
||||
@@ -40,5 +38,22 @@
|
||||
"max_file_size": 5000,
|
||||
"type": ["FILE", "SHARING"]
|
||||
}
|
||||
},
|
||||
"hello": {
|
||||
"group_welc_fixed_msg": {
|
||||
"群聊1": "群聊1的固定欢迎语",
|
||||
"群聊2": "群聊2的固定欢迎语"
|
||||
},
|
||||
"group_welc_prompt": "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。",
|
||||
|
||||
"group_exit_prompt": "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。",
|
||||
|
||||
"patpat_prompt": "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。",
|
||||
|
||||
"use_character_desc": false
|
||||
},
|
||||
"Apilot": {
|
||||
"alapi_token": "xxx",
|
||||
"morning_news_text_enabled": false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -313,7 +313,7 @@ class Godcmd(Plugin):
|
||||
except Exception as e:
|
||||
ok, result = False, "你没有设置私有GPT模型"
|
||||
elif cmd == "reset":
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]:
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.CLAUDEAPI]:
|
||||
bot.sessions.clear_session(session_id)
|
||||
if Bridge().chat_bots.get(bottype):
|
||||
Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)
|
||||
@@ -339,7 +339,7 @@ class Godcmd(Plugin):
|
||||
ok, result = True, "配置已重载"
|
||||
elif cmd == "resetall":
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
|
||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]:
|
||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI, const.ZHIPU_AI, const.MOONSHOT]:
|
||||
channel.cancel_all_session()
|
||||
bot.sessions.clear_all_session()
|
||||
ok, result = True, "重置所有会话成功"
|
||||
@@ -475,3 +475,11 @@ class Godcmd(Plugin):
|
||||
if model == "gpt-4-turbo":
|
||||
return const.GPT4_TURBO_PREVIEW
|
||||
return model
|
||||
|
||||
def reload(self):
|
||||
gconf = pconf(self.name)
|
||||
if gconf:
|
||||
if gconf.get("password"):
|
||||
self.password = gconf["password"]
|
||||
if gconf.get("admin_users"):
|
||||
self.admin_users = gconf["admin_users"]
|
||||
|
||||
41
plugins/hello/README.md
Normal file
41
plugins/hello/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
## 插件说明
|
||||
|
||||
可以根据需求设置入群欢迎、群聊拍一拍、退群等消息的自定义提示词,也支持为每个群设置对应的固定欢迎语。
|
||||
|
||||
该插件也是用户根据需求开发自定义插件的示例插件,参考[插件开发说明](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)
|
||||
|
||||
## 插件配置
|
||||
|
||||
将 `plugins/hello` 目录下的 `config.json.template` 配置模板复制为最终生效的 `config.json`。 (如果未配置则会默认使用`config.json.template`模板中配置)。
|
||||
|
||||
以下是插件配置项说明:
|
||||
|
||||
```bash
|
||||
{
|
||||
"group_welc_fixed_msg": { ## 这里可以为特定群里配置特定的固定欢迎语
|
||||
"群聊1": "群聊1的固定欢迎语",
|
||||
"群聊2": "群聊2的固定欢迎语"
|
||||
},
|
||||
|
||||
"group_welc_prompt": "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。", ## 群聊随机欢迎语的提示词
|
||||
|
||||
"group_exit_prompt": "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。", ## 移出群聊的提示词
|
||||
|
||||
"patpat_prompt": "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。", ## 群内拍一拍的提示词
|
||||
|
||||
"use_character_desc": false ## 是否在Hello插件中使用LinkAI应用的系统设定
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
注意:
|
||||
|
||||
- 设置全局的用户进群固定欢迎语,可以在***项目根目录下***的`config.json`文件里,可以添加参数`"group_welcome_msg": "" `,参考 [#1482](https://github.com/zhayujie/chatgpt-on-wechat/pull/1482)
|
||||
- 为每个群设置固定的欢迎语,可以在`"group_welc_fixed_msg": {}`配置群聊名和对应的固定欢迎语,优先级高于全局固定欢迎语
|
||||
- 如果没有配置以上两个参数,则使用随机欢迎语,如需设定风格,语言等,修改`"group_welc_prompt": `即可
|
||||
- 如果使用LinkAI的服务,想在随机欢迎中结合LinkAI应用的设定,配置`"use_character_desc": true `
|
||||
- 实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释
|
||||
- 如果是`docker`部署,可通过映射 `plugins/config.json` 到容器中来完成插件配置,参考[文档](https://github.com/zhayujie/chatgpt-on-wechat#3-%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8)
|
||||
|
||||
|
||||
|
||||
14
plugins/hello/config.json.template
Normal file
14
plugins/hello/config.json.template
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"group_welc_fixed_msg": {
|
||||
"群聊1": "群聊1的固定欢迎语",
|
||||
"群聊2": "群聊2的固定欢迎语"
|
||||
},
|
||||
|
||||
"group_welc_prompt": "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。",
|
||||
|
||||
"group_exit_prompt": "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。",
|
||||
|
||||
"patpat_prompt": "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。",
|
||||
|
||||
"use_character_desc": false
|
||||
}
|
||||
@@ -17,12 +17,29 @@ from config import conf
|
||||
version="0.1",
|
||||
author="lanvent",
|
||||
)
|
||||
|
||||
|
||||
class Hello(Plugin):
|
||||
|
||||
group_welc_prompt = "请你随机使用一种风格说一句问候语来欢迎新用户\"{nickname}\"加入群聊。"
|
||||
group_exit_prompt = "请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。"
|
||||
patpat_prompt = "请你随机使用一种风格跟其他群用户说他违反规则\"{nickname}\"退出群聊。"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[Hello] inited")
|
||||
self.config = super().load_config()
|
||||
try:
|
||||
self.config = super().load_config()
|
||||
if not self.config:
|
||||
self.config = self._load_config_template()
|
||||
self.group_welc_fixed_msg = self.config.get("group_welc_fixed_msg", {})
|
||||
self.group_welc_prompt = self.config.get("group_welc_prompt", self.group_welc_prompt)
|
||||
self.group_exit_prompt = self.config.get("group_exit_prompt", self.group_exit_prompt)
|
||||
self.patpat_prompt = self.config.get("patpat_prompt", self.patpat_prompt)
|
||||
logger.info("[Hello] inited")
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
except Exception as e:
|
||||
logger.error(f"[Hello]初始化异常:{e}")
|
||||
raise "[Hello] init failed, ignore "
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
if e_context["context"].type not in [
|
||||
@@ -32,17 +49,21 @@ class Hello(Plugin):
|
||||
ContextType.EXIT_GROUP
|
||||
]:
|
||||
return
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
group_name = msg.from_user_nickname
|
||||
if e_context["context"].type == ContextType.JOIN_GROUP:
|
||||
if "group_welcome_msg" in conf():
|
||||
if "group_welcome_msg" in conf() or group_name in self.group_welc_fixed_msg:
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = conf().get("group_welcome_msg", "")
|
||||
if group_name in self.group_welc_fixed_msg:
|
||||
reply.content = self.group_welc_fixed_msg.get(group_name, "")
|
||||
else:
|
||||
reply.content = conf().get("group_welcome_msg", "")
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
return
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
|
||||
e_context["context"].content = self.group_welc_prompt.format(nickname=msg.actual_user_nickname)
|
||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
|
||||
if not self.config or not self.config.get("use_character_desc"):
|
||||
e_context["context"]["generate_breaked_by"] = EventAction.BREAK
|
||||
@@ -51,8 +72,7 @@ class Hello(Plugin):
|
||||
if e_context["context"].type == ContextType.EXIT_GROUP:
|
||||
if conf().get("group_chat_exit_group"):
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f'请你随机使用一种风格跟其他群用户说他违反规则"{msg.actual_user_nickname}"退出群聊。'
|
||||
e_context["context"].content = self.group_exit_prompt.format(nickname=msg.actual_user_nickname)
|
||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
|
||||
return
|
||||
e_context.action = EventAction.BREAK
|
||||
@@ -60,8 +80,7 @@ class Hello(Plugin):
|
||||
|
||||
if e_context["context"].type == ContextType.PATPAT:
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f"请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。"
|
||||
e_context["context"].content = self.patpat_prompt
|
||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
|
||||
if not self.config or not self.config.get("use_character_desc"):
|
||||
e_context["context"]["generate_breaked_by"] = EventAction.BREAK
|
||||
@@ -72,7 +91,6 @@ class Hello(Plugin):
|
||||
if content == "Hello":
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
if e_context["context"]["isgroup"]:
|
||||
reply.content = f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}"
|
||||
else:
|
||||
@@ -96,3 +114,14 @@ class Hello(Plugin):
|
||||
def get_help_text(self, **kwargs):
|
||||
help_text = "输入Hello,我会回复你的名字\n输入End,我会回复你世界的图片\n"
|
||||
return help_text
|
||||
|
||||
def _load_config_template(self):
|
||||
logger.debug("No Hello plugin config.json, use plugins/hello/config.json.template")
|
||||
try:
|
||||
plugin_config_path = os.path.join(self.path, "config.json.template")
|
||||
if os.path.exists(plugin_config_path):
|
||||
with open(plugin_config_path, "r", encoding="utf-8") as f:
|
||||
plugin_conf = json.load(f)
|
||||
return plugin_conf
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
@@ -55,7 +55,7 @@ class Keyword(Plugin):
|
||||
reply_text = self.keyword[content]
|
||||
|
||||
# 判断匹配内容的类型
|
||||
if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".img"]):
|
||||
if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".webp", ".jpeg", ".png", ".gif", ".img"]):
|
||||
# 如果是以 http:// 或 https:// 开头,且".jpg", ".jpeg", ".png", ".gif", ".img"结尾,则认为是图片 URL。
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.IMAGE_URL
|
||||
|
||||
@@ -98,6 +98,8 @@
|
||||
|
||||
如果不想创建 `plugins/linkai/config.json` 配置,可以直接通过 `$linkai sum open` 指令开启该功能。
|
||||
|
||||
也可以通过私聊(全局 `config.json` 中的 `linkai_app_code`)或者群聊绑定(通过`group_app_map`参数配置)的应用来开启该功能:在LinkAI平台 [应用配置](https://link-ai.tech/console/factory) 里添加并开启**内容总结**插件。
|
||||
|
||||
#### 使用
|
||||
|
||||
功能开启后,向机器人发送 **文件**、 **分享链接卡片**、**图片** 即可生成摘要,进一步可以与文件或链接的内容进行多轮对话。如果需要关闭某种类型的内容总结,设置 `summary`配置中的type字段即可。
|
||||
|
||||
@@ -9,6 +9,8 @@ from common.expired_dict import ExpiredDict
|
||||
from common import const
|
||||
import os
|
||||
from .utils import Util
|
||||
from config import plugin_config, conf
|
||||
|
||||
|
||||
@plugins.register(
|
||||
name="linkai",
|
||||
@@ -26,13 +28,12 @@ class LinkAI(Plugin):
|
||||
# 未加载到配置,使用模板中的配置
|
||||
self.config = self._load_config_template()
|
||||
if self.config:
|
||||
self.mj_bot = MJBot(self.config.get("midjourney"))
|
||||
self.mj_bot = MJBot(self.config.get("midjourney"), self._fetch_group_app_code)
|
||||
self.sum_config = {}
|
||||
if self.config:
|
||||
self.sum_config = self.config.get("summary")
|
||||
logger.info(f"[LinkAI] inited, config={self.config}")
|
||||
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
"""
|
||||
消息处理逻辑
|
||||
@@ -42,7 +43,8 @@ class LinkAI(Plugin):
|
||||
return
|
||||
|
||||
context = e_context['context']
|
||||
if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE, ContextType.FILE, ContextType.SHARING]:
|
||||
if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE, ContextType.FILE,
|
||||
ContextType.SHARING]:
|
||||
# filter content no need solve
|
||||
return
|
||||
|
||||
@@ -54,7 +56,8 @@ class LinkAI(Plugin):
|
||||
return
|
||||
if context.type != ContextType.IMAGE:
|
||||
_send_info(e_context, "正在为你加速生成摘要,请稍后")
|
||||
res = LinkSummary().summary_file(file_path)
|
||||
app_code = self._fetch_app_code(context)
|
||||
res = LinkSummary().summary_file(file_path, app_code)
|
||||
if not res:
|
||||
if context.type != ContextType.IMAGE:
|
||||
_set_reply_text("因为神秘力量无法获取内容,请稍后再试吧", e_context, level=ReplyType.TEXT)
|
||||
@@ -68,15 +71,17 @@ class LinkAI(Plugin):
|
||||
return
|
||||
|
||||
if (context.type == ContextType.SHARING and self._is_summary_open(context)) or \
|
||||
(context.type == ContextType.TEXT and LinkSummary().check_url(context.content)):
|
||||
(context.type == ContextType.TEXT and self._is_summary_open(context) and LinkSummary().check_url(context.content)):
|
||||
if not LinkSummary().check_url(context.content):
|
||||
return
|
||||
_send_info(e_context, "正在为你加速生成摘要,请稍后")
|
||||
res = LinkSummary().summary_url(context.content)
|
||||
app_code = self._fetch_app_code(context)
|
||||
res = LinkSummary().summary_url(context.content, app_code)
|
||||
if not res:
|
||||
_set_reply_text("因为神秘力量无法获取文章内容,请稍后再试吧~", e_context, level=ReplyType.TEXT)
|
||||
return
|
||||
_set_reply_text(res.get("summary") + "\n\n💬 发送 \"开启对话\" 可以开启与文章内容的对话", e_context, level=ReplyType.TEXT)
|
||||
_set_reply_text(res.get("summary") + "\n\n💬 发送 \"开启对话\" 可以开启与文章内容的对话", e_context,
|
||||
level=ReplyType.TEXT)
|
||||
USER_FILE_MAP[_find_user_id(context) + "-sum_id"] = res.get("summary_id")
|
||||
return
|
||||
|
||||
@@ -99,7 +104,8 @@ class LinkAI(Plugin):
|
||||
_set_reply_text("开启对话失败,请稍后再试吧", e_context)
|
||||
return
|
||||
USER_FILE_MAP[_find_user_id(context) + "-file_id"] = res.get("file_id")
|
||||
_set_reply_text("💡你可以问我关于这篇文章的任何问题,例如:\n\n" + res.get("questions") + "\n\n发送 \"退出对话\" 可以关闭与文章的对话", e_context, level=ReplyType.TEXT)
|
||||
_set_reply_text("💡你可以问我关于这篇文章的任何问题,例如:\n\n" + res.get(
|
||||
"questions") + "\n\n发送 \"退出对话\" 可以关闭与文章的对话", e_context, level=ReplyType.TEXT)
|
||||
return
|
||||
|
||||
if context.type == ContextType.TEXT and context.content == "退出对话" and _find_file_id(context):
|
||||
@@ -117,12 +123,10 @@ class LinkAI(Plugin):
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
|
||||
if self._is_chat_task(e_context):
|
||||
# 文本对话任务处理
|
||||
self._process_chat_task(e_context)
|
||||
|
||||
|
||||
# 插件管理功能
|
||||
def _process_admin_cmd(self, e_context: EventContext):
|
||||
context = e_context['context']
|
||||
@@ -167,7 +171,7 @@ class LinkAI(Plugin):
|
||||
return
|
||||
|
||||
if len(cmd) == 3 and cmd[1] == "sum" and (cmd[2] == "open" or cmd[2] == "close"):
|
||||
# 知识库开关指令
|
||||
# 总结对话开关指令
|
||||
if not Util.is_admin(e_context):
|
||||
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
@@ -177,7 +181,9 @@ class LinkAI(Plugin):
|
||||
tips_text = "关闭"
|
||||
is_open = False
|
||||
if not self.sum_config:
|
||||
_set_reply_text(f"插件未启用summary功能,请参考以下链添加插件配置\n\nhttps://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/linkai/README.md", e_context, level=ReplyType.INFO)
|
||||
_set_reply_text(
|
||||
f"插件未启用summary功能,请参考以下链添加插件配置\n\nhttps://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/linkai/README.md",
|
||||
e_context, level=ReplyType.INFO)
|
||||
else:
|
||||
self.sum_config["enabled"] = is_open
|
||||
_set_reply_text(f"文章总结功能{tips_text}", e_context, level=ReplyType.INFO)
|
||||
@@ -188,14 +194,36 @@ class LinkAI(Plugin):
|
||||
return
|
||||
|
||||
def _is_summary_open(self, context) -> bool:
|
||||
if not self.sum_config or not self.sum_config.get("enabled"):
|
||||
return False
|
||||
if context.kwargs.get("isgroup") and not self.sum_config.get("group_enabled"):
|
||||
return False
|
||||
support_type = self.sum_config.get("type") or ["FILE", "SHARING"]
|
||||
if context.type.name not in support_type:
|
||||
return False
|
||||
return True
|
||||
# 获取远程应用插件状态
|
||||
remote_enabled = False
|
||||
if context.kwargs.get("isgroup"):
|
||||
# 群聊场景只查询群对应的app_code
|
||||
group_name = context.get("msg").from_user_nickname
|
||||
app_code = self._fetch_group_app_code(group_name)
|
||||
if app_code:
|
||||
if context.type.name in ["FILE", "SHARING"]:
|
||||
remote_enabled = Util.fetch_app_plugin(app_code, "内容总结")
|
||||
else:
|
||||
# 非群聊场景使用全局app_code
|
||||
app_code = conf().get("linkai_app_code")
|
||||
if app_code:
|
||||
if context.type.name in ["FILE", "SHARING"]:
|
||||
remote_enabled = Util.fetch_app_plugin(app_code, "内容总结")
|
||||
|
||||
# 基础条件:总开关开启且消息类型符合要求
|
||||
base_enabled = (
|
||||
self.sum_config
|
||||
and self.sum_config.get("enabled")
|
||||
and (context.type.name in (
|
||||
self.sum_config.get("type") or ["FILE", "SHARING"]) or context.type.name == "TEXT")
|
||||
)
|
||||
|
||||
# 群聊:需要满足(总开关和群开关)或远程插件开启
|
||||
if context.kwargs.get("isgroup"):
|
||||
return (base_enabled and self.sum_config.get("group_enabled")) or remote_enabled
|
||||
|
||||
# 非群聊:只需要满足总开关或远程插件开启
|
||||
return base_enabled or remote_enabled
|
||||
|
||||
# LinkAI 对话任务处理
|
||||
def _is_chat_task(self, e_context: EventContext):
|
||||
@@ -226,6 +254,19 @@ class LinkAI(Plugin):
|
||||
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
|
||||
return app_code
|
||||
|
||||
def _fetch_app_code(self, context) -> str:
|
||||
"""
|
||||
根据主配置或者群聊名称获取对应的应用code,优先获取群聊配置的应用code
|
||||
:param context: 上下文
|
||||
:return: 应用code
|
||||
"""
|
||||
app_code = conf().get("linkai_app_code")
|
||||
if context.kwargs.get("isgroup"):
|
||||
# 群聊场景只查询群对应的app_code
|
||||
group_name = context.get("msg").from_user_nickname
|
||||
app_code = self._fetch_group_app_code(group_name)
|
||||
return app_code
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
trigger_prefix = _get_trigger_prefix()
|
||||
help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画、文档总结、联网搜索等能力。\n\n"
|
||||
@@ -250,10 +291,14 @@ class LinkAI(Plugin):
|
||||
plugin_conf = json.load(f)
|
||||
plugin_conf["midjourney"]["enabled"] = False
|
||||
plugin_conf["summary"]["enabled"] = False
|
||||
write_plugin_config({"linkai": plugin_conf})
|
||||
return plugin_conf
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def reload(self):
|
||||
self.config = super().load_config()
|
||||
|
||||
|
||||
def _send_info(e_context: EventContext, content: str):
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
@@ -273,15 +318,19 @@ def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = Re
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
|
||||
def _get_trigger_prefix():
|
||||
return conf().get("plugin_trigger_prefix", "$")
|
||||
|
||||
|
||||
def _find_sum_id(context):
|
||||
return USER_FILE_MAP.get(_find_user_id(context) + "-sum_id")
|
||||
|
||||
|
||||
def _find_file_id(context):
|
||||
user_id = _find_user_id(context)
|
||||
if user_id:
|
||||
return USER_FILE_MAP.get(user_id + "-file_id")
|
||||
|
||||
|
||||
USER_FILE_MAP = ExpiredDict(conf().get("expires_in_seconds") or 60 * 30)
|
||||
|
||||
@@ -10,6 +10,7 @@ from bridge.context import ContextType
|
||||
from plugins import EventContext, EventAction
|
||||
from .utils import Util
|
||||
|
||||
|
||||
INVALID_REQUEST = 410
|
||||
NOT_FOUND_ORIGIN_IMAGE = 461
|
||||
NOT_FOUND_TASK = 462
|
||||
@@ -67,10 +68,11 @@ class MJTask:
|
||||
|
||||
# midjourney bot
|
||||
class MJBot:
|
||||
def __init__(self, config):
|
||||
self.base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/img/midjourney"
|
||||
def __init__(self, config, fetch_group_app_code):
|
||||
self.base_url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/img/midjourney"
|
||||
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
self.config = config
|
||||
self.fetch_group_app_code = fetch_group_app_code
|
||||
self.tasks = {}
|
||||
self.temp_dict = {}
|
||||
self.tasks_lock = threading.Lock()
|
||||
@@ -98,7 +100,7 @@ class MJBot:
|
||||
return TaskType.VARIATION
|
||||
elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
|
||||
return TaskType.RESET
|
||||
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix") and self.config.get("enabled"):
|
||||
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix") and self._is_mj_open(context):
|
||||
return TaskType.GENERATE
|
||||
|
||||
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
||||
@@ -129,8 +131,8 @@ class MJBot:
|
||||
self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if not self.config.get("enabled"):
|
||||
logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置")
|
||||
if not self._is_mj_open(context):
|
||||
logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置,或者在LinkAI平台 应用中添加/打开”MJ“插件")
|
||||
self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
@@ -409,6 +411,25 @@ class MJBot:
|
||||
result.append(task)
|
||||
return result
|
||||
|
||||
def _is_mj_open(self, context) -> bool:
|
||||
# 获取远程应用插件状态
|
||||
remote_enabled = False
|
||||
if context.kwargs.get("isgroup"):
|
||||
# 群聊场景只查询群对应的app_code
|
||||
group_name = context.get("msg").from_user_nickname
|
||||
app_code = self.fetch_group_app_code(group_name)
|
||||
if app_code:
|
||||
remote_enabled = Util.fetch_app_plugin(app_code, "Midjourney")
|
||||
else:
|
||||
# 非群聊场景使用全局app_code
|
||||
app_code = conf().get("linkai_app_code")
|
||||
if app_code:
|
||||
remote_enabled = Util.fetch_app_plugin(app_code, "Midjourney")
|
||||
|
||||
# 本地配置
|
||||
base_enabled = self.config.get("enabled")
|
||||
|
||||
return base_enabled or remote_enabled
|
||||
|
||||
def _send(channel, reply: Reply, context, retry_cnt=0):
|
||||
try:
|
||||
|
||||
@@ -2,25 +2,33 @@ import requests
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import os
|
||||
import html
|
||||
|
||||
|
||||
class LinkSummary:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def summary_file(self, file_path: str):
|
||||
def summary_file(self, file_path: str, app_code: str):
|
||||
file_body = {
|
||||
"file": open(file_path, "rb"),
|
||||
"name": file_path.split("/")[-1],
|
||||
"name": file_path.split("/")[-1]
|
||||
}
|
||||
body = {
|
||||
"app_code": app_code
|
||||
}
|
||||
url = self.base_url() + "/v1/summary/file"
|
||||
res = requests.post(url, headers=self.headers(), files=file_body, timeout=(5, 300))
|
||||
logger.info(f"[LinkSum] file summary, app_code={app_code}")
|
||||
res = requests.post(url, headers=self.headers(), files=file_body, data=body, timeout=(5, 300))
|
||||
return self._parse_summary_res(res)
|
||||
|
||||
def summary_url(self, url: str):
|
||||
def summary_url(self, url: str, app_code: str):
|
||||
url = html.unescape(url)
|
||||
body = {
|
||||
"url": url
|
||||
"url": url,
|
||||
"app_code": app_code
|
||||
}
|
||||
logger.info(f"[LinkSum] url summary, app_code={app_code}")
|
||||
res = requests.post(url=self.base_url() + "/v1/summary/url", headers=self.headers(), json=body, timeout=(5, 180))
|
||||
return self._parse_summary_res(res)
|
||||
|
||||
@@ -46,7 +54,7 @@ class LinkSummary:
|
||||
def _parse_summary_res(self, res):
|
||||
if res.status_code == 200:
|
||||
res = res.json()
|
||||
logger.debug(f"[LinkSum] url summary, res={res}")
|
||||
logger.debug(f"[LinkSum] summary result, res={res}")
|
||||
if res.get("code") == 200:
|
||||
data = res.get("data")
|
||||
return {
|
||||
@@ -59,7 +67,7 @@ class LinkSummary:
|
||||
return None
|
||||
|
||||
def base_url(self):
|
||||
return conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
return conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
|
||||
def headers(self):
|
||||
return {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import requests
|
||||
from common.log import logger
|
||||
from config import global_config
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from plugins.event import EventContext, EventAction
|
||||
|
||||
from config import conf
|
||||
|
||||
class Util:
|
||||
@staticmethod
|
||||
@@ -26,3 +28,23 @@ class Util:
|
||||
reply = Reply(level, content)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
@staticmethod
|
||||
def fetch_app_plugin(app_code: str, plugin_name: str) -> bool:
|
||||
try:
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.tech")
|
||||
params = {"app_code": app_code}
|
||||
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
|
||||
if res.status_code == 200:
|
||||
plugins = res.json().get("data").get("plugins")
|
||||
for plugin in plugins:
|
||||
if plugin.get("name") and plugin.get("name") == plugin_name:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"[LinkAI] find app info exception, res={res}")
|
||||
return False
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
from config import pconf, plugin_config, conf
|
||||
from config import pconf, plugin_config, conf, write_plugin_config
|
||||
from common.log import logger
|
||||
|
||||
|
||||
@@ -18,18 +18,19 @@ class Plugin:
|
||||
if not plugin_conf:
|
||||
# 全局配置不存在,则获取插件目录下的配置
|
||||
plugin_config_path = os.path.join(self.path, "config.json")
|
||||
logger.debug(f"loading plugin config, plugin_config_path={plugin_config_path}, exist={os.path.exists(plugin_config_path)}")
|
||||
if os.path.exists(plugin_config_path):
|
||||
with open(plugin_config_path, "r", encoding="utf-8") as f:
|
||||
plugin_conf = json.load(f)
|
||||
|
||||
# 写入全局配置内存
|
||||
plugin_config[self.name] = plugin_conf
|
||||
write_plugin_config({self.name: plugin_conf})
|
||||
logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
|
||||
return plugin_conf
|
||||
|
||||
def save_config(self, config: dict):
|
||||
try:
|
||||
plugin_config[self.name] = config
|
||||
write_plugin_config({self.name: config})
|
||||
# 写入全局配置
|
||||
global_config_path = "./plugins/config.json"
|
||||
if os.path.exists(global_config_path):
|
||||
@@ -46,3 +47,6 @@ class Plugin:
|
||||
|
||||
def get_help_text(self, **kwargs):
|
||||
return "暂无帮助信息"
|
||||
|
||||
def reload(self):
|
||||
pass
|
||||
|
||||
@@ -9,7 +9,7 @@ import sys
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.sorted_dict import SortedDict
|
||||
from config import conf, write_plugin_config
|
||||
from config import conf, remove_plugin_config, write_plugin_config
|
||||
|
||||
from .event import *
|
||||
|
||||
@@ -99,7 +99,7 @@ class PluginManager:
|
||||
try:
|
||||
self.current_plugin_path = plugin_path
|
||||
if plugin_path in self.loaded:
|
||||
if self.loaded[plugin_path] == None:
|
||||
if plugin_name.upper() != 'GODCMD':
|
||||
logger.info("reload module %s" % plugin_name)
|
||||
self.loaded[plugin_path] = importlib.reload(sys.modules[import_path])
|
||||
dependent_module_names = [name for name in sys.modules.keys() if name.startswith(import_path + ".")]
|
||||
@@ -141,28 +141,35 @@ class PluginManager:
|
||||
failed_plugins = []
|
||||
for name, plugincls in self.plugins.items():
|
||||
if plugincls.enabled:
|
||||
if name not in self.instances:
|
||||
try:
|
||||
instance = plugincls()
|
||||
except Exception as e:
|
||||
logger.warn("Failed to init %s, diabled. %s" % (name, e))
|
||||
self.disable_plugin(name)
|
||||
failed_plugins.append(name)
|
||||
continue
|
||||
self.instances[name] = instance
|
||||
for event in instance.handlers:
|
||||
if event not in self.listening_plugins:
|
||||
self.listening_plugins[event] = []
|
||||
self.listening_plugins[event].append(name)
|
||||
if 'GODCMD' in self.instances and name == 'GODCMD':
|
||||
continue
|
||||
# if name not in self.instances:
|
||||
try:
|
||||
instance = plugincls()
|
||||
except Exception as e:
|
||||
logger.warn("Failed to init %s, diabled. %s" % (name, e))
|
||||
self.disable_plugin(name)
|
||||
failed_plugins.append(name)
|
||||
continue
|
||||
if name in self.instances:
|
||||
self.instances[name].handlers.clear()
|
||||
self.instances[name] = instance
|
||||
for event in instance.handlers:
|
||||
if event not in self.listening_plugins:
|
||||
self.listening_plugins[event] = []
|
||||
self.listening_plugins[event].append(name)
|
||||
self.refresh_order()
|
||||
return failed_plugins
|
||||
|
||||
def reload_plugin(self, name: str):
|
||||
name = name.upper()
|
||||
remove_plugin_config(name)
|
||||
if name in self.instances:
|
||||
for event in self.listening_plugins:
|
||||
if name in self.listening_plugins[event]:
|
||||
self.listening_plugins[event].remove(name)
|
||||
if name in self.instances:
|
||||
self.instances[name].handlers.clear()
|
||||
del self.instances[name]
|
||||
self.activate_plugins()
|
||||
return True
|
||||
|
||||
@@ -99,7 +99,8 @@ class Role(Plugin):
|
||||
if e_context["context"].type != ContextType.TEXT:
|
||||
return
|
||||
btype = Bridge().get_bot_type("chat")
|
||||
if btype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
|
||||
if btype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.QWEN_DASHSCOPE, const.XUNFEI, const.BAIDU, const.ZHIPU_AI, const.MOONSHOT, const.MiniMax, const.LINKAI]:
|
||||
logger.debug(f'不支持的bot: {btype}')
|
||||
return
|
||||
bot = Bridge().get_bot("chat")
|
||||
content = e_context["context"].content[:]
|
||||
@@ -179,6 +180,7 @@ class Role(Plugin):
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
else:
|
||||
e_context["context"]["generate_breaked_by"] = EventAction.BREAK
|
||||
prompt = self.roleplays[sessionid].action(content)
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
e_context["context"].content = prompt
|
||||
|
||||
@@ -12,13 +12,29 @@
|
||||
"url": "https://github.com/lanvent/plugin_summary.git",
|
||||
"desc": "总结聊天记录的插件"
|
||||
},
|
||||
"timetask": {
|
||||
"url": "https://github.com/haikerapples/timetask.git",
|
||||
"desc": "一款定时任务系统的插件"
|
||||
},
|
||||
"Apilot": {
|
||||
"url": "https://github.com/6vision/Apilot.git",
|
||||
"desc": "通过api直接查询早报、热榜、快递、天气等实用信息的插件"
|
||||
},
|
||||
"pictureChange": {
|
||||
"url": "https://github.com/Yanyutin753/pictureChange.git",
|
||||
"desc": "1. 支持百度AI和Stable Diffusion WebUI进行图像处理,提供多种模型选择,支持图生图、文生图自定义模板。2. 支持Suno音乐AI可将图像和文字转为音乐。3. 支持自定义模型进行文件、图片总结功能。4. 支持管理员控制群聊内容与参数和功能改变。"
|
||||
},
|
||||
"Blackroom": {
|
||||
"url": "https://github.com/dividduang/blackroom.git",
|
||||
"desc": "小黑屋插件,被拉进小黑屋的人将不能使用@bot的功能的插件"
|
||||
},
|
||||
"midjourney": {
|
||||
"url": "https://github.com/baojingyu/midjourney.git",
|
||||
"desc": "利用midjourney实现ai绘图的的插件"
|
||||
},
|
||||
"solitaire": {
|
||||
"url": "https://github.com/Wang-zhechao/solitaire.git",
|
||||
"desc": "机器人微信接龙插件"
|
||||
},
|
||||
"HighSpeedTicket": {
|
||||
"url": "https://github.com/He0607/HighSpeedTicket.git",
|
||||
"desc": "高铁(火车)票查询插件"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
{
|
||||
"tools": [
|
||||
"python",
|
||||
"url-get",
|
||||
"terminal",
|
||||
"meteo"
|
||||
],
|
||||
"kwargs": {
|
||||
"debug": true,
|
||||
"debug": false,
|
||||
"no_default": false,
|
||||
"model_name": "gpt-3.5-turbo"
|
||||
}
|
||||
|
||||
@@ -22,11 +22,13 @@ class Tool(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
|
||||
self.app = self._reset_app()
|
||||
|
||||
if not self.tool_config.get("tools"):
|
||||
logger.warn("[tool] init failed, ignore ")
|
||||
raise Exception("config.json not found")
|
||||
logger.info("[tool] inited")
|
||||
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。"
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
@@ -137,7 +139,7 @@ class Tool(Plugin):
|
||||
|
||||
return {
|
||||
# 全局配置相关
|
||||
"log": True, # tool 日志开关
|
||||
"log": False, # tool 日志开关
|
||||
"debug": kwargs.get("debug", False), # 输出更多日志
|
||||
"no_default": kwargs.get("no_default", False), # 不要默认的工具,只加载自己导入的工具
|
||||
"think_depth": kwargs.get("think_depth", 2), # 一个问题最多使用多少次工具
|
||||
|
||||
@@ -7,8 +7,10 @@ gTTS>=2.3.1 # google text to speech
|
||||
pyttsx3>=2.90 # pytsx text to speech
|
||||
baidu_aip>=4.16.10 # baidu voice
|
||||
azure-cognitiveservices-speech # azure voice
|
||||
edge-tts # edge-tts
|
||||
numpy<=1.24.2
|
||||
langid # language detect
|
||||
elevenlabs==1.0.3 # elevenlabs TTS
|
||||
|
||||
#install plugin
|
||||
dulwich
|
||||
@@ -25,6 +27,8 @@ websocket-client==1.2.0
|
||||
|
||||
# claude bot
|
||||
curl_cffi
|
||||
# claude API
|
||||
anthropic
|
||||
|
||||
# tongyi qwen
|
||||
broadscope_bailian
|
||||
@@ -32,8 +36,11 @@ broadscope_bailian
|
||||
# google
|
||||
google-generativeai
|
||||
|
||||
# linkai
|
||||
linkai
|
||||
|
||||
# dingtalk
|
||||
dingtalk_stream
|
||||
|
||||
# zhipuai
|
||||
zhipuai>=2.0.1
|
||||
|
||||
# tongyi qwen new sdk
|
||||
dashscope
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
openai==0.27.8
|
||||
HTMLParser>=0.0.2
|
||||
PyQRCode>=1.2.1
|
||||
qrcode>=7.4.2
|
||||
PyQRCode==1.2.1
|
||||
qrcode==7.4.2
|
||||
requests>=2.28.2
|
||||
chardet>=5.1.0
|
||||
Pillow
|
||||
pre-commit
|
||||
web.py
|
||||
linkai>=0.0.6.0
|
||||
|
||||
192
run.sh
Normal file
192
run.sh
Normal file
@@ -0,0 +1,192 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
# 颜色定义
|
||||
RED='\033[0;31m' # 红色
|
||||
GREEN='\033[0;32m' # 绿色
|
||||
YELLOW='\033[0;33m' # 黄色
|
||||
BLUE='\033[0;34m' # 蓝色
|
||||
NC='\033[0m' # 无颜色
|
||||
|
||||
# 获取当前脚本的目录
|
||||
export BASE_DIR=$(cd "$(dirname "$0")"; pwd)
|
||||
echo -e "${GREEN}📁 BASE_DIR: ${BASE_DIR}${NC}"
|
||||
|
||||
# 检查 config.json 文件是否存在
|
||||
check_config_file() {
|
||||
if [ ! -f "${BASE_DIR}/config.json" ]; then
|
||||
echo -e "${RED}❌ 错误:未找到 config.json 文件。请确保 config.json 存在于当前目录。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查 Python 版本是否大于等于 3.7,并检查 pip 是否可用
|
||||
check_python_version() {
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
echo -e "${RED}❌ 错误:未找到 Python3。请安装 Python 3.7 或以上版本。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
||||
PYTHON_MAJOR=$(echo "$PYTHON_VERSION" | cut -d'.' -f1)
|
||||
PYTHON_MINOR=$(echo "$PYTHON_VERSION" | cut -d'.' -f2)
|
||||
|
||||
if (( PYTHON_MAJOR < 3 || (PYTHON_MAJOR == 3 && PYTHON_MINOR < 7) )); then
|
||||
echo -e "${RED}❌ 错误:Python 版本为 ${PYTHON_VERSION}。请安装 Python 3.7 或以上版本。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! python3 -m pip --version &> /dev/null; then
|
||||
echo -e "${RED}❌ 错误:未找到 pip。请安装 pip。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查并安装缺失的依赖
|
||||
install_dependencies() {
|
||||
echo -e "${YELLOW}⏳ 正在安装依赖...${NC}"
|
||||
|
||||
if [ ! -f "${BASE_DIR}/requirements.txt" ]; then
|
||||
echo -e "${RED}❌ 错误:未找到 requirements.txt 文件。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 安装 requirements.txt 中的依赖,使用清华大学的 PyPI 镜像
|
||||
pip3 install -r "${BASE_DIR}/requirements.txt" -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# 处理 requirements-optional.txt(如果存在)
|
||||
if [ -f "${BASE_DIR}/requirements-optional.txt" ]; then
|
||||
echo -e "${YELLOW}⏳ 正在安装可选的依赖...${NC}"
|
||||
pip3 install -r "${BASE_DIR}/requirements-optional.txt" -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
fi
|
||||
}
|
||||
|
||||
# 启动项目
|
||||
run_project() {
|
||||
echo -e "${GREEN}🚀 准备启动项目...${NC}"
|
||||
cd "${BASE_DIR}"
|
||||
sleep 2
|
||||
|
||||
|
||||
# 判断操作系统类型
|
||||
OS_TYPE=$(uname)
|
||||
|
||||
if [[ "$OS_TYPE" == "Linux" ]]; then
|
||||
# 在 Linux 上使用 setsid
|
||||
setsid python3 "${BASE_DIR}/app.py" > "${BASE_DIR}/nohup.out" 2>&1 &
|
||||
echo -e "${GREEN}🚀 正在启动 ChatGPT-on-WeChat (Linux)...${NC}"
|
||||
elif [[ "$OS_TYPE" == "Darwin" ]]; then
|
||||
# 在 macOS 上直接运行
|
||||
python3 "${BASE_DIR}/app.py" > "${BASE_DIR}/nohup.out" 2>&1 &
|
||||
echo -e "${GREEN}🚀 正在启动 ChatGPT-on-WeChat (macOS)...${NC}"
|
||||
else
|
||||
echo -e "${RED}❌ 错误:不支持的操作系统 ${OS_TYPE}。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
sleep 2
|
||||
# 显示日志输出,供用户扫码
|
||||
tail -n 30 -f "${BASE_DIR}/nohup.out"
|
||||
|
||||
}
|
||||
# 更新项目
|
||||
update_project() {
|
||||
echo -e "${GREEN}🔄 准备更新项目,现在停止项目...${NC}"
|
||||
cd "${BASE_DIR}"
|
||||
|
||||
# 停止项目
|
||||
stop_project
|
||||
echo -e "${GREEN}🔄 开始更新项目...${NC}"
|
||||
# 更新代码,从 git 仓库拉取最新代码
|
||||
if [ -d .git ]; then
|
||||
GIT_PULL_OUTPUT=$(git pull)
|
||||
if [ $? -eq 0 ]; then
|
||||
if [[ "$GIT_PULL_OUTPUT" == *"Already up to date."* ]]; then
|
||||
echo -e "${GREEN}✅ 代码已经是最新的。${NC}"
|
||||
else
|
||||
echo -e "${GREEN}✅ 代码更新完成。${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e "${YELLOW}⚠️ 从 GitHub 更新失败,尝试切换到 Gitee 仓库...${NC}"
|
||||
# 更改远程仓库为 Gitee
|
||||
git remote set-url origin https://gitee.com/zhayujie/chatgpt-on-wechat.git
|
||||
GIT_PULL_OUTPUT=$(git pull)
|
||||
if [ $? -eq 0 ]; then
|
||||
if [[ "$GIT_PULL_OUTPUT" == *"Already up to date."* ]]; then
|
||||
echo -e "${GREEN}✅ 代码已经是最新的。${NC}"
|
||||
else
|
||||
echo -e "${GREEN}✅ 从 Gitee 更新成功。${NC}"
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}❌ 错误:从 Gitee 更新仍然失败,请检查网络连接。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}❌ 错误:当前目录不是 git 仓库,无法更新代码。${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 安装依赖
|
||||
install_dependencies
|
||||
|
||||
# 启动项目
|
||||
run_project
|
||||
}
|
||||
|
||||
# 停止项目
|
||||
stop_project() {
|
||||
echo -e "${GREEN}🛑 正在停止项目...${NC}"
|
||||
cd "${BASE_DIR}"
|
||||
pid=$(ps ax | grep -i app.py | grep "${BASE_DIR}" | grep python3 | grep -v grep | awk '{print $1}')
|
||||
if [ -z "$pid" ] ; then
|
||||
echo -e "${YELLOW}⚠️ 未找到正在运行的 ChatGPT-on-WeChat。${NC}"
|
||||
return
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}🛑 正在运行的 ChatGPT-on-WeChat (PID: ${pid})${NC}"
|
||||
|
||||
kill ${pid}
|
||||
sleep 3
|
||||
|
||||
if ps -p $pid > /dev/null; then
|
||||
echo -e "${YELLOW}⚠️ 进程未停止,尝试强制终止...${NC}"
|
||||
kill -9 ${pid}
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}✅ 已停止 ChatGPT-on-WeChat (PID: ${pid})${NC}"
|
||||
}
|
||||
|
||||
# 主函数,根据用户参数执行操作
|
||||
case "$1" in
|
||||
start)
|
||||
check_config_file
|
||||
check_python_version
|
||||
run_project
|
||||
;;
|
||||
stop)
|
||||
stop_project
|
||||
;;
|
||||
restart)
|
||||
stop_project
|
||||
check_config_file
|
||||
check_python_version
|
||||
run_project
|
||||
;;
|
||||
update)
|
||||
check_config_file
|
||||
check_python_version
|
||||
update_project
|
||||
;;
|
||||
*)
|
||||
echo -e "${YELLOW}=========================================${NC}"
|
||||
echo -e "${YELLOW}用法:${GREEN}$0 ${BLUE}{start|stop|restart|update}${NC}"
|
||||
echo -e "${YELLOW}示例:${NC}"
|
||||
echo -e " ${GREEN}$0 ${BLUE}start${NC}"
|
||||
echo -e " ${GREEN}$0 ${BLUE}stop${NC}"
|
||||
echo -e " ${GREEN}$0 ${BLUE}restart${NC}"
|
||||
echo -e " ${GREEN}$0 ${BLUE}update${NC}"
|
||||
echo -e "${YELLOW}=========================================${NC}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
@@ -8,6 +8,7 @@ Description:
|
||||
|
||||
"""
|
||||
|
||||
import http.client
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
@@ -61,6 +62,69 @@ def text_to_speech_aliyun(url, text, appkey, token):
|
||||
|
||||
return output_file
|
||||
|
||||
def speech_to_text_aliyun(url, audioContent, appkey, token):
|
||||
"""
|
||||
使用阿里云的语音识别服务识别音频文件中的语音。
|
||||
|
||||
参数:
|
||||
- url (str): 阿里云语音识别服务的端点URL。
|
||||
- audioContent (byte): pcm音频数据。
|
||||
- appkey (str): 您的阿里云appkey。
|
||||
- token (str): 阿里云API的认证令牌。
|
||||
|
||||
返回值:
|
||||
- str: 成功时输出识别到的文本,否则为None。
|
||||
"""
|
||||
format = 'pcm'
|
||||
sample_rate = 16000
|
||||
enablePunctuationPrediction = True
|
||||
enableInverseTextNormalization = True
|
||||
enableVoiceDetection = False
|
||||
|
||||
# 设置RESTful请求参数
|
||||
request = url + '?appkey=' + appkey
|
||||
request = request + '&format=' + format
|
||||
request = request + '&sample_rate=' + str(sample_rate)
|
||||
|
||||
if enablePunctuationPrediction :
|
||||
request = request + '&enable_punctuation_prediction=' + 'true'
|
||||
|
||||
if enableInverseTextNormalization :
|
||||
request = request + '&enable_inverse_text_normalization=' + 'true'
|
||||
|
||||
if enableVoiceDetection :
|
||||
request = request + '&enable_voice_detection=' + 'true'
|
||||
|
||||
host = 'nls-gateway-cn-shanghai.aliyuncs.com'
|
||||
|
||||
# 设置HTTPS请求头部
|
||||
httpHeaders = {
|
||||
'X-NLS-Token': token,
|
||||
'Content-type': 'application/octet-stream',
|
||||
'Content-Length': len(audioContent)
|
||||
}
|
||||
|
||||
conn = http.client.HTTPSConnection(host)
|
||||
conn.request(method='POST', url=request, body=audioContent, headers=httpHeaders)
|
||||
|
||||
response = conn.getresponse()
|
||||
body = response.read()
|
||||
try:
|
||||
body = json.loads(body)
|
||||
status = body['status']
|
||||
if status == 20000000 :
|
||||
result = body['result']
|
||||
if result :
|
||||
logger.info(f"阿里云语音识别到了:{result}")
|
||||
conn.close()
|
||||
return result
|
||||
else :
|
||||
logger.error(f"语音识别失败,状态码: {status}")
|
||||
except ValueError:
|
||||
logger.error(f"语音识别失败,收到非JSON格式的数据: {body}")
|
||||
conn.close()
|
||||
return None
|
||||
|
||||
|
||||
class AliyunTokenGenerator:
|
||||
"""
|
||||
|
||||
@@ -15,9 +15,9 @@ import time
|
||||
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from voice.audio_convert import get_pcm_from_wav
|
||||
from voice.voice import Voice
|
||||
from voice.ali.ali_api import AliyunTokenGenerator
|
||||
from voice.ali.ali_api import text_to_speech_aliyun
|
||||
from voice.ali.ali_api import AliyunTokenGenerator, speech_to_text_aliyun, text_to_speech_aliyun
|
||||
from config import conf
|
||||
|
||||
|
||||
@@ -34,7 +34,8 @@ class AliVoice(Voice):
|
||||
self.token = None
|
||||
self.token_expire_time = 0
|
||||
# 默认复用阿里云千问的 access_key 和 access_secret
|
||||
self.api_url = config.get("api_url")
|
||||
self.api_url_voice_to_text = config.get("api_url_voice_to_text")
|
||||
self.api_url_text_to_voice = config.get("api_url_text_to_voice")
|
||||
self.app_key = config.get("app_key")
|
||||
self.access_key_id = conf().get("qwen_access_key_id") or config.get("access_key_id")
|
||||
self.access_key_secret = conf().get("qwen_access_key_secret") or config.get("access_key_secret")
|
||||
@@ -53,7 +54,7 @@ class AliVoice(Voice):
|
||||
r'äöüÄÖÜáéíóúÁÉÍÓÚàèìòùÀÈÌÒÙâêîôûÂÊÎÔÛçÇñÑ,。!?,.]', '', text)
|
||||
# 提取有效的token
|
||||
token_id = self.get_valid_token()
|
||||
fileName = text_to_speech_aliyun(self.api_url, text, self.app_key, token_id)
|
||||
fileName = text_to_speech_aliyun(self.api_url_text_to_voice, text, self.app_key, token_id)
|
||||
if fileName:
|
||||
logger.info("[Ali] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
@@ -61,6 +62,25 @@ class AliVoice(Voice):
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
||||
return reply
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
"""
|
||||
将语音文件转换为文本。
|
||||
|
||||
:param voice_file: 要转换的语音文件。
|
||||
:return: 返回一个Reply对象,其中包含转换得到的文本或错误信息。
|
||||
"""
|
||||
# 提取有效的token
|
||||
token_id = self.get_valid_token()
|
||||
logger.debug("[Ali] voice file name={}".format(voice_file))
|
||||
pcm = get_pcm_from_wav(voice_file)
|
||||
text = speech_to_text_aliyun(self.api_url_voice_to_text, pcm, self.app_key, token_id)
|
||||
if text:
|
||||
logger.info("[Ali] VoicetoText = {}".format(text))
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
|
||||
return reply
|
||||
|
||||
def get_valid_token(self):
|
||||
"""
|
||||
获取有效的阿里云token。
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"api_url": "https://nls-gateway-cn-shanghai.aliyuncs.com/stream/v1/tts",
|
||||
"api_url_text_to_voice": "https://nls-gateway-cn-shanghai.aliyuncs.com/stream/v1/tts",
|
||||
"api_url_voice_to_text": "https://nls-gateway.cn-shanghai.aliyuncs.com/stream/v1/asr",
|
||||
"app_key": "",
|
||||
"access_key_id": "",
|
||||
"access_key_secret": ""
|
||||
|
||||
@@ -6,7 +6,7 @@ from common.log import logger
|
||||
try:
|
||||
import pysilk
|
||||
except ImportError:
|
||||
logger.warn("import pysilk failed, wechaty voice message will not be supported.")
|
||||
logger.debug("import pysilk failed, wechaty voice message will not be supported.")
|
||||
|
||||
from pydub import AudioSegment
|
||||
|
||||
@@ -64,7 +64,9 @@ def any_to_wav(any_path, wav_path):
|
||||
if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
|
||||
return sil_to_wav(any_path, wav_path)
|
||||
audio = AudioSegment.from_file(any_path)
|
||||
audio.export(wav_path, format="wav")
|
||||
audio.set_frame_rate(8000) # 百度语音转写支持8000采样率, pcm_s16le, 单通道语音识别
|
||||
audio.set_channels(1)
|
||||
audio.export(wav_path, format="wav", codec='pcm_s16le')
|
||||
|
||||
|
||||
def any_to_sil(any_path, sil_path):
|
||||
|
||||
@@ -65,7 +65,7 @@ class AzureVoice(Voice):
|
||||
reply = Reply(ReplyType.TEXT, result.text)
|
||||
else:
|
||||
cancel_details = result.cancellation_details
|
||||
logger.error("[Azure] voiceToText error, result={}, errordetails={}".format(result, cancel_details.error_details))
|
||||
logger.error("[Azure] voiceToText error, result={}, errordetails={}".format(result, cancel_details))
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败")
|
||||
return reply
|
||||
|
||||
|
||||
50
voice/edge/edge_voice.py
Normal file
50
voice/edge/edge_voice.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import time
|
||||
|
||||
import edge_tts
|
||||
import asyncio
|
||||
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from voice.voice import Voice
|
||||
|
||||
|
||||
class EdgeVoice(Voice):
|
||||
|
||||
def __init__(self):
|
||||
'''
|
||||
# 普通话
|
||||
zh-CN-XiaoxiaoNeural
|
||||
zh-CN-XiaoyiNeural
|
||||
zh-CN-YunjianNeural
|
||||
zh-CN-YunxiNeural
|
||||
zh-CN-YunxiaNeural
|
||||
zh-CN-YunyangNeural
|
||||
# 地方口音
|
||||
zh-CN-liaoning-XiaobeiNeural
|
||||
zh-CN-shaanxi-XiaoniNeural
|
||||
# 粤语
|
||||
zh-HK-HiuGaaiNeural
|
||||
zh-HK-HiuMaanNeural
|
||||
zh-HK-WanLungNeural
|
||||
# 湾湾腔
|
||||
zh-TW-HsiaoChenNeural
|
||||
zh-TW-HsiaoYuNeural
|
||||
zh-TW-YunJheNeural
|
||||
'''
|
||||
self.voice = "zh-CN-YunjianNeural"
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
pass
|
||||
|
||||
async def gen_voice(self, text, fileName):
|
||||
communicate = edge_tts.Communicate(text, self.voice)
|
||||
await communicate.save(fileName)
|
||||
|
||||
def textToVoice(self, text):
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
|
||||
asyncio.run(self.gen_voice(text, fileName))
|
||||
|
||||
logger.info("[EdgeTTS] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
return Reply(ReplyType.VOICE, fileName)
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
|
||||
from elevenlabs import set_api_key,generate
|
||||
|
||||
from elevenlabs.client import ElevenLabs
|
||||
from elevenlabs import save
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
@@ -9,7 +9,7 @@ from voice.voice import Voice
|
||||
from config import conf
|
||||
|
||||
XI_API_KEY = conf().get("xi_api_key")
|
||||
set_api_key(XI_API_KEY)
|
||||
client = ElevenLabs(api_key=XI_API_KEY)
|
||||
name = conf().get("xi_voice_id")
|
||||
|
||||
class ElevenLabsVoice(Voice):
|
||||
@@ -21,13 +21,12 @@ class ElevenLabsVoice(Voice):
|
||||
pass
|
||||
|
||||
def textToVoice(self, text):
|
||||
audio = generate(
|
||||
audio = client.generate(
|
||||
text=text,
|
||||
voice=name,
|
||||
model='eleven_multilingual_v1'
|
||||
model='eleven_multilingual_v2'
|
||||
)
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
with open(fileName, "wb") as f:
|
||||
f.write(audio)
|
||||
save(audio, fileName)
|
||||
logger.info("[ElevenLabs] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
return Reply(ReplyType.VOICE, fileName)
|
||||
@@ -42,4 +42,12 @@ def create_voice(voice_type):
|
||||
from voice.ali.ali_voice import AliVoice
|
||||
|
||||
return AliVoice()
|
||||
elif voice_type == "edge":
|
||||
from voice.edge.edge_voice import EdgeVoice
|
||||
|
||||
return EdgeVoice()
|
||||
elif voice_type == "xunfei":
|
||||
from voice.xunfei.xunfei_voice import XunfeiVoice
|
||||
|
||||
return XunfeiVoice()
|
||||
raise RuntimeError
|
||||
|
||||
@@ -19,7 +19,7 @@ class LinkAIVoice(Voice):
|
||||
def voiceToText(self, voice_file):
|
||||
logger.debug("[LinkVoice] voice file name={}".format(voice_file))
|
||||
try:
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/audio/transcriptions"
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/audio/transcriptions"
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
model = None
|
||||
if not conf().get("text_to_voice") or conf().get("voice_to_text") == "openai":
|
||||
@@ -54,7 +54,7 @@ class LinkAIVoice(Voice):
|
||||
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/audio/speech"
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.tech") + "/v1/audio/speech"
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
model = const.TTS_1
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
|
||||
@@ -21,8 +21,21 @@ class OpenaiVoice(Voice):
|
||||
logger.debug("[Openai] voice file name={}".format(voice_file))
|
||||
try:
|
||||
file = open(voice_file, "rb")
|
||||
result = openai.Audio.transcribe("whisper-1", file)
|
||||
text = result["text"]
|
||||
api_base = conf().get("open_ai_api_base") or "https://api.openai.com/v1"
|
||||
url = f'{api_base}/audio/transcriptions'
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + conf().get("open_ai_api_key"),
|
||||
# 'Content-Type': 'multipart/form-data' # 加了会报错,不知道什么原因
|
||||
}
|
||||
files = {
|
||||
"file": file,
|
||||
}
|
||||
data = {
|
||||
"model": "whisper-1",
|
||||
}
|
||||
response = requests.post(url, headers=headers, files=files, data=data)
|
||||
response_data = response.json()
|
||||
text = response_data['text']
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file))
|
||||
except Exception as e:
|
||||
|
||||
7
voice/xunfei/config.json.template
Normal file
7
voice/xunfei/config.json.template
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"APPID":"xxx71xxx",
|
||||
"APIKey":"xxxx69058exxxxxx",
|
||||
"APISecret":"xxxx697f0xxxxxx",
|
||||
"BusinessArgsTTS":{"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"},
|
||||
"BusinessArgsASR":{"domain": "iat", "language": "zh_cn", "accent": "mandarin", "vad_eos":10000, "dwa": "wpgs"}
|
||||
}
|
||||
209
voice/xunfei/xunfei_asr.py
Normal file
209
voice/xunfei/xunfei_asr.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# Author: njnuko
|
||||
# Email: njnuko@163.com
|
||||
#
|
||||
# 这个文档是基于官方的demo来改的,固体官方demo文档请参考官网
|
||||
#
|
||||
# 语音听写流式 WebAPI 接口调用示例 接口文档(必看):https://doc.xfyun.cn/rest_api/语音听写(流式版).html
|
||||
# webapi 听写服务参考帖子(必看):http://bbs.xfyun.cn/forum.php?mod=viewthread&tid=38947&extra=
|
||||
# 语音听写流式WebAPI 服务,热词使用方式:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--个性化热词,
|
||||
# 设置热词
|
||||
# 注意:热词只能在识别的时候会增加热词的识别权重,需要注意的是增加相应词条的识别率,但并不是绝对的,具体效果以您测试为准。
|
||||
# 语音听写流式WebAPI 服务,方言试用方法:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--识别语种列表
|
||||
# 可添加语种或方言,添加后会显示该方言的参数值
|
||||
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
|
||||
import websocket
|
||||
import datetime
|
||||
import hashlib
|
||||
import base64
|
||||
import hmac
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
import time
|
||||
import ssl
|
||||
from wsgiref.handlers import format_date_time
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
import _thread as thread
|
||||
import os
|
||||
import wave
|
||||
|
||||
|
||||
STATUS_FIRST_FRAME = 0 # 第一帧的标识
|
||||
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
|
||||
STATUS_LAST_FRAME = 2 # 最后一帧的标识
|
||||
|
||||
#############
|
||||
#whole_dict 是用来存储返回值的,由于带语音修正,所以用dict来存储,有更新的化pop之前的值,最后再合并
|
||||
global whole_dict
|
||||
#这个文档是官方文档改的,这个参数是用来做函数调用时用的
|
||||
global wsParam
|
||||
##############
|
||||
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
def __init__(self, APPID, APIKey, APISecret,BusinessArgs, AudioFile):
|
||||
self.APPID = APPID
|
||||
self.APIKey = APIKey
|
||||
self.APISecret = APISecret
|
||||
self.AudioFile = AudioFile
|
||||
self.BusinessArgs = BusinessArgs
|
||||
# 公共参数(common)
|
||||
self.CommonArgs = {"app_id": self.APPID}
|
||||
# 业务参数(business),更多个性化参数可在官网查看
|
||||
#self.BusinessArgs = {"domain": "iat", "language": "zh_cn", "accent": "mandarin", "vinfo":1,"vad_eos":10000}
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = 'wss://ws-api.xfyun.cn/v2/iat'
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/iat " + "HTTP/1.1"
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": "ws-api.xfyun.cn"
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = url + '?' + urlencode(v)
|
||||
#print("date: ",date)
|
||||
#print("v: ",v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
#print('websocket url :', url)
|
||||
return url
|
||||
|
||||
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
global whole_dict
|
||||
try:
|
||||
code = json.loads(message)["code"]
|
||||
sid = json.loads(message)["sid"]
|
||||
if code != 0:
|
||||
errMsg = json.loads(message)["message"]
|
||||
print("sid:%s call error:%s code is:%s" % (sid, errMsg, code))
|
||||
else:
|
||||
temp1 = json.loads(message)["data"]["result"]
|
||||
data = json.loads(message)["data"]["result"]["ws"]
|
||||
sn = temp1["sn"]
|
||||
if "rg" in temp1.keys():
|
||||
rep = temp1["rg"]
|
||||
rep_start = rep[0]
|
||||
rep_end = rep[1]
|
||||
for sn in range(rep_start,rep_end+1):
|
||||
#print("before pop",whole_dict)
|
||||
#print("sn",sn)
|
||||
whole_dict.pop(sn,None)
|
||||
#print("after pop",whole_dict)
|
||||
results = ""
|
||||
for i in data:
|
||||
for w in i["cw"]:
|
||||
results += w["w"]
|
||||
whole_dict[sn]=results
|
||||
#print("after add",whole_dict)
|
||||
else:
|
||||
results = ""
|
||||
for i in data:
|
||||
for w in i["cw"]:
|
||||
results += w["w"]
|
||||
whole_dict[sn]=results
|
||||
#print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False)))
|
||||
except Exception as e:
|
||||
print("receive msg,but parse exception:", e)
|
||||
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("### error:", error)
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws,a,b):
|
||||
print("### closed ###")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
global wsParam
|
||||
def run(*args):
|
||||
frameSize = 8000 # 每一帧的音频大小
|
||||
intervel = 0.04 # 发送音频间隔(单位:s)
|
||||
status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧
|
||||
|
||||
with wave.open(wsParam.AudioFile, "rb") as fp:
|
||||
while True:
|
||||
buf = fp.readframes(frameSize)
|
||||
# 文件结束
|
||||
if not buf:
|
||||
status = STATUS_LAST_FRAME
|
||||
# 第一帧处理
|
||||
# 发送第一帧音频,带business 参数
|
||||
# appid 必须带上,只需第一帧发送
|
||||
if status == STATUS_FIRST_FRAME:
|
||||
d = {"common": wsParam.CommonArgs,
|
||||
"business": wsParam.BusinessArgs,
|
||||
"data": {"status": 0, "format": "audio/L16;rate=16000","audio": str(base64.b64encode(buf), 'utf-8'), "encoding": "raw"}}
|
||||
d = json.dumps(d)
|
||||
ws.send(d)
|
||||
status = STATUS_CONTINUE_FRAME
|
||||
# 中间帧处理
|
||||
elif status == STATUS_CONTINUE_FRAME:
|
||||
d = {"data": {"status": 1, "format": "audio/L16;rate=16000",
|
||||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||
"encoding": "raw"}}
|
||||
ws.send(json.dumps(d))
|
||||
# 最后一帧处理
|
||||
elif status == STATUS_LAST_FRAME:
|
||||
d = {"data": {"status": 2, "format": "audio/L16;rate=16000",
|
||||
"audio": str(base64.b64encode(buf), 'utf-8'),
|
||||
"encoding": "raw"}}
|
||||
ws.send(json.dumps(d))
|
||||
time.sleep(1)
|
||||
break
|
||||
# 模拟音频采样间隔
|
||||
time.sleep(intervel)
|
||||
ws.close()
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
#提供给xunfei_voice调用的函数
|
||||
def xunfei_asr(APPID,APISecret,APIKey,BusinessArgsASR,AudioFile):
|
||||
global whole_dict
|
||||
global wsParam
|
||||
whole_dict = {}
|
||||
wsParam1 = Ws_Param(APPID=APPID, APISecret=APISecret,
|
||||
APIKey=APIKey,BusinessArgs=BusinessArgsASR,
|
||||
AudioFile=AudioFile)
|
||||
#wsParam是global变量,给上面on_open函数调用使用的
|
||||
wsParam = wsParam1
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
|
||||
ws.on_open = on_open
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
#把字典的值合并起来做最后识别的输出
|
||||
whole_words = ""
|
||||
for i in sorted(whole_dict.keys()):
|
||||
whole_words += whole_dict[i]
|
||||
return whole_words
|
||||
|
||||
|
||||
163
voice/xunfei/xunfei_tts.py
Normal file
163
voice/xunfei/xunfei_tts.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
#
|
||||
# Author: njnuko
|
||||
# Email: njnuko@163.com
|
||||
#
|
||||
# 这个文档是基于官方的demo来改的,固体官方demo文档请参考官网
|
||||
#
|
||||
# 语音听写流式 WebAPI 接口调用示例 接口文档(必看):https://doc.xfyun.cn/rest_api/语音听写(流式版).html
|
||||
# webapi 听写服务参考帖子(必看):http://bbs.xfyun.cn/forum.php?mod=viewthread&tid=38947&extra=
|
||||
# 语音听写流式WebAPI 服务,热词使用方式:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--个性化热词,
|
||||
# 设置热词
|
||||
# 注意:热词只能在识别的时候会增加热词的识别权重,需要注意的是增加相应词条的识别率,但并不是绝对的,具体效果以您测试为准。
|
||||
# 语音听写流式WebAPI 服务,方言试用方法:登陆开放平台https://www.xfyun.cn/后,找到控制台--我的应用---语音听写(流式)---服务管理--识别语种列表
|
||||
# 可添加语种或方言,添加后会显示该方言的参数值
|
||||
# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看)
|
||||
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
|
||||
import websocket
|
||||
import datetime
|
||||
import hashlib
|
||||
import base64
|
||||
import hmac
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
import time
|
||||
import ssl
|
||||
from wsgiref.handlers import format_date_time
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
import _thread as thread
|
||||
import os
|
||||
|
||||
|
||||
|
||||
STATUS_FIRST_FRAME = 0 # 第一帧的标识
|
||||
STATUS_CONTINUE_FRAME = 1 # 中间帧标识
|
||||
STATUS_LAST_FRAME = 2 # 最后一帧的标识
|
||||
|
||||
#############
|
||||
#这个参数是用来做输出文件路径的
|
||||
global outfile
|
||||
#这个文档是官方文档改的,这个参数是用来做函数调用时用的
|
||||
global wsParam
|
||||
##############
|
||||
|
||||
|
||||
class Ws_Param(object):
|
||||
# 初始化
|
||||
def __init__(self, APPID, APIKey, APISecret,BusinessArgs,Text):
|
||||
self.APPID = APPID
|
||||
self.APIKey = APIKey
|
||||
self.APISecret = APISecret
|
||||
self.BusinessArgs = BusinessArgs
|
||||
self.Text = Text
|
||||
|
||||
# 公共参数(common)
|
||||
self.CommonArgs = {"app_id": self.APPID}
|
||||
# 业务参数(business),更多个性化参数可在官网查看
|
||||
#self.BusinessArgs = {"aue": "raw", "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"}
|
||||
self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-8')), "UTF8")}
|
||||
#使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"”
|
||||
#self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")}
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
url = 'wss://tts-api.xfyun.cn/v2/tts'
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.APIKey, "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": "ws-api.xfyun.cn"
|
||||
}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = url + '?' + urlencode(v)
|
||||
# print("date: ",date)
|
||||
# print("v: ",v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
# print('websocket url :', url)
|
||||
return url
|
||||
|
||||
def on_message(ws, message):
|
||||
#输出文件
|
||||
global outfile
|
||||
try:
|
||||
message =json.loads(message)
|
||||
code = message["code"]
|
||||
sid = message["sid"]
|
||||
audio = message["data"]["audio"]
|
||||
audio = base64.b64decode(audio)
|
||||
status = message["data"]["status"]
|
||||
if status == 2:
|
||||
print("ws is closed")
|
||||
ws.close()
|
||||
if code != 0:
|
||||
errMsg = message["message"]
|
||||
print("sid:%s call error:%s code is:%s" % (sid, errMsg, code))
|
||||
else:
|
||||
|
||||
with open(outfile, 'ab') as f:
|
||||
f.write(audio)
|
||||
|
||||
except Exception as e:
|
||||
print("receive msg,but parse exception:", e)
|
||||
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
global outfile
|
||||
global wsParam
|
||||
def run(*args):
|
||||
d = {"common": wsParam.CommonArgs,
|
||||
"business": wsParam.BusinessArgs,
|
||||
"data": wsParam.Data,
|
||||
}
|
||||
d = json.dumps(d)
|
||||
# print("------>开始发送文本数据")
|
||||
ws.send(d)
|
||||
if os.path.exists(outfile):
|
||||
os.remove(outfile)
|
||||
|
||||
thread.start_new_thread(run, ())
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
print("### error:", error)
|
||||
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws):
|
||||
print("### closed ###")
|
||||
|
||||
|
||||
|
||||
def xunfei_tts(APPID, APIKey, APISecret,BusinessArgsTTS, Text, OutFile):
|
||||
global outfile
|
||||
global wsParam
|
||||
outfile = OutFile
|
||||
wsParam1 = Ws_Param(APPID,APIKey,APISecret,BusinessArgsTTS,Text)
|
||||
wsParam = wsParam1
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = wsParam.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close)
|
||||
ws.on_open = on_open
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
return outfile
|
||||
|
||||
86
voice/xunfei/xunfei_voice.py
Normal file
86
voice/xunfei/xunfei_voice.py
Normal file
@@ -0,0 +1,86 @@
|
||||
#####################################################################
|
||||
# xunfei voice service
|
||||
# Auth: njnuko
|
||||
# Email: njnuko@163.com
|
||||
#
|
||||
# 要使用本模块, 首先到 xfyun.cn 注册一个开发者账号,
|
||||
# 之后创建一个新应用, 然后在应用管理的语音识别或者语音合同右边可以查看APPID API Key 和 Secret Key
|
||||
# 然后在 config.json 中填入这三个值
|
||||
#
|
||||
# 配置说明:
|
||||
# {
|
||||
# "APPID":"xxx71xxx",
|
||||
# "APIKey":"xxxx69058exxxxxx", #讯飞xfyun.cn控制台语音合成或者听写界面的APIKey
|
||||
# "APISecret":"xxxx697f0xxxxxx", #讯飞xfyun.cn控制台语音合成或者听写界面的APIKey
|
||||
# "BusinessArgsTTS":{"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": "xiaoyan", "tte": "utf8"}, #语音合成的参数,具体可以参考xfyun.cn的文档
|
||||
# "BusinessArgsASR":{"domain": "iat", "language": "zh_cn", "accent": "mandarin", "vad_eos":10000, "dwa": "wpgs"} #语音听写的参数,具体可以参考xfyun.cn的文档
|
||||
# }
|
||||
#####################################################################
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from config import conf
|
||||
from voice.voice import Voice
|
||||
from .xunfei_asr import xunfei_asr
|
||||
from .xunfei_tts import xunfei_tts
|
||||
from voice.audio_convert import any_to_mp3
|
||||
import shutil
|
||||
from pydub import AudioSegment
|
||||
|
||||
|
||||
class XunfeiVoice(Voice):
|
||||
def __init__(self):
|
||||
try:
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
conf = None
|
||||
with open(config_path, "r") as fr:
|
||||
conf = json.load(fr)
|
||||
print(conf)
|
||||
self.APPID = str(conf.get("APPID"))
|
||||
self.APIKey = str(conf.get("APIKey"))
|
||||
self.APISecret = str(conf.get("APISecret"))
|
||||
self.BusinessArgsTTS = conf.get("BusinessArgsTTS")
|
||||
self.BusinessArgsASR= conf.get("BusinessArgsASR")
|
||||
|
||||
except Exception as e:
|
||||
logger.warn("XunfeiVoice init failed: %s, ignore " % e)
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
# 识别本地文件
|
||||
try:
|
||||
logger.debug("[Xunfei] voice file name={}".format(voice_file))
|
||||
#print("voice_file===========",voice_file)
|
||||
#print("voice_file_type===========",type(voice_file))
|
||||
#mp3_name, file_extension = os.path.splitext(voice_file)
|
||||
#mp3_file = mp3_name + ".mp3"
|
||||
#pcm_data=get_pcm_from_wav(voice_file)
|
||||
#mp3_name, file_extension = os.path.splitext(voice_file)
|
||||
#AudioSegment.from_wav(voice_file).export(mp3_file, format="mp3")
|
||||
#shutil.copy2(voice_file, 'tmp/test1.wav')
|
||||
#shutil.copy2(mp3_file, 'tmp/test1.mp3')
|
||||
#print("voice and mp3 file",voice_file,mp3_file)
|
||||
text = xunfei_asr(self.APPID,self.APISecret,self.APIKey,self.BusinessArgsASR,voice_file)
|
||||
logger.info("讯飞语音识别到了: {}".format(text))
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
except Exception as e:
|
||||
logger.warn("XunfeiVoice init failed: %s, ignore " % e)
|
||||
reply = Reply(ReplyType.ERROR, "讯飞语音识别出错了;{0}")
|
||||
return reply
|
||||
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
# Avoid the same filename under multithreading
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
return_file = xunfei_tts(self.APPID,self.APIKey,self.APISecret,self.BusinessArgsTTS,text,fileName)
|
||||
logger.info("[Xunfei] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
except Exception as e:
|
||||
logger.error("[Xunfei] textToVoice error={}".format(fileName))
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,讯飞语音合成失败")
|
||||
return reply
|
||||
Reference in New Issue
Block a user