mirror of
https://github.com/zhayujie/chatgpt-on-wechat.git
synced 2026-06-03 02:27:09 +08:00
Compare commits
355 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
74ebbdd761 | ||
|
|
5346dfdd8b | ||
|
|
3ee4147285 | ||
|
|
c41e486bfc | ||
|
|
eda3ba92fd | ||
|
|
40255290b0 | ||
|
|
af5bc73dc0 | ||
|
|
0247cd4c45 | ||
|
|
d6fdf8ca2a | ||
|
|
95708489c9 | ||
|
|
ced0fa4608 | ||
|
|
7e0fbd600f | ||
|
|
f33e4e0323 | ||
|
|
d0fd78497d | ||
|
|
8045019603 | ||
|
|
7d92b9435e | ||
|
|
1e0822703a | ||
|
|
0403ff88ef | ||
|
|
78376d591b | ||
|
|
8e23d0df20 | ||
|
|
9e281d20ab | ||
|
|
644bd4a106 | ||
|
|
7729e66a96 | ||
|
|
d67d6b7948 | ||
|
|
4c4a46bfbe | ||
|
|
4536f9c177 | ||
|
|
eae95dfef5 | ||
|
|
b67d4460ca | ||
|
|
3dea8311b1 | ||
|
|
11f6e98874 | ||
|
|
2609e595f4 | ||
|
|
ac6e41abc8 | ||
|
|
9c17e16d0a | ||
|
|
55e9064307 | ||
|
|
91cabd7d49 | ||
|
|
7456950530 | ||
|
|
8fcdda625d | ||
|
|
40a10ee926 | ||
|
|
c3f7e2645c | ||
|
|
b264af1892 | ||
|
|
43e93e8e22 | ||
|
|
d6c4789688 | ||
|
|
cb31ee6f01 | ||
|
|
f7b694ac56 | ||
|
|
eb809055d4 | ||
|
|
78d9be82b2 | ||
|
|
76a95c0226 | ||
|
|
d3ab8fb04a | ||
|
|
f7a0b63a00 | ||
|
|
a21dd97786 | ||
|
|
04943c0bfa | ||
|
|
203d4d8bfb | ||
|
|
c049a619dc | ||
|
|
cc1b14b607 | ||
|
|
e04a12a8f4 | ||
|
|
a2c82bc583 | ||
|
|
b4dc382f7c | ||
|
|
eca1892e2a | ||
|
|
23a237074e | ||
|
|
219e9eca4f | ||
|
|
413e09fb9e | ||
|
|
3514c37e4c | ||
|
|
95260e303c | ||
|
|
0cef34bdfa | ||
|
|
9838979bbd | ||
|
|
c8910b8e14 | ||
|
|
207fa1d019 | ||
|
|
be0bb591e7 | ||
|
|
bfacdb9c3b | ||
|
|
ae4077ed6c | ||
|
|
6eb3c90e18 | ||
|
|
8c2a53a504 | ||
|
|
74db1e0308 | ||
|
|
b9dfdcef3d | ||
|
|
9d4afeac31 | ||
|
|
14ae2f169a | ||
|
|
55df19142f | ||
|
|
40fd545b2c | ||
|
|
95fb07343e | ||
|
|
4d87906559 | ||
|
|
6b30dced43 | ||
|
|
293a03b7c8 | ||
|
|
c010549f17 | ||
|
|
cc0be22026 | ||
|
|
e5ba26febe | ||
|
|
36f9680eec | ||
|
|
f4f5be5b08 | ||
|
|
d89b056886 | ||
|
|
65424c7db9 | ||
|
|
32a8a847fc | ||
|
|
88fb3dbf60 | ||
|
|
f6bee3aa58 | ||
|
|
5f19f37dcb | ||
|
|
dd36d8ce9e | ||
|
|
865e4b5349 | ||
|
|
e70564752b | ||
|
|
6e0d2f9437 | ||
|
|
291f936097 | ||
|
|
0b2ce48586 | ||
|
|
da87fd9e20 | ||
|
|
d4da4d2575 | ||
|
|
bad20ff483 | ||
|
|
21ad51ffbf | ||
|
|
697c6d5fbe | ||
|
|
293c659053 | ||
|
|
a12507abbd | ||
|
|
4e675b84fb | ||
|
|
c1022feab8 | ||
|
|
ddcfcf21fe | ||
|
|
86a58c3d80 | ||
|
|
abf9a9048d | ||
|
|
b1030a527a | ||
|
|
8d07ba6332 | ||
|
|
4ce37f84e4 | ||
|
|
061d8a3a5f | ||
|
|
374cd5dbb8 | ||
|
|
5ad53c2b9c | ||
|
|
a2ec1a063d | ||
|
|
e431dbe2df | ||
|
|
7218463f9e | ||
|
|
aeb09a95b0 | ||
|
|
0c8f292e12 | ||
|
|
f001ac6903 | ||
|
|
db8e506de0 | ||
|
|
099f859dd4 | ||
|
|
b7684c1c2b | ||
|
|
058c167f79 | ||
|
|
49446d4872 | ||
|
|
ced560e1e1 | ||
|
|
339102c3cd | ||
|
|
6331350239 | ||
|
|
34e06fcbf8 | ||
|
|
70aac312ff | ||
|
|
5e00704152 | ||
|
|
1a9edb6907 | ||
|
|
0c18c3a6dd | ||
|
|
847bb51ce4 | ||
|
|
fa60a5dc63 | ||
|
|
aaed3f9839 | ||
|
|
21b956b983 | ||
|
|
792e940279 | ||
|
|
c2477b26c0 | ||
|
|
4b27de809b | ||
|
|
572932d8e8 | ||
|
|
270dd778d9 | ||
|
|
dd04287b0a | ||
|
|
36ac6d005a | ||
|
|
701daedf49 | ||
|
|
238f05f453 | ||
|
|
dd082bd212 | ||
|
|
cfd2f27b0b | ||
|
|
a2160d135e | ||
|
|
16d7836369 | ||
|
|
f3de4dcc5f | ||
|
|
e34523028f | ||
|
|
efe2fbacd6 | ||
|
|
2fa1df29be | ||
|
|
f72cd13fba | ||
|
|
5b552dffbf | ||
|
|
a0ae2d13dc | ||
|
|
f7262a0a3a | ||
|
|
9736f121eb | ||
|
|
7c8fb7eacc | ||
|
|
b45eea5908 | ||
|
|
6babf4ee6c | ||
|
|
576526d4ee | ||
|
|
c03e31b7be | ||
|
|
a1aa925019 | ||
|
|
a5a234ed97 | ||
|
|
5b5dbcd78b | ||
|
|
bd1c6361d3 | ||
|
|
1fc1febf03 | ||
|
|
55cc35efa9 | ||
|
|
5ba8fdc5e7 | ||
|
|
6ea295e227 | ||
|
|
5010c76ef7 | ||
|
|
79c7f0c29f | ||
|
|
2b3e643786 | ||
|
|
90cdff327c | ||
|
|
55c116e727 | ||
|
|
3dd83aa6b7 | ||
|
|
a74aa12641 | ||
|
|
151e8c69f9 | ||
|
|
d8bfa77705 | ||
|
|
6bd286e8d5 | ||
|
|
905532b681 | ||
|
|
04d5c1ab01 | ||
|
|
28be141dc7 | ||
|
|
652b786baf | ||
|
|
ba6c671051 | ||
|
|
ca25d0433f | ||
|
|
5338106dfa | ||
|
|
b6b76be4f6 | ||
|
|
03d94fcfa0 | ||
|
|
b2c5f0d455 | ||
|
|
54f60dd38c | ||
|
|
42f181aca2 | ||
|
|
9c3a27894f | ||
|
|
f7cd348912 | ||
|
|
aeaeb75d3b | ||
|
|
96542b532e | ||
|
|
139295fe0d | ||
|
|
13217b2ce2 | ||
|
|
5cc8b56a7c | ||
|
|
e23e01c95e | ||
|
|
bca8ba12c7 | ||
|
|
3c44bdbe1c | ||
|
|
db93ed025b | ||
|
|
4209e108d0 | ||
|
|
14cbf011af | ||
|
|
03a41ec199 | ||
|
|
125fe2a026 | ||
|
|
ac4adac29e | ||
|
|
ac449d078e | ||
|
|
79be4530d4 | ||
|
|
85ce52d70c | ||
|
|
7ab56b9076 | ||
|
|
dedf976375 | ||
|
|
89f438208a | ||
|
|
ffbc5080ae | ||
|
|
4167f13bac | ||
|
|
6ba0baabb0 | ||
|
|
081003df47 | ||
|
|
559194ffb2 | ||
|
|
97a26d4a46 | ||
|
|
503c6c9b7e | ||
|
|
9a1e10deff | ||
|
|
054f927c05 | ||
|
|
22210747d0 | ||
|
|
53b2deb72c | ||
|
|
6fc158e7d6 | ||
|
|
a23a65c731 | ||
|
|
7dc7105ee2 | ||
|
|
bac70108b2 | ||
|
|
297404b21e | ||
|
|
33a7f8b558 | ||
|
|
4a670b7df7 | ||
|
|
79e4af315e | ||
|
|
c6e31b2fdc | ||
|
|
91dc44df53 | ||
|
|
7e57f8f157 | ||
|
|
15f6b7c6d3 | ||
|
|
b213ba541d | ||
|
|
7c6ed9944e | ||
|
|
a5a825e439 | ||
|
|
a4ab547f77 | ||
|
|
76ed763abe | ||
|
|
b9e3125610 | ||
|
|
8d9d5b7b6f | ||
|
|
187601da1e | ||
|
|
cc3a0fc367 | ||
|
|
44cc4165d1 | ||
|
|
f98b43514e | ||
|
|
3c9b1a14e9 | ||
|
|
827e8eddf8 | ||
|
|
7bc27d6167 | ||
|
|
ba06edd63a | ||
|
|
cacf553a5b | ||
|
|
d89091a8ea | ||
|
|
01a56e1155 | ||
|
|
a64d7c42b1 | ||
|
|
36b6cc58bf | ||
|
|
5ac8a257e7 | ||
|
|
74119d0372 | ||
|
|
4e162c73e5 | ||
|
|
5ff753a492 | ||
|
|
89400630c0 | ||
|
|
3899c0cfe3 | ||
|
|
a086f1989f | ||
|
|
1171b04e93 | ||
|
|
c55d81825a | ||
|
|
2dcd026e9f | ||
|
|
cdf8609d24 | ||
|
|
36580c5f7f | ||
|
|
1cff2521f4 | ||
|
|
db4998a56b | ||
|
|
acbd506568 | ||
|
|
0cf8e3be73 | ||
|
|
2473334dfc | ||
|
|
1ff72d1d37 | ||
|
|
241fad5524 | ||
|
|
1b48cea50a | ||
|
|
88bf345b91 | ||
|
|
ab4ff3d1a3 | ||
|
|
3502e0d643 | ||
|
|
995894d3aa | ||
|
|
4da8714124 | ||
|
|
6b247ae880 | ||
|
|
176941ea3b | ||
|
|
5176b56d3b | ||
|
|
8abf18ab25 | ||
|
|
395edbd9f4 | ||
|
|
2386eb8fc2 | ||
|
|
68208f82a0 | ||
|
|
ca916b7ce5 | ||
|
|
01e02934da | ||
|
|
c81a79f7b9 | ||
|
|
1133648bf6 | ||
|
|
e05bc541d7 | ||
|
|
d689d20482 | ||
|
|
39dd99b272 | ||
|
|
cda21acb43 | ||
|
|
9bd7d09f20 | ||
|
|
b22994c2d2 | ||
|
|
e027286b6d | ||
|
|
d6e16995e0 | ||
|
|
782bff3a51 | ||
|
|
de26dc0597 | ||
|
|
233b24ab0f | ||
|
|
2f9e5b1219 | ||
|
|
dd36b8b150 | ||
|
|
f81ac31fe1 | ||
|
|
24b63bc5bd | ||
|
|
1817a972c6 | ||
|
|
74a253f521 | ||
|
|
41762a1c57 | ||
|
|
a786fa4b75 | ||
|
|
e4c7602c0c | ||
|
|
e0d2e34980 | ||
|
|
9ef8e1be3f | ||
|
|
aae9b64833 | ||
|
|
4bab4299f2 | ||
|
|
954e55f4b4 | ||
|
|
2361e3c28c | ||
|
|
8224c2fc16 | ||
|
|
8aac86f0a9 | ||
|
|
6384e9310b | ||
|
|
7a9205dfba | ||
|
|
94b47a56f4 | ||
|
|
709b5be634 | ||
|
|
f970b2c168 | ||
|
|
973acb37ed | ||
|
|
1c9020a565 | ||
|
|
c5f1d0042c | ||
|
|
fa706e8b1d | ||
|
|
12c170f227 | ||
|
|
db27dfe227 | ||
|
|
2db4673392 | ||
|
|
38619db629 | ||
|
|
930fd436ea | ||
|
|
98b8ff2fc8 | ||
|
|
d0662683f9 | ||
|
|
957f2574a9 | ||
|
|
109b362ebd | ||
|
|
ff3fdfa738 | ||
|
|
e2636ed54a | ||
|
|
dbe2f17e1a | ||
|
|
4dc535673f | ||
|
|
f414b6408e | ||
|
|
3aa2e6a04d | ||
|
|
1963ff273f | ||
|
|
bb737a71d5 | ||
|
|
4dbc54fa15 | ||
|
|
1d4ff796d7 | ||
|
|
44cb54a9ea |
6
.github/ISSUE_TEMPLATE/config.yml
vendored
6
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,6 +0,0 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: 知识星球
|
||||
url: https://public.zsxq.com/groups/88885848842852.html
|
||||
about: 如果你想了解更多项目细节,并与开发者们交流更多关于AI技术的实践,欢迎加入星球
|
||||
|
||||
72
.github/workflows/deploy-image-arm.yml
vendored
Normal file
72
.github/workflows/deploy-image-arm.yml
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
# GitHub recommends pinning actions to a commit SHA.
|
||||
# To get a newer version, you will need to update the SHA.
|
||||
# You can also reference a tag or branch, but the action may change without warning.
|
||||
|
||||
name: Create and publish a Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ['master']
|
||||
create:
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v1
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Available platforms
|
||||
run: echo ${{ steps.buildx.outputs.platforms }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
file: ./docker/Dockerfile.latest
|
||||
platforms: linux/arm64
|
||||
tags: ${{ steps.meta.outputs.tags }}-arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
- uses: actions/delete-package-versions@v4
|
||||
with:
|
||||
package-name: 'chatgpt-on-wechat'
|
||||
package-type: 'container'
|
||||
min-versions-to-keep: 10
|
||||
delete-only-untagged-versions: 'true'
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
11
.github/workflows/deploy-image.yml
vendored
11
.github/workflows/deploy-image.yml
vendored
@@ -19,6 +19,7 @@ env:
|
||||
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
if: github.repository == 'zhayujie/chatgpt-on-wechat'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -28,6 +29,12 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
@@ -39,7 +46,9 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@v4
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
images: |
|
||||
${{ env.IMAGE_NAME }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v3
|
||||
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -1,6 +1,8 @@
|
||||
.DS_Store
|
||||
.idea
|
||||
.vscode
|
||||
.venv
|
||||
.vs
|
||||
.wechaty/
|
||||
__pycache__/
|
||||
venv*
|
||||
@@ -22,6 +24,10 @@ plugins/**/
|
||||
!plugins/tool
|
||||
!plugins/banwords
|
||||
!plugins/banwords/**/
|
||||
plugins/banwords/__pycache__
|
||||
plugins/banwords/lib/__pycache__
|
||||
!plugins/hello
|
||||
!plugins/role
|
||||
!plugins/keyword
|
||||
!plugins/keyword
|
||||
!plugins/linkai
|
||||
client_config.json
|
||||
|
||||
176
README.md
176
README.md
@@ -1,38 +1,54 @@
|
||||
# 简介
|
||||
|
||||
> ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型把我们的微信打造成一个智能机器人,可以在与好友对话中给出意想不到的回应,而且再也不用担心女朋友影响我们 ~~打游戏~~ 工作了。
|
||||
> 本项目是基于大模型的智能对话机器人,支持微信、企业微信、公众号、飞书、钉钉接入,可选择GPT3.5/GPT4.0/Claude/文心一言/讯飞星火/通义千问/Gemini/LinkAI/ZhipuAI,能处理文本、语音和图片,通过插件访问操作系统和互联网等外部资源,支持基于自有知识库定制企业AI应用。
|
||||
|
||||
最新版本支持的功能如下:
|
||||
|
||||
- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信,微信公众号和企业微信应用等部署方式
|
||||
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3,GPT-3.5,GPT-4模型
|
||||
- [x] **语音识别:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai等多种语音模型
|
||||
- [x] **图片生成:** 支持图片生成 和 图生图(如照片修复),可选择 Dell-E, stable diffusion, replicate模型
|
||||
- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结等插件
|
||||
- [X] **Tool工具:** 与操作系统和互联网交互,支持最新信息搜索、数学计算、天气和资讯查询、网页总结,基于 [chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) 实现
|
||||
|
||||
> 欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py)实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
|
||||
|
||||
**一键部署:**
|
||||
- 个人微信
|
||||
|
||||
[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
- [x] **多端部署:** 有多种部署方式可选择且功能完备,目前已支持个人微信、微信公众号和、企业微信、飞书、钉钉等部署方式
|
||||
- [x] **基础对话:** 私聊及群聊的消息智能回复,支持多轮会话上下文记忆,支持 GPT-3.5, GPT-4, claude, Gemini, 文心一言, 讯飞星火, 通义千问,ChatGLM
|
||||
- [x] **语音能力:** 可识别语音消息,通过文字或语音回复,支持 azure, baidu, google, openai(whisper/tts) 等多种语音模型
|
||||
- [x] **图像能力:** 支持图片生成、图片识别、图生图(如照片修复),可选择 Dall-E-3, stable diffusion, replicate, midjourney, CogView-3, vision模型
|
||||
- [x] **丰富插件:** 支持个性化插件扩展,已实现多角色切换、文字冒险、敏感词过滤、聊天记录总结、文档总结和对话、联网搜索等插件
|
||||
- [x] **知识库:** 通过上传知识库文件自定义专属机器人,可作为数字分身、智能客服、私域助手使用,基于 [LinkAI](https://link-ai.tech) 实现
|
||||
|
||||
# 演示
|
||||
|
||||
https://user-images.githubusercontent.com/26161723/233777277-e3b9928e-b88f-43e2-b0e0-3cbc923bc799.mp4
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/assets/26161723/d5154020-36e3-41db-8706-40ce9f3f1b1e
|
||||
|
||||
Demo made by [Visionn](https://www.wangpc.cc/)
|
||||
|
||||
# 交流群
|
||||
# 商业支持
|
||||
|
||||
添加小助手微信进群:
|
||||
> 我们还提供企业级的 **AI应用平台**,包含知识库、Agent插件、应用管理等能力,支持多平台聚合的应用接入、客户端管理、对话管理,以及提供
|
||||
SaaS服务、私有化部署、稳定托管接入 等多种模式。
|
||||
>
|
||||
> 目前已在私域运营、智能客服、企业效率助手等场景积累了丰富的 AI 解决方案, 在电商、文教、健康、新消费等各行业沉淀了 AI 落地的最佳实践,致力于打造助力中小企业拥抱 AI 的一站式平台。
|
||||
|
||||
企业服务和商用咨询可联系产品顾问:
|
||||
|
||||
<img width="240" src="https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg">
|
||||
|
||||
# 开源社区
|
||||
|
||||
添加小助手微信加入开源项目交流群:
|
||||
|
||||
<img width="240" src="./docs/images/contact.jpg">
|
||||
|
||||
# 更新日志
|
||||
|
||||
>**2023.06.12:** 接入 [LinkAI](https://chat.link-ai.tech/console) 平台,可在线创建 个人知识库,并接入微信中。Beta版本欢迎体验,使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
>**2023.11.11:** [1.5.3版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.3) 和 [1.5.4版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.4),新增Google Gemini、通义千问模型
|
||||
|
||||
>**2023.11.10:** [1.5.2版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.2),新增飞书通道、图像识别对话、黑名单配置
|
||||
|
||||
>**2023.11.10:** [1.5.0版本](https://github.com/zhayujie/chatgpt-on-wechat/releases/tag/1.5.0),新增 `gpt-4-turbo`, `dall-e-3`, `tts` 模型接入,完善图像理解&生成、语音识别&生成的多模态能力
|
||||
|
||||
>**2023.10.16:** 支持通过意图识别使用LinkAI联网搜索、数学计算、网页访问等插件,参考[插件文档](https://docs.link-ai.tech/platform/plugins)
|
||||
|
||||
>**2023.09.26:** 插件增加 文件/文章链接 一键总结和对话的功能,使用参考:[插件说明](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai#3%E6%96%87%E6%A1%A3%E6%80%BB%E7%BB%93%E5%AF%B9%E8%AF%9D%E5%8A%9F%E8%83%BD)
|
||||
|
||||
>**2023.08.08:** 接入百度文心一言模型,通过 [插件](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/linkai) 支持 Midjourney 绘图
|
||||
|
||||
>**2023.06.12:** 接入 [LinkAI](https://link-ai.tech/console) 平台,可在线创建领域知识库,并接入微信、公众号及企业微信中,打造专属客服机器人。使用参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
>**2023.04.26:** 支持企业微信应用号部署,兼容插件,并支持语音图片交互,私人助理理想选择,[使用文档](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/wechatcom/README.md)。(contributed by [@lanvent](https://github.com/lanvent) in [#944](https://github.com/zhayujie/chatgpt-on-wechat/pull/944))
|
||||
|
||||
@@ -44,25 +60,29 @@ Demo made by [Visionn](https://www.wangpc.cc/)
|
||||
|
||||
>**2023.03.09:** 基于 `whisper API`(后续已接入更多的语音`API`服务) 实现对微信语音消息的解析和回复,添加配置项 `"speech_recognition":true` 即可启用,使用参考 [#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)。(contributed by [wanggang1987](https://github.com/wanggang1987) in [#385](https://github.com/zhayujie/chatgpt-on-wechat/pull/385))
|
||||
|
||||
>**2023.03.02:** 接入[ChatGPT API](https://platform.openai.com/docs/guides/chat) (gpt-3.5-turbo),默认使用该模型进行对话,需升级openai依赖 (`pip3 install --upgrade openai`)。网络问题参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
|
||||
>**2023.02.09:** 扫码登录存在账号限制风险,请谨慎使用,参考[#58](https://github.com/AutumnWhj/ChatGPT-wechat-bot/issues/158)
|
||||
|
||||
# 快速开始
|
||||
|
||||
快速开始文档:[项目搭建文档](https://docs.link-ai.tech/cow/quick-start)
|
||||
|
||||
## 准备
|
||||
|
||||
### 1. OpenAI账号注册
|
||||
### 1. 账号注册
|
||||
|
||||
前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,参考这篇 [教程](https://www.pythonthree.com/register-openai-chatgpt/) 可以通过虚拟手机号来接收验证码。创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。
|
||||
项目默认使用OpenAI接口,需前往 [OpenAI注册页面](https://beta.openai.com/signup) 创建账号,创建完账号则前往 [API管理页面](https://beta.openai.com/account/api-keys) 创建一个 API Key 并保存下来,后面需要在项目中配置这个key。接口需要海外网络访问及绑定信用卡支付。
|
||||
|
||||
> 项目中默认使用的对话模型是 gpt3.5 turbo,计费方式是约每 500 汉字 (包含请求和回复) 消耗 $0.002,图片生成是每张消耗 $0.016。
|
||||
> 默认对话模型是 openai 的 gpt-3.5-turbo,计费方式是约每 1000tokens (约750个英文单词 或 500汉字,包含请求和回复) 消耗 $0.002,图片生成是Dell E模型,每张消耗 $0.016。
|
||||
|
||||
项目同时也支持使用 LinkAI 接口,无需代理,可使用 文心、讯飞、GPT-3、GPT-4 等模型,支持 定制化知识库、联网搜索、MJ绘图、文档总结和对话等能力。修改配置即可一键切换,参考 [接入文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
### 2.运行环境
|
||||
|
||||
支持 Linux、MacOS、Windows 系统(可在Linux服务器上长期运行),同时需安装 `Python`。
|
||||
> 建议Python版本在 3.7.1~3.9.X 之间,推荐3.8版本,3.10及以上版本在 MacOS 可用,其他系统上不确定能否正常运行。
|
||||
|
||||
> 注意:Docker 或 Railway 部署无需安装python环境和下载源码,可直接快进到下一节。
|
||||
|
||||
**(1) 克隆项目代码:**
|
||||
|
||||
```bash
|
||||
@@ -70,6 +90,8 @@ git clone https://github.com/zhayujie/chatgpt-on-wechat
|
||||
cd chatgpt-on-wechat/
|
||||
```
|
||||
|
||||
注: 如遇到网络问题可选择国内镜像 https://gitee.com/zhayujie/chatgpt-on-wechat
|
||||
|
||||
**(2) 安装核心依赖 (必选):**
|
||||
> 能够使用`itchat`创建机器人,并具有文字交流功能所需的最小依赖集合。
|
||||
```bash
|
||||
@@ -81,23 +103,7 @@ pip3 install -r requirements.txt
|
||||
```bash
|
||||
pip3 install -r requirements-optional.txt
|
||||
```
|
||||
> 如果某项依赖安装失败请注释掉对应的行再继续。
|
||||
|
||||
其中`tiktoken`要求`python`版本在3.8以上,它用于精确计算会话使用的tokens数量,强烈建议安装。
|
||||
|
||||
|
||||
使用`google`或`baidu`语音识别需安装`ffmpeg`,
|
||||
|
||||
默认的`openai`语音识别不需要安装`ffmpeg`。
|
||||
|
||||
参考[#415](https://github.com/zhayujie/chatgpt-on-wechat/issues/415)
|
||||
|
||||
使用`azure`语音功能需安装依赖,并参考[文档](https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/quickstarts/setup-platform?pivots=programming-language-python&tabs=linux%2Cubuntu%2Cdotnet%2Cjre%2Cmaven%2Cnodejs%2Cmac%2Cpypi)的环境要求。
|
||||
:
|
||||
|
||||
```bash
|
||||
pip3 install azure-cognitiveservices-speech
|
||||
```
|
||||
> 如果某项依赖安装失败可注释掉对应的行再继续
|
||||
|
||||
## 配置
|
||||
|
||||
@@ -113,8 +119,8 @@ pip3 install azure-cognitiveservices-speech
|
||||
# config.json文件内容示例
|
||||
{
|
||||
"open_ai_api_key": "YOUR API KEY", # 填入上面创建的 OpenAI API KEY
|
||||
"model": "gpt-3.5-turbo", # 模型名称。当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"proxy": "127.0.0.1:7890", # 代理客户端的ip和端口
|
||||
"model": "gpt-3.5-turbo", # 模型名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"proxy": "", # 代理客户端的ip和端口,国内环境开启代理的需要填写该项,如 "127.0.0.1:7890"
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
@@ -126,9 +132,13 @@ pip3 install azure-cognitiveservices-speech
|
||||
"group_speech_recognition": false, # 是否开启群组语音识别
|
||||
"use_azure_chatgpt": false, # 是否使用Azure ChatGPT service代替openai ChatGPT service. 当设置为true时需要设置 open_ai_api_base,如 https://xxx.openai.azure.com/
|
||||
"azure_deployment_id": "", # 采用Azure ChatGPT时,模型部署名称
|
||||
"azure_api_version": "", # 采用Azure ChatGPT时,API版本
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
# 订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复,可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。"
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持角色扮演和文字冒险等丰富插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false, # 是否使用LinkAI接口,默认关闭,开启后可国内访问,使用知识库和MJ
|
||||
"linkai_api_key": "", # LinkAI Api Key
|
||||
"linkai_app_code": "" # LinkAI 应用code
|
||||
}
|
||||
```
|
||||
**配置说明:**
|
||||
@@ -153,7 +163,7 @@ pip3 install azure-cognitiveservices-speech
|
||||
|
||||
**4.其他配置**
|
||||
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k` (其中gpt-4 api暂未完全开放,申请通过后可使用)
|
||||
+ `model`: 模型名称,目前支持 `gpt-3.5-turbo`, `text-davinci-003`, `gpt-4`, `gpt-4-32k`, `wenxin` , `claude` , `xunfei`(其中gpt-4 api暂未完全开放,申请通过后可使用)
|
||||
+ `temperature`,`frequency_penalty`,`presence_penalty`: Chat API接口参数,详情参考[OpenAI官方文档。](https://platform.openai.com/docs/api-reference/chat)
|
||||
+ `proxy`:由于目前 `openai` 接口国内无法访问,需配置代理客户端的地址,详情参考 [#351](https://github.com/zhayujie/chatgpt-on-wechat/issues/351)
|
||||
+ 对于图像生成,在满足个人或群组触发条件外,还需要额外的关键词前缀来触发,对应配置 `image_create_prefix `
|
||||
@@ -165,6 +175,12 @@ pip3 install azure-cognitiveservices-speech
|
||||
+ `character_desc` 配置中保存着你对机器人说的一段话,他会记住这段话并作为他的设定,你可以为他定制任何人格 (关于会话上下文的更多内容参考该 [issue](https://github.com/zhayujie/chatgpt-on-wechat/issues/43))
|
||||
+ `subscribe_msg`:订阅消息,公众号和企业微信channel中请填写,当被订阅时会自动回复, 可使用特殊占位符。目前支持的占位符有{trigger_prefix},在程序中它会自动替换成bot的触发词。
|
||||
|
||||
**5.LinkAI配置 (可选)**
|
||||
|
||||
+ `use_linkai`: 是否使用LinkAI接口,开启后可国内访问,使用知识库和 `Midjourney` 绘画, 参考 [文档](https://link-ai.tech/platform/link-app/wechat)
|
||||
+ `linkai_api_key`: LinkAI Api Key,可在 [控制台](https://link-ai.tech/console/interface) 创建
|
||||
+ `linkai_app_code`: LinkAI 应用code,选填
|
||||
|
||||
**本说明文档可能会未及时更新,当前所有可选的配置项均在该[`config.py`](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/config.py)中列出。**
|
||||
|
||||
## 运行
|
||||
@@ -174,17 +190,16 @@ pip3 install azure-cognitiveservices-speech
|
||||
如果是开发机 **本地运行**,直接在项目根目录下执行:
|
||||
|
||||
```bash
|
||||
python3 app.py
|
||||
python3 app.py # windows环境下该命令通常为 python app.py
|
||||
```
|
||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
|
||||
终端输出二维码后,使用微信进行扫码,当输出 "Start auto replying" 时表示自动回复程序已经成功运行了(注意:用于登录的微信需要在支付处已完成实名认证)。扫码登录后你的账号就成为机器人了,可以在微信手机端通过配置的关键词触发自动回复 (任意好友发送消息给你,或是自己发消息给好友),参考[#142](https://github.com/zhayujie/chatgpt-on-wechat/issues/142)。
|
||||
|
||||
### 2.服务器部署
|
||||
|
||||
使用nohup命令在后台运行程序:
|
||||
|
||||
```bash
|
||||
touch nohup.out # 首次运行需要新建日志文件
|
||||
nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通过日志输出二维码
|
||||
```
|
||||
扫码登录后程序即可运行于服务器后台,此时可通过 `ctrl+c` 关闭日志,不会影响后台程序的运行。使用 `ps -ef | grep app.py | grep -v grep` 命令可查看运行于后台的进程,如果想要重新启动程序可以先 `kill` 掉对应的进程。日志关闭后如果想要再次打开只需输入 `tail -f nohup.out`。此外,`scripts` 目录下有一键运行、关闭程序的脚本供使用。
|
||||
@@ -196,24 +211,71 @@ nohup python3 app.py & tail -f nohup.out # 在后台运行程序并通
|
||||
|
||||
### 3.Docker部署
|
||||
|
||||
参考文档 [Docker部署](https://github.com/limccn/chatgpt-on-wechat/wiki/Docker%E9%83%A8%E7%BD%B2) (Contributed by [limccn](https://github.com/limccn))。
|
||||
> 使用docker部署无需下载源码和安装依赖,只需要获取 docker-compose.yml 配置文件并启动容器即可。
|
||||
|
||||
### 4. Railway部署 (✅推荐)
|
||||
> Railway每月提供5刀和最多500小时的免费额度。
|
||||
1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)。
|
||||
> 前提是需要安装好 `docker` 及 `docker-compose`,安装成功的表现是执行 `docker -v` 和 `docker-compose version` (或 docker compose version) 可以查看到版本号,可前往 [docker官网](https://docs.docker.com/engine/install/) 进行下载。
|
||||
|
||||
#### (1) 下载 docker-compose.yml 文件
|
||||
|
||||
```bash
|
||||
wget https://open-1317903499.cos.ap-guangzhou.myqcloud.com/docker-compose.yml
|
||||
```
|
||||
|
||||
下载完成后打开 `docker-compose.yml` 修改所需配置,如 `OPEN_AI_API_KEY` 和 `GROUP_NAME_WHITE_LIST` 等。
|
||||
|
||||
#### (2) 启动容器
|
||||
|
||||
在 `docker-compose.yml` 所在目录下执行以下命令启动容器:
|
||||
|
||||
```bash
|
||||
sudo docker compose up -d
|
||||
```
|
||||
|
||||
运行 `sudo docker ps` 能查看到 NAMES 为 chatgpt-on-wechat 的容器即表示运行成功。
|
||||
|
||||
注意:
|
||||
|
||||
- 如果 `docker-compose` 是 1.X 版本 则需要执行 `sudo docker-compose up -d` 来启动容器
|
||||
- 该命令会自动去 [docker hub](https://hub.docker.com/r/zhayujie/chatgpt-on-wechat) 拉取 latest 版本的镜像,latest 镜像会在每次项目 release 新的版本时生成
|
||||
|
||||
最后运行以下命令可查看容器运行日志,扫描日志中的二维码即可完成登录:
|
||||
|
||||
```bash
|
||||
sudo docker logs -f chatgpt-on-wechat
|
||||
```
|
||||
|
||||
#### (3) 插件使用
|
||||
|
||||
如果需要在docker容器中修改插件配置,可通过挂载的方式完成,将 [插件配置文件](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/config.json.template)
|
||||
重命名为 `config.json`,放置于 `docker-compose.yml` 相同目录下,并在 `docker-compose.yml` 中的 `chatgpt-on-wechat` 部分下添加 `volumes` 映射:
|
||||
|
||||
```
|
||||
volumes:
|
||||
- ./config.json:/app/plugins/config.json
|
||||
```
|
||||
|
||||
### 4. Railway部署
|
||||
|
||||
> Railway 每月提供5刀和最多500小时的免费额度。 (07.11更新: 目前大部分账号已无法免费部署)
|
||||
|
||||
1. 进入 [Railway](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
2. 点击 `Deploy Now` 按钮。
|
||||
3. 设置环境变量来重载程序运行的参数,例如`open_ai_api_key`, `character_desc`。
|
||||
|
||||
**一键部署:**
|
||||
|
||||
[](https://railway.app/template/qApznZ?referralCode=RC3znh)
|
||||
|
||||
## 常见问题
|
||||
|
||||
FAQs: <https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs>
|
||||
|
||||
或直接在线咨询 [项目小助手](https://chat.link-ai.tech/app/Kv2fXJcH) (beta版本,语料完善中,回复仅供参考)
|
||||
或直接在线咨询 [项目小助手](https://link-ai.tech/app/Kv2fXJcH) (beta版本,语料完善中,回复仅供参考)
|
||||
|
||||
## 开发
|
||||
|
||||
欢迎接入更多应用,参考 [Terminal代码](https://github.com/zhayujie/chatgpt-on-wechat/blob/master/channel/terminal/terminal_channel.py) 实现接收和发送消息逻辑即可接入。 同时欢迎增加新的插件,参考 [插件说明文档](https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins)。
|
||||
|
||||
## 联系
|
||||
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。
|
||||
|
||||
如果你想了解更多项目细节,与开发者们交流更多关于AI技术的实践,欢迎加入星球:
|
||||
|
||||
<a href="https://public.zsxq.com/groups/88885848842852.html"><img width="360" src="./docs/images/planet.jpg"></a>
|
||||
欢迎提交PR、Issues,以及Star支持一下。程序运行遇到问题可以查看 [常见问题列表](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) ,其次前往 [Issues](https://github.com/zhayujie/chatgpt-on-wechat/issues) 中搜索。个人开发者可加入开源交流群参与更多讨论,企业用户可联系[产品顾问](https://img-1317903499.cos.ap-guangzhou.myqcloud.com/docs/product-manager-qrcode.jpg)咨询。
|
||||
|
||||
30
app.py
30
app.py
@@ -3,11 +3,13 @@
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
|
||||
from channel import channel_factory
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from common import const
|
||||
from config import load_config
|
||||
from plugins import *
|
||||
import threading
|
||||
|
||||
|
||||
def sigterm_handler_wrap(_signo):
|
||||
@@ -23,6 +25,21 @@ def sigterm_handler_wrap(_signo):
|
||||
signal.signal(_signo, func)
|
||||
|
||||
|
||||
def start_channel(channel_name: str):
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app", "wework",
|
||||
const.FEISHU, const.DINGTALK]:
|
||||
PluginManager().load_plugins()
|
||||
|
||||
if conf().get("use_linkai"):
|
||||
try:
|
||||
from common import linkai_client
|
||||
threading.Thread(target=linkai_client.start, args=(channel,)).start()
|
||||
except Exception as e:
|
||||
pass
|
||||
channel.startup()
|
||||
|
||||
|
||||
def run():
|
||||
try:
|
||||
# load config
|
||||
@@ -40,14 +57,11 @@ def run():
|
||||
|
||||
if channel_name == "wxy":
|
||||
os.environ["WECHATY_LOG"] = "warn"
|
||||
# os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001'
|
||||
|
||||
channel = channel_factory.create_channel(channel_name)
|
||||
if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service", "wechatcom_app"]:
|
||||
PluginManager().load_plugins()
|
||||
start_channel(channel_name)
|
||||
|
||||
# startup channel
|
||||
channel.startup()
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error("App startup failed!")
|
||||
logger.exception(e)
|
||||
|
||||
214
bot/ali/ali_qwen_bot.py
Normal file
214
bot/ali/ali_qwen_bot.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
import broadscope_bailian
|
||||
from broadscope_bailian import ChatQaMessage
|
||||
|
||||
from bot.bot import Bot
|
||||
from bot.ali.ali_qwen_session import AliQwenSession
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common import const
|
||||
from config import conf, load_config
|
||||
|
||||
class AliQwenBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api_key_expired_time = self.set_api_key()
|
||||
self.sessions = SessionManager(AliQwenSession, model=conf().get("model", const.QWEN))
|
||||
|
||||
def api_key_client(self):
|
||||
return broadscope_bailian.AccessTokenClient(access_key_id=self.access_key_id(), access_key_secret=self.access_key_secret())
|
||||
|
||||
def access_key_id(self):
|
||||
return conf().get("qwen_access_key_id")
|
||||
|
||||
def access_key_secret(self):
|
||||
return conf().get("qwen_access_key_secret")
|
||||
|
||||
def agent_key(self):
|
||||
return conf().get("qwen_agent_key")
|
||||
|
||||
def app_id(self):
|
||||
return conf().get("qwen_app_id")
|
||||
|
||||
def node_id(self):
|
||||
return conf().get("qwen_node_id", "")
|
||||
|
||||
def temperature(self):
|
||||
return conf().get("temperature", 0.2 )
|
||||
|
||||
def top_p(self):
|
||||
return conf().get("top_p", 1)
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[QWEN] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[QWEN] session query={}".format(session.messages))
|
||||
|
||||
reply_content = self.reply_text(session)
|
||||
logger.debug(
|
||||
"[QWEN] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[QWEN] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: AliQwenSession, retry_count=0) -> dict:
|
||||
"""
|
||||
call bailian's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
prompt, history = self.convert_messages_format(session.messages)
|
||||
self.update_api_key_if_expired()
|
||||
# NOTE 阿里百炼的call()函数未提供temperature参数,考虑到temperature和top_p参数作用相同,取两者较小的值作为top_p参数传入,详情见文档 https://help.aliyun.com/document_detail/2587502.htm
|
||||
response = broadscope_bailian.Completions().call(app_id=self.app_id(), prompt=prompt, history=history, top_p=min(self.temperature(), self.top_p()))
|
||||
completion_content = self.get_completion_content(response, self.node_id())
|
||||
completion_tokens, total_tokens = self.calc_tokens(session.messages, completion_content)
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": completion_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[QWEN] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[QWEN] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIError):
|
||||
logger.warn("[QWEN] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[QWEN] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
else:
|
||||
logger.exception("[QWEN] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[QWEN] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
|
||||
def set_api_key(self):
|
||||
api_key, expired_time = self.api_key_client().create_token(agent_key=self.agent_key())
|
||||
broadscope_bailian.api_key = api_key
|
||||
return expired_time
|
||||
|
||||
def update_api_key_if_expired(self):
|
||||
if time.time() > self.api_key_expired_time:
|
||||
self.api_key_expired_time = self.set_api_key()
|
||||
|
||||
def convert_messages_format(self, messages) -> Tuple[str, List[ChatQaMessage]]:
|
||||
history = []
|
||||
user_content = ''
|
||||
assistant_content = ''
|
||||
system_content = ''
|
||||
for message in messages:
|
||||
role = message.get('role')
|
||||
if role == 'user':
|
||||
user_content += message.get('content')
|
||||
elif role == 'assistant':
|
||||
assistant_content = message.get('content')
|
||||
history.append(ChatQaMessage(user_content, assistant_content))
|
||||
user_content = ''
|
||||
assistant_content = ''
|
||||
elif role =='system':
|
||||
system_content += message.get('content')
|
||||
if user_content == '':
|
||||
raise Exception('no user message')
|
||||
if system_content != '':
|
||||
# NOTE 模拟系统消息,测试发现人格描述以"你需要扮演ChatGPT"开头能够起作用,而以"你是ChatGPT"开头模型会直接否认
|
||||
system_qa = ChatQaMessage(system_content, '好的,我会严格按照你的设定回答问题')
|
||||
history.insert(0, system_qa)
|
||||
logger.debug("[QWEN] converted qa messages: {}".format([item.to_dict() for item in history]))
|
||||
logger.debug("[QWEN] user content as prompt: {}".format(user_content))
|
||||
return user_content, history
|
||||
|
||||
def get_completion_content(self, response, node_id):
|
||||
if not response['Success']:
|
||||
return f"[ERROR]\n{response['Code']}:{response['Message']}"
|
||||
text = response['Data']['Text']
|
||||
if node_id == '':
|
||||
return text
|
||||
# TODO: 当使用流程编排创建大模型应用时,响应结构如下,最终结果在['finalResult'][node_id]['response']['text']中,暂时先这么写
|
||||
# {
|
||||
# 'Success': True,
|
||||
# 'Code': None,
|
||||
# 'Message': None,
|
||||
# 'Data': {
|
||||
# 'ResponseId': '9822f38dbacf4c9b8daf5ca03a2daf15',
|
||||
# 'SessionId': 'session_id',
|
||||
# 'Text': '{"finalResult":{"LLM_T7islK":{"params":{"modelId":"qwen-plus-v1","prompt":"${systemVars.query}${bizVars.Text}"},"response":{"text":"作为一个AI语言模型,我没有年龄,因为我没有生日。\n我只是一个程序,没有生命和身体。"}}}}',
|
||||
# 'Thoughts': [],
|
||||
# 'Debug': {},
|
||||
# 'DocReferences': []
|
||||
# },
|
||||
# 'RequestId': '8e11d31551ce4c3f83f49e6e0dd998b0',
|
||||
# 'Failed': None
|
||||
# }
|
||||
text_dict = json.loads(text)
|
||||
completion_content = text_dict['finalResult'][node_id]['response']['text']
|
||||
return completion_content
|
||||
|
||||
def calc_tokens(self, messages, completion_content):
|
||||
completion_tokens = len(completion_content)
|
||||
prompt_tokens = 0
|
||||
for message in messages:
|
||||
prompt_tokens += len(message["content"])
|
||||
return completion_tokens, prompt_tokens + completion_tokens
|
||||
62
bot/ali/ali_qwen_session.py
Normal file
62
bot/ali/ali_qwen_session.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
e.g.
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
class AliQwenSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="qianwen"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
# 官方token计算规则:"对于中文文本来说,1个token通常对应一个汉字;对于英文文本来说,1个token通常对应3至4个字母或1个单词"
|
||||
# 详情请产看文档:https://help.aliyun.com/document_detail/2586397.html
|
||||
# 目前根据字符串长度粗略估计token数,不影响正常使用
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
107
bot/baidu/baidu_wenxin.py
Normal file
107
bot/baidu/baidu_wenxin.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests, json
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
BAIDU_API_KEY = conf().get("baidu_wenxin_api_key")
|
||||
BAIDU_SECRET_KEY = conf().get("baidu_wenxin_secret_key")
|
||||
|
||||
class BaiduWenxinBot(Bot):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
wenxin_model = conf().get("baidu_wenxin_model") or "eb-instant"
|
||||
if conf().get("model") and conf().get("model") == "wenxin-4":
|
||||
wenxin_model = "completions_pro"
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=wenxin_model)
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context and context.type:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[BAIDU] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
if query == "#清除记忆":
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
else:
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
result = self.reply_text(session)
|
||||
total_tokens, completion_tokens, reply_content = (
|
||||
result["total_tokens"],
|
||||
result["completion_tokens"],
|
||||
result["content"],
|
||||
)
|
||||
logger.debug(
|
||||
"[BAIDU] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(session.messages, session_id, reply_content, completion_tokens)
|
||||
)
|
||||
|
||||
if total_tokens == 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content)
|
||||
else:
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens)
|
||||
reply = Reply(ReplyType.TEXT, reply_content)
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: BaiduWenxinSession, retry_count=0):
|
||||
try:
|
||||
logger.info("[BAIDU] model={}".format(session.model))
|
||||
access_token = self.get_access_token()
|
||||
if access_token == 'None':
|
||||
logger.warn("[BAIDU] access token 获取失败")
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": 0,
|
||||
}
|
||||
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + session.model + "?access_token=" + access_token
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
payload = {'messages': session.messages}
|
||||
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))
|
||||
response_text = json.loads(response.text)
|
||||
logger.info(f"[BAIDU] response text={response_text}")
|
||||
res_content = response_text["result"]
|
||||
total_tokens = response_text["usage"]["total_tokens"]
|
||||
completion_tokens = response_text["usage"]["completion_tokens"]
|
||||
logger.info("[BAIDU] reply={}".format(res_content))
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"content": res_content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
logger.warn("[BAIDU] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
result = {"completion_tokens": 0, "content": "出错了: {}".format(e)}
|
||||
return result
|
||||
|
||||
def get_access_token(self):
|
||||
"""
|
||||
使用 AK,SK 生成鉴权签名(Access Token)
|
||||
:return: access_token,或是None(如果错误)
|
||||
"""
|
||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
params = {"grant_type": "client_credentials", "client_id": BAIDU_API_KEY, "client_secret": BAIDU_SECRET_KEY}
|
||||
return str(requests.post(url, params=params).json().get("access_token"))
|
||||
53
bot/baidu/baidu_wenxin_session.py
Normal file
53
bot/baidu/baidu_wenxin_session.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
"""
|
||||
e.g. [
|
||||
{"role": "user", "content": "Who won the world series in 2020?"},
|
||||
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
|
||||
{"role": "user", "content": "Where was it played?"}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class BaiduWenxinSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# 百度文心不支持system prompt
|
||||
# self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) >= 2:
|
||||
self.messages.pop(0)
|
||||
self.messages.pop(0)
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens, len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
# 官方token计算规则暂不明确: "大约为 token数为 "中文字 + 其他语种单词数 x 1.3"
|
||||
# 这里先直接根据字数粗略估算吧,暂不影响正常使用,仅在判断是否丢弃历史会话的时候会有偏差
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
@@ -11,31 +11,50 @@ def create_bot(bot_type):
|
||||
:return: bot instance
|
||||
"""
|
||||
if bot_type == const.BAIDU:
|
||||
# Baidu Unit对话接口
|
||||
from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
|
||||
return BaiduUnitBot()
|
||||
# 替换Baidu Unit为Baidu文心千帆对话接口
|
||||
# from bot.baidu.baidu_unit_bot import BaiduUnitBot
|
||||
# return BaiduUnitBot()
|
||||
from bot.baidu.baidu_wenxin import BaiduWenxinBot
|
||||
return BaiduWenxinBot()
|
||||
|
||||
elif bot_type == const.CHATGPT:
|
||||
# ChatGPT 网页端web接口
|
||||
from bot.chatgpt.chat_gpt_bot import ChatGPTBot
|
||||
|
||||
return ChatGPTBot()
|
||||
|
||||
elif bot_type == const.OPEN_AI:
|
||||
# OpenAI 官方对话模型API
|
||||
from bot.openai.open_ai_bot import OpenAIBot
|
||||
|
||||
return OpenAIBot()
|
||||
|
||||
elif bot_type == const.CHATGPTONAZURE:
|
||||
# Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/
|
||||
from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot
|
||||
|
||||
return AzureChatGPTBot()
|
||||
|
||||
elif bot_type == const.XUNFEI:
|
||||
from bot.xunfei.xunfei_spark_bot import XunFeiBot
|
||||
return XunFeiBot()
|
||||
|
||||
elif bot_type == const.LINKAI:
|
||||
from bot.linkai.link_ai_bot import LinkAIBot
|
||||
return LinkAIBot()
|
||||
|
||||
elif bot_type == const.CLAUDEAI:
|
||||
from bot.claude.claude_ai_bot import ClaudeAIBot
|
||||
return ClaudeAIBot()
|
||||
|
||||
elif bot_type == const.QWEN:
|
||||
from bot.ali.ali_qwen_bot import AliQwenBot
|
||||
return AliQwenBot()
|
||||
|
||||
elif bot_type == const.GEMINI:
|
||||
from bot.gemini.google_gemini_bot import GoogleGeminiBot
|
||||
return GoogleGeminiBot()
|
||||
|
||||
elif bot_type == const.ZHIPU_AI:
|
||||
from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
|
||||
return ZHIPUAIBot()
|
||||
|
||||
|
||||
raise RuntimeError
|
||||
|
||||
@@ -121,6 +121,7 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
if args is None:
|
||||
args = self.args
|
||||
response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
||||
# logger.debug("[CHATGPT] response={}".format(response))
|
||||
# logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
return {
|
||||
"total_tokens": response["usage"]["total_tokens"],
|
||||
@@ -147,8 +148,9 @@ class ChatGPTBot(Bot, OpenAIImage):
|
||||
time.sleep(10)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[CHATGPT] APIConnectionError: {}".format(e))
|
||||
need_retry = False
|
||||
result["content"] = "我连接不到你的网络"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
else:
|
||||
logger.exception("[CHATGPT] Exception: {}".format(e))
|
||||
need_retry = False
|
||||
@@ -165,7 +167,7 @@ class AzureChatGPTBot(ChatGPTBot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
openai.api_type = "azure"
|
||||
openai.api_version = "2023-03-15-preview"
|
||||
openai.api_version = conf().get("azure_api_version", "2023-06-01-preview")
|
||||
self.args["deployment_id"] = conf().get("azure_deployment_id")
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
from common import const
|
||||
|
||||
"""
|
||||
e.g. [
|
||||
@@ -55,27 +56,33 @@ class ChatGPTSession(Session):
|
||||
# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
def num_tokens_from_messages(messages, model):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
|
||||
if model in ["wenxin", "xunfei", const.GEMINI]:
|
||||
return num_tokens_by_character(messages)
|
||||
|
||||
import tiktoken
|
||||
|
||||
if model == "gpt-3.5-turbo" or model == "gpt-35-turbo":
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
||||
elif model == "gpt-4":
|
||||
return num_tokens_from_messages(messages, model="gpt-4-0314")
|
||||
if model in ["gpt-3.5-turbo-0301", "gpt-35-turbo", "gpt-3.5-turbo-1106"]:
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
elif model in ["gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613", "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613", "gpt-35-turbo-16k", "gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview", const.GPT4_TURBO_PREVIEW, const.GPT4_VISION_PREVIEW]:
|
||||
return num_tokens_from_messages(messages, model="gpt-4")
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.debug("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model == "gpt-3.5-turbo-0301":
|
||||
if model == "gpt-3.5-turbo":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif model == "gpt-4-0314":
|
||||
elif model == "gpt-4":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301")
|
||||
logger.warn(f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo")
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
@@ -85,3 +92,11 @@ def num_tokens_from_messages(messages, model):
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
|
||||
def num_tokens_by_character(messages):
|
||||
"""Returns the number of tokens used by a list of messages."""
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
|
||||
222
bot/claude/claude_ai_bot.py
Normal file
222
bot/claude/claude_ai_bot.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from curl_cffi import requests
|
||||
from bot.bot import Bot
|
||||
from bot.claude.claude_ai_session import ClaudeAiSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
class ClaudeAIBot(Bot, OpenAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ClaudeAiSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.claude_api_cookie = conf().get("claude_api_cookie")
|
||||
self.proxy = conf().get("proxy")
|
||||
self.con_uuid_dic = {}
|
||||
if self.proxy:
|
||||
self.proxies = {
|
||||
"http": self.proxy,
|
||||
"https": self.proxy
|
||||
}
|
||||
else:
|
||||
self.proxies = None
|
||||
self.error = ""
|
||||
self.org_uuid = self.get_organization_id()
|
||||
|
||||
def generate_uuid(self):
|
||||
random_uuid = uuid.uuid4()
|
||||
random_uuid_str = str(random_uuid)
|
||||
formatted_uuid = f"{random_uuid_str[0:8]}-{random_uuid_str[9:13]}-{random_uuid_str[14:18]}-{random_uuid_str[19:23]}-{random_uuid_str[24:]}"
|
||||
return formatted_uuid
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def get_organization_id(self):
|
||||
url = "https://claude.ai/api/organizations"
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}'
|
||||
}
|
||||
try:
|
||||
response = requests.get(url, headers=headers, impersonate="chrome110", proxies =self.proxies, timeout=400)
|
||||
res = json.loads(response.text)
|
||||
uuid = res[0]['uuid']
|
||||
except:
|
||||
if "App unavailable" in response.text:
|
||||
logger.error("IP error: The IP is not allowed to be used on Claude")
|
||||
self.error = "ip所在地区不被claude支持"
|
||||
elif "Invalid authorization" in response.text:
|
||||
logger.error("Cookie error: Invalid authorization of claude, check cookie please.")
|
||||
self.error = "无法通过claude身份验证,请检查cookie"
|
||||
return None
|
||||
return uuid
|
||||
|
||||
def conversation_share_check(self,session_id):
|
||||
if conf().get("claude_uuid") is not None and conf().get("claude_uuid") != "":
|
||||
con_uuid = conf().get("claude_uuid")
|
||||
return con_uuid
|
||||
if session_id not in self.con_uuid_dic:
|
||||
self.con_uuid_dic[session_id] = self.generate_uuid()
|
||||
self.create_new_chat(self.con_uuid_dic[session_id])
|
||||
return self.con_uuid_dic[session_id]
|
||||
|
||||
def check_cookie(self):
|
||||
flag = self.get_organization_id()
|
||||
return flag
|
||||
|
||||
def create_new_chat(self, con_uuid):
|
||||
"""
|
||||
新建claude对话实体
|
||||
:param con_uuid: 对话id
|
||||
:return:
|
||||
"""
|
||||
url = f"https://claude.ai/api/organizations/{self.org_uuid}/chat_conversations"
|
||||
payload = json.dumps({"uuid": con_uuid, "name": ""})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': self.claude_api_cookie,
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
response = requests.post(url, headers=headers, data=payload, impersonate="chrome110", proxies=self.proxies, timeout=400)
|
||||
# Returns JSON of the newly created conversation information
|
||||
return response.json()
|
||||
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[CLAUDEAI] failed after maximum number of retry times")
|
||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
||||
|
||||
try:
|
||||
session_id = context["session_id"]
|
||||
if self.org_uuid is None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
con_uuid = self.conversation_share_check(session_id)
|
||||
|
||||
model = conf().get("model") or "gpt-3.5-turbo"
|
||||
# remove system message
|
||||
if session.messages[0].get("role") == "system":
|
||||
if model == "wenxin" or model == "claude":
|
||||
session.messages.pop(0)
|
||||
logger.info(f"[CLAUDEAI] query={query}")
|
||||
|
||||
# do http request
|
||||
base_url = "https://claude.ai"
|
||||
payload = json.dumps({
|
||||
"completion": {
|
||||
"prompt": f"{query}",
|
||||
"timezone": "Asia/Kolkata",
|
||||
"model": "claude-2"
|
||||
},
|
||||
"organization_uuid": f"{self.org_uuid}",
|
||||
"conversation_uuid": f"{con_uuid}",
|
||||
"text": f"{query}",
|
||||
"attachments": []
|
||||
})
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/115.0',
|
||||
'Accept': 'text/event-stream, text/event-stream',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Referer': 'https://claude.ai/chats',
|
||||
'Content-Type': 'application/json',
|
||||
'Origin': 'https://claude.ai',
|
||||
'DNT': '1',
|
||||
'Connection': 'keep-alive',
|
||||
'Cookie': f'{self.claude_api_cookie}',
|
||||
'Sec-Fetch-Dest': 'empty',
|
||||
'Sec-Fetch-Mode': 'cors',
|
||||
'Sec-Fetch-Site': 'same-origin',
|
||||
'TE': 'trailers'
|
||||
}
|
||||
|
||||
res = requests.post(base_url + "/api/append_message", headers=headers, data=payload,impersonate="chrome110",proxies= self.proxies,timeout=400)
|
||||
if res.status_code == 200 or "pemission" in res.text:
|
||||
# execute success
|
||||
decoded_data = res.content.decode("utf-8")
|
||||
decoded_data = re.sub('\n+', '\n', decoded_data).strip()
|
||||
data_strings = decoded_data.split('\n')
|
||||
completions = []
|
||||
for data_string in data_strings:
|
||||
json_str = data_string[6:].strip()
|
||||
data = json.loads(json_str)
|
||||
if 'completion' in data:
|
||||
completions.append(data['completion'])
|
||||
|
||||
reply_content = ''.join(completions)
|
||||
|
||||
if "rate limi" in reply_content:
|
||||
logger.error("rate limit error: The conversation has reached the system speed limit and is synchronized with Cladue. Please go to the official website to check the lifting time")
|
||||
return Reply(ReplyType.ERROR, "对话达到系统速率限制,与cladue同步,请进入官网查看解除限制时间")
|
||||
logger.info(f"[CLAUDE] reply={reply_content}, total_tokens=invisible")
|
||||
self.sessions.session_reply(reply_content, session_id, 100)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
else:
|
||||
flag = self.check_cookie()
|
||||
if flag == None:
|
||||
return Reply(ReplyType.ERROR, self.error)
|
||||
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[CLAUDE] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}, detail: {res.text}, uuid: {con_uuid}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[CLAUDE] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
9
bot/claude/claude_ai_session.py
Normal file
9
bot/claude/claude_ai_session.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from bot.session_manager import Session
|
||||
|
||||
|
||||
class ClaudeAiSession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="claude"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
# claude逆向不支持role prompt
|
||||
# self.reset()
|
||||
75
bot/gemini/google_gemini_bot.py
Normal file
75
bot/gemini/google_gemini_bot.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Google gemini bot
|
||||
|
||||
@author zhayujie
|
||||
@Date 2023/12/15
|
||||
"""
|
||||
# encoding:utf-8
|
||||
|
||||
from bot.bot import Bot
|
||||
import google.generativeai as genai
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
|
||||
|
||||
# OpenAI对话模型API (可用)
|
||||
class GoogleGeminiBot(Bot):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.api_key = conf().get("gemini_api_key")
|
||||
# 复用文心的token计算方式
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
try:
|
||||
if context.type != ContextType.TEXT:
|
||||
logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
|
||||
return Reply(ReplyType.TEXT, None)
|
||||
logger.info(f"[Gemini] query={query}")
|
||||
session_id = context["session_id"]
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
gemini_messages = self._convert_to_gemini_messages(self._filter_messages(session.messages))
|
||||
genai.configure(api_key=self.api_key)
|
||||
model = genai.GenerativeModel('gemini-pro')
|
||||
response = model.generate_content(gemini_messages)
|
||||
reply_text = response.text
|
||||
self.sessions.session_reply(reply_text, session_id)
|
||||
logger.info(f"[Gemini] reply={reply_text}")
|
||||
return Reply(ReplyType.TEXT, reply_text)
|
||||
except Exception as e:
|
||||
logger.error("[Gemini] fetch reply error, may contain unsafe content")
|
||||
logger.error(e)
|
||||
|
||||
def _convert_to_gemini_messages(self, messages: list):
|
||||
res = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
role = "user"
|
||||
elif msg.get("role") == "assistant":
|
||||
role = "model"
|
||||
else:
|
||||
continue
|
||||
res.append({
|
||||
"role": role,
|
||||
"parts": [{"text": msg.get("content")}]
|
||||
})
|
||||
return res
|
||||
|
||||
def _filter_messages(self, messages: list):
|
||||
res = []
|
||||
turn = "user"
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
message = messages[i]
|
||||
if message.get("role") != turn:
|
||||
continue
|
||||
res.insert(0, message)
|
||||
if turn == "user":
|
||||
turn = "assistant"
|
||||
elif turn == "assistant":
|
||||
turn = "user"
|
||||
return res
|
||||
@@ -1,50 +1,61 @@
|
||||
# access LinkAI knowledge base platform
|
||||
# docs: https://link-ai.tech/platform/link-app/wechat
|
||||
|
||||
import re
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
import config
|
||||
from bot.bot import Bot
|
||||
from bot.chatgpt.chat_gpt_session import ChatGPTSession
|
||||
from bot.openai.open_ai_image import OpenAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from config import conf, pconf
|
||||
import threading
|
||||
from common import memory, utils
|
||||
import base64
|
||||
import os
|
||||
|
||||
|
||||
class LinkAIBot(Bot, OpenAIImage):
|
||||
class LinkAIBot(Bot):
|
||||
# authentication failed
|
||||
AUTH_FAILED_CODE = 401
|
||||
NO_QUOTA_CODE = 406
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.base_url = "https://api.link-ai.chat/v1"
|
||||
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.sessions = LinkAISessionManager(LinkAISession, model=conf().get("model") or "gpt-3.5-turbo")
|
||||
self.args = {}
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
return self._chat(query, context)
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if not conf().get("text_to_image"):
|
||||
logger.warn("[LinkAI] text_to_image is not enabled, ignore the IMAGE_CREATE request")
|
||||
return Reply(ReplyType.TEXT, "")
|
||||
ok, res = self.create_img(query, 0)
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
reply = Reply(ReplyType.IMAGE_URL, res)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
reply = Reply(ReplyType.ERROR, res)
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def _chat(self, query, context, retry_count=0):
|
||||
if retry_count >= 2:
|
||||
def _chat(self, query, context, retry_count=0) -> Reply:
|
||||
"""
|
||||
发起对话请求
|
||||
:param query: 请求提示词
|
||||
:param context: 对话上下文
|
||||
:param retry_count: 当前递归重试次数
|
||||
:return: 回复
|
||||
"""
|
||||
if retry_count > 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[LINKAI] failed after maximum number of retry times")
|
||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
||||
return Reply(ReplyType.TEXT, "请再问我一次吧")
|
||||
|
||||
try:
|
||||
# load config
|
||||
@@ -52,53 +63,104 @@ class LinkAIBot(Bot, OpenAIImage):
|
||||
logger.info(f"[LINKAI] won't set appcode because a plugin ({context['generate_breaked_by']}) affected the context")
|
||||
app_code = None
|
||||
else:
|
||||
app_code = conf().get("linkai_app_code")
|
||||
plugin_app_code = self._find_group_mapping_code(context)
|
||||
app_code = context.kwargs.get("app_code") or plugin_app_code or conf().get("linkai_app_code")
|
||||
linkai_api_key = conf().get("linkai_api_key")
|
||||
|
||||
session_id = context["session_id"]
|
||||
session_message = self.sessions.session_msg_query(query, session_id)
|
||||
logger.debug(f"[LinkAI] session={session_message}, session_id={session_id}")
|
||||
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
# image process
|
||||
img_cache = memory.USER_IMAGE_CACHE.get(session_id)
|
||||
if img_cache:
|
||||
messages = self._process_image_msg(app_code=app_code, session_id=session_id, query=query, img_cache=img_cache)
|
||||
if messages:
|
||||
session_message = messages
|
||||
|
||||
model = conf().get("model")
|
||||
# remove system message
|
||||
if app_code and session.messages[0].get("role") == "system":
|
||||
session.messages.pop(0)
|
||||
|
||||
logger.info(f"[LINKAI] query={query}, app_code={app_code}")
|
||||
|
||||
if session_message[0].get("role") == "system":
|
||||
if app_code or model == "wenxin":
|
||||
session_message.pop(0)
|
||||
body = {
|
||||
"appCode": app_code,
|
||||
"messages": session.messages,
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称
|
||||
"app_code": app_code,
|
||||
"messages": session_message,
|
||||
"model": model, # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"temperature": conf().get("temperature"),
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"session_id": session_id,
|
||||
"channel_type": conf().get("channel_type", "wx")
|
||||
}
|
||||
try:
|
||||
from linkai import LinkAIClient
|
||||
client_id = LinkAIClient.fetch_client_id()
|
||||
if client_id:
|
||||
body["client_id"] = client_id
|
||||
# start: client info deliver
|
||||
if context.kwargs.get("msg"):
|
||||
body["session_id"] = context.kwargs.get("msg").from_user_id
|
||||
if context.kwargs.get("msg").is_group:
|
||||
body["is_group"] = True
|
||||
body["group_name"] = context.kwargs.get("msg").from_user_nickname
|
||||
body["sender_name"] = context.kwargs.get("msg").actual_user_nickname
|
||||
else:
|
||||
if body.get("channel_type") in ["wechatcom_app"]:
|
||||
body["sender_name"] = context.kwargs.get("msg").from_user_id
|
||||
else:
|
||||
body["sender_name"] = context.kwargs.get("msg").from_user_nickname
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
file_id = context.kwargs.get("file_id")
|
||||
if file_id:
|
||||
body["file_id"] = file_id
|
||||
logger.info(f"[LINKAI] query={query}, app_code={app_code}, model={body.get('model')}, file_id={file_id}")
|
||||
headers = {"Authorization": "Bearer " + linkai_api_key}
|
||||
|
||||
# do http request
|
||||
res = requests.post(url=self.base_url + "/chat/completion", json=body, headers=headers).json()
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
# execute success
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
self.sessions.session_reply(reply_content, session_id, total_tokens, query=query)
|
||||
|
||||
agent_suffix = self._fetch_agent_suffix(response)
|
||||
if agent_suffix:
|
||||
reply_content += agent_suffix
|
||||
if not agent_suffix:
|
||||
knowledge_suffix = self._fetch_knowledge_search_suffix(response)
|
||||
if knowledge_suffix:
|
||||
reply_content += knowledge_suffix
|
||||
# image process
|
||||
if response["choices"][0].get("img_urls"):
|
||||
thread = threading.Thread(target=self._send_image, args=(context.get("channel"), context, response["choices"][0].get("img_urls")))
|
||||
thread.start()
|
||||
if response["choices"][0].get("text_content"):
|
||||
reply_content = response["choices"][0].get("text_content")
|
||||
reply_content = self._process_url(reply_content)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
|
||||
if not res or not res["success"]:
|
||||
if res.get("code") == self.AUTH_FAILED_CODE:
|
||||
logger.exception(f"[LINKAI] please check your linkai_api_key, res={res}")
|
||||
return Reply(ReplyType.ERROR, "请再问我一次吧")
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
elif res.get("code") == self.NO_QUOTA_CODE:
|
||||
logger.exception(f"[LINKAI] please check your account quota, https://chat.link-ai.tech/console/account")
|
||||
return Reply(ReplyType.ERROR, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
else:
|
||||
# retry
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
# execute success
|
||||
reply_content = res["data"]["content"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}")
|
||||
self.sessions.session_reply(reply_content, session_id)
|
||||
return Reply(ReplyType.TEXT, reply_content)
|
||||
return Reply(ReplyType.TEXT, "提问太快啦,请休息一下再问我吧")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
@@ -106,3 +168,300 @@ class LinkAIBot(Bot, OpenAIImage):
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self._chat(query, context, retry_count + 1)
|
||||
|
||||
def _process_image_msg(self, app_code: str, session_id: str, query:str, img_cache: dict):
|
||||
try:
|
||||
enable_image_input = False
|
||||
app_info = self._fetch_app_info(app_code)
|
||||
if not app_info:
|
||||
logger.debug(f"[LinkAI] not found app, can't process images, app_code={app_code}")
|
||||
return None
|
||||
plugins = app_info.get("data").get("plugins")
|
||||
for plugin in plugins:
|
||||
if plugin.get("input_type") and "IMAGE" in plugin.get("input_type"):
|
||||
enable_image_input = True
|
||||
if not enable_image_input:
|
||||
return
|
||||
msg = img_cache.get("msg")
|
||||
path = img_cache.get("path")
|
||||
msg.prepare()
|
||||
logger.info(f"[LinkAI] query with images, path={path}")
|
||||
messages = self._build_vision_msg(query, path)
|
||||
memory.USER_IMAGE_CACHE[session_id] = None
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _find_group_mapping_code(self, context):
|
||||
try:
|
||||
if context.kwargs.get("isgroup"):
|
||||
group_name = context.kwargs.get("msg").from_user_nickname
|
||||
if config.plugin_config and config.plugin_config.get("linkai"):
|
||||
linkai_config = config.plugin_config.get("linkai")
|
||||
group_mapping = linkai_config.get("group_app_map")
|
||||
if group_mapping and group_name:
|
||||
return group_mapping.get(group_name)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return None
|
||||
|
||||
def _build_vision_msg(self, query: str, path: str):
|
||||
try:
|
||||
suffix = utils.get_path_suffix(path)
|
||||
with open(path, "rb") as file:
|
||||
base64_str = base64.b64encode(file.read()).decode('utf-8')
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": query
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/{suffix};base64,{base64_str}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}]
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def reply_text(self, session: ChatGPTSession, app_code="", retry_count=0) -> dict:
|
||||
if retry_count >= 2:
|
||||
# exit from retry 2 times
|
||||
logger.warn("[LINKAI] failed after maximum number of retry times")
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": "请再问我一次吧"
|
||||
}
|
||||
|
||||
try:
|
||||
body = {
|
||||
"app_code": app_code,
|
||||
"messages": session.messages,
|
||||
"model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称, 支持 gpt-3.5-turbo, gpt-3.5-turbo-16k, gpt-4, wenxin, xunfei
|
||||
"temperature": conf().get("temperature"),
|
||||
"top_p": conf().get("top_p", 1),
|
||||
"frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
"presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
|
||||
}
|
||||
if self.args.get("max_tokens"):
|
||||
body["max_tokens"] = self.args.get("max_tokens")
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
res = requests.post(url=base_url + "/v1/chat/completions", json=body, headers=headers,
|
||||
timeout=conf().get("request_timeout", 180))
|
||||
if res.status_code == 200:
|
||||
# execute success
|
||||
response = res.json()
|
||||
reply_content = response["choices"][0]["message"]["content"]
|
||||
total_tokens = response["usage"]["total_tokens"]
|
||||
logger.info(f"[LINKAI] reply={reply_content}, total_tokens={total_tokens}")
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": response["usage"]["completion_tokens"],
|
||||
"content": reply_content,
|
||||
}
|
||||
|
||||
else:
|
||||
response = res.json()
|
||||
error = response.get("error")
|
||||
logger.error(f"[LINKAI] chat failed, status_code={res.status_code}, "
|
||||
f"msg={error.get('message')}, type={error.get('type')}")
|
||||
|
||||
if res.status_code >= 500:
|
||||
# server error, need retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self.reply_text(session, app_code, retry_count + 1)
|
||||
|
||||
return {
|
||||
"total_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"content": "提问太快啦,请休息一下再问我吧"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
# retry
|
||||
time.sleep(2)
|
||||
logger.warn(f"[LINKAI] do retry, times={retry_count}")
|
||||
return self.reply_text(session, app_code, retry_count + 1)
|
||||
|
||||
def _fetch_app_info(self, app_code: str):
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
# do http request
|
||||
base_url = conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
params = {"app_code": app_code}
|
||||
res = requests.get(url=base_url + "/v1/app/info", params=params, headers=headers, timeout=(5, 10))
|
||||
if res.status_code == 200:
|
||||
return res.json()
|
||||
else:
|
||||
logger.warning(f"[LinkAI] find app info exception, res={res}")
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
try:
|
||||
logger.info("[LinkImage] image_query={}".format(query))
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {conf().get('linkai_api_key')}"
|
||||
}
|
||||
data = {
|
||||
"prompt": query,
|
||||
"n": 1,
|
||||
"model": conf().get("text_to_image") or "dall-e-2",
|
||||
"response_format": "url",
|
||||
"img_proxy": conf().get("image_proxy")
|
||||
}
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/images/generations"
|
||||
res = requests.post(url, headers=headers, json=data, timeout=(5, 90))
|
||||
t2 = time.time()
|
||||
image_url = res.json()["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(format(e))
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
|
||||
|
||||
def _fetch_knowledge_search_suffix(self, response) -> str:
|
||||
try:
|
||||
if response.get("knowledge_base"):
|
||||
search_hit = response.get("knowledge_base").get("search_hit")
|
||||
first_similarity = response.get("knowledge_base").get("first_similarity")
|
||||
logger.info(f"[LINKAI] knowledge base, search_hit={search_hit}, first_similarity={first_similarity}")
|
||||
plugin_config = pconf("linkai")
|
||||
if plugin_config and plugin_config.get("knowledge_base") and plugin_config.get("knowledge_base").get("search_miss_text_enabled"):
|
||||
search_miss_similarity = plugin_config.get("knowledge_base").get("search_miss_similarity")
|
||||
search_miss_text = plugin_config.get("knowledge_base").get("search_miss_suffix")
|
||||
if not search_hit:
|
||||
return search_miss_text
|
||||
if search_miss_similarity and float(search_miss_similarity) > first_similarity:
|
||||
return search_miss_text
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
def _fetch_agent_suffix(self, response):
|
||||
try:
|
||||
plugin_list = []
|
||||
logger.debug(f"[LinkAgent] res={response}")
|
||||
if response.get("agent") and response.get("agent").get("chain") and response.get("agent").get("need_show_plugin"):
|
||||
chain = response.get("agent").get("chain")
|
||||
suffix = "\n\n- - - - - - - - - - - -"
|
||||
i = 0
|
||||
for turn in chain:
|
||||
plugin_name = turn.get('plugin_name')
|
||||
suffix += "\n"
|
||||
need_show_thought = response.get("agent").get("need_show_thought")
|
||||
if turn.get("thought") and plugin_name and need_show_thought:
|
||||
suffix += f"{turn.get('thought')}\n"
|
||||
if plugin_name:
|
||||
plugin_list.append(turn.get('plugin_name'))
|
||||
if turn.get('plugin_icon'):
|
||||
suffix += f"{turn.get('plugin_icon')} "
|
||||
suffix += f"{turn.get('plugin_name')}"
|
||||
if turn.get('plugin_input'):
|
||||
suffix += f":{turn.get('plugin_input')}"
|
||||
if i < len(chain) - 1:
|
||||
suffix += "\n"
|
||||
i += 1
|
||||
logger.info(f"[LinkAgent] use plugins: {plugin_list}")
|
||||
return suffix
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
def _process_url(self, text):
|
||||
try:
|
||||
url_pattern = re.compile(r'\[(.*?)\]\((http[s]?://.*?)\)')
|
||||
def replace_markdown_url(match):
|
||||
return f"{match.group(2)}"
|
||||
return url_pattern.sub(replace_markdown_url, text)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def _send_image(self, channel, context, image_urls):
|
||||
if not image_urls:
|
||||
return
|
||||
max_send_num = conf().get("max_media_send_count")
|
||||
send_interval = conf().get("media_send_interval")
|
||||
try:
|
||||
i = 0
|
||||
for url in image_urls:
|
||||
if max_send_num and i >= max_send_num:
|
||||
continue
|
||||
i += 1
|
||||
if url.endswith(".mp4"):
|
||||
reply_type = ReplyType.VIDEO_URL
|
||||
elif url.endswith(".pdf") or url.endswith(".doc") or url.endswith(".docx") or url.endswith(".csv"):
|
||||
reply_type = ReplyType.FILE
|
||||
url = _download_file(url)
|
||||
if not url:
|
||||
continue
|
||||
else:
|
||||
reply_type = ReplyType.IMAGE_URL
|
||||
reply = Reply(reply_type, url)
|
||||
channel.send(reply, context)
|
||||
if send_interval:
|
||||
time.sleep(send_interval)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
|
||||
def _download_file(url: str):
|
||||
try:
|
||||
file_path = "tmp"
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path)
|
||||
file_name = url.split("/")[-1] # 获取文件名
|
||||
file_path = os.path.join(file_path, file_name)
|
||||
response = requests.get(url)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
return file_path
|
||||
except Exception as e:
|
||||
logger.warn(e)
|
||||
|
||||
|
||||
class LinkAISessionManager(SessionManager):
|
||||
def session_msg_query(self, query, session_id):
|
||||
session = self.build_session(session_id)
|
||||
messages = session.messages + [{"role": "user", "content": query}]
|
||||
return messages
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens=None, query=None):
|
||||
session = self.build_session(session_id)
|
||||
if query:
|
||||
session.add_query(query)
|
||||
session.add_reply(reply)
|
||||
try:
|
||||
max_tokens = conf().get("conversation_max_tokens", 2500)
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.debug(f"[LinkAI] chat history, before tokens={total_tokens}, now tokens={tokens_cnt}")
|
||||
except Exception as e:
|
||||
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
|
||||
class LinkAISession(ChatGPTSession):
|
||||
def calc_tokens(self):
|
||||
if not self.messages:
|
||||
return 0
|
||||
return len(str(self.messages))
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
cur_tokens = self.calc_tokens()
|
||||
if cur_tokens > max_tokens:
|
||||
for i in range(0, len(self.messages)):
|
||||
if i > 0 and self.messages[i].get("role") == "assistant" and self.messages[i - 1].get("role") == "user":
|
||||
self.messages.pop(i)
|
||||
self.messages.pop(i - 1)
|
||||
return self.calc_tokens()
|
||||
return cur_tokens
|
||||
|
||||
@@ -15,7 +15,7 @@ class OpenAIImage(object):
|
||||
if conf().get("rate_limit_dalle"):
|
||||
self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50))
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None):
|
||||
def create_img(self, query, retry_count=0, api_key=None, api_base=None):
|
||||
try:
|
||||
if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token():
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
@@ -24,7 +24,8 @@ class OpenAIImage(object):
|
||||
api_key=api_key,
|
||||
prompt=query, # 图片描述
|
||||
n=1, # 每次生成图片的数量
|
||||
size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
model=conf().get("text_to_image") or "dall-e-2",
|
||||
# size=conf().get("image_create_size", "256x256"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
)
|
||||
image_url = response["data"][0]["url"]
|
||||
logger.info("[OPEN_AI] image_url={}".format(image_url))
|
||||
@@ -36,7 +37,7 @@ class OpenAIImage(object):
|
||||
logger.warn("[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format(retry_count + 1))
|
||||
return self.create_img(query, retry_count + 1)
|
||||
else:
|
||||
return False, "提问太快啦,请休息一下再问我吧"
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, str(e)
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
|
||||
@@ -69,7 +69,7 @@ class SessionManager(object):
|
||||
total_tokens = session.discard_exceeding(max_tokens, None)
|
||||
logger.debug("prompt tokens used={}".format(total_tokens))
|
||||
except Exception as e:
|
||||
logger.debug("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
||||
logger.warning("Exception when counting tokens precisely for prompt: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def session_reply(self, reply, session_id, total_tokens=None):
|
||||
@@ -80,7 +80,7 @@ class SessionManager(object):
|
||||
tokens_cnt = session.discard_exceeding(max_tokens, total_tokens)
|
||||
logger.debug("raw total_tokens={}, savesession tokens={}".format(total_tokens, tokens_cnt))
|
||||
except Exception as e:
|
||||
logger.debug("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
logger.warning("Exception when counting tokens precisely for session: {}".format(str(e)))
|
||||
return session
|
||||
|
||||
def clear_session(self, session_id):
|
||||
|
||||
267
bot/xunfei/xunfei_spark_bot.py
Normal file
267
bot/xunfei/xunfei_spark_bot.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import requests, json
|
||||
from bot.bot import Bot
|
||||
from bot.session_manager import SessionManager
|
||||
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
|
||||
from bridge.context import ContextType, Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from common import const
|
||||
import time
|
||||
import _thread as thread
|
||||
import datetime
|
||||
from datetime import datetime
|
||||
from wsgiref.handlers import format_date_time
|
||||
from urllib.parse import urlencode
|
||||
import base64
|
||||
import ssl
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from time import mktime
|
||||
from urllib.parse import urlparse
|
||||
import websocket
|
||||
import queue
|
||||
import threading
|
||||
import random
|
||||
|
||||
# 消息队列 map
|
||||
queue_map = dict()
|
||||
|
||||
# 响应队列 map
|
||||
reply_map = dict()
|
||||
|
||||
|
||||
class XunFeiBot(Bot):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.app_id = conf().get("xunfei_app_id")
|
||||
self.api_key = conf().get("xunfei_api_key")
|
||||
self.api_secret = conf().get("xunfei_api_secret")
|
||||
# 默认使用v2.0版本: "generalv2"
|
||||
# v1.5版本为 "general"
|
||||
# v3.0版本为: "generalv3"
|
||||
self.domain = "generalv3"
|
||||
# 默认使用v2.0版本: "ws://spark-api.xf-yun.com/v2.1/chat"
|
||||
# v1.5版本为: "ws://spark-api.xf-yun.com/v1.1/chat"
|
||||
# v3.0版本为: "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
self.spark_url = "ws://spark-api.xf-yun.com/v3.1/chat"
|
||||
self.host = urlparse(self.spark_url).netloc
|
||||
self.path = urlparse(self.spark_url).path
|
||||
# 和wenxin使用相同的session机制
|
||||
self.sessions = SessionManager(BaiduWenxinSession, model=const.XUNFEI)
|
||||
|
||||
def reply(self, query, context: Context = None) -> Reply:
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[XunFei] query={}".format(query))
|
||||
session_id = context["session_id"]
|
||||
request_id = self.gen_request_id(session_id)
|
||||
reply_map[request_id] = ""
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
threading.Thread(target=self.create_web_socket,
|
||||
args=(session.messages, request_id)).start()
|
||||
depth = 0
|
||||
time.sleep(0.1)
|
||||
t1 = time.time()
|
||||
usage = {}
|
||||
while depth <= 300:
|
||||
try:
|
||||
data_queue = queue_map.get(request_id)
|
||||
if not data_queue:
|
||||
depth += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
data_item = data_queue.get(block=True, timeout=0.1)
|
||||
if data_item.is_end:
|
||||
# 请求结束
|
||||
del queue_map[request_id]
|
||||
if data_item.reply:
|
||||
reply_map[request_id] += data_item.reply
|
||||
usage = data_item.usage
|
||||
break
|
||||
|
||||
reply_map[request_id] += data_item.reply
|
||||
depth += 1
|
||||
except Exception as e:
|
||||
depth += 1
|
||||
continue
|
||||
t2 = time.time()
|
||||
logger.info(
|
||||
f"[XunFei-API] response={reply_map[request_id]}, time={t2 - t1}s, usage={usage}"
|
||||
)
|
||||
self.sessions.session_reply(reply_map[request_id], session_id,
|
||||
usage.get("total_tokens"))
|
||||
reply = Reply(ReplyType.TEXT, reply_map[request_id])
|
||||
del reply_map[request_id]
|
||||
return reply
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR,
|
||||
"Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def create_web_socket(self, prompt, session_id, temperature=0.5):
|
||||
logger.info(f"[XunFei] start connect, prompt={prompt}")
|
||||
websocket.enableTrace(False)
|
||||
wsUrl = self.create_url()
|
||||
ws = websocket.WebSocketApp(wsUrl,
|
||||
on_message=on_message,
|
||||
on_error=on_error,
|
||||
on_close=on_close,
|
||||
on_open=on_open)
|
||||
data_queue = queue.Queue(1000)
|
||||
queue_map[session_id] = data_queue
|
||||
ws.appid = self.app_id
|
||||
ws.question = prompt
|
||||
ws.domain = self.domain
|
||||
ws.session_id = session_id
|
||||
ws.temperature = temperature
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def gen_request_id(self, session_id: str):
|
||||
return session_id + "_" + str(int(time.time())) + "" + str(
|
||||
random.randint(0, 100))
|
||||
|
||||
# 生成url
|
||||
def create_url(self):
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + self.host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.api_secret.encode('utf-8'),
|
||||
signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(
|
||||
encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", ' \
|
||||
f'signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(
|
||||
authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {"authorization": authorization, "date": date, "host": self.host}
|
||||
# 拼接鉴权参数,生成url
|
||||
url = self.spark_url + '?' + urlencode(v)
|
||||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||||
return url
|
||||
|
||||
def gen_params(self, appid, domain, question):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
class ReplyItem:
|
||||
def __init__(self, reply, usage=None, is_end=False):
|
||||
self.is_end = is_end
|
||||
self.reply = reply
|
||||
self.usage = usage
|
||||
|
||||
|
||||
# 收到websocket错误的处理
|
||||
def on_error(ws, error):
|
||||
logger.error(f"[XunFei] error: {str(error)}")
|
||||
|
||||
|
||||
# 收到websocket关闭的处理
|
||||
def on_close(ws, one, two):
|
||||
data_queue = queue_map.get(ws.session_id)
|
||||
data_queue.put("END")
|
||||
|
||||
|
||||
# 收到websocket连接建立的处理
|
||||
def on_open(ws):
|
||||
logger.info(f"[XunFei] Start websocket, session_id={ws.session_id}")
|
||||
thread.start_new_thread(run, (ws, ))
|
||||
|
||||
|
||||
def run(ws, *args):
|
||||
data = json.dumps(
|
||||
gen_params(appid=ws.appid,
|
||||
domain=ws.domain,
|
||||
question=ws.question,
|
||||
temperature=ws.temperature))
|
||||
ws.send(data)
|
||||
|
||||
|
||||
# Websocket 操作
|
||||
# 收到websocket消息的处理
|
||||
def on_message(ws, message):
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
logger.error(f'请求错误: {code}, {data}')
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
data_queue = queue_map.get(ws.session_id)
|
||||
if not data_queue:
|
||||
logger.error(
|
||||
f"[XunFei] can't find data queue, session_id={ws.session_id}")
|
||||
return
|
||||
reply_item = ReplyItem(content)
|
||||
if status == 2:
|
||||
usage = data["payload"].get("usage")
|
||||
reply_item = ReplyItem(content, usage)
|
||||
reply_item.is_end = True
|
||||
ws.close()
|
||||
data_queue.put(reply_item)
|
||||
|
||||
|
||||
def gen_params(appid, domain, question, temperature=0.5):
|
||||
"""
|
||||
通过appid和用户的提问来生成请参数
|
||||
"""
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": appid,
|
||||
"uid": "1234"
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": domain,
|
||||
"temperature": temperature,
|
||||
"random_threshold": 0.5,
|
||||
"max_tokens": 2048,
|
||||
"auditing": "default"
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": question
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
29
bot/zhipuai/zhipu_ai_image.py
Normal file
29
bot/zhipuai/zhipu_ai_image.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
|
||||
|
||||
# ZhipuAI提供的画图接口
|
||||
|
||||
class ZhipuAIImage(object):
|
||||
def __init__(self):
|
||||
from zhipuai import ZhipuAI
|
||||
self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
|
||||
|
||||
def create_img(self, query, retry_count=0, api_key=None, api_base=None):
|
||||
try:
|
||||
if conf().get("rate_limit_dalle"):
|
||||
return False, "请求太快了,请休息一下再问我吧"
|
||||
logger.info("[ZHIPU_AI] image_query={}".format(query))
|
||||
response = self.client.images.generations(
|
||||
prompt=query,
|
||||
n=1, # 每次生成图片的数量
|
||||
model=conf().get("text_to_image") or "cogview-3",
|
||||
size=conf().get("image_create_size", "1024x1024"), # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
quality="standard",
|
||||
)
|
||||
image_url = response.data[0].url
|
||||
logger.info("[ZHIPU_AI] image_url={}".format(image_url))
|
||||
return True, image_url
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return False, "画图出现问题,请休息一下再问我吧"
|
||||
51
bot/zhipuai/zhipu_ai_session.py
Normal file
51
bot/zhipuai/zhipu_ai_session.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from bot.session_manager import Session
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class ZhipuAISession(Session):
|
||||
def __init__(self, session_id, system_prompt=None, model="glm-4"):
|
||||
super().__init__(session_id, system_prompt)
|
||||
self.model = model
|
||||
self.reset()
|
||||
|
||||
def discard_exceeding(self, max_tokens, cur_tokens=None):
|
||||
precise = True
|
||||
try:
|
||||
cur_tokens = self.calc_tokens()
|
||||
except Exception as e:
|
||||
precise = False
|
||||
if cur_tokens is None:
|
||||
raise e
|
||||
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
|
||||
while cur_tokens > max_tokens:
|
||||
if len(self.messages) > 2:
|
||||
self.messages.pop(1)
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
|
||||
self.messages.pop(1)
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
break
|
||||
elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
|
||||
logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
|
||||
break
|
||||
else:
|
||||
logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
|
||||
len(self.messages)))
|
||||
break
|
||||
if precise:
|
||||
cur_tokens = self.calc_tokens()
|
||||
else:
|
||||
cur_tokens = cur_tokens - max_tokens
|
||||
return cur_tokens
|
||||
|
||||
def calc_tokens(self):
|
||||
return num_tokens_from_messages(self.messages, self.model)
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages, model):
|
||||
tokens = 0
|
||||
for msg in messages:
|
||||
tokens += len(msg["content"])
|
||||
return tokens
|
||||
149
bot/zhipuai/zhipuai_bot.py
Normal file
149
bot/zhipuai/zhipuai_bot.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# encoding:utf-8
|
||||
|
||||
import time
|
||||
|
||||
import openai
|
||||
import openai.error
|
||||
from bot.bot import Bot
|
||||
from bot.zhipuai.zhipu_ai_session import ZhipuAISession
|
||||
from bot.zhipuai.zhipu_ai_image import ZhipuAIImage
|
||||
from bot.session_manager import SessionManager
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
|
||||
# ZhipuAI对话模型API
|
||||
class ZHIPUAIBot(Bot, ZhipuAIImage):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sessions = SessionManager(ZhipuAISession, model=conf().get("model") or "ZHIPU_AI")
|
||||
self.args = {
|
||||
"model": conf().get("model") or "glm-4", # 对话模型的名称
|
||||
"temperature": conf().get("temperature", 0.9), # 值在(0,1)之间(智谱AI 的温度不能取 0 或者 1)
|
||||
"top_p": conf().get("top_p", 0.7), # 值在(0,1)之间(智谱AI 的 top_p 不能取 0 或者 1)
|
||||
}
|
||||
self.client = ZhipuAI(api_key=conf().get("zhipu_ai_api_key"))
|
||||
|
||||
def reply(self, query, context=None):
|
||||
# acquire reply content
|
||||
if context.type == ContextType.TEXT:
|
||||
logger.info("[ZHIPU_AI] query={}".format(query))
|
||||
|
||||
session_id = context["session_id"]
|
||||
reply = None
|
||||
clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
|
||||
if query in clear_memory_commands:
|
||||
self.sessions.clear_session(session_id)
|
||||
reply = Reply(ReplyType.INFO, "记忆已清除")
|
||||
elif query == "#清除所有":
|
||||
self.sessions.clear_all_session()
|
||||
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
|
||||
elif query == "#更新配置":
|
||||
load_config()
|
||||
reply = Reply(ReplyType.INFO, "配置已更新")
|
||||
if reply:
|
||||
return reply
|
||||
session = self.sessions.session_query(query, session_id)
|
||||
logger.debug("[ZHIPU_AI] session query={}".format(session.messages))
|
||||
|
||||
api_key = context.get("openai_api_key") or openai.api_key
|
||||
model = context.get("gpt_model")
|
||||
new_args = None
|
||||
if model:
|
||||
new_args = self.args.copy()
|
||||
new_args["model"] = model
|
||||
# if context.get('stream'):
|
||||
# # reply in stream
|
||||
# return self.reply_text_stream(query, new_query, session_id)
|
||||
|
||||
reply_content = self.reply_text(session, api_key, args=new_args)
|
||||
logger.debug(
|
||||
"[ZHIPU_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
|
||||
session.messages,
|
||||
session_id,
|
||||
reply_content["content"],
|
||||
reply_content["completion_tokens"],
|
||||
)
|
||||
)
|
||||
if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
elif reply_content["completion_tokens"] > 0:
|
||||
self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
|
||||
reply = Reply(ReplyType.TEXT, reply_content["content"])
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, reply_content["content"])
|
||||
logger.debug("[ZHIPU_AI] reply {} used 0 tokens.".format(reply_content))
|
||||
return reply
|
||||
elif context.type == ContextType.IMAGE_CREATE:
|
||||
ok, retstring = self.create_img(query, 0)
|
||||
reply = None
|
||||
if ok:
|
||||
reply = Reply(ReplyType.IMAGE_URL, retstring)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, retstring)
|
||||
return reply
|
||||
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
|
||||
return reply
|
||||
|
||||
def reply_text(self, session: ZhipuAISession, api_key=None, args=None, retry_count=0) -> dict:
|
||||
"""
|
||||
call openai's ChatCompletion to get the answer
|
||||
:param session: a conversation session
|
||||
:param session_id: session id
|
||||
:param retry_count: retry count
|
||||
:return: {}
|
||||
"""
|
||||
try:
|
||||
# if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token():
|
||||
# raise openai.error.RateLimitError("RateLimitError: rate limit exceeded")
|
||||
# if api_key == None, the default openai.api_key will be used
|
||||
if args is None:
|
||||
args = self.args
|
||||
# response = openai.ChatCompletion.create(api_key=api_key, messages=session.messages, **args)
|
||||
response = self.client.chat.completions.create(messages=session.messages, **args)
|
||||
# logger.debug("[ZHIPU_AI] response={}".format(response))
|
||||
# logger.info("[ZHIPU_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
|
||||
|
||||
return {
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"content": response.choices[0].message.content,
|
||||
}
|
||||
except Exception as e:
|
||||
need_retry = retry_count < 2
|
||||
result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
|
||||
if isinstance(e, openai.error.RateLimitError):
|
||||
logger.warn("[ZHIPU_AI] RateLimitError: {}".format(e))
|
||||
result["content"] = "提问太快啦,请休息一下再问我吧"
|
||||
if need_retry:
|
||||
time.sleep(20)
|
||||
elif isinstance(e, openai.error.Timeout):
|
||||
logger.warn("[ZHIPU_AI] Timeout: {}".format(e))
|
||||
result["content"] = "我没有收到你的消息"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
elif isinstance(e, openai.error.APIError):
|
||||
logger.warn("[ZHIPU_AI] Bad Gateway: {}".format(e))
|
||||
result["content"] = "请再问我一次"
|
||||
if need_retry:
|
||||
time.sleep(10)
|
||||
elif isinstance(e, openai.error.APIConnectionError):
|
||||
logger.warn("[ZHIPU_AI] APIConnectionError: {}".format(e))
|
||||
result["content"] = "我连接不到你的网络"
|
||||
if need_retry:
|
||||
time.sleep(5)
|
||||
else:
|
||||
logger.exception("[ZHIPU_AI] Exception: {}".format(e), e)
|
||||
need_retry = False
|
||||
self.sessions.clear_session(session.session_id)
|
||||
|
||||
if need_retry:
|
||||
logger.warn("[ZHIPU_AI] 第{}次重试".format(retry_count + 1))
|
||||
return self.reply_text(session, api_key, args, retry_count + 1)
|
||||
else:
|
||||
return result
|
||||
@@ -18,14 +18,33 @@ class Bridge(object):
|
||||
"text_to_voice": conf().get("text_to_voice", "google"),
|
||||
"translate": conf().get("translate", "baidu"),
|
||||
}
|
||||
model_type = conf().get("model")
|
||||
model_type = conf().get("model") or const.GPT35
|
||||
if model_type in ["text-davinci-003"]:
|
||||
self.btype["chat"] = const.OPEN_AI
|
||||
if conf().get("use_azure_chatgpt", False):
|
||||
self.btype["chat"] = const.CHATGPTONAZURE
|
||||
if model_type in ["wenxin", "wenxin-4"]:
|
||||
self.btype["chat"] = const.BAIDU
|
||||
if model_type in ["xunfei"]:
|
||||
self.btype["chat"] = const.XUNFEI
|
||||
if model_type in [const.QWEN]:
|
||||
self.btype["chat"] = const.QWEN
|
||||
if model_type in [const.GEMINI]:
|
||||
self.btype["chat"] = const.GEMINI
|
||||
if model_type in [const.ZHIPU_AI]:
|
||||
self.btype["chat"] = const.ZHIPU_AI
|
||||
|
||||
if conf().get("use_linkai") and conf().get("linkai_api_key"):
|
||||
self.btype["chat"] = const.LINKAI
|
||||
if not conf().get("voice_to_text") or conf().get("voice_to_text") in ["openai"]:
|
||||
self.btype["voice_to_text"] = const.LINKAI
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
self.btype["text_to_voice"] = const.LINKAI
|
||||
|
||||
if model_type in ["claude"]:
|
||||
self.btype["chat"] = const.CLAUDEAI
|
||||
self.bots = {}
|
||||
self.chat_bots = {}
|
||||
|
||||
def get_bot(self, typename):
|
||||
if self.bots.get(typename) is None:
|
||||
@@ -54,3 +73,14 @@ class Bridge(object):
|
||||
|
||||
def fetch_translate(self, text, from_lang="", to_lang="en") -> Reply:
|
||||
return self.get_bot("translate").translate(text, from_lang, to_lang)
|
||||
|
||||
def find_chat_bot(self, bot_type: str):
|
||||
if self.chat_bots.get(bot_type) is None:
|
||||
self.chat_bots[bot_type] = create_bot(bot_type)
|
||||
return self.chat_bots.get(bot_type)
|
||||
|
||||
def reset_bot(self):
|
||||
"""
|
||||
重置bot路由
|
||||
"""
|
||||
self.__init__()
|
||||
|
||||
@@ -7,9 +7,17 @@ class ContextType(Enum):
|
||||
TEXT = 1 # 文本消息
|
||||
VOICE = 2 # 音频消息
|
||||
IMAGE = 3 # 图片消息
|
||||
FILE = 4 # 文件信息
|
||||
VIDEO = 5 # 视频信息
|
||||
SHARING = 6 # 分享信息
|
||||
|
||||
IMAGE_CREATE = 10 # 创建图片命令
|
||||
ACCEPT_FRIEND = 19 # 同意好友请求
|
||||
JOIN_GROUP = 20 # 加入群聊
|
||||
PATPAT = 21 # 拍了拍
|
||||
FUNCTION = 22 # 函数调用
|
||||
EXIT_GROUP = 23 #退出
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@@ -8,9 +8,15 @@ class ReplyType(Enum):
|
||||
VOICE = 2 # 音频文件
|
||||
IMAGE = 3 # 图片文件
|
||||
IMAGE_URL = 4 # 图片URL
|
||||
|
||||
VIDEO_URL = 5 # 视频URL
|
||||
FILE = 6 # 文件
|
||||
CARD = 7 # 微信名片,仅支持ntchat
|
||||
InviteRoom = 8 # 邀请好友进群
|
||||
INFO = 9
|
||||
ERROR = 10
|
||||
TEXT_ = 11 # 强制文本
|
||||
VIDEO = 12
|
||||
MINIAPP = 13 # 小程序
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@@ -8,6 +8,7 @@ from bridge.reply import *
|
||||
|
||||
|
||||
class Channel(object):
|
||||
channel_type = ""
|
||||
NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE]
|
||||
|
||||
def startup(self):
|
||||
|
||||
@@ -1,36 +1,45 @@
|
||||
"""
|
||||
channel factory
|
||||
"""
|
||||
from common import const
|
||||
from .channel import Channel
|
||||
|
||||
|
||||
def create_channel(channel_type):
|
||||
def create_channel(channel_type) -> Channel:
|
||||
"""
|
||||
create a channel instance
|
||||
:param channel_type: channel type code
|
||||
:return: channel instance
|
||||
"""
|
||||
ch = Channel()
|
||||
if channel_type == "wx":
|
||||
from channel.wechat.wechat_channel import WechatChannel
|
||||
|
||||
return WechatChannel()
|
||||
ch = WechatChannel()
|
||||
elif channel_type == "wxy":
|
||||
from channel.wechat.wechaty_channel import WechatyChannel
|
||||
|
||||
return WechatyChannel()
|
||||
ch = WechatyChannel()
|
||||
elif channel_type == "terminal":
|
||||
from channel.terminal.terminal_channel import TerminalChannel
|
||||
|
||||
return TerminalChannel()
|
||||
ch = TerminalChannel()
|
||||
elif channel_type == "wechatmp":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
|
||||
return WechatMPChannel(passive_reply=True)
|
||||
ch = WechatMPChannel(passive_reply=True)
|
||||
elif channel_type == "wechatmp_service":
|
||||
from channel.wechatmp.wechatmp_channel import WechatMPChannel
|
||||
|
||||
return WechatMPChannel(passive_reply=False)
|
||||
ch = WechatMPChannel(passive_reply=False)
|
||||
elif channel_type == "wechatcom_app":
|
||||
from channel.wechatcom.wechatcomapp_channel import WechatComAppChannel
|
||||
|
||||
return WechatComAppChannel()
|
||||
raise RuntimeError
|
||||
ch = WechatComAppChannel()
|
||||
elif channel_type == "wework":
|
||||
from channel.wework.wework_channel import WeworkChannel
|
||||
ch = WeworkChannel()
|
||||
elif channel_type == const.FEISHU:
|
||||
from channel.feishu.feishu_channel import FeiShuChanel
|
||||
ch = FeiShuChanel()
|
||||
elif channel_type == const.DINGTALK:
|
||||
from channel.dingtalk.dingtalk_channel import DingTalkChanel
|
||||
ch = DingTalkChanel()
|
||||
else:
|
||||
raise RuntimeError
|
||||
ch.channel_type = channel_type
|
||||
return ch
|
||||
|
||||
@@ -4,13 +4,13 @@ import threading
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from concurrent import futures
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.channel import Channel
|
||||
from common.dequeue import Dequeue
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from common import memory
|
||||
from plugins import *
|
||||
|
||||
try:
|
||||
@@ -18,6 +18,8 @@ try:
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
|
||||
|
||||
# 抽象类, 它包含了与消息通道无关的通用处理逻辑
|
||||
class ChatChannel(Channel):
|
||||
@@ -26,7 +28,6 @@ class ChatChannel(Channel):
|
||||
futures = {} # 记录每个session_id提交到线程池的future对象, 用于重置会话时把没执行的future取消掉,正在执行的不会被取消
|
||||
sessions = {} # 用于控制并发,每个session_id同时只能有一个context在处理
|
||||
lock = threading.Lock() # 用于控制对sessions的访问
|
||||
handler_pool = ThreadPoolExecutor(max_workers=8) # 处理消息的线程池
|
||||
|
||||
def __init__(self):
|
||||
_thread = threading.Thread(target=self.consume)
|
||||
@@ -91,30 +92,53 @@ class ChatChannel(Channel):
|
||||
# 消息内容匹配过程,并处理content
|
||||
if ctype == ContextType.TEXT:
|
||||
if first_in and "」\n- - - - - - -" in content: # 初次匹配 过滤引用消息
|
||||
logger.debug(content)
|
||||
logger.debug("[WX]reference query skipped")
|
||||
return None
|
||||
|
||||
nick_name_black_list = conf().get("nick_name_black_list", [])
|
||||
if context.get("isgroup", False): # 群聊
|
||||
# 校验关键字
|
||||
match_prefix = check_prefix(content, conf().get("group_chat_prefix"))
|
||||
match_contain = check_contain(content, conf().get("group_chat_keyword"))
|
||||
flag = False
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
flag = True
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
if context["msg"].is_at:
|
||||
logger.info("[WX]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
if context["msg"].to_user_id != context["msg"].actual_user_id:
|
||||
if match_prefix is not None or match_contain is not None:
|
||||
flag = True
|
||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
||||
content = re.sub(pattern, r"", content)
|
||||
if match_prefix:
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
if context["msg"].is_at:
|
||||
nick_name = context["msg"].actual_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[WX] Nickname {nick_name} in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
logger.info("[WX]receive group at")
|
||||
if not conf().get("group_at_off", False):
|
||||
flag = True
|
||||
pattern = f"@{re.escape(self.name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
if isinstance(context["msg"].at_list, list):
|
||||
for at in context["msg"].at_list:
|
||||
pattern = f"@{re.escape(at)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", subtract_res)
|
||||
if subtract_res == content and context["msg"].self_display_name:
|
||||
# 前缀移除后没有变化,使用群昵称再次移除
|
||||
pattern = f"@{re.escape(context['msg'].self_display_name)}(\u2005|\u0020)"
|
||||
subtract_res = re.sub(pattern, r"", content)
|
||||
content = subtract_res
|
||||
if not flag:
|
||||
if context["origin_ctype"] == ContextType.VOICE:
|
||||
logger.info("[WX]receive group voice, but checkprefix didn't match")
|
||||
return None
|
||||
else: # 单聊
|
||||
nick_name = context["msg"].from_user_nickname
|
||||
if nick_name and nick_name in nick_name_black_list:
|
||||
# 黑名单过滤
|
||||
logger.warning(f"[WX] Nickname '{nick_name}' in In BlackList, ignore")
|
||||
return None
|
||||
|
||||
match_prefix = check_prefix(content, conf().get("single_chat_prefix", [""]))
|
||||
if match_prefix is not None: # 判断如果匹配到自定义前缀,则返回过滤掉前缀+空格后的内容
|
||||
content = content.replace(match_prefix, "", 1).strip()
|
||||
@@ -162,9 +186,8 @@ class ChatChannel(Channel):
|
||||
reply = e_context["reply"]
|
||||
if not e_context.is_pass():
|
||||
logger.debug("[WX] ready to handle context: type={}, content={}".format(context.type, context.content))
|
||||
if e_context.is_break():
|
||||
context["generate_breaked_by"] = e_context["breaked_by"]
|
||||
if context.type == ContextType.TEXT or context.type == ContextType.IMAGE_CREATE: # 文字和图片消息
|
||||
context["channel"] = e_context["channel"]
|
||||
reply = super().build_reply_content(context.content, context)
|
||||
elif context.type == ContextType.VOICE: # 语音消息
|
||||
cmsg = context["msg"]
|
||||
@@ -193,10 +216,17 @@ class ChatChannel(Channel):
|
||||
reply = self._generate_reply(new_context)
|
||||
else:
|
||||
return
|
||||
elif context.type == ContextType.IMAGE: # 图片消息,当前无默认逻辑
|
||||
elif context.type == ContextType.IMAGE: # 图片消息,当前仅做下载保存到本地的逻辑
|
||||
memory.USER_IMAGE_CACHE[context["session_id"]] = {
|
||||
"path": context.content,
|
||||
"msg": context.get("msg")
|
||||
}
|
||||
elif context.type == ContextType.SHARING: # 分享信息,当前无默认逻辑
|
||||
pass
|
||||
elif context.type == ContextType.FUNCTION or context.type == ContextType.FILE: # 文件消息及函数调用等,当前无默认逻辑
|
||||
pass
|
||||
else:
|
||||
logger.error("[WX] unknown context type: {}".format(context.type))
|
||||
logger.warning("[WX] unknown context type: {}".format(context.type))
|
||||
return
|
||||
return reply
|
||||
|
||||
@@ -222,14 +252,15 @@ class ChatChannel(Channel):
|
||||
reply = super().build_text_to_voice(reply.content)
|
||||
return self._decorate_reply(context, reply)
|
||||
if context.get("isgroup", False):
|
||||
reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text
|
||||
if not context.get("no_need_at", False):
|
||||
reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip()
|
||||
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "")
|
||||
else:
|
||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text
|
||||
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "")
|
||||
reply.content = reply_text
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
reply.content = "[" + str(reply.type) + "]\n" + reply.content
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE:
|
||||
elif reply.type == ReplyType.IMAGE_URL or reply.type == ReplyType.VOICE or reply.type == ReplyType.IMAGE or reply.type == ReplyType.FILE or reply.type == ReplyType.VIDEO or reply.type == ReplyType.VIDEO_URL:
|
||||
pass
|
||||
else:
|
||||
logger.error("[WX] unknown reply type: {}".format(reply.type))
|
||||
@@ -310,7 +341,7 @@ class ChatChannel(Channel):
|
||||
if not context_queue.empty():
|
||||
context = context_queue.get()
|
||||
logger.debug("[WX] consume context: {}".format(context))
|
||||
future: Future = self.handler_pool.submit(self._handle, context)
|
||||
future: Future = handler_pool.submit(self._handle, context)
|
||||
future.add_done_callback(self._thread_pool_callback(session_id, context=context))
|
||||
if session_id not in self.futures:
|
||||
self.futures[session_id] = []
|
||||
|
||||
@@ -24,9 +24,7 @@ is_at: 是否被at
|
||||
- (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在)
|
||||
actual_user_id: 实际发送者id (群聊必填)
|
||||
actual_user_nickname:实际发送者昵称
|
||||
|
||||
|
||||
|
||||
self_display_name: 自身的展示名,设置群昵称时,该字段表示群昵称
|
||||
|
||||
_prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等,
|
||||
_prepared: 是否已经调用过准备函数
|
||||
@@ -48,11 +46,14 @@ class ChatMessage(object):
|
||||
to_user_nickname = None
|
||||
other_user_id = None
|
||||
other_user_nickname = None
|
||||
my_msg = False
|
||||
self_display_name = None
|
||||
|
||||
is_group = False
|
||||
is_at = False
|
||||
actual_user_id = None
|
||||
actual_user_nickname = None
|
||||
at_list = None
|
||||
|
||||
_prepare_fn = None
|
||||
_prepared = False
|
||||
@@ -67,7 +68,7 @@ class ChatMessage(object):
|
||||
self._prepare_fn()
|
||||
|
||||
def __str__(self):
|
||||
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}".format(
|
||||
return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}, at_list={}".format(
|
||||
self.msg_id,
|
||||
self.create_time,
|
||||
self.ctype,
|
||||
@@ -82,4 +83,5 @@ class ChatMessage(object):
|
||||
self.is_at,
|
||||
self.actual_user_id,
|
||||
self.actual_user_nickname,
|
||||
self.at_list
|
||||
)
|
||||
|
||||
100
channel/dingtalk/dingtalk_channel.py
Normal file
100
channel/dingtalk/dingtalk_channel.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
钉钉通道接入
|
||||
|
||||
@author huiwen
|
||||
@Date 2023/11/28
|
||||
"""
|
||||
|
||||
# -*- coding=utf-8 -*-
|
||||
from channel.dingtalk.dingtalk_message import DingTalkMessage
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from common.expired_dict import ExpiredDict
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_channel import ChatChannel
|
||||
import logging
|
||||
from dingtalk_stream import AckMessage
|
||||
import dingtalk_stream
|
||||
|
||||
|
||||
@singleton
|
||||
class DingTalkChanel(ChatChannel, dingtalk_stream.ChatbotHandler):
|
||||
dingtalk_client_id = conf().get('dingtalk_client_id')
|
||||
dingtalk_client_secret = conf().get('dingtalk_client_secret')
|
||||
|
||||
def setup_logger(self):
|
||||
logger = logging.getLogger()
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter('%(asctime)s %(name)-8s %(levelname)-8s %(message)s [%(filename)s:%(lineno)d]'))
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
return logger
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
super(dingtalk_stream.ChatbotHandler, self).__init__()
|
||||
self.logger = self.setup_logger()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
logger.info("[dingtalk] client_id={}, client_secret={} ".format(
|
||||
self.dingtalk_client_id, self.dingtalk_client_secret))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
|
||||
def startup(self):
|
||||
credential = dingtalk_stream.Credential(self.dingtalk_client_id, self.dingtalk_client_secret)
|
||||
client = dingtalk_stream.DingTalkStreamClient(credential)
|
||||
client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self)
|
||||
client.start_forever()
|
||||
|
||||
def handle_single(self, cmsg: DingTalkMessage):
|
||||
# 处理单聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
expression = cmsg.my_msg
|
||||
cmsg.content = conf()["single_chat_prefix"][0] + cmsg.content
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
def handle_group(self, cmsg: DingTalkMessage):
|
||||
# 处理群聊消息
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
logger.debug("[dingtalk]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[dingtalk]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[dingtalk]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
expression = cmsg.my_msg
|
||||
cmsg.content = conf()["group_chat_prefix"][0] + cmsg.content
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
context['no_need_at'] = True
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
try:
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
dingtalk_msg = DingTalkMessage(incoming_message)
|
||||
if incoming_message.conversation_type == '1':
|
||||
self.handle_single(dingtalk_msg)
|
||||
else:
|
||||
self.handle_group(dingtalk_msg)
|
||||
return AckMessage.STATUS_OK, 'OK'
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.FAILED_MSG
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
incoming_message = context.kwargs['msg'].incoming_message
|
||||
self.reply_text(reply.content, incoming_message)
|
||||
44
channel/dingtalk/dingtalk_message.py
Normal file
44
channel/dingtalk/dingtalk_message.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
import json
|
||||
import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
from dingtalk_stream import ChatbotMessage
|
||||
|
||||
class DingTalkMessage(ChatMessage):
|
||||
def __init__(self, event: ChatbotMessage):
|
||||
super().__init__(event)
|
||||
|
||||
self.msg_id = event.message_id
|
||||
msg_type = event.message_type
|
||||
self.incoming_message =event
|
||||
self.sender_staff_id = event.sender_staff_id
|
||||
self.other_user_id = event.conversation_id
|
||||
self.create_time = event.create_at
|
||||
if event.conversation_type=="1":
|
||||
self.is_group = False
|
||||
else:
|
||||
self.is_group = True
|
||||
|
||||
|
||||
if msg_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
|
||||
self.content = event.text.content.strip()
|
||||
elif msg_type == "audio":
|
||||
|
||||
# 钉钉支持直接识别语音,所以此处将直接提取文字,当文字处理
|
||||
self.content = event.extensions['content']['recognition'].strip()
|
||||
self.ctype = ContextType.TEXT
|
||||
self.from_user_id = event.sender_id
|
||||
self.to_user_id = event.chatbot_user_id
|
||||
self.other_user_nickname = event.conversation_title
|
||||
|
||||
user_id = event.sender_id
|
||||
nickname =event.sender_nick
|
||||
|
||||
|
||||
|
||||
|
||||
254
channel/feishu/feishu_channel.py
Normal file
254
channel/feishu/feishu_channel.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
飞书通道接入
|
||||
|
||||
@author Saboteur7
|
||||
@Date 2023/11/19
|
||||
"""
|
||||
|
||||
# -*- coding=utf-8 -*-
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
import web
|
||||
from channel.feishu.feishu_message import FeishuMessage
|
||||
from bridge.context import Context
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from config import conf
|
||||
from common.expired_dict import ExpiredDict
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_channel import ChatChannel, check_prefix
|
||||
from common import utils
|
||||
import json
|
||||
import os
|
||||
|
||||
URL_VERIFICATION = "url_verification"
|
||||
|
||||
|
||||
@singleton
|
||||
class FeiShuChanel(ChatChannel):
|
||||
feishu_app_id = conf().get('feishu_app_id')
|
||||
feishu_app_secret = conf().get('feishu_app_secret')
|
||||
feishu_token = conf().get('feishu_token')
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# 历史消息id暂存,用于幂等控制
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 7.1)
|
||||
logger.info("[FeiShu] app_id={}, app_secret={} verification_token={}".format(
|
||||
self.feishu_app_id, self.feishu_app_secret, self.feishu_token))
|
||||
# 无需群校验和前缀
|
||||
conf()["group_name_white_list"] = ["ALL_GROUP"]
|
||||
conf()["single_chat_prefix"] = []
|
||||
|
||||
def startup(self):
|
||||
urls = (
|
||||
'/', 'channel.feishu.feishu_channel.FeishuController'
|
||||
)
|
||||
app = web.application(urls, globals(), autoreload=False)
|
||||
port = conf().get("feishu_port", 9891)
|
||||
web.httpserver.runsimple(app.wsgifunc(), ("0.0.0.0", port))
|
||||
|
||||
def send(self, reply: Reply, context: Context):
|
||||
msg = context.get("msg")
|
||||
is_group = context["isgroup"]
|
||||
if msg:
|
||||
access_token = msg.access_token
|
||||
else:
|
||||
access_token = self.fetch_access_token()
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
msg_type = "text"
|
||||
logger.info(f"[FeiShu] start send reply message, type={context.type}, content={reply.content}")
|
||||
reply_content = reply.content
|
||||
content_key = "text"
|
||||
if reply.type == ReplyType.IMAGE_URL:
|
||||
# 图片上传
|
||||
reply_content = self._upload_image_url(reply.content, access_token)
|
||||
if not reply_content:
|
||||
logger.warning("[FeiShu] upload file failed")
|
||||
return
|
||||
msg_type = "image"
|
||||
content_key = "image_key"
|
||||
if is_group:
|
||||
# 群聊中直接回复
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{msg.msg_id}/reply"
|
||||
data = {
|
||||
"msg_type": msg_type,
|
||||
"content": json.dumps({content_key: reply_content})
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, json=data, timeout=(5, 10))
|
||||
else:
|
||||
url = "https://open.feishu.cn/open-apis/im/v1/messages"
|
||||
params = {"receive_id_type": context.get("receive_id_type") or "open_id"}
|
||||
data = {
|
||||
"receive_id": context.get("receiver"),
|
||||
"msg_type": msg_type,
|
||||
"content": json.dumps({content_key: reply_content})
|
||||
}
|
||||
res = requests.post(url=url, headers=headers, params=params, json=data, timeout=(5, 10))
|
||||
res = res.json()
|
||||
if res.get("code") == 0:
|
||||
logger.info(f"[FeiShu] send message success")
|
||||
else:
|
||||
logger.error(f"[FeiShu] send message failed, code={res.get('code')}, msg={res.get('msg')}")
|
||||
|
||||
|
||||
def fetch_access_token(self) -> str:
|
||||
url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/"
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
req_body = {
|
||||
"app_id": self.feishu_app_id,
|
||||
"app_secret": self.feishu_app_secret
|
||||
}
|
||||
data = bytes(json.dumps(req_body), encoding='utf8')
|
||||
response = requests.post(url=url, data=data, headers=headers)
|
||||
if response.status_code == 200:
|
||||
res = response.json()
|
||||
if res.get("code") != 0:
|
||||
logger.error(f"[FeiShu] get tenant_access_token error, code={res.get('code')}, msg={res.get('msg')}")
|
||||
return ""
|
||||
else:
|
||||
return res.get("tenant_access_token")
|
||||
else:
|
||||
logger.error(f"[FeiShu] fetch token error, res={response}")
|
||||
|
||||
|
||||
def _upload_image_url(self, img_url, access_token):
|
||||
logger.debug(f"[WX] start download image, img_url={img_url}")
|
||||
response = requests.get(img_url)
|
||||
suffix = utils.get_path_suffix(img_url)
|
||||
temp_name = str(uuid.uuid4()) + "." + suffix
|
||||
if response.status_code == 200:
|
||||
# 将图片内容保存为临时文件
|
||||
with open(temp_name, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
# upload
|
||||
upload_url = "https://open.feishu.cn/open-apis/im/v1/images"
|
||||
data = {
|
||||
'image_type': 'message'
|
||||
}
|
||||
headers = {
|
||||
'Authorization': f'Bearer {access_token}',
|
||||
}
|
||||
with open(temp_name, "rb") as file:
|
||||
upload_response = requests.post(upload_url, files={"image": file}, data=data, headers=headers)
|
||||
logger.info(f"[FeiShu] upload file, res={upload_response.content}")
|
||||
os.remove(temp_name)
|
||||
return upload_response.json().get("data").get("image_key")
|
||||
|
||||
|
||||
|
||||
class FeishuController:
|
||||
# 类常量
|
||||
FAILED_MSG = '{"success": false}'
|
||||
SUCCESS_MSG = '{"success": true}'
|
||||
MESSAGE_RECEIVE_TYPE = "im.message.receive_v1"
|
||||
|
||||
def GET(self):
|
||||
return "Feishu service start success!"
|
||||
|
||||
def POST(self):
|
||||
try:
|
||||
channel = FeiShuChanel()
|
||||
|
||||
request = json.loads(web.data().decode("utf-8"))
|
||||
logger.debug(f"[FeiShu] receive request: {request}")
|
||||
|
||||
# 1.事件订阅回调验证
|
||||
if request.get("type") == URL_VERIFICATION:
|
||||
varify_res = {"challenge": request.get("challenge")}
|
||||
return json.dumps(varify_res)
|
||||
|
||||
# 2.消息接收处理
|
||||
# token 校验
|
||||
header = request.get("header")
|
||||
if not header or header.get("token") != channel.feishu_token:
|
||||
return self.FAILED_MSG
|
||||
|
||||
# 处理消息事件
|
||||
event = request.get("event")
|
||||
if header.get("event_type") == self.MESSAGE_RECEIVE_TYPE and event:
|
||||
if not event.get("message") or not event.get("sender"):
|
||||
logger.warning(f"[FeiShu] invalid message, msg={request}")
|
||||
return self.FAILED_MSG
|
||||
msg = event.get("message")
|
||||
|
||||
# 幂等判断
|
||||
if channel.receivedMsgs.get(msg.get("message_id")):
|
||||
logger.warning(f"[FeiShu] repeat msg filtered, event_id={header.get('event_id')}")
|
||||
return self.SUCCESS_MSG
|
||||
channel.receivedMsgs[msg.get("message_id")] = True
|
||||
|
||||
is_group = False
|
||||
chat_type = msg.get("chat_type")
|
||||
if chat_type == "group":
|
||||
if not msg.get("mentions") and msg.get("message_type") == "text":
|
||||
# 群聊中未@不响应
|
||||
return self.SUCCESS_MSG
|
||||
if msg.get("mentions")[0].get("name") != conf().get("feishu_bot_name") and msg.get("message_type") == "text":
|
||||
# 不是@机器人,不响应
|
||||
return self.SUCCESS_MSG
|
||||
# 群聊
|
||||
is_group = True
|
||||
receive_id_type = "chat_id"
|
||||
elif chat_type == "p2p":
|
||||
receive_id_type = "open_id"
|
||||
else:
|
||||
logger.warning("[FeiShu] message ignore")
|
||||
return self.SUCCESS_MSG
|
||||
# 构造飞书消息对象
|
||||
feishu_msg = FeishuMessage(event, is_group=is_group, access_token=channel.fetch_access_token())
|
||||
if not feishu_msg:
|
||||
return self.SUCCESS_MSG
|
||||
|
||||
context = self._compose_context(
|
||||
feishu_msg.ctype,
|
||||
feishu_msg.content,
|
||||
isgroup=is_group,
|
||||
msg=feishu_msg,
|
||||
receive_id_type=receive_id_type,
|
||||
no_need_at=True
|
||||
)
|
||||
if context:
|
||||
channel.produce(context)
|
||||
logger.info(f"[FeiShu] query={feishu_msg.content}, type={feishu_msg.ctype}")
|
||||
return self.SUCCESS_MSG
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.FAILED_MSG
|
||||
|
||||
def _compose_context(self, ctype: ContextType, content, **kwargs):
|
||||
context = Context(ctype, content)
|
||||
context.kwargs = kwargs
|
||||
if "origin_ctype" not in context:
|
||||
context["origin_ctype"] = ctype
|
||||
|
||||
cmsg = context["msg"]
|
||||
context["session_id"] = cmsg.from_user_id
|
||||
context["receiver"] = cmsg.other_user_id
|
||||
|
||||
if ctype == ContextType.TEXT:
|
||||
# 1.文本请求
|
||||
# 图片生成处理
|
||||
img_match_prefix = check_prefix(content, conf().get("image_create_prefix"))
|
||||
if img_match_prefix:
|
||||
content = content.replace(img_match_prefix, "", 1)
|
||||
context.type = ContextType.IMAGE_CREATE
|
||||
else:
|
||||
context.type = ContextType.TEXT
|
||||
context.content = content.strip()
|
||||
|
||||
elif context.type == ContextType.VOICE:
|
||||
# 2.语音请求
|
||||
if "desire_rtype" not in context and conf().get("voice_reply_voice"):
|
||||
context["desire_rtype"] = ReplyType.VOICE
|
||||
|
||||
return context
|
||||
63
channel/feishu/feishu_message.py
Normal file
63
channel/feishu/feishu_message.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
import json
|
||||
import requests
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from common import utils
|
||||
|
||||
|
||||
class FeishuMessage(ChatMessage):
|
||||
def __init__(self, event: dict, is_group=False, access_token=None):
|
||||
super().__init__(event)
|
||||
msg = event.get("message")
|
||||
sender = event.get("sender")
|
||||
self.access_token = access_token
|
||||
self.msg_id = msg.get("message_id")
|
||||
self.create_time = msg.get("create_time")
|
||||
self.is_group = is_group
|
||||
msg_type = msg.get("message_type")
|
||||
|
||||
if msg_type == "text":
|
||||
self.ctype = ContextType.TEXT
|
||||
content = json.loads(msg.get('content'))
|
||||
self.content = content.get("text").strip()
|
||||
elif msg_type == "file":
|
||||
self.ctype = ContextType.FILE
|
||||
content = json.loads(msg.get("content"))
|
||||
file_key = content.get("file_key")
|
||||
file_name = content.get("file_name")
|
||||
|
||||
self.content = TmpDir().path() + file_key + "." + utils.get_path_suffix(file_name)
|
||||
|
||||
def _download_file():
|
||||
# 如果响应状态码是200,则将响应内容写入本地文件
|
||||
url = f"https://open.feishu.cn/open-apis/im/v1/messages/{self.msg_id}/resources/{file_key}"
|
||||
headers = {
|
||||
"Authorization": "Bearer " + access_token,
|
||||
}
|
||||
params = {
|
||||
"type": "file"
|
||||
}
|
||||
response = requests.get(url=url, headers=headers, params=params)
|
||||
if response.status_code == 200:
|
||||
with open(self.content, "wb") as f:
|
||||
f.write(response.content)
|
||||
else:
|
||||
logger.info(f"[FeiShu] Failed to download file, key={file_key}, res={response.text}")
|
||||
self._prepare_fn = _download_file
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} ".format(msg_type))
|
||||
|
||||
self.from_user_id = sender.get("sender_id").get("open_id")
|
||||
self.to_user_id = event.get("app_id")
|
||||
if is_group:
|
||||
# 群聊
|
||||
self.other_user_id = msg.get("chat_id")
|
||||
self.actual_user_id = self.from_user_id
|
||||
self.content = self.content.replace("@_user_1", "").strip()
|
||||
self.actual_user_nickname = ""
|
||||
else:
|
||||
# 私聊
|
||||
self.other_user_id = self.from_user_id
|
||||
self.actual_user_id = self.from_user_id
|
||||
@@ -15,6 +15,7 @@ import requests
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel import chat_channel
|
||||
from channel.wechat.wechat_message import *
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common.log import logger
|
||||
@@ -25,7 +26,7 @@ from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
|
||||
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE])
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING])
|
||||
def handler_single_msg(msg):
|
||||
try:
|
||||
cmsg = WechatMessage(msg, False)
|
||||
@@ -36,7 +37,7 @@ def handler_single_msg(msg):
|
||||
return None
|
||||
|
||||
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE], isGroupChat=True)
|
||||
@itchat.msg_register([TEXT, VOICE, PICTURE, NOTE, ATTACHMENT, SHARING], isGroupChat=True)
|
||||
def handler_group_msg(msg):
|
||||
try:
|
||||
cmsg = WechatMessage(msg, True)
|
||||
@@ -53,11 +54,14 @@ def _check(func):
|
||||
if msgId in self.receivedMsgs:
|
||||
logger.info("Wechat message {} already received, ignore".format(msgId))
|
||||
return
|
||||
self.receivedMsgs[msgId] = cmsg
|
||||
self.receivedMsgs[msgId] = True
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if conf().get("hot_reload") == True and int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
if cmsg.my_msg and not cmsg.is_group:
|
||||
logger.debug("[WX]my message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
@@ -92,7 +96,7 @@ def qrCallback(uuid, status, qrcode):
|
||||
print(qr_api4)
|
||||
print(qr_api2)
|
||||
print(qr_api1)
|
||||
|
||||
_send_qr_code([qr_api1, qr_api2, qr_api3, qr_api4])
|
||||
qr = qrcode.QRCode(border=1)
|
||||
qr.add_data(url)
|
||||
qr.make(fit=True)
|
||||
@@ -105,24 +109,47 @@ class WechatChannel(ChatChannel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.receivedMsgs = ExpiredDict(60 * 60 * 24)
|
||||
self.receivedMsgs = ExpiredDict(60 * 60)
|
||||
self.auto_login_times = 0
|
||||
|
||||
def startup(self):
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get("hot_reload", False)
|
||||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
itchat.auto_login(
|
||||
enableCmdQR=2,
|
||||
hotReload=hotReload,
|
||||
statusStorageDir=status_path,
|
||||
qrCallback=qrCallback,
|
||||
)
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
self.name = itchat.instance.storageClass.nickName
|
||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# start message listener
|
||||
itchat.run()
|
||||
try:
|
||||
itchat.instance.receivingRetryCount = 600 # 修改断线超时时间
|
||||
# login by scan QRCode
|
||||
hotReload = conf().get("hot_reload", False)
|
||||
status_path = os.path.join(get_appdata_dir(), "itchat.pkl")
|
||||
itchat.auto_login(
|
||||
enableCmdQR=2,
|
||||
hotReload=hotReload,
|
||||
statusStorageDir=status_path,
|
||||
qrCallback=qrCallback,
|
||||
exitCallback=self.exitCallback,
|
||||
loginCallback=self.loginCallback
|
||||
)
|
||||
self.user_id = itchat.instance.storageClass.userName
|
||||
self.name = itchat.instance.storageClass.nickName
|
||||
logger.info("Wechat login success, user_id: {}, nickname: {}".format(self.user_id, self.name))
|
||||
# start message listener
|
||||
itchat.run()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def exitCallback(self):
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id and conf().get("use_linkai"):
|
||||
_send_logout()
|
||||
time.sleep(2)
|
||||
self.auto_login_times += 1
|
||||
if self.auto_login_times < 100:
|
||||
chat_channel.handler_pool._shutdown = False
|
||||
self.startup()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def loginCallback(self):
|
||||
logger.debug("Login success")
|
||||
_send_login_success()
|
||||
|
||||
# handle_* 系列函数处理收到的消息后构造Context,然后传入produce函数中处理Context和发送回复
|
||||
# Context包含了消息的所有信息,包括以下属性
|
||||
@@ -135,10 +162,12 @@ class WechatChannel(ChatChannel):
|
||||
# msg: ChatMessage消息对象
|
||||
# origin_ctype: 原始消息类型,语音转文字后,私聊时如果匹配前缀失败,会根据初始消息是否是语音来放宽触发规则
|
||||
# desire_rtype: 希望回复类型,默认是文本回复,设置为ReplyType.VOICE是语音回复
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: ChatMessage):
|
||||
# filter system message
|
||||
if cmsg.other_user_id in ["weixin"]:
|
||||
return
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if conf().get("speech_recognition") != True:
|
||||
return
|
||||
@@ -159,16 +188,18 @@ class WechatChannel(ChatChannel):
|
||||
@_check
|
||||
def handle_group(self, cmsg: ChatMessage):
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if conf().get("speech_recognition") != True:
|
||||
if conf().get("group_speech_recognition") != True:
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
|
||||
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.ACCEPT_FRIEND, ContextType.EXIT_GROUP]:
|
||||
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
# logger.debug("[WX]receive group msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
pass
|
||||
elif cmsg.ctype == ContextType.FILE:
|
||||
logger.debug(f"[WX]receive attachment msg, file_name={cmsg.content}")
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
@@ -189,10 +220,14 @@ class WechatChannel(ChatChannel):
|
||||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
logger.debug(f"[WX] start download image, img_url={img_url}")
|
||||
pic_res = requests.get(img_url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
size = 0
|
||||
for block in pic_res.iter_content(1024):
|
||||
size += len(block)
|
||||
image_storage.write(block)
|
||||
logger.info(f"[WX] download image success, size={size}, img_url={img_url}")
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
@@ -201,3 +236,48 @@ class WechatChannel(ChatChannel):
|
||||
image_storage.seek(0)
|
||||
itchat.send_image(image_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.FILE: # 新增文件回复类型
|
||||
file_storage = reply.content
|
||||
itchat.send_file(file_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendFile, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO: # 新增视频回复类型
|
||||
video_storage = reply.content
|
||||
itchat.send_video(video_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendFile, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL: # 新增视频URL回复类型
|
||||
video_url = reply.content
|
||||
logger.debug(f"[WX] start download video, video_url={video_url}")
|
||||
video_res = requests.get(video_url, stream=True)
|
||||
video_storage = io.BytesIO()
|
||||
size = 0
|
||||
for block in video_res.iter_content(1024):
|
||||
size += len(block)
|
||||
video_storage.write(block)
|
||||
logger.info(f"[WX] download video success, size={size}, video_url={video_url}")
|
||||
video_storage.seek(0)
|
||||
itchat.send_video(video_storage, toUserName=receiver)
|
||||
logger.info("[WX] sendVideo url={}, receiver={}".format(video_url, receiver))
|
||||
|
||||
def _send_login_success():
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id:
|
||||
chat_client.send_login_success()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def _send_logout():
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id:
|
||||
chat_client.send_logout()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def _send_qr_code(qrcode_list: list):
|
||||
try:
|
||||
from common.linkai_client import chat_client
|
||||
if chat_client.client_id:
|
||||
chat_client.send_qrcode(qrcode_list)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@@ -7,7 +7,6 @@ from common.tmp_dir import TmpDir
|
||||
from lib import itchat
|
||||
from lib.itchat.content import *
|
||||
|
||||
|
||||
class WechatMessage(ChatMessage):
|
||||
def __init__(self, itchat_msg, is_group=False):
|
||||
super().__init__(itchat_msg)
|
||||
@@ -28,13 +27,24 @@ class WechatMessage(ChatMessage):
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000:
|
||||
if is_group and ("加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"]):
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
# 这里只能得到nickname, actual_user_id还是机器人的id
|
||||
if "加入了群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[-1]
|
||||
elif "加入群聊" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif is_group and ("移出了群聊" in itchat_msg["Content"]):
|
||||
self.ctype = ContextType.EXIT_GROUP
|
||||
self.content = itchat_msg["Content"]
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
|
||||
elif "你已添加了" in itchat_msg["Content"]: #通过好友请求
|
||||
self.ctype = ContextType.ACCEPT_FRIEND
|
||||
self.content = itchat_msg["Content"]
|
||||
elif "拍了拍我" in itchat_msg["Content"]:
|
||||
self.ctype = ContextType.PATPAT
|
||||
self.content = itchat_msg["Content"]
|
||||
@@ -42,6 +52,14 @@ class WechatMessage(ChatMessage):
|
||||
self.actual_user_nickname = re.findall(r"\"(.*?)\"", itchat_msg["Content"])[0]
|
||||
else:
|
||||
raise NotImplementedError("Unsupported note message: " + itchat_msg["Content"])
|
||||
elif itchat_msg["Type"] == ATTACHMENT:
|
||||
self.ctype = ContextType.FILE
|
||||
self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径
|
||||
self._prepare_fn = lambda: itchat_msg.download(self.content)
|
||||
elif itchat_msg["Type"] == SHARING:
|
||||
self.ctype = ContextType.SHARING
|
||||
self.content = itchat_msg.get("Url")
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Unsupported message type: Type:{} MsgType:{}".format(itchat_msg["Type"], itchat_msg["MsgType"]))
|
||||
|
||||
@@ -57,13 +75,19 @@ class WechatMessage(ChatMessage):
|
||||
self.from_user_nickname = nickname
|
||||
if self.to_user_id == user_id:
|
||||
self.to_user_nickname = nickname
|
||||
try: # 陌生人时候, 'User'字段可能不存在
|
||||
try: # 陌生人时候, User字段可能不存在
|
||||
# my_msg 为True是表示是自己发送的消息
|
||||
self.my_msg = itchat_msg["ToUserName"] == itchat_msg["User"]["UserName"] and \
|
||||
itchat_msg["ToUserName"] != itchat_msg["FromUserName"]
|
||||
self.other_user_id = itchat_msg["User"]["UserName"]
|
||||
self.other_user_nickname = itchat_msg["User"]["NickName"]
|
||||
if self.other_user_id == self.from_user_id:
|
||||
self.from_user_nickname = self.other_user_nickname
|
||||
if self.other_user_id == self.to_user_id:
|
||||
self.to_user_nickname = self.other_user_nickname
|
||||
if itchat_msg["User"].get("Self"):
|
||||
# 自身的展示名,当设置了群昵称时,该字段表示群昵称
|
||||
self.self_display_name = itchat_msg["User"].get("Self").get("DisplayName")
|
||||
except KeyError as e: # 处理偶尔没有对方信息的情况
|
||||
logger.warn("[WX]get other_user_id failed: " + str(e))
|
||||
if self.from_user_id == user_id:
|
||||
@@ -74,5 +98,5 @@ class WechatMessage(ChatMessage):
|
||||
if self.is_group:
|
||||
self.is_at = itchat_msg["IsAt"]
|
||||
self.actual_user_id = itchat_msg["ActualUserName"]
|
||||
if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
|
||||
if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT, ContextType.EXIT_GROUP]:
|
||||
self.actual_user_nickname = itchat_msg["ActualNickName"]
|
||||
|
||||
@@ -49,7 +49,7 @@ class Query:
|
||||
|
||||
# New request
|
||||
if (
|
||||
from_user not in channel.cache_dict
|
||||
channel.cache_dict.get(from_user) is None
|
||||
and from_user not in channel.running
|
||||
or content.startswith("#")
|
||||
and message_id not in channel.request_cnt # insert the godcmd
|
||||
@@ -131,8 +131,10 @@ class Query:
|
||||
|
||||
# Only one request can access to the cached data
|
||||
try:
|
||||
(reply_type, reply_content) = channel.cache_dict.pop(from_user)
|
||||
except KeyError:
|
||||
(reply_type, reply_content) = channel.cache_dict[from_user].pop(0)
|
||||
if not channel.cache_dict[from_user]: # If popping the message makes the list empty, delete the user entry from cache
|
||||
del channel.cache_dict[from_user]
|
||||
except IndexError:
|
||||
return "success"
|
||||
|
||||
if reply_type == "text":
|
||||
@@ -146,7 +148,7 @@ class Query:
|
||||
max_split=1,
|
||||
)
|
||||
reply_text = splits[0] + continue_text
|
||||
channel.cache_dict[from_user] = ("text", splits[1])
|
||||
channel.cache_dict[from_user].append(("text", splits[1]))
|
||||
|
||||
logger.info(
|
||||
"[wechatmp] Request {} do send to {} {}: {}\n{}".format(
|
||||
|
||||
@@ -10,6 +10,7 @@ import requests
|
||||
import web
|
||||
from wechatpy.crypto import WeChatCrypto
|
||||
from wechatpy.exceptions import WeChatClientException
|
||||
from collections import defaultdict
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
@@ -20,7 +21,7 @@ from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.utils import split_string_by_utf8_length
|
||||
from config import conf
|
||||
from voice.audio_convert import any_to_mp3
|
||||
from voice.audio_convert import any_to_mp3, split_audio
|
||||
|
||||
# If using SSL, uncomment the following lines, and modify the certificate path.
|
||||
# from cheroot.server import HTTPServer
|
||||
@@ -46,7 +47,7 @@ class WechatMPChannel(ChatChannel):
|
||||
self.crypto = WeChatCrypto(token, aes_key, appid)
|
||||
if self.passive_reply:
|
||||
# Cache the reply to the user's first message
|
||||
self.cache_dict = dict()
|
||||
self.cache_dict = defaultdict(list)
|
||||
# Record whether the current message is being processed
|
||||
self.running = set()
|
||||
# Count the request from wechat official server by message_id
|
||||
@@ -82,24 +83,28 @@ class WechatMPChannel(ChatChannel):
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
logger.info("[wechatmp] text cached, receiver {}\n{}".format(receiver, reply_text))
|
||||
self.cache_dict[receiver] = ("text", reply_text)
|
||||
self.cache_dict[receiver].append(("text", reply_text))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
try:
|
||||
voice_file_path = reply.content
|
||||
with open(voice_file_path, "rb") as f:
|
||||
# support: <2M, <60s, mp3/wma/wav/amr
|
||||
response = self.client.material.add("voice", f)
|
||||
logger.debug("[wechatmp] upload voice response: {}".format(response))
|
||||
# 根据文件大小估计一个微信自动审核的时间,审核结束前返回将会导致语音无法播放,这个估计有待验证
|
||||
f_size = os.fstat(f.fileno()).st_size
|
||||
time.sleep(1.0 + 2 * f_size / 1024 / 1024)
|
||||
# todo check media_id
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver] = ("voice", media_id)
|
||||
voice_file_path = reply.content
|
||||
duration, files = split_audio(voice_file_path, 60 * 1000)
|
||||
if len(files) > 1:
|
||||
logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
|
||||
|
||||
for path in files:
|
||||
# support: <2M, <60s, mp3/wma/wav/amr
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
response = self.client.material.add("voice", f)
|
||||
logger.debug("[wechatmp] upload voice response: {}".format(response))
|
||||
f_size = os.fstat(f.fileno()).st_size
|
||||
time.sleep(1.0 + 2 * f_size / 1024 / 1024)
|
||||
# todo check media_id
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] voice uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver].append(("voice", media_id))
|
||||
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
@@ -119,7 +124,7 @@ class WechatMPChannel(ChatChannel):
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver] = ("image", media_id)
|
||||
self.cache_dict[receiver].append(("image", media_id))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
@@ -134,7 +139,7 @@ class WechatMPChannel(ChatChannel):
|
||||
return
|
||||
media_id = response["media_id"]
|
||||
logger.info("[wechatmp] image uploaded, receiver {}, media_id {}".format(receiver, media_id))
|
||||
self.cache_dict[receiver] = ("image", media_id)
|
||||
self.cache_dict[receiver].append(("image", media_id))
|
||||
else:
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.INFO or reply.type == ReplyType.ERROR:
|
||||
reply_text = reply.content
|
||||
@@ -162,13 +167,28 @@ class WechatMPChannel(ChatChannel):
|
||||
file_name = os.path.basename(file_path)
|
||||
file_type = "audio/mpeg"
|
||||
logger.info("[wechatmp] file_name: {}, file_type: {} ".format(file_name, file_type))
|
||||
# support: <2M, <60s, AMR\MP3
|
||||
response = self.client.media.upload("voice", (file_name, open(file_path, "rb"), file_type))
|
||||
logger.debug("[wechatmp] upload voice response: {}".format(response))
|
||||
media_ids = []
|
||||
duration, files = split_audio(file_path, 60 * 1000)
|
||||
if len(files) > 1:
|
||||
logger.info("[wechatmp] voice too long {}s > 60s , split into {} parts".format(duration / 1000.0, len(files)))
|
||||
for path in files:
|
||||
# support: <2M, <60s, AMR\MP3
|
||||
response = self.client.media.upload("voice", (os.path.basename(path), open(path, "rb"), file_type))
|
||||
logger.debug("[wechatcom] upload voice response: {}".format(response))
|
||||
media_ids.append(response["media_id"])
|
||||
os.remove(path)
|
||||
except WeChatClientException as e:
|
||||
logger.error("[wechatmp] upload voice failed: {}".format(e))
|
||||
return
|
||||
self.client.message.send_voice(receiver, response["media_id"])
|
||||
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for media_id in media_ids:
|
||||
self.client.message.send_voice(receiver, media_id)
|
||||
time.sleep(1)
|
||||
logger.info("[wechatmp] Do send voice to {}".format(receiver))
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
|
||||
17
channel/wework/run.py
Normal file
17
channel/wework/run.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
import time
|
||||
os.environ['ntwork_LOG'] = "ERROR"
|
||||
import ntwork
|
||||
|
||||
wework = ntwork.WeWork()
|
||||
|
||||
|
||||
def forever():
|
||||
try:
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
ntwork.exit_()
|
||||
os._exit(0)
|
||||
|
||||
|
||||
326
channel/wework/wework_channel.py
Normal file
326
channel/wework/wework_channel.py
Normal file
@@ -0,0 +1,326 @@
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import threading
|
||||
os.environ['ntwork_LOG'] = "ERROR"
|
||||
import ntwork
|
||||
import requests
|
||||
import uuid
|
||||
|
||||
from bridge.context import *
|
||||
from bridge.reply import *
|
||||
from channel.chat_channel import ChatChannel
|
||||
from channel.wework.wework_message import *
|
||||
from channel.wework.wework_message import WeworkMessage
|
||||
from common.singleton import singleton
|
||||
from common.log import logger
|
||||
from common.time_check import time_checker
|
||||
from common.utils import compress_imgfile, fsize
|
||||
from config import conf
|
||||
from channel.wework.run import wework
|
||||
from channel.wework import run
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_wxid_by_name(room_members, group_wxid, name):
|
||||
if group_wxid in room_members:
|
||||
for member in room_members[group_wxid]['member_list']:
|
||||
if member['room_nickname'] == name or member['username'] == name:
|
||||
return member['user_id']
|
||||
return None # 如果没有找到对应的group_wxid或name,则返回None
|
||||
|
||||
|
||||
def download_and_compress_image(url, filename, quality=30):
|
||||
# 确定保存图片的目录
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
# 如果目录不存在,则创建目录
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
# 下载图片
|
||||
pic_res = requests.get(url, stream=True)
|
||||
image_storage = io.BytesIO()
|
||||
for block in pic_res.iter_content(1024):
|
||||
image_storage.write(block)
|
||||
|
||||
# 检查图片大小并可能进行压缩
|
||||
sz = fsize(image_storage)
|
||||
if sz >= 10 * 1024 * 1024: # 如果图片大于 10 MB
|
||||
logger.info("[wework] image too large, ready to compress, sz={}".format(sz))
|
||||
image_storage = compress_imgfile(image_storage, 10 * 1024 * 1024 - 1)
|
||||
logger.info("[wework] image compressed, sz={}".format(fsize(image_storage)))
|
||||
|
||||
# 将内存缓冲区的指针重置到起始位置
|
||||
image_storage.seek(0)
|
||||
|
||||
# 读取并保存图片
|
||||
image = Image.open(image_storage)
|
||||
image_path = os.path.join(directory, f"{filename}.png")
|
||||
image.save(image_path, "png")
|
||||
|
||||
return image_path
|
||||
|
||||
|
||||
def download_video(url, filename):
|
||||
# 确定保存视频的目录
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
# 如果目录不存在,则创建目录
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
# 下载视频
|
||||
response = requests.get(url, stream=True)
|
||||
total_size = 0
|
||||
|
||||
video_path = os.path.join(directory, f"{filename}.mp4")
|
||||
|
||||
with open(video_path, 'wb') as f:
|
||||
for block in response.iter_content(1024):
|
||||
total_size += len(block)
|
||||
|
||||
# 如果视频的总大小超过30MB (30 * 1024 * 1024 bytes),则停止下载并返回
|
||||
if total_size > 30 * 1024 * 1024:
|
||||
logger.info("[WX] Video is larger than 30MB, skipping...")
|
||||
return None
|
||||
|
||||
f.write(block)
|
||||
|
||||
return video_path
|
||||
|
||||
|
||||
def create_message(wework_instance, message, is_group):
|
||||
logger.debug(f"正在为{'群聊' if is_group else '单聊'}创建 WeworkMessage")
|
||||
cmsg = WeworkMessage(message, wework=wework_instance, is_group=is_group)
|
||||
logger.debug(f"cmsg:{cmsg}")
|
||||
return cmsg
|
||||
|
||||
|
||||
def handle_message(cmsg, is_group):
|
||||
logger.debug(f"准备用 WeworkChannel 处理{'群聊' if is_group else '单聊'}消息")
|
||||
if is_group:
|
||||
WeworkChannel().handle_group(cmsg)
|
||||
else:
|
||||
WeworkChannel().handle_single(cmsg)
|
||||
logger.debug(f"已用 WeworkChannel 处理完{'群聊' if is_group else '单聊'}消息")
|
||||
|
||||
|
||||
def _check(func):
|
||||
def wrapper(self, cmsg: ChatMessage):
|
||||
msgId = cmsg.msg_id
|
||||
create_time = cmsg.create_time # 消息时间戳
|
||||
if create_time is None:
|
||||
return func(self, cmsg)
|
||||
if int(create_time) < int(time.time()) - 60: # 跳过1分钟前的历史消息
|
||||
logger.debug("[WX]history message {} skipped".format(msgId))
|
||||
return
|
||||
return func(self, cmsg)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@wework.msg_register(
|
||||
[ntwork.MT_RECV_TEXT_MSG, ntwork.MT_RECV_IMAGE_MSG, 11072, ntwork.MT_RECV_LINK_CARD_MSG,ntwork.MT_RECV_FILE_MSG, ntwork.MT_RECV_VOICE_MSG])
|
||||
def all_msg_handler(wework_instance: ntwork.WeWork, message):
|
||||
logger.debug(f"收到消息: {message}")
|
||||
if 'data' in message:
|
||||
# 首先查找conversation_id,如果没有找到,则查找room_conversation_id
|
||||
conversation_id = message['data'].get('conversation_id', message['data'].get('room_conversation_id'))
|
||||
if conversation_id is not None:
|
||||
is_group = "R:" in conversation_id
|
||||
try:
|
||||
cmsg = create_message(wework_instance=wework_instance, message=message, is_group=is_group)
|
||||
except NotImplementedError as e:
|
||||
logger.error(f"[WX]{message.get('MsgId', 'unknown')} 跳过: {e}")
|
||||
return None
|
||||
delay = random.randint(1, 2)
|
||||
timer = threading.Timer(delay, handle_message, args=(cmsg, is_group))
|
||||
timer.start()
|
||||
else:
|
||||
logger.debug("消息数据中无 conversation_id")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def accept_friend_with_retries(wework_instance, user_id, corp_id):
|
||||
result = wework_instance.accept_friend(user_id, corp_id)
|
||||
logger.debug(f'result:{result}')
|
||||
|
||||
|
||||
# @wework.msg_register(ntwork.MT_RECV_FRIEND_MSG)
|
||||
# def friend(wework_instance: ntwork.WeWork, message):
|
||||
# data = message["data"]
|
||||
# user_id = data["user_id"]
|
||||
# corp_id = data["corp_id"]
|
||||
# logger.info(f"接收到好友请求,消息内容:{data}")
|
||||
# delay = random.randint(1, 180)
|
||||
# threading.Timer(delay, accept_friend_with_retries, args=(wework_instance, user_id, corp_id)).start()
|
||||
#
|
||||
# return None
|
||||
|
||||
|
||||
def get_with_retry(get_func, max_retries=5, delay=5):
|
||||
retries = 0
|
||||
result = None
|
||||
while retries < max_retries:
|
||||
result = get_func()
|
||||
if result:
|
||||
break
|
||||
logger.warning(f"获取数据失败,重试第{retries + 1}次······")
|
||||
retries += 1
|
||||
time.sleep(delay) # 等待一段时间后重试
|
||||
return result
|
||||
|
||||
|
||||
@singleton
|
||||
class WeworkChannel(ChatChannel):
|
||||
NOT_SUPPORT_REPLYTYPE = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def startup(self):
|
||||
smart = conf().get("wework_smart", True)
|
||||
wework.open(smart)
|
||||
logger.info("等待登录······")
|
||||
wework.wait_login()
|
||||
login_info = wework.get_login_info()
|
||||
self.user_id = login_info['user_id']
|
||||
self.name = login_info['nickname']
|
||||
logger.info(f"登录信息:>>>user_id:{self.user_id}>>>>>>>>name:{self.name}")
|
||||
logger.info("静默延迟60s,等待客户端刷新数据,请勿进行任何操作······")
|
||||
time.sleep(60)
|
||||
contacts = get_with_retry(wework.get_external_contacts)
|
||||
rooms = get_with_retry(wework.get_rooms)
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
if not contacts or not rooms:
|
||||
logger.error("获取contacts或rooms失败,程序退出")
|
||||
ntwork.exit_()
|
||||
os.exit(0)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
# 将contacts保存到json文件中
|
||||
with open(os.path.join(directory, 'wework_contacts.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(contacts, f, ensure_ascii=False, indent=4)
|
||||
with open(os.path.join(directory, 'wework_rooms.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(rooms, f, ensure_ascii=False, indent=4)
|
||||
# 创建一个空字典来保存结果
|
||||
result = {}
|
||||
|
||||
# 遍历列表中的每个字典
|
||||
for room in rooms['room_list']:
|
||||
# 获取聊天室ID
|
||||
room_wxid = room['conversation_id']
|
||||
|
||||
# 获取聊天室成员
|
||||
room_members = wework.get_room_members(room_wxid)
|
||||
|
||||
# 将聊天室成员保存到结果字典中
|
||||
result[room_wxid] = room_members
|
||||
|
||||
# 将结果保存到json文件中
|
||||
with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=4)
|
||||
logger.info("wework程序初始化完成········")
|
||||
run.forever()
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_single(self, cmsg: ChatMessage):
|
||||
if cmsg.from_user_id == cmsg.to_user_id:
|
||||
# ignore self reply
|
||||
return
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if not conf().get("speech_recognition"):
|
||||
return
|
||||
logger.debug("[WX]receive voice msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.PATPAT:
|
||||
logger.debug("[WX]receive patpat msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
logger.debug("[WX]receive text msg: {}, cmsg={}".format(json.dumps(cmsg._rawmsg, ensure_ascii=False), cmsg))
|
||||
else:
|
||||
logger.debug("[WX]receive msg: {}, cmsg={}".format(cmsg.content, cmsg))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=False, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
@time_checker
|
||||
@_check
|
||||
def handle_group(self, cmsg: ChatMessage):
|
||||
if cmsg.ctype == ContextType.VOICE:
|
||||
if not conf().get("speech_recognition"):
|
||||
return
|
||||
logger.debug("[WX]receive voice for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.IMAGE:
|
||||
logger.debug("[WX]receive image for group msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype in [ContextType.JOIN_GROUP, ContextType.PATPAT]:
|
||||
logger.debug("[WX]receive note msg: {}".format(cmsg.content))
|
||||
elif cmsg.ctype == ContextType.TEXT:
|
||||
pass
|
||||
else:
|
||||
logger.debug("[WX]receive group msg: {}".format(cmsg.content))
|
||||
context = self._compose_context(cmsg.ctype, cmsg.content, isgroup=True, msg=cmsg)
|
||||
if context:
|
||||
self.produce(context)
|
||||
|
||||
# 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息
|
||||
def send(self, reply: Reply, context: Context):
|
||||
logger.debug(f"context: {context}")
|
||||
receiver = context["receiver"]
|
||||
actual_user_id = context["msg"].actual_user_id
|
||||
if reply.type == ReplyType.TEXT or reply.type == ReplyType.TEXT_:
|
||||
match = re.search(r"^@(.*?)\n", reply.content)
|
||||
logger.debug(f"match: {match}")
|
||||
if match:
|
||||
new_content = re.sub(r"^@(.*?)\n", "\n", reply.content)
|
||||
at_list = [actual_user_id]
|
||||
logger.debug(f"new_content: {new_content}")
|
||||
wework.send_room_at_msg(receiver, new_content, at_list)
|
||||
else:
|
||||
wework.send_text(receiver, reply.content)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
|
||||
wework.send_text(receiver, reply.content)
|
||||
logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver))
|
||||
elif reply.type == ReplyType.IMAGE: # 从文件读取图片
|
||||
image_storage = reply.content
|
||||
image_storage.seek(0)
|
||||
# Read data from image_storage
|
||||
data = image_storage.read()
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp:
|
||||
temp_path = temp.name
|
||||
temp.write(data)
|
||||
# Send the image
|
||||
wework.send_image(receiver, temp_path)
|
||||
logger.info("[WX] sendImage, receiver={}".format(receiver))
|
||||
# Remove the temporary file
|
||||
os.remove(temp_path)
|
||||
elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片
|
||||
img_url = reply.content
|
||||
filename = str(uuid.uuid4())
|
||||
|
||||
# 调用你的函数,下载图片并保存为本地文件
|
||||
image_path = download_and_compress_image(img_url, filename)
|
||||
|
||||
wework.send_image(receiver, file_path=image_path)
|
||||
logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver))
|
||||
elif reply.type == ReplyType.VIDEO_URL:
|
||||
video_url = reply.content
|
||||
filename = str(uuid.uuid4())
|
||||
video_path = download_video(video_url, filename)
|
||||
|
||||
if video_path is None:
|
||||
# 如果视频太大,下载可能会被跳过,此时 video_path 将为 None
|
||||
wework.send_text(receiver, "抱歉,视频太大了!!!")
|
||||
else:
|
||||
wework.send_video(receiver, video_path)
|
||||
logger.info("[WX] sendVideo, receiver={}".format(receiver))
|
||||
elif reply.type == ReplyType.VOICE:
|
||||
current_dir = os.getcwd()
|
||||
voice_file = reply.content.split("/")[-1]
|
||||
reply.content = os.path.join(current_dir, "tmp", voice_file)
|
||||
wework.send_file(receiver, reply.content)
|
||||
logger.info("[WX] sendFile={}, receiver={}".format(reply.content, receiver))
|
||||
227
channel/wework/wework_message.py
Normal file
227
channel/wework/wework_message.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import pilk
|
||||
|
||||
from bridge.context import ContextType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from ntwork.const import send_type
|
||||
|
||||
|
||||
def get_with_retry(get_func, max_retries=5, delay=5):
|
||||
retries = 0
|
||||
result = None
|
||||
while retries < max_retries:
|
||||
result = get_func()
|
||||
if result:
|
||||
break
|
||||
logger.warning(f"获取数据失败,重试第{retries + 1}次······")
|
||||
retries += 1
|
||||
time.sleep(delay) # 等待一段时间后重试
|
||||
return result
|
||||
|
||||
|
||||
def get_room_info(wework, conversation_id):
|
||||
logger.debug(f"传入的 conversation_id: {conversation_id}")
|
||||
rooms = wework.get_rooms()
|
||||
if not rooms or 'room_list' not in rooms:
|
||||
logger.error(f"获取群聊信息失败: {rooms}")
|
||||
return None
|
||||
time.sleep(1)
|
||||
logger.debug(f"获取到的群聊信息: {rooms}")
|
||||
for room in rooms['room_list']:
|
||||
if room['conversation_id'] == conversation_id:
|
||||
return room
|
||||
return None
|
||||
|
||||
|
||||
def cdn_download(wework, message, file_name):
|
||||
data = message["data"]
|
||||
aes_key = data["cdn"]["aes_key"]
|
||||
file_size = data["cdn"]["size"]
|
||||
|
||||
# 获取当前工作目录,然后与文件名拼接得到保存路径
|
||||
current_dir = os.getcwd()
|
||||
save_path = os.path.join(current_dir, "tmp", file_name)
|
||||
|
||||
# 下载保存图片到本地
|
||||
if "url" in data["cdn"].keys() and "auth_key" in data["cdn"].keys():
|
||||
url = data["cdn"]["url"]
|
||||
auth_key = data["cdn"]["auth_key"]
|
||||
# result = wework.wx_cdn_download(url, auth_key, aes_key, file_size, save_path) # ntwork库本身接口有问题,缺失了aes_key这个参数
|
||||
"""
|
||||
下载wx类型的cdn文件,以https开头
|
||||
"""
|
||||
data = {
|
||||
'url': url,
|
||||
'auth_key': auth_key,
|
||||
'aes_key': aes_key,
|
||||
'size': file_size,
|
||||
'save_path': save_path
|
||||
}
|
||||
result = wework._WeWork__send_sync(send_type.MT_WXCDN_DOWNLOAD_MSG, data) # 直接用wx_cdn_download的接口内部实现来调用
|
||||
elif "file_id" in data["cdn"].keys():
|
||||
if message["type"] == 11042:
|
||||
file_type = 2
|
||||
elif message["type"] == 11045:
|
||||
file_type = 5
|
||||
file_id = data["cdn"]["file_id"]
|
||||
result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
|
||||
else:
|
||||
logger.error(f"something is wrong, data: {data}")
|
||||
return
|
||||
|
||||
# 输出下载结果
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
|
||||
def c2c_download_and_convert(wework, message, file_name):
|
||||
data = message["data"]
|
||||
aes_key = data["cdn"]["aes_key"]
|
||||
file_size = data["cdn"]["size"]
|
||||
file_type = 5
|
||||
file_id = data["cdn"]["file_id"]
|
||||
|
||||
current_dir = os.getcwd()
|
||||
save_path = os.path.join(current_dir, "tmp", file_name)
|
||||
result = wework.c2c_cdn_download(file_id, aes_key, file_size, file_type, save_path)
|
||||
logger.debug(result)
|
||||
|
||||
# 在下载完SILK文件之后,立即将其转换为WAV文件
|
||||
base_name, _ = os.path.splitext(save_path)
|
||||
wav_file = base_name + ".wav"
|
||||
pilk.silk_to_wav(save_path, wav_file, rate=24000)
|
||||
|
||||
# 删除SILK文件
|
||||
try:
|
||||
os.remove(save_path)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
class WeworkMessage(ChatMessage):
|
||||
def __init__(self, wework_msg, wework, is_group=False):
|
||||
try:
|
||||
super().__init__(wework_msg)
|
||||
self.msg_id = wework_msg['data'].get('conversation_id', wework_msg['data'].get('room_conversation_id'))
|
||||
# 使用.get()防止 'send_time' 键不存在时抛出错误
|
||||
self.create_time = wework_msg['data'].get("send_time")
|
||||
self.is_group = is_group
|
||||
self.wework = wework
|
||||
|
||||
if wework_msg["type"] == 11041: # 文本消息类型
|
||||
if any(substring in wework_msg['data']['content'] for substring in ("该消息类型暂不能展示", "不支持的消息类型")):
|
||||
return
|
||||
self.ctype = ContextType.TEXT
|
||||
self.content = wework_msg['data']['content']
|
||||
elif wework_msg["type"] == 11044: # 语音消息类型,需要缓存文件
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".silk"
|
||||
base_name, _ = os.path.splitext(file_name)
|
||||
file_name_2 = base_name + ".wav"
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.VOICE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name_2)
|
||||
self._prepare_fn = lambda: c2c_download_and_convert(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11042: # 图片消息类型,需要下载文件
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + ".jpg"
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.IMAGE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name)
|
||||
self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11045: # 文件消息
|
||||
print("文件消息")
|
||||
print(wework_msg)
|
||||
file_name = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
||||
file_name = file_name + wework_msg['data']['cdn']['file_name']
|
||||
current_dir = os.getcwd()
|
||||
self.ctype = ContextType.FILE
|
||||
self.content = os.path.join(current_dir, "tmp", file_name)
|
||||
self._prepare_fn = lambda: cdn_download(wework, wework_msg, file_name)
|
||||
elif wework_msg["type"] == 11047: # 链接消息
|
||||
self.ctype = ContextType.SHARING
|
||||
self.content = wework_msg['data']['url']
|
||||
elif wework_msg["type"] == 11072: # 新成员入群通知
|
||||
self.ctype = ContextType.JOIN_GROUP
|
||||
member_list = wework_msg['data']['member_list']
|
||||
self.actual_user_nickname = member_list[0]['name']
|
||||
self.actual_user_id = member_list[0]['user_id']
|
||||
self.content = f"{self.actual_user_nickname}加入了群聊!"
|
||||
directory = os.path.join(os.getcwd(), "tmp")
|
||||
rooms = get_with_retry(wework.get_rooms)
|
||||
if not rooms:
|
||||
logger.error("更新群信息失败···")
|
||||
else:
|
||||
result = {}
|
||||
for room in rooms['room_list']:
|
||||
# 获取聊天室ID
|
||||
room_wxid = room['conversation_id']
|
||||
|
||||
# 获取聊天室成员
|
||||
room_members = wework.get_room_members(room_wxid)
|
||||
|
||||
# 将聊天室成员保存到结果字典中
|
||||
result[room_wxid] = room_members
|
||||
with open(os.path.join(directory, 'wework_room_members.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=4)
|
||||
logger.info("有新成员加入,已自动更新群成员列表缓存!")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported message type: Type:{} MsgType:{}".format(wework_msg["type"], wework_msg["MsgType"]))
|
||||
|
||||
data = wework_msg['data']
|
||||
login_info = self.wework.get_login_info()
|
||||
logger.debug(f"login_info: {login_info}")
|
||||
nickname = f"{login_info['username']}({login_info['nickname']})" if login_info['nickname'] else login_info['username']
|
||||
user_id = login_info['user_id']
|
||||
|
||||
sender_id = data.get('sender')
|
||||
conversation_id = data.get('conversation_id')
|
||||
sender_name = data.get("sender_name")
|
||||
|
||||
self.from_user_id = user_id if sender_id == user_id else conversation_id
|
||||
self.from_user_nickname = nickname if sender_id == user_id else sender_name
|
||||
self.to_user_id = user_id
|
||||
self.to_user_nickname = nickname
|
||||
self.other_user_nickname = sender_name
|
||||
self.other_user_id = conversation_id
|
||||
|
||||
if self.is_group:
|
||||
conversation_id = data.get('conversation_id') or data.get('room_conversation_id')
|
||||
self.other_user_id = conversation_id
|
||||
if conversation_id:
|
||||
room_info = get_room_info(wework=wework, conversation_id=conversation_id)
|
||||
self.other_user_nickname = room_info.get('nickname', None) if room_info else None
|
||||
self.from_user_nickname = room_info.get('nickname', None) if room_info else None
|
||||
at_list = data.get('at_list', [])
|
||||
tmp_list = []
|
||||
for at in at_list:
|
||||
tmp_list.append(at['nickname'])
|
||||
at_list = tmp_list
|
||||
logger.debug(f"at_list: {at_list}")
|
||||
logger.debug(f"nickname: {nickname}")
|
||||
self.is_at = False
|
||||
if nickname in at_list or login_info['nickname'] in at_list or login_info['username'] in at_list:
|
||||
self.is_at = True
|
||||
self.at_list = at_list
|
||||
|
||||
# 检查消息内容是否包含@用户名。处理复制粘贴的消息,这类消息可能不会触发@通知,但内容中可能包含 "@用户名"。
|
||||
content = data.get('content', '')
|
||||
name = nickname
|
||||
pattern = f"@{re.escape(name)}(\u2005|\u0020)"
|
||||
if re.search(pattern, content):
|
||||
logger.debug(f"Wechaty message {self.msg_id} includes at")
|
||||
self.is_at = True
|
||||
|
||||
if not self.actual_user_id:
|
||||
self.actual_user_id = data.get("sender")
|
||||
self.actual_user_nickname = sender_name if self.ctype != ContextType.JOIN_GROUP else self.actual_user_nickname
|
||||
else:
|
||||
logger.error("群聊消息中没有找到 conversation_id 或 room_conversation_id")
|
||||
|
||||
logger.debug(f"WeworkMessage has been successfully instantiated with message id: {self.msg_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"在 WeworkMessage 的初始化过程中出现错误:{e}")
|
||||
raise e
|
||||
@@ -2,7 +2,27 @@
|
||||
OPEN_AI = "openAI"
|
||||
CHATGPT = "chatGPT"
|
||||
BAIDU = "baidu"
|
||||
XUNFEI = "xunfei"
|
||||
CHATGPTONAZURE = "chatGPTOnAzure"
|
||||
LINKAI = "linkai"
|
||||
CLAUDEAI = "claude"
|
||||
QWEN = "qwen"
|
||||
GEMINI = "gemini"
|
||||
ZHIPU_AI = "glm-4"
|
||||
|
||||
VERSION = "1.3.0"
|
||||
|
||||
# model
|
||||
GPT35 = "gpt-3.5-turbo"
|
||||
GPT4 = "gpt-4"
|
||||
GPT4_TURBO_PREVIEW = "gpt-4-0125-preview"
|
||||
GPT4_VISION_PREVIEW = "gpt-4-vision-preview"
|
||||
WHISPER_1 = "whisper-1"
|
||||
TTS_1 = "tts-1"
|
||||
TTS_1_HD = "tts-1-hd"
|
||||
|
||||
MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo",
|
||||
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI]
|
||||
|
||||
# channel
|
||||
FEISHU = "feishu"
|
||||
DINGTALK = "dingtalk"
|
||||
|
||||
55
common/linkai_client.py
Normal file
55
common/linkai_client.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from bridge.context import Context, ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from linkai import LinkAIClient, PushMsg
|
||||
from config import conf, pconf, plugin_config
|
||||
from plugins import PluginManager
|
||||
|
||||
|
||||
chat_client: LinkAIClient
|
||||
|
||||
class ChatClient(LinkAIClient):
|
||||
def __init__(self, api_key, host, channel):
|
||||
super().__init__(api_key, host)
|
||||
self.channel = channel
|
||||
self.client_type = channel.channel_type
|
||||
|
||||
def on_message(self, push_msg: PushMsg):
|
||||
session_id = push_msg.session_id
|
||||
msg_content = push_msg.msg_content
|
||||
logger.info(f"receive msg push, session_id={session_id}, msg_content={msg_content}")
|
||||
context = Context()
|
||||
context.type = ContextType.TEXT
|
||||
context["receiver"] = session_id
|
||||
context["isgroup"] = push_msg.is_group
|
||||
self.channel.send(Reply(ReplyType.TEXT, content=msg_content), context)
|
||||
|
||||
def on_config(self, config: dict):
|
||||
if not self.client_id:
|
||||
return
|
||||
logger.info(f"从控制台加载配置: {config}")
|
||||
local_config = conf()
|
||||
for key in local_config.keys():
|
||||
if config.get(key) is not None:
|
||||
local_config[key] = config.get(key)
|
||||
if config.get("reply_voice_mode"):
|
||||
if config.get("reply_voice_mode") == "voice_reply_voice":
|
||||
local_config["voice_reply_voice"] = True
|
||||
elif config.get("reply_voice_mode") == "always_reply_voice":
|
||||
local_config["always_reply_voice"] = True
|
||||
# if config.get("admin_password") and plugin_config["Godcmd"]:
|
||||
# plugin_config["Godcmd"]["password"] = config.get("admin_password")
|
||||
# PluginManager().instances["Godcmd"].reload()
|
||||
# if config.get("group_app_map") and pconf("linkai"):
|
||||
# local_group_map = {}
|
||||
# for mapping in config.get("group_app_map"):
|
||||
# local_group_map[mapping.get("group_name")] = mapping.get("app_code")
|
||||
# pconf("linkai")["group_app_map"] = local_group_map
|
||||
# PluginManager().instances["linkai"].reload()
|
||||
|
||||
|
||||
def start(channel):
|
||||
global chat_client
|
||||
chat_client = ChatClient(api_key=conf().get("linkai_api_key"),
|
||||
host="link-ai.chat", channel=channel)
|
||||
chat_client.start()
|
||||
3
common/memory.py
Normal file
3
common/memory.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from common.expired_dict import ExpiredDict
|
||||
|
||||
USER_IMAGE_CACHE = ExpiredDict(60 * 3)
|
||||
@@ -1,6 +1,6 @@
|
||||
import io
|
||||
import os
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@@ -49,3 +49,8 @@ def split_string_by_utf8_length(string, max_length, max_split=0):
|
||||
result.append(encoded[start:end].decode("utf-8"))
|
||||
start = end
|
||||
return result
|
||||
|
||||
|
||||
def get_path_suffix(path):
|
||||
path = urlparse(path).path
|
||||
return os.path.splitext(path)[-1].lstrip('.')
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
{
|
||||
"channel_type": "wx",
|
||||
"model": "",
|
||||
"open_ai_api_key": "YOUR API KEY",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"text_to_image": "dall-e-2",
|
||||
"voice_to_text": "openai",
|
||||
"text_to_voice": "openai",
|
||||
"proxy": "",
|
||||
"hot_reload": false,
|
||||
"single_chat_prefix": [
|
||||
"bot",
|
||||
"@bot"
|
||||
@@ -14,21 +19,17 @@
|
||||
"ChatGPT测试群",
|
||||
"ChatGPT测试群2"
|
||||
],
|
||||
"group_chat_in_one_session": [
|
||||
"ChatGPT测试群"
|
||||
],
|
||||
"image_create_prefix": [
|
||||
"画",
|
||||
"看",
|
||||
"找"
|
||||
"画"
|
||||
],
|
||||
"speech_recognition": false,
|
||||
"speech_recognition": true,
|
||||
"group_speech_recognition": false,
|
||||
"voice_reply_voice": false,
|
||||
"conversation_max_tokens": 1000,
|
||||
"conversation_max_tokens": 2500,
|
||||
"expires_in_seconds": 3600,
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"subscribe_msg": "感谢您的关注!\n这里是ChatGPT,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"temperature": 0.7,
|
||||
"subscribe_msg": "感谢您的关注!\n这里是AI智能助手,可以自由对话。\n支持语音对话。\n支持图片输入。\n支持图片输出,画字开头的消息将按要求创作图片。\n支持tool、角色扮演和文字冒险等丰富的插件。\n输入{trigger_prefix}#help 查看详细指令。",
|
||||
"use_linkai": false,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": ""
|
||||
|
||||
100
config.py
100
config.py
@@ -16,26 +16,35 @@ available_setting = {
|
||||
"open_ai_api_base": "https://api.openai.com/v1",
|
||||
"proxy": "", # openai使用的代理
|
||||
# chatgpt模型, 当use_azure_chatgpt为true时,其名称为Azure上model deployment名称
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": "gpt-3.5-turbo", # 还支持 gpt-4, gpt-4-turbo, wenxin, xunfei, qwen
|
||||
"use_azure_chatgpt": False, # 是否使用azure的chatgpt
|
||||
"azure_deployment_id": "", # azure 模型部署名称
|
||||
"azure_api_version": "", # azure api版本
|
||||
# Bot触发配置
|
||||
"single_chat_prefix": ["bot", "@bot"], # 私聊时文本需要包含该前缀才能触发机器人回复
|
||||
"single_chat_reply_prefix": "[bot] ", # 私聊时自动回复的前缀,用于区分真人
|
||||
"single_chat_reply_suffix": "", # 私聊时自动回复的后缀,\n 可以换行
|
||||
"group_chat_prefix": ["@bot"], # 群聊时包含该前缀则会触发机器人回复
|
||||
"group_chat_reply_prefix": "", # 群聊时自动回复的前缀
|
||||
"group_chat_reply_suffix": "", # 群聊时自动回复的后缀,\n 可以换行
|
||||
"group_chat_keyword": [], # 群聊时包含该关键词则会触发机器人回复
|
||||
"group_at_off": False, # 是否关闭群聊时@bot的触发
|
||||
"group_name_white_list": ["ChatGPT测试群", "ChatGPT测试群2"], # 开启自动回复的群名称列表
|
||||
"group_name_keyword_white_list": [], # 开启自动回复的群名称关键词列表
|
||||
"group_chat_in_one_session": ["ChatGPT测试群"], # 支持会话上下文共享的群名称
|
||||
"nick_name_black_list": [], # 用户昵称黑名单
|
||||
"group_welcome_msg": "", # 配置新人进群固定欢迎语,不配置则使用随机风格欢迎
|
||||
"trigger_by_self": False, # 是否允许机器人触发
|
||||
"text_to_image": "dall-e-2", # 图片生成模型,可选 dall-e-2, dall-e-3
|
||||
"image_proxy": True, # 是否需要图片代理,国内访问LinkAI时需要
|
||||
"image_create_prefix": ["画", "看", "找"], # 开启图片回复的前缀
|
||||
"concurrency_in_session": 1, # 同一会话最多有多少条消息在处理中,大于1可能乱序
|
||||
"image_create_size": "256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024
|
||||
"image_create_size": "256x256", # 图片大小,可选有 256x256, 512x512, 1024x1024 (dall-e-3默认为1024x1024)
|
||||
"group_chat_exit_group": False,
|
||||
# chatgpt会话参数
|
||||
"expires_in_seconds": 3600, # 无操作会话的过期时间
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。", # 人格描述
|
||||
# 人格描述
|
||||
"character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。",
|
||||
"conversation_max_tokens": 1000, # 支持上下文记忆的最多字符数
|
||||
# chatgpt限流配置
|
||||
"rate_limit_chatgpt": 20, # chatgpt的调用频率限制
|
||||
@@ -45,15 +54,38 @@ available_setting = {
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"request_timeout": 60, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"request_timeout": 180, # chatgpt请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间
|
||||
"timeout": 120, # chatgpt重试超时时间,在这个时间内,将会自动重试
|
||||
# Baidu 文心一言参数
|
||||
"baidu_wenxin_model": "eb-instant", # 默认使用ERNIE-Bot-turbo模型
|
||||
"baidu_wenxin_api_key": "", # Baidu api key
|
||||
"baidu_wenxin_secret_key": "", # Baidu secret key
|
||||
# 讯飞星火API
|
||||
"xunfei_app_id": "", # 讯飞应用ID
|
||||
"xunfei_api_key": "", # 讯飞 API key
|
||||
"xunfei_api_secret": "", # 讯飞 API secret
|
||||
# claude 配置
|
||||
"claude_api_cookie": "",
|
||||
"claude_uuid": "",
|
||||
# 通义千问API, 获取方式查看文档 https://help.aliyun.com/document_detail/2587494.html
|
||||
"qwen_access_key_id": "",
|
||||
"qwen_access_key_secret": "",
|
||||
"qwen_agent_key": "",
|
||||
"qwen_app_id": "",
|
||||
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
|
||||
# Google Gemini Api Key
|
||||
"gemini_api_key": "",
|
||||
# wework的通用配置
|
||||
"wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开
|
||||
# 语音设置
|
||||
"speech_recognition": False, # 是否开启语音识别
|
||||
"speech_recognition": True, # 是否开启语音识别
|
||||
"group_speech_recognition": False, # 是否开启群组语音识别
|
||||
"voice_reply_voice": False, # 是否使用语音回复语音,需要设置对应语音合成引擎的api key
|
||||
"always_reply_voice": False, # 是否一直使用语音回复
|
||||
"voice_to_text": "openai", # 语音识别引擎,支持openai,baidu,google,azure
|
||||
"text_to_voice": "baidu", # 语音合成引擎,支持baidu,google,pytts(offline),azure
|
||||
"text_to_voice": "openai", # 语音合成引擎,支持openai,baidu,google,pytts(offline),azure,elevenlabs
|
||||
"text_to_voice_model": "tts-1",
|
||||
"tts_voice_id": "alloy",
|
||||
# baidu 语音api配置, 使用百度语音识别和语音合成时需要
|
||||
"baidu_app_id": "",
|
||||
"baidu_api_key": "",
|
||||
@@ -63,6 +95,9 @@ available_setting = {
|
||||
# azure 语音api配置, 使用azure语音识别和语音合成时需要
|
||||
"azure_voice_api_key": "",
|
||||
"azure_voice_region": "japaneast",
|
||||
# elevenlabs 语音api配置
|
||||
"xi_api_key": "", #获取ap的方法可以参考https://docs.elevenlabs.io/api-reference/quick-start/authentication
|
||||
"xi_voice_id": "", #ElevenLabs提供了9种英式、美式等英语发音id,分别是“Adam/Antoni/Arnold/Bella/Domi/Elli/Josh/Rachel/Sam”
|
||||
# 服务时间限制,目前支持itchat
|
||||
"chat_time_module": False, # 是否开启服务时间限制
|
||||
"chat_start_time": "00:00", # 服务开始时间
|
||||
@@ -90,6 +125,18 @@ available_setting = {
|
||||
"wechatcomapp_secret": "", # 企业微信app的secret
|
||||
"wechatcomapp_agent_id": "", # 企业微信app的agent_id
|
||||
"wechatcomapp_aes_key": "", # 企业微信app的aes_key
|
||||
|
||||
# 飞书配置
|
||||
"feishu_port": 80, # 飞书bot监听端口
|
||||
"feishu_app_id": "", # 飞书机器人应用APP Id
|
||||
"feishu_app_secret": "", # 飞书机器人APP secret
|
||||
"feishu_token": "", # 飞书 verification token
|
||||
"feishu_bot_name": "", # 飞书机器人的名字
|
||||
|
||||
# 钉钉配置
|
||||
"dingtalk_client_id": "", # 钉钉机器人Client ID
|
||||
"dingtalk_client_secret": "", # 钉钉机器人Client Secret
|
||||
|
||||
# chatgpt指令自定义触发词
|
||||
"clear_memory_commands": ["#清除记忆"], # 重置会话指令,必须以#开头
|
||||
# channel配置
|
||||
@@ -99,10 +146,18 @@ available_setting = {
|
||||
"appdata_dir": "", # 数据目录
|
||||
# 插件配置
|
||||
"plugin_trigger_prefix": "$", # 规范插件提供聊天相关指令的前缀,建议不要和管理员指令前缀"#"冲突
|
||||
# 知识库平台配置
|
||||
# 是否使用全局插件配置
|
||||
"use_global_plugin_config": False,
|
||||
"max_media_send_count": 3, # 单次最大发送媒体资源的个数
|
||||
"media_send_interval": 1, # 发送图片的事件间隔,单位秒
|
||||
# 智谱AI 平台配置
|
||||
"zhipu_ai_api_key": "",
|
||||
"zhipu_ai_api_base": "https://open.bigmodel.cn/api/paas/v4",
|
||||
# LinkAI平台配置
|
||||
"use_linkai": False,
|
||||
"linkai_api_key": "",
|
||||
"linkai_app_code": ""
|
||||
"linkai_app_code": "",
|
||||
"linkai_api_base": "https://api.link-ai.chat", # linkAI服务地址,若国内无法访问或延迟较高可改为 https://api.link-ai.tech
|
||||
}
|
||||
|
||||
|
||||
@@ -226,3 +281,32 @@ def subscribe_msg():
|
||||
trigger_prefix = conf().get("single_chat_prefix", [""])[0]
|
||||
msg = conf().get("subscribe_msg", "")
|
||||
return msg.format(trigger_prefix=trigger_prefix)
|
||||
|
||||
|
||||
# global plugin config
|
||||
plugin_config = {}
|
||||
|
||||
|
||||
def write_plugin_config(pconf: dict):
|
||||
"""
|
||||
写入插件全局配置
|
||||
:param pconf: 全量插件配置
|
||||
"""
|
||||
global plugin_config
|
||||
for k in pconf:
|
||||
plugin_config[k.lower()] = pconf[k]
|
||||
|
||||
|
||||
def pconf(plugin_name: str) -> dict:
|
||||
"""
|
||||
根据插件名称获取配置
|
||||
:param plugin_name: 插件名称
|
||||
:return: 该插件的配置项
|
||||
"""
|
||||
return plugin_config.get(plugin_name.lower())
|
||||
|
||||
|
||||
# 全局配置,用于存放全局生效的状态
|
||||
global_config = {
|
||||
"admin_users": []
|
||||
}
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
FROM python:3.10-alpine
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
curl \
|
||||
wget \
|
||||
&& export BUILD_GITHUB_TAG=${CHATGPT_ON_WECHAT_VER:-`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`} \
|
||||
&& wget -t 3 -T 30 -nv -O chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/archive/refs/tags/${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& tar -xzf chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& mv chatgpt-on-wechat-${BUILD_GITHUB_TAG} ${BUILD_PREFIX} \
|
||||
&& rm chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache -r requirements.txt --extra-index-url https://alpine-wheels.github.io/index\
|
||||
&& pip install --no-cache -r requirements-optional.txt --extra-index-url https://alpine-wheels.github.io/index\
|
||||
&& apk del curl wget
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \
|
||||
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,29 +0,0 @@
|
||||
FROM python:3.10-alpine
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
ADD . ${BUILD_PREFIX}
|
||||
|
||||
RUN apk add --no-cache bash ffmpeg espeak \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache -r requirements.txt --extra-index-url https://alpine-wheels.github.io/index\
|
||||
&& pip install --no-cache -r requirements-optional.txt --extra-index-url https://alpine-wheels.github.io/index
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD docker/entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \
|
||||
&& chown -R noroot:noroot ${BUILD_PREFIX}
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,41 +0,0 @@
|
||||
FROM python:3.10
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
ARG CHATGPT_ON_WECHAT_VER
|
||||
|
||||
ENV BUILD_PREFIX=/app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
wget \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& export BUILD_GITHUB_TAG=${CHATGPT_ON_WECHAT_VER:-`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`} \
|
||||
&& wget -t 3 -T 30 -nv -O chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
https://github.com/zhayujie/chatgpt-on-wechat/archive/refs/tags/${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& tar -xzf chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& mv chatgpt-on-wechat-${BUILD_GITHUB_TAG} ${BUILD_PREFIX} \
|
||||
&& rm chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \
|
||||
&& cd ${BUILD_PREFIX} \
|
||||
&& cp config-template.json ${BUILD_PREFIX}/config.json \
|
||||
&& /usr/local/bin/python -m pip install --no-cache --upgrade pip \
|
||||
&& pip install --no-cache -r requirements.txt \
|
||||
&& pip install --no-cache -r requirements-optional.txt
|
||||
|
||||
WORKDIR ${BUILD_PREFIX}
|
||||
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh \
|
||||
&& mkdir -p /home/noroot \
|
||||
&& groupadd -r noroot \
|
||||
&& useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \
|
||||
&& chown -R noroot:noroot /home/noroot ${BUILD_PREFIX} /usr/local/lib
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.10-slim
|
||||
FROM python:3.10-slim-bullseye
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
@@ -32,4 +32,4 @@ RUN chmod +x /entrypoint.sh \
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# fetch latest release tag
|
||||
CHATGPT_ON_WECHAT_TAG=`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`
|
||||
|
||||
# build image
|
||||
docker build -f Dockerfile.alpine \
|
||||
--build-arg CHATGPT_ON_WECHAT_VER=$CHATGPT_ON_WECHAT_TAG \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
|
||||
# tag image
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine
|
||||
@@ -1,15 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# fetch latest release tag
|
||||
CHATGPT_ON_WECHAT_TAG=`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \
|
||||
grep '"tag_name":' | \
|
||||
sed -E 's/.*"([^"]+)".*/\1/'`
|
||||
|
||||
# build image
|
||||
docker build -f Dockerfile.debian \
|
||||
--build-arg CHATGPT_ON_WECHAT_VER=$CHATGPT_ON_WECHAT_TAG \
|
||||
-t zhayujie/chatgpt-on-wechat .
|
||||
|
||||
# tag image
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian
|
||||
docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian
|
||||
@@ -1,23 +0,0 @@
|
||||
FROM zhayujie/chatgpt-on-wechat:alpine
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
USER root
|
||||
|
||||
RUN apk add --no-cache \
|
||||
ffmpeg \
|
||||
espeak \
|
||||
&& pip install --no-cache \
|
||||
baidu-aip \
|
||||
chardet \
|
||||
SpeechRecognition
|
||||
|
||||
# replace entrypoint
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,24 +0,0 @@
|
||||
FROM zhayujie/chatgpt-on-wechat:debian
|
||||
|
||||
LABEL maintainer="foo@bar.com"
|
||||
ARG TZ='Asia/Shanghai'
|
||||
|
||||
USER root
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
espeak \
|
||||
&& pip install --no-cache \
|
||||
baidu-aip \
|
||||
chardet \
|
||||
SpeechRecognition
|
||||
|
||||
# replace entrypoint
|
||||
ADD ./entrypoint.sh /entrypoint.sh
|
||||
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
USER noroot
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -1,24 +0,0 @@
|
||||
version: '2.0'
|
||||
services:
|
||||
chatgpt-on-wechat:
|
||||
build:
|
||||
context: ./
|
||||
dockerfile: Dockerfile.alpine
|
||||
image: zhayujie/chatgpt-on-wechat-voice-reply
|
||||
container_name: chatgpt-on-wechat-voice-reply
|
||||
environment:
|
||||
OPEN_AI_API_KEY: 'YOUR API KEY'
|
||||
OPEN_AI_PROXY: ''
|
||||
SINGLE_CHAT_PREFIX: '["bot", "@bot"]'
|
||||
SINGLE_CHAT_REPLY_PREFIX: '"[bot] "'
|
||||
GROUP_CHAT_PREFIX: '["@bot"]'
|
||||
GROUP_NAME_WHITE_LIST: '["ChatGPT测试群", "ChatGPT测试群2"]'
|
||||
IMAGE_CREATE_PREFIX: '["画", "看", "找"]'
|
||||
CONVERSATION_MAX_TOKENS: 1000
|
||||
SPEECH_RECOGNITION: 'true'
|
||||
CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
|
||||
EXPIRES_IN_SECONDS: 3600
|
||||
VOICE_REPLY_VOICE: 'true'
|
||||
BAIDU_APP_ID: 'YOUR BAIDU APP ID'
|
||||
BAIDU_API_KEY: 'YOUR BAIDU API KEY'
|
||||
BAIDU_SECRET_KEY: 'YOUR BAIDU SERVICE KEY'
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# build prefix
|
||||
CHATGPT_ON_WECHAT_PREFIX=${CHATGPT_ON_WECHAT_PREFIX:-""}
|
||||
# path to config.json
|
||||
CHATGPT_ON_WECHAT_CONFIG_PATH=${CHATGPT_ON_WECHAT_CONFIG_PATH:-""}
|
||||
# execution command line
|
||||
CHATGPT_ON_WECHAT_EXEC=${CHATGPT_ON_WECHAT_EXEC:-""}
|
||||
|
||||
OPEN_AI_API_KEY=${OPEN_AI_API_KEY:-""}
|
||||
OPEN_AI_PROXY=${OPEN_AI_PROXY:-""}
|
||||
SINGLE_CHAT_PREFIX=${SINGLE_CHAT_PREFIX:-""}
|
||||
SINGLE_CHAT_REPLY_PREFIX=${SINGLE_CHAT_REPLY_PREFIX:-""}
|
||||
GROUP_CHAT_PREFIX=${GROUP_CHAT_PREFIX:-""}
|
||||
GROUP_NAME_WHITE_LIST=${GROUP_NAME_WHITE_LIST:-""}
|
||||
IMAGE_CREATE_PREFIX=${IMAGE_CREATE_PREFIX:-""}
|
||||
CONVERSATION_MAX_TOKENS=${CONVERSATION_MAX_TOKENS:-""}
|
||||
SPEECH_RECOGNITION=${SPEECH_RECOGNITION:-""}
|
||||
CHARACTER_DESC=${CHARACTER_DESC:-""}
|
||||
EXPIRES_IN_SECONDS=${EXPIRES_IN_SECONDS:-""}
|
||||
|
||||
VOICE_REPLY_VOICE=${VOICE_REPLY_VOICE:-""}
|
||||
BAIDU_APP_ID=${BAIDU_APP_ID:-""}
|
||||
BAIDU_API_KEY=${BAIDU_API_KEY:-""}
|
||||
BAIDU_SECRET_KEY=${BAIDU_SECRET_KEY:-""}
|
||||
|
||||
# CHATGPT_ON_WECHAT_PREFIX is empty, use /app
|
||||
if [ "$CHATGPT_ON_WECHAT_PREFIX" == "" ] ; then
|
||||
CHATGPT_ON_WECHAT_PREFIX=/app
|
||||
fi
|
||||
|
||||
# CHATGPT_ON_WECHAT_CONFIG_PATH is empty, use '/app/config.json'
|
||||
if [ "$CHATGPT_ON_WECHAT_CONFIG_PATH" == "" ] ; then
|
||||
CHATGPT_ON_WECHAT_CONFIG_PATH=$CHATGPT_ON_WECHAT_PREFIX/config.json
|
||||
fi
|
||||
|
||||
# CHATGPT_ON_WECHAT_EXEC is empty, use ‘python app.py’
|
||||
if [ "$CHATGPT_ON_WECHAT_EXEC" == "" ] ; then
|
||||
CHATGPT_ON_WECHAT_EXEC="python app.py"
|
||||
fi
|
||||
|
||||
# modify content in config.json
|
||||
if [ "$OPEN_AI_API_KEY" != "" ] ; then
|
||||
sed -i "s/\"open_ai_api_key\".*,$/\"open_ai_api_key\": \"$OPEN_AI_API_KEY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
else
|
||||
echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m"
|
||||
fi
|
||||
|
||||
# use http_proxy as default
|
||||
if [ "$HTTP_PROXY" != "" ] ; then
|
||||
sed -i "s/\"proxy\".*,$/\"proxy\": \"$HTTP_PROXY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$OPEN_AI_PROXY" != "" ] ; then
|
||||
sed -i "s/\"proxy\".*,$/\"proxy\": \"$OPEN_AI_PROXY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$SINGLE_CHAT_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"single_chat_prefix\".*,$/\"single_chat_prefix\": $SINGLE_CHAT_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$SINGLE_CHAT_REPLY_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"single_chat_reply_prefix\".*,$/\"single_chat_reply_prefix\": $SINGLE_CHAT_REPLY_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$GROUP_CHAT_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"group_chat_prefix\".*,$/\"group_chat_prefix\": $GROUP_CHAT_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$GROUP_NAME_WHITE_LIST" != "" ] ; then
|
||||
sed -i "s/\"group_name_white_list\".*,$/\"group_name_white_list\": $GROUP_NAME_WHITE_LIST,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$IMAGE_CREATE_PREFIX" != "" ] ; then
|
||||
sed -i "s/\"image_create_prefix\".*,$/\"image_create_prefix\": $IMAGE_CREATE_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$CONVERSATION_MAX_TOKENS" != "" ] ; then
|
||||
sed -i "s/\"conversation_max_tokens\".*,$/\"conversation_max_tokens\": $CONVERSATION_MAX_TOKENS,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$SPEECH_RECOGNITION" != "" ] ; then
|
||||
sed -i "s/\"speech_recognition\".*,$/\"speech_recognition\": $SPEECH_RECOGNITION,/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$CHARACTER_DESC" != "" ] ; then
|
||||
sed -i "s/\"character_desc\".*,$/\"character_desc\": \"$CHARACTER_DESC\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$EXPIRES_IN_SECONDS" != "" ] ; then
|
||||
sed -i "s/\"expires_in_seconds\".*$/\"expires_in_seconds\": $EXPIRES_IN_SECONDS/" $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
# append
|
||||
if [ "$BAIDU_SECRET_KEY" != "" ] ; then
|
||||
sed -i "1a \ \ \"baidu_secret_key\": \"$BAIDU_SECRET_KEY\"," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$BAIDU_API_KEY" != "" ] ; then
|
||||
sed -i "1a \ \ \"baidu_api_key\": \"$BAIDU_API_KEY\"," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$BAIDU_APP_ID" != "" ] ; then
|
||||
sed -i "1a \ \ \"baidu_app_id\": \"$BAIDU_APP_ID\"," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
if [ "$VOICE_REPLY_VOICE" != "" ] ; then
|
||||
sed -i "1a \ \ \"voice_reply_voice\": $VOICE_REPLY_VOICE," $CHATGPT_ON_WECHAT_CONFIG_PATH
|
||||
fi
|
||||
|
||||
# go to prefix dir
|
||||
cd $CHATGPT_ON_WECHAT_PREFIX
|
||||
# excute
|
||||
$CHATGPT_ON_WECHAT_EXEC
|
||||
|
||||
|
||||
@@ -1,20 +1,24 @@
|
||||
version: '2.0'
|
||||
services:
|
||||
chatgpt-on-wechat:
|
||||
build:
|
||||
context: ./
|
||||
dockerfile: Dockerfile.alpine
|
||||
image: zhayujie/chatgpt-on-wechat
|
||||
container_name: sample-chatgpt-on-wechat
|
||||
container_name: chatgpt-on-wechat
|
||||
security_opt:
|
||||
- seccomp:unconfined
|
||||
environment:
|
||||
OPEN_AI_API_KEY: 'YOUR API KEY'
|
||||
OPEN_AI_PROXY: ''
|
||||
MODEL: 'gpt-3.5-turbo'
|
||||
PROXY: ''
|
||||
SINGLE_CHAT_PREFIX: '["bot", "@bot"]'
|
||||
SINGLE_CHAT_REPLY_PREFIX: '"[bot] "'
|
||||
GROUP_CHAT_PREFIX: '["@bot"]'
|
||||
GROUP_NAME_WHITE_LIST: '["ChatGPT测试群", "ChatGPT测试群2"]'
|
||||
IMAGE_CREATE_PREFIX: '["画", "看", "找"]'
|
||||
CONVERSATION_MAX_TOKENS: 1000
|
||||
SPEECH_RECOGNITION: "False"
|
||||
SPEECH_RECOGNITION: 'False'
|
||||
CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。'
|
||||
EXPIRES_IN_SECONDS: 3600
|
||||
EXPIRES_IN_SECONDS: 3600
|
||||
USE_GLOBAL_PLUGIN_CONFIG: 'True'
|
||||
USE_LINKAI: 'False'
|
||||
LINKAI_API_KEY: ''
|
||||
LINKAI_APP_CODE: ''
|
||||
@@ -1,16 +0,0 @@
|
||||
OPEN_AI_API_KEY=YOUR API KEY
|
||||
OPEN_AI_PROXY=
|
||||
SINGLE_CHAT_PREFIX=["bot", "@bot"]
|
||||
SINGLE_CHAT_REPLY_PREFIX="[bot] "
|
||||
GROUP_CHAT_PREFIX=["@bot"]
|
||||
GROUP_NAME_WHITE_LIST=["ChatGPT测试群", "ChatGPT测试群2"]
|
||||
IMAGE_CREATE_PREFIX=["画", "看", "找"]
|
||||
CONVERSATION_MAX_TOKENS=1000
|
||||
SPEECH_RECOGNITION=false
|
||||
CHARACTER_DESC=你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。
|
||||
EXPIRES_IN_SECONDS=3600
|
||||
|
||||
# Optional
|
||||
#CHATGPT_ON_WECHAT_PREFIX=/app
|
||||
#CHATGPT_ON_WECHAT_CONFIG_PATH=/app/config.json
|
||||
#CHATGPT_ON_WECHAT_EXEC=python app.py
|
||||
@@ -1,26 +0,0 @@
|
||||
IMG:=`cat Name`
|
||||
MOUNT:=
|
||||
PORT_MAP:=
|
||||
DOTENV:=.env
|
||||
CONTAINER_NAME:=sample-chatgpt-on-wechat
|
||||
|
||||
echo:
|
||||
echo $(IMG)
|
||||
|
||||
run_d:
|
||||
docker rm $(CONTAINER_NAME) || echo
|
||||
docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \
|
||||
--env-file=$(DOTENV) \
|
||||
$(MOUNT) $(IMG)
|
||||
|
||||
run_i:
|
||||
docker rm $(CONTAINER_NAME) || echo
|
||||
docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \
|
||||
--env-file=$(DOTENV) \
|
||||
$(MOUNT) $(IMG)
|
||||
|
||||
stop:
|
||||
docker stop $(CONTAINER_NAME)
|
||||
|
||||
rm: stop
|
||||
docker rm $(CONTAINER_NAME)
|
||||
@@ -1 +0,0 @@
|
||||
zhayujie/chatgpt-on-wechat
|
||||
9
lib/itchat/LICENSE
Normal file
9
lib/itchat/LICENSE
Normal file
@@ -0,0 +1,9 @@
|
||||
**The MIT License (MIT)**
|
||||
|
||||
Copyright (c) 2017 LittleCoder ([littlecodersh@Github](https://github.com/littlecodersh))
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
@@ -24,16 +24,17 @@ class Banwords(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
# load config
|
||||
conf = super().load_config()
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
conf = None
|
||||
if not os.path.exists(config_path):
|
||||
conf = {"action": "ignore"}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(conf, f, indent=4)
|
||||
else:
|
||||
with open(config_path, "r") as f:
|
||||
conf = json.load(f)
|
||||
if not conf:
|
||||
# 配置不存在则写入默认配置
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
conf = {"action": "ignore"}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(conf, f, indent=4)
|
||||
|
||||
self.searchr = WordsSearch()
|
||||
self.action = conf["action"]
|
||||
banwords_path = os.path.join(curdir, "banwords.txt")
|
||||
|
||||
@@ -29,14 +29,9 @@ class BDunit(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
conf = None
|
||||
if not os.path.exists(config_path):
|
||||
conf = super().load_config()
|
||||
if not conf:
|
||||
raise Exception("config.json not found")
|
||||
else:
|
||||
with open(config_path, "r") as f:
|
||||
conf = json.load(f)
|
||||
self.service_id = conf["service_id"]
|
||||
self.api_key = conf["api_key"]
|
||||
self.secret_key = conf["secret_key"]
|
||||
|
||||
44
plugins/config.json.template
Normal file
44
plugins/config.json.template
Normal file
@@ -0,0 +1,44 @@
|
||||
{
|
||||
"godcmd": {
|
||||
"password": "",
|
||||
"admin_users": []
|
||||
},
|
||||
"banwords": {
|
||||
"action": "replace",
|
||||
"reply_filter": true,
|
||||
"reply_action": "ignore"
|
||||
},
|
||||
"tool": {
|
||||
"tools": [
|
||||
"python",
|
||||
"url-get",
|
||||
"terminal",
|
||||
"meteo-weather"
|
||||
],
|
||||
"kwargs": {
|
||||
"top_k_results": 2,
|
||||
"no_default": false,
|
||||
"model_name": "gpt-3.5-turbo"
|
||||
}
|
||||
},
|
||||
"linkai": {
|
||||
"group_app_map": {
|
||||
"测试群1": "default",
|
||||
"测试群2": "Kv2fXJcH"
|
||||
},
|
||||
"midjourney": {
|
||||
"enabled": true,
|
||||
"auto_translate": true,
|
||||
"img_proxy": true,
|
||||
"max_tasks": 3,
|
||||
"max_tasks_per_user": 1,
|
||||
"use_image_create_prefix": true
|
||||
},
|
||||
"summary": {
|
||||
"enabled": true,
|
||||
"group_enabled": true,
|
||||
"max_file_size": 5000,
|
||||
"type": ["FILE", "SHARING"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,16 +4,16 @@ import json
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import traceback
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
import bridge.bridge
|
||||
import plugins
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf, load_config
|
||||
from config import conf, load_config, global_config
|
||||
from plugins import *
|
||||
|
||||
# 定义指令集
|
||||
@@ -32,6 +32,10 @@ COMMANDS = {
|
||||
"args": ["口令"],
|
||||
"desc": "管理员认证",
|
||||
},
|
||||
"model": {
|
||||
"alias": ["model", "模型"],
|
||||
"desc": "查看和设置全局模型",
|
||||
},
|
||||
"set_openai_api_key": {
|
||||
"alias": ["set_openai_api_key"],
|
||||
"args": ["api_key"],
|
||||
@@ -132,9 +136,9 @@ ADMIN_COMMANDS = {
|
||||
|
||||
# 定义帮助函数
|
||||
def get_help_text(isadmin, isgroup):
|
||||
help_text = "通用指令:\n"
|
||||
help_text = "通用指令\n"
|
||||
for cmd, info in COMMANDS.items():
|
||||
if cmd == "auth": # 不提示认证指令
|
||||
if cmd in ["auth", "set_openai_api_key", "reset_openai_api_key", "set_gpt_model", "reset_gpt_model", "gpt_model"]: # 不显示帮助指令
|
||||
continue
|
||||
if cmd == "id" and conf().get("channel_type", "wx") not in ["wxy", "wechatmp"]:
|
||||
continue
|
||||
@@ -147,7 +151,7 @@ def get_help_text(isadmin, isgroup):
|
||||
|
||||
# 插件指令
|
||||
plugins = PluginManager().list_plugins()
|
||||
help_text += "\n目前可用插件有:"
|
||||
help_text += "\n可用插件"
|
||||
for plugin in plugins:
|
||||
if plugins[plugin].enabled and not plugins[plugin].hidden:
|
||||
namecn = plugins[plugin].namecn
|
||||
@@ -178,16 +182,13 @@ class Godcmd(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
gconf = None
|
||||
if not os.path.exists(config_path):
|
||||
gconf = {"password": "", "admin_users": []}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(gconf, f, indent=4)
|
||||
else:
|
||||
with open(config_path, "r") as f:
|
||||
gconf = json.load(f)
|
||||
config_path = os.path.join(os.path.dirname(__file__), "config.json")
|
||||
gconf = super().load_config()
|
||||
if not gconf:
|
||||
if not os.path.exists(config_path):
|
||||
gconf = {"password": "", "admin_users": []}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(gconf, f, indent=4)
|
||||
if gconf["password"] == "":
|
||||
self.temp_password = "".join(random.sample(string.digits, 4))
|
||||
logger.info("[Godcmd] 因未设置口令,本次的临时口令为%s。" % self.temp_password)
|
||||
@@ -202,6 +203,7 @@ class Godcmd(Plugin):
|
||||
|
||||
self.password = gconf["password"]
|
||||
self.admin_users = gconf["admin_users"] # 预存的管理员账号,这些账号不需要认证。itchat的用户名每次都会变,不可用
|
||||
global_config["admin_users"] = self.admin_users
|
||||
self.isrunning = True # 机器人是否运行中
|
||||
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
@@ -260,6 +262,20 @@ class Godcmd(Plugin):
|
||||
break
|
||||
if not ok:
|
||||
result = "插件不存在或未启用"
|
||||
elif cmd == "model":
|
||||
if not isadmin and not self.is_admin_in_group(e_context["context"]):
|
||||
ok, result = False, "需要管理员权限执行"
|
||||
elif len(args) == 0:
|
||||
model = conf().get("model") or const.GPT35
|
||||
ok, result = True, "当前模型为: " + str(model)
|
||||
elif len(args) == 1:
|
||||
if args[0] not in const.MODEL_LIST:
|
||||
ok, result = False, "模型名称不存在"
|
||||
else:
|
||||
conf()["model"] = self.model_mapping(args[0])
|
||||
Bridge().reset_bot()
|
||||
model = conf().get("model") or const.GPT35
|
||||
ok, result = True, "模型设置为: " + str(model)
|
||||
elif cmd == "id":
|
||||
ok, result = True, user
|
||||
elif cmd == "set_openai_api_key":
|
||||
@@ -297,8 +313,10 @@ class Godcmd(Plugin):
|
||||
except Exception as e:
|
||||
ok, result = False, "你没有设置私有GPT模型"
|
||||
elif cmd == "reset":
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI, const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]:
|
||||
bot.sessions.clear_session(session_id)
|
||||
if Bridge().chat_bots.get(bottype):
|
||||
Bridge().chat_bots.get(bottype).sessions.clear_session(session_id)
|
||||
channel.cancel_session(session_id)
|
||||
ok, result = True, "会话已重置"
|
||||
else:
|
||||
@@ -320,15 +338,20 @@ class Godcmd(Plugin):
|
||||
load_config()
|
||||
ok, result = True, "配置已重载"
|
||||
elif cmd == "resetall":
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]:
|
||||
if bottype in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI,
|
||||
const.BAIDU, const.XUNFEI, const.QWEN, const.GEMINI]:
|
||||
channel.cancel_all_session()
|
||||
bot.sessions.clear_all_session()
|
||||
ok, result = True, "重置所有会话成功"
|
||||
else:
|
||||
ok, result = False, "当前对话机器人不支持重置会话"
|
||||
elif cmd == "debug":
|
||||
logger.setLevel("DEBUG")
|
||||
ok, result = True, "DEBUG模式已开启"
|
||||
if logger.getEffectiveLevel() == logging.DEBUG: # 判断当前日志模式是否DEBUG
|
||||
logger.setLevel(logging.INFO)
|
||||
ok, result = True, "DEBUG模式已关闭"
|
||||
else:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
ok, result = True, "DEBUG模式已开启"
|
||||
elif cmd == "plist":
|
||||
plugins = PluginManager().list_plugins()
|
||||
ok = True
|
||||
@@ -429,12 +452,34 @@ class Godcmd(Plugin):
|
||||
password = args[0]
|
||||
if password == self.password:
|
||||
self.admin_users.append(userid)
|
||||
global_config["admin_users"].append(userid)
|
||||
return True, "认证成功"
|
||||
elif password == self.temp_password:
|
||||
self.admin_users.append(userid)
|
||||
global_config["admin_users"].append(userid)
|
||||
return True, "认证成功,请尽快设置口令"
|
||||
else:
|
||||
return False, "认证失败"
|
||||
|
||||
def get_help_text(self, isadmin=False, isgroup=False, **kwargs):
|
||||
return get_help_text(isadmin, isgroup)
|
||||
|
||||
|
||||
def is_admin_in_group(self, context):
|
||||
if context["isgroup"]:
|
||||
return context.kwargs.get("msg").actual_user_id in global_config["admin_users"]
|
||||
return False
|
||||
|
||||
|
||||
def model_mapping(self, model) -> str:
|
||||
if model == "gpt-4-turbo":
|
||||
return const.GPT4_TURBO_PREVIEW
|
||||
return model
|
||||
|
||||
def reload(self):
|
||||
gconf = plugin_config[self.name]
|
||||
if gconf:
|
||||
if gconf.get("password"):
|
||||
self.password = gconf["password"]
|
||||
if gconf.get("admin_users"):
|
||||
self.admin_users = gconf["admin_users"]
|
||||
|
||||
@@ -6,6 +6,7 @@ from bridge.reply import Reply, ReplyType
|
||||
from channel.chat_message import ChatMessage
|
||||
from common.log import logger
|
||||
from plugins import *
|
||||
from config import conf
|
||||
|
||||
|
||||
@plugins.register(
|
||||
@@ -21,27 +22,49 @@ class Hello(Plugin):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
logger.info("[Hello] inited")
|
||||
self.config = super().load_config()
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
if e_context["context"].type not in [
|
||||
ContextType.TEXT,
|
||||
ContextType.JOIN_GROUP,
|
||||
ContextType.PATPAT,
|
||||
ContextType.EXIT_GROUP
|
||||
]:
|
||||
return
|
||||
|
||||
if e_context["context"].type == ContextType.JOIN_GROUP:
|
||||
if "group_welcome_msg" in conf():
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = conf().get("group_welcome_msg", "")
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
return
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。'
|
||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
|
||||
if not self.config or not self.config.get("use_character_desc"):
|
||||
e_context["context"]["generate_breaked_by"] = EventAction.BREAK
|
||||
return
|
||||
|
||||
|
||||
if e_context["context"].type == ContextType.EXIT_GROUP:
|
||||
if conf().get("group_chat_exit_group"):
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f'请你随机使用一种风格跟其他群用户说他违反规则"{msg.actual_user_nickname}"退出群聊。'
|
||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
|
||||
return
|
||||
e_context.action = EventAction.BREAK
|
||||
return
|
||||
|
||||
if e_context["context"].type == ContextType.PATPAT:
|
||||
e_context["context"].type = ContextType.TEXT
|
||||
msg: ChatMessage = e_context["context"]["msg"]
|
||||
e_context["context"].content = f"请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。"
|
||||
e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑
|
||||
e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑
|
||||
if not self.config or not self.config.get("use_character_desc"):
|
||||
e_context["context"]["generate_breaked_by"] = EventAction.BREAK
|
||||
return
|
||||
|
||||
content = e_context["context"].content
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
import plugins
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
@@ -51,15 +51,46 @@ class Keyword(Plugin):
|
||||
content = e_context["context"].content.strip()
|
||||
logger.debug("[keyword] on_handle_context. content: %s" % content)
|
||||
if content in self.keyword:
|
||||
logger.debug(f"[keyword] 匹配到关键字【{content}】")
|
||||
logger.info(f"[keyword] 匹配到关键字【{content}】")
|
||||
reply_text = self.keyword[content]
|
||||
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = reply_text
|
||||
# 判断匹配内容的类型
|
||||
if (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".gif", ".img"]):
|
||||
# 如果是以 http:// 或 https:// 开头,且".jpg", ".jpeg", ".png", ".gif", ".img"结尾,则认为是图片 URL。
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.IMAGE_URL
|
||||
reply.content = reply_text
|
||||
|
||||
elif (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".pdf", ".doc", ".docx", ".xls", "xlsx",".zip", ".rar"]):
|
||||
# 如果是以 http:// 或 https:// 开头,且".pdf", ".doc", ".docx", ".xls", "xlsx",".zip", ".rar"结尾,则下载文件到tmp目录并发送给用户
|
||||
file_path = "tmp"
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path)
|
||||
file_name = reply_text.split("/")[-1] # 获取文件名
|
||||
file_path = os.path.join(file_path, file_name)
|
||||
response = requests.get(reply_text)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
#channel/wechat/wechat_channel.py和channel/wechat_channel.py中缺少ReplyType.FILE类型。
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.FILE
|
||||
reply.content = file_path
|
||||
|
||||
elif (reply_text.startswith("http://") or reply_text.startswith("https://")) and any(reply_text.endswith(ext) for ext in [".mp4"]):
|
||||
# 如果是以 http:// 或 https:// 开头,且".mp4"结尾,则下载视频到tmp目录并发送给用户
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.VIDEO_URL
|
||||
reply.content = reply_text
|
||||
|
||||
else:
|
||||
# 否则认为是普通文本
|
||||
reply = Reply()
|
||||
reply.type = ReplyType.TEXT
|
||||
reply.content = reply_text
|
||||
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑
|
||||
|
||||
|
||||
def get_help_text(self, **kwargs):
|
||||
help_text = "关键词过滤"
|
||||
return help_text
|
||||
|
||||
109
plugins/linkai/README.md
Normal file
109
plugins/linkai/README.md
Normal file
@@ -0,0 +1,109 @@
|
||||
## 插件说明
|
||||
|
||||
基于 LinkAI 提供的知识库、Midjourney绘画、文档对话等能力对机器人的功能进行增强。平台地址: https://link-ai.tech/console
|
||||
|
||||
## 插件配置
|
||||
|
||||
将 `plugins/linkai` 目录下的 `config.json.template` 配置模板复制为最终生效的 `config.json`。 (如果未配置则会默认使用`config.json.template`模板中配置,但功能默认关闭,需要可通过指令进行开启)。
|
||||
|
||||
以下是插件配置项说明:
|
||||
|
||||
```bash
|
||||
{
|
||||
"group_app_map": { # 群聊 和 应用编码 的映射关系
|
||||
"测试群名称1": "default", # 表示在名称为 "测试群名称1" 的群聊中将使用app_code 为 default 的应用
|
||||
"测试群名称2": "Kv2fXJcH"
|
||||
},
|
||||
"midjourney": {
|
||||
"enabled": true, # midjourney 绘画开关
|
||||
"auto_translate": true, # 是否自动将提示词翻译为英文
|
||||
"img_proxy": true, # 是否对生成的图片使用代理,如果你是国外服务器,将这一项设置为false会获得更快的生成速度
|
||||
"max_tasks": 3, # 支持同时提交的总任务个数
|
||||
"max_tasks_per_user": 1, # 支持单个用户同时提交的任务个数
|
||||
"use_image_create_prefix": true # 是否使用全局的绘画触发词,如果开启将同时支持由`config.json`中的 image_create_prefix 配置触发
|
||||
},
|
||||
"summary": {
|
||||
"enabled": true, # 文档总结和对话功能开关
|
||||
"group_enabled": true, # 是否支持群聊开启
|
||||
"max_file_size": 5000, # 文件的大小限制,单位KB,默认为5M,超过该大小直接忽略
|
||||
"type": ["FILE", "SHARING", "IMAGE"] # 支持总结的类型,分别表示 文件、分享链接、图片,其中文件和链接默认打开,图片默认关闭
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
根目录 `config.json` 中配置,`API_KEY` 在 [控制台](https://link-ai.tech/console/interface) 中创建并复制过来:
|
||||
|
||||
```bash
|
||||
"linkai_api_key": "Link_xxxxxxxxx"
|
||||
```
|
||||
|
||||
注意:
|
||||
|
||||
- 配置项中 `group_app_map` 部分是用于映射群聊与LinkAI平台上的应用, `midjourney` 部分是 mj 画图的配置,`summary` 部分是文档总结及对话功能的配置。三部分的配置相互独立,可按需开启
|
||||
- 实际 `config.json` 配置中应保证json格式,不应携带 '#' 及后面的注释
|
||||
- 如果是`docker`部署,可通过映射 `plugins/config.json` 到容器中来完成插件配置,参考[文档](https://github.com/zhayujie/chatgpt-on-wechat#3-%E6%8F%92%E4%BB%B6%E4%BD%BF%E7%94%A8)
|
||||
|
||||
## 插件使用
|
||||
|
||||
> 使用插件中的知识库管理功能需要首先开启`linkai`对话,依赖全局 `config.json` 中的 `use_linkai` 和 `linkai_api_key` 配置;而midjourney绘画 和 summary文档总结对话功能则只需填写 `linkai_api_key` 配置,`use_linkai` 无论是否关闭均可使用。具体可参考 [详细文档](https://link-ai.tech/platform/link-app/wechat)。
|
||||
|
||||
完成配置后运行项目,会自动运行插件,输入 `#help linkai` 可查看插件功能。
|
||||
|
||||
### 1.知识库管理功能
|
||||
|
||||
提供在不同群聊使用不同应用的功能。可以在上述 `group_app_map` 配置中固定映射关系,也可以通过指令在群中快速完成切换。
|
||||
|
||||
应用切换指令需要首先完成管理员 (`godcmd`) 插件的认证,然后按以下格式输入:
|
||||
|
||||
`$linkai app {app_code}`
|
||||
|
||||
例如输入 `$linkai app Kv2fXJcH`,即将当前群聊与 app_code为 Kv2fXJcH 的应用绑定。
|
||||
|
||||
另外,还可以通过 `$linkai close` 来一键关闭linkai对话,此时就会使用默认的openai接口;同理,发送 `$linkai open` 可以再次开启。
|
||||
|
||||
### 2.Midjourney绘画功能
|
||||
|
||||
若未配置 `plugins/linkai/config.json`,默认会关闭画图功能,直接使用 `$mj open` 可基于默认配置直接使用mj画图。
|
||||
|
||||
指令格式:
|
||||
|
||||
```
|
||||
- 图片生成: $mj 描述词1, 描述词2..
|
||||
- 图片放大: $mju 图片ID 图片序号
|
||||
- 图片变换: $mjv 图片ID 图片序号
|
||||
- 重置: $mjr 图片ID
|
||||
```
|
||||
|
||||
例如:
|
||||
|
||||
```
|
||||
"$mj a little cat, white --ar 9:16"
|
||||
"$mju 1105592717188272288 2"
|
||||
"$mjv 11055927171882 2"
|
||||
"$mjr 11055927171882"
|
||||
```
|
||||
|
||||
注意事项:
|
||||
1. 使用 `$mj open` 和 `$mj close` 指令可以快速打开和关闭绘图功能
|
||||
2. 海外环境部署请将 `img_proxy` 设置为 `false`
|
||||
3. 开启 `use_image_create_prefix` 配置后可直接复用全局画图触发词,以"画"开头便可以生成图片。
|
||||
4. 提示词内容中包含敏感词或者参数格式错误可能导致绘画失败,生成失败不消耗积分
|
||||
5. 若未收到图片可能有两种可能,一种是收到了图片但微信发送失败,可以在后台日志查看有没有获取到图片url,一般原因是受到了wx限制,可以稍后重试或更换账号尝试;另一种情况是图片提示词存在疑似违规,mj不会直接提示错误但会在画图后删掉原图导致程序无法获取,这种情况不消耗积分。
|
||||
|
||||
### 3.文档总结对话功能
|
||||
|
||||
#### 配置
|
||||
|
||||
该功能依赖 LinkAI的知识库及对话功能,需要在项目根目录的config.json中设置 `linkai_api_key`, 同时根据上述插件配置说明,在插件config.json添加 `summary` 部分的配置,设置 `enabled` 为 true。
|
||||
|
||||
如果不想创建 `plugins/linkai/config.json` 配置,可以直接通过 `$linkai sum open` 指令开启该功能。
|
||||
|
||||
#### 使用
|
||||
|
||||
功能开启后,向机器人发送 **文件**、 **分享链接卡片**、**图片** 即可生成摘要,进一步可以与文件或链接的内容进行多轮对话。如果需要关闭某种类型的内容总结,设置 `summary`配置中的type字段即可。
|
||||
|
||||
#### 限制
|
||||
|
||||
1. 文件目前 支持 `txt`, `docx`, `pdf`, `md`, `csv`格式,文件大小由 `max_file_size` 限制,最大不超过15M,文件字数最多可支持百万字的文件。但不建议上传字数过多的文件,一是token消耗过大,二是摘要很难覆盖到全部内容,只能通过多轮对话来了解细节。
|
||||
2. 分享链接 目前仅支持 公众号文章,后续会支持更多文章类型及视频链接等
|
||||
3. 总结及对话的 费用与 LinkAI 3.5-4K 模型的计费方式相同,按文档内容的tokens进行计算
|
||||
1
plugins/linkai/__init__.py
Normal file
1
plugins/linkai/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .linkai import *
|
||||
20
plugins/linkai/config.json.template
Normal file
20
plugins/linkai/config.json.template
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"group_app_map": {
|
||||
"测试群名1": "default",
|
||||
"测试群名2": "Kv2fXJcH"
|
||||
},
|
||||
"midjourney": {
|
||||
"enabled": true,
|
||||
"auto_translate": true,
|
||||
"img_proxy": true,
|
||||
"max_tasks": 3,
|
||||
"max_tasks_per_user": 1,
|
||||
"use_image_create_prefix": true
|
||||
},
|
||||
"summary": {
|
||||
"enabled": true,
|
||||
"group_enabled": true,
|
||||
"max_file_size": 5000,
|
||||
"type": ["FILE", "SHARING"]
|
||||
}
|
||||
}
|
||||
287
plugins/linkai/linkai.py
Normal file
287
plugins/linkai/linkai.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import plugins
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from plugins import *
|
||||
from .midjourney import MJBot
|
||||
from .summary import LinkSummary
|
||||
from bridge import bridge
|
||||
from common.expired_dict import ExpiredDict
|
||||
from common import const
|
||||
import os
|
||||
from .utils import Util
|
||||
|
||||
@plugins.register(
|
||||
name="linkai",
|
||||
desc="A plugin that supports knowledge base and midjourney drawing.",
|
||||
version="0.1.0",
|
||||
author="https://link-ai.tech",
|
||||
desire_priority=99
|
||||
)
|
||||
class LinkAI(Plugin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
|
||||
self.config = super().load_config()
|
||||
if not self.config:
|
||||
# 未加载到配置,使用模板中的配置
|
||||
self.config = self._load_config_template()
|
||||
if self.config:
|
||||
self.mj_bot = MJBot(self.config.get("midjourney"))
|
||||
self.sum_config = {}
|
||||
if self.config:
|
||||
self.sum_config = self.config.get("summary")
|
||||
logger.info(f"[LinkAI] inited, config={self.config}")
|
||||
|
||||
|
||||
def on_handle_context(self, e_context: EventContext):
|
||||
"""
|
||||
消息处理逻辑
|
||||
:param e_context: 消息上下文
|
||||
"""
|
||||
if not self.config:
|
||||
return
|
||||
|
||||
context = e_context['context']
|
||||
if context.type not in [ContextType.TEXT, ContextType.IMAGE, ContextType.IMAGE_CREATE, ContextType.FILE, ContextType.SHARING]:
|
||||
# filter content no need solve
|
||||
return
|
||||
|
||||
if context.type in [ContextType.FILE, ContextType.IMAGE] and self._is_summary_open(context):
|
||||
# 文件处理
|
||||
context.get("msg").prepare()
|
||||
file_path = context.content
|
||||
if not LinkSummary().check_file(file_path, self.sum_config):
|
||||
return
|
||||
if context.type != ContextType.IMAGE:
|
||||
_send_info(e_context, "正在为你加速生成摘要,请稍后")
|
||||
res = LinkSummary().summary_file(file_path)
|
||||
if not res:
|
||||
if context.type != ContextType.IMAGE:
|
||||
_set_reply_text("因为神秘力量无法获取内容,请稍后再试吧", e_context, level=ReplyType.TEXT)
|
||||
return
|
||||
summary_text = res.get("summary")
|
||||
if context.type != ContextType.IMAGE:
|
||||
USER_FILE_MAP[_find_user_id(context) + "-sum_id"] = res.get("summary_id")
|
||||
summary_text += "\n\n💬 发送 \"开启对话\" 可以开启与文件内容的对话"
|
||||
_set_reply_text(summary_text, e_context, level=ReplyType.TEXT)
|
||||
os.remove(file_path)
|
||||
return
|
||||
|
||||
if (context.type == ContextType.SHARING and self._is_summary_open(context)) or \
|
||||
(context.type == ContextType.TEXT and LinkSummary().check_url(context.content)):
|
||||
if not LinkSummary().check_url(context.content):
|
||||
return
|
||||
_send_info(e_context, "正在为你加速生成摘要,请稍后")
|
||||
res = LinkSummary().summary_url(context.content)
|
||||
if not res:
|
||||
_set_reply_text("因为神秘力量无法获取文章内容,请稍后再试吧~", e_context, level=ReplyType.TEXT)
|
||||
return
|
||||
_set_reply_text(res.get("summary") + "\n\n💬 发送 \"开启对话\" 可以开启与文章内容的对话", e_context, level=ReplyType.TEXT)
|
||||
USER_FILE_MAP[_find_user_id(context) + "-sum_id"] = res.get("summary_id")
|
||||
return
|
||||
|
||||
mj_type = self.mj_bot.judge_mj_task_type(e_context)
|
||||
if mj_type:
|
||||
# MJ作图任务处理
|
||||
self.mj_bot.process_mj_task(mj_type, e_context)
|
||||
return
|
||||
|
||||
if context.content.startswith(f"{_get_trigger_prefix()}linkai"):
|
||||
# 应用管理功能
|
||||
self._process_admin_cmd(e_context)
|
||||
return
|
||||
|
||||
if context.type == ContextType.TEXT and context.content == "开启对话" and _find_sum_id(context):
|
||||
# 文本对话
|
||||
_send_info(e_context, "正在为你开启对话,请稍后")
|
||||
res = LinkSummary().summary_chat(_find_sum_id(context))
|
||||
if not res:
|
||||
_set_reply_text("开启对话失败,请稍后再试吧", e_context)
|
||||
return
|
||||
USER_FILE_MAP[_find_user_id(context) + "-file_id"] = res.get("file_id")
|
||||
_set_reply_text("💡你可以问我关于这篇文章的任何问题,例如:\n\n" + res.get("questions") + "\n\n发送 \"退出对话\" 可以关闭与文章的对话", e_context, level=ReplyType.TEXT)
|
||||
return
|
||||
|
||||
if context.type == ContextType.TEXT and context.content == "退出对话" and _find_file_id(context):
|
||||
del USER_FILE_MAP[_find_user_id(context) + "-file_id"]
|
||||
bot = bridge.Bridge().find_chat_bot(const.LINKAI)
|
||||
bot.sessions.clear_session(context["session_id"])
|
||||
_set_reply_text("对话已退出", e_context, level=ReplyType.TEXT)
|
||||
return
|
||||
|
||||
if context.type == ContextType.TEXT and _find_file_id(context):
|
||||
bot = bridge.Bridge().find_chat_bot(const.LINKAI)
|
||||
context.kwargs["file_id"] = _find_file_id(context)
|
||||
reply = bot.reply(context.content, context)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
|
||||
if self._is_chat_task(e_context):
|
||||
# 文本对话任务处理
|
||||
self._process_chat_task(e_context)
|
||||
|
||||
|
||||
# 插件管理功能
|
||||
def _process_admin_cmd(self, e_context: EventContext):
|
||||
context = e_context['context']
|
||||
cmd = context.content.split()
|
||||
if len(cmd) == 1 or (len(cmd) == 2 and cmd[1] == "help"):
|
||||
_set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
|
||||
# 知识库开关指令
|
||||
if not Util.is_admin(e_context):
|
||||
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
is_open = True
|
||||
tips_text = "开启"
|
||||
if cmd[1] == "close":
|
||||
tips_text = "关闭"
|
||||
is_open = False
|
||||
conf()["use_linkai"] = is_open
|
||||
bridge.Bridge().reset_bot()
|
||||
_set_reply_text(f"LinkAI对话功能{tips_text}", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if len(cmd) == 3 and cmd[1] == "app":
|
||||
# 知识库应用切换指令
|
||||
if not context.kwargs.get("isgroup"):
|
||||
_set_reply_text("该指令需在群聊中使用", e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
if not Util.is_admin(e_context):
|
||||
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
app_code = cmd[2]
|
||||
group_name = context.kwargs.get("msg").from_user_nickname
|
||||
group_mapping = self.config.get("group_app_map")
|
||||
if group_mapping:
|
||||
group_mapping[group_name] = app_code
|
||||
else:
|
||||
self.config["group_app_map"] = {group_name: app_code}
|
||||
# 保存插件配置
|
||||
super().save_config(self.config)
|
||||
_set_reply_text(f"应用设置成功: {app_code}", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if len(cmd) == 3 and cmd[1] == "sum" and (cmd[2] == "open" or cmd[2] == "close"):
|
||||
# 知识库开关指令
|
||||
if not Util.is_admin(e_context):
|
||||
_set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
is_open = True
|
||||
tips_text = "开启"
|
||||
if cmd[2] == "close":
|
||||
tips_text = "关闭"
|
||||
is_open = False
|
||||
if not self.sum_config:
|
||||
_set_reply_text(f"插件未启用summary功能,请参考以下链添加插件配置\n\nhttps://github.com/zhayujie/chatgpt-on-wechat/blob/master/plugins/linkai/README.md", e_context, level=ReplyType.INFO)
|
||||
else:
|
||||
self.sum_config["enabled"] = is_open
|
||||
_set_reply_text(f"文章总结功能{tips_text}", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
_set_reply_text(f"指令错误,请输入{_get_trigger_prefix()}linkai help 获取帮助", e_context,
|
||||
level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
def _is_summary_open(self, context) -> bool:
|
||||
if not self.sum_config or not self.sum_config.get("enabled"):
|
||||
return False
|
||||
if context.kwargs.get("isgroup") and not self.sum_config.get("group_enabled"):
|
||||
return False
|
||||
support_type = self.sum_config.get("type") or ["FILE", "SHARING"]
|
||||
if context.type.name not in support_type:
|
||||
return False
|
||||
return True
|
||||
|
||||
# LinkAI 对话任务处理
|
||||
def _is_chat_task(self, e_context: EventContext):
|
||||
context = e_context['context']
|
||||
# 群聊应用管理
|
||||
return self.config.get("group_app_map") and context.kwargs.get("isgroup")
|
||||
|
||||
def _process_chat_task(self, e_context: EventContext):
|
||||
"""
|
||||
处理LinkAI对话任务
|
||||
:param e_context: 对话上下文
|
||||
"""
|
||||
context = e_context['context']
|
||||
# 群聊应用管理
|
||||
group_name = context.get("msg").from_user_nickname
|
||||
app_code = self._fetch_group_app_code(group_name)
|
||||
if app_code:
|
||||
context.kwargs['app_code'] = app_code
|
||||
|
||||
def _fetch_group_app_code(self, group_name: str) -> str:
|
||||
"""
|
||||
根据群聊名称获取对应的应用code
|
||||
:param group_name: 群聊名称
|
||||
:return: 应用code
|
||||
"""
|
||||
group_mapping = self.config.get("group_app_map")
|
||||
if group_mapping:
|
||||
app_code = group_mapping.get(group_name) or group_mapping.get("ALL_GROUP")
|
||||
return app_code
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
trigger_prefix = _get_trigger_prefix()
|
||||
help_text = "用于集成 LinkAI 提供的知识库、Midjourney绘画、文档总结、联网搜索等能力。\n\n"
|
||||
if not verbose:
|
||||
return help_text
|
||||
help_text += f'📖 知识库\n - 群聊中指定应用: {trigger_prefix}linkai app 应用编码\n'
|
||||
help_text += f' - {trigger_prefix}linkai open: 开启对话\n'
|
||||
help_text += f' - {trigger_prefix}linkai close: 关闭对话\n'
|
||||
help_text += f'\n例如: \n"{trigger_prefix}linkai app Kv2fXJcH"\n\n'
|
||||
help_text += f"🎨 绘画\n - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: {trigger_prefix}mjv 图片ID 图片序号\n - 重置: {trigger_prefix}mjr 图片ID"
|
||||
help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
|
||||
help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
|
||||
help_text += f"\n\n💡 文档总结和对话\n - 开启: {trigger_prefix}linkai sum open\n - 使用: 发送文件、公众号文章等可生成摘要,并与内容对话"
|
||||
return help_text
|
||||
|
||||
def _load_config_template(self):
|
||||
logger.debug("No LinkAI plugin config.json, use plugins/linkai/config.json.template")
|
||||
try:
|
||||
plugin_config_path = os.path.join(self.path, "config.json.template")
|
||||
if os.path.exists(plugin_config_path):
|
||||
with open(plugin_config_path, "r", encoding="utf-8") as f:
|
||||
plugin_conf = json.load(f)
|
||||
plugin_conf["midjourney"]["enabled"] = False
|
||||
plugin_conf["summary"]["enabled"] = False
|
||||
return plugin_conf
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
def _send_info(e_context: EventContext, content: str):
|
||||
reply = Reply(ReplyType.TEXT, content)
|
||||
channel = e_context["channel"]
|
||||
channel.send(reply, e_context["context"])
|
||||
|
||||
|
||||
def _find_user_id(context):
|
||||
if context["isgroup"]:
|
||||
return context.kwargs.get("msg").actual_user_id
|
||||
else:
|
||||
return context["receiver"]
|
||||
|
||||
|
||||
def _set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
|
||||
reply = Reply(level, content)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
def _get_trigger_prefix():
|
||||
return conf().get("plugin_trigger_prefix", "$")
|
||||
|
||||
def _find_sum_id(context):
|
||||
return USER_FILE_MAP.get(_find_user_id(context) + "-sum_id")
|
||||
|
||||
def _find_file_id(context):
|
||||
user_id = _find_user_id(context)
|
||||
if user_id:
|
||||
return USER_FILE_MAP.get(user_id + "-file_id")
|
||||
|
||||
USER_FILE_MAP = ExpiredDict(conf().get("expires_in_seconds") or 60 * 30)
|
||||
432
plugins/linkai/midjourney.py
Normal file
432
plugins/linkai/midjourney.py
Normal file
@@ -0,0 +1,432 @@
|
||||
from enum import Enum
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import requests
|
||||
import threading
|
||||
import time
|
||||
from bridge.reply import Reply, ReplyType
|
||||
import asyncio
|
||||
from bridge.context import ContextType
|
||||
from plugins import EventContext, EventAction
|
||||
from .utils import Util
|
||||
|
||||
INVALID_REQUEST = 410
|
||||
NOT_FOUND_ORIGIN_IMAGE = 461
|
||||
NOT_FOUND_TASK = 462
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
GENERATE = "generate"
|
||||
UPSCALE = "upscale"
|
||||
VARIATION = "variation"
|
||||
RESET = "reset"
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
PENDING = "pending"
|
||||
FINISHED = "finished"
|
||||
EXPIRED = "expired"
|
||||
ABORTED = "aborted"
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class TaskMode(Enum):
|
||||
FAST = "fast"
|
||||
RELAX = "relax"
|
||||
|
||||
|
||||
task_name_mapping = {
|
||||
TaskType.GENERATE.name: "生成",
|
||||
TaskType.UPSCALE.name: "放大",
|
||||
TaskType.VARIATION.name: "变换",
|
||||
TaskType.RESET.name: "重新生成",
|
||||
}
|
||||
|
||||
|
||||
class MJTask:
|
||||
def __init__(self, id, user_id: str, task_type: TaskType, raw_prompt=None, expires: int = 60 * 6,
|
||||
status=Status.PENDING):
|
||||
self.id = id
|
||||
self.user_id = user_id
|
||||
self.task_type = task_type
|
||||
self.raw_prompt = raw_prompt
|
||||
self.send_func = None # send_func(img_url)
|
||||
self.expiry_time = time.time() + expires
|
||||
self.status = status
|
||||
self.img_url = None # url
|
||||
self.img_id = None
|
||||
|
||||
def __str__(self):
|
||||
return f"id={self.id}, user_id={self.user_id}, task_type={self.task_type}, status={self.status}, img_id={self.img_id}"
|
||||
|
||||
|
||||
# midjourney bot
|
||||
class MJBot:
|
||||
def __init__(self, config):
|
||||
self.base_url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/img/midjourney"
|
||||
self.headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
self.config = config
|
||||
self.tasks = {}
|
||||
self.temp_dict = {}
|
||||
self.tasks_lock = threading.Lock()
|
||||
self.event_loop = asyncio.new_event_loop()
|
||||
|
||||
def judge_mj_task_type(self, e_context: EventContext):
|
||||
"""
|
||||
判断MJ任务的类型
|
||||
:param e_context: 上下文
|
||||
:return: 任务类型枚举
|
||||
"""
|
||||
if not self.config:
|
||||
return None
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
context = e_context['context']
|
||||
if context.type == ContextType.TEXT:
|
||||
cmd_list = context.content.split(maxsplit=1)
|
||||
if not cmd_list:
|
||||
return None
|
||||
if cmd_list[0].lower() == f"{trigger_prefix}mj":
|
||||
return TaskType.GENERATE
|
||||
elif cmd_list[0].lower() == f"{trigger_prefix}mju":
|
||||
return TaskType.UPSCALE
|
||||
elif cmd_list[0].lower() == f"{trigger_prefix}mjv":
|
||||
return TaskType.VARIATION
|
||||
elif cmd_list[0].lower() == f"{trigger_prefix}mjr":
|
||||
return TaskType.RESET
|
||||
elif context.type == ContextType.IMAGE_CREATE and self.config.get("use_image_create_prefix") and self.config.get("enabled"):
|
||||
return TaskType.GENERATE
|
||||
|
||||
def process_mj_task(self, mj_type: TaskType, e_context: EventContext):
|
||||
"""
|
||||
处理mj任务
|
||||
:param mj_type: mj任务类型
|
||||
:param e_context: 对话上下文
|
||||
"""
|
||||
context = e_context['context']
|
||||
session_id = context["session_id"]
|
||||
cmd = context.content.split(maxsplit=1)
|
||||
if len(cmd) == 1 and context.type == ContextType.TEXT:
|
||||
# midjourney 帮助指令
|
||||
self._set_reply_text(self.get_help_text(verbose=True), e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if len(cmd) == 2 and (cmd[1] == "open" or cmd[1] == "close"):
|
||||
if not Util.is_admin(e_context):
|
||||
Util.set_reply_text("需要管理员权限执行", e_context, level=ReplyType.ERROR)
|
||||
return
|
||||
# midjourney 开关指令
|
||||
is_open = True
|
||||
tips_text = "开启"
|
||||
if cmd[1] == "close":
|
||||
tips_text = "关闭"
|
||||
is_open = False
|
||||
self.config["enabled"] = is_open
|
||||
self._set_reply_text(f"Midjourney绘画已{tips_text}", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if not self.config.get("enabled"):
|
||||
logger.warn("Midjourney绘画未开启,请查看 plugins/linkai/config.json 中的配置")
|
||||
self._set_reply_text(f"Midjourney绘画未开启", e_context, level=ReplyType.INFO)
|
||||
return
|
||||
|
||||
if not self._check_rate_limit(session_id, e_context):
|
||||
logger.warn("[MJ] midjourney task exceed rate limit")
|
||||
return
|
||||
|
||||
if mj_type == TaskType.GENERATE:
|
||||
if context.type == ContextType.IMAGE_CREATE:
|
||||
raw_prompt = context.content
|
||||
else:
|
||||
# 图片生成
|
||||
raw_prompt = cmd[1]
|
||||
reply = self.generate(raw_prompt, session_id, e_context)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
elif mj_type == TaskType.UPSCALE or mj_type == TaskType.VARIATION:
|
||||
# 图片放大/变换
|
||||
clist = cmd[1].split()
|
||||
if len(clist) < 2:
|
||||
self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
|
||||
return
|
||||
img_id = clist[0]
|
||||
index = int(clist[1])
|
||||
if index < 1 or index > 4:
|
||||
self._set_reply_text(f"图片序号 {index} 错误,应在 1 至 4 之间", e_context)
|
||||
return
|
||||
key = f"{str(mj_type)}_{img_id}_{index}"
|
||||
if self.temp_dict.get(key):
|
||||
self._set_reply_text(f"第 {index} 张图片已经{task_name_mapping.get(str(mj_type))}过了", e_context)
|
||||
return
|
||||
# 执行图片放大/变换操作
|
||||
reply = self.do_operate(mj_type, session_id, img_id, e_context, index)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return
|
||||
|
||||
elif mj_type == TaskType.RESET:
|
||||
# 图片重新生成
|
||||
clist = cmd[1].split()
|
||||
if len(clist) < 1:
|
||||
self._set_reply_text(f"{cmd[0]} 命令缺少参数", e_context)
|
||||
return
|
||||
img_id = clist[0]
|
||||
# 图片重新生成
|
||||
reply = self.do_operate(mj_type, session_id, img_id, e_context)
|
||||
e_context['reply'] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
else:
|
||||
self._set_reply_text(f"暂不支持该命令", e_context)
|
||||
|
||||
def generate(self, prompt: str, user_id: str, e_context: EventContext) -> Reply:
|
||||
"""
|
||||
图片生成
|
||||
:param prompt: 提示词
|
||||
:param user_id: 用户id
|
||||
:param e_context: 对话上下文
|
||||
:return: 任务ID
|
||||
"""
|
||||
logger.info(f"[MJ] image generate, prompt={prompt}")
|
||||
mode = self._fetch_mode(prompt)
|
||||
body = {"prompt": prompt, "mode": mode, "auto_translate": self.config.get("auto_translate")}
|
||||
if not self.config.get("img_proxy"):
|
||||
body["img_proxy"] = False
|
||||
res = requests.post(url=self.base_url + "/generate", json=body, headers=self.headers, timeout=(5, 40))
|
||||
if res.status_code == 200:
|
||||
res = res.json()
|
||||
logger.debug(f"[MJ] image generate, res={res}")
|
||||
if res.get("code") == 200:
|
||||
task_id = res.get("data").get("task_id")
|
||||
real_prompt = res.get("data").get("real_prompt")
|
||||
if mode == TaskMode.RELAX.value:
|
||||
time_str = "1~10分钟"
|
||||
else:
|
||||
time_str = "1分钟"
|
||||
content = f"🚀您的作品将在{time_str}左右完成,请耐心等待\n- - - - - - - - -\n"
|
||||
if real_prompt:
|
||||
content += f"初始prompt: {prompt}\n转换后prompt: {real_prompt}"
|
||||
else:
|
||||
content += f"prompt: {prompt}"
|
||||
reply = Reply(ReplyType.INFO, content)
|
||||
task = MJTask(id=task_id, status=Status.PENDING, raw_prompt=prompt, user_id=user_id,
|
||||
task_type=TaskType.GENERATE)
|
||||
# put to memory dict
|
||||
self.tasks[task.id] = task
|
||||
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
||||
self._do_check_task(task, e_context)
|
||||
return reply
|
||||
else:
|
||||
res_json = res.json()
|
||||
logger.error(f"[MJ] generate error, msg={res_json.get('message')}, status_code={res.status_code}")
|
||||
if res.status_code == INVALID_REQUEST:
|
||||
reply = Reply(ReplyType.ERROR, "图片生成失败,请检查提示词参数或内容")
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "图片生成失败,请稍后再试")
|
||||
return reply
|
||||
|
||||
def do_operate(self, task_type: TaskType, user_id: str, img_id: str, e_context: EventContext,
|
||||
index: int = None) -> Reply:
|
||||
logger.info(f"[MJ] image operate, task_type={task_type}, img_id={img_id}, index={index}")
|
||||
body = {"type": task_type.name, "img_id": img_id}
|
||||
if index:
|
||||
body["index"] = index
|
||||
if not self.config.get("img_proxy"):
|
||||
body["img_proxy"] = False
|
||||
res = requests.post(url=self.base_url + "/operate", json=body, headers=self.headers, timeout=(5, 40))
|
||||
logger.debug(res)
|
||||
if res.status_code == 200:
|
||||
res = res.json()
|
||||
if res.get("code") == 200:
|
||||
task_id = res.get("data").get("task_id")
|
||||
logger.info(f"[MJ] image operate processing, task_id={task_id}")
|
||||
icon_map = {TaskType.UPSCALE: "🔎", TaskType.VARIATION: "🪄", TaskType.RESET: "🔄"}
|
||||
content = f"{icon_map.get(task_type)}图片正在{task_name_mapping.get(task_type.name)}中,请耐心等待"
|
||||
reply = Reply(ReplyType.INFO, content)
|
||||
task = MJTask(id=task_id, status=Status.PENDING, user_id=user_id, task_type=task_type)
|
||||
# put to memory dict
|
||||
self.tasks[task.id] = task
|
||||
key = f"{task_type.name}_{img_id}_{index}"
|
||||
self.temp_dict[key] = True
|
||||
# asyncio.run_coroutine_threadsafe(self.check_task(task, e_context), self.event_loop)
|
||||
self._do_check_task(task, e_context)
|
||||
return reply
|
||||
else:
|
||||
error_msg = ""
|
||||
if res.status_code == NOT_FOUND_ORIGIN_IMAGE:
|
||||
error_msg = "请输入正确的图片ID"
|
||||
res_json = res.json()
|
||||
logger.error(f"[MJ] operate error, msg={res_json.get('message')}, status_code={res.status_code}")
|
||||
reply = Reply(ReplyType.ERROR, error_msg or "图片生成失败,请稍后再试")
|
||||
return reply
|
||||
|
||||
def check_task_sync(self, task: MJTask, e_context: EventContext):
|
||||
logger.debug(f"[MJ] start check task status, {task}")
|
||||
max_retry_times = 90
|
||||
while max_retry_times > 0:
|
||||
time.sleep(10)
|
||||
url = f"{self.base_url}/tasks/{task.id}"
|
||||
try:
|
||||
res = requests.get(url, headers=self.headers, timeout=8)
|
||||
if res.status_code == 200:
|
||||
res_json = res.json()
|
||||
logger.debug(f"[MJ] task check res sync, task_id={task.id}, status={res.status_code}, "
|
||||
f"data={res_json.get('data')}, thread={threading.current_thread().name}")
|
||||
if res_json.get("data") and res_json.get("data").get("status") == Status.FINISHED.name:
|
||||
# process success res
|
||||
if self.tasks.get(task.id):
|
||||
self.tasks[task.id].status = Status.FINISHED
|
||||
self._process_success_task(task, res_json.get("data"), e_context)
|
||||
return
|
||||
max_retry_times -= 1
|
||||
else:
|
||||
res_json = res.json()
|
||||
logger.warn(f"[MJ] image check error, status_code={res.status_code}, res={res_json}")
|
||||
max_retry_times -= 20
|
||||
except Exception as e:
|
||||
max_retry_times -= 20
|
||||
logger.warn(e)
|
||||
logger.warn("[MJ] end from poll")
|
||||
if self.tasks.get(task.id):
|
||||
self.tasks[task.id].status = Status.EXPIRED
|
||||
|
||||
def _do_check_task(self, task: MJTask, e_context: EventContext):
|
||||
threading.Thread(target=self.check_task_sync, args=(task, e_context)).start()
|
||||
|
||||
def _process_success_task(self, task: MJTask, res: dict, e_context: EventContext):
|
||||
"""
|
||||
处理任务成功的结果
|
||||
:param task: MJ任务
|
||||
:param res: 请求结果
|
||||
:param e_context: 对话上下文
|
||||
"""
|
||||
# channel send img
|
||||
task.status = Status.FINISHED
|
||||
task.img_id = res.get("img_id")
|
||||
task.img_url = res.get("img_url")
|
||||
logger.info(f"[MJ] task success, task_id={task.id}, img_id={task.img_id}, img_url={task.img_url}")
|
||||
|
||||
# send img
|
||||
reply = Reply(ReplyType.IMAGE_URL, task.img_url)
|
||||
channel = e_context["channel"]
|
||||
_send(channel, reply, e_context["context"])
|
||||
|
||||
# send info
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
text = ""
|
||||
if task.task_type == TaskType.GENERATE or task.task_type == TaskType.VARIATION or task.task_type == TaskType.RESET:
|
||||
text = f"🎨绘画完成!\n"
|
||||
if task.raw_prompt:
|
||||
text += f"prompt: {task.raw_prompt}\n"
|
||||
text += f"- - - - - - - - -\n图片ID: {task.img_id}"
|
||||
text += f"\n\n🔎使用 {trigger_prefix}mju 命令放大图片\n"
|
||||
text += f"例如:\n{trigger_prefix}mju {task.img_id} 1"
|
||||
text += f"\n\n🪄使用 {trigger_prefix}mjv 命令变换图片\n"
|
||||
text += f"例如:\n{trigger_prefix}mjv {task.img_id} 1"
|
||||
text += f"\n\n🔄使用 {trigger_prefix}mjr 命令重新生成图片\n"
|
||||
text += f"例如:\n{trigger_prefix}mjr {task.img_id}"
|
||||
reply = Reply(ReplyType.INFO, text)
|
||||
_send(channel, reply, e_context["context"])
|
||||
|
||||
self._print_tasks()
|
||||
return
|
||||
|
||||
def _check_rate_limit(self, user_id: str, e_context: EventContext) -> bool:
|
||||
"""
|
||||
midjourney任务限流控制
|
||||
:param user_id: 用户id
|
||||
:param e_context: 对话上下文
|
||||
:return: 任务是否能够生成, True:可以生成, False: 被限流
|
||||
"""
|
||||
tasks = self.find_tasks_by_user_id(user_id)
|
||||
task_count = len([t for t in tasks if t.status == Status.PENDING])
|
||||
if task_count >= self.config.get("max_tasks_per_user"):
|
||||
reply = Reply(ReplyType.INFO, "您的Midjourney作图任务数已达上限,请稍后再试")
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return False
|
||||
task_count = len([t for t in self.tasks.values() if t.status == Status.PENDING])
|
||||
if task_count >= self.config.get("max_tasks"):
|
||||
reply = Reply(ReplyType.INFO, "Midjourney作图任务数已达上限,请稍后再试")
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
return False
|
||||
return True
|
||||
|
||||
def _fetch_mode(self, prompt) -> str:
|
||||
mode = self.config.get("mode")
|
||||
if "--relax" in prompt or mode == TaskMode.RELAX.value:
|
||||
return TaskMode.RELAX.value
|
||||
return mode or TaskMode.FAST.value
|
||||
|
||||
def _run_loop(self, loop: asyncio.BaseEventLoop):
|
||||
"""
|
||||
运行事件循环,用于轮询任务的线程
|
||||
:param loop: 事件循环
|
||||
"""
|
||||
loop.run_forever()
|
||||
loop.stop()
|
||||
|
||||
def _print_tasks(self):
|
||||
for id in self.tasks:
|
||||
logger.debug(f"[MJ] current task: {self.tasks[id]}")
|
||||
|
||||
def _set_reply_text(self, content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
|
||||
"""
|
||||
设置回复文本
|
||||
:param content: 回复内容
|
||||
:param e_context: 对话上下文
|
||||
:param level: 回复等级
|
||||
"""
|
||||
reply = Reply(level, content)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
|
||||
def get_help_text(self, verbose=False, **kwargs):
|
||||
trigger_prefix = conf().get("plugin_trigger_prefix", "$")
|
||||
help_text = "🎨利用Midjourney进行画图\n\n"
|
||||
if not verbose:
|
||||
return help_text
|
||||
help_text += f" - 生成: {trigger_prefix}mj 描述词1, 描述词2.. \n - 放大: {trigger_prefix}mju 图片ID 图片序号\n - 变换: mjv 图片ID 图片序号\n - 重置: mjr 图片ID"
|
||||
help_text += f"\n\n例如:\n\"{trigger_prefix}mj a little cat, white --ar 9:16\"\n\"{trigger_prefix}mju 11055927171882 2\""
|
||||
help_text += f"\n\"{trigger_prefix}mjv 11055927171882 2\"\n\"{trigger_prefix}mjr 11055927171882\""
|
||||
return help_text
|
||||
|
||||
def find_tasks_by_user_id(self, user_id) -> list:
|
||||
result = []
|
||||
with self.tasks_lock:
|
||||
now = time.time()
|
||||
for task in self.tasks.values():
|
||||
if task.status == Status.PENDING and now > task.expiry_time:
|
||||
task.status = Status.EXPIRED
|
||||
logger.info(f"[MJ] {task} expired")
|
||||
if task.user_id == user_id:
|
||||
result.append(task)
|
||||
return result
|
||||
|
||||
|
||||
def _send(channel, reply: Reply, context, retry_cnt=0):
|
||||
try:
|
||||
channel.send(reply, context)
|
||||
except Exception as e:
|
||||
logger.error("[WX] sendMsg error: {}".format(str(e)))
|
||||
if isinstance(e, NotImplementedError):
|
||||
return
|
||||
logger.exception(e)
|
||||
if retry_cnt < 2:
|
||||
time.sleep(3 + 3 * retry_cnt)
|
||||
channel.send(reply, context, retry_cnt + 1)
|
||||
|
||||
|
||||
def check_prefix(content, prefix_list):
|
||||
if not prefix_list:
|
||||
return None
|
||||
for prefix in prefix_list:
|
||||
if content.startswith(prefix):
|
||||
return prefix
|
||||
return None
|
||||
94
plugins/linkai/summary.py
Normal file
94
plugins/linkai/summary.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import requests
|
||||
from config import conf
|
||||
from common.log import logger
|
||||
import os
|
||||
|
||||
|
||||
class LinkSummary:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def summary_file(self, file_path: str):
|
||||
file_body = {
|
||||
"file": open(file_path, "rb"),
|
||||
"name": file_path.split("/")[-1],
|
||||
}
|
||||
url = self.base_url() + "/v1/summary/file"
|
||||
res = requests.post(url, headers=self.headers(), files=file_body, timeout=(5, 300))
|
||||
return self._parse_summary_res(res)
|
||||
|
||||
def summary_url(self, url: str):
|
||||
body = {
|
||||
"url": url
|
||||
}
|
||||
res = requests.post(url=self.base_url() + "/v1/summary/url", headers=self.headers(), json=body, timeout=(5, 180))
|
||||
return self._parse_summary_res(res)
|
||||
|
||||
def summary_chat(self, summary_id: str):
|
||||
body = {
|
||||
"summary_id": summary_id
|
||||
}
|
||||
res = requests.post(url=self.base_url() + "/v1/summary/chat", headers=self.headers(), json=body, timeout=(5, 180))
|
||||
if res.status_code == 200:
|
||||
res = res.json()
|
||||
logger.debug(f"[LinkSum] chat open, res={res}")
|
||||
if res.get("code") == 200:
|
||||
data = res.get("data")
|
||||
return {
|
||||
"questions": data.get("questions"),
|
||||
"file_id": data.get("file_id")
|
||||
}
|
||||
else:
|
||||
res_json = res.json()
|
||||
logger.error(f"[LinkSum] summary error, status_code={res.status_code}, msg={res_json.get('message')}")
|
||||
return None
|
||||
|
||||
def _parse_summary_res(self, res):
|
||||
if res.status_code == 200:
|
||||
res = res.json()
|
||||
logger.debug(f"[LinkSum] url summary, res={res}")
|
||||
if res.get("code") == 200:
|
||||
data = res.get("data")
|
||||
return {
|
||||
"summary": data.get("summary"),
|
||||
"summary_id": data.get("summary_id")
|
||||
}
|
||||
else:
|
||||
res_json = res.json()
|
||||
logger.error(f"[LinkSum] summary error, status_code={res.status_code}, msg={res_json.get('message')}")
|
||||
return None
|
||||
|
||||
def base_url(self):
|
||||
return conf().get("linkai_api_base", "https://api.link-ai.chat")
|
||||
|
||||
def headers(self):
|
||||
return {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
|
||||
def check_file(self, file_path: str, sum_config: dict) -> bool:
|
||||
file_size = os.path.getsize(file_path) // 1000
|
||||
|
||||
if (sum_config.get("max_file_size") and file_size > sum_config.get("max_file_size")) or file_size > 15000:
|
||||
logger.warn(f"[LinkSum] file size exceeds limit, No processing, file_size={file_size}KB")
|
||||
return False
|
||||
|
||||
suffix = file_path.split(".")[-1]
|
||||
support_list = ["txt", "csv", "docx", "pdf", "md", "jpg", "jpeg", "png"]
|
||||
if suffix not in support_list:
|
||||
logger.warn(f"[LinkSum] unsupported file, suffix={suffix}, support_list={support_list}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def check_url(self, url: str):
|
||||
if not url:
|
||||
return False
|
||||
support_list = ["http://mp.weixin.qq.com", "https://mp.weixin.qq.com"]
|
||||
black_support_list = ["https://mp.weixin.qq.com/mp/waerrpage"]
|
||||
for black_url_prefix in black_support_list:
|
||||
if url.strip().startswith(black_url_prefix):
|
||||
logger.warn(f"[LinkSum] unsupported url, no need to process, url={url}")
|
||||
return False
|
||||
for support_url in support_list:
|
||||
if url.strip().startswith(support_url):
|
||||
return True
|
||||
return False
|
||||
28
plugins/linkai/utils.py
Normal file
28
plugins/linkai/utils.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from config import global_config
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from plugins.event import EventContext, EventAction
|
||||
|
||||
|
||||
class Util:
|
||||
@staticmethod
|
||||
def is_admin(e_context: EventContext) -> bool:
|
||||
"""
|
||||
判断消息是否由管理员用户发送
|
||||
:param e_context: 消息上下文
|
||||
:return: True: 是, False: 否
|
||||
"""
|
||||
context = e_context["context"]
|
||||
if context["isgroup"]:
|
||||
actual_user_id = context.kwargs.get("msg").actual_user_id
|
||||
for admin_user in global_config["admin_users"]:
|
||||
if actual_user_id and actual_user_id in admin_user:
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return context["receiver"] in global_config["admin_users"]
|
||||
|
||||
@staticmethod
|
||||
def set_reply_text(content: str, e_context: EventContext, level: ReplyType = ReplyType.ERROR):
|
||||
reply = Reply(level, content)
|
||||
e_context["reply"] = reply
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
@@ -1,6 +1,51 @@
|
||||
import os
|
||||
import json
|
||||
from config import pconf, plugin_config, conf
|
||||
from common.log import logger
|
||||
|
||||
|
||||
class Plugin:
|
||||
def __init__(self):
|
||||
self.handlers = {}
|
||||
|
||||
def load_config(self) -> dict:
|
||||
"""
|
||||
加载当前插件配置
|
||||
:return: 插件配置字典
|
||||
"""
|
||||
# 优先获取 plugins/config.json 中的全局配置
|
||||
plugin_conf = pconf(self.name)
|
||||
if not plugin_conf:
|
||||
# 全局配置不存在,则获取插件目录下的配置
|
||||
plugin_config_path = os.path.join(self.path, "config.json")
|
||||
if os.path.exists(plugin_config_path):
|
||||
with open(plugin_config_path, "r", encoding="utf-8") as f:
|
||||
plugin_conf = json.load(f)
|
||||
|
||||
# 写入全局配置内存
|
||||
plugin_config[self.name] = plugin_conf
|
||||
logger.debug(f"loading plugin config, plugin_name={self.name}, conf={plugin_conf}")
|
||||
return plugin_conf
|
||||
|
||||
def save_config(self, config: dict):
|
||||
try:
|
||||
plugin_config[self.name] = config
|
||||
# 写入全局配置
|
||||
global_config_path = "./plugins/config.json"
|
||||
if os.path.exists(global_config_path):
|
||||
with open(global_config_path, "w", encoding='utf-8') as f:
|
||||
json.dump(plugin_config, f, indent=4, ensure_ascii=False)
|
||||
# 写入插件配置
|
||||
plugin_config_path = os.path.join(self.path, "config.json")
|
||||
if os.path.exists(plugin_config_path):
|
||||
with open(plugin_config_path, "w", encoding='utf-8') as f:
|
||||
json.dump(config, f, indent=4, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.warn("save plugin config failed: {}".format(e))
|
||||
|
||||
def get_help_text(self, **kwargs):
|
||||
return "暂无帮助信息"
|
||||
|
||||
def reload(self):
|
||||
pass
|
||||
|
||||
@@ -9,7 +9,7 @@ import sys
|
||||
from common.log import logger
|
||||
from common.singleton import singleton
|
||||
from common.sorted_dict import SortedDict
|
||||
from config import conf
|
||||
from config import conf, write_plugin_config
|
||||
|
||||
from .event import *
|
||||
|
||||
@@ -62,6 +62,28 @@ class PluginManager:
|
||||
self.save_config()
|
||||
return pconf
|
||||
|
||||
@staticmethod
|
||||
def _load_all_config():
|
||||
"""
|
||||
背景: 目前插件配置存放于每个插件目录的config.json下,docker运行时不方便进行映射,故增加统一管理的入口,优先
|
||||
加载 plugins/config.json,原插件目录下的config.json 不受影响
|
||||
|
||||
从 plugins/config.json 中加载所有插件的配置并写入 config.py 的全局配置中,供插件中使用
|
||||
插件实例中通过 config.pconf(plugin_name) 即可获取该插件的配置
|
||||
"""
|
||||
all_config_path = "./plugins/config.json"
|
||||
try:
|
||||
if os.path.exists(all_config_path):
|
||||
# read from all plugins config
|
||||
with open(all_config_path, "r", encoding="utf-8") as f:
|
||||
all_conf = json.load(f)
|
||||
logger.info(f"load all config from plugins/config.json: {all_conf}")
|
||||
|
||||
# write to global config
|
||||
write_plugin_config(all_conf)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def scan_plugins(self):
|
||||
logger.info("Scaning plugins ...")
|
||||
plugins_dir = "./plugins"
|
||||
@@ -88,7 +110,7 @@ class PluginManager:
|
||||
self.loaded[plugin_path] = importlib.import_module(import_path)
|
||||
self.current_plugin_path = None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to import plugin %s: %s" % (plugin_name, e))
|
||||
logger.warn("Failed to import plugin %s: %s" % (plugin_name, e))
|
||||
continue
|
||||
pconf = self.pconf
|
||||
news = [self.plugins[name] for name in self.plugins]
|
||||
@@ -123,7 +145,7 @@ class PluginManager:
|
||||
try:
|
||||
instance = plugincls()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to init %s, diabled. %s" % (name, e))
|
||||
logger.warn("Failed to init %s, diabled. %s" % (name, e))
|
||||
self.disable_plugin(name)
|
||||
failed_plugins.append(name)
|
||||
continue
|
||||
@@ -149,6 +171,8 @@ class PluginManager:
|
||||
def load_plugins(self):
|
||||
self.load_config()
|
||||
self.scan_plugins()
|
||||
# 加载全量插件配置
|
||||
self._load_all_config()
|
||||
pconf = self.pconf
|
||||
logger.debug("plugins.json config={}".format(pconf))
|
||||
for name, plugin in pconf["plugins"].items():
|
||||
|
||||
@@ -11,6 +11,14 @@
|
||||
"summary": {
|
||||
"url": "https://github.com/lanvent/plugin_summary.git",
|
||||
"desc": "总结聊天记录的插件"
|
||||
},
|
||||
"timetask": {
|
||||
"url": "https://github.com/haikerapples/timetask.git",
|
||||
"desc": "一款定时任务系统的插件"
|
||||
},
|
||||
"Apilot": {
|
||||
"url": "https://github.com/6vision/Apilot.git",
|
||||
"desc": "通过api直接查询早报、热榜、快递、天气等实用信息的插件"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,19 @@
|
||||
使用说明(默认trigger_prefix为$):
|
||||
```text
|
||||
#help tool: 查看tool帮助信息,可查看已加载工具列表
|
||||
$tool 命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。
|
||||
$tool 工具名 命令: (pure模式)根据给出的{命令}使用指定 一个 可用工具尽力为你得到结果。
|
||||
$tool 命令: (多工具模式)根据给出的{命令}使用 一些 可用工具尽力为你得到结果。
|
||||
$tool reset: 重置工具。
|
||||
```
|
||||
### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)
|
||||
|
||||
2024.01.16更新
|
||||
1. 新增工具pure模式,支持单个工具调用
|
||||
2. 新增消息转发工具:email, sms, wechat, 可以根据规则向其他平台发送消息
|
||||
3. 替换visual-dl(更名为visual)实现,目前识别图片链接效果较好。
|
||||
4. 修复了0.4版本大部分工具返回结果不可靠问题
|
||||
|
||||
新版本工具名共19个,不一一列举,相应工具需要的环境参数见`tool.py`里的`_build_tool_kwargs`函数
|
||||
|
||||
## 使用说明
|
||||
使用该插件后将默认使用4个工具, 无需额外配置长期生效:
|
||||
@@ -24,7 +32,7 @@ $tool reset: 重置工具。
|
||||
|
||||
> 注1:url-get默认配置、browser需额外配置,browser依赖google-chrome,你需要提前安装好
|
||||
|
||||
> 注2:当检测到长文本时会进入summary tool总结长文本,tokens可能会大量消耗!
|
||||
> 注2:(可通过`browser_use_summary`或 `url_get_use_summary`开关)当检测到长文本时会进入summary tool总结长文本,tokens可能会大量消耗!
|
||||
|
||||
这是debian端安装google-chrome教程,其他系统请自行查找
|
||||
> https://www.linuxjournal.com/content/how-can-you-install-google-browser-debian
|
||||
@@ -34,9 +42,10 @@ $tool reset: 重置工具。
|
||||
|
||||
> terminal调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issue-1659347640
|
||||
|
||||
### 4. meteo-weather
|
||||
### 4. meteo
|
||||
###### 回答你有关天气的询问, 需要获取时间、地点上下文信息,本工具使用了[meteo open api](https://open-meteo.com/)
|
||||
注:该工具需要较高的对话技巧,不保证你问的任何问题均能得到满意的回复
|
||||
注2:当前版本可只使用这个工具,返回结果较可控。
|
||||
|
||||
> meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334
|
||||
|
||||
@@ -65,18 +74,12 @@ $tool reset: 重置工具。
|
||||
#### 6.2. morning-news *
|
||||
###### 每日60秒早报,每天凌晨一点更新,本工具使用了[alapi-每日60秒早报](https://alapi.cn/api/view/93)
|
||||
|
||||
```text
|
||||
可配置参数:
|
||||
1. morning_news_use_llm: 是否使用LLM润色结果,默认false(可能会慢)
|
||||
```
|
||||
|
||||
> 该tool每天返回内容相同
|
||||
|
||||
#### 6.3. finance-news
|
||||
###### 获取实时的金融财政新闻
|
||||
|
||||
> 该工具需要解决browser tool 的google-chrome依赖安装
|
||||
|
||||
> 该工具需要用到browser工具解决反爬问题
|
||||
|
||||
|
||||
### 7. bing-search *
|
||||
@@ -99,33 +102,49 @@ $tool reset: 重置工具。
|
||||
> 0.4.2更新,例子:帮我找一篇吴恩达写的论文
|
||||
|
||||
### 11. summary
|
||||
###### 总结工具,该工具必须输入一个本地文件的绝对路径
|
||||
###### 总结工具,该工具可以支持输入url
|
||||
|
||||
> 该工具目前是和其他工具配合使用,暂未测试单独使用效果
|
||||
|
||||
### 12. image2text
|
||||
###### 将图片转换成文字,底层调用imageCaption模型,该工具必须输入一个本地文件的绝对路径
|
||||
### 12. visual
|
||||
###### 将图片转换成文字,底层调用ali dashscope `qwen-vl-plus`模型
|
||||
|
||||
### 13. searxng-search *
|
||||
###### 一个私有化的搜索引擎工具
|
||||
|
||||
> 安装教程:https://docs.searxng.org/admin/installation.html
|
||||
|
||||
### 14. email *
|
||||
###### 发送邮件
|
||||
|
||||
### 15. sms *
|
||||
###### 发送短信
|
||||
|
||||
### 16. stt *
|
||||
###### speak to text 语音识别
|
||||
|
||||
### 17. tts *
|
||||
###### text to speak 文生语音
|
||||
|
||||
### 18. wechat *
|
||||
###### 向好友、群组发送微信
|
||||
|
||||
---
|
||||
|
||||
###### 注1:带*工具需要获取api-key才能使用(在config.json内的kwargs添加项),部分工具需要外网支持
|
||||
#### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md)
|
||||
## [工具的api申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md)
|
||||
|
||||
## config.json 配置说明
|
||||
###### 默认工具无需配置,其它工具需手动配置,一个例子:
|
||||
###### 默认工具无需配置,其它工具需手动配置,以增加morning-news和bing-search两个工具为例:
|
||||
```json
|
||||
{
|
||||
"tools": ["wikipedia", "你想要添加的其他工具"], // 填入你想用到的额外工具名
|
||||
"tools": ["bing-search", "morning-news", "你想要添加的其他工具"], // 填入你想用到的额外工具名,这里加入了工具"bing-search"和工具"morning-news"
|
||||
"kwargs": {
|
||||
"debug": true, // 当你遇到问题求助时,需要配置
|
||||
"request_timeout": 120, // openai接口超时时间
|
||||
"no_default": false, // 是否不使用默认的4个工具
|
||||
// 带*工具需要申请api-key,在这里填入,api_name参考前述`申请方法`
|
||||
"bing_subscription_key": "4871f273a4804743",//带*工具需要申请api-key,这里填入了工具bing-search对应的api,api_name参考前述`工具的api申请方法`
|
||||
"morning_news_api_key": "5w1kjNh9VQlUc",// 这里填入了morning-news对应的api,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +155,6 @@ $tool reset: 重置工具。
|
||||
- `debug`: 输出chatgpt-tool-hub额外信息用于调试
|
||||
- `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置
|
||||
- `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具
|
||||
- `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2
|
||||
- `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认
|
||||
|
||||
---
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
"python",
|
||||
"url-get",
|
||||
"terminal",
|
||||
"meteo-weather"
|
||||
"meteo"
|
||||
],
|
||||
"kwargs": {
|
||||
"top_k_results": 2,
|
||||
"debug": false,
|
||||
"no_default": false,
|
||||
"model_name": "gpt-3.5-turbo"
|
||||
}
|
||||
|
||||
@@ -1,24 +1,20 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from chatgpt_tool_hub.apps import AppFactory
|
||||
from chatgpt_tool_hub.apps.app import App
|
||||
from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names
|
||||
from chatgpt_tool_hub.tools.tool_register import main_tool_register
|
||||
|
||||
import plugins
|
||||
from bridge.bridge import Bridge
|
||||
from bridge.context import ContextType
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common import const
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from config import conf, get_appdata_dir
|
||||
from plugins import *
|
||||
|
||||
|
||||
@plugins.register(
|
||||
name="tool",
|
||||
desc="Arming your ChatGPT bot with various tools",
|
||||
version="0.4",
|
||||
version="0.5",
|
||||
author="goldfishh",
|
||||
desire_priority=0,
|
||||
)
|
||||
@@ -37,10 +33,12 @@ class Tool(Plugin):
|
||||
if not verbose:
|
||||
return help_text
|
||||
help_text += "\n使用说明:\n"
|
||||
help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n"
|
||||
help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}模型来选择使用哪些工具尽力为你得到结果。\n"
|
||||
help_text += f"{trigger_prefix}tool 工具名 " + "命令: 根据给出的{命令}使用指定工具尽力为你得到结果。\n"
|
||||
help_text += f"{trigger_prefix}tool reset: 重置工具。\n\n"
|
||||
|
||||
help_text += f"已加载工具列表: \n"
|
||||
for idx, tool in enumerate(self.app.get_tool_list()):
|
||||
for idx, tool in enumerate(main_tool_register.get_registered_tool_names()):
|
||||
if idx != 0:
|
||||
help_text += ", "
|
||||
help_text += f"{tool}"
|
||||
@@ -92,17 +90,28 @@ class Tool(Plugin):
|
||||
|
||||
e_context.action = EventAction.BREAK
|
||||
return
|
||||
|
||||
query = content_list[1].strip()
|
||||
|
||||
use_one_tool = False
|
||||
for tool_name in main_tool_register.get_registered_tool_names():
|
||||
if query.startswith(tool_name):
|
||||
use_one_tool = True
|
||||
query = query[len(tool_name):]
|
||||
break
|
||||
|
||||
# Don't modify bot name
|
||||
all_sessions = Bridge().get_bot("chat").sessions
|
||||
user_session = all_sessions.session_query(query, e_context["context"]["session_id"]).messages
|
||||
|
||||
# chatgpt-tool-hub will reply you with many tools
|
||||
logger.debug("[tool]: just-go")
|
||||
try:
|
||||
_reply = self.app.ask(query, user_session)
|
||||
if use_one_tool:
|
||||
_func, _ = main_tool_register.get_registered_tool()[tool_name]
|
||||
tool = _func(**self.app_kwargs)
|
||||
_reply = tool.run(query)
|
||||
else:
|
||||
# chatgpt-tool-hub will reply you with many tools
|
||||
_reply = self.app.ask(query, user_session)
|
||||
e_context.action = EventAction.BREAK_PASS
|
||||
all_sessions.session_reply(_reply, e_context["context"]["session_id"])
|
||||
except Exception as e:
|
||||
@@ -119,67 +128,119 @@ class Tool(Plugin):
|
||||
return
|
||||
|
||||
def _read_json(self) -> dict:
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
tool_config = {"tools": [], "kwargs": {}}
|
||||
if not os.path.exists(config_path):
|
||||
return tool_config
|
||||
else:
|
||||
with open(config_path, "r") as f:
|
||||
tool_config = json.load(f)
|
||||
return tool_config
|
||||
default_config = {"tools": [], "kwargs": {}}
|
||||
return super().load_config() or default_config
|
||||
|
||||
def _build_tool_kwargs(self, kwargs: dict):
|
||||
tool_model_name = kwargs.get("model_name")
|
||||
request_timeout = kwargs.get("request_timeout")
|
||||
|
||||
return {
|
||||
"debug": kwargs.get("debug", False),
|
||||
"openai_api_key": conf().get("open_ai_api_key", ""),
|
||||
"open_ai_api_base": conf().get("open_ai_api_base", "https://api.openai.com/v1"),
|
||||
"proxy": conf().get("proxy", ""),
|
||||
# 全局配置相关
|
||||
"log": True, # tool 日志开关
|
||||
"debug": kwargs.get("debug", False), # 输出更多日志
|
||||
"no_default": kwargs.get("no_default", False), # 不要默认的工具,只加载自己导入的工具
|
||||
"think_depth": kwargs.get("think_depth", 2), # 一个问题最多使用多少次工具
|
||||
"proxy": conf().get("proxy", ""), # 科学上网
|
||||
"request_timeout": request_timeout if request_timeout else conf().get("request_timeout", 120),
|
||||
"temperature": kwargs.get("temperature", 0), # llm 温度,建议设置0
|
||||
# LLM配置相关
|
||||
"llm_api_key": conf().get("open_ai_api_key", ""), # 如果llm api用key鉴权,传入这里
|
||||
"llm_api_base_url": conf().get("open_ai_api_base", "https://api.openai.com/v1"), # 支持openai接口的llm服务地址前缀
|
||||
"deployment_id": conf().get("azure_deployment_id", ""), # azure openai会用到
|
||||
# note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置
|
||||
"model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"),
|
||||
"no_default": kwargs.get("no_default", False),
|
||||
"top_k_results": kwargs.get("top_k_results", 3),
|
||||
# for news tool
|
||||
"news_api_key": kwargs.get("news_api_key", ""),
|
||||
"model_name": tool_model_name if tool_model_name else conf().get("model", const.GPT35),
|
||||
# 工具配置相关
|
||||
# for arxiv tool
|
||||
"arxiv_simple": kwargs.get("arxiv_simple", True), # 返回内容更精简
|
||||
"arxiv_top_k_results": kwargs.get("arxiv_top_k_results", 2), # 只返回前k个搜索结果
|
||||
"arxiv_sort_by": kwargs.get("arxiv_sort_by", "relevance"), # 搜索排序方式 ["relevance","lastUpdatedDate","submittedDate"]
|
||||
"arxiv_sort_order": kwargs.get("arxiv_sort_order", "descending"), # 搜索排序方式 ["ascending", "descending"]
|
||||
"arxiv_output_type": kwargs.get("arxiv_output_type", "text"), # 搜索结果类型 ["text", "pdf", "all"]
|
||||
# for bing-search tool
|
||||
"bing_subscription_key": kwargs.get("bing_subscription_key", ""),
|
||||
"bing_search_url": kwargs.get("bing_search_url", "https://api.bing.microsoft.com/v7.0/search"), # 必应搜索的endpoint地址,无需修改
|
||||
"bing_search_top_k_results": kwargs.get("bing_search_top_k_results", 2), # 只返回前k个搜索结果
|
||||
"bing_search_simple": kwargs.get("bing_search_simple", True), # 返回内容更精简
|
||||
"bing_search_output_type": kwargs.get("bing_search_output_type", "text"), # 搜索结果类型 ["text", "json"]
|
||||
# for email tool
|
||||
"email_nickname_mapping": kwargs.get("email_nickname_mapping", "{}"), # 关于人的代号对应的邮箱地址,可以不输入邮箱地址发送邮件。键为代号值为邮箱地址
|
||||
"email_smtp_host": kwargs.get("email_smtp_host", ""), # 例如 'smtp.qq.com'
|
||||
"email_smtp_port": kwargs.get("email_smtp_port", ""), # 例如 587
|
||||
"email_sender": kwargs.get("email_sender", ""), # 发送者的邮件地址
|
||||
"email_authorization_code": kwargs.get("email_authorization_code", ""), # 发送者验证秘钥(可能不是登录密码)
|
||||
# for google-search tool
|
||||
"google_api_key": kwargs.get("google_api_key", ""),
|
||||
"google_cse_id": kwargs.get("google_cse_id", ""),
|
||||
"google_simple": kwargs.get("google_simple", True), # 返回内容更精简
|
||||
"google_output_type": kwargs.get("google_output_type", "text"), # 搜索结果类型 ["text", "json"]
|
||||
# for finance-news tool
|
||||
"finance_news_filter": kwargs.get("finance_news_filter", False), # 是否开启过滤
|
||||
"finance_news_filter_list": kwargs.get("finance_news_filter_list", []), # 过滤词列表
|
||||
"finance_news_simple": kwargs.get("finance_news_simple", True), # 返回内容更精简
|
||||
"finance_news_repeat_news": kwargs.get("finance_news_repeat_news", False), # 是否过滤不返回。该tool每次返回约50条新闻,可能有重复新闻
|
||||
# for morning-news tool
|
||||
"morning_news_api_key": kwargs.get("morning_news_api_key", ""), # api-key
|
||||
"morning_news_simple": kwargs.get("morning_news_simple", True), # 返回内容更精简
|
||||
"morning_news_output_type": kwargs.get("morning_news_output_type", "text"), # 搜索结果类型 ["text", "image"]
|
||||
# for news-api tool
|
||||
"news_api_key": kwargs.get("news_api_key", ""),
|
||||
# for searxng-search tool
|
||||
"searx_search_host": kwargs.get("searx_search_host", ""),
|
||||
"searxng_search_host": kwargs.get("searxng_search_host", ""),
|
||||
"searxng_search_top_k_results": kwargs.get("searxng_search_top_k_results", 2), # 只返回前k个搜索结果
|
||||
"searxng_search_output_type": kwargs.get("searxng_search_output_type", "text"), # 搜索结果类型 ["text", "json"]
|
||||
# for sms tool
|
||||
"sms_nickname_mapping": kwargs.get("sms_nickname_mapping", "{}"), # 关于人的代号对应的手机号,可以不输入手机号发送sms。键为代号值为手机号
|
||||
"sms_username": kwargs.get("sms_username", ""), # smsbao用户名
|
||||
"sms_apikey": kwargs.get("sms_apikey", ""), # smsbao
|
||||
# for stt tool
|
||||
"stt_api_key": kwargs.get("stt_api_key", ""), # azure
|
||||
"stt_api_region": kwargs.get("stt_api_region", ""), # azure
|
||||
"stt_recognition_language": kwargs.get("stt_recognition_language", "zh-CN"), # 识别的语言类型 部分:en-US ja-JP ko-KR yue-CN zh-CN
|
||||
# for tts tool
|
||||
"tts_api_key": kwargs.get("tts_api_key", ""), # azure
|
||||
"tts_api_region": kwargs.get("tts_api_region", ""), # azure
|
||||
"tts_auto_detect": kwargs.get("tts_auto_detect", True), # 是否自动检测语音的语言
|
||||
"tts_speech_id": kwargs.get("tts_speech_id", "zh-CN-XiaozhenNeural"), # 输出语音ID
|
||||
# for summary tool
|
||||
"summary_max_segment_length": kwargs.get("summary_max_segment_length", 2500), # 每2500tokens分段,多段触发总结tool
|
||||
# for terminal tool
|
||||
"terminal_nsfc_filter": kwargs.get("terminal_nsfc_filter", True), # 是否过滤llm输出的危险命令
|
||||
"terminal_return_err_output": kwargs.get("terminal_return_err_output", True), # 是否输出错误信息
|
||||
"terminal_timeout": kwargs.get("terminal_timeout", 20), # 允许命令最长执行时间
|
||||
# for visual tool
|
||||
"caption_api_key": kwargs.get("caption_api_key", ""), # ali dashscope apikey
|
||||
# for browser tool
|
||||
"browser_use_summary": kwargs.get("browser_use_summary", True), # 是否对返回结果使用tool功能
|
||||
# for url-get tool
|
||||
"url_get_use_summary": kwargs.get("url_get_use_summary", True), # 是否对返回结果使用tool功能
|
||||
# for wechat tool
|
||||
"wechat_hot_reload": kwargs.get("wechat_hot_reload", True), # 是否使用热重载的方式发送wechat
|
||||
"wechat_cpt_path": kwargs.get("wechat_cpt_path", os.path.join(get_appdata_dir(), "itchat.pkl")), # wechat 配置文件(`itchat.pkl`)
|
||||
"wechat_send_group": kwargs.get("wechat_send_group", False), # 是否向群组发送消息
|
||||
"wechat_nickname_mapping": kwargs.get("wechat_nickname_mapping", "{}"), # 关于人的代号映射关系。键为代号值为微信名(昵称、备注名均可)
|
||||
# for wikipedia tool
|
||||
"wikipedia_top_k_results": kwargs.get("wikipedia_top_k_results", 2), # 只返回前k个搜索结果
|
||||
# for wolfram-alpha tool
|
||||
"wolfram_alpha_appid": kwargs.get("wolfram_alpha_appid", ""),
|
||||
# for morning-news tool
|
||||
"morning_news_api_key": kwargs.get("morning_news_api_key", ""),
|
||||
# for visual_dl tool
|
||||
"cuda_device": kwargs.get("cuda_device", "cpu"),
|
||||
"think_depth": kwargs.get("think_depth", 3),
|
||||
"arxiv_summary": kwargs.get("arxiv_summary", True),
|
||||
"morning_news_use_llm": kwargs.get("morning_news_use_llm", False),
|
||||
}
|
||||
|
||||
def _filter_tool_list(self, tool_list: list):
|
||||
valid_list = []
|
||||
for tool in tool_list:
|
||||
if tool in get_all_tool_names():
|
||||
if tool in main_tool_register.get_registered_tool_names():
|
||||
valid_list.append(tool)
|
||||
else:
|
||||
logger.warning("[tool] filter invalid tool: " + repr(tool))
|
||||
return valid_list
|
||||
|
||||
def _reset_app(self) -> App:
|
||||
tool_config = self._read_json()
|
||||
app_kwargs = self._build_tool_kwargs(tool_config.get("kwargs", {}))
|
||||
self.tool_config = self._read_json()
|
||||
self.app_kwargs = self._build_tool_kwargs(self.tool_config.get("kwargs", {}))
|
||||
|
||||
app = AppFactory()
|
||||
app.init_env(**app_kwargs)
|
||||
|
||||
app.init_env(**self.app_kwargs)
|
||||
# filter not support tool
|
||||
tool_list = self._filter_tool_list(tool_config.get("tools", []))
|
||||
tool_list = self._filter_tool_list(self.tool_config.get("tools", []))
|
||||
|
||||
return app.create_app(tools_list=tool_list, **app_kwargs)
|
||||
return app.create_app(tools_list=tool_list, **self.app_kwargs)
|
||||
|
||||
@@ -13,16 +13,30 @@ langid # language detect
|
||||
#install plugin
|
||||
dulwich
|
||||
|
||||
# wechaty
|
||||
wechaty>=0.10.7
|
||||
wechaty_puppet>=0.4.23
|
||||
pysilk_mod>=1.6.0 # needed by send voice
|
||||
|
||||
# wechatmp wechatcom
|
||||
# wechatmp && wechatcom
|
||||
web.py
|
||||
wechatpy
|
||||
|
||||
# chatgpt-tool-hub plugin
|
||||
chatgpt_tool_hub==0.5.0
|
||||
|
||||
--extra-index-url https://pypi.python.org/simple
|
||||
chatgpt_tool_hub==0.4.4
|
||||
# xunfei spark
|
||||
websocket-client==1.2.0
|
||||
|
||||
# claude bot
|
||||
curl_cffi
|
||||
|
||||
# tongyi qwen
|
||||
broadscope_bailian
|
||||
|
||||
# google
|
||||
google-generativeai
|
||||
|
||||
# linkai
|
||||
linkai>=0.0.3.5
|
||||
|
||||
# dingtalk
|
||||
dingtalk_stream
|
||||
|
||||
# zhipuai
|
||||
zhipuai>=2.0.1
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
openai==0.27.2
|
||||
openai==0.27.8
|
||||
HTMLParser>=0.0.2
|
||||
PyQRCode>=1.2.1
|
||||
qrcode>=7.4.2
|
||||
requests>=2.28.2
|
||||
chardet>=5.1.0
|
||||
Pillow
|
||||
pre-commit
|
||||
pre-commit
|
||||
web.py
|
||||
|
||||
152
voice/ali/ali_api.py
Normal file
152
voice/ali/ali_api.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# coding=utf-8
|
||||
"""
|
||||
Author: chazzjimel
|
||||
Email: chazzjimel@gmail.com
|
||||
wechat:cheung-z-x
|
||||
|
||||
Description:
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
import datetime
|
||||
import hashlib
|
||||
import hmac
|
||||
import base64
|
||||
import urllib.parse
|
||||
import uuid
|
||||
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
|
||||
|
||||
def text_to_speech_aliyun(url, text, appkey, token):
|
||||
"""
|
||||
使用阿里云的文本转语音服务将文本转换为语音。
|
||||
|
||||
参数:
|
||||
- url (str): 阿里云文本转语音服务的端点URL。
|
||||
- text (str): 要转换为语音的文本。
|
||||
- appkey (str): 您的阿里云appkey。
|
||||
- token (str): 阿里云API的认证令牌。
|
||||
|
||||
返回值:
|
||||
- str: 成功时输出音频文件的路径,否则为None。
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"text": text,
|
||||
"appkey": appkey,
|
||||
"token": token,
|
||||
"format": "wav"
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, data=json.dumps(data))
|
||||
|
||||
if response.status_code == 200 and response.headers['Content-Type'] == 'audio/mpeg':
|
||||
output_file = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".wav"
|
||||
|
||||
with open(output_file, 'wb') as file:
|
||||
file.write(response.content)
|
||||
logger.debug(f"音频文件保存成功,文件名:{output_file}")
|
||||
else:
|
||||
logger.debug("响应状态码: {}".format(response.status_code))
|
||||
logger.debug("响应内容: {}".format(response.text))
|
||||
output_file = None
|
||||
|
||||
return output_file
|
||||
|
||||
|
||||
class AliyunTokenGenerator:
|
||||
"""
|
||||
用于生成阿里云服务认证令牌的类。
|
||||
|
||||
属性:
|
||||
- access_key_id (str): 您的阿里云访问密钥ID。
|
||||
- access_key_secret (str): 您的阿里云访问密钥秘密。
|
||||
"""
|
||||
|
||||
def __init__(self, access_key_id, access_key_secret):
|
||||
self.access_key_id = access_key_id
|
||||
self.access_key_secret = access_key_secret
|
||||
|
||||
def sign_request(self, parameters):
|
||||
"""
|
||||
为阿里云服务签名请求。
|
||||
|
||||
参数:
|
||||
- parameters (dict): 请求的参数字典。
|
||||
|
||||
返回值:
|
||||
- str: 请求的签名签章。
|
||||
"""
|
||||
# 将参数按照字典顺序排序
|
||||
sorted_params = sorted(parameters.items())
|
||||
|
||||
# 构造待签名的查询字符串
|
||||
canonicalized_query_string = ''
|
||||
for (k, v) in sorted_params:
|
||||
canonicalized_query_string += '&' + self.percent_encode(k) + '=' + self.percent_encode(v)
|
||||
|
||||
# 构造用于签名的字符串
|
||||
string_to_sign = 'GET&%2F&' + self.percent_encode(canonicalized_query_string[1:]) # 使用GET方法
|
||||
|
||||
# 使用HMAC算法计算签名
|
||||
h = hmac.new((self.access_key_secret + "&").encode('utf-8'), string_to_sign.encode('utf-8'), hashlib.sha1)
|
||||
signature = base64.encodebytes(h.digest()).strip()
|
||||
|
||||
return signature
|
||||
|
||||
def percent_encode(self, encode_str):
|
||||
"""
|
||||
对字符串进行百分比编码。
|
||||
|
||||
参数:
|
||||
- encode_str (str): 要编码的字符串。
|
||||
|
||||
返回值:
|
||||
- str: 编码后的字符串。
|
||||
"""
|
||||
encode_str = str(encode_str)
|
||||
res = urllib.parse.quote(encode_str, '')
|
||||
res = res.replace('+', '%20')
|
||||
res = res.replace('*', '%2A')
|
||||
res = res.replace('%7E', '~')
|
||||
return res
|
||||
|
||||
def get_token(self):
|
||||
"""
|
||||
获取阿里云服务的令牌。
|
||||
|
||||
返回值:
|
||||
- str: 获取到的令牌。
|
||||
"""
|
||||
# 设置请求参数
|
||||
params = {
|
||||
'Format': 'JSON',
|
||||
'Version': '2019-02-28',
|
||||
'AccessKeyId': self.access_key_id,
|
||||
'SignatureMethod': 'HMAC-SHA1',
|
||||
'Timestamp': datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
|
||||
'SignatureVersion': '1.0',
|
||||
'SignatureNonce': str(uuid.uuid4()), # 使用uuid生成唯一的随机数
|
||||
'Action': 'CreateToken',
|
||||
'RegionId': 'cn-shanghai'
|
||||
}
|
||||
|
||||
# 计算签名
|
||||
signature = self.sign_request(params)
|
||||
params['Signature'] = signature
|
||||
|
||||
# 构造请求URL
|
||||
url = 'http://nls-meta.cn-shanghai.aliyuncs.com/?' + urllib.parse.urlencode(params)
|
||||
|
||||
# 发送请求
|
||||
response = requests.get(url)
|
||||
|
||||
return response.text
|
||||
81
voice/ali/ali_voice.py
Normal file
81
voice/ali/ali_voice.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Author: chazzjimel
|
||||
Email: chazzjimel@gmail.com
|
||||
wechat:cheung-z-x
|
||||
|
||||
Description:
|
||||
ali voice service
|
||||
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from voice.voice import Voice
|
||||
from voice.ali.ali_api import AliyunTokenGenerator
|
||||
from voice.ali.ali_api import text_to_speech_aliyun
|
||||
from config import conf
|
||||
|
||||
|
||||
class AliVoice(Voice):
|
||||
def __init__(self):
|
||||
"""
|
||||
初始化AliVoice类,从配置文件加载必要的配置。
|
||||
"""
|
||||
try:
|
||||
curdir = os.path.dirname(__file__)
|
||||
config_path = os.path.join(curdir, "config.json")
|
||||
with open(config_path, "r") as fr:
|
||||
config = json.load(fr)
|
||||
self.token = None
|
||||
self.token_expire_time = 0
|
||||
# 默认复用阿里云千问的 access_key 和 access_secret
|
||||
self.api_url = config.get("api_url")
|
||||
self.app_key = config.get("app_key")
|
||||
self.access_key_id = conf().get("qwen_access_key_id") or config.get("access_key_id")
|
||||
self.access_key_secret = conf().get("qwen_access_key_secret") or config.get("access_key_secret")
|
||||
except Exception as e:
|
||||
logger.warn("AliVoice init failed: %s, ignore " % e)
|
||||
|
||||
def textToVoice(self, text):
|
||||
"""
|
||||
将文本转换为语音文件。
|
||||
|
||||
:param text: 要转换的文本。
|
||||
:return: 返回一个Reply对象,其中包含转换得到的语音文件或错误信息。
|
||||
"""
|
||||
# 清除文本中的非中文、非英文和非基本字符
|
||||
text = re.sub(r'[^\u4e00-\u9fa5\u3040-\u30FF\uAC00-\uD7AFa-zA-Z0-9'
|
||||
r'äöüÄÖÜáéíóúÁÉÍÓÚàèìòùÀÈÌÒÙâêîôûÂÊÎÔÛçÇñÑ,。!?,.]', '', text)
|
||||
# 提取有效的token
|
||||
token_id = self.get_valid_token()
|
||||
fileName = text_to_speech_aliyun(self.api_url, text, self.app_key, token_id)
|
||||
if fileName:
|
||||
logger.info("[Ali] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
reply = Reply(ReplyType.VOICE, fileName)
|
||||
else:
|
||||
reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败")
|
||||
return reply
|
||||
|
||||
def get_valid_token(self):
|
||||
"""
|
||||
获取有效的阿里云token。
|
||||
|
||||
:return: 返回有效的token字符串。
|
||||
"""
|
||||
current_time = time.time()
|
||||
if self.token is None or current_time >= self.token_expire_time:
|
||||
get_token = AliyunTokenGenerator(self.access_key_id, self.access_key_secret)
|
||||
token_str = get_token.get_token()
|
||||
token_data = json.loads(token_str)
|
||||
self.token = token_data["Token"]["Id"]
|
||||
# 将过期时间减少一小段时间(例如5分钟),以避免在边界条件下的过期
|
||||
self.token_expire_time = token_data["Token"]["ExpireTime"] - 300
|
||||
logger.debug(f"新获取的阿里云token:{self.token}")
|
||||
else:
|
||||
logger.debug("使用缓存的token")
|
||||
return self.token
|
||||
6
voice/ali/config.json.template
Normal file
6
voice/ali/config.json.template
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"api_url": "https://nls-gateway-cn-shanghai.aliyuncs.com/stream/v1/tts",
|
||||
"app_key": "",
|
||||
"access_key_id": "",
|
||||
"access_key_secret": ""
|
||||
}
|
||||
33
voice/elevent/elevent_voice.py
Normal file
33
voice/elevent/elevent_voice.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import time
|
||||
|
||||
from elevenlabs import set_api_key,generate
|
||||
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from common.tmp_dir import TmpDir
|
||||
from voice.voice import Voice
|
||||
from config import conf
|
||||
|
||||
XI_API_KEY = conf().get("xi_api_key")
|
||||
set_api_key(XI_API_KEY)
|
||||
name = conf().get("xi_voice_id")
|
||||
|
||||
class ElevenLabsVoice(Voice):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
pass
|
||||
|
||||
def textToVoice(self, text):
|
||||
audio = generate(
|
||||
text=text,
|
||||
voice=name,
|
||||
model='eleven_multilingual_v1'
|
||||
)
|
||||
fileName = TmpDir().path() + "reply-" + str(int(time.time())) + "-" + str(hash(text) & 0x7FFFFFFF) + ".mp3"
|
||||
with open(fileName, "wb") as f:
|
||||
f.write(audio)
|
||||
logger.info("[ElevenLabs] textToVoice text={} voice file name={}".format(text, fileName))
|
||||
return Reply(ReplyType.VOICE, fileName)
|
||||
@@ -29,4 +29,17 @@ def create_voice(voice_type):
|
||||
from voice.azure.azure_voice import AzureVoice
|
||||
|
||||
return AzureVoice()
|
||||
elif voice_type == "elevenlabs":
|
||||
from voice.elevent.elevent_voice import ElevenLabsVoice
|
||||
|
||||
return ElevenLabsVoice()
|
||||
|
||||
elif voice_type == "linkai":
|
||||
from voice.linkai.linkai_voice import LinkAIVoice
|
||||
|
||||
return LinkAIVoice()
|
||||
elif voice_type == "ali":
|
||||
from voice.ali.ali_voice import AliVoice
|
||||
|
||||
return AliVoice()
|
||||
raise RuntimeError
|
||||
|
||||
83
voice/linkai/linkai_voice.py
Normal file
83
voice/linkai/linkai_voice.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
google voice service
|
||||
"""
|
||||
import random
|
||||
import requests
|
||||
from voice import audio_convert
|
||||
from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from voice.voice import Voice
|
||||
from common import const
|
||||
import os
|
||||
import datetime
|
||||
|
||||
class LinkAIVoice(Voice):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def voiceToText(self, voice_file):
|
||||
logger.debug("[LinkVoice] voice file name={}".format(voice_file))
|
||||
try:
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/audio/transcriptions"
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
model = None
|
||||
if not conf().get("text_to_voice") or conf().get("voice_to_text") == "openai":
|
||||
model = const.WHISPER_1
|
||||
if voice_file.endswith(".amr"):
|
||||
try:
|
||||
mp3_file = os.path.splitext(voice_file)[0] + ".mp3"
|
||||
audio_convert.any_to_mp3(voice_file, mp3_file)
|
||||
voice_file = mp3_file
|
||||
except Exception as e:
|
||||
logger.warn(f"[LinkVoice] amr file transfer failed, directly send amr voice file: {format(e)}")
|
||||
file = open(voice_file, "rb")
|
||||
file_body = {
|
||||
"file": file
|
||||
}
|
||||
data = {
|
||||
"model": model
|
||||
}
|
||||
res = requests.post(url, files=file_body, headers=headers, data=data, timeout=(5, 60))
|
||||
if res.status_code == 200:
|
||||
text = res.json().get("text")
|
||||
else:
|
||||
res_json = res.json()
|
||||
logger.error(f"[LinkVoice] voiceToText error, status_code={res.status_code}, msg={res_json.get('message')}")
|
||||
return None
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
logger.info(f"[LinkVoice] voiceToText success, text={text}, file name={voice_file}")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
return reply
|
||||
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
url = conf().get("linkai_api_base", "https://api.link-ai.chat") + "/v1/audio/speech"
|
||||
headers = {"Authorization": "Bearer " + conf().get("linkai_api_key")}
|
||||
model = const.TTS_1
|
||||
if not conf().get("text_to_voice") or conf().get("text_to_voice") in ["openai", const.TTS_1, const.TTS_1_HD]:
|
||||
model = conf().get("text_to_voice_model") or const.TTS_1
|
||||
data = {
|
||||
"model": model,
|
||||
"input": text,
|
||||
"voice": conf().get("tts_voice_id"),
|
||||
"app_code": conf().get("linkai_app_code")
|
||||
}
|
||||
res = requests.post(url, headers=headers, json=data, timeout=(5, 120))
|
||||
if res.status_code == 200:
|
||||
tmp_file_name = "tmp/" + datetime.datetime.now().strftime('%Y%m%d%H%M%S') + str(random.randint(0, 1000)) + ".mp3"
|
||||
with open(tmp_file_name, 'wb') as f:
|
||||
f.write(res.content)
|
||||
reply = Reply(ReplyType.VOICE, tmp_file_name)
|
||||
logger.info(f"[LinkVoice] textToVoice success, input={text}, model={model}, voice_id={data.get('voice')}")
|
||||
return reply
|
||||
else:
|
||||
res_json = res.json()
|
||||
logger.error(f"[LinkVoice] textToVoice error, status_code={res.status_code}, msg={res_json.get('message')}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
# reply = Reply(ReplyType.ERROR, "遇到了一点小问题,请稍后再问我吧")
|
||||
return None
|
||||
@@ -9,7 +9,9 @@ from bridge.reply import Reply, ReplyType
|
||||
from common.log import logger
|
||||
from config import conf
|
||||
from voice.voice import Voice
|
||||
|
||||
import requests
|
||||
from common import const
|
||||
import datetime, random
|
||||
|
||||
class OpenaiVoice(Voice):
|
||||
def __init__(self):
|
||||
@@ -24,6 +26,32 @@ class OpenaiVoice(Voice):
|
||||
reply = Reply(ReplyType.TEXT, text)
|
||||
logger.info("[Openai] voiceToText text={} voice file name={}".format(text, voice_file))
|
||||
except Exception as e:
|
||||
reply = Reply(ReplyType.ERROR, str(e))
|
||||
reply = Reply(ReplyType.ERROR, "我暂时还无法听清您的语音,请稍后再试吧~")
|
||||
finally:
|
||||
return reply
|
||||
|
||||
|
||||
def textToVoice(self, text):
|
||||
try:
|
||||
api_base = conf().get("open_ai_api_base") or "https://api.openai.com/v1"
|
||||
url = f'{api_base}/audio/speech'
|
||||
headers = {
|
||||
'Authorization': 'Bearer ' + conf().get("open_ai_api_key"),
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
data = {
|
||||
'model': conf().get("text_to_voice_model") or const.TTS_1,
|
||||
'input': text,
|
||||
'voice': conf().get("tts_voice_id") or "alloy"
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
file_name = "tmp/" + datetime.datetime.now().strftime('%Y%m%d%H%M%S') + str(random.randint(0, 1000)) + ".mp3"
|
||||
logger.debug(f"[OPENAI] text_to_Voice file_name={file_name}, input={text}")
|
||||
with open(file_name, 'wb') as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"[OPENAI] text_to_Voice success")
|
||||
reply = Reply(ReplyType.VOICE, file_name)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
reply = Reply(ReplyType.ERROR, "遇到了一点小问题,请稍后再问我吧")
|
||||
return reply
|
||||
|
||||
Reference in New Issue
Block a user